Skip to content

Commit 223db0d

Browse files
TroyGardenmeta-codesync[bot]
authored andcommitted
add pipeline benchmark for kvzch (#3604)
Summary: Pull Request resolved: #3604 # context * modify the kvzch benchmark configs to better represent the real use case * add config pass-in to the test models * fix small bugs and minior refactoring # changes * previous kv-zch embedding table is too small the prefetch process is too short, after this change (increased table size) the prefetch process is longer {F1983784711} {F1983784733} # benchmark |short name |GPU Runtime (P90)|CPU Runtime (P90)|GPU Peak Mem alloc (P90)|GPU Peak Mem reserved (P90)|GPU Mem used (P90)|Malloc retries (P50/P90/P100)|CPU Peak RSS (P90)| |--|--|--|--|--|--|--|--| |regular-base |9864.51 ms |9403.68 ms |33.77 GB |49.66 GB |50.71 GB |0.0 / 0.0 / 0.0 |30.65 GB | |kvzch-base |18804.26 ms |44245.82 ms |25.28 GB |36.33 GB |37.38 GB |0.0 / 0.0 / 0.0 |31.18 GB | |base-inplace |20141.71 ms |46805.58 ms |25.28 GB |34.39 GB |35.44 GB |0.0 / 0.0 / 0.0 |31.19 GB | |kvzch-sdd |20382.59 ms |45647.02 ms |33.42 GB |47.52 GB |48.56 GB |0.0 / 0.0 / 0.0 |31.13 GB | |kvzch-prefetch |17951.19 ms |38598.57 ms |33.45 GB |47.16 GB |48.21 GB |0.0 / 0.0 / 0.0 |30.83 GB | |regular-base |49710.51 ms |74880.50 ms |43.14 GB |50.63 GB |51.68 GB |0.0 / 0.0 / 0.0 |33.57 GB | Reviewed By: spmex Differential Revision: D84268361 fbshipit-source-id: e28abb6fb6ccb1121dcf4ae778e26520b454da8c
1 parent 5cf0f0d commit 223db0d

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

torchrec/distributed/benchmark/yaml/prefetch_kvzch.yml

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,33 @@ RunOptions:
1010
sharding_type: table_wise
1111
profile_dir: "."
1212
name: "sparsenn_prefetch_kvzch_dram"
13+
memory_snapshot: True
14+
loglevel: "info"
15+
num_float_features: 1000
1316
PipelineConfig:
1417
pipeline: "prefetch"
18+
# inplace_copy_batch_to_gpu: True
1519
ModelInputConfig:
16-
feature_pooling_avg: 30
20+
num_float_features: 1000
21+
feature_pooling_avg: 60
22+
ModelSelectionConfig:
23+
model_name: "test_sparse_nn"
24+
model_config:
25+
num_float_features: 1000
26+
submodule_kwargs:
27+
dense_arch_out_size: 1024
28+
over_arch_out_size: 4096
29+
over_arch_hidden_layers: 10
30+
dense_arch_hidden_sizes: [128, 128, 128]
31+
1732
EmbeddingTablesConfig:
18-
num_unweighted_features: 10
19-
num_weighted_features: 10
33+
num_unweighted_features: 50
34+
num_weighted_features: 50
2035
embedding_feature_dim: 256
2136
additional_tables:
2237
- - name: FP16_table
2338
embedding_dim: 512
24-
num_embeddings: 100_000 # Both feature hashsize and virtual table size
39+
num_embeddings: 1_000_000 # Both feature hashsize and virtual table size
2540
feature_names: ["additional_0_0"]
2641
data_type: FP16
2742
total_num_buckets: 100 # num_embedding should be divisible by total_num_buckets

torchrec/distributed/test_utils/model_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TestSparseNNConfig(BaseModelConfig):
8888
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
8989
postproc_module: Optional[nn.Module] = None
9090
zch: bool = False
91+
submodule_kwargs: Optional[Dict[str, Any]] = None
9192

9293
def generate_model(
9394
self,
@@ -108,6 +109,7 @@ def generate_model(
108109
postproc_module=self.postproc_module,
109110
embedding_groups=self.embedding_groups,
110111
zch=self.zch,
112+
submodule_kwargs=self.submodule_kwargs,
111113
)
112114

113115

torchrec/distributed/test_utils/test_model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,7 @@ def __init__(
881881
device: Optional[torch.device] = None,
882882
dense_arch_out_size: Optional[int] = None,
883883
dense_arch_hidden_sizes: Optional[List[int]] = None,
884+
**_kwargs: Any,
884885
) -> None:
885886
"""
886887
Args:
@@ -1191,6 +1192,8 @@ def __init__(
11911192
dense_arch_out_size: Optional[int] = None,
11921193
over_arch_out_size: Optional[int] = None,
11931194
over_arch_hidden_layers: Optional[int] = None,
1195+
over_arch_hidden_repeat: Optional[int] = None,
1196+
**_kwargs: Any,
11941197
) -> None:
11951198
"""
11961199
Args:
@@ -1237,7 +1240,8 @@ def __init__(
12371240
),
12381241
SwishLayerNorm([out_features]),
12391242
]
1240-
1243+
for _ in range(over_arch_hidden_repeat or 0):
1244+
layers += layers[1:]
12411245
self.overarch = torch.nn.Sequential(*layers)
12421246

12431247
self.regroup_module = KTRegroupAsDict(
@@ -1398,6 +1402,7 @@ def __init__(
13981402
weighted_tables: List[EmbeddingBagConfig],
13991403
device: Optional[torch.device] = None,
14001404
max_feature_lengths: Optional[Dict[str, int]] = None,
1405+
**_kwargs: Any,
14011406
) -> None:
14021407
"""
14031408
Args:
@@ -1547,6 +1552,7 @@ def __init__(
15471552
over_arch_clazz: Optional[Type[nn.Module]] = None,
15481553
postproc_module: Optional[nn.Module] = None,
15491554
zch: bool = False,
1555+
submodule_kwargs: Optional[Dict[str, Any]] = None,
15501556
) -> None:
15511557
super().__init__(
15521558
tables=cast(List[BaseEmbeddingConfig], tables),
@@ -1559,7 +1565,9 @@ def __init__(
15591565
over_arch_clazz = TestOverArch
15601566
if weighted_tables is None:
15611567
weighted_tables = []
1562-
self.dense = TestDenseArch(num_float_features, device=dense_device)
1568+
self.dense = TestDenseArch(
1569+
num_float_features, device=dense_device, **(submodule_kwargs or {})
1570+
)
15631571
if zch:
15641572
self.sparse: nn.Module = TestEBCSparseArchZCH(
15651573
tables, # pyre-ignore
@@ -1571,13 +1579,15 @@ def __init__(
15711579
self.sparse = TestECSparseArch(
15721580
tables, # pyre-ignore [6]
15731581
sparse_device,
1582+
**(submodule_kwargs or {}),
15741583
)
15751584
else:
15761585
self.sparse = TestEBCSparseArch(
15771586
tables, # pyre-ignore
15781587
weighted_tables,
15791588
sparse_device,
15801589
max_feature_lengths,
1590+
**(submodule_kwargs or {}),
15811591
)
15821592

15831593
embedding_names = (
@@ -1596,6 +1606,7 @@ def __init__(
15961606
weighted_tables,
15971607
embedding_names,
15981608
dense_device,
1609+
**(submodule_kwargs or {}),
15991610
)
16001611
self.register_buffer(
16011612
"dummy_ones",

0 commit comments

Comments
 (0)