Skip to content

Commit 48e25be

Browse files
committed
lint
1 parent c72a259 commit 48e25be

File tree

2 files changed

+13
-8
lines changed

2 files changed

+13
-8
lines changed

td3a_cpp/tutorial/td_mul_cython.pyx

+3
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ cnumpy.import_array()
1212

1313

1414
def multiply_matrix(m1, m2):
15+
"Matrix multiplication"
1516
m3 = pynumpy.zeros((m1.shape[0], m2.shape[1]), dtype=m1.dtype)
1617
for i in range(0, m1.shape[0]):
1718
for j in range(0, m2.shape[1]):
@@ -23,6 +24,7 @@ def multiply_matrix(m1, m2):
2324
cdef void _c_multiply_matrix(double[:, :] m1, double[:, :] m2,
2425
double[:, :] m3,
2526
cython.int ni, cython.int nj, cython.int nk) nogil:
27+
"Matrix multiplication wuth cython"
2628
cdef cython.int i, j, k
2729
for i in prange(0, ni):
2830
for j in range(0, nj):
@@ -31,6 +33,7 @@ cdef void _c_multiply_matrix(double[:, :] m1, double[:, :] m2,
3133

3234

3335
def c_multiply_matrix(double[:, :] m1, double[:, :] m2):
36+
"Matrix multiplication calling the cython version"
3437
m3 = pynumpy.zeros((m1.shape[0], m2.shape[1]), dtype=pynumpy.float64)
3538
cdef cython.int ni = m1.shape[0]
3639
cdef cython.int nj = m2.shape[1]

tests/test_tutorial_td.py

+10-8
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from numpy.testing import assert_almost_equal
77
from td3a_cpp.tutorial.td_mul_cython import (
88
multiply_matrix, c_multiply_matrix)
9-
from td3a_cpp.tutorial.mul_cython_omp import dmul_cython_omp
109

1110

1211
class TestTutorialTD(unittest.TestCase):
@@ -28,15 +27,18 @@ def test_matrix_cmultiply_matrix(self):
2827
def test_timeit(self):
2928
va = numpy.random.randn(300, 400).astype(numpy.float64)
3029
vb = numpy.random.randn(400, 500).astype(numpy.float64)
31-
res1 = va @ vb
32-
res2 = c_multiply_matrix(va, vb)
33-
ctx = {'va': va, 'vb': vb, 'c_multiply_matrix': c_multiply_matrix}
30+
ctx = {'va': va, 'vb': vb, 'c_multiply_matrix': c_multiply_matrix,
31+
'multiply_matrix': multiply_matrix}
3432
res1 = timeit.timeit('va @ vb', number=10, globals=ctx)
35-
res2 = timeit.timeit('c_multiply_matrix(va, vb)', number=10, globals=ctx)
33+
res2 = timeit.timeit(
34+
'c_multiply_matrix(va, vb)', number=10, globals=ctx)
35+
res3 = timeit.timeit(
36+
'multiply_matrix(va, vb)', number=10, globals=ctx)
3637
self.assertLess(res1, res2) # numpy is much faster.
37-
ratio = res2 / res1
38-
self.assertGreater(ratio, 1) # ratio = number of times numpy is faster
39-
# print(ratio)
38+
ratio1 = res2 / res1
39+
self.assertGreater(ratio1, 1) # ratio1 = number of times numpy is faster
40+
ratio2 = res3 / res1
41+
self.assertGreater(ratio2, 1)
4042

4143

4244
if __name__ == '__main__':

0 commit comments

Comments
 (0)