Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 19 additions & 4 deletions torchrec/distributed/benchmark/yaml/prefetch_kvzch.yml
Original file line number Diff line number Diff line change
Expand Up @@ -10,18 +10,33 @@ RunOptions:
sharding_type: table_wise
profile_dir: "."
name: "sparsenn_prefetch_kvzch_dram"
memory_snapshot: True
loglevel: "info"
num_float_features: 1000
PipelineConfig:
pipeline: "prefetch"
# inplace_copy_batch_to_gpu: True
ModelInputConfig:
feature_pooling_avg: 30
num_float_features: 1000
feature_pooling_avg: 60
ModelSelectionConfig:
model_name: "test_sparse_nn"
model_config:
num_float_features: 1000
submodule_kwargs:
dense_arch_out_size: 1024
over_arch_out_size: 4096
over_arch_hidden_layers: 10
dense_arch_hidden_sizes: [128, 128, 128]

EmbeddingTablesConfig:
num_unweighted_features: 10
num_weighted_features: 10
num_unweighted_features: 50
num_weighted_features: 50
embedding_feature_dim: 256
additional_tables:
- - name: FP16_table
embedding_dim: 512
num_embeddings: 100_000 # Both feature hashsize and virtual table size
num_embeddings: 1_000_000 # Both feature hashsize and virtual table size
feature_names: ["additional_0_0"]
data_type: FP16
total_num_buckets: 100 # num_embedding should be divisible by total_num_buckets
Expand Down
2 changes: 2 additions & 0 deletions torchrec/distributed/test_utils/model_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,7 @@ class TestSparseNNConfig(BaseModelConfig):
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
postproc_module: Optional[nn.Module] = None
zch: bool = False
submodule_kwargs: Optional[Dict[str, Any]] = None

def generate_model(
self,
Expand All @@ -108,6 +109,7 @@ def generate_model(
postproc_module=self.postproc_module,
embedding_groups=self.embedding_groups,
zch=self.zch,
submodule_kwargs=self.submodule_kwargs,
)


Expand Down
15 changes: 13 additions & 2 deletions torchrec/distributed/test_utils/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -881,6 +881,7 @@ def __init__(
device: Optional[torch.device] = None,
dense_arch_out_size: Optional[int] = None,
dense_arch_hidden_sizes: Optional[List[int]] = None,
**_kwargs: Any,
) -> None:
"""
Args:
Expand Down Expand Up @@ -1191,6 +1192,8 @@ def __init__(
dense_arch_out_size: Optional[int] = None,
over_arch_out_size: Optional[int] = None,
over_arch_hidden_layers: Optional[int] = None,
over_arch_hidden_repeat: Optional[int] = None,
**_kwargs: Any,
) -> None:
"""
Args:
Expand Down Expand Up @@ -1237,7 +1240,8 @@ def __init__(
),
SwishLayerNorm([out_features]),
]

for _ in range(over_arch_hidden_repeat or 0):
layers += layers[1:]
self.overarch = torch.nn.Sequential(*layers)

self.regroup_module = KTRegroupAsDict(
Expand Down Expand Up @@ -1398,6 +1402,7 @@ def __init__(
weighted_tables: List[EmbeddingBagConfig],
device: Optional[torch.device] = None,
max_feature_lengths: Optional[Dict[str, int]] = None,
**_kwargs: Any,
) -> None:
"""
Args:
Expand Down Expand Up @@ -1547,6 +1552,7 @@ def __init__(
over_arch_clazz: Optional[Type[nn.Module]] = None,
postproc_module: Optional[nn.Module] = None,
zch: bool = False,
submodule_kwargs: Optional[Dict[str, Any]] = None,
) -> None:
super().__init__(
tables=cast(List[BaseEmbeddingConfig], tables),
Expand All @@ -1559,7 +1565,9 @@ def __init__(
over_arch_clazz = TestOverArch
if weighted_tables is None:
weighted_tables = []
self.dense = TestDenseArch(num_float_features, device=dense_device)
self.dense = TestDenseArch(
num_float_features, device=dense_device, **(submodule_kwargs or {})
)
if zch:
self.sparse: nn.Module = TestEBCSparseArchZCH(
tables, # pyre-ignore
Expand All @@ -1571,13 +1579,15 @@ def __init__(
self.sparse = TestECSparseArch(
tables, # pyre-ignore [6]
sparse_device,
**(submodule_kwargs or {}),
)
else:
self.sparse = TestEBCSparseArch(
tables, # pyre-ignore
weighted_tables,
sparse_device,
max_feature_lengths,
**(submodule_kwargs or {}),
)

embedding_names = (
Expand All @@ -1596,6 +1606,7 @@ def __init__(
weighted_tables,
embedding_names,
dense_device,
**(submodule_kwargs or {}),
)
self.register_buffer(
"dummy_ones",
Expand Down
Loading