-
Notifications
You must be signed in to change notification settings - Fork 88
Integration with DCP #978
base: unflatten
Are you sure you want to change the base?
Integration with DCP #978
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,6 +1,8 @@ | ||
# Copyright (c) Meta Platforms, Inc. and affiliates | ||
import torch | ||
from pippy import annotate_split_points, Pipe, SplitPoint | ||
import torch.distributed.checkpoint as dcp | ||
import tempfile | ||
|
||
|
||
d_hid = 16 | ||
|
@@ -66,6 +68,49 @@ def get_layers(module): | |
return layers | ||
|
||
|
||
def pipe_to_sd(pipe): | ||
sd = {} | ||
for stage_idx in range(pipe.num_stages): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. something a little fishy about this proposal (equally so for both option 1 and 2) is that it's not likely you'd want to iterate all the stages in the pipe and load/save them. Example 1: simple pipeline with 4 gpus |
||
stage_mod = pipe.get_stage_module(stage_idx) | ||
sd[f"stage_{stage_idx}"] = stage_mod | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. not really clear to me why we need to add a prefix at all.
There should be no duplication of fqns between submods/stages. what are we doing about the 'submod_0' part in the fqn? when we do If the former, can't we just save/load the keys as usual? If the latter, we can still save/load without a prefix of There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Former. @wconstab |
||
return sd | ||
|
||
with tempfile.TemporaryDirectory() as tmpdir: | ||
#Simulate saving the pipe | ||
# Option 1: | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think Option 1 would be more likely used than Option 2 in realistic setting. Could you please uncomment this block of code? |
||
# for stage_idx in range(pipe.num_stages): | ||
# print(f"Saving pipeline stage {stage_idx}") | ||
# stage_mod = pipe.get_stage_module(stage_idx) | ||
# dcp.save( | ||
# {f"stage_{stage_idx}": stage_mod}, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Curious, is the dict required by API of DCP? Can a user directly save There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. why does this matter? i think the DCP api had reasons for interfacing with dict instead of model, adding a new variant that takes model and gets its dict should be possible, but i think it's clearer this way that the only part of the model that gets saved is the dict There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Just to be clear: I like saving the state dict too (instead of the module). That's more composable to me. |
||
# checkpoint_id=f"{tmpdir}_{stage_idx}" | ||
# ) | ||
# Option 2: | ||
sd = pipe_to_sd(pipe) | ||
dcp.save(state_dict, checkpoint_id=tmpdir) | ||
|
||
|
||
#Simulate loading the pipe | ||
# Option 1: | ||
# for stage_idx in range(pipe.num_stages): | ||
# print(f"Loading pipeline stage {stage_idx}") | ||
# stage_mod = pipe.get_stage_module(stage_idx) | ||
# dcp.load( | ||
# {f"stage_{stage_idx}": stage_mod}, | ||
# checkpoint_id=f"{tmpdir}_{stage_idx}" | ||
# ) | ||
|
||
#Option 2: | ||
new_pipe = Pipe.from_tracing( | ||
transformer, | ||
1, | ||
(x,), | ||
) | ||
sd = pipe_to_sd(new_pipe) | ||
dcp.load(sd, checkpoint_id=tmpdir) | ||
|
||
pipe = new_pipe | ||
|
||
# Collect all layers in pipe | ||
layers = [] | ||
for stage_idx in range(pipe.num_stages): | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wz337 , might be interesting in dist state dict