@@ -141,6 +141,8 @@ def __init__(
141
141
self ._accelerator_flag = self ._choose_auto_accelerator ()
142
142
elif self ._accelerator_flag == "gpu" :
143
143
self ._accelerator_flag = self ._choose_gpu_accelerator_backend ()
144
+ elif isinstance (self ._accelerator_flag , Accelerator ):
145
+ pass # do nothing
144
146
145
147
self ._set_parallel_devices_and_init_accelerator ()
146
148
@@ -461,7 +463,7 @@ def _check_and_init_precision(self) -> Precision:
461
463
if isinstance (self .strategy , DeepSpeedStrategy ):
462
464
return DeepSpeedPrecision (self ._precision_input ) # type: ignore
463
465
if isinstance (self .strategy , FSDPStrategy ):
464
- return FSDPPrecision (precision = self ._precision_input ) # type: ignore[arg-type]
466
+ return FSDPPrecision (precision = self ._precision_input , device = self . _accelerator_flag . get_device () if isinstance ( self . _accelerator_flag , Accelerator ) else None ) # type: ignore[arg-type]
465
467
mp_precision_supported = ("32-true" , "bf16-mixed" , "bf16-true" , "16-true" )
466
468
if isinstance (self .strategy , ModelParallelStrategy ) and self ._precision_input not in mp_precision_supported :
467
469
raise ValueError (
@@ -493,6 +495,8 @@ def _check_and_init_precision(self) -> Precision:
493
495
else "Using bfloat16 Automatic Mixed Precision (AMP)"
494
496
)
495
497
device = "cpu" if self ._accelerator_flag == "cpu" else "cuda"
498
+ if isinstance (self ._accelerator_flag , Accelerator ):
499
+ device = self ._accelerator_flag .get_device ()
496
500
return MixedPrecision (precision = self ._precision_input , device = device ) # type: ignore[arg-type]
497
501
498
502
raise RuntimeError ("No precision set" )
0 commit comments