66
77import copy
88import functools
9- from typing import Any , Callable , Dict , Iterable , List
9+ from typing import Any , Callable , Dict , Generic , List , TypeVar
1010
1111import torch
1212import torch .nn as nn
3030]
3131
3232
33- def _create_optimizer (
34- parameters : Iterable [nn .Parameter ], optimizer_kwargs : Dict [str , Any ], name : str
35- ) -> Optimizer :
36- if name == "Adam" :
37- return torch .optim .Adam (parameters , ** optimizer_kwargs )
38- elif name == "AdamW" :
39- return torch .optim .AdamW (parameters , ** optimizer_kwargs )
40- else :
41- raise NotImplementedError (f"Optimizer { name } not added." )
33+ T = TypeVar ("T" , bound = Optimizer )
4234
4335
44- class OptimizersContainer (Optimizer ):
36+ class OptimizersContainer (Optimizer , Generic [ T ] ):
4537 """A container for multiple optimizers.
4638
4739 This class is used to wrap multiple optimizers into a single object that can be
@@ -67,18 +59,21 @@ class OptimizersContainer(Optimizer):
6759 name (str): Name of the optimizers.
6860 """
6961
70- optimizers : List [Optimizer ]
62+ optimizers : List [T ]
7163 model_parts : List [nn .Module ]
7264
7365 def __init__ (
74- self , model_parts : List [nn .Module ], optimizer_kwargs : Dict [str , Any ], name : str
66+ self ,
67+ model_parts : List [nn .Module ],
68+ optimizer_cls : type [T ],
69+ optimizer_kwargs : Dict [str , Any ],
7570 ) -> None :
7671 all_params = []
77- self .optimizers : List [Optimizer ] = []
72+ self .optimizers : List [T ] = []
7873 self .model_parts = model_parts
7974 for model in self .model_parts :
8075 params = [p for p in model .parameters () if p .requires_grad ]
81- self .optimizers .append (_create_optimizer (params , optimizer_kwargs , name ))
76+ self .optimizers .append (optimizer_cls (params , ** optimizer_kwargs ))
8277 all_params .extend (params )
8378 self ._validate_length (len (self .model_parts ))
8479 self ._post_init (all_params , optimizer_kwargs )
@@ -139,7 +134,10 @@ class OptimizersInBackwardContainer(OptimizersContainer):
139134 """
140135
141136 def __init__ (
142- self , model_parts : List [nn .Module ], optimizer_kwargs : Dict [str , Any ], name : str
137+ self ,
138+ model_parts : List [nn .Module ],
139+ optimizer_cls : type [T ],
140+ optimizer_kwargs : Dict [str , Any ],
143141 ) -> None :
144142 all_params = []
145143 self .model_parts = model_parts
@@ -148,7 +146,7 @@ def __init__(
148146 for model in self .model_parts :
149147 for p in model .parameters ():
150148 if p .requires_grad :
151- optim_dict [p ] = _create_optimizer ([p ], optimizer_kwargs , name )
149+ optim_dict [p ] = optimizer_cls ([p ], ** optimizer_kwargs )
152150 all_params .append (p )
153151
154152 def optim_hook (param ) -> None :
@@ -218,11 +216,17 @@ def build_optimizers(
218216 "fused" : fused ,
219217 "foreach" : foreach ,
220218 }
221-
219+ optimizer_classes = {
220+ "Adam" : torch .optim .Adam ,
221+ "AdamW" : torch .optim .AdamW ,
222+ }
223+ if name not in optimizer_classes :
224+ raise NotImplementedError (f"Optimizer { name } not added." )
225+ optimizer_cls = optimizer_classes [name ]
222226 return (
223- OptimizersContainer (model_parts , optimizer_kwargs , name )
227+ OptimizersContainer (model_parts , optimizer_cls , optimizer_kwargs )
224228 if not optim_in_bwd
225- else OptimizersInBackwardContainer (model_parts , optimizer_kwargs , name )
229+ else OptimizersInBackwardContainer (model_parts , optimizer_cls , optimizer_kwargs )
226230 )
227231
228232
0 commit comments