Skip to content

Commit 7c7daaf

Browse files
faran928meta-codesync[bot]
authored andcommitted
Handling sequence embedding table-wise sharding onto subset of world size
Summary: While doing table wise sharding, we may have input cases where we don't have enough tables to shard them across all the ranks. In those cases, some embedding modules may not have any embeddings placed onto a few ranks. For table-wise sequence sharding using usharding approach it fails correctly as we modified the split boundary for usharding. Handling empty ranks for those emedding modules where we can just skip those ranks while collecting the results from all the shards Differential Revision: D80360860 fbshipit-source-id: 50fd076b194e3426f1ecdcb6e0e8cc5e9ddab43c
1 parent 51be4ea commit 7c7daaf

File tree

3 files changed

+56
-30
lines changed

3 files changed

+56
-30
lines changed

torchrec/distributed/dist_data.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1240,7 +1240,9 @@ def forward(self, tensors: List[torch.Tensor]) -> List[torch.Tensor]:
12401240
Awaitable[torch.Tensor]: awaitable of the merged pooled embeddings.
12411241
"""
12421242

1243-
assert len(tensors) == self._world_size
1243+
assert (
1244+
len(tensors) == self._world_size
1245+
), f"length of input tensor {len(tensors)} must match with world size {self._world_size}"
12441246
return torch.ops.fbgemm.all_to_one_device(
12451247
tensors,
12461248
self._device,

torchrec/distributed/quant_embedding.py

Lines changed: 40 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -305,37 +305,49 @@ def _construct_jagged_tensors_tw(
305305
storage_device_type: str,
306306
) -> Dict[str, JaggedTensor]:
307307
ret: Dict[str, JaggedTensor] = {}
308+
index = 0
308309
for i in range(len(embedding_names_per_rank)):
309-
embeddings_i = embeddings[i]
310-
features_i: KeyedJaggedTensor = features[i]
311-
if storage_device_type in ["ssd", "cpu"]:
312-
embeddings_i = _get_batching_hinted_output(
313-
_fx_trec_get_feature_length(features_i, embedding_names_per_rank[i]),
314-
embeddings_i,
315-
)
316-
317-
lengths = features_i.lengths().view(-1, features_i.stride())
318-
values = features_i.values()
319-
embeddings_list = _fx_split_embeddings_per_feature_length(
320-
embeddings_i, features_i
321-
)
322-
stride = features_i.stride()
323-
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
324-
if need_indices:
325-
values_list = _fx_split_embeddings_per_feature_length(values, features_i)
326-
for j, key in enumerate(embedding_names_per_rank[i]):
327-
ret[key] = JaggedTensor(
328-
lengths=lengths_tuple[j],
329-
values=embeddings_list[j],
330-
weights=values_list[j],
310+
if len(embedding_names_per_rank[i]) > 0:
311+
embeddings_i = embeddings[index]
312+
features_i: KeyedJaggedTensor = features[i]
313+
if storage_device_type in ["ssd", "cpu"]:
314+
embeddings_i = _get_batching_hinted_output(
315+
_fx_trec_get_feature_length(
316+
features_i, embedding_names_per_rank[index]
317+
),
318+
embeddings_i,
331319
)
332-
else:
333-
for j, key in enumerate(embedding_names_per_rank[i]):
334-
ret[key] = JaggedTensor(
335-
lengths=lengths_tuple[j],
336-
values=embeddings_list[j],
337-
weights=None,
320+
321+
lengths = features_i.lengths().view(-1, features_i.stride())
322+
values = features_i.values()
323+
embeddings_list = _fx_split_embeddings_per_feature_length(
324+
embeddings_i, features_i
325+
)
326+
stride = features_i.stride()
327+
lengths_tuple = torch.unbind(lengths.view(-1, stride), dim=0)
328+
if need_indices:
329+
values_list = _fx_split_embeddings_per_feature_length(
330+
values, features_i
338331
)
332+
for j, key in enumerate(embedding_names_per_rank[i]):
333+
ret[key] = JaggedTensor(
334+
lengths=lengths_tuple[j],
335+
values=embeddings_list[j],
336+
weights=values_list[j],
337+
)
338+
else:
339+
for j, key in enumerate(embedding_names_per_rank[i]):
340+
ret[key] = JaggedTensor(
341+
lengths=lengths_tuple[j],
342+
values=embeddings_list[j],
343+
weights=None,
344+
)
345+
index += 1
346+
# for cuda storage device, empty embeddding per rank is already skipped
347+
# as part of tw_sequence_sharding output dist before executing
348+
# SeqEmbeddingsAllToOne (for cpu / ssd SeqEmbeddingsAllToOne is not required)
349+
elif storage_device_type in ["cpu", "ssd"]:
350+
index += 1
339351
return ret
340352

341353

torchrec/distributed/sharding/tw_sequence_sharding.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,12 +175,21 @@ def __init__(
175175
device: torch.device,
176176
world_size: int,
177177
storage_device_type_from_sharding_infos: Optional[str] = None,
178+
embedding_names_per_rank: Optional[List[List[str]]] = None,
178179
) -> None:
179180
super().__init__()
180-
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(device, world_size)
181+
self._adjusted_world_size: int = (
182+
world_size
183+
if embedding_names_per_rank is None
184+
else sum(1 for sublist in embedding_names_per_rank if len(sublist) > 0)
185+
)
186+
self._dist: SeqEmbeddingsAllToOne = SeqEmbeddingsAllToOne(
187+
device, self._adjusted_world_size
188+
)
181189
self._storage_device_type_from_sharding_infos: Optional[str] = (
182190
storage_device_type_from_sharding_infos
183191
)
192+
self._embedding_names_per_rank = embedding_names_per_rank
184193

185194
def forward(
186195
self,
@@ -216,6 +225,8 @@ def forward(
216225
local_emb,
217226
)
218227
for i, local_emb in enumerate(local_embs)
228+
if self._embedding_names_per_rank is not None
229+
and len(self._embedding_names_per_rank[i]) > 0
219230
]
220231
return self._dist(local_embs)
221232
else:
@@ -269,4 +280,5 @@ def create_output_dist(
269280
device,
270281
self._world_size,
271282
self._storage_device_type_from_sharding_infos,
283+
self.embedding_names_per_rank(),
272284
)

0 commit comments

Comments
 (0)