Skip to content

Commit 7ac593d

Browse files
committed
Added comments to example
1 parent 66f1770 commit 7ac593d

File tree

1 file changed

+112
-34
lines changed

1 file changed

+112
-34
lines changed

examples/plot_matrixmult.py

Lines changed: 112 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -1,53 +1,86 @@
11
"""
22
Distributed Matrix Multiplication
3-
=========================
4-
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMultiply.SUMMAMatrixMult`.
3+
=================================
4+
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMult.MPIMatrixMult`.
55
This class provides a way to distribute arrays across multiple processes in
66
a parallel computing environment.
77
"""
8-
8+
from matplotlib import pyplot as plt
99
import math
1010
import numpy as np
1111
from mpi4py import MPI
1212

1313
from pylops_mpi import DistributedArray, Partition
1414
from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
1515

16+
plt.close("all")
17+
###############################################################################
18+
# We set the seed such that all processes initially start out with the same initial matrix.
19+
# Ideally this data would be loaded in a manner appropriate to the use-case.
1620
np.random.seed(42)
1721

22+
# MPI parameters
1823
comm = MPI.COMM_WORLD
19-
rank = comm.Get_rank()
20-
n_procs = comm.Get_size()
24+
rank = comm.Get_rank() # rank of current process
25+
size = comm.Get_size() # number of processes
2126

22-
P_prime = int(math.ceil(math.sqrt(n_procs)))
23-
C = int(math.ceil(n_procs / P_prime))
27+
p_prime = int(math.ceil(math.sqrt(size)))
28+
C = int(math.ceil(size / p_prime))
2429

25-
if (P_prime * C) != n_procs:
30+
if (p_prime * C) != size:
2631
print("No. of procs has to be a square number")
2732
exit(-1)
2833

2934
# matrix dims
30-
M = 33
31-
K = 34
32-
N = 37
33-
35+
M, K, N = 4, 4, 4
3436
A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)
35-
B = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)
36-
37-
my_group = rank % P_prime
38-
my_layer = rank // P_prime
39-
40-
# sub‐communicators
37+
X = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)
38+
################################################################################
39+
#Process Grid Organization
40+
#*************************
41+
#
42+
#The processes are arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, where :math:`P` is the total number of processes.
43+
#
44+
#Define
45+
#
46+
#.. math::
47+
# P' = \bigl \lceil \sqrt{P} \bigr \rceil
48+
#
49+
#and the replication factor
50+
#
51+
#.. math::
52+
# C = \bigl\lceil \tfrac{P}{P'} \bigr\rceil.
53+
#
54+
#Each process is assigned a pair of coordinates :math:`(g, l)` within this grid:
55+
#
56+
#.. math::
57+
# g = \mathrm{rank} \bmod P',
58+
# \quad
59+
# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor.
60+
#
61+
#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
62+
#
63+
#.. raw:: html
64+
#
65+
# <div style="text-align: center; font-family: monospace; white-space: pre;">
66+
# ┌────────────┬────────────┐
67+
# │ Rank 0 │ Rank 1 │
68+
# │ (g=0, l=0) │ (g=1, l=0) │
69+
# ├────────────┼────────────┤
70+
# │ Rank 2 │ Rank 3 │
71+
# │ (g=0, l=1) │ (g=1, l=1) │
72+
# └────────────┴────────────┘
73+
# </div>
74+
75+
my_group = rank % p_prime
76+
my_layer = rank // p_prime
77+
78+
# Create the sub‐communicators
4179
layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
4280
group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group
4381

44-
45-
#Each rank will end up with:
46-
# - :math:`A_{p} \in \mathbb{R}^{\text{my\_own\_rows}\times K}`
47-
# - :math:`B_{p} \in \mathbb{R}^{K\times \text{my\_own\_cols}}`
48-
# where
49-
blk_rows = int(math.ceil(M / P_prime))
50-
blk_cols = int(math.ceil(N / P_prime))
82+
blk_rows = int(math.ceil(M / p_prime))
83+
blk_cols = int(math.ceil(N / p_prime))
5184

5285
rs = my_group * blk_rows
5386
re = min(M, rs + blk_rows)
@@ -57,29 +90,74 @@
5790
ce = min(N, cs + blk_cols)
5891
my_own_cols = ce - cs
5992

60-
A_p, B_p = A[rs:re, :].copy(), B[:, cs:ce].copy()
61-
93+
################################################################################
94+
#Each rank will end up with:
95+
# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}`
96+
# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}`
97+
#as follows:
98+
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
99+
100+
################################################################################
101+
#.. raw:: html
102+
#
103+
# <div style="text-align: left; font-family: monospace; white-space: pre;">
104+
# <b>Matrix A (4 x 4):</b>
105+
# ┌─────────────────┐
106+
# │ a11 a12 a13 a14 │ <- Rows 0–1 (Group 0)
107+
# │ a21 a22 a23 a24 │
108+
# ├─────────────────┤
109+
# │ a41 a42 a43 a44 │ <- Rows 2–3 (Group 1)
110+
# │ a51 a52 a53 a54 │
111+
# └─────────────────┘
112+
# </div>
113+
#
114+
#.. raw:: html
115+
#
116+
# <div style="text-align: left; font-family: monospace; white-space: pre;">
117+
# <b>Matrix B (4 x 4):</b>
118+
# ┌─────────┬─────────┐
119+
# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1)
120+
# │ b21 b22 │ b23 b24 │
121+
# │ b31 b32 │ b33 b34 │
122+
# │ b41 b42 │ b43 b44 │
123+
# └─────────┴─────────┘
124+
#
125+
# </div>
126+
#
127+
128+
################################################################################
129+
#Forward Operation
130+
#*****************
131+
#To perform our distributed matrix-matrix multiplication :math:`Y = \text{Aop} \times X` we need to create our distributed operator :math:`\text{Aop}` and distributed operand :math:`X` from :math:`A_p` and
132+
#:math:`X_p` respectively
62133
Aop = MPIMatrixMult(A_p, N, dtype="float32")
134+
################################################################################
135+
# While as well passing the appropriate values.
63136
col_lens = comm.allgather(my_own_cols)
64137
total_cols = np.sum(col_lens)
65138
x = DistributedArray(global_shape=K * total_cols,
66139
local_shapes=[K * col_len for col_len in col_lens],
67140
partition=Partition.SCATTER,
68-
mask=[i % P_prime for i in range(comm.Get_size())],
141+
mask=[i // p_prime for i in range(comm.Get_size())],
69142
base_comm=comm,
70143
dtype="float32")
71-
x[:] = B_p.flatten()
144+
x[:] = X_p.flatten()
145+
################################################################################
146+
#When we perform the matrix-matrix multiplication we shall then obtain a distributed :math:`Y` in the same way our :math:`X` was distributed.
72147
y = Aop @ x
73-
74-
# ======================= VERIFICATION =================-=============
75-
y_loc = A @ B
148+
###############################################################################
149+
#Adjoint Operation
150+
#*****************
151+
# In a similar fashion we then perform the Adjoint :math:`Xadj = A^H * Y`
152+
xadj = Aop.H @ y
153+
###############################################################################
154+
#Here we verify the result against the equivalent serial version of the operation. Each rank checks that it has computed the correct values for it partition.
155+
y_loc = A @ X
76156
xadj_loc = (A.T.dot(y_loc.conj())).conj()
77157

78-
79158
expected_y_loc = y_loc[:, cs:ce].flatten().astype(np.float32)
80159
expected_xadj_loc = xadj_loc[:, cs:ce].flatten().astype(np.float32)
81160

82-
xadj = Aop.H @ y
83161
if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6):
84162
print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
85163
print(f'{rank} local: {y.local_array}, expected: {y_loc[:, cs:ce]}')

0 commit comments

Comments
 (0)