Skip to content

Commit 714014f

Browse files
jd7-trfacebook-github-bot
authored andcommitted
Empty shard support
Differential Revision: D80917317
1 parent 31cf49f commit 714014f

File tree

3 files changed

+48
-4
lines changed

3 files changed

+48
-4
lines changed

torchrec/distributed/dist_data.py

Lines changed: 26 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -173,9 +173,33 @@ def forward(self, tensors: List[torch.Tensor], cat_dim: int) -> torch.Tensor:
173173
Here we assume input tensors are:
174174
[TBE_output_0, ..., TBE_output_(n-1)]
175175
"""
176-
B = tensors[0].size(1 - cat_dim)
176+
# Handle empty shards case (can happen in column-wise sharding)
177+
if not tensors or len(tensors) == 0:
178+
# Return empty tensor if no tensors provided
179+
return torch.empty(0, 0, dtype=torch.float, device=self.current_device)
180+
181+
# Check if we are in TorchScript mode first to avoid global variable access issues
182+
if torch.jit.is_scripting() or torch.jit.is_tracing():
183+
# In TorchScript or JIT tracing mode, use all tensors and let FBGEMM handle empties
184+
tensors_to_use = tensors
185+
else:
186+
if torch.fx._symbolic_trace.is_fx_tracing():
187+
# During FX tracing, include all tensors to avoid control flow issues
188+
tensors_to_use = tensors
189+
else:
190+
# Normal execution: filter out empty tensors
191+
non_empty_tensors = []
192+
193+
for t in tensors:
194+
if t.numel() > 0 and t.size(cat_dim) > 0:
195+
non_empty_tensors.append(t)
196+
197+
tensors_to_use = non_empty_tensors if non_empty_tensors else tensors
198+
199+
# Use the first tensor to determine batch size
200+
B = tensors_to_use[0].size(1 - cat_dim)
177201
return torch.ops.fbgemm.merge_pooled_embeddings(
178-
tensors,
202+
tensors_to_use,
179203
B,
180204
self.current_device,
181205
cat_dim,

torchrec/distributed/quant_embedding_kernel.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -371,6 +371,26 @@ def _emb_module_forward(
371371
lengths_or_offsets: torch.Tensor,
372372
weights: Optional[torch.Tensor],
373373
) -> torch.Tensor:
374+
# Check if total embedding dimension is 0 (can happen in column-wise sharding)
375+
total_D = sum(table.local_cols for table in self._config.embedding_tables)
376+
377+
if total_D == 0:
378+
# For empty shards, return tensor with correct batch size but 0 embedding dimension
379+
# Use tensor operations that are FX symbolic tracing compatible
380+
if self.lengths_to_tbe:
381+
# For lengths format, batch size equals lengths tensor size
382+
# Create [B, 0] tensor using zeros_like and slicing
383+
dummy = torch.zeros_like(lengths_or_offsets, dtype=torch.float)
384+
return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor
385+
else:
386+
# For offsets format, batch size is one less than offset size
387+
# Use tensor slicing to create batch dimension
388+
batch_tensor = lengths_or_offsets[
389+
:-1
390+
] # Remove last element to get batch size
391+
dummy = torch.zeros_like(batch_tensor, dtype=torch.float)
392+
return dummy.unsqueeze(-1)[:, :0] # [B, 0] tensor
393+
374394
kwargs = {"indices": indices}
375395

376396
if self.lengths_to_tbe:

torchrec/distributed/tests/test_infer_shardings.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -571,9 +571,9 @@ def test_cw(
571571
def test_uneven_cw(self, weight_dtype: torch.dtype, device_type: str) -> None:
572572
num_embeddings = 64
573573
emb_dim = 512
574-
dim_1 = 63
574+
dim_1 = 0
575575
dim_2 = 128
576-
dim_3 = 65
576+
dim_3 = 128
577577
dim_4 = 256
578578
local_size = 4
579579
world_size = 4

0 commit comments

Comments
 (0)