Skip to content

Commit ffb2b99

Browse files
committed
Fixed Notation
1 parent ed3b585 commit ffb2b99

File tree

3 files changed

+57
-58
lines changed

3 files changed

+57
-58
lines changed

examples/plot_matrixmult.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -52,9 +52,9 @@
5252
# We are now ready to create the input matrices :math:`\mathbf{A}` of size
5353
# :math:`M \times k` :math:`\mathbf{A}` of size and :math:`\mathbf{A}` of size
5454
# :math:`K \times N`.
55-
M, K, N = 4, 4, 4
56-
A = np.random.rand(M * K).astype(dtype=np.float32).reshape(M, K)
57-
X = np.random.rand(K * N).astype(dtype=np.float32).reshape(K, N)
55+
N, K, M = 4, 4, 4
56+
A = np.random.rand(N * K).astype(dtype=np.float32).reshape(N, K)
57+
X = np.random.rand(K * M).astype(dtype=np.float32).reshape(K, M)
5858

5959
################################################################################
6060
# The processes are now arranged in a :math:`\sqrt{P} \times \sqrt{P}` grid,
@@ -129,27 +129,26 @@
129129
# │ b31 b32 │ b33 b34 │
130130
# │ b41 b42 │ b43 b44 │
131131
# └─────────┴─────────┘
132-
#
133132
# </div>
134133
#
135134

136-
blk_rows = int(math.ceil(M / p_prime))
137-
blk_cols = int(math.ceil(N / p_prime))
135+
blk_rows = int(math.ceil(N / p_prime))
136+
blk_cols = int(math.ceil(M / p_prime))
138137

139138
rs = my_group * blk_rows
140-
re = min(M, rs + blk_rows)
139+
re = min(N, rs + blk_rows)
141140
my_own_rows = re - rs
142141

143142
cs = my_layer * blk_cols
144-
ce = min(N, cs + blk_cols)
143+
ce = min(M, cs + blk_cols)
145144
my_own_cols = ce - cs
146145

147146
A_p, X_p = A[rs:re, :].copy(), X[:, cs:ce].copy()
148147

149148
################################################################################
150149
# We are now ready to create the :py:class:`pylops_mpi.basicoperators.MPIMatrixMult`
151150
# operator and the input matrix math:`\mathbf{X}`
152-
Aop = MPIMatrixMult(A_p, N, dtype="float32")
151+
Aop = MPIMatrixMult(A_p, M, dtype="float32")
153152

154153
col_lens = comm.allgather(my_own_cols)
155154
total_cols = np.sum(col_lens)
@@ -183,13 +182,13 @@
183182

184183
# Local benchmarks
185184
y = y.asarray(masked=True)
186-
col_counts = [min(blk_cols, N - j * blk_cols) for j in range(p_prime)]
185+
col_counts = [min(blk_cols, M - j * blk_cols) for j in range(p_prime)]
187186
y_blocks = []
188187
offset = 0
189188
for cnt in col_counts:
190-
block_size = M * cnt
189+
block_size = N * cnt
191190
y_blocks.append(
192-
y[offset: offset + block_size].reshape(M, cnt)
191+
y[offset: offset + block_size].reshape(N, cnt)
193192
)
194193
offset += block_size
195194
y = np.hstack(y_blocks)

pylops_mpi/basicoperators/MatrixMult.py

Lines changed: 26 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ class MPIMatrixMult(MPILinearOperator):
2525
Local block of the matrix of shape :math:`[M_{loc} \times K]`
2626
where ``M_loc`` is the number of rows stored on this MPI rank and
2727
``K`` is the global number of columns.
28-
N : :obj:`int`
28+
M : :obj:`int`
2929
Global leading dimension (i.e., number of columns) of the matrices
3030
representing the input model and data vectors.
3131
saveAt : :obj:`bool`, optional
@@ -55,9 +55,9 @@ class MPIMatrixMult(MPILinearOperator):
5555
This operator performs a matrix-matrix multiplication, whose forward
5656
operation can be described as :math:`Y = A \cdot X` where:
5757
58-
- :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[M \times K]`
59-
- :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times N]`
60-
- :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[M \times N]`
58+
- :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]`
59+
- :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M]`
60+
- :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M]`
6161
6262
whilst the adjoint operation is represented by
6363
:math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where
@@ -70,16 +70,16 @@ class MPIMatrixMult(MPILinearOperator):
7070
7171
- The matrix ``A`` is distributed across MPI processes in a block-row fashion
7272
and each process holds a local block of ``A`` with shape
73-
:math:`[M_{loc} \times K]`
73+
:math:`[N_{loc} \times K]`
7474
- The operand matrix ``X`` is distributed in a block-column fashion and
75-
and each process holds a local block of ``X`` with shape
76-
:math:`[K \times N_{loc}]`
75+
each process holds a local block of ``X`` with shape
76+
:math:`[K \times M_{loc}]`
7777
- Communication is minimized by using a 2D process grid layout
7878
7979
**Forward Operation step-by-step**
8080
8181
1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
82-
of shape ``(K, N)``) is reshaped to ``(K, N_local)`` where ``N_local``
82+
of shape ``(K, M)``) is reshaped to ``(K, M_local)`` where ``M_local``
8383
is the number of columns assigned to the current process.
8484
8585
2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
@@ -88,8 +88,8 @@ class MPIMatrixMult(MPILinearOperator):
8888
the same operand columns.
8989
9090
3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
91-
- ``A_local`` is the local block of matrix ``A`` (shape ``M_local x K``)
92-
- ``X_local`` is the broadcasted operand (shape ``K x N_local``)
91+
- ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
92+
- ``X_local`` is the broadcasted operand (shape ``K x M_local``)
9393
9494
4. **Layer Gather**: Results from all processes in each layer are gathered
9595
using ``allgather`` to reconstruct the full result matrix vertically.
@@ -98,7 +98,7 @@ class MPIMatrixMult(MPILinearOperator):
9898
9999
The adjoint operation performs the conjugate transpose multiplication:
100100
101-
1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(M, N_local)``
101+
1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local)``
102102
representing the local columns of the input matrix.
103103
104104
2. **Local Adjoint Computation**:
@@ -107,21 +107,21 @@ class MPIMatrixMult(MPILinearOperator):
107107
- Pre-computed ``At`` (if ``saveAt=True``)
108108
- Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
109109
Each process multiplies its transposed local ``A`` block ``A_local^H``
110-
(shape ``K x M_block``)
111-
with the extracted ``X_tile`` (shape ``M_block x N_local``),
112-
producing a partial result of shape ``(K, N_local)``.
110+
(shape ``K x N_block``)
111+
with the extracted ``X_tile`` (shape ``N_block x M_local``),
112+
producing a partial result of shape ``(K, M_local)``.
113113
This computes the local contribution of columns of ``A^H`` to the final result.
114114
115115
3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116116
sum of contributions from all column blocks of ``A^H``, processes in the
117117
same layer perform an ``allreduce`` sum to combine their partial results.
118-
This gives the complete ``(K, N_local)`` result for their assigned columns.
118+
This gives the complete ``(K, M_local)`` result for their assigned columns.
119119
120120
"""
121121
def __init__(
122122
self,
123123
A: NDArray,
124-
N: int,
124+
M: int,
125125
saveAt: bool = False,
126126
base_comm: MPI.Comm = MPI.COMM_WORLD,
127127
dtype: DTypeLike = "float64",
@@ -147,25 +147,25 @@ def __init__(
147147
self.A = A.astype(np.dtype(dtype))
148148
if saveAt: self.At = A.T.conj()
149149

150-
self.M = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
150+
self.N = self._layer_comm.allreduce(self.A.shape[0], op=MPI.SUM)
151151
self.K = A.shape[1]
152-
self.N = N
152+
self.M = M
153153

154-
block_cols = int(math.ceil(self.N / self._P_prime))
155-
blk_rows = int(math.ceil(self.M / self._P_prime))
154+
block_cols = int(math.ceil(self.M / self._P_prime))
155+
blk_rows = int(math.ceil(self.N / self._P_prime))
156156

157157
self._row_start = self._group_id * blk_rows
158-
self._row_end = min(self.M, self._row_start + blk_rows)
158+
self._row_end = min(self.N, self._row_start + blk_rows)
159159

160160
self._col_start = self._layer_id * block_cols
161-
self._col_end = min(self.N, self._col_start + block_cols)
161+
self._col_end = min(self.M, self._col_start + block_cols)
162162

163163
self._local_ncols = self._col_end - self._col_start
164164
self._rank_col_lens = self.base_comm.allgather(self._local_ncols)
165165
total_ncols = np.sum(self._rank_col_lens)
166166

167167
self.dims = (self.K, total_ncols)
168-
self.dimsd = (self.M, total_ncols)
168+
self.dimsd = (self.N, total_ncols)
169169
shape = (int(np.prod(self.dimsd)), int(np.prod(self.dims)))
170170
super().__init__(shape=shape, dtype=np.dtype(dtype), base_comm=base_comm)
171171

@@ -174,8 +174,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
174174
if x.partition != Partition.SCATTER:
175175
raise ValueError(f"x should have partition={Partition.SCATTER} Got {x.partition} instead...")
176176

177-
y = DistributedArray(global_shape=(self.M * self.dimsd[1]),
178-
local_shapes=[(self.M * c) for c in self._rank_col_lens],
177+
y = DistributedArray(global_shape=(self.N * self.dimsd[1]),
178+
local_shapes=[(self.N * c) for c in self._rank_col_lens],
179179
mask=x.mask,
180180
partition=Partition.SCATTER,
181181
dtype=self.dtype)
@@ -204,7 +204,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
204204
dtype=self.dtype,
205205
)
206206

207-
x_arr = x.local_array.reshape((self.M, self._local_ncols)).astype(self.dtype)
207+
x_arr = x.local_array.reshape((self.N, self._local_ncols)).astype(self.dtype)
208208
X_tile = x_arr[self._row_start:self._row_end, :]
209209
A_local = self.At if hasattr(self, "At") else self.A.T.conj()
210210
Y_local = ncp.matmul(A_local, X_tile)

tests/test_matrixmult.py

Lines changed: 20 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -14,22 +14,22 @@
1414
rank = comm.Get_rank()
1515
size = comm.Get_size()
1616

17-
# Define test cases: (M, K, N, dtype_str)
18-
# M, K, N are matrix dimensions A(M,K), B(K,N)
17+
# Define test cases: (N K, M, dtype_str)
18+
# M, K, N are matrix dimensions A(N,K), B(K,M)
1919
# P_prime will be ceil(sqrt(size)).
2020
test_params = [
2121
pytest.param(37, 37, 37, "float32", id="f32_37_37_37"),
22-
pytest.param(40, 30, 50, "float64", id="f64_40_30_50"),
23-
pytest.param(16, 20, 22, "complex64", id="c64_16_20_22"),
24-
pytest.param(5, 4, 3, "float32", id="f32_5_4_3"),
25-
pytest.param(1, 2, 1, "float64", id="f64_1_2_1",),
26-
pytest.param(3, 1, 2, "float32", id="f32_3_1_2",),
22+
pytest.param(50, 30, 40, "float64", id="f64_40_30_50"),
23+
pytest.param(22, 20, 16, "complex64", id="c64_16_20_22"),
24+
pytest.param( 3, 4, 5, "float32", id="f32_5_4_3"),
25+
pytest.param( 1, 2, 1, "float64", id="f64_1_2_1",),
26+
pytest.param( 2, 1, 3, "float32", id="f32_3_1_2",),
2727
]
2828

2929

3030
@pytest.mark.mpi(min_size=1) # SUMMA should also work for 1 process.
3131
@pytest.mark.parametrize("M, K, N, dtype_str", test_params)
32-
def test_SUMMAMatrixMult(M, K, N, dtype_str):
32+
def test_SUMMAMatrixMult(N, K, M, dtype_str):
3333
dtype = np.dtype(dtype_str)
3434

3535
cmplx = 1j if np.issubdtype(dtype, np.complexfloating) else 0
@@ -47,28 +47,28 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
4747
group_comm = comm.Split(color=my_group, key=my_layer)
4848

4949
# Calculate local matrix dimensions
50-
blk_rows_A = int(math.ceil(M / p_prime))
50+
blk_rows_A = int(math.ceil(N / p_prime))
5151
row_start_A = my_group * blk_rows_A
52-
row_end_A = min(M, row_start_A + blk_rows_A)
52+
row_end_A = min(N, row_start_A + blk_rows_A)
5353

54-
blk_cols_X = int(math.ceil(N / p_prime))
54+
blk_cols_X = int(math.ceil(M / p_prime))
5555
col_start_X = my_layer * blk_cols_X
56-
col_end_X = min(N, col_start_X + blk_cols_X)
56+
col_end_X = min(M, col_start_X + blk_cols_X)
5757
local_col_X_len = max(0, col_end_X - col_start_X)
5858

59-
A_glob_real = np.arange(M * K, dtype=base_float_dtype).reshape(M, K)
60-
A_glob_imag = np.arange(M * K, dtype=base_float_dtype).reshape(M, K) * 0.5
59+
A_glob_real = np.arange(N * K, dtype=base_float_dtype).reshape(N, K)
60+
A_glob_imag = np.arange(N * K, dtype=base_float_dtype).reshape(N, K) * 0.5
6161
A_glob = (A_glob_real + cmplx * A_glob_imag).astype(dtype)
6262

63-
X_glob_real = np.arange(K * N, dtype=base_float_dtype).reshape(K, N)
64-
X_glob_imag = np.arange(K * N, dtype=base_float_dtype).reshape(K, N) * 0.7
63+
X_glob_real = np.arange(K * M, dtype=base_float_dtype).reshape(K, M)
64+
X_glob_imag = np.arange(K * M, dtype=base_float_dtype).reshape(K, M) * 0.7
6565
X_glob = (X_glob_real + cmplx * X_glob_imag).astype(dtype)
6666

6767
A_p = A_glob[row_start_A:row_end_A,:]
6868
X_p = X_glob[:,col_start_X:col_end_X]
6969

7070
# Create MPIMatrixMult operator
71-
Aop = MPIMatrixMult(A_p, N, base_comm=comm, dtype=dtype_str)
71+
Aop = MPIMatrixMult(A_p, M, base_comm=comm, dtype=dtype_str)
7272

7373
# Create DistributedArray for input x (representing B flattened)
7474
all_local_col_len = comm.allgather(local_col_X_len)
@@ -91,13 +91,13 @@ def test_SUMMAMatrixMult(M, K, N, dtype_str):
9191
xadj_dist = Aop.H @ y_dist
9292

9393
y = y_dist.asarray(masked=True)
94-
col_counts = [min(blk_cols_X, N - j * blk_cols_X) for j in range(p_prime)]
94+
col_counts = [min(blk_cols_X, M - j * blk_cols_X) for j in range(p_prime)]
9595
y_blocks = []
9696
offset = 0
9797
for cnt in col_counts:
98-
block_size = M * cnt
98+
block_size = N * cnt
9999
y_blocks.append(
100-
y[offset: offset + block_size].reshape(M, cnt)
100+
y[offset: offset + block_size].reshape(N, cnt)
101101
)
102102
offset += block_size
103103
y = np.hstack(y_blocks)

0 commit comments

Comments
 (0)