Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ Step 3: Training
```shell
CUDA_VISIBLE_DEVICES="0,1,2,3,4,5,6,7" python train_recammaster.py --task train --dataset_path recam_train_data --output_path ./models/train --dit_path "models/Wan-AI/Wan2.1-T2V-1.3B/diffusion_pytorch_model.safetensors" --steps_per_epoch 8000 --max_epochs 100 --learning_rate 1e-4 --accumulate_grad_batches 1 --use_gradient_checkpointing --dataloader_num_workers 4
```
We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate and increasing the batch size.
We do not explore the optimal set of hyper-parameters and train with a batch size of 1 on each GPU. You may achieve better model performance by adjusting hyper-parameters such as the learning rate ~and increasing the batch size~. We only support batch size=1, see more discussion here: [Wan2.1 finetuning script seems to only support bs = 1](https://github.com/modelscope/DiffSynth-Studio/issues/600)

Step 4: Test the model

Expand Down
12 changes: 7 additions & 5 deletions train_recammaster.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,9 @@ def crop_and_resize(self, image):
)
return image


def __len__(self):
return len(self.path)

def load_frames_using_imageio(self, file_path, max_num_frames, start_frame_id, interval, num_frames, frame_process):
reader = imageio.get_reader(file_path)
if reader.count_frames() < max_num_frames or reader.count_frames() - 1 < start_frame_id + (num_frames - 1) * interval:
Expand Down Expand Up @@ -115,8 +117,9 @@ def __getitem__(self, data_id):
else:
data = {"text": text, "video": video, "path": path}
break
except:
data_id += 1
except Exception as e:
print(f"ERROR WHEN LOADING: {e}")
self.__getitem__(data_id+1 if data_id+1 < len(self.path) else 0)
return data


Expand Down Expand Up @@ -203,7 +206,6 @@ def get_relative_pose(self, cam_params):
ret_poses = np.array(ret_poses, dtype=np.float32)
return ret_poses


def __getitem__(self, index):
# Return:
# data['latents']: torch.Size([16, 21*2, 60, 104])
Expand Down Expand Up @@ -263,7 +265,7 @@ def __getitem__(self, index):


def __len__(self):
return self.steps_per_epoch
return min(len(self.path), self.steps_per_epoch)



Expand Down