diff --git a/torchrec/distributed/benchmark/yaml/prefetch_kvzch.yml b/torchrec/distributed/benchmark/yaml/prefetch_kvzch.yml index 32bed808f..bdc016398 100644 --- a/torchrec/distributed/benchmark/yaml/prefetch_kvzch.yml +++ b/torchrec/distributed/benchmark/yaml/prefetch_kvzch.yml @@ -10,18 +10,33 @@ RunOptions: sharding_type: table_wise profile_dir: "." name: "sparsenn_prefetch_kvzch_dram" + memory_snapshot: True + loglevel: "info" + num_float_features: 1000 PipelineConfig: pipeline: "prefetch" + # inplace_copy_batch_to_gpu: True ModelInputConfig: - feature_pooling_avg: 30 + num_float_features: 1000 + feature_pooling_avg: 60 +ModelSelectionConfig: + model_name: "test_sparse_nn" + model_config: + num_float_features: 1000 + submodule_kwargs: + dense_arch_out_size: 1024 + over_arch_out_size: 4096 + over_arch_hidden_layers: 10 + dense_arch_hidden_sizes: [128, 128, 128] + EmbeddingTablesConfig: - num_unweighted_features: 10 - num_weighted_features: 10 + num_unweighted_features: 50 + num_weighted_features: 50 embedding_feature_dim: 256 additional_tables: - - name: FP16_table embedding_dim: 512 - num_embeddings: 100_000 # Both feature hashsize and virtual table size + num_embeddings: 1_000_000 # Both feature hashsize and virtual table size feature_names: ["additional_0_0"] data_type: FP16 total_num_buckets: 100 # num_embedding should be divisible by total_num_buckets diff --git a/torchrec/distributed/test_utils/model_config.py b/torchrec/distributed/test_utils/model_config.py index 0341a00c6..5c9eefc82 100644 --- a/torchrec/distributed/test_utils/model_config.py +++ b/torchrec/distributed/test_utils/model_config.py @@ -88,6 +88,7 @@ class TestSparseNNConfig(BaseModelConfig): over_arch_clazz: Type[nn.Module] = TestOverArchLarge postproc_module: Optional[nn.Module] = None zch: bool = False + submodule_kwargs: Optional[Dict[str, Any]] = None def generate_model( self, @@ -108,6 +109,7 @@ def generate_model( postproc_module=self.postproc_module, embedding_groups=self.embedding_groups, zch=self.zch, + submodule_kwargs=self.submodule_kwargs, ) diff --git a/torchrec/distributed/test_utils/test_model.py b/torchrec/distributed/test_utils/test_model.py index cb7004670..01e4d51d8 100644 --- a/torchrec/distributed/test_utils/test_model.py +++ b/torchrec/distributed/test_utils/test_model.py @@ -881,6 +881,7 @@ def __init__( device: Optional[torch.device] = None, dense_arch_out_size: Optional[int] = None, dense_arch_hidden_sizes: Optional[List[int]] = None, + **_kwargs: Any, ) -> None: """ Args: @@ -1191,6 +1192,8 @@ def __init__( dense_arch_out_size: Optional[int] = None, over_arch_out_size: Optional[int] = None, over_arch_hidden_layers: Optional[int] = None, + over_arch_hidden_repeat: Optional[int] = None, + **_kwargs: Any, ) -> None: """ Args: @@ -1237,7 +1240,8 @@ def __init__( ), SwishLayerNorm([out_features]), ] - + for _ in range(over_arch_hidden_repeat or 0): + layers += layers[1:] self.overarch = torch.nn.Sequential(*layers) self.regroup_module = KTRegroupAsDict( @@ -1398,6 +1402,7 @@ def __init__( weighted_tables: List[EmbeddingBagConfig], device: Optional[torch.device] = None, max_feature_lengths: Optional[Dict[str, int]] = None, + **_kwargs: Any, ) -> None: """ Args: @@ -1547,6 +1552,7 @@ def __init__( over_arch_clazz: Optional[Type[nn.Module]] = None, postproc_module: Optional[nn.Module] = None, zch: bool = False, + submodule_kwargs: Optional[Dict[str, Any]] = None, ) -> None: super().__init__( tables=cast(List[BaseEmbeddingConfig], tables), @@ -1559,7 +1565,9 @@ def __init__( over_arch_clazz = TestOverArch if weighted_tables is None: weighted_tables = [] - self.dense = TestDenseArch(num_float_features, device=dense_device) + self.dense = TestDenseArch( + num_float_features, device=dense_device, **(submodule_kwargs or {}) + ) if zch: self.sparse: nn.Module = TestEBCSparseArchZCH( tables, # pyre-ignore @@ -1571,6 +1579,7 @@ def __init__( self.sparse = TestECSparseArch( tables, # pyre-ignore [6] sparse_device, + **(submodule_kwargs or {}), ) else: self.sparse = TestEBCSparseArch( @@ -1578,6 +1587,7 @@ def __init__( weighted_tables, sparse_device, max_feature_lengths, + **(submodule_kwargs or {}), ) embedding_names = ( @@ -1596,6 +1606,7 @@ def __init__( weighted_tables, embedding_names, dense_device, + **(submodule_kwargs or {}), ) self.register_buffer( "dummy_ones",