Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions torchrec/distributed/keyed_jagged_tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -608,7 +608,7 @@ def create_context(self) -> ObjectPoolShardingContext:
def _lookup_ids_dist(
self,
ids: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
return self._lookup_ids_dist_impl(ids)

# pyre-ignore
Expand All @@ -630,7 +630,7 @@ def _lookup_values_dist(

# pyre-ignore
def forward(self, ids: torch.Tensor) -> KeyedJaggedTensor:
dist_input, unbucketize_permute = self._lookup_ids_dist(ids)
dist_input, unbucketize_permute, _, _ = self._lookup_ids_dist(ids)
lookup = self._lookup_local(dist_input)
# Here we are playing a trick to workaround a fx tracing issue,
# as proxy is not iteratable.
Expand Down
12 changes: 2 additions & 10 deletions torchrec/distributed/quant_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,7 @@
from torchrec.modules.utils import (
_fx_trec_get_feature_length,
_get_batching_hinted_output,
_get_unbucketize_tensor_via_length_alignment,
)
from torchrec.quant.embedding_modules import (
EmbeddingCollection as QuantEmbeddingCollection,
Expand All @@ -96,6 +97,7 @@
torch.fx.wrap("len")
torch.fx.wrap("_get_batching_hinted_output")
torch.fx.wrap("_fx_trec_get_feature_length")
torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")

try:
torch.ops.load_library("//deeplearning/fbgemm/fbgemm_gpu:sparse_ops")
Expand Down Expand Up @@ -278,16 +280,6 @@ def _fx_trec_wrap_length_tolist(length: torch.Tensor) -> List[int]:
return length.long().tolist()


@torch.fx.wrap
def _get_unbucketize_tensor_via_length_alignment(
lengths: torch.Tensor,
bucketize_length: torch.Tensor,
bucketize_permute_tensor: torch.Tensor,
bucket_mapping_tensor: torch.Tensor,
) -> torch.Tensor:
return bucketize_permute_tensor


@torch.fx.wrap
def _fx_split_embeddings_per_feature_length(
embeddings: torch.Tensor,
Expand Down
49 changes: 39 additions & 10 deletions torchrec/distributed/sharding/rw_pool_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,9 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
block_size (torch.Tensor): tensor containing block sizes for each rank.
e.g. if block_size=torch.tensor(100), then IDs 0-99 will be assigned to rank
0, 100-199 to rank 1, and so on.
block_bucketize_row_pos (torch.Tensor]): tensor containing shard/row offsets for each
rank in case of uneven sharding of the tensor pool across ranks. If not provided,
then block_size will be used to permute the IDs across ranks.

Example:
device = torch.device("cpu")
Expand All @@ -179,22 +182,27 @@ class InferRwObjectPoolInputDist(torch.nn.Module):
_world_size: int
_device: torch.device
_block_size: torch.Tensor
_block_bucketize_row_pos: list[torch.Tensor]

def __init__(
self,
env: ShardingEnv,
device: torch.device,
block_size: torch.Tensor,
block_bucketize_row_pos: Optional[list[torch.Tensor]] = None,
) -> None:
super().__init__()
self._world_size = env.world_size
self._device = device
self._block_size = block_size
self._block_bucketize_row_pos: list[torch.Tensor] = (
[] if block_bucketize_row_pos is None else block_bucketize_row_pos
)

def forward(
self,
ids: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Bucketizes ids tensor into a list of tensors each containing ids
for the corresponding rank. Places each tensor on the appropriate device.
Expand All @@ -203,24 +211,34 @@ def forward(
ids (torch.Tensor): Tensor with ids

Returns:
Tuple[List[torch.Tensor], torch.Tensor]: Tuple containing list of ids tensors
for each rank given the bucket sizes, and the tensor containing indices
to permute the ids to get the original order before bucketization.
Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
Tuple containing
1. list of ids tensors for each rank given the bucket sizes
2. the tensor containing indices to permute the ids to get the original order before bucketization.
3. the tensor containing the bucket mapping for each id
4. the tensor containing the bucketized lengths
"""
(
bucketized_lengths,
bucketized_indices,
_bucketized_weights,
_bucketize_permute,
_, # bucketized_weights
_, # _bucketize_permute
unbucketize_permute,
) = torch.ops.fbgemm.block_bucketize_sparse_features(
_get_bucketize_shape(ids, ids.device),
ids.long(),
bucket_mapping,
) = torch.ops.fbgemm.block_bucketize_sparse_features_inference(
lengths=_get_bucketize_shape(ids, ids.device),
indices=ids.long(),
bucketize_pos=False,
sequence=True,
block_sizes=self._block_size.long(),
my_size=self._world_size,
weights=None,
block_bucketize_pos=(
self._block_bucketize_row_pos
if len(self._block_bucketize_row_pos) > 0
else None
),
return_bucket_mapping=True,
)

id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)
Expand All @@ -236,7 +254,13 @@ def forward(
)

assert unbucketize_permute is not None, "unbucketize permute must not be None"
return dist_ids, unbucketize_permute
assert bucket_mapping is not None, "bucket mapping must not be None"
return (
dist_ids,
unbucketize_permute,
bucket_mapping,
bucketized_lengths,
)

def update(
self,
Expand Down Expand Up @@ -270,6 +294,11 @@ def update(
block_sizes=self._block_size.long(),
my_size=self._world_size,
weights=None,
block_bucketize_pos=(
self._block_bucketize_row_pos
if len(self._block_bucketize_row_pos) > 0
else None
),
)

id_offsets = torch.ops.fbgemm.asynchronous_complete_cumsum(bucketized_lengths)
Expand Down
15 changes: 13 additions & 2 deletions torchrec/distributed/sharding/rw_tensor_pool_sharding.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,8 @@ class InferRwTensorPoolOutputDist(torch.nn.Module):
vals = torch.Tensor([1,2,3,4,5,6], device=device)
"""

__annotations__ = {"_device": Optional[torch.device]}

def __init__(
self,
env: ShardingEnv,
Expand All @@ -224,6 +226,11 @@ def __init__(
self._cat_dim = 0
self._placeholder: torch.Tensor = torch.ones(1, device=device)

@torch.jit.export
def set_device(self, device_str: str) -> None:
self._device = torch.device(device_str)
self._placeholder = torch.ones(1, device=self._device)

def forward(
self,
lookups: List[torch.Tensor],
Expand Down Expand Up @@ -256,12 +263,16 @@ def __init__(
pool_size: int,
env: ShardingEnv,
device: torch.device,
memory_capacity_per_rank: Optional[list[int]] = None,
) -> None:
super().__init__(pool_size, env, device)
super().__init__(pool_size, env, device, memory_capacity_per_rank)

def create_lookup_ids_dist(self) -> InferRwObjectPoolInputDist:
return InferRwObjectPoolInputDist(
self._env, device=self._device, block_size=self._block_size_t
self._env,
device=self._device,
block_size=self._block_size_t,
block_bucketize_row_pos=self._block_bucketize_row_pos,
)

def create_lookup_values_dist(
Expand Down
82 changes: 75 additions & 7 deletions torchrec/distributed/tensor_pool.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,14 @@
)
from torchrec.modules.object_pool_lookups import TensorLookup, TensorPoolLookup
from torchrec.modules.tensor_pool import TensorPool
from torchrec.modules.utils import deterministic_dedup
from torchrec.modules.utils import (
_get_batching_hinted_output,
_get_unbucketize_tensor_via_length_alignment,
deterministic_dedup,
)

torch.fx.wrap("_get_unbucketize_tensor_via_length_alignment")
torch.fx.wrap("_get_batching_hinted_output")


@torch.fx.wrap
Expand All @@ -44,6 +51,17 @@ def index_select_view(
return output[unbucketize_permute].view(-1, dim)


@torch.fx.wrap
def _fx_item_unwrap_optional_tensor(optional: Optional[torch.Tensor]) -> torch.Tensor:
assert optional is not None, "Expected optional to be non-None Tensor"
return optional


@torch.fx.wrap
def _get_id_length_sharded_tensor_pool(ids: torch.Tensor) -> torch.Tensor:
return torch.tensor([ids.size(dim=0)], device=ids.device, dtype=torch.long)


class TensorPoolAwaitable(LazyAwaitable[torch.Tensor]):
def __init__(
self,
Expand Down Expand Up @@ -271,6 +289,8 @@ class LocalShardPool(torch.nn.Module):
# out is tensor([1,2,3]) i.e. first row of the shard
"""

current_device: torch.device

def __init__(
self,
shard: torch.Tensor,
Expand All @@ -280,6 +300,12 @@ def __init__(
shard,
requires_grad=False,
)
self.current_device = self._shard.device

@torch.jit.export
def set_device(self, device_str: str) -> None:
self.current_device = torch.device(device_str)
self._shard.to(self.current_device)

def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
"""
Expand All @@ -291,7 +317,7 @@ def forward(self, rank_ids: torch.Tensor) -> torch.Tensor:
Returns:
torch.Tensor: Tensor of values corresponding to the given rank ids.
"""
return self._shard[rank_ids]
return self._shard[rank_ids.to(self.current_device)]

def update(self, rank_ids: torch.Tensor, values: torch.Tensor) -> None:
_ = update(self._shard, rank_ids, values)
Expand Down Expand Up @@ -337,6 +363,11 @@ def __init__(
env=self._sharding_env,
device=self._device,
pool_size=self._pool_size,
memory_capacity_per_rank=(
self._sharding_plan.memory_capacity_per_rank
if self._sharding_plan.memory_capacity_per_rank is not None
else None
),
)
else:
raise NotImplementedError(
Expand All @@ -356,6 +387,7 @@ def __init__(
if device == torch.device("cpu")
else torch.device("cuda", rank)
)

self._local_shard_pools.append(
LocalShardPool(
torch.empty(
Expand Down Expand Up @@ -409,7 +441,7 @@ def create_context(self) -> ObjectPoolShardingContext:
def _lookup_ids_dist(
self,
ids: torch.Tensor,
) -> Tuple[List[torch.Tensor], torch.Tensor]:
) -> Tuple[List[torch.Tensor], torch.Tensor, torch.Tensor, torch.Tensor]:
return self._lookup_ids_dist_impl(ids)

# pyre-ignore
Expand Down Expand Up @@ -439,18 +471,54 @@ def _lookup_values_dist(

# pyre-ignore
def forward(self, ids: torch.Tensor) -> torch.Tensor:
dist_input, unbucketize_permute = self._lookup_ids_dist(ids)
dist_input, unbucketize_permute, bucket_mapping, bucketized_lengths = (
self._lookup_ids_dist(ids)
)
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
unbucketize_permute
)

lookup = self._lookup_local(dist_input)

# Here we are playing a trick to workaround a fx tracing issue,
# as proxy is not iteratable.
lookup_list = []
for i in range(self._world_size):
lookup_list.append(lookup[i])
# In case of non-heterogenous even sharding keeping the behavior
# consistent with existing logic to ensure that additional fx wrappers
# do not impact the model split logic during inference in anyway
if self._sharding_plan.memory_capacity_per_rank is None:
for i in range(self._world_size):
lookup_list.append(lookup[i])
else:
# Adding fx wrappers in case of uneven heterogenous sharding to
# make it compatible with model split boundaries during inference
for i in range(self._world_size):
lookup_list.append(
_get_batching_hinted_output(
_get_id_length_sharded_tensor_pool(dist_input[i]), lookup[i]
)
)

features_before_input_dist_length = _get_id_length_sharded_tensor_pool(ids)
bucketized_lengths_col_view = bucketized_lengths.view(self._world_size, -1)
unbucketize_permute_non_opt = _fx_item_unwrap_optional_tensor(
unbucketize_permute
)
bucket_mapping_non_opt = _fx_item_unwrap_optional_tensor(bucket_mapping)
unbucketize_permute_non_opt = _get_unbucketize_tensor_via_length_alignment(
features_before_input_dist_length,
bucketized_lengths_col_view,
unbucketize_permute_non_opt,
bucket_mapping_non_opt,
)

output = self._lookup_values_dist(lookup_list)

return index_select_view(output, unbucketize_permute, self._dim)
return index_select_view(
output,
unbucketize_permute_non_opt.to(device=output.device),
self._dim,
)

# pyre-ignore
def _update_values_dist(self, ctx: ObjectPoolShardingContext, values: torch.Tensor):
Expand Down
Loading
Loading