Skip to content

Commit e8659ea

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 17ee5a7 commit e8659ea

File tree

4 files changed

+57
-20
lines changed

4 files changed

+57
-20
lines changed

colossalai/checkpoint_io/__init__.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
from .checkpoint_io_base import CheckpointIO
22
from .general_checkpoint_io import GeneralCheckpointIO
33
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
4-
54
from .index_file import CheckpointIndexFile
65
from .moe_checkpoint import MoECheckpointIO
76

colossalai/checkpoint_io/general_checkpoint_io.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -309,4 +309,4 @@ def load_sharded_model(
309309
)
310310

311311
def save_lora_as_pretrained(self, model: nn.Module, checkpoint: str, use_safetensors: bool = False) -> None:
312-
raise NotImplementedError
312+
raise NotImplementedError

colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py

Lines changed: 44 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,13 @@
2424
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
2525
from colossalai.utils.safetensors import _flatten_optim_state_dict, load_flat
2626

27+
from .distributed_checkpoint_utils import (
28+
create_model_metadata,
29+
is_pytorch_model_meta_dist_file,
30+
load_dist_model,
31+
save_dist_sharded_model,
32+
save_dist_unshard_model,
33+
)
2734
from .general_checkpoint_io import GeneralCheckpointIO
2835
from .index_file import CheckpointIndexFile
2936
from .utils import (
@@ -47,14 +54,6 @@
4754
sharded_optimizer_loading_epilogue,
4855
)
4956

50-
from .distributed_checkpoint_utils import (
51-
save_dist_sharded_model,
52-
save_dist_unshard_model,
53-
load_dist_model,
54-
is_pytorch_model_meta_dist_file,
55-
create_model_metadata
56-
)
57-
5857
try:
5958
from torch.nn.modules.module import _EXTRA_STATE_KEY_SUFFIX
6059
except ImportError:
@@ -244,9 +243,19 @@ def save_sharded_model(
244243
return
245244
dist_id = self.tp_size * self.pp_rank + self.tp_rank
246245
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
247-
save_dist_sharded_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, prefix=prefix, size_per_shard=size_per_shard, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
246+
save_dist_sharded_model(
247+
model=model,
248+
model_metadata=model_metadata,
249+
checkpoint=checkpoint,
250+
prefix=prefix,
251+
size_per_shard=size_per_shard,
252+
use_safetensors=use_safetensors,
253+
use_async=use_async,
254+
dist_id=dist_id,
255+
pinned_state_dicts=self.pinned_state_dicts,
256+
)
248257
return
249-
258+
250259
model = model.unwrap()
251260

252261
if os.path.isfile(checkpoint):
@@ -394,9 +403,15 @@ def load_sharded_model(
394403

395404
if is_pytorch_model_meta_dist_file(checkpoint_index_file):
396405
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
397-
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint_index_file, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
406+
load_dist_model(
407+
model=model,
408+
model_metadata=model_metadata,
409+
checkpoint=checkpoint_index_file,
410+
low_cpu_mem_mode=low_cpu_mem_mode,
411+
num_threads=num_threads,
412+
)
398413
return
399-
414+
400415
model_before_wrapping = model # backup for model before wrapping
401416
model = model.unwrap()
402417

@@ -792,9 +807,17 @@ def save_unsharded_model(
792807
if self.dp_rank != 0 and self.sp_rank != 0:
793808
return
794809
dist_id = self.tp_size * self.pp_rank + self.tp_rank
795-
save_dist_unshard_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, use_safetensors=use_safetensors, use_async=use_async, dist_id = dist_id, pinned_state_dicts = self.pinned_state_dicts)
810+
save_dist_unshard_model(
811+
model=model,
812+
model_metadata=model_metadata,
813+
checkpoint=checkpoint,
814+
use_safetensors=use_safetensors,
815+
use_async=use_async,
816+
dist_id=dist_id,
817+
pinned_state_dicts=self.pinned_state_dicts,
818+
)
796819
return
797-
820+
798821
model = model.unwrap()
799822
if self.dp_rank != 0:
800823
return
@@ -867,7 +890,13 @@ def load_unsharded_model(
867890
for filename in os.listdir(checkpoint):
868891
if is_pytorch_model_meta_dist_file(filename):
869892
model_metadata = create_model_metadata(model, tp_size=self.tp_size, tp_rank=self.tp_rank)
870-
load_dist_model(model=model, model_metadata=model_metadata, checkpoint=checkpoint, low_cpu_mem_mode=low_cpu_mem_mode, num_threads=num_threads)
893+
load_dist_model(
894+
model=model,
895+
model_metadata=model_metadata,
896+
checkpoint=checkpoint,
897+
low_cpu_mem_mode=low_cpu_mem_mode,
898+
num_threads=num_threads,
899+
)
871900
return
872901

873902
strict = False
@@ -1099,7 +1128,6 @@ def gather_from_sharded_optimizer_state(
10991128
dist.all_gather(gather_tensor, v, group=dp_group)
11001129
v = torch.stack(gather_tensor).view(-1)[: param.numel()].reshape_as(param)
11011130

1102-
11031131
# Then gather TP shards.
11041132
partition_dim = search_tp_partition_dim(current_shape, original_shape, tp_size)
11051133
if partition_dim is not None:

tests/test_checkpoint_io/test_dist_checkpointio.py

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -79,7 +79,12 @@ def _preprocess_data(data):
7979
model_ckpt_path_0 = f"{tempdir}/model_0"
8080

8181
booster_0.save_model(
82-
model_0, model_ckpt_path_0, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
82+
model_0,
83+
model_ckpt_path_0,
84+
shard=shard,
85+
gather_dtensor=True,
86+
size_per_shard=size_per_shard,
87+
use_async=use_async,
8388
)
8489
booster_0.checkpoint_io._sync_d2h()
8590
booster_0.checkpoint_io._sync_io()
@@ -96,7 +101,12 @@ def _preprocess_data(data):
96101

97102
model_ckpt_path_1 = f"{tempdir}/model_1"
98103
booster_1.save_model(
99-
model_1, model_ckpt_path_1, shard=shard, gather_dtensor=True, size_per_shard=size_per_shard, use_async=use_async
104+
model_1,
105+
model_ckpt_path_1,
106+
shard=shard,
107+
gather_dtensor=True,
108+
size_per_shard=size_per_shard,
109+
use_async=use_async,
100110
)
101111
booster_1.checkpoint_io._sync_d2h()
102112
booster_1.checkpoint_io._sync_io()

0 commit comments

Comments
 (0)