Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -20,3 +20,7 @@ htmlcov

.venv/
lightning_logs/

dataset
training
*.ckpt
119 changes: 112 additions & 7 deletions src/python/piper_train/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

import torch
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import ModelCheckpoint
from pytorch_lightning.callbacks import ModelCheckpoint, EarlyStopping, LearningRateMonitor

from .vits.lightning import VitsModel

Expand All @@ -14,6 +14,9 @@

def main():
logging.basicConfig(level=logging.DEBUG)
logging.getLogger("fsspec").setLevel(logging.WARNING)
logging.getLogger("matplotlib").setLevel(logging.WARNING)
logging.getLogger("PIL").setLevel(logging.WARNING)

parser = argparse.ArgumentParser()
parser.add_argument(
Expand All @@ -37,25 +40,76 @@ def main():
"--resume_from_single_speaker_checkpoint",
help="For multi-speaker models only. Converts a single-speaker checkpoint to multi-speaker and resumes training",
)
Trainer.add_argparse_args(parser)
VitsModel.add_model_specific_args(parser)
parser.add_argument(
"--accelerator",
type=str,
)
parser.add_argument(
"--devices",
type=int,
)
parser.add_argument(
"--log_every_n_steps",
type=int,
)
parser.add_argument(
"--max_epochs",
type=int,
)
parser.add_argument(
"--seed",
type=int,
default=1234
)
parser.add_argument(
"--random_seed",
type=bool,
default=False
)
parser.add_argument(
"--resume_from_checkpoint",
type=str,
)
parser.add_argument(
"--precision",
type=str,
)
parser.add_argument(
"--num_ckpt",
type=int,
default=1,
help="# of ckpts saved."
)
parser.add_argument(
"--default_root_dir",
type=str,
help="Default root dir for checkpoints and logs."
)
parser.add_argument(
"--save_last",
type=bool,
default=None,
help="Always save the last checkpoint."
)
parser.add_argument(
"--monitor",
type=str,
default="val_loss",
help="Metric to monitor."
)
parser.add_argument(
"--monitor_mode",
type=str,
default="min",
help="Mode to monitor."
)
parser.add_argument(
"--early_stop_patience",
type=int,
default=0,
help="Early stopping patience."
)
args = parser.parse_args()
_LOGGER.debug(args)

Expand All @@ -64,7 +118,24 @@ def main():
args.default_root_dir = args.dataset_dir

torch.backends.cudnn.benchmark = True
torch.manual_seed(args.seed)

if args.random_seed:
seed = torch.seed()
_LOGGER.debug("Using random seed: %s", seed)
else:
torch.manual_seed(args.seed)
_LOGGER.debug("Using manual seed: %s", args.seed)

# Function to check if the GPU supports Tensor Cores
def supports_tensor_cores():
# Assuming that Tensor Cores are supported if the compute capability is 7.0 or higher
# This is a simplification; you might need a more detailed check based on your specific requirements
return torch.cuda.get_device_capability(0)[0] >= 7

# Set the float32 matrix multiplication precision based on GPU support for Tensor Cores
if supports_tensor_cores():
# Set to 'high' or 'medium' based on your preference
torch.set_float32_matmul_precision('high')

config_path = args.dataset_dir / "config.json"
dataset_path = args.dataset_dir / "dataset.jsonl"
Expand All @@ -76,20 +147,54 @@ def main():
num_speakers = int(config["num_speakers"])
sample_rate = int(config["audio"]["sample_rate"])

trainer = Trainer.from_argparse_args(args)
# List of argument names to include
allowed_args = [
"accelerator",
"devices",
"log_every_n_steps",
"max_epochs",
"precision",
"default_root_dir",
]

# Filter the arguments
filtered_args = {key: value for key, value in vars(args).items() if key in allowed_args}

# Initialize callbacks
callbacks = []

if args.checkpoint_epochs is not None:
trainer.callbacks = [ModelCheckpoint(
checkpoint_callback = ModelCheckpoint(
every_n_epochs=args.checkpoint_epochs,
save_top_k=args.num_ckpt,
monitor=args.monitor,
mode=args.monitor_mode,
save_last=args.save_last
)]
)
callbacks.append(checkpoint_callback)
_LOGGER.debug(
"Checkpoints will be saved every %s epoch(s)", args.checkpoint_epochs
)
_LOGGER.debug(
"%s Checkpoints will be saved", args.num_ckpt
)

if args.early_stop_patience > 0:
# Early stopping callback
early_stopping_callback = EarlyStopping(
monitor='val_loss',
patience=args.early_stop_patience,
verbose=True,
mode='min'
)
callbacks.append(early_stopping_callback)

# Learning rate monitor callback
lr_monitor_callback = LearningRateMonitor(logging_interval='epoch')
callbacks.append(lr_monitor_callback)

trainer = Trainer(**filtered_args, callbacks=callbacks)

dict_args = vars(args)
if args.quality == "x-low":
dict_args["hidden_channels"] = 96
Expand Down Expand Up @@ -147,7 +252,7 @@ def main():
"Successfully converted single-speaker checkpoint to multi-speaker"
)

trainer.fit(model)
trainer.fit(model, ckpt_path=args.resume_from_checkpoint)


def load_state_dict(model, saved_state_dict):
Expand Down
2 changes: 1 addition & 1 deletion src/python/piper_train/clean_cached_audio.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ def check_file(pt_path: Path) -> None:

try:
_LOGGER.debug("Checking %s", pt_path)
torch.load(str(pt_path))
torch.load(str(pt_path), weights_only=True)
except Exception:
_LOGGER.error(pt_path)
if args.delete:
Expand Down
14 changes: 8 additions & 6 deletions src/python/piper_train/export_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,9 @@ def main() -> None:
with torch.no_grad():
model_g.dec.remove_weight_norm()

# old_forward = model_g.infer
# Define the device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model_g.to(device)

def infer_forward(text, text_lengths, scales, sid=None):
noise_scale = scales[0]
Expand All @@ -73,15 +75,15 @@ def infer_forward(text, text_lengths, scales, sid=None):
dummy_input_length = 50
sequences = torch.randint(
low=0, high=num_symbols, size=(1, dummy_input_length), dtype=torch.long
)
sequence_lengths = torch.LongTensor([sequences.size(1)])
).to(device)
sequence_lengths = torch.LongTensor([sequences.size(1)]).to(device)

sid: Optional[torch.LongTensor] = None
if num_speakers > 1:
sid = torch.LongTensor([0])
sid = torch.LongTensor([0]).to(device)

# noise, noise_w, length
scales = torch.FloatTensor([0.667, 1.0, 0.8])
scales = torch.FloatTensor([0.667, 1.0, 0.8]).to(device)
dummy_input = (sequences, sequence_lengths, scales, sid)

# Export
Expand All @@ -106,4 +108,4 @@ def infer_forward(text, text_lengths, scales, sid=None):
# -----------------------------------------------------------------------------

if __name__ == "__main__":
main()
main()
2 changes: 1 addition & 1 deletion src/python/piper_train/infer_generator.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def main():
args.output_dir = Path(args.output_dir)
args.output_dir.mkdir(parents=True, exist_ok=True)

model = torch.load(args.model)
model = torch.load(args.model, weights_only=True)

# Inference only
model.eval()
Expand Down
2 changes: 1 addition & 1 deletion src/python/piper_train/norm_audio/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def cache_norm_audio(
if ignore_cache or (not audio_spec_path.exists()):
if audio_norm_tensor is None:
# Load pre-cached normalized audio
audio_norm_tensor = torch.load(audio_norm_path)
audio_norm_tensor = torch.load(audio_norm_path, weights_only=True)

audio_spec_tensor = spectrogram_torch(
y=audio_norm_tensor,
Expand Down
4 changes: 2 additions & 2 deletions src/python/piper_train/vits/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,8 +77,8 @@ def __getitem__(self, idx) -> UtteranceTensors:
utt = self.utterances[idx]
return UtteranceTensors(
phoneme_ids=LongTensor(utt.phoneme_ids),
audio_norm=torch.load(utt.audio_norm_path),
spectrogram=torch.load(utt.audio_spec_path),
audio_norm=torch.load(utt.audio_norm_path, weights_only=True),
spectrogram=torch.load(utt.audio_spec_path, weights_only=True),
speaker_id=LongTensor([utt.speaker_id])
if utt.speaker_id is not None
else None,
Expand Down
Loading