Skip to content

Commit e0596f1

Browse files
authored
Merge pull request #392 from Stability-AI/chunhan/sv4d
update sv4d sampling script and readme
2 parents 8636655 + ce1576b commit e0596f1

File tree

6 files changed

+58
-90
lines changed

6 files changed

+58
-90
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,4 +11,5 @@
1111
/dist
1212
/outputs
1313
/build
14-
/src
14+
/src
15+
/.vscode

README.md

100644100755
Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,8 @@ To run **SV4D** on a single input video of 21 frames:
2323
- `num_steps` : default is 20, can increase to 50 for better quality but longer sampling time.
2424
- `sv3d_version` : To specify the SV3D model to generate reference multi-views, set `--sv3d_version=sv3d_u` for SV3D_u or `--sv3d_version=sv3d_p` for SV3D_p.
2525
- `elevations_deg` : To generate novel-view videos at a specified elevation (default elevation is 10) using SV3D_p (default is SV3D_u), run `python scripts/sampling/simple_video_sample_4d.py --input_path test_video1.mp4 --sv3d_version sv3d_p --elevations_deg 30.0`
26-
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos (with noisy background), try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
26+
- **Background removal** : For input videos with plain background, (optionally) use [rembg](https://github.com/danielgatis/rembg) to remove background and crop video frames by setting `--remove_bg=True`. To obtain higher quality outputs on real-world input videos with noisy background, try segmenting the foreground object using [Cliipdrop](https://clipdrop.co/) before running SV4D.
27+
- **Low VRAM environment** : To run on GPUs with low VRAM, try setting `--decoding_t=1` (of frames decoded at a time) or lower video resolution like `--img_size=512`.
2728

2829
![tile](assets/sv4d.gif)
2930

scripts/demo/sv4d_helpers.py

100644100755
Lines changed: 47 additions & 78 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,20 @@
3636
from sgm.util import default, instantiate_from_config
3737

3838

39+
def load_module_gpu(model):
40+
model.cuda()
41+
42+
43+
def unload_module_gpu(model):
44+
model.cpu()
45+
torch.cuda.empty_cache()
46+
47+
48+
def initial_model_load(model):
49+
model.model.half()
50+
return model
51+
52+
3953
def get_resizing_factor(
4054
desired_shape: Tuple[int, int], current_shape: Tuple[int, int]
4155
) -> float:
@@ -60,75 +74,11 @@ def get_resizing_factor(
6074
return factor
6175

6276

63-
def load_img_for_prediction_no_st(
64-
image_path: str,
65-
mask_path: str,
66-
W: int,
67-
H: int,
68-
crop_h: int,
69-
crop_w: int,
70-
device="cuda",
71-
) -> torch.Tensor:
72-
image = Image.open(image_path)
73-
if image is None:
74-
return None
75-
image = np.array(image).astype(np.float32) / 255
76-
h, w = image.shape[:2]
77-
rotated = 0
78-
79-
mask = None
80-
if mask_path is not None:
81-
mask = Image.open(mask_path)
82-
mask = np.array(mask).astype(np.float32) / 255
83-
mask = np.any(mask.reshape(h, w, -1) > 0, axis=2, keepdims=True).astype(
84-
np.float32
85-
)
86-
elif image.shape[-1] == 4:
87-
mask = image[:, :, 3:]
88-
89-
if mask is not None:
90-
image = image[:, :, :3] * mask + (1 - mask)
91-
# if "DAVIS" in image_path:
92-
# y, x, _ = np.where(mask > 0)
93-
# x_mean, y_mean = np.mean(x), np.mean(y)
94-
# else:
95-
# x_mean, y_mean = w//2, h//2
96-
# h_new = int(max(crop_h, crop_w) * 1.33)
97-
# x_min = max(int(x_mean - h_new//2), 0)
98-
# y_min = max(int(y_mean - h_new//2), 0)
99-
# image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
100-
# h_crop, w_crop = image_cropped.shape[:2]
101-
# h_new = max(h_crop, w_crop)
102-
# top = max((h_new - h_crop) // 2, 0)
103-
# left = max((h_new - w_crop) // 2, 0)
104-
# image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
105-
# image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
106-
# image = image_padded
107-
# h, w = image.shape[:2]
108-
109-
image = image.transpose(2, 0, 1)
110-
image = torch.from_numpy(image).to(dtype=torch.float32)
111-
image = image.unsqueeze(0)
112-
113-
rfs = get_resizing_factor((H, W), (h, w))
114-
resize_size = [int(np.ceil(rfs * s)) for s in (h, w)]
115-
top = (resize_size[0] - H) // 2
116-
left = (resize_size[1] - W) // 2
117-
118-
image = torch.nn.functional.interpolate(
119-
image, resize_size, mode="area", antialias=False
120-
)
121-
image = TT.functional.crop(image, top=top, left=left, height=H, width=W)
122-
return image.to(device) * 2.0 - 1.0, rotated
123-
124-
12577
def read_gif(input_path, n_frames):
12678
frames = []
12779
video = Image.open(input_path)
128-
if video.n_frames < n_frames:
129-
return frames
13080
for img in ImageSequence.Iterator(video):
131-
frames.append(img.convert("RGB"))
81+
frames.append(img.convert("RGBA"))
13282
if len(frames) == n_frames:
13383
break
13484
return frames
@@ -206,16 +156,17 @@ def read_video(
206156
print(f"Loading {len(all_img_paths)} video frames...")
207157
images = [Image.open(img_path) for img_path in all_img_paths]
208158

159+
if len(images) < n_frames:
160+
images = (images + images[::-1])[:n_frames]
161+
209162
if len(images) != n_frames:
210-
raise ValueError("Input video contains fewer than {n_frames} frames.")
163+
raise ValueError(f"Input video contains fewer than {n_frames} frames.")
211164

212165
# Remove background and crop video frames
213166
images_v0 = []
214-
for image in images:
167+
for t, image in enumerate(images):
215168
if remove_bg:
216-
if image.mode == "RGBA":
217-
pass
218-
else:
169+
if image.mode != "RGBA":
219170
image.thumbnail([W, H], Image.Resampling.LANCZOS)
220171
image = remove(image.convert("RGBA"), alpha_matting=True)
221172
image_arr = np.array(image)
@@ -225,11 +176,12 @@ def read_video(
225176
)
226177
x, y, w, h = cv2.boundingRect(mask)
227178
max_size = max(w, h)
228-
side_len = (
229-
int(max_size / image_frame_ratio)
230-
if image_frame_ratio is not None
231-
else in_w
232-
)
179+
if t == 0:
180+
side_len = (
181+
int(max_size / image_frame_ratio)
182+
if image_frame_ratio is not None
183+
else in_w
184+
)
233185
padded_image = np.zeros((side_len, side_len, 4), dtype=np.uint8)
234186
center = side_len // 2
235187
padded_image[
@@ -239,7 +191,9 @@ def read_video(
239191
rgba = Image.fromarray(padded_image).resize((W, H), Image.LANCZOS)
240192
rgba_arr = np.array(rgba) / 255.0
241193
rgb = rgba_arr[..., :3] * rgba_arr[..., -1:] + (1 - rgba_arr[..., -1:])
242-
images = Image.fromarray((rgb * 255).astype(np.uint8))
194+
image = Image.fromarray((rgb * 255).astype(np.uint8))
195+
else:
196+
image = image.convert("RGB").resize((W, H), Image.LANCZOS)
243197
image = ToTensor()(image).unsqueeze(0).to(device)
244198
images_v0.append(image * 2.0 - 1.0)
245199
return images_v0
@@ -341,11 +295,13 @@ def denoiser(input, sigma, c):
341295

342296

343297
def decode_latents(model, samples_z, timesteps):
298+
load_module_gpu(model.first_stage_model)
344299
if isinstance(model.first_stage_model.decoder, VideoDecoder):
345300
samples_x = model.decode_first_stage(samples_z, timesteps=timesteps)
346301
else:
347302
samples_x = model.decode_first_stage(samples_z)
348303
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
304+
unload_module_gpu(model.first_stage_model)
349305
return samples
350306

351307

@@ -751,20 +707,21 @@ def do_sample(
751707
else:
752708
num_samples = [num_samples]
753709

710+
load_module_gpu(model.conditioner)
754711
batch, batch_uc = get_batch(
755712
get_unique_embedder_keys_from_conditioner(model.conditioner),
756713
value_dict,
757714
num_samples,
758715
T=T,
759716
additional_batch_uc_fields=additional_batch_uc_fields,
760717
)
761-
762718
c, uc = model.conditioner.get_unconditional_conditioning(
763719
batch,
764720
batch_uc=batch_uc,
765721
force_uc_zero_embeddings=force_uc_zero_embeddings,
766722
force_cond_zero_embeddings=force_cond_zero_embeddings,
767723
)
724+
unload_module_gpu(model.conditioner)
768725

769726
for k in c:
770727
if not k == "crossattn":
@@ -805,15 +762,21 @@ def denoiser(input, sigma, c):
805762
model.model, input, sigma, c, **additional_model_inputs
806763
)
807764

765+
load_module_gpu(model.model)
766+
load_module_gpu(model.denoiser)
808767
samples_z = sampler(denoiser, randn, cond=c, uc=uc)
768+
unload_module_gpu(model.model)
769+
unload_module_gpu(model.denoiser)
809770

771+
load_module_gpu(model.first_stage_model)
810772
if isinstance(model.first_stage_model.decoder, VideoDecoder):
811773
samples_x = model.decode_first_stage(
812774
samples_z, timesteps=default(decoding_t, T)
813775
)
814776
else:
815777
samples_x = model.decode_first_stage(samples_z)
816778
samples = torch.clamp((samples_x + 1.0) / 2.0, min=0.0, max=1.0)
779+
unload_module_gpu(model.first_stage_model)
817780

818781
if filter is not None:
819782
samples = filter(samples)
@@ -850,20 +813,21 @@ def do_sample_per_step(
850813
else:
851814
num_samples = [num_samples]
852815

816+
load_module_gpu(model.conditioner)
853817
batch, batch_uc = get_batch(
854818
get_unique_embedder_keys_from_conditioner(model.conditioner),
855819
value_dict,
856820
num_samples,
857821
T=T,
858822
additional_batch_uc_fields=additional_batch_uc_fields,
859823
)
860-
861824
c, uc = model.conditioner.get_unconditional_conditioning(
862825
batch,
863826
batch_uc=batch_uc,
864827
force_uc_zero_embeddings=force_uc_zero_embeddings,
865828
force_cond_zero_embeddings=force_cond_zero_embeddings,
866829
)
830+
unload_module_gpu(model.conditioner)
867831

868832
for k in c:
869833
if not k == "crossattn":
@@ -917,6 +881,9 @@ def denoiser(input, sigma, c):
917881
if sampler.s_tmin <= sigmas[step] <= sampler.s_tmax
918882
else 0.0
919883
)
884+
885+
load_module_gpu(model.model)
886+
load_module_gpu(model.denoiser)
920887
samples_z = sampler.sampler_step(
921888
s_in * sigmas[step],
922889
s_in * sigmas[step + 1],
@@ -926,6 +893,8 @@ def denoiser(input, sigma, c):
926893
uc,
927894
gamma,
928895
)
896+
unload_module_gpu(model.model)
897+
unload_module_gpu(model.denoiser)
929898

930899
return samples_z
931900

scripts/sampling/configs/sv4d.yaml

100644100755
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -93,12 +93,6 @@ model:
9393
sigma_sampler_config:
9494
target: sgm.modules.diffusionmodules.sigma_sampling.ZeroSampler
9595

96-
# - input_key: cond_aug
97-
# is_trainable: False
98-
# target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND
99-
# params:
100-
# outdim: 256
101-
10296
- input_key: polar_rad
10397
is_trainable: False
10498
target: sgm.modules.encoders.modules.ConcatTimestepEmbedderND

scripts/sampling/simple_video_sample_4d.py

100644100755
Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from scripts.demo.sv4d_helpers import (
1414
decode_latents,
1515
load_model,
16+
initial_model_load,
1617
read_video,
1718
run_img2vid,
1819
run_img2vid_per_step,
@@ -26,6 +27,7 @@ def sample(
2627
output_folder: Optional[str] = "outputs/sv4d",
2728
num_steps: Optional[int] = 20,
2829
sv3d_version: str = "sv3d_u", # sv3d_u or sv3d_p
30+
img_size: int = 576, # image resolution
2931
fps_id: int = 6,
3032
motion_bucket_id: int = 127,
3133
cond_aug: float = 1e-5,
@@ -47,7 +49,7 @@ def sample(
4749
V = 8 # number of views per sample
4850
F = 8 # vae factor to downsize image->latent
4951
C = 4
50-
H, W = 576, 576
52+
H, W = img_size, img_size
5153
n_frames = 21 # number of input and output video frames
5254
n_views = V + 1 # number of output video views (1 input view + 8 novel views)
5355
n_views_sv3d = 21
@@ -64,7 +66,7 @@ def sample(
6466
"f": F,
6567
"options": {
6668
"discretization": 1,
67-
"cfg": 2.5,
69+
"cfg": 3.0,
6870
"sigma_min": 0.002,
6971
"sigma_max": 700.0,
7072
"rho": 7.0,
@@ -137,7 +139,7 @@ def sample(
137139
for t in range(n_frames):
138140
img_matrix[t][0] = images_v0[t]
139141

140-
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 10
142+
base_count = len(glob(os.path.join(output_folder, "*.mp4"))) // 11
141143
save_video(
142144
os.path.join(output_folder, f"{base_count:06d}_t000.mp4"),
143145
img_matrix[0],
@@ -155,6 +157,7 @@ def sample(
155157
num_steps,
156158
verbose,
157159
)
160+
model = initial_model_load(model)
158161

159162
# Interleaved sampling for anchor frames
160163
t0, v0 = 0, 0

sgm/modules/spacetime_attention.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -593,4 +593,4 @@ def forward(
593593
if not self.use_linear:
594594
x = self.proj_out(x)
595595
out = x + x_in
596-
return out
596+
return out

0 commit comments

Comments
 (0)