Skip to content

Commit 82629f8

Browse files
authored
Generalize Optimizers container type, by passing base internal optimizer class. (#884)
Passing `optimizer_cls` to `OptimizersContainer` and `OptimizersInBackwardContainer` constructors, instead of `name`.
1 parent 291ace6 commit 82629f8

File tree

2 files changed

+25
-21
lines changed

2 files changed

+25
-21
lines changed

tests/unit_tests/test_train_spec.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -47,8 +47,8 @@ def fake_build_optimizers(
4747
}
4848
return OptimizersContainer(
4949
model_parts=model_parts,
50+
optimizer_cls=torch.optim.Adam,
5051
optimizer_kwargs=optimizer_kwargs,
51-
name="Adam",
5252
)
5353

5454

torchtitan/components/optimizer.py

Lines changed: 24 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
import copy
88
import functools
9-
from typing import Any, Callable, Dict, Iterable, List
9+
from typing import Any, Callable, Dict, Generic, List, TypeVar
1010

1111
import torch
1212
import torch.nn as nn
@@ -30,18 +30,10 @@
3030
]
3131

3232

33-
def _create_optimizer(
34-
parameters: Iterable[nn.Parameter], optimizer_kwargs: Dict[str, Any], name: str
35-
) -> Optimizer:
36-
if name == "Adam":
37-
return torch.optim.Adam(parameters, **optimizer_kwargs)
38-
elif name == "AdamW":
39-
return torch.optim.AdamW(parameters, **optimizer_kwargs)
40-
else:
41-
raise NotImplementedError(f"Optimizer {name} not added.")
33+
T = TypeVar("T", bound=Optimizer)
4234

4335

44-
class OptimizersContainer(Optimizer):
36+
class OptimizersContainer(Optimizer, Generic[T]):
4537
"""A container for multiple optimizers.
4638
4739
This class is used to wrap multiple optimizers into a single object that can be
@@ -67,18 +59,21 @@ class OptimizersContainer(Optimizer):
6759
name (str): Name of the optimizers.
6860
"""
6961

70-
optimizers: List[Optimizer]
62+
optimizers: List[T]
7163
model_parts: List[nn.Module]
7264

7365
def __init__(
74-
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
66+
self,
67+
model_parts: List[nn.Module],
68+
optimizer_cls: type[T],
69+
optimizer_kwargs: Dict[str, Any],
7570
) -> None:
7671
all_params = []
77-
self.optimizers: List[Optimizer] = []
72+
self.optimizers: List[T] = []
7873
self.model_parts = model_parts
7974
for model in self.model_parts:
8075
params = [p for p in model.parameters() if p.requires_grad]
81-
self.optimizers.append(_create_optimizer(params, optimizer_kwargs, name))
76+
self.optimizers.append(optimizer_cls(params, **optimizer_kwargs))
8277
all_params.extend(params)
8378
self._validate_length(len(self.model_parts))
8479
self._post_init(all_params, optimizer_kwargs)
@@ -139,7 +134,10 @@ class OptimizersInBackwardContainer(OptimizersContainer):
139134
"""
140135

141136
def __init__(
142-
self, model_parts: List[nn.Module], optimizer_kwargs: Dict[str, Any], name: str
137+
self,
138+
model_parts: List[nn.Module],
139+
optimizer_cls: type[T],
140+
optimizer_kwargs: Dict[str, Any],
143141
) -> None:
144142
all_params = []
145143
self.model_parts = model_parts
@@ -148,7 +146,7 @@ def __init__(
148146
for model in self.model_parts:
149147
for p in model.parameters():
150148
if p.requires_grad:
151-
optim_dict[p] = _create_optimizer([p], optimizer_kwargs, name)
149+
optim_dict[p] = optimizer_cls([p], **optimizer_kwargs)
152150
all_params.append(p)
153151

154152
def optim_hook(param) -> None:
@@ -218,11 +216,17 @@ def build_optimizers(
218216
"fused": fused,
219217
"foreach": foreach,
220218
}
221-
219+
optimizer_classes = {
220+
"Adam": torch.optim.Adam,
221+
"AdamW": torch.optim.AdamW,
222+
}
223+
if name not in optimizer_classes:
224+
raise NotImplementedError(f"Optimizer {name} not added.")
225+
optimizer_cls = optimizer_classes[name]
222226
return (
223-
OptimizersContainer(model_parts, optimizer_kwargs, name)
227+
OptimizersContainer(model_parts, optimizer_cls, optimizer_kwargs)
224228
if not optim_in_bwd
225-
else OptimizersInBackwardContainer(model_parts, optimizer_kwargs, name)
229+
else OptimizersInBackwardContainer(model_parts, optimizer_cls, optimizer_kwargs)
226230
)
227231

228232

0 commit comments

Comments
 (0)