Skip to content

Commit 5bd37b3

Browse files
Luka Buskivadzefacebook-github-bot
Luka Buskivadze
authored andcommitted
add custom model for planner (pytorch#429)
Summary: Pull Request resolved: pytorch#429 this diff allows to create custom models to run planner and check results. in constast to previous models, running this model only exploits planner and is simpler to use Reviewed By: joshuadeng Differential Revision: D37090549 fbshipit-source-id: e71bdeea30d9250fc3e8d668260340a7f160c931
1 parent 778b47c commit 5bd37b3

File tree

1 file changed

+153
-0
lines changed

1 file changed

+153
-0
lines changed
Lines changed: 153 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,153 @@
1+
#!/usr/bin/env python3
2+
# Copyright (c) Meta Platforms, Inc. and affiliates.
3+
# All rights reserved.
4+
#
5+
# This source code is licensed under the BSD-style license found in the
6+
# LICENSE file in the root directory of this source tree.
7+
8+
import argparse
9+
import logging
10+
from typing import cast, List
11+
12+
import torch
13+
14+
from torch import nn
15+
16+
from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder
17+
18+
from torchrec.distributed.planner.parallelized_planners import (
19+
ParallelizedEmbeddingShardingPlanner,
20+
)
21+
from torchrec.distributed.planner.planners import EmbeddingShardingPlanner
22+
23+
from torchrec.distributed.planner.types import Topology
24+
from torchrec.distributed.test_utils.test_model import TestSparseNN
25+
from torchrec.distributed.types import ModuleSharder
26+
from torchrec.modules.embedding_configs import EmbeddingBagConfig
27+
28+
parser = argparse.ArgumentParser(description="custom model for running planner")
29+
30+
parser.add_argument(
31+
"-lws",
32+
"--local_world_size",
33+
type=int,
34+
default=8,
35+
help="local_world_size; local world size used in topolgy. Defaults to 8",
36+
required=False,
37+
)
38+
parser.add_argument(
39+
"-ws",
40+
"--world_size",
41+
type=int,
42+
default=16,
43+
help="world_size; number of ranks used in topology. Defaults to 16",
44+
required=False,
45+
)
46+
parser.add_argument(
47+
"-bs",
48+
"--batch_size",
49+
type=int,
50+
default=32,
51+
help="batch_size; batch_size used in topology. Defaults to 32",
52+
required=False,
53+
)
54+
parser.add_argument(
55+
"-hc",
56+
"--hbm_cap",
57+
type=int,
58+
default=16777216,
59+
help="hbm_cap; maximum storage used in topology. Defaults to 1024 * 1024 * 16",
60+
required=False,
61+
)
62+
parser.add_argument(
63+
"-cd",
64+
"--compute_device",
65+
type=str,
66+
default="cuda",
67+
help="compute_device; compute_device used in topology. Defaults to 'cuda'",
68+
required=False,
69+
)
70+
parser.add_argument(
71+
"-ne",
72+
"--num_embeddings",
73+
type=int,
74+
default=100,
75+
help="num_embeddings, number of embeddings used in creating tables. Defaults to 100",
76+
required=False,
77+
)
78+
parser.add_argument(
79+
"-ed",
80+
"--embedding_dim",
81+
type=int,
82+
default=64,
83+
help="embedding_dim: embedding dimension used in creating tables. Defaults to 64",
84+
required=False,
85+
)
86+
parser.add_argument(
87+
"-nt",
88+
"--num_tables",
89+
type=int,
90+
default=10,
91+
help="num_tables: number of tables used in creating tables. Defaults to 10",
92+
required=False,
93+
)
94+
parser.add_argument(
95+
"-pt",
96+
"--planner_type",
97+
type=str,
98+
default="parallelized",
99+
help="embedding_sharding_planner_type: type of embedding sharding planner used in creating a planner"
100+
"if need to use non_parallelized, type 'non_parallelized', otherwise defaults to parallelized",
101+
required=False,
102+
)
103+
104+
args: argparse.Namespace = parser.parse_args()
105+
106+
logging.basicConfig(level=logging.INFO)
107+
108+
109+
def main() -> None:
110+
"""
111+
Generates the sharding plan for a SparseNN model.
112+
113+
Purpose behind this function is to test planners quickly. This can be done by building the function with custom parameters
114+
such as local_world_size, num_embeddings, num_tables and more.
115+
116+
Program outputs planner summary.
117+
"""
118+
topology = Topology(
119+
local_world_size=args.local_world_size,
120+
world_size=args.world_size,
121+
batch_size=args.batch_size,
122+
hbm_cap=args.hbm_cap,
123+
compute_device=args.compute_device,
124+
)
125+
126+
if args.embedding_sharding_planner_type == "non_parallelized":
127+
planner = EmbeddingShardingPlanner(topology=topology)
128+
else:
129+
planner = ParallelizedEmbeddingShardingPlanner(topology=topology)
130+
131+
tables: List[EmbeddingBagConfig] = [
132+
EmbeddingBagConfig(
133+
num_embeddings=args.num_embeddings,
134+
embedding_dim=args.embedding_dim,
135+
name="table_" + str(i),
136+
feature_names=["feature_" + str(i)],
137+
)
138+
for i in range(args.num_tables)
139+
]
140+
model = TestSparseNN(tables=tables, sparse_device=torch.device("meta"))
141+
142+
Sharders: List[ModuleSharder[nn.Module]] = [
143+
cast(ModuleSharder[nn.Module], EmbeddingBagCollectionSharder()),
144+
]
145+
146+
planner.plan(
147+
module=model,
148+
sharders=Sharders,
149+
)
150+
151+
152+
if __name__ == "__main__":
153+
main()

0 commit comments

Comments
 (0)