Skip to content

Commit 0f50c2b

Browse files
nipung90facebook-github-bot
authored andcommitted
Enable logging for the plan() function, ShardEstimators and TrainingPipeline class constructors
Summary: This diff enables the static logging functionality to collect data for: 1) plan() - This will allow us to look at the inputs and outputs to the planner to help with use issue debugging 2) ShardEstimators - This will allow us to look at the inputs and outputs to the ShardEstimators, which gives us the bandwidth inputs to verify if the planner is generating expected values as well as help with debugging OOMs 3) TrainingPipeline - The class type here will be an indicator of which pipeline was used by the training job. The training pipeline has implications on the memory usage and is an important data point to collect to investigate OOMs. Reviewed By: kausv Differential Revision: D86317910
1 parent 51be4ea commit 0f50c2b

File tree

6 files changed

+15
-0
lines changed

6 files changed

+15
-0
lines changed

torchrec/distributed/embeddingbag.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,7 @@
5555
FUSED_PARAM_IS_SSD_TABLE,
5656
FUSED_PARAM_SSD_TABLE_LIST,
5757
)
58+
from torchrec.distributed.logger import _torchrec_method_logger
5859
from torchrec.distributed.sharding.cw_sharding import CwPooledEmbeddingSharding
5960
from torchrec.distributed.sharding.dp_sharding import DpPooledEmbeddingSharding
6061
from torchrec.distributed.sharding.dynamic_sharding import (
@@ -466,6 +467,7 @@ class ShardedEmbeddingBagCollection(
466467
This is part of the public API to allow for manual data dist pipelining.
467468
"""
468469

470+
@_torchrec_method_logger()
469471
def __init__(
470472
self,
471473
module: EmbeddingBagCollectionInterface,
@@ -2021,6 +2023,7 @@ class ShardedEmbeddingBag(
20212023
This is part of the public API to allow for manual data dist pipelining.
20222024
"""
20232025

2026+
@_torchrec_method_logger()
20242027
def __init__(
20252028
self,
20262029
module: nn.EmbeddingBag,

torchrec/distributed/planner/planners.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
from torch import nn
1919
from torchrec.distributed.collective_utils import invoke_on_rank_and_broadcast_result
2020
from torchrec.distributed.comm import get_local_size
21+
from torchrec.distributed.logger import _torchrec_method_logger
2122
from torchrec.distributed.planner.constants import BATCH_SIZE, MAX_SIZE
2223
from torchrec.distributed.planner.enumerators import EmbeddingEnumerator
2324
from torchrec.distributed.planner.partitioners import (
@@ -459,6 +460,7 @@ def collective_plan(
459460
sharders,
460461
)
461462

463+
@_torchrec_method_logger()
462464
def plan(
463465
self,
464466
module: nn.Module,

torchrec/distributed/planner/shard_estimators.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import torchrec.optim as trec_optim
1717
from torch import nn
1818
from torchrec.distributed.embedding_types import EmbeddingComputeKernel
19+
from torchrec.distributed.logger import _torchrec_method_logger
1920
from torchrec.distributed.planner.constants import (
2021
BATCHED_COPY_PERF_FACTOR,
2122
BIGINT_DTYPE,
@@ -955,6 +956,7 @@ class EmbeddingStorageEstimator(ShardEstimator):
955956
is_inference (bool): If the model is inference model. Default to False.
956957
"""
957958

959+
@_torchrec_method_logger()
958960
def __init__(
959961
self,
960962
topology: Topology,

torchrec/distributed/shard.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
from torch.distributed._composable.contract import contract
1616
from torchrec.distributed.comm import get_local_size
1717
from torchrec.distributed.global_settings import get_propogate_device
18+
from torchrec.distributed.logger import _torchrec_method_logger
1819
from torchrec.distributed.model_parallel import get_default_sharders
1920
from torchrec.distributed.planner import EmbeddingShardingPlanner, Topology
2021
from torchrec.distributed.sharding_plan import (
@@ -146,6 +147,7 @@ def _shard(
146147

147148
# pyre-ignore
148149
@contract()
150+
@_torchrec_method_logger()
149151
def shard_modules(
150152
module: nn.Module,
151153
env: Optional[ShardingEnv] = None,
@@ -194,6 +196,7 @@ def init_weights(m):
194196
return _shard_modules(module, env, device, plan, sharders, init_params)
195197

196198

199+
@_torchrec_method_logger()
197200
def _shard_modules( # noqa: C901
198201
module: nn.Module,
199202
# TODO: Consolidate to using Dict[str, ShardingEnv]

torchrec/distributed/train_pipeline/train_pipelines.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,7 @@
3232
import torch
3333
from torch.autograd.profiler import record_function
3434
from torchrec.distributed.dist_data import KJTAllToAllTensorsAwaitable
35+
from torchrec.distributed.logger import _torchrec_method_logger
3536
from torchrec.distributed.model_parallel import ShardedModule
3637
from torchrec.distributed.train_pipeline.pipeline_context import (
3738
EmbeddingTrainPipelineContext,
@@ -106,6 +107,7 @@ class TrainPipeline(abc.ABC, Generic[In, Out]):
106107
def progress(self, dataloader_iter: Iterator[In]) -> Out:
107108
pass
108109

110+
@_torchrec_method_logger()
109111
def __init__(self) -> None:
110112
# pipeline state such as in foward, in backward etc, used in training recover scenarios
111113
self._state: PipelineState = PipelineState.IDLE

torchrec/modules/mc_embedding_modules.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212

1313
import torch
1414
import torch.nn as nn
15+
from torchrec.distributed.logger import _torchrec_method_logger
1516

1617
from torchrec.modules.embedding_modules import (
1718
EmbeddingBagCollection,
@@ -125,6 +126,7 @@ class ManagedCollisionEmbeddingCollection(BaseManagedCollisionEmbeddingCollectio
125126
126127
"""
127128

129+
@_torchrec_method_logger()
128130
def __init__(
129131
self,
130132
embedding_collection: EmbeddingCollection,
@@ -164,6 +166,7 @@ class ManagedCollisionEmbeddingBagCollection(BaseManagedCollisionEmbeddingCollec
164166
165167
"""
166168

169+
@_torchrec_method_logger()
167170
def __init__(
168171
self,
169172
embedding_bag_collection: EmbeddingBagCollection,

0 commit comments

Comments
 (0)