Skip to content

Commit 4dd2794

Browse files
authoredMar 27, 2024··
Refactor the code in train.py (#213)
1 parent c6b6874 commit 4dd2794

File tree

1 file changed

+60
-41
lines changed

1 file changed

+60
-41
lines changed
 

‎wetts/vits/train.py

+60-41
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import os
22

33
import torch
4+
import torch.distributed as dist
5+
46
from torch.nn import functional as F
57
from torch.utils.data import DataLoader
68
from torch.utils.tensorboard import SummaryWriter
7-
import torch.distributed as dist
89
from torch.nn.parallel import DistributedDataParallel as DDP
910
from torch.cuda.amp import autocast, GradScaler
1011

@@ -30,30 +31,17 @@
3031

3132
def main():
3233
hps = task.get_hparams()
34+
# Set random seed
3335
torch.manual_seed(hps.train.seed)
3436
global global_step
37+
# Initialize distributed
3538
world_size = int(os.environ.get('WORLD_SIZE', 1))
3639
local_rank = int(os.environ.get('LOCAL_RANK', 0))
3740
rank = int(os.environ.get('RANK', 0))
3841
torch.cuda.set_device(local_rank)
3942
dist.init_process_group("nccl")
40-
if rank == 0:
41-
logger = task.get_logger(hps.model_dir)
42-
logger.info(hps)
43-
writer = SummaryWriter(log_dir=hps.model_dir)
44-
writer_eval = SummaryWriter(
45-
log_dir=os.path.join(hps.model_dir, "eval"))
46-
47-
if ("use_mel_posterior_encoder" in hps.model.keys()
48-
and hps.model.use_mel_posterior_encoder):
49-
print("Using mel posterior encoder for VITS2")
50-
posterior_channels = hps.data.n_mel_channels # vits2
51-
hps.data.use_mel_posterior_encoder = True
52-
else:
53-
print("Using lin posterior encoder for VITS1")
54-
posterior_channels = hps.data.filter_length // 2 + 1
55-
hps.data.use_mel_posterior_encoder = False
5643

44+
# Get the dataset and data loader
5745
train_dataset = TextAudioSpeakerLoader(hps.data.training_files, hps.data)
5846
train_sampler = DistributedBucketSampler(
5947
train_dataset,
@@ -85,6 +73,17 @@ def main():
8573
collate_fn=collate_fn,
8674
)
8775

76+
# Get the tts model
77+
if ("use_mel_posterior_encoder" in hps.model.keys()
78+
and hps.model.use_mel_posterior_encoder):
79+
print("Using mel posterior encoder for VITS2")
80+
posterior_channels = hps.data.n_mel_channels # vits2
81+
hps.data.use_mel_posterior_encoder = True
82+
else:
83+
print("Using lin posterior encoder for VITS1")
84+
posterior_channels = hps.data.filter_length // 2 + 1
85+
hps.data.use_mel_posterior_encoder = False
86+
8887
# some of these flags are not being used in the code and directly set in hps
8988
# json file. they are kept here for reference and prototyping.
9089
if ("use_transformer_flows" in hps.model.keys()
@@ -144,7 +143,7 @@ def main():
144143
0.1,
145144
gin_channels=hps.model.gin_channels
146145
if hps.data.n_speakers != 0 else 0,
147-
).cuda(rank)
146+
).cuda(local_rank)
148147
elif duration_discriminator_type == "dur_disc_2":
149148
net_dur_disc = DurationDiscriminatorV2(
150149
hps.model.hidden_channels,
@@ -153,7 +152,7 @@ def main():
153152
0.1,
154153
gin_channels=hps.model.gin_channels
155154
if hps.data.n_speakers != 0 else 0,
156-
).cuda(rank)
155+
).cuda(local_rank)
157156
else:
158157
print("NOT using any duration discriminator like VITS1")
159158
net_dur_disc = None
@@ -164,15 +163,33 @@ def main():
164163
n_speakers=hps.data.n_speakers,
165164
mas_noise_scale_initial=mas_noise_scale_initial,
166165
noise_scale_delta=noise_scale_delta,
167-
**hps.model).cuda(rank)
166+
**hps.model).cuda(local_rank)
168167
if ("use_mrd_disc" in hps.model.keys()
169168
and hps.model.use_mrd_disc):
170169
print("Using MultiPeriodMultiResolutionDiscriminator")
171170
net_d = MultiPeriodMultiResolutionDiscriminator(
172-
hps.model.use_spectral_norm).cuda(rank)
171+
hps.model.use_spectral_norm).cuda(local_rank)
173172
else:
174173
print("Using MPD")
175-
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(rank)
174+
net_d = MultiPeriodDiscriminator(hps.model.use_spectral_norm).cuda(local_rank)
175+
176+
# Dispatch the model from cpu to gpu
177+
# comment - choihkk
178+
# if we comment out unused parameter like DurationDiscriminator's
179+
# self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
180+
# self.post_transformer we don't have to set find_unused_parameters=True
181+
# but I will not proceed with commenting out for compatibility with the
182+
# latest work for others
183+
net_g = DDP(net_g, device_ids=[local_rank], find_unused_parameters=True)
184+
net_d = DDP(net_d, device_ids=[local_rank], find_unused_parameters=True)
185+
if net_dur_disc:
186+
net_dur_disc = DDP(
187+
net_dur_disc,
188+
device_ids=[local_rank],
189+
find_unused_parameters=True
190+
)
191+
192+
# Get the optimizer
176193
optim_g = torch.optim.AdamW(
177194
net_g.parameters(),
178195
hps.train.learning_rate,
@@ -195,17 +212,7 @@ def main():
195212
else:
196213
optim_dur_disc = None
197214

198-
# comment - choihkk
199-
# if we comment out unused parameter like DurationDiscriminator's
200-
# self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
201-
# self.post_transformer we don't have to set find_unused_parameters=True
202-
# but I will not proceed with commenting out for compatibility with the
203-
# latest work for others
204-
net_g = DDP(net_g, device_ids=[rank], find_unused_parameters=True)
205-
net_d = DDP(net_d, device_ids=[rank], find_unused_parameters=True)
206-
if net_dur_disc:
207-
net_dur_disc = DDP(net_dur_disc, device_ids=[rank], find_unused_parameters=True)
208-
215+
# Load the checkpoint
209216
try:
210217
_, _, _, epoch_str = task.load_checkpoint(
211218
task.latest_checkpoint_path(hps.model_dir, "G_*.pth"), net_g,
@@ -224,6 +231,7 @@ def main():
224231
epoch_str = 1
225232
global_step = 0
226233

234+
# Get the scheduler
227235
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(
228236
optim_g, gamma=hps.train.lr_decay, last_epoch=epoch_str - 2)
229237
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(
@@ -234,12 +242,22 @@ def main():
234242
else:
235243
scheduler_dur_disc = None
236244

245+
# Get the tensorboard summary
246+
writer = None
247+
if rank == 0:
248+
logger = task.get_logger(hps.model_dir)
249+
logger.info(hps)
250+
writer = SummaryWriter(log_dir=hps.model_dir)
251+
writer_eval = SummaryWriter(
252+
log_dir=os.path.join(hps.model_dir, "eval"))
253+
237254
scaler = GradScaler(enabled=hps.train.fp16_run)
238255

239256
for epoch in range(epoch_str, hps.train.epochs + 1):
240257
if rank == 0:
241258
train_and_evaluate(
242259
rank,
260+
local_rank,
243261
epoch,
244262
hps,
245263
[net_g, net_d, net_dur_disc],
@@ -253,6 +271,7 @@ def main():
253271
else:
254272
train_and_evaluate(
255273
rank,
274+
local_rank,
256275
epoch,
257276
hps,
258277
[net_g, net_d, net_dur_disc],
@@ -269,7 +288,7 @@ def main():
269288
scheduler_dur_disc.step()
270289

271290

272-
def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
291+
def train_and_evaluate(rank, local_rank, epoch, hps, nets, optims, schedulers, scaler,
273292
loaders, logger, writers):
274293
net_g, net_d, net_dur_disc = nets
275294
optim_g, optim_d, optim_dur_disc = optims
@@ -301,14 +320,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
301320
net_g.module.noise_scale_delta * global_step)
302321
net_g.module.current_mas_noise_scale = max(current_mas_noise_scale,
303322
0.0)
304-
x, x_lengths = x.cuda(rank, non_blocking=True), x_lengths.cuda(
305-
rank, non_blocking=True)
323+
x, x_lengths = x.cuda(local_rank, non_blocking=True), x_lengths.cuda(
324+
local_rank, non_blocking=True)
306325
spec, spec_lengths = spec.cuda(
307-
rank, non_blocking=True), spec_lengths.cuda(rank,
308-
non_blocking=True)
309-
y, y_lengths = y.cuda(rank, non_blocking=True), y_lengths.cuda(
310-
rank, non_blocking=True)
311-
speakers = speakers.cuda(rank, non_blocking=True)
326+
local_rank, non_blocking=True), spec_lengths.cuda(local_rank,
327+
non_blocking=True)
328+
y, y_lengths = y.cuda(local_rank, non_blocking=True), y_lengths.cuda(
329+
local_rank, non_blocking=True)
330+
speakers = speakers.cuda(local_rank, non_blocking=True)
312331

313332
with autocast(enabled=hps.train.fp16_run):
314333
(

0 commit comments

Comments
 (0)
Please sign in to comment.