Skip to content

Commit 2b31193

Browse files
TroyGardenfacebook-github-bot
authored andcommitted
add pipeline benchmark for kvzch
Summary: # context * modify the kvzch benchmark configs to better represent the real use case * add config pass-in to the test models * fix small bugs and minior refactoring # changes * previous kv-zch embedding table is too small the prefetch process is too short, after this change (increased table size) the prefetch process is longer {F1983784711} {F1983784733} # benchmark |short name |GPU Runtime (P90)|CPU Runtime (P90)|GPU Peak Mem alloc (P90)|GPU Peak Mem reserved (P90)|GPU Mem used (P90)|Malloc retries (P50/P90/P100)|CPU Peak RSS (P90)| |--|--|--|--|--|--|--|--| |regular-base |9864.51 ms |9403.68 ms |33.77 GB |49.66 GB |50.71 GB |0.0 / 0.0 / 0.0 |30.65 GB | |kvzch-base |18804.26 ms |44245.82 ms |25.28 GB |36.33 GB |37.38 GB |0.0 / 0.0 / 0.0 |31.18 GB | |base-inplace |20141.71 ms |46805.58 ms |25.28 GB |34.39 GB |35.44 GB |0.0 / 0.0 / 0.0 |31.19 GB | |kvzch-sdd |20382.59 ms |45647.02 ms |33.42 GB |47.52 GB |48.56 GB |0.0 / 0.0 / 0.0 |31.13 GB | |kvzch-prefetch |17951.19 ms |38598.57 ms |33.45 GB |47.16 GB |48.21 GB |0.0 / 0.0 / 0.0 |30.83 GB | Differential Revision: D84268361
1 parent 7b3effd commit 2b31193

File tree

4 files changed

+34
-15
lines changed

4 files changed

+34
-15
lines changed

torchrec/distributed/benchmark/base.py

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -504,14 +504,6 @@ def wrapper() -> Any: # pyre-ignore [3]
504504
help="JSON config file for benchmarking",
505505
)
506506

507-
# Add loglevel argument with current logger level as default
508-
parser.add_argument(
509-
"--loglevel",
510-
type=str,
511-
default=logging._levelToName[logger.level],
512-
help="Set the logging level (e.g. info, debug, warning, error)",
513-
)
514-
515507
pre_args, _ = parser.parse_known_args()
516508

517509
yaml_defaults: Dict[str, Any] = (
@@ -531,7 +523,6 @@ def wrapper() -> Any: # pyre-ignore [3]
531523
seen_args = {
532524
"json_config",
533525
"yaml_config",
534-
"loglevel",
535526
}
536527

537528
for _name, param in sig.parameters.items():

torchrec/distributed/benchmark/yaml/prefetch_kvzch.yml

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,18 +10,33 @@ RunOptions:
1010
sharding_type: table_wise
1111
profile_dir: "."
1212
name: "sparsenn_prefetch_kvzch_dram"
13+
memory_snapshot: True
14+
loglevel: "info"
15+
num_float_features: 1000
1316
PipelineConfig:
1417
pipeline: "prefetch"
18+
# inplace_copy_batch_to_gpu: True
1519
ModelInputConfig:
16-
feature_pooling_avg: 30
20+
num_float_features: 1000
21+
feature_pooling_avg: 60
22+
ModelSelectionConfig:
23+
model_name: "test_sparse_nn"
24+
model_config:
25+
num_float_features: 1000
26+
submodule_kwargs:
27+
dense_arch_out_size: 1024
28+
over_arch_out_size: 4096
29+
over_arch_hidden_layers: 10
30+
dense_arch_hidden_sizes: [128, 128, 128]
31+
1732
EmbeddingTablesConfig:
18-
num_unweighted_features: 10
19-
num_weighted_features: 10
33+
num_unweighted_features: 50
34+
num_weighted_features: 50
2035
embedding_feature_dim: 256
2136
additional_tables:
2237
- - name: FP16_table
2338
embedding_dim: 512
24-
num_embeddings: 100_000 # Both feature hashsize and virtual table size
39+
num_embeddings: 1_000_000 # Both feature hashsize and virtual table size
2540
feature_names: ["additional_0_0"]
2641
data_type: FP16
2742
total_num_buckets: 100 # num_embedding should be divisible by total_num_buckets

torchrec/distributed/test_utils/model_config.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,7 @@ class TestSparseNNConfig(BaseModelConfig):
8888
over_arch_clazz: Type[nn.Module] = TestOverArchLarge
8989
postproc_module: Optional[nn.Module] = None
9090
zch: bool = False
91+
submodule_kwargs: Optional[Dict[str, Any]] = None
9192

9293
def generate_model(
9394
self,
@@ -108,6 +109,7 @@ def generate_model(
108109
postproc_module=self.postproc_module,
109110
embedding_groups=self.embedding_groups,
110111
zch=self.zch,
112+
submodule_kwargs=self.submodule_kwargs,
111113
)
112114

113115

torchrec/distributed/test_utils/test_model.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -881,6 +881,7 @@ def __init__(
881881
device: Optional[torch.device] = None,
882882
dense_arch_out_size: Optional[int] = None,
883883
dense_arch_hidden_sizes: Optional[List[int]] = None,
884+
**_kwargs: Any,
884885
) -> None:
885886
"""
886887
Args:
@@ -1191,6 +1192,8 @@ def __init__(
11911192
dense_arch_out_size: Optional[int] = None,
11921193
over_arch_out_size: Optional[int] = None,
11931194
over_arch_hidden_layers: Optional[int] = None,
1195+
over_arch_hidden_repeat: Optional[int] = None,
1196+
**_kwargs: Any,
11941197
) -> None:
11951198
"""
11961199
Args:
@@ -1237,7 +1240,8 @@ def __init__(
12371240
),
12381241
SwishLayerNorm([out_features]),
12391242
]
1240-
1243+
for _ in range(over_arch_hidden_repeat or 0):
1244+
layers += layers[1:]
12411245
self.overarch = torch.nn.Sequential(*layers)
12421246

12431247
self.regroup_module = KTRegroupAsDict(
@@ -1398,6 +1402,7 @@ def __init__(
13981402
weighted_tables: List[EmbeddingBagConfig],
13991403
device: Optional[torch.device] = None,
14001404
max_feature_lengths: Optional[Dict[str, int]] = None,
1405+
**_kwargs: Any,
14011406
) -> None:
14021407
"""
14031408
Args:
@@ -1547,6 +1552,7 @@ def __init__(
15471552
over_arch_clazz: Optional[Type[nn.Module]] = None,
15481553
postproc_module: Optional[nn.Module] = None,
15491554
zch: bool = False,
1555+
submodule_kwargs: Optional[Dict[str, Any]] = None,
15501556
) -> None:
15511557
super().__init__(
15521558
tables=cast(List[BaseEmbeddingConfig], tables),
@@ -1559,7 +1565,9 @@ def __init__(
15591565
over_arch_clazz = TestOverArch
15601566
if weighted_tables is None:
15611567
weighted_tables = []
1562-
self.dense = TestDenseArch(num_float_features, device=dense_device)
1568+
self.dense = TestDenseArch(
1569+
num_float_features, device=dense_device, **(submodule_kwargs or {})
1570+
)
15631571
if zch:
15641572
self.sparse: nn.Module = TestEBCSparseArchZCH(
15651573
tables, # pyre-ignore
@@ -1571,13 +1579,15 @@ def __init__(
15711579
self.sparse = TestECSparseArch(
15721580
tables, # pyre-ignore [6]
15731581
sparse_device,
1582+
**(submodule_kwargs or {}),
15741583
)
15751584
else:
15761585
self.sparse = TestEBCSparseArch(
15771586
tables, # pyre-ignore
15781587
weighted_tables,
15791588
sparse_device,
15801589
max_feature_lengths,
1590+
**(submodule_kwargs or {}),
15811591
)
15821592

15831593
embedding_names = (
@@ -1596,6 +1606,7 @@ def __init__(
15961606
weighted_tables,
15971607
embedding_names,
15981608
dense_device,
1609+
**(submodule_kwargs or {}),
15991610
)
16001611
self.register_buffer(
16011612
"dummy_ones",

0 commit comments

Comments
 (0)