diff --git a/README.md b/README.md index c07aa9b..c81dc92 100644 --- a/README.md +++ b/README.md @@ -136,7 +136,7 @@ Step 3: Training ```shell CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task train --dataset_path recam_train_data --output_path ./models/train --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" --steps_per_epoch 8000 --max_epochs 100 --learning_rate 1e-4 --accumulate_grad_batches 1 --use_gradient_checkpointing --dataloader_num_workers 4 ``` -We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate and increasing the batch size. +We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate ~and increasing the batch size~. We only support batch size=1, see more discussion here: [Wan2.1 finetuning script seems to only support bs = 1](https://github.com/modelscope/DiffSynth-Studio/issues/600) Step 4: Test the model diff --git a/train_recammaster.py b/train_recammaster.py index d687f21..ac60ef7 100644 --- a/train_recammaster.py +++ b/train_recammaster.py @@ -48,7 +48,9 @@ def crop_and_resize(self, image): ) return image - + def __len__(self): + return len(self.path) + def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process): reader = imageio.get_reader(file_path) if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval: @@ -115,8 +117,9 @@ def __getitem__(self, data_id): else: data = {"text": text, "video": video, "path": path} break - except: - data_id += 1 + except Exception as e: + print(f"ERROR WHEN LOADING: {e}") + self.__getitem__(data_id+1 if data_id+1 < len(self.path) else 0) return data @@ -203,7 +206,6 @@ def get_relative_pose(self, cam_params): ret_poses = np.array(ret_poses, dtype=np.float32) return ret_poses - def __getitem__(self, index): # Return: # data['latents']: torch.Size([16, 21*2, 60, 104]) @@ -263,7 +265,7 @@ def __getitem__(self, index): def __len__(self): - return self.steps_per_epoch + return min(len(self.path), self.steps_per_epoch)