Skip to content
This repository was archived by the owner on Feb 14, 2025. It is now read-only.

Commit 1609e9b

Browse files
authored
Merge VQGAN v2 to dev (myshell-ai#56)
* squash vqgan v2 changes * Merge pretrain stage 1 and 2 * Optimize vqgan inference (remove redundant code) * Implement data mixing * Optimize vqgan v2 config * Add support to freeze discriminator * Add stft loss & larger segement size
1 parent 39f6902 commit 1609e9b

17 files changed

+1737
-495
lines changed

.pre-commit-config.yaml

+1
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ repos:
1818
hooks:
1919
- id: codespell
2020
files: ^.*\.(py|md|rst|yml)$
21+
args: [-L=fro]
2122

2223
- repo: https://github.com/pre-commit/pre-commit-hooks
2324
rev: v4.5.0

fish_speech/configs/vqgan_pretrain_v2.yaml

+59-25
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,8 @@ defaults:
33
- _self_
44

55
project: vqgan_pretrain_v2
6+
ckpt_path: checkpoints/hifigan-base-comb-mix-lb-020/step_001200000_weights_only.ckpt
7+
resume_weights_only: true
68

79
# Lightning Trainer
810
trainer:
@@ -15,22 +17,36 @@ trainer:
1517

1618
sample_rate: 44100
1719
hop_length: 512
18-
num_mels: 128
20+
num_mels: 160
1921
n_fft: 2048
2022
win_length: 2048
2123
segment_size: 256
2224

2325
# Dataset Configuration
2426
train_dataset:
25-
_target_: fish_speech.datasets.vqgan.VQGANDataset
26-
filelist: data/Genshin/vq_train_filelist.txt
27-
sample_rate: ${sample_rate}
28-
hop_length: ${hop_length}
29-
slice_frames: ${segment_size}
27+
_target_: fish_speech.datasets.vqgan.MixDatast
28+
datasets:
29+
high-quality-441:
30+
prob: 0.5
31+
dataset:
32+
_target_: fish_speech.datasets.vqgan.VQGANDataset
33+
filelist: data/vocoder_data_441/vq_train_filelist.txt
34+
sample_rate: ${sample_rate}
35+
hop_length: ${hop_length}
36+
slice_frames: ${segment_size}
37+
38+
common-voice:
39+
prob: 0.5
40+
dataset:
41+
_target_: fish_speech.datasets.vqgan.VQGANDataset
42+
filelist: data/cv-corpus-16.0-2023-12-06/vq_train_filelist.txt
43+
sample_rate: ${sample_rate}
44+
hop_length: ${hop_length}
45+
slice_frames: ${segment_size}
3046

3147
val_dataset:
3248
_target_: fish_speech.datasets.vqgan.VQGANDataset
33-
filelist: data/Genshin/vq_val_filelist.txt
49+
filelist: data/vocoder_data_441/vq_val_filelist.txt
3450
sample_rate: ${sample_rate}
3551
hop_length: ${hop_length}
3652

@@ -47,8 +63,9 @@ model:
4763
_target_: fish_speech.models.vqgan.VQGAN
4864
sample_rate: ${sample_rate}
4965
hop_length: ${hop_length}
50-
segment_size: 8192
51-
mode: pretrain-stage1
66+
segment_size: 32768
67+
mode: pretrain
68+
freeze_discriminator: true
5269

5370
downsample:
5471
_target_: fish_speech.models.vqgan.modules.encoders.ConvDownSampler
@@ -67,8 +84,8 @@ model:
6784
_target_: fish_speech.models.vqgan.modules.encoders.VQEncoder
6885
in_channels: 256
6986
vq_channels: 256
70-
codebook_size: 1024
71-
codebook_layers: 4
87+
codebook_size: 256
88+
codebook_groups: 4
7289
downsample: 1
7390

7491
decoder:
@@ -80,33 +97,50 @@ model:
8097
n_layers: 6
8198

8299
generator:
83-
_target_: fish_speech.models.vqgan.modules.decoder.Generator
84-
initial_channel: ${num_mels}
85-
resblock: "1"
100+
_target_: fish_speech.models.vqgan.modules.decoder_v2.HiFiGANGenerator
101+
hop_length: ${hop_length}
102+
upsample_rates: [8, 8, 2, 2, 2] # aka. strides
103+
upsample_kernel_sizes: [16, 16, 4, 4, 4]
86104
resblock_kernel_sizes: [3, 7, 11]
87105
resblock_dilation_sizes: [[1, 3, 5], [1, 3, 5], [1, 3, 5]]
88-
upsample_rates: [8, 8, 2, 2, 2]
106+
num_mels: ${num_mels}
89107
upsample_initial_channel: 512
90-
upsample_kernel_sizes: [16, 16, 4, 4, 4]
91-
92-
discriminator:
93-
_target_: fish_speech.models.vqgan.modules.discriminator.EnsembleDiscriminator
94-
periods: [2, 3, 5, 7, 11, 17, 23, 37]
95-
108+
use_template: true
109+
pre_conv_kernel_size: 7
110+
post_conv_kernel_size: 7
111+
112+
discriminators:
113+
_target_: torch.nn.ModuleDict
114+
modules:
115+
mpd:
116+
_target_: fish_speech.models.vqgan.modules.discriminators.mpd.MultiPeriodDiscriminator
117+
periods: [2, 3, 5, 7, 11, 17, 23, 37]
118+
119+
mrd:
120+
_target_: fish_speech.models.vqgan.modules.discriminators.mrd.MultiResolutionDiscriminator
121+
resolutions:
122+
- ["${n_fft}", "${hop_length}", "${win_length}"]
123+
- [1024, 120, 600]
124+
- [2048, 240, 1200]
125+
- [4096, 480, 2400]
126+
- [512, 50, 240]
127+
128+
multi_resolution_stft_loss:
129+
_target_: fish_speech.models.vqgan.losses.MultiResolutionSTFTLoss
130+
resolutions: ${model.discriminators.modules.mrd.resolutions}
131+
96132
mel_transform:
97133
_target_: fish_speech.models.vqgan.spectrogram.LogMelSpectrogram
98134
sample_rate: ${sample_rate}
99135
n_fft: ${n_fft}
100136
hop_length: ${hop_length}
101137
win_length: ${win_length}
102138
n_mels: ${num_mels}
103-
f_min: 0
104-
f_max: 16000
105139

106140
optimizer:
107141
_target_: torch.optim.AdamW
108142
_partial_: true
109-
lr: 2e-4
143+
lr: 1e-4
110144
betas: [0.8, 0.99]
111145
eps: 1e-5
112146

@@ -119,7 +153,7 @@ callbacks:
119153
grad_norm_monitor:
120154
sub_module:
121155
- generator
122-
- discriminator
156+
- discriminators
123157
- mel_encoder
124158
- vq_encoder
125159
- decoder

fish_speech/datasets/vqgan.py

+29-2
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77
import torch
88
from lightning import LightningDataModule
9-
from torch.utils.data import DataLoader, Dataset
9+
from torch.utils.data import DataLoader, Dataset, IterableDataset
1010

1111
from fish_speech.utils import RankedLogger
1212

@@ -72,6 +72,33 @@ def __getitem__(self, idx):
7272
return None
7373

7474

75+
class MixDatast(IterableDataset):
76+
def __init__(self, datasets: dict[str, dict], seed: int = 42) -> None:
77+
values = list(datasets.values())
78+
probs = [v["prob"] for v in values]
79+
self.datasets = [v["dataset"] for v in values]
80+
81+
total_probs = sum(probs)
82+
self.probs = [p / total_probs for p in probs]
83+
self.seed = seed
84+
85+
def __iter__(self):
86+
rng = np.random.default_rng(self.seed)
87+
dataset_iterators = [iter(dataset) for dataset in self.datasets]
88+
89+
while True:
90+
# Random choice one
91+
dataset_idx = rng.choice(len(self.datasets), p=self.probs)
92+
dataset_iterator = dataset_iterators[dataset_idx]
93+
94+
try:
95+
yield next(dataset_iterator)
96+
except StopIteration:
97+
# Exhausted, create a new iterator
98+
dataset_iterators[dataset_idx] = iter(self.datasets[dataset_idx])
99+
yield next(dataset_iterators[dataset_idx])
100+
101+
75102
@dataclass
76103
class VQGANCollator:
77104
def __call__(self, batch):
@@ -116,7 +143,7 @@ def train_dataloader(self):
116143
batch_size=self.batch_size,
117144
collate_fn=VQGANCollator(),
118145
num_workers=self.num_workers,
119-
shuffle=True,
146+
shuffle=not isinstance(self.train_dataset, IterableDataset),
120147
)
121148

122149
def val_dataloader(self):

0 commit comments

Comments
 (0)