File tree 1 file changed +6
-3
lines changed
distributed/tensor_parallelism
1 file changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -91,9 +91,6 @@ def forward(self, x):
91
91
# create model and move it to GPU - init"cuda"_mesh has already mapped GPU ids.
92
92
tp_model = ToyModel ().to ("cuda" )
93
93
94
- # Create an optimizer for the parallelized module.
95
- lr = 0.25
96
- optimizer = torch .optim .AdamW (tp_model .parameters (), lr = lr , foreach = True )
97
94
98
95
# Custom parallelization plan for the model
99
96
tp_model = parallelize_module (
@@ -104,6 +101,12 @@ def forward(self, x):
104
101
"out_proj" : RowwiseParallel (),
105
102
},
106
103
)
104
+
105
+ # Create an optimizer for the parallelized module.
106
+ lr = 0.25
107
+ optimizer = torch .optim .AdamW (tp_model .parameters (), lr = lr , foreach = True )
108
+
109
+
107
110
# Perform a num of iterations of forward/backward
108
111
# and optimizations for the sharded module.
109
112
num_iters = 10
You can’t perform that action at this time.
0 commit comments