@@ -16,6 +16,7 @@ def __init__(
16
16
self ,
17
17
A : NDArray ,
18
18
N : int ,
19
+ saveAt : bool = False ,
19
20
base_comm : MPI .Comm = MPI .COMM_WORLD ,
20
21
dtype : DTypeLike = "float64" ,
21
22
) -> None :
@@ -25,113 +26,91 @@ def __init__(
25
26
# Determine grid dimensions (P_prime × C) such that P_prime * C ≥ size
26
27
self ._P_prime = int (math .ceil (math .sqrt (size )))
27
28
self ._C = int (math .ceil (size / self ._P_prime ))
28
- if self ._P_prime * self ._C < size :
29
+ if self ._P_prime * self ._C != size :
29
30
raise Exception ("Number of Procs must be a square number" )
30
31
31
32
# Compute this process's group and layer indices
32
33
self ._group_id = rank % self ._P_prime
33
34
self ._layer_id = rank // self ._P_prime
34
35
35
36
# Split communicators by layer (rows) and by group (columns)
36
- self .base_comm = base_comm
37
+ self .base_comm = base_comm
37
38
self ._layer_comm = base_comm .Split (color = self ._layer_id , key = self ._group_id )
38
39
self ._group_comm = base_comm .Split (color = self ._group_id , key = self ._layer_id )
39
40
self .A = A .astype (np .dtype (dtype ))
41
+ if saveAt : self .At = A .T .conj ()
40
42
41
43
self .M = self ._layer_comm .allreduce (self .A .shape [0 ], op = MPI .SUM )
42
44
self .K = A .shape [1 ]
43
45
self .N = N
44
46
45
47
# Determine how many columns each group holds
46
48
block_cols = int (math .ceil (self .N / self ._P_prime ))
47
- local_col_start = self ._group_id * block_cols
48
- local_col_end = min (self .N , local_col_start + block_cols )
49
- local_ncols = local_col_end - local_col_start
49
+ blk_rows = int (math .ceil (self .M / self ._P_prime ))
50
50
51
- # Sum up the total number of input columns across all processes
52
- total_ncols = base_comm .allreduce (local_ncols , op = MPI .SUM )
53
- self .dims = (self .K , total_ncols )
51
+ self ._row_start = self ._group_id * blk_rows
52
+ self ._row_end = min (self .M , self ._row_start + blk_rows )
53
+
54
+ self ._col_start = self ._layer_id * block_cols
55
+ self ._col_end = min (self .N , self ._col_start + block_cols )
54
56
55
- # Recompute how many output columns each layer holds
56
- layer_col_start = self ._layer_id * block_cols
57
- layer_col_end = min (self .N , layer_col_start + block_cols )
58
- layer_ncols = layer_col_end - layer_col_start
59
- total_layer_cols = self .base_comm .allreduce (layer_ncols , op = MPI .SUM )
57
+ self ._local_ncols = self ._col_end - self ._col_start
58
+ self ._rank_col_lens = self .base_comm .allgather (self ._local_ncols )
59
+ total_ncols = np .sum (self ._rank_col_lens )
60
60
61
- self .dimsd = (self .M , total_layer_cols )
61
+ self .dims = (self .K , total_ncols )
62
+ self .dimsd = (self .M , total_ncols )
62
63
shape = (int (np .prod (self .dimsd )), int (np .prod (self .dims )))
63
64
super ().__init__ (shape = shape , dtype = np .dtype (dtype ), base_comm = base_comm )
64
-
65
+
65
66
def _matvec (self , x : DistributedArray ) -> DistributedArray :
66
67
ncp = get_module (x .engine )
67
68
if x .partition != Partition .SCATTER :
68
69
raise ValueError (f"x should have partition={ Partition .SCATTER } Got { x .partition } instead..." )
69
- blk_cols = int (math .ceil (self .N / self ._P_prime ))
70
- col_start = self ._layer_id * blk_cols
71
- col_end = min (self .N , col_start + blk_cols )
72
- my_own_cols = max (0 , col_end - col_start )
73
- x = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
74
- x = x .astype (self .dtype )
75
-
76
- B_block = self ._layer_comm .bcast (x if self ._group_id == self ._layer_id else None , root = self ._layer_id )
77
- C_local = ncp .vstack (
70
+
71
+ my_own_cols = self ._rank_col_lens [self .rank ]
72
+ x_arr = x .local_array .reshape ((self .dims [0 ], my_own_cols ))
73
+ x_arr = x_arr .astype (self .dtype )
74
+
75
+ X_local = self ._layer_comm .bcast (x_arr if self ._group_id == self ._layer_id else None , root = self ._layer_id )
76
+ Y_local = ncp .vstack (
78
77
self ._layer_comm .allgather (
79
- ncp .matmul (self .A , B_block )
78
+ ncp .matmul (self .A , X_local )
80
79
)
81
80
)
82
81
83
- layer_col_start = self ._layer_id * blk_cols
84
- layer_col_end = min (self .N , layer_col_start + blk_cols )
85
- layer_ncols = max (0 , layer_col_end - layer_col_start )
86
- layer_col_lens = self .base_comm .allgather (layer_ncols )
87
- mask = [i // self ._P_prime for i in range (self .size )]
88
-
89
- y = DistributedArray (global_shape = (self .M * self .dimsd [1 ]),
90
- local_shapes = [(self .M * c ) for c in layer_col_lens ],
91
- mask = mask ,
82
+ y = DistributedArray (global_shape = (self .M * self .dimsd [1 ]),
83
+ local_shapes = [(self .M * c ) for c in self ._rank_col_lens ],
84
+ mask = x .mask ,
92
85
partition = Partition .SCATTER ,
93
86
dtype = self .dtype )
94
- y [:] = C_local .flatten ()
87
+ y [:] = Y_local .flatten ()
95
88
return y
96
89
97
90
def _rmatvec (self , x : DistributedArray ) -> DistributedArray :
98
91
ncp = get_module (x .engine )
99
92
if x .partition != Partition .SCATTER :
100
93
raise ValueError (f"x should have partition={ Partition .SCATTER } . Got { x .partition } instead." )
101
94
102
- # Determine local column block for this layer
103
- blk_cols = int (math .ceil (self .N / self ._P_prime ))
104
- layer_col_start = self ._layer_id * blk_cols
105
- layer_col_end = min (self .N , layer_col_start + blk_cols )
106
- layer_ncols = layer_col_end - layer_col_start
107
- layer_col_lens = self .base_comm .allgather (layer_ncols )
108
- x = x .local_array .reshape ((self .M , layer_ncols )).astype (self .dtype )
109
-
110
- # Determine local row block for this process group
111
- blk_rows = int (math .ceil (self .M / self ._P_prime ))
112
- row_start = self ._group_id * blk_rows
113
- row_end = min (self .M , row_start + blk_rows )
114
-
115
- B_tile = x [row_start :row_end , :].astype (self .dtype )
116
- A_local = self .A .T .conj ().astype (self .dtype )
117
-
118
- m , b = A_local .shape
119
- pad = (- m ) % self ._P_prime
120
- r = (m + pad ) // self ._P_prime
121
- A_pad = np .pad (A_local , ((0 , pad ), (0 , 0 )), mode = 'constant' , constant_values = self .dtype .type (0.0 ))
95
+ x_arr = x .local_array .reshape ((self .M , self ._local_ncols )).astype (self .dtype )
96
+ X_tile = x_arr [self ._row_start :self ._row_end , :]
97
+
98
+ A_local = self .At if hasattr (self , "At" ) else self .A .T .conj ()
99
+ m , b = A_local .shape
100
+ pad = (- m ) % self ._P_prime
101
+ r = (m + pad ) // self ._P_prime
102
+ A_pad = np .pad (A_local , ((0 , pad ), (0 , 0 )), mode = 'constant' , constant_values = self .dtype .type (0.0 ))
122
103
A_batch = A_pad .reshape (self ._P_prime , r , b )
123
104
124
- # Perform local matmul and unpad
125
- Y_batch = ncp .matmul (A_batch , B_tile ).astype (self .dtype )
126
- Y_pad = Y_batch .reshape (r * self ._P_prime , - 1 )
105
+ Y_batch = ncp .matmul (A_batch , X_tile )
106
+ Y_pad = Y_batch .reshape (r * self ._P_prime , - 1 )
127
107
y_local = Y_pad [:m , :]
128
108
y_layer = self ._layer_comm .allreduce (y_local , op = MPI .SUM )
129
109
130
- mask = [i // self ._P_prime for i in range (self .size )]
131
110
y = DistributedArray (
132
111
global_shape = (self .K * self .dimsd [1 ]),
133
- local_shapes = [self .K * c for c in layer_col_lens ],
134
- mask = mask ,
112
+ local_shapes = [self .K * c for c in self . _rank_col_lens ],
113
+ mask = x . mask ,
135
114
partition = Partition .SCATTER ,
136
115
dtype = self .dtype ,
137
116
)
0 commit comments