Skip to content

Commit 729264d

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Zero Collision Hash Benchmark Framework
Differential Revision: D77033290
1 parent 7bd2a6f commit 729264d

27 files changed

+4815
-0
lines changed
Lines changed: 125 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,125 @@
1+
import argparse
2+
from typing import List
3+
4+
5+
def parse_args(argv: List[str]) -> argparse.Namespace:
6+
parser = argparse.ArgumentParser(description="torchrec dlrm example trainer")
7+
8+
# Dataset related arguments
9+
parser.add_argument(
10+
"--dataset_name",
11+
type=str,
12+
choices=["movielens_1m", "criteo_kaggle"],
13+
default="movielens_1m",
14+
help="dataset for experiment, current support criteo_1tb, criteo_kaggle",
15+
)
16+
17+
# Model related arguments
18+
parser.add_argument(
19+
"--model_name",
20+
type=str,
21+
choices=["dlrmv2", "dlrmv3"],
22+
default="dlrmv3",
23+
help="model for experiment, current support dlrmv2, dlrmv3. Dlrmv3 is the default",
24+
)
25+
parser.add_argument(
26+
"--num_embeddings", # ratio of feature ids to embedding table size # 3 axis: x-bath_idx; y-collisions; zembedding table sizes
27+
type=int,
28+
default=100_000,
29+
help="max_ind_size. The number of embeddings in each embedding table. Defaults"
30+
" to 100_000 if num_embeddings_per_feature is not supplied.",
31+
)
32+
parser.add_argument(
33+
"--embedding_dim",
34+
type=int,
35+
default=64,
36+
help="Size of each embedding.",
37+
)
38+
parser.add_argument(
39+
"--seed",
40+
type=int,
41+
help="Random seed for reproducibility.",
42+
default=0,
43+
)
44+
45+
# Training related arguments
46+
parser.add_argument(
47+
"--epochs",
48+
type=int,
49+
default=1,
50+
help="number of epochs to train",
51+
)
52+
parser.add_argument(
53+
"--batch_size",
54+
type=int,
55+
default=4096,
56+
help="batch size to use for training",
57+
)
58+
parser.add_argument(
59+
"--sparse_optim",
60+
type=str,
61+
default="adagrad",
62+
help="The optimizer to use for sparse parameters.",
63+
)
64+
parser.add_argument(
65+
"--dense_optim",
66+
type=str,
67+
default="adagrad",
68+
help="The optimizer to use for sparse parameters.",
69+
)
70+
parser.add_argument(
71+
"--learning_rate",
72+
type=float,
73+
default=1.0,
74+
help="Learning rate.",
75+
)
76+
parser.add_argument(
77+
"--eps",
78+
type=float,
79+
default=1e-8,
80+
help="Epsilon for Adagrad optimizer.",
81+
)
82+
parser.add_argument(
83+
"--shuffle_batches",
84+
dest="shuffle_batches",
85+
action="store_true",
86+
help="Shuffle each batch during training.",
87+
)
88+
parser.add_argument(
89+
"--validation_freq_within_epoch",
90+
type=int,
91+
default=None,
92+
help="Frequency at which validation will be run within an epoch.",
93+
)
94+
parser.set_defaults(
95+
pin_memory=None,
96+
mmap_mode=None,
97+
drop_last=None,
98+
shuffle_batches=None,
99+
shuffle_training_set=None,
100+
)
101+
parser.add_argument(
102+
"--input_hash_size",
103+
type=int,
104+
default=100_000,
105+
help="Input feature value range",
106+
)
107+
parser.add_argument(
108+
"--profiling_result_folder",
109+
type=str,
110+
default="profiling_result",
111+
help="Folder to save profiling results",
112+
)
113+
parser.add_argument(
114+
"--zch_method",
115+
type=str,
116+
help="The method to use for zero collision hashing, blank for no zch",
117+
default="",
118+
)
119+
parser.add_argument(
120+
"--num_buckets",
121+
type=int,
122+
default=4,
123+
help="Number of buckets for identity table. Only used for MPZCH. The number of ranks WORLD_SIZE must be a factor of num_buckets, and the number of buckets must be a factor of input_hash_size",
124+
)
125+
return parser.parse_args(argv)

0 commit comments

Comments
 (0)