Skip to content

Commit 8af5cb6

Browse files
committed
[Fix] enable weights_only in torch.load to fix tests
1 parent 390ba2f commit 8af5cb6

File tree

3 files changed

+12
-12
lines changed

3 files changed

+12
-12
lines changed

mmengine/runner/checkpoint.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def load_from_local(filename, map_location):
344344
filename = osp.expanduser(filename)
345345
if not osp.isfile(filename):
346346
raise FileNotFoundError(f'{filename} can not be found.')
347-
checkpoint = torch.load(filename, map_location=map_location)
347+
checkpoint = torch.load(filename, map_location=map_location, weights_only = False)
348348
return checkpoint
349349

350350

@@ -412,7 +412,7 @@ def load_from_pavi(filename, map_location=None):
412412
with TemporaryDirectory() as tmp_dir:
413413
downloaded_file = osp.join(tmp_dir, model.name)
414414
model.download(downloaded_file)
415-
checkpoint = torch.load(downloaded_file, map_location=map_location)
415+
checkpoint = torch.load(downloaded_file, map_location=map_location, weights_only = False)
416416
return checkpoint
417417

418418

@@ -435,7 +435,7 @@ def load_from_ceph(filename, map_location=None, backend='petrel'):
435435
file_backend = get_file_backend(
436436
filename, backend_args={'backend': backend})
437437
with io.BytesIO(file_backend.get(filename)) as buffer:
438-
checkpoint = torch.load(buffer, map_location=map_location)
438+
checkpoint = torch.load(buffer, map_location=map_location, weights_only = False)
439439
return checkpoint
440440

441441

@@ -504,7 +504,7 @@ def load_from_openmmlab(filename, map_location=None):
504504
filename = osp.join(_get_mmengine_home(), model_url)
505505
if not osp.isfile(filename):
506506
raise FileNotFoundError(f'{filename} can not be found.')
507-
checkpoint = torch.load(filename, map_location=map_location)
507+
checkpoint = torch.load(filename, map_location=map_location, weights_only = False)
508508
return checkpoint
509509

510510

tests/test_hooks/test_ema_hook.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -230,7 +230,7 @@ def test_with_runner(self):
230230
self.assertTrue(
231231
isinstance(ema_hook.ema_model, ExponentialMovingAverage))
232232

233-
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
233+
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only = False)
234234
self.assertTrue('ema_state_dict' in checkpoint)
235235
self.assertTrue(checkpoint['ema_state_dict']['steps'] == 8)
236236

@@ -245,7 +245,7 @@ def test_with_runner(self):
245245
runner.test()
246246

247247
# Test load checkpoint without ema_state_dict
248-
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'))
248+
checkpoint = torch.load(osp.join(self.temp_dir.name, 'epoch_2.pth'), weights_only = False)
249249
checkpoint.pop('ema_state_dict')
250250
torch.save(checkpoint,
251251
osp.join(self.temp_dir.name, 'without_ema_state_dict.pth'))
@@ -274,7 +274,7 @@ def test_with_runner(self):
274274
runner = self.build_runner(cfg)
275275
runner.train()
276276
state_dict = torch.load(
277-
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu')
277+
osp.join(self.temp_dir.name, 'epoch_4.pth'), map_location='cpu', weights_only = False)
278278
self.assertIn('ema_state_dict', state_dict)
279279
for k, v in state_dict['state_dict'].items():
280280
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
@@ -287,12 +287,12 @@ def test_with_runner(self):
287287
runner = self.build_runner(cfg)
288288
runner.train()
289289
state_dict = torch.load(
290-
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu')
290+
osp.join(self.temp_dir.name, 'iter_4.pth'), map_location='cpu', weights_only = False)
291291
self.assertIn('ema_state_dict', state_dict)
292292
for k, v in state_dict['state_dict'].items():
293293
assert_allclose(v, state_dict['ema_state_dict']['module.' + k])
294294
state_dict = torch.load(
295-
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu')
295+
osp.join(self.temp_dir.name, 'iter_5.pth'), map_location='cpu', weights_only = False)
296296
self.assertIn('ema_state_dict', state_dict)
297297

298298
def _test_swap_parameters(self, func_name, *args, **kwargs):

tests/test_runner/test_runner.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2272,7 +2272,7 @@ def test_checkpoint(self):
22722272
self.assertTrue(osp.exists(path))
22732273
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_4.pth')))
22742274

2275-
ckpt = torch.load(path)
2275+
ckpt = torch.load(path, weights_only = False)
22762276
self.assertEqual(ckpt['meta']['epoch'], 3)
22772277
self.assertEqual(ckpt['meta']['iter'], 12)
22782278
self.assertEqual(ckpt['meta']['experiment_name'],
@@ -2444,7 +2444,7 @@ def test_checkpoint(self):
24442444
self.assertTrue(osp.exists(path))
24452445
self.assertFalse(osp.exists(osp.join(self.temp_dir, 'epoch_13.pth')))
24462446

2447-
ckpt = torch.load(path)
2447+
ckpt = torch.load(path, weights_only = False)
24482448
self.assertEqual(ckpt['meta']['epoch'], 0)
24492449
self.assertEqual(ckpt['meta']['iter'], 12)
24502450
assert isinstance(ckpt['optimizer'], dict)
@@ -2455,7 +2455,7 @@ def test_checkpoint(self):
24552455
self.assertEqual(message_hub.get_info('iter'), 11)
24562456
# 2.1.2 check class attribute _statistic_methods can be saved
24572457
HistoryBuffer._statistics_methods.clear()
2458-
ckpt = torch.load(path)
2458+
ckpt = torch.load(path, weights_only = False)
24592459
self.assertIn('min', HistoryBuffer._statistics_methods)
24602460

24612461
# 2.2 test `load_checkpoint`

0 commit comments

Comments
 (0)