@@ -1032,20 +1032,33 @@ def tearDown(self):
10321032 except OSError :
10331033 pass
10341034
1035- def test_distributed_debug_mode (self ):
1035+ def test_debug_level (self ):
1036+ try :
1037+ del os .environ ["TORCH_DISTRIBUTED_DEBUG" ]
1038+ except KeyError :
1039+ pass
1040+
1041+ dist .set_debug_level_from_env ()
10361042 # Default should be off
1037- default_debug_mode = dist ._get_debug_mode ()
1038- self .assertEqual (default_debug_mode , dist ._DistributedDebugLevel .OFF )
1043+ default_debug_mode = dist .get_debug_level ()
1044+ self .assertEqual (default_debug_mode , dist .DebugLevel .OFF )
10391045 mapping = {
1040- "OFF" : dist ._DistributedDebugLevel .OFF ,
1041- "INFO" : dist ._DistributedDebugLevel .INFO ,
1042- "DETAIL" : dist ._DistributedDebugLevel .DETAIL ,
1046+ "OFF" : dist .DebugLevel .OFF ,
1047+ "off" : dist .DebugLevel .OFF ,
1048+ "oFf" : dist .DebugLevel .OFF ,
1049+ "INFO" : dist .DebugLevel .INFO ,
1050+ "info" : dist .DebugLevel .INFO ,
1051+ "INfO" : dist .DebugLevel .INFO ,
1052+ "DETAIL" : dist .DebugLevel .DETAIL ,
1053+ "detail" : dist .DebugLevel .DETAIL ,
1054+ "DeTaIl" : dist .DebugLevel .DETAIL ,
10431055 }
10441056 invalid_debug_modes = ["foo" , 0 , 1 , - 1 ]
10451057
10461058 for mode in mapping .keys ():
10471059 os .environ ["TORCH_DISTRIBUTED_DEBUG" ] = str (mode )
1048- set_debug_mode = dist ._get_debug_mode ()
1060+ dist .set_debug_level_from_env ()
1061+ set_debug_mode = dist .get_debug_level ()
10491062 self .assertEqual (
10501063 set_debug_mode ,
10511064 mapping [mode ],
@@ -1054,8 +1067,8 @@ def test_distributed_debug_mode(self):
10541067
10551068 for mode in invalid_debug_modes :
10561069 os .environ ["TORCH_DISTRIBUTED_DEBUG" ] = str (mode )
1057- with self .assertRaisesRegex (RuntimeError , "to be one of " ):
1058- dist ._get_debug_mode ()
1070+ with self .assertRaisesRegex (RuntimeError , "The value of TORCH_DISTRIBUTED_DEBUG must " ):
1071+ dist .set_debug_level_from_env ()
10591072
10601073
10611074if __name__ == "__main__" :
0 commit comments