Skip to content

Implementation of MPI SUMMA matrix mul #160

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 29 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
29 commits
Select commit Hold shift + click to select a range
75c815b
inital-example
astroC86 Jun 29, 2025
908d4b9
Merge remote-tracking branch 'origin/main' into actual-SUMMA
astroC86 Jun 29, 2025
069e5dd
working simple example
astroC86 Jun 29, 2025
5fcbad3
untransformed C
astroC86 Jun 29, 2025
d8d9463
Initial impl of SUMMA matmul
astroC86 Jun 29, 2025
a9e679e
matmul with padding
astroC86 Jun 29, 2025
8142d44
impl adjoint
astroC86 Jul 13, 2025
7363431
cleanedup adjoint impl
astroC86 Jul 13, 2025
4a94ac6
merge and new example
astroC86 Jul 13, 2025
dc00226
added handling for padding
astroC86 Jul 13, 2025
58d3ceb
Cleanup
astroC86 Jul 14, 2025
1ef09ab
converted Bcast into bcast
astroC86 Jul 23, 2025
56e9414
Added docstring
astroC86 Jul 23, 2025
f3d1918
removed block distribute function
astroC86 Jul 23, 2025
9d36b0c
removed unnecessary check on local matrix A
astroC86 Jul 23, 2025
66e3296
Added Generic MatMulOp with docstring
astroC86 Jul 23, 2025
fa07ae8
Merge branch 'PyLops:main' into actual-SUMMA
astroC86 Jul 23, 2025
0956e7b
Converted it to a function
astroC86 Jul 24, 2025
b053f5b
Added SUMMA tests and fixed dtype problem
astroC86 Jul 26, 2025
8851e05
Added documentation and example explination
astroC86 Jul 27, 2025
531873f
Fixed block_gather fn
astroC86 Jul 27, 2025
75900a7
Fixed np to ncp in forward and backward
astroC86 Jul 27, 2025
0c5cb7e
consistancy
astroC86 Jul 27, 2025
302bd4b
Merge branch 'main' into actual-SUMMA
astroC86 Jul 29, 2025
f9c0ca5
minor: clean-up of docstrings and code
mrava87 Jul 31, 2025
3738a06
minor: removed empty first line in MatrixMult
mrava87 Jul 31, 2025
9c04583
feat: ensure y arrays are created on same engine as x
mrava87 Jul 31, 2025
186af3a
fix: Fixed failing test
astroC86 Jul 31, 2025
25f30bc
minor: minor docstring edits
astroC86 Jul 31, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions docs/source/api/index.rst
Original file line number Diff line number Diff line change
Expand Up @@ -118,6 +118,15 @@ Utils
local_split


.. currentmodule:: pylops_mpi.basicoperators.MatrixMult

.. autosummary::
:toctree: generated/

block_gather
local_block_split
active_grid_comm

.. currentmodule:: pylops_mpi.utils

.. autosummary::
Expand Down
27 changes: 13 additions & 14 deletions examples/plot_matrixmult.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
r"""
Distributed Matrix Multiplication
=================================
Distributed Matrix Multiplication - Block-row-column decomposition
==================================================================
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
operator to perform matrix-matrix multiplication between a matrix :math:`\mathbf{A}`
blocked over rows (i.e., blocks of rows are stored over different ranks) and a
matrix :math:`\mathbf{X}` blocked over columns (i.e., blocks of columns are
stored over different ranks), with equal number of row and column blocks.
Similarly, the adjoint operation can be peformed with a matrix :math:`\mathbf{Y}`
blocked in the same fashion of matrix :math:`\mathbf{X}`.
operator with ``kind='blocked'`` to perform matrix-matrix multiplication between
a matrix :math:`\mathbf{A}` blocked over rows (i.e., blocks of rows are stored
over different ranks) and a matrix :math:`\mathbf{X}` blocked over columns
(i.e., blocks of columns are stored over different ranks), with equal number
of row and column blocks. Similarly, the adjoint operation can be peformed with
a matrix :math:`\mathbf{Y}` blocked in the same fashion of matrix :math:`\mathbf{X}`.

Note that whilst the different blocks of the matrix :math:`\mathbf{A}` are directly
stored in the operator on different ranks, the matrix :math:`\mathbf{X}` is
Expand All @@ -19,15 +19,16 @@

"""

from matplotlib import pyplot as plt
import math
import numpy as np
from mpi4py import MPI
from matplotlib import pyplot as plt

import pylops

import pylops_mpi
from pylops_mpi import Partition
from pylops_mpi.basicoperators.MatrixMult import active_grid_comm, MPIMatrixMult

plt.close("all")

Expand All @@ -39,8 +40,7 @@

###############################################################################
# We are now ready to create the input matrices :math:`\mathbf{A}` of size
# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size
# :math:`K \times N`.
# :math:`M \times k` and :math:`\mathbf{X}` of size :math:`K \times N`.
N, K, M = 4, 4, 4
A = np.random.rand(N * K).astype(dtype=np.float32).reshape(N, K)
X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M)
Expand Down Expand Up @@ -88,8 +88,7 @@
# than the row or columm ranks.

base_comm = MPI.COMM_WORLD
comm, rank, row_id, col_id, is_active = \
pylops_mpi.MPIMatrixMult.active_grid_comm(base_comm, N, M)
comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")
if not is_active: exit(0)

Expand Down Expand Up @@ -147,7 +146,7 @@
################################################################################
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
# operator and the input matrix :math:`\mathbf{X}`
Aop = pylops_mpi.MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32")
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype="float32", kind="block")

col_lens = comm.allgather(my_own_cols)
total_cols = np.sum(col_lens)
Expand Down
159 changes: 159 additions & 0 deletions examples/plot_summamatrixmult.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
r"""
Distributed Matrix Multiplication - SUMMA
=========================================
This example shows how to use the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
operator with ``kind='summa'`` to perform matrix-matrix multiplication between
a matrix :math:`\mathbf{A}` distributed in 2D blocks across a square process
grid and matrices :math:`\mathbf{X}` and :math:`\mathbf{Y}` distributed in 2D
blocks across the same grid. Similarly, the adjoint operation can be performed
with a matrix :math:`\mathbf{Y}` distributed in the same fashion as matrix
:math:`\mathbf{X}`.

Note that whilst the different blocks of matrix :math:`\mathbf{A}` are directly
stored in the operator on different ranks, the matrices :math:`\mathbf{X}` and
:math:`\mathbf{Y}` are effectively represented by 1-D :py:class:`pylops_mpi.DistributedArray`
objects where the different blocks are flattened and stored on different ranks.
Note that to optimize communications, the ranks are organized in a square grid and
blocks of :math:`\mathbf{A}` and :math:`\mathbf{X}` are systematically broadcast
across different ranks during computation - see below for details.
"""

import math
import numpy as np
from mpi4py import MPI
from matplotlib import pyplot as plt

import pylops_mpi
from pylops import Conj
from pylops_mpi.basicoperators.MatrixMult import \
local_block_split, MPIMatrixMult, active_grid_comm

plt.close("all")

###############################################################################
# We set the seed such that all processes can create the input matrices filled
# with the same random number. In practical applications, such matrices will be
# filled with data that is appropriate to the use-case.
np.random.seed(42)

###############################################################################
# We are now ready to create the input matrices for our distributed matrix
# multiplication example. We need to set up:
#
# - Matrix :math:`\mathbf{A}` of size :math:`N \times K` (the left operand)
# - Matrix :math:`\mathbf{X}` of size :math:`K \times M` (the right operand)
# - The result will be :math:`\mathbf{Y} = \mathbf{A} \mathbf{X}` of size
# :math:`N \times M`
#
# We create here global test matrices with sequential values for easy verification:
#
# - Matrix A: Each element :math:`A_{i,j} = i \cdot K + j` (row-major ordering)
# - Matrix X: Each element :math:`X_{i,j} = i \cdot M + j`

N, M, K = 6, 6, 6
A_shape, x_shape, y_shape = (N, K), (K, M), (N, M)

A_data = np.arange(int(A_shape[0] * A_shape[1])).reshape(A_shape)
x_data = np.arange(int(x_shape[0] * x_shape[1])).reshape(x_shape)

################################################################################
# For distributed computation, we arrange processes in a square grid of size
# :math:`P' \times P'` where :math:`P' = \sqrt{P}` and :math:`P` is the total
# number of MPI processes. Each process will own a block of each matrix
# according to this 2D grid layout.

base_comm = MPI.COMM_WORLD
comm, rank, row_id, col_id, is_active = active_grid_comm(base_comm, N, M)
print(f"Process {base_comm.Get_rank()} is {'active' if is_active else 'inactive'}")

p_prime = math.isqrt(comm.Get_size())
print(f"Process grid: {p_prime} x {p_prime} = {comm.Get_size()} processes")

if rank == 0:
print(f"Global matrix A shape: {A_shape} (N={A_shape[0]}, K={A_shape[1]})")
print(f"Global matrix X shape: {x_shape} (K={x_shape[0]}, M={x_shape[1]})")
print(f"Expected Global result Y shape: ({A_shape[0]}, {x_shape[1]}) = (N, M)")

################################################################################
# Next we must determine which block of each matrix each process should own.
#
# The 2D block distribution requires:
#
# - Process at grid position :math:`(i,j)` gets block
# :math:`\mathbf{A}[i_{start}:i_{end}, j_{start}:j_{end}]`
# - Block sizes are approximately :math:`\lceil N/P' \rceil \times \lceil K/P' \rceil`
# with edge processes handling remainder
#
# .. raw:: html
#
# <div style="text-align: left; font-family: monospace; white-space: pre;">
# <b>Example: 2x2 Process Grid with 6x6 Matrices</b>
#
# Matrix A (6x6): Matrix X (6x6):
# ┌───────────┬───────────┐ ┌───────────┬───────────┐
# │ 0 1 2 │ 3 4 5 │ │ 0 1 2 │ 3 4 5 │
# │ 6 7 8 │ 9 10 11 │ │ 6 7 8 │ 9 10 11 │
# │ 12 13 14 │ 15 16 17 │ │ 12 13 14 │ 15 16 17 │
# ├───────────┼───────────┤ ├───────────┼───────────┤
# │ 18 19 20 │ 21 22 23 │ │ 18 19 20 │ 21 22 23 │
# │ 24 25 26 │ 27 28 29 │ │ 24 25 26 │ 27 28 29 │
# │ 30 31 32 │ 33 34 35 │ │ 30 31 32 │ 33 34 35 │
# └───────────┴───────────┘ └───────────┴───────────┘
#
# Process (0,0): A[0:3, 0:3], X[0:3, 0:3]
# Process (0,1): A[0:3, 3:6], X[0:3, 3:6]
# Process (1,0): A[3:6, 0:3], X[3:6, 0:3]
# Process (1,1): A[3:6, 3:6], X[3:6, 3:6]
# </div>
#

A_slice = local_block_split(A_shape, rank, comm)
x_slice = local_block_split(x_shape, rank, comm)

################################################################################
# Extract the local portion of each matrix for this process
A_local = A_data[A_slice]
x_local = x_data[x_slice]

print(f"Process {rank}: A_local shape {A_local.shape}, X_local shape {x_local.shape}")
print(f"Process {rank}: A slice {A_slice}, X slice {x_slice}")

################################################################################

################################################################################
# We are now ready to create the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
# operator and the input matrix :math:`\mathbf{X}`

Aop = MPIMatrixMult(A_local, M, base_comm=comm, kind="summa", dtype=A_local.dtype)

x_dist = pylops_mpi.DistributedArray(
global_shape=(K * M),
local_shapes=comm.allgather(x_local.shape[0] * x_local.shape[1]),
base_comm=comm,
partition=pylops_mpi.Partition.SCATTER,
dtype=x_local.dtype)
x_dist[:] = x_local.flatten()

################################################################################
# We can now apply the forward pass :math:`\mathbf{y} = \mathbf{Ax}` (which
# effectively implements a distributed matrix-matrix multiplication
# :math:`Y = \mathbf{AX}`). Note :math:`\mathbf{Y}` is distributed in the same
# way as the input :math:`\mathbf{X}` in a block-block fashion.
y_dist = Aop @ x_dist

###############################################################################
# Next we apply the adjoint pass :math:`\mathbf{x}_{adj} = \mathbf{A}^H \mathbf{x}`
# (which effectively implements a distributed summa matrix-matrix multiplication
# :math:`\mathbf{X}_{adj} = \mathbf{A}^H \mathbf{X}`). Note that
# :math:`\mathbf{X}_{adj}` is again distributed in the same way as the input
# :math:`\mathbf{X}` in a block-block fashion.
xadj_dist = Aop.H @ y_dist

###############################################################################
# Finally, we show that the SUMMA :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
# operator can be combined with any other PyLops-MPI operator. We are going to
# apply here a conjugate operator to the output of the matrix multiplication.
Dop = Conj(dims=(A_local.shape[0], x_local.shape[1]))
DBop = pylops_mpi.MPIBlockDiag(ops=[Dop, ])
Op = DBop @ Aop
y1 = Op @ x_dist
1 change: 0 additions & 1 deletion pylops_mpi/LinearOperator.py
Original file line number Diff line number Diff line change
Expand Up @@ -76,7 +76,6 @@ def matvec(self, x: DistributedArray) -> DistributedArray:

"""
M, N = self.shape

if x.global_shape != (N,):
raise ValueError("dimension mismatch")

Expand Down
Loading
Loading