Skip to content

Commit 9ae71d5

Browse files
authored
Add files via upload
0 parents  commit 9ae71d5

File tree

1 file changed

+94
-0
lines changed

1 file changed

+94
-0
lines changed

parallel_matrix_multiply.py

Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
import time
2+
import numpy as np
3+
import multiprocessing
4+
5+
6+
def default_multiply(matrix_a, matrix_b):
7+
dim = matrix_a.shape[0]
8+
result = np.zeros((dim, dim), dtype='int')
9+
matrix_b = np.transpose(matrix_b)
10+
for i in range(dim):
11+
result[i] = np.sum(matrix_a[i] * matrix_b, axis=1)
12+
return result
13+
14+
15+
def matrix_split(matrix):
16+
n = matrix.shape[0] // 2
17+
return matrix[:n, :n], matrix[:n, n:], matrix[n:, :n], matrix[n:, n:]
18+
19+
20+
def strassen(matrix_a, matrix_b):
21+
dim = matrix_a.shape[0]
22+
if dim % 4 != 0:
23+
return default_multiply(matrix_a, matrix_b)
24+
25+
A, B, C, D = matrix_split(matrix_a)
26+
E, F, G, H = matrix_split(matrix_b)
27+
28+
p1 = strassen(A + D, E + H)
29+
p2 = strassen(C + D, E)
30+
p3 = strassen(A, F - H)
31+
p4 = strassen(D, G - E)
32+
p5 = strassen(A + B, H)
33+
p6 = strassen(C - A, E + F)
34+
p7 = strassen(B - D, G + H)
35+
36+
top_left = p1 + p4 - p5 + p7
37+
top_right = p3 + p5
38+
bot_left = p2 + p4
39+
bot_right = p1 - p2 + p3 + p6
40+
41+
result = np.vstack((np.hstack((top_left, top_right)), np.hstack((bot_left, bot_right))))
42+
return result
43+
44+
45+
def pad_zeros(matrix):
46+
zeros = 0
47+
size = matrix.shape[0]
48+
while (size + zeros) % 4 != 0:
49+
zeros += 1
50+
pad_h = np.zeros((size, zeros), dtype='int')
51+
pad_v = np.zeros((zeros, size + zeros), dtype='int')
52+
53+
matrix = np.hstack((matrix, pad_h))
54+
return np.vstack((matrix, pad_v))
55+
56+
57+
def parallel_multiply_matrices(matrix_a, matrix_b):
58+
dim = matrix_a.shape[0]
59+
if dim % 4 != 0:
60+
matrix_a = pad_zeros(matrix_a)
61+
matrix_b = pad_zeros(matrix_b)
62+
63+
A, B, C, D = matrix_split(matrix_a)
64+
E, F, G, H = matrix_split(matrix_b)
65+
66+
workers = [(A + D, E + H), (C + D, E), (A, F - H), (D, G - E), (A + B, H), (C - A, E + F), (B - D, G + H)]
67+
68+
with multiprocessing.Pool() as pool:
69+
results = pool.starmap(strassen, workers)
70+
pool.close()
71+
pool.join()
72+
73+
top_left = results[0] + results[3] - results[4] + results[6]
74+
top_right = results[2] + results[4]
75+
bot_left = results[1] + results[3]
76+
bot_right = results[0] - results[1] + results[2] + results[5]
77+
78+
result = np.vstack((np.hstack((top_left, top_right)), np.hstack((bot_left, bot_right))))
79+
return result[:dim, :dim]
80+
81+
82+
if __name__=='__main__':
83+
n = 1000
84+
matrix1 = np.random.randint(0, 2, size=(n, n))
85+
matrix2 = np.random.randint(0, 2, size=(n, n))
86+
87+
#print(matrix1, '\n')
88+
#print(matrix2, '\n')
89+
start = time.time()
90+
result = parallel_multiply_matrices(matrix1, matrix2)
91+
end = time.time()
92+
print("Execution time: ", end-start)
93+
print("Result:")
94+
print(result, '\n')

0 commit comments

Comments
 (0)