diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst index 3c39cc0a..66e1a373 100644 --- a/docs/source/api/index.rst +++ b/docs/source/api/index.rst @@ -42,6 +42,7 @@ Basic Operators .. autosummary:: :toctree: generated/ + MPIMatrixMult MPIBlockDiag MPIStackedBlockDiag MPIVStack diff --git a/examples/plot_matrixmult.py b/examples/plot_matrixmult.py new file mode 100644 index 00000000..413206f4 --- /dev/null +++ b/examples/plot_matrixmult.py @@ -0,0 +1,223 @@ +""" +Distributed Matrix Multiplication +================================= +This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` +operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}` +blocked over rows (i.e., blocks of rows are stored over different ranks) and a +matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are +stored over different ranks), with equal number of row and column blocks. +Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}` +blocked in the same fashion of matrix :math:`\mathbf{X}. + +Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly +stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is +effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where +the different blocks are flattened and stored on different ranks. Note that to +optimize communications, the ranks are organized in a 2D grid and some of the +row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are +replicated across different ranks - see below for details. + +""" +from matplotlib import pyplot as plt +import math +import numpy as np +from mpi4py import MPI + +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult + +plt.close("all") + +############################################################################### +# We set the seed such that all processes can create the input matrices filled +# with the same random number. In practical application, such matrices will be +# filled with data that is appropriate that is appropriate the use-case. +np.random.seed(42) + +############################################################################### +# Next we obtain the MPI parameters for each rank and check that the number +# of processes (``size``) is a square number +comm = MPI.COMM_WORLD +rank = comm.Get_rank() # rank of current process +size = comm.Get_size() # number of processes + +p_prime = math.isqrt(size) +repl_factor = p_prime + +if (p_prime * repl_factor) != size: + print(f"Number of processes must be a square number, provided {size} instead...") + exit(-1) + +############################################################################### +# We are now ready to create the input matrices :math:`\mathbf{A}` of size +# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size +# :math:`K \times N`. +N, K, M = 4, 4, 4 +A = np.random.rand(N * K).astype(dtype=np.float32).reshape(N, K) +X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M) + +################################################################################ +# The processes are now arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid, +# where :math:`P` is the total number of processes. +# +# We define +# +# .. math:: +# P' = \bigl \lceil \sqrt{P} \bigr \rceil +# +# and the replication factor +# +# .. math:: +# R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil. +# +# Each process is therefore assigned a pair of coordinates +# :math:`(g, l)` within this grid: +# +# .. math:: +# g = \mathrm{rank} \bmod P', +# \quad +# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor. +# +#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout: +# +#.. raw:: html +# +#
+# ┌────────────┬────────────┐ +# │ Rank 0 │ Rank 1 │ +# │ (g=0, l=0) │ (g=1, l=0) │ +# ├────────────┼────────────┤ +# │ Rank 2 │ Rank 3 │ +# │ (g=0, l=1) │ (g=1, l=1) │ +# └────────────┴────────────┘ +#
+ +my_group = rank % p_prime +my_layer = rank // p_prime + +# Create sub‐communicators +layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer +group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group + +################################################################################ +# At this point we divide the rows and columns of :math:`\mathbf{A}` and +# :math:`\mathbf{X}`, respectively, such that each rank ends up with: +# +# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}` +# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}` +# +# .. raw:: html +# +#
+# Matrix A (4 x 4): +# ┌─────────────────┐ +# │ a11 a12 a13 a14 │ <- Rows 0–1 (Group 0) +# │ a21 a22 a23 a24 │ +# ├─────────────────┤ +# │ a41 a42 a43 a44 │ <- Rows 2–3 (Group 1) +# │ a51 a52 a53 a54 │ +# └─────────────────┘ +#
+# +# .. raw:: html +# +#
+# Matrix X (4 x 4): +# ┌─────────┬─────────┐ +# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1) +# │ b21 b22 │ b23 b24 │ +# │ b31 b32 │ b33 b34 │ +# │ b41 b42 │ b43 b44 │ +# └─────────┴─────────┘ +#
+# + +blk_rows = int(math.ceil(N / p_prime)) +blk_cols = int(math.ceil(M / p_prime)) + +rs = my_group * blk_rows +re = min(N, rs + blk_rows) +my_own_rows = re - rs + +cs = my_layer * blk_cols +ce = min(M, cs + blk_cols) +my_own_cols = ce - cs + +A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy() + +################################################################################ +# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult` +# operator and the input matrix math:`\mathbf{X}` +Aop = MPIMatrixMult(A_p, M, dtype="float32") + +col_lens = comm.allgather(my_own_cols) +total_cols = np.sum(col_lens) +x = DistributedArray(global_shape=K * total_cols, + local_shapes=[K * col_len for col_len in col_lens], + partition=Partition.SCATTER, + mask=[i % p_prime for i in range(comm.Get_size())], + base_comm=comm, + dtype="float32") +x[:] = X_p.flatten() + +################################################################################ +# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively +# implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`) +# Note :math:`\mathbf{Y}` is distributed in the same way as the input +# :math:`\mathbf{X}`. +y = Aop @ x + +############################################################################### +# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}` +# (which effectively implements a distributed matrix-matrix multiplication +# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that +# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input +# :math:`\mathbf{X}`. +xadj = Aop.H @ y + +############################################################################### +# To conclude we verify our result against the equivalent serial version of +# the operation by gathering the resulting matrices in rank0 and reorganizing +# the returned 1D-arrays into 2D-arrays. + +# Local benchmarks +y = y.asarray(masked=True) +col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)] +y_blocks = [] +offset = 0 +for cnt in col_counts: + block_size = N * cnt + y_blocks.append( + y[offset: offset + block_size].reshape(N, cnt) + ) + offset += block_size +y = np.hstack(y_blocks) + +xadj = xadj.asarray(masked=True) +xadj_blocks = [] +offset = 0 +for cnt in col_counts: + block_size = K * cnt + xadj_blocks.append( + xadj[offset: offset + block_size].reshape(K, cnt) + ) + offset += block_size +xadj = np.hstack(xadj_blocks) + +if rank == 0: + y_loc = (A @ X).squeeze() + xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze() + + if not np.allclose(y, y_loc, rtol=1e-6): + print(f" FORWARD VERIFICATION FAILED") + print(f'distributed: {y}') + print(f'expected: {y_loc}') + else: + print(f"FORWARD VERIFICATION PASSED") + + if not np.allclose(xadj, xadj_loc, rtol=1e-6): + print(f" ADJOINT VERIFICATION FAILED") + print(f'distributed: {xadj}') + print(f'expected: {xadj_loc}') + else: + print(f"ADJOINT VERIFICATION PASSED") diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py new file mode 100644 index 00000000..ac15e2e5 --- /dev/null +++ b/pylops_mpi/basicoperators/MatrixMult.py @@ -0,0 +1,213 @@ +import numpy as np +import math +from mpi4py import MPI +from pylops.utils.backend import get_module +from pylops.utils.typing import DTypeLike, NDArray + +from pylops_mpi import ( + DistributedArray, + MPILinearOperator, + Partition +) + + +class MPIMatrixMult(MPILinearOperator): + r"""MPI Matrix multiplication + + Implement distributed matrix-matrix multiplication between a matrix + :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored + over different ranks) and the input model and data vector, which are both to + be interpreted as matrices blocked over columns. + + Parameters + ---------- + A : :obj:`numpy.ndarray` + Local block of the matrix of shape :math:`[M_{loc} \times K]` + where ``M_loc`` is the number of rows stored on this MPI rank and + ``K`` is the global number of columns. + M : :obj:`int` + Global leading dimension (i.e., number of columns) of the matrices + representing the input model and data vectors. + saveAt : :obj:`bool`, optional + Save ``A`` and ``A.H`` to speed up the computation of adjoint + (``True``) or create ``A.H`` on-the-fly (``False``) + Note that ``saveAt=True`` will double the amount of required memory. + Default is ``False``. + base_comm : :obj:`mpi4py.MPI.Comm`, optional + MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``. + dtype : :obj:`str`, optional + Type of elements in input array. + + Attributes + ---------- + shape : :obj:`tuple` + Operator shape + + Raises + ------ + Exception + If the operator is created without a square number of mpi ranks. + ValueError + If input vector does not have the correct partition type. + + Notes + ----- + This operator performs a matrix-matrix multiplication, whose forward + operation can be described as :math:`Y = A \cdot X` where: + + - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]` + - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]` + - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]` + + whilst the adjoint operation is represented by + :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where + :math:`\mathbf{A}^H` is the complex conjugate and transpose of :math:`\mathbf{A}`. + + This implementation is based on a 1D block distribution of the operator + matrix and reshaped model and data vectors replicated across math:`P` + processes by a factor equivalent to :math:`\sqrt{P}` across a square process + grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically: + + - The matrix ``A`` is distributed across MPI processes in a block-row fashion + and each process holds a local block of ``A`` with shape + :math:`[N_{loc} \times K]` + - The operand matrix ``X`` is distributed in a block-column fashion and + each process holds a local block of ``X`` with shape + :math:`[K \times M_{loc}]` + - Communication is minimized by using a 2D process grid layout + + **Forward Operation step-by-step** + + 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X`` + of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local`` + is the number of columns assigned to the current process. + + 2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``), + the operand data is broadcast from the process whose ``group_id`` matches + the ``layer_id``. This ensures all processes in a layer have access to + the same operand columns. + + 3. **Local Computation**: Each process computes ``A_local @ X_local`` where: + - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``) + - ``X_local`` is the broadcasted operand (shape ``K x M_local``) + + 4. **Layer Gather**: Results from all processes in each layer are gathered + using ``allgather`` to reconstruct the full result matrix vertically. + + **Adjoint Operation step-by-step** + + The adjoint operation performs the conjugate transpose multiplication: + + 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)`` + representing the local columns of the input matrix. + + 2. **Local Adjoint Computation**: + Each process computes ``A_local.H @ X_tile`` + where ``A_local.H`` is either: + - Pre-computed ``At`` (if ``saveAt=True``) + - Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``) + Each process multiplies its transposed local ``A`` block ``A_local^H`` + (shape ``K x N_block``) + with the extracted ``X_tile`` (shape ``N_block x M_local``), + producing a partial result of shape ``(K, M_local)``. + This computes the local contribution of columns of ``A^H`` to the final result. + + 3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the + sum of contributions from all column blocks of ``A^H``, processes in the + same layer perform an ``allreduce`` sum to combine their partial results. + This gives the complete ``(K, M_local)`` result for their assigned columns. + + """ + def __init__( + self, + A: NDArray, + M: int, + saveAt: bool = False, + base_comm: MPI.Comm = MPI.COMM_WORLD, + dtype: DTypeLike = "float64", + ) -> None: + rank = base_comm.Get_rank() + size = base_comm.Get_size() + + # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size + self._P_prime = math.isqrt(size) + self._C = self._P_prime + if self._P_prime * self._C != size: + raise Exception(f"Number of processes must be a square number, provided {size} instead...") + + # Compute this process's group and layer indices + self._group_id = rank % self._P_prime + self._layer_id = rank // self._P_prime + + # Split communicators by layer (rows) and by group (columns) + self.base_comm = base_comm + self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id) + self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id) + + self.A = A.astype(np.dtype(dtype)) + if saveAt: self.At = A.T.conj() + + self.N = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM) + self.K = A.shape[1] + self.M = M + + block_cols = int(math.ceil(self.M / self._P_prime)) + blk_rows = int(math.ceil(self.N / self._P_prime)) + + self._row_start = self._group_id * blk_rows + self._row_end = min(self.N, self._row_start + blk_rows) + + self._col_start = self._layer_id * block_cols + self._col_end = min(self.M, self._col_start + block_cols) + + self._local_ncols = self._col_end - self._col_start + self._rank_col_lens = self.base_comm.allgather(self._local_ncols) + total_ncols = np.sum(self._rank_col_lens) + + self.dims = (self.K, total_ncols) + self.dimsd = (self.N, total_ncols) + shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims))) + super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm) + + def _matvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...") + + y = DistributedArray(global_shape=(self.N * self.dimsd[1]), + local_shapes=[(self.N * c) for c in self._rank_col_lens], + mask=x.mask, + partition=Partition.SCATTER, + dtype=self.dtype) + + my_own_cols = self._rank_col_lens[self.rank] + x_arr = x.local_array.reshape((self.dims[0], my_own_cols)) + X_local = x_arr.astype(self.dtype) + Y_local = ncp.vstack( + self._layer_comm.allgather( + ncp.matmul(self.A, X_local) + ) + ) + y[:] = Y_local.flatten() + return y + + def _rmatvec(self, x: DistributedArray) -> DistributedArray: + ncp = get_module(x.engine) + if x.partition != Partition.SCATTER: + raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.") + + y = DistributedArray( + global_shape=(self.K * self.dimsd[1]), + local_shapes=[self.K * c for c in self._rank_col_lens], + mask=x.mask, + partition=Partition.SCATTER, + dtype=self.dtype, + ) + + x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype) + X_tile = x_arr[self._row_start:self._row_end, :] + A_local = self.At if hasattr(self, "At") else self.A.T.conj() + Y_local = ncp.matmul(A_local, X_tile) + y_layer = self._layer_comm.allreduce(Y_local, op=MPI.SUM) + y[:] = y_layer.flatten() + return y diff --git a/pylops_mpi/basicoperators/__init__.py b/pylops_mpi/basicoperators/__init__.py index 566b59fb..8db5a988 100644 --- a/pylops_mpi/basicoperators/__init__.py +++ b/pylops_mpi/basicoperators/__init__.py @@ -7,6 +7,7 @@ functionalities using MPI. A list of operators present in pylops_mpi.basicoperators: + MPIMatrixMult Matrix Multiplication operator MPIBlockDiag Block Diagonal arrangement of PyLops operators MPIStackedBlockDiag Block Diagonal arrangement of PyLops-MPI operators MPIVStack Vertical Stacking of PyLops operators @@ -19,6 +20,7 @@ """ +from .MatrixMult import * from .BlockDiag import * from .VStack import * from .HStack import * @@ -28,6 +30,7 @@ from .Gradient import * __all__ = [ + "MPIMatrixMult", "MPIBlockDiag", "MPIStackedBlockDiag", "MPIVStack", diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py new file mode 100644 index 00000000..61f6ee7e --- /dev/null +++ b/tests/test_matrixmult.py @@ -0,0 +1,135 @@ +import pytest +import numpy as np +from numpy.testing import assert_allclose +from mpi4py import MPI +import math +import sys + +from pylops_mpi import DistributedArray, Partition +from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult + +np.random.seed(42) + +comm = MPI.COMM_WORLD +rank = comm.Get_rank() +size = comm.Get_size() + +# Define test cases: (N K, M, dtype_str) +# M, K, N are matrix dimensions A(N,K), B(K,M) +# P_prime will be ceil(sqrt(size)). +test_params = [ + pytest.param(37, 37, 37, "float32", id="f32_37_37_37"), + pytest.param(50, 30, 40, "float64", id="f64_50_30_40"), + pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"), + pytest.param( 3, 4, 5, "float32", id="f32_3_4_5"), + pytest.param( 1, 2, 1, "float64", id="f64_1_2_1",), + pytest.param( 2, 1, 3, "float32", id="f32_2_1_3",), +] + + +@pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process. +@pytest.mark.parametrize("M, K, N, dtype_str", test_params) +def test_SUMMAMatrixMult(N, K, M, dtype_str): + p_prime = math.isqrt(size) + C = p_prime + if p_prime * C != size: + pytest.skip(f"Number of processes must be a square number, provided {size} instead...") + + dtype = np.dtype(dtype_str) + + cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0 + base_float_dtype = np.float32 if dtype == np.complex64 else np.float64 + + my_group = rank % p_prime + my_layer = rank // p_prime + + # Create sub-communicators + layer_comm = comm.Split(color=my_layer, key=my_group) + group_comm = comm.Split(color=my_group, key=my_layer) + + # Calculate local matrix dimensions + blk_rows_A = int(math.ceil(N / p_prime)) + row_start_A = my_group * blk_rows_A + row_end_A = min(N, row_start_A + blk_rows_A) + + blk_cols_X = int(math.ceil(M / p_prime)) + col_start_X = my_layer * blk_cols_X + col_end_X = min(M, col_start_X + blk_cols_X) + local_col_X_len = max(0, col_end_X - col_start_X) + + A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) + A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5 + A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype) + + X_glob_real = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) + X_glob_imag = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7 + X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype) + + A_p = A_glob[row_start_A:row_end_A,:] + X_p = X_glob[:,col_start_X:col_end_X] + + # Create MPIMatrixMult operator + Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str) + + # Create DistributedArray for input x (representing B flattened) + all_local_col_len = comm.allgather(local_col_X_len) + total_cols = np.sum(all_local_col_len) + + x_dist = DistributedArray( + global_shape=(K * total_cols), + local_shapes=[K * cl_b for cl_b in all_local_col_len], + partition=Partition.SCATTER, + base_comm=comm, + mask=[i % p_prime for i in range(size)], + dtype=dtype + ) + + x_dist.local_array[:] = X_p.ravel() + + # Forward operation: y = A @ B (distributed) + y_dist = Aop @ x_dist + # Adjoint operation: xadj = A.H @ y (distributed) + xadj_dist = Aop.H @ y_dist + + y = y_dist.asarray(masked=True) + col_counts = [min(blk_cols_X, M - j * blk_cols_X) for j in range(p_prime)] + y_blocks = [] + offset = 0 + for cnt in col_counts: + block_size = N * cnt + y_blocks.append( + y[offset: offset + block_size].reshape(N, cnt) + ) + offset += block_size + y = np.hstack(y_blocks) + + xadj = xadj_dist.asarray(masked=True) + xadj_blocks = [] + offset = 0 + for cnt in col_counts: + block_size = K * cnt + xadj_blocks.append( + xadj[offset: offset + block_size].reshape(K, cnt) + ) + offset += block_size + xadj = np.hstack(xadj_blocks) + + if rank == 0: + y_loc = A_glob @ X_glob + assert_allclose( + y.squeeze(), + y_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Forward verification failed." + ) + + xadj_loc = A_glob.conj().T @ y_loc + assert_allclose( + xadj.squeeze(), + xadj_loc.squeeze(), + rtol=np.finfo(np.dtype(dtype)).resolution, + err_msg=f"Rank {rank}: Ajoint verification failed." + ) + + group_comm.Free() + layer_comm.Free() \ No newline at end of file