diff --git a/docs/source/api/index.rst b/docs/source/api/index.rst
index 3c39cc0a..66e1a373 100644
--- a/docs/source/api/index.rst
+++ b/docs/source/api/index.rst
@@ -42,6 +42,7 @@ Basic Operators
.. autosummary::
:toctree: generated/
+ MPIMatrixMult
MPIBlockDiag
MPIStackedBlockDiag
MPIVStack
diff --git a/examples/plot_matrixmult.py b/examples/plot_matrixmult.py
new file mode 100644
index 00000000..413206f4
--- /dev/null
+++ b/examples/plot_matrixmult.py
@@ -0,0 +1,223 @@
+"""
+Distributed Matrix Multiplication
+=================================
+This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
+operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
+blocked over rows (i.e., blocks of rows are stored over different ranks) and a
+matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are
+stored over different ranks), with equal number of row and column blocks.
+Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}`
+blocked in the same fashion of matrix :math:`\mathbf{X}.
+
+Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly
+stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is
+effectively represented by a 1-D :py:class:`pylops_mpi.DistributedArray` where
+the different blocks are flattened and stored on different ranks. Note that to
+optimize communications, the ranks are organized in a 2D grid and some of the
+row blocks of :math:`\mathbf{A}` and column blocks of :math:`\mathbf{X}` are
+replicated across different ranks - see below for details.
+
+"""
+from matplotlib import pyplot as plt
+import math
+import numpy as np
+from mpi4py import MPI
+
+from pylops_mpi import DistributedArray, Partition
+from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
+
+plt.close("all")
+
+###############################################################################
+# We set the seed such that all processes can create the input matrices filled
+# with the same random number. In practical application, such matrices will be
+# filled with data that is appropriate that is appropriate the use-case.
+np.random.seed(42)
+
+###############################################################################
+# Next we obtain the MPI parameters for each rank and check that the number
+# of processes (``size``) is a square number
+comm = MPI.COMM_WORLD
+rank = comm.Get_rank() # rank of current process
+size = comm.Get_size() # number of processes
+
+p_prime = math.isqrt(size)
+repl_factor = p_prime
+
+if (p_prime * repl_factor) != size:
+ print(f"Number of processes must be a square number, provided {size} instead...")
+ exit(-1)
+
+###############################################################################
+# We are now ready to create the input matrices :math:`\mathbf{A}` of size
+# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size
+# :math:`K \times N`.
+N, K, M = 4, 4, 4
+A = np.random.rand(N * K).astype(dtype=np.float32).reshape(N, K)
+X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M)
+
+################################################################################
+# The processes are now arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid,
+# where :math:`P` is the total number of processes.
+#
+# We define
+#
+# .. math::
+# P' = \bigl \lceil \sqrt{P} \bigr \rceil
+#
+# and the replication factor
+#
+# .. math::
+# R = \bigl\lceil \tfrac{P}{P'} \bigr\rceil.
+#
+# Each process is therefore assigned a pair of coordinates
+# :math:`(g, l)` within this grid:
+#
+# .. math::
+# g = \mathrm{rank} \bmod P',
+# \quad
+# l = \left\lfloor \frac{\mathrm{rank}}{P'} \right\rfloor.
+#
+#For example, when :math:`P = 4` we have :math:`P' = 2`, giving a 2×2 layout:
+#
+#.. raw:: html
+#
+#
+# ┌────────────┬────────────┐
+# │ Rank 0 │ Rank 1 │
+# │ (g=0, l=0) │ (g=1, l=0) │
+# ├────────────┼────────────┤
+# │ Rank 2 │ Rank 3 │
+# │ (g=0, l=1) │ (g=1, l=1) │
+# └────────────┴────────────┘
+#
+
+my_group = rank % p_prime
+my_layer = rank // p_prime
+
+# Create sub‐communicators
+layer_comm = comm.Split(color=my_layer, key=my_group) # all procs in same layer
+group_comm = comm.Split(color=my_group, key=my_layer) # all procs in same group
+
+################################################################################
+# At this point we divide the rows and columns of :math:`\mathbf{A}` and
+# :math:`\mathbf{X}`, respectively, such that each rank ends up with:
+#
+# - :math:`A_{p} \in \mathbb{R}^{\text{my_own_rows}\times K}`
+# - :math:`X_{p} \in \mathbb{R}^{K\times \text{my_own_cols}}`
+#
+# .. raw:: html
+#
+#
+# Matrix A (4 x 4):
+# ┌─────────────────┐
+# │ a11 a12 a13 a14 │ <- Rows 0–1 (Group 0)
+# │ a21 a22 a23 a24 │
+# ├─────────────────┤
+# │ a41 a42 a43 a44 │ <- Rows 2–3 (Group 1)
+# │ a51 a52 a53 a54 │
+# └─────────────────┘
+#
+#
+# .. raw:: html
+#
+#
+# Matrix X (4 x 4):
+# ┌─────────┬─────────┐
+# │ b11 b12 │ b13 b14 │ <- Cols 0–1 (Layer 0), Cols 2–3 (Layer 1)
+# │ b21 b22 │ b23 b24 │
+# │ b31 b32 │ b33 b34 │
+# │ b41 b42 │ b43 b44 │
+# └─────────┴─────────┘
+#
+#
+
+blk_rows = int(math.ceil(N / p_prime))
+blk_cols = int(math.ceil(M / p_prime))
+
+rs = my_group * blk_rows
+re = min(N, rs + blk_rows)
+my_own_rows = re - rs
+
+cs = my_layer * blk_cols
+ce = min(M, cs + blk_cols)
+my_own_cols = ce - cs
+
+A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
+
+################################################################################
+# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
+# operator and the input matrix math:`\mathbf{X}`
+Aop = MPIMatrixMult(A_p, M, dtype="float32")
+
+col_lens = comm.allgather(my_own_cols)
+total_cols = np.sum(col_lens)
+x = DistributedArray(global_shape=K * total_cols,
+ local_shapes=[K * col_len for col_len in col_lens],
+ partition=Partition.SCATTER,
+ mask=[i % p_prime for i in range(comm.Get_size())],
+ base_comm=comm,
+ dtype="float32")
+x[:] = X_p.flatten()
+
+################################################################################
+# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which effectively
+# implements a distributed matrix-matrix multiplication :math:`Y = \mathbf{AX}`)
+# Note :math:`\mathbf{Y}` is distributed in the same way as the input
+# :math:`\mathbf{X}`.
+y = Aop @ x
+
+###############################################################################
+# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}`
+# (which effectively implements a distributed matrix-matrix multiplication
+# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that
+# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input
+# :math:`\mathbf{X}`.
+xadj = Aop.H @ y
+
+###############################################################################
+# To conclude we verify our result against the equivalent serial version of
+# the operation by gathering the resulting matrices in rank0 and reorganizing
+# the returned 1D-arrays into 2D-arrays.
+
+# Local benchmarks
+y = y.asarray(masked=True)
+col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)]
+y_blocks = []
+offset = 0
+for cnt in col_counts:
+ block_size = N * cnt
+ y_blocks.append(
+ y[offset: offset + block_size].reshape(N, cnt)
+ )
+ offset += block_size
+y = np.hstack(y_blocks)
+
+xadj = xadj.asarray(masked=True)
+xadj_blocks = []
+offset = 0
+for cnt in col_counts:
+ block_size = K * cnt
+ xadj_blocks.append(
+ xadj[offset: offset + block_size].reshape(K, cnt)
+ )
+ offset += block_size
+xadj = np.hstack(xadj_blocks)
+
+if rank == 0:
+ y_loc = (A @ X).squeeze()
+ xadj_loc = (A.T.dot(y_loc.conj())).conj().squeeze()
+
+ if not np.allclose(y, y_loc, rtol=1e-6):
+ print(f" FORWARD VERIFICATION FAILED")
+ print(f'distributed: {y}')
+ print(f'expected: {y_loc}')
+ else:
+ print(f"FORWARD VERIFICATION PASSED")
+
+ if not np.allclose(xadj, xadj_loc, rtol=1e-6):
+ print(f" ADJOINT VERIFICATION FAILED")
+ print(f'distributed: {xadj}')
+ print(f'expected: {xadj_loc}')
+ else:
+ print(f"ADJOINT VERIFICATION PASSED")
diff --git a/pylops_mpi/basicoperators/MatrixMult.py b/pylops_mpi/basicoperators/MatrixMult.py
new file mode 100644
index 00000000..ac15e2e5
--- /dev/null
+++ b/pylops_mpi/basicoperators/MatrixMult.py
@@ -0,0 +1,213 @@
+import numpy as np
+import math
+from mpi4py import MPI
+from pylops.utils.backend import get_module
+from pylops.utils.typing import DTypeLike, NDArray
+
+from pylops_mpi import (
+ DistributedArray,
+ MPILinearOperator,
+ Partition
+)
+
+
+class MPIMatrixMult(MPILinearOperator):
+ r"""MPI Matrix multiplication
+
+ Implement distributed matrix-matrix multiplication between a matrix
+ :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored
+ over different ranks) and the input model and data vector, which are both to
+ be interpreted as matrices blocked over columns.
+
+ Parameters
+ ----------
+ A : :obj:`numpy.ndarray`
+ Local block of the matrix of shape :math:`[M_{loc} \times K]`
+ where ``M_loc`` is the number of rows stored on this MPI rank and
+ ``K`` is the global number of columns.
+ M : :obj:`int`
+ Global leading dimension (i.e., number of columns) of the matrices
+ representing the input model and data vectors.
+ saveAt : :obj:`bool`, optional
+ Save ``A`` and ``A.H`` to speed up the computation of adjoint
+ (``True``) or create ``A.H`` on-the-fly (``False``)
+ Note that ``saveAt=True`` will double the amount of required memory.
+ Default is ``False``.
+ base_comm : :obj:`mpi4py.MPI.Comm`, optional
+ MPI Base Communicator. Defaults to ``mpi4py.MPI.COMM_WORLD``.
+ dtype : :obj:`str`, optional
+ Type of elements in input array.
+
+ Attributes
+ ----------
+ shape : :obj:`tuple`
+ Operator shape
+
+ Raises
+ ------
+ Exception
+ If the operator is created without a square number of mpi ranks.
+ ValueError
+ If input vector does not have the correct partition type.
+
+ Notes
+ -----
+ This operator performs a matrix-matrix multiplication, whose forward
+ operation can be described as :math:`Y = A \cdot X` where:
+
+ - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]`
+ - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
+ - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
+
+ whilst the adjoint operation is represented by
+ :math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where
+ :math:`\mathbf{A}^H` is the complex conjugate and transpose of :math:`\mathbf{A}`.
+
+ This implementation is based on a 1D block distribution of the operator
+ matrix and reshaped model and data vectors replicated across math:`P`
+ processes by a factor equivalent to :math:`\sqrt{P}` across a square process
+ grid (:math:`\sqrt{P}\times\sqrt{P}`). More specifically:
+
+ - The matrix ``A`` is distributed across MPI processes in a block-row fashion
+ and each process holds a local block of ``A`` with shape
+ :math:`[N_{loc} \times K]`
+ - The operand matrix ``X`` is distributed in a block-column fashion and
+ each process holds a local block of ``X`` with shape
+ :math:`[K \times M_{loc}]`
+ - Communication is minimized by using a 2D process grid layout
+
+ **Forward Operation step-by-step**
+
+ 1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
+ of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
+ is the number of columns assigned to the current process.
+
+ 2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
+ the operand data is broadcast from the process whose ``group_id`` matches
+ the ``layer_id``. This ensures all processes in a layer have access to
+ the same operand columns.
+
+ 3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
+ - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
+ - ``X_local`` is the broadcasted operand (shape ``K x M_local``)
+
+ 4. **Layer Gather**: Results from all processes in each layer are gathered
+ using ``allgather`` to reconstruct the full result matrix vertically.
+
+ **Adjoint Operation step-by-step**
+
+ The adjoint operation performs the conjugate transpose multiplication:
+
+ 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)``
+ representing the local columns of the input matrix.
+
+ 2. **Local Adjoint Computation**:
+ Each process computes ``A_local.H @ X_tile``
+ where ``A_local.H`` is either:
+ - Pre-computed ``At`` (if ``saveAt=True``)
+ - Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
+ Each process multiplies its transposed local ``A`` block ``A_local^H``
+ (shape ``K x N_block``)
+ with the extracted ``X_tile`` (shape ``N_block x M_local``),
+ producing a partial result of shape ``(K, M_local)``.
+ This computes the local contribution of columns of ``A^H`` to the final result.
+
+ 3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
+ sum of contributions from all column blocks of ``A^H``, processes in the
+ same layer perform an ``allreduce`` sum to combine their partial results.
+ This gives the complete ``(K, M_local)`` result for their assigned columns.
+
+ """
+ def __init__(
+ self,
+ A: NDArray,
+ M: int,
+ saveAt: bool = False,
+ base_comm: MPI.Comm = MPI.COMM_WORLD,
+ dtype: DTypeLike = "float64",
+ ) -> None:
+ rank = base_comm.Get_rank()
+ size = base_comm.Get_size()
+
+ # Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
+ self._P_prime = math.isqrt(size)
+ self._C = self._P_prime
+ if self._P_prime * self._C != size:
+ raise Exception(f"Number of processes must be a square number, provided {size} instead...")
+
+ # Compute this process's group and layer indices
+ self._group_id = rank % self._P_prime
+ self._layer_id = rank // self._P_prime
+
+ # Split communicators by layer (rows) and by group (columns)
+ self.base_comm = base_comm
+ self._layer_comm = base_comm.Split(color=self._layer_id, key=self._group_id)
+ self._group_comm = base_comm.Split(color=self._group_id, key=self._layer_id)
+
+ self.A = A.astype(np.dtype(dtype))
+ if saveAt: self.At = A.T.conj()
+
+ self.N = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
+ self.K = A.shape[1]
+ self.M = M
+
+ block_cols = int(math.ceil(self.M / self._P_prime))
+ blk_rows = int(math.ceil(self.N / self._P_prime))
+
+ self._row_start = self._group_id * blk_rows
+ self._row_end = min(self.N, self._row_start + blk_rows)
+
+ self._col_start = self._layer_id * block_cols
+ self._col_end = min(self.M, self._col_start + block_cols)
+
+ self._local_ncols = self._col_end - self._col_start
+ self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
+ total_ncols = np.sum(self._rank_col_lens)
+
+ self.dims = (self.K, total_ncols)
+ self.dimsd = (self.N, total_ncols)
+ shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
+ super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
+
+ def _matvec(self, x: DistributedArray) -> DistributedArray:
+ ncp = get_module(x.engine)
+ if x.partition != Partition.SCATTER:
+ raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
+
+ y = DistributedArray(global_shape=(self.N * self.dimsd[1]),
+ local_shapes=[(self.N * c) for c in self._rank_col_lens],
+ mask=x.mask,
+ partition=Partition.SCATTER,
+ dtype=self.dtype)
+
+ my_own_cols = self._rank_col_lens[self.rank]
+ x_arr = x.local_array.reshape((self.dims[0], my_own_cols))
+ X_local = x_arr.astype(self.dtype)
+ Y_local = ncp.vstack(
+ self._layer_comm.allgather(
+ ncp.matmul(self.A, X_local)
+ )
+ )
+ y[:] = Y_local.flatten()
+ return y
+
+ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
+ ncp = get_module(x.engine)
+ if x.partition != Partition.SCATTER:
+ raise ValueError(f"x should have partition={Partition.SCATTER}. Got {x.partition} instead.")
+
+ y = DistributedArray(
+ global_shape=(self.K * self.dimsd[1]),
+ local_shapes=[self.K * c for c in self._rank_col_lens],
+ mask=x.mask,
+ partition=Partition.SCATTER,
+ dtype=self.dtype,
+ )
+
+ x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
+ X_tile = x_arr[self._row_start:self._row_end, :]
+ A_local = self.At if hasattr(self, "At") else self.A.T.conj()
+ Y_local = ncp.matmul(A_local, X_tile)
+ y_layer = self._layer_comm.allreduce(Y_local, op=MPI.SUM)
+ y[:] = y_layer.flatten()
+ return y
diff --git a/pylops_mpi/basicoperators/__init__.py b/pylops_mpi/basicoperators/__init__.py
index 566b59fb..8db5a988 100644
--- a/pylops_mpi/basicoperators/__init__.py
+++ b/pylops_mpi/basicoperators/__init__.py
@@ -7,6 +7,7 @@
functionalities using MPI.
A list of operators present in pylops_mpi.basicoperators:
+ MPIMatrixMult Matrix Multiplication operator
MPIBlockDiag Block Diagonal arrangement of PyLops operators
MPIStackedBlockDiag Block Diagonal arrangement of PyLops-MPI operators
MPIVStack Vertical Stacking of PyLops operators
@@ -19,6 +20,7 @@
"""
+from .MatrixMult import *
from .BlockDiag import *
from .VStack import *
from .HStack import *
@@ -28,6 +30,7 @@
from .Gradient import *
__all__ = [
+ "MPIMatrixMult",
"MPIBlockDiag",
"MPIStackedBlockDiag",
"MPIVStack",
diff --git a/tests/test_matrixmult.py b/tests/test_matrixmult.py
new file mode 100644
index 00000000..61f6ee7e
--- /dev/null
+++ b/tests/test_matrixmult.py
@@ -0,0 +1,135 @@
+import pytest
+import numpy as np
+from numpy.testing import assert_allclose
+from mpi4py import MPI
+import math
+import sys
+
+from pylops_mpi import DistributedArray, Partition
+from pylops_mpi.basicoperators.MatrixMult import MPIMatrixMult
+
+np.random.seed(42)
+
+comm = MPI.COMM_WORLD
+rank = comm.Get_rank()
+size = comm.Get_size()
+
+# Define test cases: (N K, M, dtype_str)
+# M, K, N are matrix dimensions A(N,K), B(K,M)
+# P_prime will be ceil(sqrt(size)).
+test_params = [
+ pytest.param(37, 37, 37, "float32", id="f32_37_37_37"),
+ pytest.param(50, 30, 40, "float64", id="f64_50_30_40"),
+ pytest.param(22, 20, 16, "complex64", id="c64_22_20_16"),
+ pytest.param( 3, 4, 5, "float32", id="f32_3_4_5"),
+ pytest.param( 1, 2, 1, "float64", id="f64_1_2_1",),
+ pytest.param( 2, 1, 3, "float32", id="f32_2_1_3",),
+]
+
+
+@pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process.
+@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
+def test_SUMMAMatrixMult(N, K, M, dtype_str):
+ p_prime = math.isqrt(size)
+ C = p_prime
+ if p_prime * C != size:
+ pytest.skip(f"Number of processes must be a square number, provided {size} instead...")
+
+ dtype = np.dtype(dtype_str)
+
+ cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
+ base_float_dtype = np.float32 if dtype == np.complex64 else np.float64
+
+ my_group = rank % p_prime
+ my_layer = rank // p_prime
+
+ # Create sub-communicators
+ layer_comm = comm.Split(color=my_layer, key=my_group)
+ group_comm = comm.Split(color=my_group, key=my_layer)
+
+ # Calculate local matrix dimensions
+ blk_rows_A = int(math.ceil(N / p_prime))
+ row_start_A = my_group * blk_rows_A
+ row_end_A = min(N, row_start_A + blk_rows_A)
+
+ blk_cols_X = int(math.ceil(M / p_prime))
+ col_start_X = my_layer * blk_cols_X
+ col_end_X = min(M, col_start_X + blk_cols_X)
+ local_col_X_len = max(0, col_end_X - col_start_X)
+
+ A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K)
+ A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5
+ A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
+
+ X_glob_real = np.arange(K * M, dtype=base_float_dtype).reshape(K, M)
+ X_glob_imag = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7
+ X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype)
+
+ A_p = A_glob[row_start_A:row_end_A,:]
+ X_p = X_glob[:,col_start_X:col_end_X]
+
+ # Create MPIMatrixMult operator
+ Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str)
+
+ # Create DistributedArray for input x (representing B flattened)
+ all_local_col_len = comm.allgather(local_col_X_len)
+ total_cols = np.sum(all_local_col_len)
+
+ x_dist = DistributedArray(
+ global_shape=(K * total_cols),
+ local_shapes=[K * cl_b for cl_b in all_local_col_len],
+ partition=Partition.SCATTER,
+ base_comm=comm,
+ mask=[i % p_prime for i in range(size)],
+ dtype=dtype
+ )
+
+ x_dist.local_array[:] = X_p.ravel()
+
+ # Forward operation: y = A @ B (distributed)
+ y_dist = Aop @ x_dist
+ # Adjoint operation: xadj = A.H @ y (distributed)
+ xadj_dist = Aop.H @ y_dist
+
+ y = y_dist.asarray(masked=True)
+ col_counts = [min(blk_cols_X, M - j * blk_cols_X) for j in range(p_prime)]
+ y_blocks = []
+ offset = 0
+ for cnt in col_counts:
+ block_size = N * cnt
+ y_blocks.append(
+ y[offset: offset + block_size].reshape(N, cnt)
+ )
+ offset += block_size
+ y = np.hstack(y_blocks)
+
+ xadj = xadj_dist.asarray(masked=True)
+ xadj_blocks = []
+ offset = 0
+ for cnt in col_counts:
+ block_size = K * cnt
+ xadj_blocks.append(
+ xadj[offset: offset + block_size].reshape(K, cnt)
+ )
+ offset += block_size
+ xadj = np.hstack(xadj_blocks)
+
+ if rank == 0:
+ y_loc = A_glob @ X_glob
+ assert_allclose(
+ y.squeeze(),
+ y_loc.squeeze(),
+ rtol=np.finfo(np.dtype(dtype)).resolution,
+ err_msg=f"Rank {rank}: Forward verification failed."
+ )
+
+ xadj_loc = A_glob.conj().T @ y_loc
+ assert_allclose(
+ xadj.squeeze(),
+ xadj_loc.squeeze(),
+ rtol=np.finfo(np.dtype(dtype)).resolution,
+ err_msg=f"Rank {rank}: Ajoint verification failed."
+ )
+
+ group_comm.Free()
+ layer_comm.Free()
\ No newline at end of file