@@ -25,7 +25,7 @@ class MPIMatrixMult(MPILinearOperator):
25
25
Local block of the matrix of shape :math:`[M_{loc} \times K]`
26
26
where ``M_loc`` is the number of rows stored on this MPI rank and
27
27
``K`` is the global number of columns.
28
- N : :obj:`int`
28
+ M : :obj:`int`
29
29
Global leading dimension (i.e., number of columns) of the matrices
30
30
representing the input model and data vectors.
31
31
saveAt : :obj:`bool`, optional
@@ -55,9 +55,9 @@ class MPIMatrixMult(MPILinearOperator):
55
55
This operator performs a matrix-matrix multiplication, whose forward
56
56
operation can be described as :math:`Y = A \cdot X` where:
57
57
58
- - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[M \times K]`
59
- - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times N ]`
60
- - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[M \times N ]`
58
+ - :math:`\mathbf{A}` is the distributed matrix operator of shape :math:`[N \times K]`
59
+ - :math:`\mathbf{X}` is the distributed operand matrix of shape :math:`[K \times M ]`
60
+ - :math:`\mathbf{Y}` is the resulting distributed matrix of shape :math:`[N \times M ]`
61
61
62
62
whilst the adjoint operation is represented by
63
63
:math:`\mathbf{X}_{adj} = \mathbf{A}^H \cdot \mathbf{Y}` where
@@ -70,16 +70,16 @@ class MPIMatrixMult(MPILinearOperator):
70
70
71
71
- The matrix ``A`` is distributed across MPI processes in a block-row fashion
72
72
and each process holds a local block of ``A`` with shape
73
- :math:`[M_ {loc} \times K]`
73
+ :math:`[N_ {loc} \times K]`
74
74
- The operand matrix ``X`` is distributed in a block-column fashion and
75
- and each process holds a local block of ``X`` with shape
76
- :math:`[K \times N_ {loc}]`
75
+ each process holds a local block of ``X`` with shape
76
+ :math:`[K \times M_ {loc}]`
77
77
- Communication is minimized by using a 2D process grid layout
78
78
79
79
**Forward Operation step-by-step**
80
80
81
81
1. **Input Preparation**: The input vector ``x`` (flattened from matrix ``X``
82
- of shape ``(K, N )``) is reshaped to ``(K, N_local )`` where ``N_local ``
82
+ of shape ``(K, M )``) is reshaped to ``(K, M_local )`` where ``M_local ``
83
83
is the number of columns assigned to the current process.
84
84
85
85
2. **Data Broadcasting**: Within each layer (processes with same ``layer_id``),
@@ -88,8 +88,8 @@ class MPIMatrixMult(MPILinearOperator):
88
88
the same operand columns.
89
89
90
90
3. **Local Computation**: Each process computes ``A_local @ X_local`` where:
91
- - ``A_local`` is the local block of matrix ``A`` (shape ``M_local x K``)
92
- - ``X_local`` is the broadcasted operand (shape ``K x N_local ``)
91
+ - ``A_local`` is the local block of matrix ``A`` (shape ``N_local x K``)
92
+ - ``X_local`` is the broadcasted operand (shape ``K x M_local ``)
93
93
94
94
4. **Layer Gather**: Results from all processes in each layer are gathered
95
95
using ``allgather`` to reconstruct the full result matrix vertically.
@@ -98,7 +98,7 @@ class MPIMatrixMult(MPILinearOperator):
98
98
99
99
The adjoint operation performs the conjugate transpose multiplication:
100
100
101
- 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(M, N_local )``
101
+ 1. **Input Reshaping**: The input vector ``x`` is reshaped to ``(N, M_local )``
102
102
representing the local columns of the input matrix.
103
103
104
104
2. **Local Adjoint Computation**:
@@ -107,21 +107,21 @@ class MPIMatrixMult(MPILinearOperator):
107
107
- Pre-computed ``At`` (if ``saveAt=True``)
108
108
- Computed on-the-fly as ``A.T.conj()`` (if ``saveAt=False``)
109
109
Each process multiplies its transposed local ``A`` block ``A_local^H``
110
- (shape ``K x M_block ``)
111
- with the extracted ``X_tile`` (shape ``M_block x N_local ``),
112
- producing a partial result of shape ``(K, N_local )``.
110
+ (shape ``K x N_block ``)
111
+ with the extracted ``X_tile`` (shape ``N_block x M_local ``),
112
+ producing a partial result of shape ``(K, M_local )``.
113
113
This computes the local contribution of columns of ``A^H`` to the final result.
114
114
115
115
3. **Layer Reduction**: Since the full result ``Y = A^H \cdot X`` is the
116
116
sum of contributions from all column blocks of ``A^H``, processes in the
117
117
same layer perform an ``allreduce`` sum to combine their partial results.
118
- This gives the complete ``(K, N_local )`` result for their assigned columns.
118
+ This gives the complete ``(K, M_local )`` result for their assigned columns.
119
119
120
120
"""
121
121
def __init__ (
122
122
self ,
123
123
A : NDArray ,
124
- N : int ,
124
+ M : int ,
125
125
saveAt : bool = False ,
126
126
base_comm : MPI .Comm = MPI .COMM_WORLD ,
127
127
dtype : DTypeLike = "float64" ,
@@ -147,25 +147,25 @@ def __init__(
147
147
self .A = A .astype (np .dtype (dtype ))
148
148
if saveAt : self .At = A .T .conj ()
149
149
150
- self .M = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
150
+ self .N = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
151
151
self .K = A .shape [1 ]
152
- self .N = N
152
+ self .M = M
153
153
154
- block_cols = int (math .ceil (self .N / self ._P_prime ))
155
- blk_rows = int (math .ceil (self .M / self ._P_prime ))
154
+ block_cols = int (math .ceil (self .M / self ._P_prime ))
155
+ blk_rows = int (math .ceil (self .N / self ._P_prime ))
156
156
157
157
self ._row_start = self ._group_id * blk_rows
158
- self ._row_end = min (self .M , self ._row_start + blk_rows )
158
+ self ._row_end = min (self .N , self ._row_start + blk_rows )
159
159
160
160
self ._col_start = self ._layer_id * block_cols
161
- self ._col_end = min (self .N , self ._col_start + block_cols )
161
+ self ._col_end = min (self .M , self ._col_start + block_cols )
162
162
163
163
self ._local_ncols = self ._col_end - self ._col_start
164
164
self ._rank_col_lens = self .base_comm .allgather (self ._local_ncols )
165
165
total_ncols = np .sum (self ._rank_col_lens )
166
166
167
167
self .dims = (self .K , total_ncols )
168
- self .dimsd = (self .M , total_ncols )
168
+ self .dimsd = (self .N , total_ncols )
169
169
shape = (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
170
170
super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
171
171
@@ -174,8 +174,8 @@ def _matvec(self, x: DistributedArray) -> DistributedArray:
174
174
if x .partition != Partition .SCATTER :
175
175
raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
176
176
177
- y = DistributedArray (global_shape = (self .M * self .dimsd [1 ]),
178
- local_shapes = [(self .M * c ) for c in self ._rank_col_lens ],
177
+ y = DistributedArray (global_shape = (self .N * self .dimsd [1 ]),
178
+ local_shapes = [(self .N * c ) for c in self ._rank_col_lens ],
179
179
mask = x .mask ,
180
180
partition = Partition .SCATTER ,
181
181
dtype = self .dtype )
@@ -204,7 +204,7 @@ def _rmatvec(self, x: DistributedArray) -> DistributedArray:
204
204
dtype = self .dtype ,
205
205
)
206
206
207
- x_arr = x .local_array .reshape ((self .M , self ._local_ncols )).astype (self .dtype )
207
+ x_arr = x .local_array .reshape ((self .N , self ._local_ncols )).astype (self .dtype )
208
208
X_tile = x_arr [self ._row_start :self ._row_end , :]
209
209
A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
210
210
Y_local = ncp .matmul (A_local , X_tile )
0 commit comments