Skip to content

Commit 3f268a1

Browse files
faran928facebook-github-bot
authored andcommitted
Support for uneven heterogenous sharding for inference sharded tensor pool
Summary: A few changes in the diff: 1. Support to proportionally shard the tensor pool based on memory capacity per rank. 2. Using block_bucketize_sparse_features_inference to return bucket_mapping that can be used during request batching in inference w/ custom sigrid predictor engine 3. Wrapping some of the operations with fx wrappers to make it compatible with model split boundaries for DLRM serving where embeddings are sharded and split onto different pytorch modules 4. Exposing set_device() api to some of the modules if we want to place some shards to cpu while others to cuda. 5. Move _get_unbucketize_tensor_via_length_alignment to common util files. Differential Revision: D79603009
1 parent 691d11f commit 3f268a1

File tree

9 files changed

+527
-37
lines changed

9 files changed

+527
-37
lines changed

torchrec/distributed/keyed_jagged_tensor_pool.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -630,7 +630,7 @@ def _lookup_values_dist(
630630

631631
# pyre-ignore
632632
def forward(self, ids: torch.Tensor) -> KeyedJaggedTensor:
633-
dist_input, unbucketize_permute = self._lookup_ids_dist(ids)
633+
dist_input, unbucketize_permute, _, _ = self._lookup_ids_dist(ids)
634634
lookup = self._lookup_local(dist_input)
635635
# Here we are playing a trick to workaround a fx tracing issue,
636636
# as proxy is not iteratable.

torchrec/distributed/quant_embedding.py

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,6 +84,7 @@
8484
from torchrec.modules.utils import (
8585
_fx_trec_get_feature_length,
8686
_get_batching_hinted_output,
87+
_get_unbucketize_tensor_via_length_alignment,
8788
)
8889
from torchrec.quant.embedding_modules import (
8990
EmbeddingCollection as QuantEmbeddingCollection,
@@ -96,6 +97,7 @@
9697
torch.fx.wrap("len")
9798
torch.fx.wrap("_get_batching_hinted_output")
9899
torch.fx.wrap("_fx_trec_get_feature_length")
100+
torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")
99101

100102
try:
101103
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
@@ -278,16 +280,6 @@ def _fx_trec_wrap_length_tolist(length: torch.Tensor) -> List[int]:
278280
return length.long().tolist()
279281

280282

281-
@torch.fx.wrap
282-
def _get_unbucketize_tensor_via_length_alignment(
283-
lengths: torch.Tensor,
284-
bucketize_length: torch.Tensor,
285-
bucketize_permute_tensor: torch.Tensor,
286-
bucket_mapping_tensor: torch.Tensor,
287-
) -> torch.Tensor:
288-
return bucketize_permute_tensor
289-
290-
291283
@torch.fx.wrap
292284
def _fx_split_embeddings_per_feature_length(
293285
embeddings: torch.Tensor,

torchrec/distributed/sharding/rw_pool_sharding.py

Lines changed: 39 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,9 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
166166
block_size (torch.Tensor): tensor containing block sizes for each rank.
167167
e.g. if block_size=torch.tensor(100), then IDs 0-99 will be assigned to rank
168168
0, 100-199 to rank 1, and so on.
169+
block_bucketize_row_pos (torch.Tensor]): tensor containing shard/row offsets for each
170+
rank in case of uneven sharding of the tensor pool across ranks. If not provided,
171+
then block_size will be used to permute the IDs across ranks.
169172
170173
Example:
171174
device = torch.device("cpu")
@@ -179,22 +182,27 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
179182
_world_size: int
180183
_device: torch.device
181184
_block_size: torch.Tensor
185+
_block_bucketize_row_pos: list[torch.Tensor]
182186

183187
def __init__(
184188
self,
185189
env: ShardingEnv,
186190
device: torch.device,
187191
block_size: torch.Tensor,
192+
block_bucketize_row_pos: Optional[list[torch.Tensor]] = None,
188193
) -> None:
189194
super().__init__()
190195
self._world_size = env.world_size
191196
self._device = device
192197
self._block_size = block_size
198+
self._block_bucketize_row_pos: list[torch.Tensor] = (
199+
[] if block_bucketize_row_pos is None else block_bucketize_row_pos
200+
)
193201

194202
def forward(
195203
self,
196204
ids: torch.Tensor,
197-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
205+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
198206
"""
199207
Bucketizes ids tensor into a list of tensors each containing ids
200208
for the corresponding rank. Places each tensor on the appropriate device.
@@ -203,24 +211,34 @@ def forward(
203211
ids (torch.Tensor): Tensor with ids
204212
205213
Returns:
206-
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing list of ids tensors
207-
for each rank given the bucket sizes, and the tensor containing indices
208-
to permute the ids to get the original order before bucketization.
214+
Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
215+
Tuple containing
216+
1. list of ids tensors for each rank given the bucket sizes
217+
2. the tensor containing indices to permute the ids to get the original order before bucketization.
218+
3. the tensor containing the bucket mapping for each id
219+
4. the tensor containing the bucketized lengths
209220
"""
210221
(
211222
bucketized_lengths,
212223
bucketized_indices,
213-
_bucketized_weights,
214-
_bucketize_permute,
224+
_, # bucketized_weights
225+
_, # _bucketize_permute
215226
unbucketize_permute,
216-
) = torch.ops.fbgemm.block_bucketize_sparse_features(
217-
_get_bucketize_shape(ids, ids.device),
218-
ids.long(),
227+
bucket_mapping,
228+
) = torch.ops.fbgemm.block_bucketize_sparse_features_inference(
229+
lengths=_get_bucketize_shape(ids, ids.device),
230+
indices=ids.long(),
219231
bucketize_pos=False,
220232
sequence=True,
221233
block_sizes=self._block_size.long(),
222234
my_size=self._world_size,
223235
weights=None,
236+
block_bucketize_pos=(
237+
self._block_bucketize_row_pos
238+
if len(self._block_bucketize_row_pos) > 0
239+
else None
240+
),
241+
return_bucket_mapping=True,
224242
)
225243

226244
id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)
@@ -236,7 +254,13 @@ def forward(
236254
)
237255

238256
assert unbucketize_permute is not None, "unbucketize permute must not be None"
239-
return dist_ids, unbucketize_permute
257+
assert bucket_mapping is not None, "bucket mapping must not be None"
258+
return (
259+
dist_ids,
260+
unbucketize_permute,
261+
bucket_mapping,
262+
bucketized_lengths,
263+
)
240264

241265
def update(
242266
self,
@@ -270,6 +294,11 @@ def update(
270294
block_sizes=self._block_size.long(),
271295
my_size=self._world_size,
272296
weights=None,
297+
block_bucketize_pos=(
298+
self._block_bucketize_row_pos
299+
if len(self._block_bucketize_row_pos) > 0
300+
else None
301+
),
273302
)
274303

275304
id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)

torchrec/distributed/sharding/rw_tensor_pool_sharding.py

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -219,11 +219,16 @@ def __init__(
219219
device: torch.device,
220220
) -> None:
221221
super().__init__()
222-
self._device: Optional[torch.device] = device
222+
self._device: torch.device = device
223223
self._world_size: int = env.world_size
224224
self._cat_dim = 0
225225
self._placeholder: torch.Tensor = torch.ones(1, device=device)
226226

227+
@torch.jit.export
228+
def set_device(self, device_str: str) -> None:
229+
self._device = torch.device(device_str)
230+
self._placeholder = torch.ones(1, device=self._device)
231+
227232
def forward(
228233
self,
229234
lookups: List[torch.Tensor],
@@ -256,12 +261,16 @@ def __init__(
256261
pool_size: int,
257262
env: ShardingEnv,
258263
device: torch.device,
264+
memory_capacity_per_rank: Optional[list[int]] = None,
259265
) -> None:
260-
super().__init__(pool_size, env, device)
266+
super().__init__(pool_size, env, device, memory_capacity_per_rank)
261267

262268
def create_lookup_ids_dist(self) -> InferRwObjectPoolInputDist:
263269
return InferRwObjectPoolInputDist(
264-
self._env, device=self._device, block_size=self._block_size_t
270+
self._env,
271+
device=self._device,
272+
block_size=self._block_size_t,
273+
block_bucketize_row_pos=self._block_bucketize_row_pos,
265274
)
266275

267276
def create_lookup_values_dist(

torchrec/distributed/tensor_pool.py

Lines changed: 75 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,14 @@
3232
)
3333
from torchrec.modules.object_pool_lookups import TensorLookup, TensorPoolLookup
3434
from torchrec.modules.tensor_pool import TensorPool
35-
from torchrec.modules.utils import deterministic_dedup
35+
from torchrec.modules.utils import (
36+
_get_batching_hinted_output,
37+
_get_unbucketize_tensor_via_length_alignment,
38+
deterministic_dedup,
39+
)
40+
41+
torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")
42+
torch.fx.wrap("_get_batching_hinted_output")
3643

3744

3845
@torch.fx.wrap
@@ -44,6 +51,17 @@ def index_select_view(
4451
return output[unbucketize_permute].view(-1, dim)
4552

4653

54+
@torch.fx.wrap
55+
def _fx_item_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor:
56+
assert optional is not None, "Expected optional to be non-None Tensor"
57+
return optional
58+
59+
60+
@torch.fx.wrap
61+
def _get_id_length_sharded_tensor_pool(ids: torch.Tensor) -> torch.Tensor:
62+
return torch.tensor([ids.size(dim=0)], device=ids.device, dtype=torch.long)
63+
64+
4765
class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]):
4866
def __init__(
4967
self,
@@ -271,6 +289,8 @@ class LocalShardPool(torch.nn.Module):
271289
# out is tensor([1,2,3]) i.e. first row of the shard
272290
"""
273291

292+
current_device: torch.device
293+
274294
def __init__(
275295
self,
276296
shard: torch.Tensor,
@@ -280,6 +300,12 @@ def __init__(
280300
shard,
281301
requires_grad=False,
282302
)
303+
self.current_device = self._shard.device
304+
305+
@torch.jit.export
306+
def set_device(self, device_str: str) -> None:
307+
self.current_device = torch.device(device_str)
308+
self._shard.to(self.current_device)
283309

284310
def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
285311
"""
@@ -291,7 +317,7 @@ def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
291317
Returns:
292318
torch.Tensor: Tensor of values corresponding to the given rank ids.
293319
"""
294-
return self._shard[rank_ids]
320+
return self._shard[rank_ids.to(self.current_device)]
295321

296322
def update(self, rank_ids: torch.Tensor, values: torch.Tensor) -> None:
297323
_ = update(self._shard, rank_ids, values)
@@ -337,6 +363,11 @@ def __init__(
337363
env=self._sharding_env,
338364
device=self._device,
339365
pool_size=self._pool_size,
366+
memory_capacity_per_rank=(
367+
self._sharding_plan.memory_capacity_per_rank
368+
if self._sharding_plan.memory_capacity_per_rank is not None
369+
else None
370+
),
340371
)
341372
else:
342373
raise NotImplementedError(
@@ -356,6 +387,7 @@ def __init__(
356387
if device == torch.device("cpu")
357388
else torch.device("cuda", rank)
358389
)
390+
359391
self._local_shard_pools.append(
360392
LocalShardPool(
361393
torch.empty(
@@ -409,7 +441,7 @@ def create_context(self) -> ObjectPoolShardingContext:
409441
def _lookup_ids_dist(
410442
self,
411443
ids: torch.Tensor,
412-
) -> Tuple[List[torch.Tensor], torch.Tensor]:
444+
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
413445
return self._lookup_ids_dist_impl(ids)
414446

415447
# pyre-ignore
@@ -439,18 +471,54 @@ def _lookup_values_dist(
439471

440472
# pyre-ignore
441473
def forward(self, ids: torch.Tensor) -> torch.Tensor:
442-
dist_input, unbucketize_permute = self._lookup_ids_dist(ids)
474+
dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = (
475+
self._lookup_ids_dist(ids)
476+
)
477+
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
478+
unbucketize_permute
479+
)
480+
443481
lookup = self._lookup_local(dist_input)
444482

445483
# Here we are playing a trick to workaround a fx tracing issue,
446484
# as proxy is not iteratable.
447485
lookup_list = []
448-
for i in range(self._world_size):
449-
lookup_list.append(lookup[i])
486+
# In case of non-heterogenous even sharding keeping the behavior
487+
# consistent with existing logic to ensure that additional fx wrappers
488+
# do not impact the model split logic during inference in anyway
489+
if self._sharding_plan.memory_capacity_per_rank is None:
490+
for i in range(self._world_size):
491+
lookup_list.append(lookup[i])
492+
else:
493+
# Adding fx wrappers in case of uneven heterogenous sharding to
494+
# make it compatible with model split boundaries during inference
495+
for i in range(self._world_size):
496+
lookup_list.append(
497+
_get_batching_hinted_output(
498+
_get_id_length_sharded_tensor_pool(dist_input[i]), lookup[i]
499+
)
500+
)
501+
502+
features_before_input_dist_length = _get_id_length_sharded_tensor_pool(ids)
503+
bucketized_lengths_col_view = bucketized_lengths.view(self._world_size, -1)
504+
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
505+
unbucketize_permute
506+
)
507+
bucket_mapping_non_opt = _fx_item_unwrap_optional_tensor(bucket_mapping)
508+
unbucketize_permute_non_opt = _get_unbucketize_tensor_via_length_alignment(
509+
features_before_input_dist_length,
510+
bucketized_lengths_col_view,
511+
unbucketize_permute_non_opt,
512+
bucket_mapping_non_opt,
513+
)
450514

451515
output = self._lookup_values_dist(lookup_list)
452516

453-
return index_select_view(output, unbucketize_permute, self._dim)
517+
return index_select_view(
518+
output,
519+
unbucketize_permute_non_opt.to(device=output.device),
520+
self._dim,
521+
)
454522

455523
# pyre-ignore
456524
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):

torchrec/distributed/tensor_sharding.py

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def __init__(
102102
pool_size: int,
103103
env: ShardingEnv,
104104
device: torch.device,
105+
memory_capacity_per_rank: Optional[list[int]] = None,
105106
) -> None:
106107
self._pool_size = pool_size
107108
self._env = env
@@ -117,13 +118,40 @@ def __init__(
117118
self._last_block_size: int = self._pool_size - self._block_size * (
118119
self._world_size - 1
119120
)
120-
self.local_pool_size_per_rank: List[int] = [self._block_size] * (
121-
self._world_size - 1
122-
) + [self._last_block_size]
123-
121+
# only used for uneven sharding case when memory_capacity_per_rank is provided
122+
row_offset_per_rank = []
123+
124+
if memory_capacity_per_rank is None:
125+
self.local_pool_size_per_rank: List[int] = [self._block_size] * (
126+
self._world_size - 1
127+
) + [self._last_block_size]
128+
else:
129+
row_offset_per_rank = [0]
130+
self.local_pool_size_per_rank: List[int] = []
131+
row_offset = 0
132+
assert (
133+
len(memory_capacity_per_rank) == self._world_size
134+
), "If memory_capacity_per_rank is provided for sharded tensor pool, it must have the same length as world_size"
135+
total_mem_cap = sum(memory_capacity_per_rank)
136+
for cap in memory_capacity_per_rank[:-1]:
137+
rows_per_shard = int(cap / total_mem_cap * self._pool_size)
138+
self.local_pool_size_per_rank.append(rows_per_shard)
139+
row_offset += rows_per_shard
140+
row_offset_per_rank.append(row_offset)
141+
self.local_pool_size_per_rank.append(
142+
self._pool_size - sum(self.local_pool_size_per_rank)
143+
)
144+
row_offset_per_rank.append(self._pool_size)
124145
self._block_size_t: torch.Tensor = torch.tensor(
125146
[self._block_size], device=self._device, dtype=torch.long
126147
)
148+
# for uneven sharding case, we get the row offsets for each rank to
149+
# enable input_dist and lookup of ids to correct rank
150+
self._block_bucketize_row_pos: Optional[List[torch.Tensor]] = (
151+
None
152+
if memory_capacity_per_rank is None
153+
else [torch.tensor(row_offset_per_rank, device=self._device)]
154+
)
127155

128156
@abstractmethod
129157
def create_lookup_ids_dist(self) -> torch.nn.Module:

0 commit comments

Comments
 (0)