Skip to content

Commit 115412b

Browse files
lizhouyufacebook-github-bot
authored andcommitted
Zero Collision Hash Benchmark Framework (#3127)
Summary: Pull Request resolved: #3127 Differential Revision: D77033290
1 parent fc37e53 commit 115412b

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+8088
-2
lines changed
Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,98 @@
1+
# Zero Collision Hashing (ZCH) Benchmarking Testbed
2+
3+
This testbed is used to benchmark the performance of ZCH algorithms with respect to the efficiency, accuracy, and collision management performances. Specifically, the testbed collects the following metrics:
4+
- QPS: query per second, the number of input faeture values the model can process in a second.
5+
- Collision rate: the percentage of collisions in the hash table. High collision rate means that lots of potentially irrelevant features are mapped to the same hash value, which can lead to information loss and decreased accuracy.
6+
- NE: normalized entropy, a measure of the confidence of models on the prediction results of classification tasks.
7+
- AUC: area under the curve, a metric used to evaluate the performance of classification models.
8+
- MAE: mean absolute error, a measure of the average magnitude of errors in regression tasks.
9+
- MSE: mean squared error, a measure of the average squared error in regression tasks.
10+
11+
## Pre-regression
12+
Before running the benchmark, it is important to ensure that the environment is properly set up. The following steps should be taken
13+
1. Prepare Python environment (Python 3.9+)
14+
2. Install the necessary dependencies
15+
```bash
16+
# Install torch and fbgemm_gpu following instructions in https://docs.pytorch.org/FBGEMM/fbgemm_gpu/development/InstallationInstructions.html
17+
pip install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu126/
18+
pip install --pre fbgemm-gpu --index-url https://download.pytorch.org/whl/nightly/cu126/
19+
# Install torchrec
20+
pip install torchrec --index-url https://download.pytorch.org/whl/nightly/cu126
21+
# Install generative recommenders
22+
git clone https://github.com/meta-recsys/generative-recommenders.git
23+
cd generative-recommenders
24+
pip install -e .
25+
```
26+
27+
## Running the benchmark
28+
To run the benchmark, use the following command:
29+
```bash
30+
WORLD_SIZE=1 python benchmark_zch.py -- --profiling_result_folder result_tbsize_10000_nonzch_dlrmv3_kuairand1k --dataset_name kuairand_1k --batch_size 16 --learning_rate 0.001 --dense_optim adam --sparse_optim adagrad --epochs 5 --num_embeddings 10000
31+
```
32+
More options can be found in the [arguments.py](arguments.py) file.
33+
34+
## Repository Structure
35+
- [benchmark_zch.py](benchmark_zch.py): the main script for running the benchmark.
36+
- [arguments.py](arguments.py): contains the arguments for the benchmark.
37+
- [benchmark_zch_utils.py](benchmark_zch_utils.py): utility functions for the benchmark.
38+
- [count_dataset_distributions.py](count_dataset_distributions.py): script for counting the distribution of features in the dataset.
39+
- [data](data): directory containing the dataset used in the benchmark.
40+
- [models](models): directory containing the models used in the benchmark.
41+
- [plots](plots): directory containing the plotting notebooks for the benchmark.
42+
- [figures](figures): directory containing the figures generated by the plotting notebooks.
43+
44+
## To add a new model
45+
To add a new model to the benchmark, follow these steps:
46+
1. Create a new configuration yaml file named as <new_model_name>.yaml in the [models/configs](models/configs) directory.
47+
- Besides the basic configurations like embedding dimensions, number of embeddings, etc. the yaml file must also contain the following two fields:
48+
- embedding_module_attribute_path: the path to the embedding module in the model, either the EmbeddingCollection or the EmbeddingBagCollection.
49+
- managed_collision_module_attribute_path: the path to the managed collision module in the model, if once appilied. It should in the following format: "module.<embedding_module_attribute_path>.mc_embedding_collection._managed_collision_collection._managed_collision_modules".
50+
2. Create a new model class in the [models/models](models/models) directory, named as <new_model_name>.py.
51+
- The model class should act as a wrapper for the new model, and it should
52+
- contain the following attributes
53+
- eval_flag (bool): whether the model is in the evaluation or training mode.
54+
- table_configs (List[Dict[str, EmbeddingConfig]]): a list of dictionaries containing the configuration of each embedding table.
55+
- override the following methods
56+
- forward(self, batch: Dict[str, Any]) -> torch.Tensor: the forward method of the model. The forward method should make the model compatible with the ipnut from the Batch dataclass, and output in the format of `summed_loss, (prediction_logits, prediction_labels, prediction_weights)`.
57+
- eval(self) -> None: set the model to the evaluation mode.
58+
- Implement the `make_model_<new_model_name>` function in the [models/make_model.py](models/make_model.py) file. The function should takes three parameters:
59+
- args: the arguments passed to the benchmark.
60+
- configs: the configuration of the model and dataset.
61+
- device: the device to run the model on.
62+
The function should return an instance of the new model class. It also contains the code to replace its embedding module with the ZCH embedding module using a `mc_adapter` object.
63+
64+
3. Add the new model to the [models/__init__.py](models/__init__.py) file with `from .<new_model_name>.py import make_model_<new_model_name>`.
65+
4. Add the new model to the [models/make_model.py](models/make_model.py) file with
66+
- Add `make_model_<new_model_name>` to the `from .models import` line.
67+
- ADD a condition branch `elif model_name == "<new_model_name>"` to the `make_model` function, in which
68+
- read the configuration file from `os.path.join(os.path.dirname(__file__), "configs", "<new_model_name>.yaml")`.
69+
- read the dataset configuration from `os.path.join(os.path.dirname(__file__), "..", "data", "configs", f"{args.dataset_name}.yaml")`.
70+
- call the make_model_<new_model_name> function with the configuration and dataset configuration.
71+
72+
## To add a new dataset
73+
To add a new dataset to the benchmark, follow these steps:
74+
1. Create a new configuration yaml file named as <new_dataset_name>.yaml in the [data/configs](data/configs) directory.
75+
- The yaml file must contain the following fields:
76+
- dataset_path: the path to the dataset.
77+
- batch_size: the batch size of the dataset.
78+
- num_workers: the number of workers to load the dataset.
79+
- Besides the three required fields, the yaml file should also contain nenecessary fields for loading and ingesting the dataset.
80+
2. Create a new dataset preprocess script in the [data/preprocess](data/preprocess) directory, named as <new_dataset_name>.py.
81+
- The script should contain a definition to the corresponding Batch dataclass, which should be a dataclass that contains necessary attributes, and override the following methods:
82+
- to(self, device: torch.device, non_blocking: bool = False) -> Batch: the method to move the data to the specified device.
83+
- pin_memory(self) -> Batch: the method to pin the data in memory.
84+
- record_stream(self, stream: torch.cuda.streams.Stream) -> None: the method to record the data stream.
85+
- get_dict(self) -> Dict[str, Any]: the method to get the data as a dictionary of `{<attribute_name>: <attribute_value>}`.
86+
- The script should also include a dataset class. The dataset class should act as a wrapper for the new dataset, and it should at least override the following methods:
87+
- __init__(self, config: Dict[str, Any], device: torch.device) -> None: the constructor of the dataset class. It should take a dictionary of configuration and a device as input, and initialize the dataset. When initializing the dataset, it must include a `items_in_memory` attribute as a list of Batch dataclass.
88+
- __len__(self) -> int: the length of the dataset.
89+
- __getitem__(self, idx: int) -> Dict[str, Any]: the method to get an item from the dataset. It should take an index as input, and return the data in the format of Batch dataclass.
90+
- load_item(self, idx: int) -> Dict[str, Any]: the method to load an item from the dataset. It should take an index as input, and return the data in the format of Batch dataclass.
91+
- get_sample(self, idx: int) -> Dict[str, Any]: the method to get a sample from the dataset. It should take an index as input, and return the data from the items_in_memory list.
92+
- __getitems__(self, idxs: List[int]) -> List[Dict[str, Any]]: the method to get a list of items from the dataset. It should take a list of indices as input, and return the data in the format of a list of Batch dataclass.
93+
- The script should include a `collate_fn` that takes a list of Batch dataclass and returns a Batch dataclass.
94+
- The script should finally include a `get_<new_dataset_name>_dataloader` function that takes three parameters:
95+
- args: the arguments passed to the benchmark.
96+
- configs: the configuration of the model and dataset.
97+
- stage: the stage of the benchmark, either "train" or "val".
98+
The function should return a dataloader for the new dataset.
Lines changed: 157 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,157 @@
1+
import argparse
2+
from typing import List
3+
4+
5+
def parse_args(argv: List[str]) -> argparse.Namespace:
6+
parser = argparse.ArgumentParser(description="torchrec dlrm example trainer")
7+
8+
# Dataset related arguments
9+
parser.add_argument(
10+
"--dataset_name",
11+
type=str,
12+
choices=["movielens_1m", "criteo_kaggle", "kuairand_1k"],
13+
default="movielens_1m",
14+
help="dataset for experiment, current support criteo_1tb, criteo_kaggle, kuairand_1k",
15+
)
16+
17+
# Model related arguments
18+
parser.add_argument(
19+
"--model_name",
20+
type=str,
21+
choices=["dlrmv2", "dlrmv3"],
22+
default="dlrmv3",
23+
help="model for experiment, current support dlrmv2, dlrmv3. Dlrmv3 is the default",
24+
)
25+
parser.add_argument(
26+
"--num_embeddings", # ratio of feature ids to embedding table size # 3 axis: x-bath_idx; y-collisions; zembedding table sizes
27+
type=int,
28+
default=100_000,
29+
help="max_ind_size. The number of embeddings in each embedding table. Defaults"
30+
" to 100_000 if num_embeddings_per_feature is not supplied.",
31+
)
32+
parser.add_argument(
33+
"--embedding_dim",
34+
type=int,
35+
default=64,
36+
help="Size of each embedding.",
37+
)
38+
parser.add_argument(
39+
"--seed",
40+
type=int,
41+
help="Random seed for reproducibility.",
42+
default=0,
43+
)
44+
45+
# Training related arguments
46+
parser.add_argument(
47+
"--epochs",
48+
type=int,
49+
default=1,
50+
help="number of epochs to train",
51+
)
52+
parser.add_argument(
53+
"--batch_size",
54+
type=int,
55+
default=4096,
56+
help="batch size to use for training",
57+
)
58+
parser.add_argument(
59+
"--sparse_optim",
60+
type=str,
61+
default="adagrad",
62+
help="The optimizer to use for sparse parameters.",
63+
)
64+
parser.add_argument(
65+
"--dense_optim",
66+
type=str,
67+
default="adagrad",
68+
help="The optimizer to use for sparse parameters.",
69+
)
70+
parser.add_argument(
71+
"--learning_rate",
72+
type=float,
73+
default=1.0,
74+
help="Learning rate.",
75+
)
76+
parser.add_argument(
77+
"--eps",
78+
type=float,
79+
default=1e-8,
80+
help="Epsilon for Adagrad optimizer.",
81+
)
82+
parser.add_argument(
83+
"--weight_decay",
84+
type=float,
85+
default=0,
86+
help="Weight decay for Adagrad optimizer.",
87+
)
88+
parser.add_argument(
89+
"--beta1",
90+
type=float,
91+
default=0.95,
92+
help="Beta1 for Adagrad optimizer.",
93+
)
94+
parser.add_argument(
95+
"--beta2",
96+
type=float,
97+
default=0.999,
98+
help="Beta2 for Adagrad optimizer.",
99+
)
100+
parser.add_argument(
101+
"--shuffle_batches",
102+
dest="shuffle_batches",
103+
action="store_true",
104+
help="Shuffle each batch during training.",
105+
)
106+
parser.add_argument(
107+
"--validation_freq_within_epoch",
108+
type=int,
109+
default=None,
110+
help="Frequency at which validation will be run within an epoch.",
111+
)
112+
parser.set_defaults(
113+
pin_memory=None,
114+
mmap_mode=None,
115+
drop_last=None,
116+
shuffle_batches=None,
117+
shuffle_training_set=None,
118+
)
119+
parser.add_argument(
120+
"--input_hash_size",
121+
type=int,
122+
default=100_000,
123+
help="Input feature value range",
124+
)
125+
parser.add_argument(
126+
"--profiling_result_folder",
127+
type=str,
128+
default="profiling_result",
129+
help="Folder to save profiling results",
130+
)
131+
parser.add_argument(
132+
"--zch_method",
133+
type=str,
134+
help="The method to use for zero collision hashing, blank for no zch",
135+
default="",
136+
)
137+
parser.add_argument(
138+
"--num_buckets",
139+
type=int,
140+
default=4,
141+
help="Number of buckets for identity table. Only used for MPZCH. The number of ranks WORLD_SIZE must be a factor of num_buckets, and the number of buckets must be a factor of input_hash_size",
142+
)
143+
parser.add_argument(
144+
"--max_probe",
145+
type=int,
146+
default=None,
147+
help="Number of probes for identity table. Only used for MPZCH",
148+
)
149+
150+
# testbed related arguments
151+
parser.add_argument(
152+
"--log_path",
153+
type=str,
154+
default="log",
155+
help="Path to save log file without the suffix",
156+
)
157+
return parser.parse_args(argv)

0 commit comments

Comments
 (0)