3232)
3333from torchrec .modules .object_pool_lookups import TensorLookup , TensorPoolLookup
3434from torchrec .modules .tensor_pool import TensorPool
35- from torchrec .modules .utils import deterministic_dedup
35+ from torchrec .modules .utils import (
36+ _get_batching_hinted_output ,
37+ _get_unbucketize_tensor_via_length_alignment ,
38+ deterministic_dedup ,
39+ )
40+
41+ torch .fx .wrap ("_get_unbucketize_tensor_via_length_alignment" )
42+ torch .fx .wrap ("_get_batching_hinted_output" )
3643
3744
3845@torch .fx .wrap
@@ -44,6 +51,17 @@ def index_select_view(
4451 return output [unbucketize_permute ].view (- 1 , dim )
4552
4653
54+ @torch .fx .wrap
55+ def _fx_item_unwrap_optional_tensor (optional : Optional [torch .Tensor ]) -> torch .Tensor :
56+ assert optional is not None , "Expected optional to be non-None Tensor"
57+ return optional
58+
59+
60+ @torch .fx .wrap
61+ def _get_id_length_sharded_tensor_pool (ids : torch .Tensor ) -> torch .Tensor :
62+ return torch .tensor ([ids .size (dim = 0 )], device = ids .device , dtype = torch .long )
63+
64+
4765class TensorPoolAwaitable (LazyAwaitable [torch .Tensor ]):
4866 def __init__ (
4967 self ,
@@ -271,6 +289,8 @@ class LocalShardPool(torch.nn.Module):
271289 # out is tensor([1,2,3]) i.e. first row of the shard
272290 """
273291
292+ current_device : torch .device
293+
274294 def __init__ (
275295 self ,
276296 shard : torch .Tensor ,
@@ -280,6 +300,12 @@ def __init__(
280300 shard ,
281301 requires_grad = False ,
282302 )
303+ self .current_device = self ._shard .device
304+
305+ @torch .jit .export
306+ def set_device (self , device_str : str ) -> None :
307+ self .current_device = torch .device (device_str )
308+ self ._shard .to (self .current_device )
283309
284310 def forward (self , rank_ids : torch .Tensor ) -> torch .Tensor :
285311 """
@@ -291,7 +317,7 @@ def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
291317 Returns:
292318 torch.Tensor: Tensor of values corresponding to the given rank ids.
293319 """
294- return self ._shard [rank_ids ]
320+ return self ._shard [rank_ids . to ( self . current_device ) ]
295321
296322 def update (self , rank_ids : torch .Tensor , values : torch .Tensor ) -> None :
297323 _ = update (self ._shard , rank_ids , values )
@@ -337,6 +363,11 @@ def __init__(
337363 env = self ._sharding_env ,
338364 device = self ._device ,
339365 pool_size = self ._pool_size ,
366+ memory_capacity_per_rank = (
367+ self ._sharding_plan .memory_capacity_per_rank
368+ if self ._sharding_plan .memory_capacity_per_rank is not None
369+ else None
370+ ),
340371 )
341372 else :
342373 raise NotImplementedError (
@@ -356,6 +387,7 @@ def __init__(
356387 if device == torch .device ("cpu" )
357388 else torch .device ("cuda" , rank )
358389 )
390+
359391 self ._local_shard_pools .append (
360392 LocalShardPool (
361393 torch .empty (
@@ -409,7 +441,7 @@ def create_context(self) -> ObjectPoolShardingContext:
409441 def _lookup_ids_dist (
410442 self ,
411443 ids : torch .Tensor ,
412- ) -> Tuple [List [torch .Tensor ], torch .Tensor ]:
444+ ) -> Tuple [List [torch .Tensor ], torch .Tensor , torch . Tensor , torch . Tensor ]:
413445 return self ._lookup_ids_dist_impl (ids )
414446
415447 # pyre-ignore
@@ -439,18 +471,54 @@ def _lookup_values_dist(
439471
440472 # pyre-ignore
441473 def forward (self , ids : torch .Tensor ) -> torch .Tensor :
442- dist_input , unbucketize_permute = self ._lookup_ids_dist (ids )
474+ dist_input , unbucketize_permute , bucket_mapping , bucketized_lengths = (
475+ self ._lookup_ids_dist (ids )
476+ )
477+ unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor (
478+ unbucketize_permute
479+ )
480+
443481 lookup = self ._lookup_local (dist_input )
444482
445483 # Here we are playing a trick to workaround a fx tracing issue,
446484 # as proxy is not iteratable.
447485 lookup_list = []
448- for i in range (self ._world_size ):
449- lookup_list .append (lookup [i ])
486+ # In case of non-heterogenous even sharding keeping the behavior
487+ # consistent with existing logic to ensure that additional fx wrappers
488+ # do not impact the model split logic during inference in anyway
489+ if self ._sharding_plan .memory_capacity_per_rank is None :
490+ for i in range (self ._world_size ):
491+ lookup_list .append (lookup [i ])
492+ else :
493+ # Adding fx wrappers in case of uneven heterogenous sharding to
494+ # make it compatible with model split boundaries during inference
495+ for i in range (self ._world_size ):
496+ lookup_list .append (
497+ _get_batching_hinted_output (
498+ _get_id_length_sharded_tensor_pool (dist_input [i ]), lookup [i ]
499+ )
500+ )
501+
502+ features_before_input_dist_length = _get_id_length_sharded_tensor_pool (ids )
503+ bucketized_lengths_col_view = bucketized_lengths .view (self ._world_size , - 1 )
504+ unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor (
505+ unbucketize_permute
506+ )
507+ bucket_mapping_non_opt = _fx_item_unwrap_optional_tensor (bucket_mapping )
508+ unbucketize_permute_non_opt = _get_unbucketize_tensor_via_length_alignment (
509+ features_before_input_dist_length ,
510+ bucketized_lengths_col_view ,
511+ unbucketize_permute_non_opt ,
512+ bucket_mapping_non_opt ,
513+ )
450514
451515 output = self ._lookup_values_dist (lookup_list )
452516
453- return index_select_view (output , unbucketize_permute , self ._dim )
517+ return index_select_view (
518+ output ,
519+ unbucketize_permute_non_opt .to (device = output .device ),
520+ self ._dim ,
521+ )
454522
455523 # pyre-ignore
456524 def _update_values_dist (self , ctx : ObjectPoolShardingContext , values : torch .Tensor ):
0 commit comments