3636from sgm .util import default , instantiate_from_config
3737
3838
39+ def load_module_gpu (model ):
40+ model .cuda ()
41+
42+
43+ def unload_module_gpu (model ):
44+ model .cpu ()
45+ torch .cuda .empty_cache ()
46+
47+
48+ def initial_model_load (model ):
49+ model .model .half ()
50+ return model
51+
52+
3953def get_resizing_factor (
4054 desired_shape : Tuple [int , int ], current_shape : Tuple [int , int ]
4155) -> float :
@@ -60,75 +74,11 @@ def get_resizing_factor(
6074 return factor
6175
6276
63- def load_img_for_prediction_no_st (
64- image_path : str ,
65- mask_path : str ,
66- W : int ,
67- H : int ,
68- crop_h : int ,
69- crop_w : int ,
70- device = "cuda" ,
71- ) -> torch .Tensor :
72- image = Image .open (image_path )
73- if image is None :
74- return None
75- image = np .array (image ).astype (np .float32 ) / 255
76- h , w = image .shape [:2 ]
77- rotated = 0
78-
79- mask = None
80- if mask_path is not None :
81- mask = Image .open (mask_path )
82- mask = np .array (mask ).astype (np .float32 ) / 255
83- mask = np .any (mask .reshape (h , w , - 1 ) > 0 , axis = 2 , keepdims = True ).astype (
84- np .float32
85- )
86- elif image .shape [- 1 ] == 4 :
87- mask = image [:, :, 3 :]
88-
89- if mask is not None :
90- image = image [:, :, :3 ] * mask + (1 - mask )
91- # if "DAVIS" in image_path:
92- # y, x, _ = np.where(mask > 0)
93- # x_mean, y_mean = np.mean(x), np.mean(y)
94- # else:
95- # x_mean, y_mean = w//2, h//2
96- # h_new = int(max(crop_h, crop_w) * 1.33)
97- # x_min = max(int(x_mean - h_new//2), 0)
98- # y_min = max(int(y_mean - h_new//2), 0)
99- # image_cropped = image[y_min : y_min + h_new, x_min : x_min + h_new]
100- # h_crop, w_crop = image_cropped.shape[:2]
101- # h_new = max(h_crop, w_crop)
102- # top = max((h_new - h_crop) // 2, 0)
103- # left = max((h_new - w_crop) // 2, 0)
104- # image_padded = np.ones((h_new, h_new, 3)).astype(np.float32)
105- # image_padded[top : top + h_crop, left : left + w_crop, :] = image_cropped
106- # image = image_padded
107- # h, w = image.shape[:2]
108-
109- image = image .transpose (2 , 0 , 1 )
110- image = torch .from_numpy (image ).to (dtype = torch .float32 )
111- image = image .unsqueeze (0 )
112-
113- rfs = get_resizing_factor ((H , W ), (h , w ))
114- resize_size = [int (np .ceil (rfs * s )) for s in (h , w )]
115- top = (resize_size [0 ] - H ) // 2
116- left = (resize_size [1 ] - W ) // 2
117-
118- image = torch .nn .functional .interpolate (
119- image , resize_size , mode = "area" , antialias = False
120- )
121- image = TT .functional .crop (image , top = top , left = left , height = H , width = W )
122- return image .to (device ) * 2.0 - 1.0 , rotated
123-
124-
12577def read_gif (input_path , n_frames ):
12678 frames = []
12779 video = Image .open (input_path )
128- if video .n_frames < n_frames :
129- return frames
13080 for img in ImageSequence .Iterator (video ):
131- frames .append (img .convert ("RGB " ))
81+ frames .append (img .convert ("RGBA " ))
13282 if len (frames ) == n_frames :
13383 break
13484 return frames
@@ -206,16 +156,17 @@ def read_video(
206156 print (f"Loading { len (all_img_paths )} video frames..." )
207157 images = [Image .open (img_path ) for img_path in all_img_paths ]
208158
159+ if len (images ) < n_frames :
160+ images = (images + images [::- 1 ])[:n_frames ]
161+
209162 if len (images ) != n_frames :
210- raise ValueError ("Input video contains fewer than {n_frames} frames." )
163+ raise ValueError (f "Input video contains fewer than { n_frames } frames." )
211164
212165 # Remove background and crop video frames
213166 images_v0 = []
214- for image in images :
167+ for t , image in enumerate ( images ) :
215168 if remove_bg :
216- if image .mode == "RGBA" :
217- pass
218- else :
169+ if image .mode != "RGBA" :
219170 image .thumbnail ([W , H ], Image .Resampling .LANCZOS )
220171 image = remove (image .convert ("RGBA" ), alpha_matting = True )
221172 image_arr = np .array (image )
@@ -225,11 +176,12 @@ def read_video(
225176 )
226177 x , y , w , h = cv2 .boundingRect (mask )
227178 max_size = max (w , h )
228- side_len = (
229- int (max_size / image_frame_ratio )
230- if image_frame_ratio is not None
231- else in_w
232- )
179+ if t == 0 :
180+ side_len = (
181+ int (max_size / image_frame_ratio )
182+ if image_frame_ratio is not None
183+ else in_w
184+ )
233185 padded_image = np .zeros ((side_len , side_len , 4 ), dtype = np .uint8 )
234186 center = side_len // 2
235187 padded_image [
@@ -239,7 +191,9 @@ def read_video(
239191 rgba = Image .fromarray (padded_image ).resize ((W , H ), Image .LANCZOS )
240192 rgba_arr = np .array (rgba ) / 255.0
241193 rgb = rgba_arr [..., :3 ] * rgba_arr [..., - 1 :] + (1 - rgba_arr [..., - 1 :])
242- images = Image .fromarray ((rgb * 255 ).astype (np .uint8 ))
194+ image = Image .fromarray ((rgb * 255 ).astype (np .uint8 ))
195+ else :
196+ image = image .convert ("RGB" ).resize ((W , H ), Image .LANCZOS )
243197 image = ToTensor ()(image ).unsqueeze (0 ).to (device )
244198 images_v0 .append (image * 2.0 - 1.0 )
245199 return images_v0
@@ -341,11 +295,13 @@ def denoiser(input, sigma, c):
341295
342296
343297def decode_latents (model , samples_z , timesteps ):
298+ load_module_gpu (model .first_stage_model )
344299 if isinstance (model .first_stage_model .decoder , VideoDecoder ):
345300 samples_x = model .decode_first_stage (samples_z , timesteps = timesteps )
346301 else :
347302 samples_x = model .decode_first_stage (samples_z )
348303 samples = torch .clamp ((samples_x + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
304+ unload_module_gpu (model .first_stage_model )
349305 return samples
350306
351307
@@ -751,20 +707,21 @@ def do_sample(
751707 else :
752708 num_samples = [num_samples ]
753709
710+ load_module_gpu (model .conditioner )
754711 batch , batch_uc = get_batch (
755712 get_unique_embedder_keys_from_conditioner (model .conditioner ),
756713 value_dict ,
757714 num_samples ,
758715 T = T ,
759716 additional_batch_uc_fields = additional_batch_uc_fields ,
760717 )
761-
762718 c , uc = model .conditioner .get_unconditional_conditioning (
763719 batch ,
764720 batch_uc = batch_uc ,
765721 force_uc_zero_embeddings = force_uc_zero_embeddings ,
766722 force_cond_zero_embeddings = force_cond_zero_embeddings ,
767723 )
724+ unload_module_gpu (model .conditioner )
768725
769726 for k in c :
770727 if not k == "crossattn" :
@@ -805,15 +762,21 @@ def denoiser(input, sigma, c):
805762 model .model , input , sigma , c , ** additional_model_inputs
806763 )
807764
765+ load_module_gpu (model .model )
766+ load_module_gpu (model .denoiser )
808767 samples_z = sampler (denoiser , randn , cond = c , uc = uc )
768+ unload_module_gpu (model .model )
769+ unload_module_gpu (model .denoiser )
809770
771+ load_module_gpu (model .first_stage_model )
810772 if isinstance (model .first_stage_model .decoder , VideoDecoder ):
811773 samples_x = model .decode_first_stage (
812774 samples_z , timesteps = default (decoding_t , T )
813775 )
814776 else :
815777 samples_x = model .decode_first_stage (samples_z )
816778 samples = torch .clamp ((samples_x + 1.0 ) / 2.0 , min = 0.0 , max = 1.0 )
779+ unload_module_gpu (model .first_stage_model )
817780
818781 if filter is not None :
819782 samples = filter (samples )
@@ -850,20 +813,21 @@ def do_sample_per_step(
850813 else :
851814 num_samples = [num_samples ]
852815
816+ load_module_gpu (model .conditioner )
853817 batch , batch_uc = get_batch (
854818 get_unique_embedder_keys_from_conditioner (model .conditioner ),
855819 value_dict ,
856820 num_samples ,
857821 T = T ,
858822 additional_batch_uc_fields = additional_batch_uc_fields ,
859823 )
860-
861824 c , uc = model .conditioner .get_unconditional_conditioning (
862825 batch ,
863826 batch_uc = batch_uc ,
864827 force_uc_zero_embeddings = force_uc_zero_embeddings ,
865828 force_cond_zero_embeddings = force_cond_zero_embeddings ,
866829 )
830+ unload_module_gpu (model .conditioner )
867831
868832 for k in c :
869833 if not k == "crossattn" :
@@ -917,6 +881,9 @@ def denoiser(input, sigma, c):
917881 if sampler .s_tmin <= sigmas [step ] <= sampler .s_tmax
918882 else 0.0
919883 )
884+
885+ load_module_gpu (model .model )
886+ load_module_gpu (model .denoiser )
920887 samples_z = sampler .sampler_step (
921888 s_in * sigmas [step ],
922889 s_in * sigmas [step + 1 ],
@@ -926,6 +893,8 @@ def denoiser(input, sigma, c):
926893 uc ,
927894 gamma ,
928895 )
896+ unload_module_gpu (model .model )
897+ unload_module_gpu (model .denoiser )
929898
930899 return samples_z
931900
0 commit comments