Skip to content

Commit 6a2d970

Browse files
committed
Addressing changes
1 parent 9e1a49f commit 6a2d970

File tree

3 files changed

+57
-120
lines changed

3 files changed

+57
-120
lines changed

examples/plot_matrixmult.py

Lines changed: 12 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -27,9 +27,9 @@
2727
exit(-1)
2828

2929
# matrix dims
30-
M = 32
31-
K = 35
32-
N = 37
30+
M = 5
31+
K = 4
32+
N = 3
3333

3434
A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)
3535
B = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)
@@ -47,21 +47,16 @@
4747
# - :math:`B_{p} \in \mathbb{R}^{K\times \text{my\_own\_cols}}`
4848
# where
4949
blk_rows = int(math.ceil(M / P_prime))
50-
row_start = my_group * blk_rows
51-
row_end = min(M, row_start + blk_rows)
52-
my_own_rows = row_end - row_start
53-
5450
blk_cols = int(math.ceil(N / P_prime))
55-
col_start = my_layer * blk_cols
56-
col_end = min(N, col_start + blk_cols)
57-
my_own_cols = col_end - col_start
58-
5951

60-
rs = (rank % P_prime) * blk_rows
52+
rs = my_group * blk_rows
6153
re = min(M, rs + blk_rows)
54+
my_own_rows = re - rs
6255

63-
cs = (rank // P_prime) * blk_cols
56+
cs = my_layer * blk_cols
6457
ce = min(N, cs + blk_cols)
58+
my_own_cols = ce - cs
59+
6560
A_p, B_p = A[rs:re, :].copy(), B[:, cs:ce].copy()
6661

6762
Aop = MPIMatrixMult(A_p, N, dtype="float32")
@@ -81,19 +76,19 @@
8176
xadj_loc = (A.T.dot(y_loc.conj())).conj()
8277

8378

84-
expected_y_loc = y_loc[:, col_start:col_end].flatten().astype(np.float32)
85-
expected_xadj_loc = xadj_loc[:, col_start:col_end].flatten().astype(np.float32)
79+
expected_y_loc = y_loc[:, cs:ce].flatten().astype(np.float32)
80+
expected_xadj_loc = xadj_loc[:, cs:ce].flatten().astype(np.float32)
8681

8782
xadj = Aop.H @ y
8883
if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6):
8984
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
90-
print(f'{rank} local: {y.local_array}, expected: {y_loc[:, col_start:col_end]}')
85+
print(f'{rank} local: {y.local_array}, expected: {y_loc[:, cs:ce]}')
9186
else:
9287
print(f"RANK {rank}: FORWARD VERIFICATION PASSED")
9388

9489
if not np.allclose(xadj.local_array, expected_xadj_loc, rtol=1e-6):
9590
print(f"RANK {rank}: ADJOINT VERIFICATION FAILED")
96-
print(f'{rank} local: {xadj.local_array}, expected: {xadj_loc[:, col_start:col_end]}')
91+
print(f'{rank} local: {xadj.local_array}, expected: {xadj_loc[:, cs:ce]}')
9792
else:
9893
print(f"RANK {rank}: ADJOINT VERIFICATION PASSED")
9994

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 40 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ def __init__(
1616
self,
1717
A: NDArray,
1818
N: int,
19+
saveAt: bool = False,
1920
base_comm: MPI.Comm = MPI.COMM_WORLD,
2021
dtype: DTypeLike = "float64",
2122
) -> None:
@@ -25,113 +26,91 @@ def __init__(
2526
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
2627
self._P_prime = int(math.ceil(math.sqrt(size)))
2728
self._C = int(math.ceil(size / self._P_prime))
28-
if self._P_prime * self._C < size:
29+
if self._P_prime * self._C != size:
2930
raise Exception("Number of Procs must be a square number")
3031

3132
# Compute this process's group and layer indices
3233
self._group_id = rank % self._P_prime
3334
self._layer_id = rank // self._P_prime
3435

3536
# Split communicators by layer (rows) and by group (columns)
36-
self.base_comm = base_comm
37+
self.base_comm = base_comm
3738
self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
3839
self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
3940
self.A = A.astype(np.dtype(dtype))
41+
if saveAt: self.At = A.T.conj()
4042

4143
self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
4244
self.K = A.shape[1]
4345
self.N = N
4446

4547
# Determine how many columns each group holds
4648
block_cols = int(math.ceil(self.N / self._P_prime))
47-
local_col_start = self._group_id * block_cols
48-
local_col_end = min(self.N, local_col_start + block_cols)
49-
local_ncols = local_col_end - local_col_start
49+
blk_rows = int(math.ceil(self.M / self._P_prime))
5050

51-
# Sum up the total number of input columns across all processes
52-
total_ncols = base_comm.allreduce(local_ncols, op=MPI.SUM)
53-
self.dims = (self.K, total_ncols)
51+
self._row_start = self._group_id * blk_rows
52+
self._row_end = min(self.M, self._row_start + blk_rows)
53+
54+
self._col_start = self._layer_id * block_cols
55+
self._col_end = min(self.N, self._col_start + block_cols)
5456

55-
# Recompute how many output columns each layer holds
56-
layer_col_start = self._layer_id * block_cols
57-
layer_col_end = min(self.N, layer_col_start + block_cols)
58-
layer_ncols = layer_col_end - layer_col_start
59-
total_layer_cols = self.base_comm.allreduce(layer_ncols, op=MPI.SUM)
57+
self._local_ncols = self._col_end - self._col_start
58+
self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
59+
total_ncols = np.sum(self._rank_col_lens)
6060

61-
self.dimsd = (self.M, total_layer_cols)
61+
self.dims = (self.K, total_ncols)
62+
self.dimsd = (self.M, total_ncols)
6263
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
6364
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
64-
65+
6566
def _matvec(self, x: DistributedArray) -> DistributedArray:
6667
ncp = get_module(x.engine)
6768
if x.partition != Partition.SCATTER:
6869
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
69-
blk_cols = int(math.ceil(self.N / self._P_prime))
70-
col_start = self._layer_id * blk_cols
71-
col_end = min(self.N, col_start + blk_cols)
72-
my_own_cols = max(0, col_end - col_start)
73-
x = x.local_array.reshape((self.dims[0], my_own_cols))
74-
x = x.astype(self.dtype)
75-
76-
B_block = self._layer_comm.bcast(x if self._group_id == self._layer_id else None, root=self._layer_id)
77-
C_local = ncp.vstack(
70+
71+
my_own_cols = self._rank_col_lens[self.rank]
72+
x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
73+
x_arr = x_arr.astype(self.dtype)
74+
75+
X_local = self._layer_comm.bcast(x_arr if self._group_id == self._layer_id else None, root=self._layer_id)
76+
Y_local = ncp.vstack(
7877
self._layer_comm.allgather(
79-
ncp.matmul(self.A, B_block)
78+
ncp.matmul(self.A, X_local)
8079
)
8180
)
8281

83-
layer_col_start = self._layer_id * blk_cols
84-
layer_col_end = min(self.N, layer_col_start + blk_cols)
85-
layer_ncols = max(0, layer_col_end - layer_col_start)
86-
layer_col_lens = self.base_comm.allgather(layer_ncols)
87-
mask = [i // self._P_prime for i in range(self.size)]
88-
89-
y = DistributedArray(global_shape= (self.M * self.dimsd[1]),
90-
local_shapes=[(self.M * c) for c in layer_col_lens],
91-
mask=mask,
82+
y = DistributedArray(global_shape=(self.M * self.dimsd[1]),
83+
local_shapes=[(self.M * c) for c in self._rank_col_lens],
84+
mask=x.mask,
9285
partition=Partition.SCATTER,
9386
dtype=self.dtype)
94-
y[:] = C_local.flatten()
87+
y[:] = Y_local.flatten()
9588
return y
9689

9790
def _rmatvec(self, x: DistributedArray) -> DistributedArray:
9891
ncp = get_module(x.engine)
9992
if x.partition != Partition.SCATTER:
10093
raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
10194

102-
# Determine local column block for this layer
103-
blk_cols = int(math.ceil(self.N / self._P_prime))
104-
layer_col_start = self._layer_id * blk_cols
105-
layer_col_end = min(self.N, layer_col_start + blk_cols)
106-
layer_ncols = layer_col_end - layer_col_start
107-
layer_col_lens = self.base_comm.allgather(layer_ncols)
108-
x = x.local_array.reshape((self.M, layer_ncols)).astype(self.dtype)
109-
110-
# Determine local row block for this process group
111-
blk_rows = int(math.ceil(self.M / self._P_prime))
112-
row_start = self._group_id * blk_rows
113-
row_end = min(self.M, row_start + blk_rows)
114-
115-
B_tile = x[row_start:row_end, :].astype(self.dtype)
116-
A_local = self.A.T.conj().astype(self.dtype)
117-
118-
m, b = A_local.shape
119-
pad = (-m) % self._P_prime
120-
r = (m + pad) // self._P_prime
121-
A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=self.dtype.type(0.0))
95+
x_arr = x.local_array.reshape((self.M, self._local_ncols)).astype(self.dtype)
96+
X_tile = x_arr[self._row_start:self._row_end, :]
97+
98+
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
99+
m, b = A_local.shape
100+
pad = (-m) % self._P_prime
101+
r = (m + pad) // self._P_prime
102+
A_pad = np.pad(A_local, ((0, pad), (0, 0)), mode='constant', constant_values=self.dtype.type(0.0))
122103
A_batch = A_pad.reshape(self._P_prime, r, b)
123104

124-
# Perform local matmul and unpad
125-
Y_batch = ncp.matmul(A_batch, B_tile).astype(self.dtype)
126-
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
105+
Y_batch = ncp.matmul(A_batch, X_tile)
106+
Y_pad = Y_batch.reshape(r * self._P_prime, -1)
127107
y_local = Y_pad[:m, :]
128108
y_layer = self._layer_comm.allreduce(y_local, op=MPI.SUM)
129109

130-
mask = [i // self._P_prime for i in range(self.size)]
131110
y = DistributedArray(
132111
global_shape=(self.K * self.dimsd[1]),
133-
local_shapes=[self.K * c for c in layer_col_lens],
134-
mask=mask,
112+
local_shapes=[self.K * c for c in self._rank_col_lens],
113+
mask=x.mask,
135114
partition=Partition.SCATTER,
136115
dtype=self.dtype,
137116
)

tests/test_matrixmult.py

Lines changed: 5 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -53,13 +53,10 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
5353
my_own_rows_A = max(0, row_end_A - row_start_A)
5454

5555
blk_cols_BC = int(math.ceil(N / P_prime))
56-
col_start_B = my_group * blk_cols_BC
56+
col_start_B = my_layer * blk_cols_BC
5757
col_end_B = min(N, col_start_B + blk_cols_BC)
5858
my_own_cols_B = max(0, col_end_B - col_start_B)
5959

60-
# Initialize local matrices
61-
A_p = np.empty((my_own_rows_A, K), dtype=dtype)
62-
B_p = np.empty((K, my_own_cols_B), dtype=dtype)
6360

6461
A_glob_real = np.arange(M * K, dtype=base_float_dtype).reshape(M, K)
6562
A_glob_imag = np.arange(M * K, dtype=base_float_dtype).reshape(M, K) * 0.5
@@ -69,53 +66,19 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
6966
B_glob_imag = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) * 0.7
7067
B_glob = (B_glob_real + cmplx * B_glob_imag).astype(dtype)
7168

72-
if rank == 0:
73-
# Distribute matrix blocks to all ranks
74-
for dest_rank in range(size):
75-
dest_my_group = dest_rank % P_prime
76-
77-
# Calculate destination rank's block dimensions
78-
dest_row_start_A = dest_my_group * blk_rows_A
79-
dest_row_end_A = min(M, dest_row_start_A + blk_rows_A)
80-
dest_my_own_rows_A = max(0, dest_row_end_A - dest_row_start_A)
81-
82-
dest_col_start_B = dest_my_group * blk_cols_BC
83-
dest_col_end_B = min(N, dest_col_start_B + blk_cols_BC)
84-
dest_my_own_cols_B = max(0, dest_col_end_B - dest_col_start_B)
85-
86-
A_block_send = A_glob[dest_row_start_A:dest_row_end_A, :].copy()
87-
B_block_send = B_glob[:, dest_col_start_B:dest_col_end_B].copy()
88-
89-
# Validate block shapes
90-
assert A_block_send.shape == (dest_my_own_rows_A, K)
91-
assert B_block_send.shape == (K, dest_my_own_cols_B)
92-
93-
if dest_rank == 0:
94-
A_p, B_p = A_block_send, B_block_send
95-
else:
96-
if A_block_send.size > 0:
97-
comm.Send(A_block_send, dest=dest_rank, tag=100 + dest_rank)
98-
if B_block_send.size > 0:
99-
comm.Send(B_block_send, dest=dest_rank, tag=200 + dest_rank)
100-
else:
101-
if A_p.size > 0:
102-
comm.Recv(A_p, source=0, tag=100 + rank)
103-
if B_p.size > 0:
104-
comm.Recv(B_p, source=0, tag=200 + rank)
105-
106-
comm.Barrier()
69+
A_p = A_glob[row_start_A:row_end_A,:]
70+
B_p = B_glob[:,col_start_B:col_end_B]
10771

10872
# Create SUMMAMatrixMult operator
10973
Aop = MPIMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
11074

11175
# Create DistributedArray for input x (representing B flattened)
11276
all_my_own_cols_B = comm.allgather(my_own_cols_B)
113-
total_cols = sum(all_my_own_cols_B)
114-
local_shapes_x = [(K * cl_b,) for cl_b in all_my_own_cols_B]
77+
total_cols = np.sum(all_my_own_cols_B)
11578

11679
x_dist = DistributedArray(
11780
global_shape=(K * total_cols),
118-
local_shapes=local_shapes_x,
81+
local_shapes=[K * cl_b for cl_b in all_my_own_cols_B],
11982
partition=Partition.SCATTER,
12083
base_comm=comm,
12184
dtype=dtype

0 commit comments

Comments
 (0)