1
1
import os
2
2
3
3
import torch
4
+ import torch .distributed as dist
5
+
4
6
from torch .nn import functional as F
5
7
from torch .utils .data import DataLoader
6
8
from torch .utils .tensorboard import SummaryWriter
7
- import torch .distributed as dist
8
9
from torch .nn .parallel import DistributedDataParallel as DDP
9
10
from torch .cuda .amp import autocast , GradScaler
10
11
30
31
31
32
def main ():
32
33
hps = task .get_hparams ()
34
+ # Set random seed
33
35
torch .manual_seed (hps .train .seed )
34
36
global global_step
37
+ # Initialize distributed
35
38
world_size = int (os .environ .get ('WORLD_SIZE' , 1 ))
36
39
local_rank = int (os .environ .get ('LOCAL_RANK' , 0 ))
37
40
rank = int (os .environ .get ('RANK' , 0 ))
38
41
torch .cuda .set_device (local_rank )
39
42
dist .init_process_group ("nccl" )
40
- if rank == 0 :
41
- logger = task .get_logger (hps .model_dir )
42
- logger .info (hps )
43
- writer = SummaryWriter (log_dir = hps .model_dir )
44
- writer_eval = SummaryWriter (
45
- log_dir = os .path .join (hps .model_dir , "eval" ))
46
-
47
- if ("use_mel_posterior_encoder" in hps .model .keys ()
48
- and hps .model .use_mel_posterior_encoder ):
49
- print ("Using mel posterior encoder for VITS2" )
50
- posterior_channels = hps .data .n_mel_channels # vits2
51
- hps .data .use_mel_posterior_encoder = True
52
- else :
53
- print ("Using lin posterior encoder for VITS1" )
54
- posterior_channels = hps .data .filter_length // 2 + 1
55
- hps .data .use_mel_posterior_encoder = False
56
43
44
+ # Get the dataset and data loader
57
45
train_dataset = TextAudioSpeakerLoader (hps .data .training_files , hps .data )
58
46
train_sampler = DistributedBucketSampler (
59
47
train_dataset ,
@@ -85,6 +73,17 @@ def main():
85
73
collate_fn = collate_fn ,
86
74
)
87
75
76
+ # Get the tts model
77
+ if ("use_mel_posterior_encoder" in hps .model .keys ()
78
+ and hps .model .use_mel_posterior_encoder ):
79
+ print ("Using mel posterior encoder for VITS2" )
80
+ posterior_channels = hps .data .n_mel_channels # vits2
81
+ hps .data .use_mel_posterior_encoder = True
82
+ else :
83
+ print ("Using lin posterior encoder for VITS1" )
84
+ posterior_channels = hps .data .filter_length // 2 + 1
85
+ hps .data .use_mel_posterior_encoder = False
86
+
88
87
# some of these flags are not being used in the code and directly set in hps
89
88
# json file. they are kept here for reference and prototyping.
90
89
if ("use_transformer_flows" in hps .model .keys ()
@@ -144,7 +143,7 @@ def main():
144
143
0.1 ,
145
144
gin_channels = hps .model .gin_channels
146
145
if hps .data .n_speakers != 0 else 0 ,
147
- ).cuda (rank )
146
+ ).cuda (local_rank )
148
147
elif duration_discriminator_type == "dur_disc_2" :
149
148
net_dur_disc = DurationDiscriminatorV2 (
150
149
hps .model .hidden_channels ,
@@ -153,7 +152,7 @@ def main():
153
152
0.1 ,
154
153
gin_channels = hps .model .gin_channels
155
154
if hps .data .n_speakers != 0 else 0 ,
156
- ).cuda (rank )
155
+ ).cuda (local_rank )
157
156
else :
158
157
print ("NOT using any duration discriminator like VITS1" )
159
158
net_dur_disc = None
@@ -164,15 +163,33 @@ def main():
164
163
n_speakers = hps .data .n_speakers ,
165
164
mas_noise_scale_initial = mas_noise_scale_initial ,
166
165
noise_scale_delta = noise_scale_delta ,
167
- ** hps .model ).cuda (rank )
166
+ ** hps .model ).cuda (local_rank )
168
167
if ("use_mrd_disc" in hps .model .keys ()
169
168
and hps .model .use_mrd_disc ):
170
169
print ("Using MultiPeriodMultiResolutionDiscriminator" )
171
170
net_d = MultiPeriodMultiResolutionDiscriminator (
172
- hps .model .use_spectral_norm ).cuda (rank )
171
+ hps .model .use_spectral_norm ).cuda (local_rank )
173
172
else :
174
173
print ("Using MPD" )
175
- net_d = MultiPeriodDiscriminator (hps .model .use_spectral_norm ).cuda (rank )
174
+ net_d = MultiPeriodDiscriminator (hps .model .use_spectral_norm ).cuda (local_rank )
175
+
176
+ # Dispatch the model from cpu to gpu
177
+ # comment - choihkk
178
+ # if we comment out unused parameter like DurationDiscriminator's
179
+ # self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
180
+ # self.post_transformer we don't have to set find_unused_parameters=True
181
+ # but I will not proceed with commenting out for compatibility with the
182
+ # latest work for others
183
+ net_g = DDP (net_g , device_ids = [local_rank ], find_unused_parameters = True )
184
+ net_d = DDP (net_d , device_ids = [local_rank ], find_unused_parameters = True )
185
+ if net_dur_disc :
186
+ net_dur_disc = DDP (
187
+ net_dur_disc ,
188
+ device_ids = [local_rank ],
189
+ find_unused_parameters = True
190
+ )
191
+
192
+ # Get the optimizer
176
193
optim_g = torch .optim .AdamW (
177
194
net_g .parameters (),
178
195
hps .train .learning_rate ,
@@ -195,17 +212,7 @@ def main():
195
212
else :
196
213
optim_dur_disc = None
197
214
198
- # comment - choihkk
199
- # if we comment out unused parameter like DurationDiscriminator's
200
- # self.pre_out_norm1,2 self.norm_1,2 and ResidualCouplingTransformersLayer's
201
- # self.post_transformer we don't have to set find_unused_parameters=True
202
- # but I will not proceed with commenting out for compatibility with the
203
- # latest work for others
204
- net_g = DDP (net_g , device_ids = [rank ], find_unused_parameters = True )
205
- net_d = DDP (net_d , device_ids = [rank ], find_unused_parameters = True )
206
- if net_dur_disc :
207
- net_dur_disc = DDP (net_dur_disc , device_ids = [rank ], find_unused_parameters = True )
208
-
215
+ # Load the checkpoint
209
216
try :
210
217
_ , _ , _ , epoch_str = task .load_checkpoint (
211
218
task .latest_checkpoint_path (hps .model_dir , "G_*.pth" ), net_g ,
@@ -224,6 +231,7 @@ def main():
224
231
epoch_str = 1
225
232
global_step = 0
226
233
234
+ # Get the scheduler
227
235
scheduler_g = torch .optim .lr_scheduler .ExponentialLR (
228
236
optim_g , gamma = hps .train .lr_decay , last_epoch = epoch_str - 2 )
229
237
scheduler_d = torch .optim .lr_scheduler .ExponentialLR (
@@ -234,12 +242,22 @@ def main():
234
242
else :
235
243
scheduler_dur_disc = None
236
244
245
+ # Get the tensorboard summary
246
+ writer = None
247
+ if rank == 0 :
248
+ logger = task .get_logger (hps .model_dir )
249
+ logger .info (hps )
250
+ writer = SummaryWriter (log_dir = hps .model_dir )
251
+ writer_eval = SummaryWriter (
252
+ log_dir = os .path .join (hps .model_dir , "eval" ))
253
+
237
254
scaler = GradScaler (enabled = hps .train .fp16_run )
238
255
239
256
for epoch in range (epoch_str , hps .train .epochs + 1 ):
240
257
if rank == 0 :
241
258
train_and_evaluate (
242
259
rank ,
260
+ local_rank ,
243
261
epoch ,
244
262
hps ,
245
263
[net_g , net_d , net_dur_disc ],
@@ -253,6 +271,7 @@ def main():
253
271
else :
254
272
train_and_evaluate (
255
273
rank ,
274
+ local_rank ,
256
275
epoch ,
257
276
hps ,
258
277
[net_g , net_d , net_dur_disc ],
@@ -269,7 +288,7 @@ def main():
269
288
scheduler_dur_disc .step ()
270
289
271
290
272
- def train_and_evaluate (rank , epoch , hps , nets , optims , schedulers , scaler ,
291
+ def train_and_evaluate (rank , local_rank , epoch , hps , nets , optims , schedulers , scaler ,
273
292
loaders , logger , writers ):
274
293
net_g , net_d , net_dur_disc = nets
275
294
optim_g , optim_d , optim_dur_disc = optims
@@ -301,14 +320,14 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler,
301
320
net_g .module .noise_scale_delta * global_step )
302
321
net_g .module .current_mas_noise_scale = max (current_mas_noise_scale ,
303
322
0.0 )
304
- x , x_lengths = x .cuda (rank , non_blocking = True ), x_lengths .cuda (
305
- rank , non_blocking = True )
323
+ x , x_lengths = x .cuda (local_rank , non_blocking = True ), x_lengths .cuda (
324
+ local_rank , non_blocking = True )
306
325
spec , spec_lengths = spec .cuda (
307
- rank , non_blocking = True ), spec_lengths .cuda (rank ,
308
- non_blocking = True )
309
- y , y_lengths = y .cuda (rank , non_blocking = True ), y_lengths .cuda (
310
- rank , non_blocking = True )
311
- speakers = speakers .cuda (rank , non_blocking = True )
326
+ local_rank , non_blocking = True ), spec_lengths .cuda (local_rank ,
327
+ non_blocking = True )
328
+ y , y_lengths = y .cuda (local_rank , non_blocking = True ), y_lengths .cuda (
329
+ local_rank , non_blocking = True )
330
+ speakers = speakers .cuda (local_rank , non_blocking = True )
312
331
313
332
with autocast (enabled = hps .train .fp16_run ):
314
333
(
0 commit comments