Skip to content

Commit 708ddb1

Browse files
faran928facebook-github-bot
authored andcommitted
uneven sharding for sharded tensor pool
Summary: trochrec uneven sharding changes Differential Revision: D79603009
1 parent 7dcdccb commit 708ddb1

File tree

8 files changed

+525
-35
lines changed

8 files changed

+525
-35
lines changed

torchrec/distributed/quant_embedding.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from torchrec.modules.utils import (
8585
_fx_trec_get_feature_length,
8686
_get_batching_hinted_output,
87+
_get_unbucketize_tensor_via_length_alignment,
8788
)
8889
from torchrec.quant.embedding_modules import (
8990
EmbeddingCollection as QuantEmbeddingCollection,
@@ -96,6 +97,7 @@
9697
torch.fx.wrap("len")
9798
torch.fx.wrap("_get_batching_hinted_output")
9899
torch.fx.wrap("_fx_trec_get_feature_length")
100+
torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")
99101

100102
try:
101103
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -278,16 +280,6 @@ def _fx_trec_wrap_length_tolist(length: torch.Tensor) -> List[int]:
278280
return length.long().tolist()
279281

280282

281-
@torch.fx.wrap
282-
def _get_unbucketize_tensor_via_length_alignment(
283-
lengths: torch.Tensor,
284-
bucketize_length: torch.Tensor,
285-
bucketize_permute_tensor: torch.Tensor,
286-
bucket_mapping_tensor: torch.Tensor,
287-
) -> torch.Tensor:
288-
return bucketize_permute_tensor
289-
290-
291283
@torch.fx.wrap
292284
def _fx_split_embeddings_per_feature_length(
293285
embeddings: torch.Tensor,

torchrec/distributed/sharding/rw_pool_sharding.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
166166
block_size (torch.Tensor): tensor containing block sizes for each rank.
167167
e.g. if block_size=torch.tensor(100), then IDs 0-99 will be assigned to rank
168168
0, 100-199 to rank 1, and so on.
169+
block_bucketize_row_pos (torch.Tensor]): tensor containing shard/row offsets for each
170+
rank in case of uneven sharding of the tensor pool across ranks. If not provided,
171+
then block_size will be used to permute the IDs across ranks.
169172
170173
Example:
171174
device = torch.device("cpu")
@@ -179,22 +182,27 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
179182
_world_size: int
180183
_device: torch.device
181184
_block_size: torch.Tensor
185+
_block_bucketize_row_pos: list[torch.Tensor]
182186

183187
def __init__(
184188
self,
185189
env: ShardingEnv,
186190
device: torch.device,
187191
block_size: torch.Tensor,
192+
block_bucketize_row_pos: Optional[list[torch.Tensor]] = None,
188193
) -> None:
189194
super().__init__()
190195
self._world_size = env.world_size
191196
self._device = device
192197
self._block_size = block_size
198+
self._block_bucketize_row_pos: list[torch.Tensor] = (
199+
[] if block_bucketize_row_pos is None else block_bucketize_row_pos
200+
)
193201

194202
def forward(
195203
self,
196204
ids: torch.Tensor,
197-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
205+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
198206
"""
199207
Bucketizes ids tensor into a list of tensors each containing ids
200208
for the corresponding rank. Places each tensor on the appropriate device.
@@ -203,24 +211,34 @@ def forward(
203211
ids (torch.Tensor): Tensor with ids
204212
205213
Returns:
206-
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing list of ids tensors
207-
for each rank given the bucket sizes, and the tensor containing indices
208-
to permute the ids to get the original order before bucketization.
214+
Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
215+
Tuple containing
216+
1. list of ids tensors for each rank given the bucket sizes
217+
2. the tensor containing indices to permute the ids to get the original order before bucketization.
218+
3. the tensor containing the bucket mapping for each id
219+
4. the tensor containing the bucketized lengths
209220
"""
210221
(
211222
bucketized_lengths,
212223
bucketized_indices,
213-
_bucketized_weights,
214-
_bucketize_permute,
224+
_, # bucketized_weights
225+
_, # _bucketize_permute
215226
unbucketize_permute,
216-
) = torch.ops.fbgemm.block_bucketize_sparse_features(
217-
_get_bucketize_shape(ids, ids.device),
218-
ids.long(),
227+
bucket_mapping,
228+
) = torch.ops.fbgemm.block_bucketize_sparse_features_inference(
229+
lengths=_get_bucketize_shape(ids, ids.device),
230+
indices=ids.long(),
219231
bucketize_pos=False,
220232
sequence=True,
221233
block_sizes=self._block_size.long(),
222234
my_size=self._world_size,
223235
weights=None,
236+
block_bucketize_pos=(
237+
self._block_bucketize_row_pos
238+
if len(self._block_bucketize_row_pos) > 0
239+
else None
240+
),
241+
return_bucket_mapping=True,
224242
)
225243

226244
id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)
@@ -236,7 +254,13 @@ def forward(
236254
)
237255

238256
assert unbucketize_permute is not None, "unbucketize permute must not be None"
239-
return dist_ids, unbucketize_permute
257+
assert bucket_mapping is not None, "bucket mapping must not be None"
258+
return (
259+
dist_ids,
260+
unbucketize_permute,
261+
bucket_mapping,
262+
bucketized_lengths,
263+
)
240264

241265
def update(
242266
self,
@@ -270,6 +294,11 @@ def update(
270294
block_sizes=self._block_size.long(),
271295
my_size=self._world_size,
272296
weights=None,
297+
block_bucketize_pos=(
298+
self._block_bucketize_row_pos
299+
if len(self._block_bucketize_row_pos) > 0
300+
else None
301+
),
273302
)
274303

275304
id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)

torchrec/distributed/sharding/rw_tensor_pool_sharding.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -224,6 +224,11 @@ def __init__(
224224
self._cat_dim = 0
225225
self._placeholder: torch.Tensor = torch.ones(1, device=device)
226226

227+
@torch.jit.export
228+
def set_device(self, device_str: str) -> None:
229+
self._device = torch.device(device_str)
230+
self._placeholder = torch.ones(1, device=self._device)
231+
227232
def forward(
228233
self,
229234
lookups: List[torch.Tensor],
@@ -256,12 +261,16 @@ def __init__(
256261
pool_size: int,
257262
env: ShardingEnv,
258263
device: torch.device,
264+
memory_capacity_per_rank: Optional[list[int]] = None,
259265
) -> None:
260-
super().__init__(pool_size, env, device)
266+
super().__init__(pool_size, env, device, memory_capacity_per_rank)
261267

262268
def create_lookup_ids_dist(self) -> InferRwObjectPoolInputDist:
263269
return InferRwObjectPoolInputDist(
264-
self._env, device=self._device, block_size=self._block_size_t
270+
self._env,
271+
device=self._device,
272+
block_size=self._block_size_t,
273+
block_bucketize_row_pos=self._block_bucketize_row_pos,
265274
)
266275

267276
def create_lookup_values_dist(

torchrec/distributed/tensor_pool.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@
3232
)
3333
from torchrec.modules.object_pool_lookups import TensorLookup, TensorPoolLookup
3434
from torchrec.modules.tensor_pool import TensorPool
35-
from torchrec.modules.utils import deterministic_dedup
35+
from torchrec.modules.utils import (
36+
_get_batching_hinted_output,
37+
_get_unbucketize_tensor_via_length_alignment,
38+
deterministic_dedup,
39+
)
40+
41+
torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")
42+
torch.fx.wrap("_get_batching_hinted_output")
3643

3744

3845
@torch.fx.wrap
@@ -44,6 +51,17 @@ def index_select_view(
4451
return output[unbucketize_permute].view(-1, dim)
4552

4653

54+
@torch.fx.wrap
55+
def _fx_item_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor:
56+
assert optional is not None, "Expected optional to be non-None Tensor"
57+
return optional
58+
59+
60+
@torch.fx.wrap
61+
def _get_id_length_sharded_tensor_pool(ids: torch.Tensor) -> torch.Tensor:
62+
return torch.tensor([ids.size(dim=0)], device=ids.device, dtype=torch.long)
63+
64+
4765
class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]):
4866
def __init__(
4967
self,
@@ -271,6 +289,8 @@ class LocalShardPool(torch.nn.Module):
271289
# out is tensor([1,2,3]) i.e. first row of the shard
272290
"""
273291

292+
current_device: torch.device
293+
274294
def __init__(
275295
self,
276296
shard: torch.Tensor,
@@ -280,6 +300,12 @@ def __init__(
280300
shard,
281301
requires_grad=False,
282302
)
303+
self.current_device = self._shard.device
304+
305+
@torch.jit.export
306+
def set_device(self, device_str: str) -> None:
307+
self.current_device = torch.device(device_str)
308+
self._shard.to(self.current_device)
283309

284310
def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
285311
"""
@@ -291,7 +317,7 @@ def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
291317
Returns:
292318
torch.Tensor: Tensor of values corresponding to the given rank ids.
293319
"""
294-
return self._shard[rank_ids]
320+
return self._shard[rank_ids.to(self.current_device)]
295321

296322
def update(self, rank_ids: torch.Tensor, values: torch.Tensor) -> None:
297323
_ = update(self._shard, rank_ids, values)
@@ -337,6 +363,11 @@ def __init__(
337363
env=self._sharding_env,
338364
device=self._device,
339365
pool_size=self._pool_size,
366+
memory_capacity_per_rank=(
367+
self._sharding_plan.memory_capacity_per_rank
368+
if self._sharding_plan.memory_capacity_per_rank is not None
369+
else None
370+
),
340371
)
341372
else:
342373
raise NotImplementedError(
@@ -356,6 +387,7 @@ def __init__(
356387
if device == torch.device("cpu")
357388
else torch.device("cuda", rank)
358389
)
390+
359391
self._local_shard_pools.append(
360392
LocalShardPool(
361393
torch.empty(
@@ -409,7 +441,7 @@ def create_context(self) -> ObjectPoolShardingContext:
409441
def _lookup_ids_dist(
410442
self,
411443
ids: torch.Tensor,
412-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
444+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
413445
return self._lookup_ids_dist_impl(ids)
414446

415447
# pyre-ignore
@@ -439,18 +471,54 @@ def _lookup_values_dist(
439471

440472
# pyre-ignore
441473
def forward(self, ids: torch.Tensor) -> torch.Tensor:
442-
dist_input, unbucketize_permute = self._lookup_ids_dist(ids)
474+
dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = (
475+
self._lookup_ids_dist(ids)
476+
)
477+
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
478+
unbucketize_permute
479+
)
480+
443481
lookup = self._lookup_local(dist_input)
444482

445483
# Here we are playing a trick to workaround a fx tracing issue,
446484
# as proxy is not iteratable.
447485
lookup_list = []
448-
for i in range(self._world_size):
449-
lookup_list.append(lookup[i])
486+
# In case of non-heterogenous even sharding keeping the behavior
487+
# consistent with existing logic to ensure that additional fx wrappers
488+
# do not impact the model split logic during inference in anyway
489+
if self._sharding_plan.memory_capacity_per_rank is None:
490+
for i in range(self._world_size):
491+
lookup_list.append(lookup[i])
492+
else:
493+
# Adding fx wrappers in case of uneven heterogenous sharding to
494+
# make it compatible with model split boundaries during inference
495+
for i in range(self._world_size):
496+
lookup_list.append(
497+
_get_batching_hinted_output(
498+
_get_id_length_sharded_tensor_pool(dist_input[i]), lookup[i]
499+
)
500+
)
501+
502+
features_before_input_dist_length = _get_id_length_sharded_tensor_pool(ids)
503+
bucketized_lengths_col_view = bucketized_lengths.view(self._world_size, -1)
504+
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
505+
unbucketize_permute
506+
)
507+
bucket_mapping_non_opt = _fx_item_unwrap_optional_tensor(bucket_mapping)
508+
unbucketize_permute_non_opt = _get_unbucketize_tensor_via_length_alignment(
509+
features_before_input_dist_length,
510+
bucketized_lengths_col_view,
511+
unbucketize_permute_non_opt,
512+
bucket_mapping_non_opt,
513+
)
450514

451515
output = self._lookup_values_dist(lookup_list)
452516

453-
return index_select_view(output, unbucketize_permute, self._dim)
517+
return index_select_view(
518+
output,
519+
unbucketize_permute_non_opt.to(device=output.device),
520+
self._dim,
521+
)
454522

455523
# pyre-ignore
456524
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):

torchrec/distributed/tensor_sharding.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
pool_size: int,
103103
env: ShardingEnv,
104104
device: torch.device,
105+
memory_capacity_per_rank: Optional[list[int]] = None,
105106
) -> None:
106107
self._pool_size = pool_size
107108
self._env = env
@@ -117,13 +118,40 @@ def __init__(
117118
self._last_block_size: int = self._pool_size - self._block_size * (
118119
self._world_size - 1
119120
)
120-
self.local_pool_size_per_rank: List[int] = [self._block_size] * (
121-
self._world_size - 1
122-
) + [self._last_block_size]
123-
121+
# only used for uneven sharding case when memory_capacity_per_rank is provided
122+
row_offset_per_rank = []
123+
124+
if memory_capacity_per_rank is None:
125+
self.local_pool_size_per_rank: List[int] = [self._block_size] * (
126+
self._world_size - 1
127+
) + [self._last_block_size]
128+
else:
129+
row_offset_per_rank = [0]
130+
self.local_pool_size_per_rank: List[int] = []
131+
row_offset = 0
132+
assert (
133+
len(memory_capacity_per_rank) == self._world_size
134+
), "If memory_capacity_per_rank is provided for sharded tensor pool, it must have the same length as world_size"
135+
total_mem_cap = sum(memory_capacity_per_rank)
136+
for cap in memory_capacity_per_rank[:-1]:
137+
rows_per_shard = int(cap / total_mem_cap * self._pool_size)
138+
self.local_pool_size_per_rank.append(rows_per_shard)
139+
row_offset += rows_per_shard
140+
row_offset_per_rank.append(row_offset)
141+
self.local_pool_size_per_rank.append(
142+
self._pool_size - sum(self.local_pool_size_per_rank)
143+
)
144+
row_offset_per_rank.append(self._pool_size)
124145
self._block_size_t: torch.Tensor = torch.tensor(
125146
[self._block_size], device=self._device, dtype=torch.long
126147
)
148+
# for uneven sharding case, we get the row offsets for each rank to
149+
# enable input_dist and lookup of ids to correct rank
150+
self._block_bucketize_row_pos: Optional[List[torch.Tensor]] = (
151+
None
152+
if memory_capacity_per_rank is None
153+
else [torch.tensor(row_offset_per_rank, device=self._device)]
154+
)
127155

128156
@abstractmethod
129157
def create_lookup_ids_dist(self) -> torch.nn.Module:

0 commit comments

Comments
 (0)