Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions torchrec/distributed/embeddingbag.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
FUSED_PARAM_IS_SSD_TABLE,
FUSED_PARAM_SSD_TABLE_LIST,
)
from torchrec.distributed.logger import _torchrec_method_logger
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
from torchrec.distributed.sharding.dynamic_sharding import (
Expand Down Expand Up @@ -466,6 +467,7 @@ class ShardedEmbeddingBagCollection(
This is part of the public API to allow for manual data dist pipelining.
"""

@_torchrec_method_logger()
def __init__(
self,
module: EmbeddingBagCollectionInterface,
Expand Down Expand Up @@ -2021,6 +2023,7 @@ class ShardedEmbeddingBag(
This is part of the public API to allow for manual data dist pipelining.
"""

@_torchrec_method_logger()
def __init__(
self,
module: nn.EmbeddingBag,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/planner/planners.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from torch import nn
from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.logger import _torchrec_method_logger
from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
from torchrec.distributed.planner.partitioners import (
Expand Down Expand Up @@ -498,6 +499,7 @@ def collective_plan(
sharders,
)

@_torchrec_method_logger()
def plan(
self,
module: nn.Module,
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/planner/shard_estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import torchrec.optim as trec_optim
from torch import nn
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
from torchrec.distributed.logger import _torchrec_method_logger
from torchrec.distributed.planner.constants import (
BATCHED_COPY_PERF_FACTOR,
BIGINT_DTYPE,
Expand Down Expand Up @@ -955,6 +956,7 @@ class EmbeddingStorageEstimator(ShardEstimator):
is_inference (bool): If the model is inference model. Default to False.
"""

@_torchrec_method_logger()
def __init__(
self,
topology: Topology,
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/shard.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
from torch.distributed._composable.contract import contract
from torchrec.distributed.comm import get_local_size
from torchrec.distributed.global_settings import get_propogate_device
from torchrec.distributed.logger import _torchrec_method_logger
from torchrec.distributed.model_parallel import get_default_sharders
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
from torchrec.distributed.sharding_plan import (
Expand Down Expand Up @@ -146,6 +147,7 @@ def _shard(

# pyre-ignore
@contract()
@_torchrec_method_logger()
def shard_modules(
module: nn.Module,
env: Optional[ShardingEnv] = None,
Expand Down Expand Up @@ -194,6 +196,7 @@ def init_weights(m):
return _shard_modules(module, env, device, plan, sharders, init_params)


@_torchrec_method_logger()
def _shard_modules( # noqa: C901
module: nn.Module,
# TODO: Consolidate to using Dict[str, ShardingEnv]
Expand Down
3 changes: 3 additions & 0 deletions torchrec/distributed/train_pipeline/train_pipelines.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@
import torch
from torch.autograd.profiler import record_function
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
from torchrec.distributed.logger import _torchrec_method_logger
from torchrec.distributed.model_parallel import ShardedModule
from torchrec.distributed.train_pipeline.pipeline_context import (
EmbeddingTrainPipelineContext,
Expand Down Expand Up @@ -106,6 +107,8 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
def progress(self, dataloader_iter: Iterator[In]) -> Out:
pass

# pyre-ignore [56]
@_torchrec_method_logger()
def __init__(self) -> None:
# pipeline state such as in foward, in backward etc, used in training recover scenarios
self._state: PipelineState = PipelineState.IDLE
Expand Down
3 changes: 3 additions & 0 deletions torchrec/modules/mc_embedding_modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

import torch
import torch.nn as nn
from torchrec.distributed.logger import _torchrec_method_logger

from torchrec.modules.embedding_modules import (
EmbeddingBagCollection,
Expand Down Expand Up @@ -125,6 +126,7 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio

"""

@_torchrec_method_logger()
def __init__(
self,
embedding_collection: EmbeddingCollection,
Expand Down Expand Up @@ -164,6 +166,7 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec

"""

@_torchrec_method_logger()
def __init__(
self,
embedding_bag_collection: EmbeddingBagCollection,
Expand Down
Loading