Skip to content

Commit 4d65694

Browse files
youdymookeon
authored andcommitted
Add 7 Matrix Unit Tests and Implement Matrix Multiplication (keon#497)
* add_matrix_multiplication * fix_rotate_image * test_matrix * add_matrix_multiplication * fix_matrix_multiplication * fix test_matrix * fix test_matrix * fix test_matrix
1 parent 638b43c commit 4d65694

File tree

5 files changed

+244
-46
lines changed

5 files changed

+244
-46
lines changed

README.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -212,6 +212,7 @@ If you want to uninstall algorithms, it is as simple as:
212212
- [copy_transform](algorithms/matrix/copy_transform.py)
213213
- [count_paths](algorithms/matrix/count_paths.py)
214214
- [matrix_rotation.txt](algorithms/matrix/matrix_rotation.txt)
215+
- [matrix_multiplication](algorithms/matrix/multiply.py)
215216
- [rotate_image](algorithms/matrix/rotate_image.py)
216217
- [search_in_sorted_matrix](algorithms/matrix/search_in_sorted_matrix.py)
217218
- [sparse_dot_vector](algorithms/matrix/sparse_dot_vector.py)

algorithms/matrix/multiply.py

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
"""
2+
This algorithm takes two compatible two dimensional matrix
3+
and return their product
4+
Space complexity: O(n^2)
5+
Possible edge case: the number of columns of multiplicand not consistent with
6+
the number of rows of multiplier, will raise exception
7+
"""
8+
9+
10+
def multiply(multiplicand: list, multiplier: list) -> list:
11+
"""
12+
:type A: List[List[int]]
13+
:type B: List[List[int]]
14+
:rtype: List[List[int]]
15+
"""
16+
multiplicand_row, multiplicand_col = len(
17+
multiplicand), len(multiplicand[0])
18+
multiplier_row, multiplier_col = len(multiplier), len(multiplier[0])
19+
if(multiplicand_col != multiplier_row):
20+
raise Exception(
21+
"Multiplicand matrix not compatible with Multiplier matrix.")
22+
# create a result matrix
23+
result = [[0] * multiplier_col for i in range(multiplicand_row)]
24+
for i in range(multiplicand_row):
25+
for j in range(multiplier_col):
26+
for k in range(len(multiplier)):
27+
result[i][j] += multiplicand[i][k] * multiplier[k][j]
28+
return result

algorithms/matrix/rotate_image.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,13 +14,14 @@
1414
# 4 5 6 => 4 5 6 => 8 5 2
1515
# 7 8 9 1 2 3 9 6 3
1616

17-
def rotate(mat):
17+
def rotate(mat):
1818
if not mat:
1919
return mat
2020
mat.reverse()
2121
for i in range(len(mat)):
2222
for j in range(i):
2323
mat[i][j], mat[j][i] = mat[j][i], mat[i][j]
24+
return mat
2425

2526

2627
if __name__ == "__main__":

algorithms/matrix/sudoku_validator.py

Lines changed: 0 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -82,32 +82,3 @@ def valid_solution_set (board):
8282
return False
8383

8484
return True
85-
86-
# test cases
87-
# To avoid congestion I'll leave testing all the functions to the reader. Just change the name of the function in the below test cases.
88-
import unittest
89-
class TestSuite(unittest.TestCase):
90-
def test_valid(self):
91-
self.assertTrue(valid_solution([[5, 3, 4, 6, 7, 8, 9, 1, 2],
92-
[6, 7, 2, 1, 9, 5, 3, 4, 8],
93-
[1, 9, 8, 3, 4, 2, 5, 6, 7],
94-
[8, 5, 9, 7, 6, 1, 4, 2, 3],
95-
[4, 2, 6, 8, 5, 3, 7, 9, 1],
96-
[7, 1, 3, 9, 2, 4, 8, 5, 6],
97-
[9, 6, 1, 5, 3, 7, 2, 8, 4],
98-
[2, 8, 7, 4, 1, 9, 6, 3, 5],
99-
[3, 4, 5, 2, 8, 6, 1, 7, 9]]))
100-
101-
def test_invalid(self):
102-
self.assertFalse(valid_solution([[5, 3, 4, 6, 7, 8, 9, 1, 2],
103-
[6, 7, 2, 1, 9, 0, 3, 4, 9],
104-
[1, 0, 0, 3, 4, 2, 5, 6, 0],
105-
[8, 5, 9, 7, 6, 1, 0, 2, 0],
106-
[4, 2, 6, 8, 5, 3, 7, 9, 1],
107-
[7, 1, 3, 9, 2, 4, 8, 5, 6],
108-
[9, 0, 1, 5, 3, 7, 2, 1, 4],
109-
[2, 8, 7, 4, 1, 9, 6, 3, 5],
110-
[3, 0, 0, 4, 8, 1, 1, 7, 9]]))
111-
112-
if __name__ == "__main__":
113-
unittest.main()

tests/test_matrix.py

Lines changed: 213 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,34 +1,86 @@
11
from algorithms.matrix import (
2-
crout_matrix_decomposition
3-
)
2+
bomb_enemy,
3+
copy_transform,
4+
crout_matrix_decomposition,
5+
multiply,
6+
rotate_image,
7+
sparse_dot_vector,
8+
spiral_traversal,
9+
sudoku_validator
10+
)
11+
import unittest
412

513

6-
import unittest
14+
class TestBombEnemy(unittest.TestCase):
15+
def test_3x4(self):
16+
grid1 = [
17+
["0","E","0","0"],
18+
["E","0","W","E"],
19+
["0","E","0","0"]
20+
]
21+
self.assertEqual(3, bomb_enemy.max_killed_enemies(grid1))
22+
23+
grid1 = [
24+
["0", "E", "0", "E"],
25+
["E", "E", "E", "0"],
26+
["E", "0", "W", "E"],
27+
["0", "E", "0", "0"]
28+
]
29+
grid2 = [
30+
["0", "0", "0", "E"],
31+
["E", "0", "0", "0"],
32+
["E", "0", "W", "E"],
33+
["0", "E", "0", "0"]
34+
]
35+
self.assertEqual(5, bomb_enemy.max_killed_enemies(grid1))
36+
self.assertEqual(3, bomb_enemy.max_killed_enemies(grid2))
37+
38+
39+
class TestCopyTransform(unittest.TestCase):
40+
"""[summary]
41+
Test for the file copy_transform.py
42+
43+
Arguments:
44+
unittest {[type]} -- [description]
45+
"""
46+
47+
def test_copy_transform(self):
48+
self.assertEqual(copy_transform.rotate_clockwise(
49+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [[7, 4, 1], [8, 5, 2], [9, 6, 3]])
50+
51+
self.assertEqual(copy_transform.rotate_counterclockwise(
52+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [[3, 6, 9], [2, 5, 8], [1, 4, 7]])
53+
54+
self.assertEqual(copy_transform.top_left_invert(
55+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [[1, 4, 7], [2, 5, 8], [3, 6, 9]])
56+
57+
self.assertEqual(copy_transform.bottom_left_invert(
58+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [[9, 6, 3], [8, 5, 2], [7, 4, 1]])
759

860

961
class TestCroutMatrixDecomposition(unittest.TestCase):
1062
"""[summary]
11-
Test for the file crout_matrix_decomposition.crout_matrix_decomposition.py
63+
Test for the file crout_matrix_decomposition.py
1264
1365
Arguments:
1466
unittest {[type]} -- [description]
1567
"""
16-
68+
1769
def test_crout_matrix_decomposition(self):
1870
self.assertEqual(([[9.0, 0.0], [7.0, 0.0]],
1971
[[1.0, 1.0], [0.0, 1.0]]),
2072
crout_matrix_decomposition.crout_matrix_decomposition(
21-
[[9,9], [7,7]]))
22-
73+
[[9, 9], [7, 7]]))
74+
2375
self.assertEqual(([[1.0, 0.0, 0.0],
2476
[3.0, -2.0, 0.0],
2577
[6.0, -5.0, 0.0]],
26-
[[1.0, 2.0, 3.0],
27-
[0.0, 1.0, 2.0],
28-
[0.0, 0.0, 1.0]]),
78+
[[1.0, 2.0, 3.0],
79+
[0.0, 1.0, 2.0],
80+
[0.0, 0.0, 1.0]]),
2981
crout_matrix_decomposition.crout_matrix_decomposition(
30-
[[1,2,3],[3,4,5],[6,7,8]]))
31-
82+
[[1, 2, 3], [3, 4, 5], [6, 7, 8]]))
83+
3284
self.assertEqual(([[2.0, 0, 0, 0],
3385
[4.0, -1.0, 0, 0],
3486
[6.0, -2.0, 2.0, 0],
@@ -38,9 +90,154 @@ def test_crout_matrix_decomposition(self):
3890
[0, 0, 1.0, 0.0],
3991
[0, 0, 0, 1.0]]),
4092
crout_matrix_decomposition.crout_matrix_decomposition(
41-
[[2,1,3,1], [4,1,4,1], [6,1,7,1], [8,1,9,1]]))
42-
43-
44-
93+
[[2, 1, 3, 1], [4, 1, 4, 1], [6, 1, 7, 1], [8, 1, 9, 1]]))
94+
95+
96+
class TestMultiply(unittest.TestCase):
97+
"""[summary]
98+
Test for the file multiply.py
99+
100+
Arguments:
101+
unittest {[type]} -- [description]
102+
"""
103+
104+
def test_multiply(self):
105+
self.assertEqual(multiply.multiply(
106+
[[1, 2, 3], [2, 1, 1]], [[1], [2], [3]]), [[14], [7]])
107+
108+
109+
class TestRotateImage(unittest.TestCase):
110+
"""[summary]
111+
Test for the file rotate_image.py
112+
113+
Arguments:
114+
unittest {[type]} -- [description]
115+
"""
116+
117+
def test_rotate_image(self):
118+
self.assertEqual(rotate_image.rotate(
119+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [[7, 4, 1], [8, 5, 2], [9, 6, 3]])
120+
121+
122+
class TestSparseDotVector(unittest.TestCase):
123+
"""[summary]
124+
Test for the file sparse_dot_vector.py
125+
126+
Arguments:
127+
unittest {[type]} -- [description]
128+
"""
129+
130+
def test_sparse_dot_vector(self):
131+
self.assertEqual(sparse_dot_vector.dot_product(sparse_dot_vector.vector_to_index_value_list(
132+
[1., 2., 3.]), sparse_dot_vector.vector_to_index_value_list([0., 2., 2.])), 10)
133+
134+
135+
class TestSpiralTraversal(unittest.TestCase):
136+
"""[summary]
137+
Test for the file spiral_traversal.py
138+
139+
Arguments:
140+
unittest {[type]} -- [description]
141+
"""
142+
143+
def test_spiral_traversal(self):
144+
self.assertEqual(spiral_traversal.spiral_traversal(
145+
[[1, 2, 3], [4, 5, 6], [7, 8, 9]]), [1, 2, 3, 6, 9, 8, 7, 4, 5])
146+
147+
148+
class TestSudokuValidator(unittest.TestCase):
149+
"""[summary]
150+
Test for the file sudoku_validator.py
151+
152+
Arguments:
153+
unittest {[type]} -- [description]
154+
"""
155+
156+
def test_sudoku_validator(self):
157+
self.assertTrue(
158+
sudoku_validator.valid_solution(
159+
[
160+
[5, 3, 4, 6, 7, 8, 9, 1, 2],
161+
[6, 7, 2, 1, 9, 5, 3, 4, 8],
162+
[1, 9, 8, 3, 4, 2, 5, 6, 7],
163+
[8, 5, 9, 7, 6, 1, 4, 2, 3],
164+
[4, 2, 6, 8, 5, 3, 7, 9, 1],
165+
[7, 1, 3, 9, 2, 4, 8, 5, 6],
166+
[9, 6, 1, 5, 3, 7, 2, 8, 4],
167+
[2, 8, 7, 4, 1, 9, 6, 3, 5],
168+
[3, 4, 5, 2, 8, 6, 1, 7, 9]
169+
]))
170+
171+
self.assertTrue(
172+
sudoku_validator.valid_solution_hashtable(
173+
[
174+
[5, 3, 4, 6, 7, 8, 9, 1, 2],
175+
[6, 7, 2, 1, 9, 5, 3, 4, 8],
176+
[1, 9, 8, 3, 4, 2, 5, 6, 7],
177+
[8, 5, 9, 7, 6, 1, 4, 2, 3],
178+
[4, 2, 6, 8, 5, 3, 7, 9, 1],
179+
[7, 1, 3, 9, 2, 4, 8, 5, 6],
180+
[9, 6, 1, 5, 3, 7, 2, 8, 4],
181+
[2, 8, 7, 4, 1, 9, 6, 3, 5],
182+
[3, 4, 5, 2, 8, 6, 1, 7, 9]
183+
]))
184+
185+
self.assertTrue(
186+
sudoku_validator.valid_solution_set(
187+
[
188+
[5, 3, 4, 6, 7, 8, 9, 1, 2],
189+
[6, 7, 2, 1, 9, 5, 3, 4, 8],
190+
[1, 9, 8, 3, 4, 2, 5, 6, 7],
191+
[8, 5, 9, 7, 6, 1, 4, 2, 3],
192+
[4, 2, 6, 8, 5, 3, 7, 9, 1],
193+
[7, 1, 3, 9, 2, 4, 8, 5, 6],
194+
[9, 6, 1, 5, 3, 7, 2, 8, 4],
195+
[2, 8, 7, 4, 1, 9, 6, 3, 5],
196+
[3, 4, 5, 2, 8, 6, 1, 7, 9]
197+
]))
198+
199+
self.assertFalse(
200+
sudoku_validator.valid_solution(
201+
[
202+
[5, 3, 4, 6, 7, 8, 9, 1, 2],
203+
[6, 7, 2, 1, 9, 0, 3, 4, 9],
204+
[1, 0, 0, 3, 4, 2, 5, 6, 0],
205+
[8, 5, 9, 7, 6, 1, 0, 2, 0],
206+
[4, 2, 6, 8, 5, 3, 7, 9, 1],
207+
[7, 1, 3, 9, 2, 4, 8, 5, 6],
208+
[9, 0, 1, 5, 3, 7, 2, 1, 4],
209+
[2, 8, 7, 4, 1, 9, 6, 3, 5],
210+
[3, 0, 0, 4, 8, 1, 1, 7, 9]
211+
]))
212+
213+
self.assertFalse(
214+
sudoku_validator.valid_solution_hashtable(
215+
[
216+
[5, 3, 4, 6, 7, 8, 9, 1, 2],
217+
[6, 7, 2, 1, 9, 0, 3, 4, 9],
218+
[1, 0, 0, 3, 4, 2, 5, 6, 0],
219+
[8, 5, 9, 7, 6, 1, 0, 2, 0],
220+
[4, 2, 6, 8, 5, 3, 7, 9, 1],
221+
[7, 1, 3, 9, 2, 4, 8, 5, 6],
222+
[9, 0, 1, 5, 3, 7, 2, 1, 4],
223+
[2, 8, 7, 4, 1, 9, 6, 3, 5],
224+
[3, 0, 0, 4, 8, 1, 1, 7, 9]
225+
]))
226+
227+
self.assertFalse(
228+
sudoku_validator.valid_solution_set(
229+
[
230+
[5, 3, 4, 6, 7, 8, 9, 1, 2],
231+
[6, 7, 2, 1, 9, 0, 3, 4, 9],
232+
[1, 0, 0, 3, 4, 2, 5, 6, 0],
233+
[8, 5, 9, 7, 6, 1, 0, 2, 0],
234+
[4, 2, 6, 8, 5, 3, 7, 9, 1],
235+
[7, 1, 3, 9, 2, 4, 8, 5, 6],
236+
[9, 0, 1, 5, 3, 7, 2, 1, 4],
237+
[2, 8, 7, 4, 1, 9, 6, 3, 5],
238+
[3, 0, 0, 4, 8, 1, 1, 7, 9]
239+
]))
240+
241+
45242
if __name__ == "__main__":
46243
unittest.main()

0 commit comments

Comments
 (0)