|
1 | 1 | """
|
2 | 2 | Distributed Matrix Multiplication
|
3 |
| -========================= |
4 |
| -This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMultiply.SUMMAMatrixMult`. |
| 3 | +================================= |
| 4 | +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MatrixMult.MPIMatrixMult`. |
5 | 5 | This class provides a way to distribute arrays across multiple processes in
|
6 | 6 | a parallel computing environment.
|
7 | 7 | """
|
8 |
| - |
| 8 | +from matplotlib import pyplot as plt |
9 | 9 | import math
|
10 | 10 | import numpy as np
|
11 | 11 | from mpi4py import MPI
|
12 | 12 |
|
13 | 13 | from pylops_mpi import DistributedArray, Partition
|
14 | 14 | from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
|
15 | 15 |
|
| 16 | +plt.close("all") |
| 17 | +############################################################################### |
| 18 | +# We set the seed such that all processes initially start out with the same initial matrix. |
| 19 | +# Ideally this data would be loaded in a manner appropriate to the use-case. |
16 | 20 | np.random.seed(42)
|
17 | 21 |
|
| 22 | +# MPI parameters |
18 | 23 | comm = MPI.COMM_WORLD
|
19 |
| -rank = comm.Get_rank() |
20 |
| -n_procs = comm.Get_size() |
| 24 | +rank = comm.Get_rank() # rank of current process |
| 25 | +size = comm.Get_size() # number of processes |
21 | 26 |
|
22 |
| -P_prime = int(math.ceil(math.sqrt(n_procs))) |
23 |
| -C = int(math.ceil(n_procs / P_prime)) |
| 27 | +p_prime = int(math.ceil(math.sqrt(size))) |
| 28 | +C = int(math.ceil(size / p_prime)) |
24 | 29 |
|
25 |
| -if (P_prime * C) != n_procs: |
| 30 | +if (p_prime * C) != size: |
26 | 31 | print("No. of procs has to be a square number")
|
27 | 32 | exit(-1)
|
28 | 33 |
|
29 | 34 | # matrix dims
|
30 |
| -M = 33 |
31 |
| -K = 34 |
32 |
| -N = 37 |
33 |
| - |
| 35 | +M, K, N = 4, 4, 4 |
34 | 36 | A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)
|
35 |
| -B = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N) |
36 |
| - |
37 |
| -my_group = rank % P_prime |
38 |
| -my_layer = rank // P_prime |
39 |
| - |
40 |
| -# sub‐communicators |
| 37 | +X = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N) |
| 38 | +################################################################################ |
| 39 | +#Process Grid Organization |
| 40 | +#************************* |
| 41 | +# |
| 42 | +#The processes are arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, where :math:`P` is the total number of processes. |
| 43 | +# |
| 44 | +#Define |
| 45 | +# |
| 46 | +#.. math:: |
| 47 | +# P' = \bigl \lceil \sqrt{P} \bigr \rceil |
| 48 | +# |
| 49 | +#and the replication factor |
| 50 | +# |
| 51 | +#.. math:: |
| 52 | +# C = \bigl\lceil \tfrac{P}{P'} \bigr\rceil. |
| 53 | +# |
| 54 | +#Each process is assigned a pair of coordinates :math:`(g, l)` within this grid: |
| 55 | +# |
| 56 | +#.. math:: |
| 57 | +# g = \mathrm{rank} \bmod P', |
| 58 | +# \quad |
| 59 | +# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor. |
| 60 | +# |
| 61 | +#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout: |
| 62 | +# |
| 63 | +#.. raw:: html |
| 64 | +# |
| 65 | +# <div style="text-align: center; font-family: monospace; white-space: pre;"> |
| 66 | +# ┌────────────┬────────────┐ |
| 67 | +# │ Rank 0 │ Rank 1 │ |
| 68 | +# │ (g=0, l=0) │ (g=1, l=0) │ |
| 69 | +# ├────────────┼────────────┤ |
| 70 | +# │ Rank 2 │ Rank 3 │ |
| 71 | +# │ (g=0, l=1) │ (g=1, l=1) │ |
| 72 | +# └────────────┴────────────┘ |
| 73 | +# </div> |
| 74 | + |
| 75 | +my_group = rank % p_prime |
| 76 | +my_layer = rank // p_prime |
| 77 | + |
| 78 | +# Create the sub‐communicators |
41 | 79 | layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
|
42 | 80 | group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group
|
43 | 81 |
|
44 |
| - |
45 |
| -#Each rank will end up with: |
46 |
| -# - :math:`A_{p} \in \mathbb{R}^{\text{my\_own\_rows}\times K}` |
47 |
| -# - :math:`B_{p} \in \mathbb{R}^{K\times \text{my\_own\_cols}}` |
48 |
| -# where |
49 |
| -blk_rows = int(math.ceil(M / P_prime)) |
50 |
| -blk_cols = int(math.ceil(N / P_prime)) |
| 82 | +blk_rows = int(math.ceil(M / p_prime)) |
| 83 | +blk_cols = int(math.ceil(N / p_prime)) |
51 | 84 |
|
52 | 85 | rs = my_group * blk_rows
|
53 | 86 | re = min(M, rs + blk_rows)
|
|
57 | 90 | ce = min(N, cs + blk_cols)
|
58 | 91 | my_own_cols = ce - cs
|
59 | 92 |
|
60 |
| -A_p, B_p = A[rs:re, :].copy(), B[:, cs:ce].copy() |
61 |
| - |
| 93 | +################################################################################ |
| 94 | +#Each rank will end up with: |
| 95 | +# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}` |
| 96 | +# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}` |
| 97 | +#as follows: |
| 98 | +A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy() |
| 99 | + |
| 100 | +################################################################################ |
| 101 | +#.. raw:: html |
| 102 | +# |
| 103 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 104 | +# <b>Matrix A (4 x 4):</b> |
| 105 | +# ┌─────────────────┐ |
| 106 | +# │ a11 a12 a13 a14 │ <- Rows 0–1 (Group 0) |
| 107 | +# │ a21 a22 a23 a24 │ |
| 108 | +# ├─────────────────┤ |
| 109 | +# │ a41 a42 a43 a44 │ <- Rows 2–3 (Group 1) |
| 110 | +# │ a51 a52 a53 a54 │ |
| 111 | +# └─────────────────┘ |
| 112 | +# </div> |
| 113 | +# |
| 114 | +#.. raw:: html |
| 115 | +# |
| 116 | +# <div style="text-align: left; font-family: monospace; white-space: pre;"> |
| 117 | +# <b>Matrix B (4 x 4):</b> |
| 118 | +# ┌─────────┬─────────┐ |
| 119 | +# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1) |
| 120 | +# │ b21 b22 │ b23 b24 │ |
| 121 | +# │ b31 b32 │ b33 b34 │ |
| 122 | +# │ b41 b42 │ b43 b44 │ |
| 123 | +# └─────────┴─────────┘ |
| 124 | +# |
| 125 | +# </div> |
| 126 | +# |
| 127 | + |
| 128 | +################################################################################ |
| 129 | +#Forward Operation |
| 130 | +#***************** |
| 131 | +#To perform our distributed matrix-matrix multiplication :math:`Y = \text{Aop} \times X` we need to create our distributed operator :math:`\text{Aop}` and distributed operand :math:`X` from :math:`A_p` and |
| 132 | +#:math:`X_p` respectively |
62 | 133 | Aop = MPIMatrixMult(A_p, N, dtype="float32")
|
| 134 | +################################################################################ |
| 135 | +# While as well passing the appropriate values. |
63 | 136 | col_lens = comm.allgather(my_own_cols)
|
64 | 137 | total_cols = np.sum(col_lens)
|
65 | 138 | x = DistributedArray(global_shape=K * total_cols,
|
66 | 139 | local_shapes=[K * col_len for col_len in col_lens],
|
67 | 140 | partition=Partition.SCATTER,
|
68 |
| - mask=[i % P_prime for i in range(comm.Get_size())], |
| 141 | + mask=[i // p_prime for i in range(comm.Get_size())], |
69 | 142 | base_comm=comm,
|
70 | 143 | dtype="float32")
|
71 |
| -x[:] = B_p.flatten() |
| 144 | +x[:] = X_p.flatten() |
| 145 | +################################################################################ |
| 146 | +#When we perform the matrix-matrix multiplication we shall then obtain a distributed :math:`Y` in the same way our :math:`X` was distributed. |
72 | 147 | y = Aop @ x
|
73 |
| - |
74 |
| -# ======================= VERIFICATION =================-============= |
75 |
| -y_loc = A @ B |
| 148 | +############################################################################### |
| 149 | +#Adjoint Operation |
| 150 | +#***************** |
| 151 | +# In a similar fashion we then perform the Adjoint :math:`Xadj = A^H * Y` |
| 152 | +xadj = Aop.H @ y |
| 153 | +############################################################################### |
| 154 | +#Here we verify the result against the equivalent serial version of the operation. Each rank checks that it has computed the correct values for it partition. |
| 155 | +y_loc = A @ X |
76 | 156 | xadj_loc = (A.T.dot(y_loc.conj())).conj()
|
77 | 157 |
|
78 |
| - |
79 | 158 | expected_y_loc = y_loc[:, cs:ce].flatten().astype(np.float32)
|
80 | 159 | expected_xadj_loc = xadj_loc[:, cs:ce].flatten().astype(np.float32)
|
81 | 160 |
|
82 |
| -xadj = Aop.H @ y |
83 | 161 | if not np.allclose(y.local_array, expected_y_loc, rtol=1e-6):
|
84 | 162 | print(f"RANK {rank}: FORWARD VERIFICATION FAILED")
|
85 | 163 | print(f'{rank} local: {y.local_array}, expected: {y_loc[:, cs:ce]}')
|
|
0 commit comments