Skip to content

Commit 2a4424c

Browse files
committed
Support Stable Diffusion V2
1 parent a4eea6a commit 2a4424c

File tree

4 files changed

+177
-0
lines changed

4 files changed

+177
-0
lines changed

docs/train.md

+7
Original file line numberDiff line numberDiff line change
@@ -119,6 +119,8 @@ Do not ask us why we use these three names - this is related to the dark history
119119

120120
Then you need to decide which Stable Diffusion Model you want to control. In this example, we will just use standard SD1.5. You can download it from the [official page of Stability](https://huggingface.co/runwayml/stable-diffusion-v1-5/tree/main). You want the file "v1-5-pruned.ckpt".
121121

122+
(Or ["v2-1_512-ema-pruned.ckpt"](https://huggingface.co/stabilityai/stable-diffusion-2-1-base/tree/main) if you are using SD2)
123+
122124
Then you need to attach a control net to the SD model. The architecture is
123125

124126
![img](../github_page/sd.png)
@@ -129,6 +131,10 @@ We provide a simple script for you to achieve this easily. If your SD filename i
129131

130132
python tool_add_control.py ./models/v1-5-pruned.ckpt ./models/control_sd15_ini.ckpt
131133

134+
Or if you are using SD2:
135+
136+
python tool_add_control_sd21.py ./models/v2-1_512-ema-pruned.ckpt ./models/control_sd21_ini.ckpt
137+
132138
You may also use other filenames as long as the command is "python tool_add_control.py input_path output_path".
133139

134140
This is the correct output from my machine:
@@ -177,6 +183,7 @@ trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
177183
trainer.fit(model, dataloader)
178184

179185
```
186+
(or "tutorial_train_sd21.py" if you are using SD2)
180187

181188
Thanks to our organized dataset pytorch object and the power of pytorch_lightning, the entire code is just super short.
182189

models/cldm_v21.yaml

+85
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
model:
2+
target: cldm.cldm.ControlLDM
3+
params:
4+
linear_start: 0.00085
5+
linear_end: 0.0120
6+
num_timesteps_cond: 1
7+
log_every_t: 200
8+
timesteps: 1000
9+
first_stage_key: "jpg"
10+
cond_stage_key: "txt"
11+
control_key: "hint"
12+
image_size: 64
13+
channels: 4
14+
cond_stage_trainable: false
15+
conditioning_key: crossattn
16+
monitor: val/loss_simple_ema
17+
scale_factor: 0.18215
18+
use_ema: False
19+
only_mid_control: False
20+
21+
control_stage_config:
22+
target: cldm.cldm.ControlNet
23+
params:
24+
use_checkpoint: True
25+
image_size: 32 # unused
26+
in_channels: 4
27+
hint_channels: 3
28+
model_channels: 320
29+
attention_resolutions: [ 4, 2, 1 ]
30+
num_res_blocks: 2
31+
channel_mult: [ 1, 2, 4, 4 ]
32+
num_head_channels: 64 # need to fix for flash-attn
33+
use_spatial_transformer: True
34+
use_linear_in_transformer: True
35+
transformer_depth: 1
36+
context_dim: 1024
37+
legacy: False
38+
39+
unet_config:
40+
target: cldm.cldm.ControlledUnetModel
41+
params:
42+
use_checkpoint: True
43+
image_size: 32 # unused
44+
in_channels: 4
45+
out_channels: 4
46+
model_channels: 320
47+
attention_resolutions: [ 4, 2, 1 ]
48+
num_res_blocks: 2
49+
channel_mult: [ 1, 2, 4, 4 ]
50+
num_head_channels: 64 # need to fix for flash-attn
51+
use_spatial_transformer: True
52+
use_linear_in_transformer: True
53+
transformer_depth: 1
54+
context_dim: 1024
55+
legacy: False
56+
57+
first_stage_config:
58+
target: ldm.models.autoencoder.AutoencoderKL
59+
params:
60+
embed_dim: 4
61+
monitor: val/rec_loss
62+
ddconfig:
63+
#attn_type: "vanilla-xformers"
64+
double_z: true
65+
z_channels: 4
66+
resolution: 256
67+
in_channels: 3
68+
out_ch: 3
69+
ch: 128
70+
ch_mult:
71+
- 1
72+
- 2
73+
- 4
74+
- 4
75+
num_res_blocks: 2
76+
attn_resolutions: []
77+
dropout: 0.0
78+
lossconfig:
79+
target: torch.nn.Identity
80+
81+
cond_stage_config:
82+
target: ldm.modules.encoders.modules.FrozenOpenCLIPEmbedder
83+
params:
84+
freeze: True
85+
layer: "penultimate"

tool_add_control_sd21.py

+50
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import sys
2+
import os
3+
4+
assert len(sys.argv) == 3, 'Args are wrong.'
5+
6+
input_path = sys.argv[1]
7+
output_path = sys.argv[2]
8+
9+
assert os.path.exists(input_path), 'Input model does not exist.'
10+
assert not os.path.exists(output_path), 'Output filename already exists.'
11+
assert os.path.exists(os.path.dirname(output_path)), 'Output path is not valid.'
12+
13+
import torch
14+
from share import *
15+
from cldm.model import create_model
16+
17+
18+
def get_node_name(name, parent_name):
19+
if len(name) <= len(parent_name):
20+
return False, ''
21+
p = name[:len(parent_name)]
22+
if p != parent_name:
23+
return False, ''
24+
return True, name[len(parent_name):]
25+
26+
27+
model = create_model(config_path='./models/cldm_v21.yaml')
28+
29+
pretrained_weights = torch.load(input_path)
30+
if 'state_dict' in pretrained_weights:
31+
pretrained_weights = pretrained_weights['state_dict']
32+
33+
scratch_dict = model.state_dict()
34+
35+
target_dict = {}
36+
for k in scratch_dict.keys():
37+
is_control, name = get_node_name(k, 'control_')
38+
if is_control:
39+
copy_k = 'model.diffusion_' + name
40+
else:
41+
copy_k = k
42+
if copy_k in pretrained_weights:
43+
target_dict[k] = pretrained_weights[copy_k].clone()
44+
else:
45+
target_dict[k] = scratch_dict[k].clone()
46+
print(f'These weights are newly added: {k}')
47+
48+
model.load_state_dict(target_dict, strict=True)
49+
torch.save(model.state_dict(), output_path)
50+
print('Done.')

tutorial_train_sd21.py

+35
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
from share import *
2+
3+
import pytorch_lightning as pl
4+
from torch.utils.data import DataLoader
5+
from tutorial_dataset import MyDataset
6+
from cldm.logger import ImageLogger
7+
from cldm.model import create_model, load_state_dict
8+
9+
10+
# Configs
11+
resume_path = './models/control_sd21_ini.ckpt'
12+
batch_size = 4
13+
logger_freq = 300
14+
learning_rate = 1e-5
15+
sd_locked = True
16+
only_mid_control = False
17+
18+
19+
# First use cpu to load models. Pytorch Lightning will automatically move it to GPUs.
20+
model = create_model('./models/cldm_v21.yaml').cpu()
21+
model.load_state_dict(load_state_dict(resume_path, location='cpu'))
22+
model.learning_rate = learning_rate
23+
model.sd_locked = sd_locked
24+
model.only_mid_control = only_mid_control
25+
26+
27+
# Misc
28+
dataset = MyDataset()
29+
dataloader = DataLoader(dataset, num_workers=0, batch_size=batch_size, shuffle=True)
30+
logger = ImageLogger(batch_frequency=logger_freq)
31+
trainer = pl.Trainer(gpus=1, precision=32, callbacks=[logger])
32+
33+
34+
# Train!
35+
trainer.fit(model, dataloader)

0 commit comments

Comments
 (0)