Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add script and docs for pre-encoding latents #175

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 101 additions & 0 deletions docs/pre_encoding.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
# Pre Encoding

When training models on encoded latents from a frozen pre-trained autoencoder, the encoder is typically frozen. Because of that, it is common to pre-encode audio to latents and store them on disk instead of computing them on-the-fly during training. This can improve training throughput as well as free up GPU memory that would otherwise be used for encoding.

## Prerequisites

To pre-encode audio to latents, you'll need a dataset config file, an autoencoder model config file, and an **unwrapped** autoencoder checkpoint file.

## Run the Pre Encoding Script

To pre-encode latents from an autoencoder model, you can use `pre_encode.py`. This script will load a pre-trained autoencoder, encode the latents/tokens, and save them to disk in a format that can be easily loaded during training.

The `pre_encode.py` script accepts the following command line arguments:

- `--model-config`
- Path to model config
- `--ckpt-path`
- Path to **unwrapped** autoencoder model checkpoint
- `--model-half`
- If true, uses half precision for model weights
- Optional
- `--dataset-config`
- Path to dataset config file
- Required
- `--output-path`
- Path to output folder
- Required
- `--batch-size`
- Batch size for processing
- Optional, defaults to 1
- `--sample-size`
- Number of audio samples to pad/crop to for pre-encoding
- Optional, defaults to 1320960 (~30 seconds)
- `--is-discrete`
- If true, treats the model as discrete, saving discrete tokens instead of continuous latents
- Optional
- `--num-workers`
- Number of dataloader workers
- Optional, defaults to 4
- `--num-gpus`
- Number of GPUs to use
- Optional, defaults to 1
- `--strategy`
- PyTorch Lightning strategy
- Optional, defaults to 'auto'
- `--limit-batches`
- Limits the number of batches processed
- Optional
- `--shuffle`
- If true, shuffles the dataset
- Optional

For example, if you wanted to use 1 GPU to encode latents with padding up to 30 seconds long in half precision, you could run the following:

```bash
$ python3 ./pre_encode.py \
--model-config /path/to/model/config.json \
--ckpt-path /path/to/autoencoder/model.ckpt \
--model-half \
--dataset-config /path/to/dataset/config.json \
--output-path /path/to/output/dir \
--sample-size 1320960 \
--num-gpus 1
```

When you run the above, the `--output-path` directory will contain numbered subdirectories for each GPU process used to encode the latents, and a `details.json` file that keeps track of settings used when the script was run.

Inside the numbered subdirectories, you will find the encoded latents as `.npy` files, along with associated `.json` metadata files.

```bash
/path/to/output/dir/
├── 0
│ ├── 0000000000000.json
│ ├── 0000000000000.npy
│ ├── 0000000000001.json
│ ├── 0000000000001.npy
│ ├── 0000000000002.json
│ ├── 0000000000002.npy
...
└── details.json
```

## Training on Pre Encoded Latents

Once you have saved your latents to disk, you can use them to train a model by providing a dataset config file to `train.py` that points to the pre-encoded latents, specifying `"dataset_type"` is `"pre_encoded"`. Under the hood, this will configure a `stable_audio_tools.data.dataset.PreEncodedDataset`.

The dataset config file should look something like this:

```json
{
"dataset_type": "pre_encoded",
"datasets": [
{
"id": "my_audio",
"path": "/path/to/output/dir",
"latent_crop_length": 645
}
],
"random_crop": false
}
```
189 changes: 189 additions & 0 deletions pre_encode.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
import argparse
import gc
import json
from pathlib import Path

import numpy as np
import pytorch_lightning as pl
import torch
from torch.nn import functional as F

from stable_audio_tools.data.dataset import create_dataloader_from_config
from stable_audio_tools.models.factory import create_model_from_config
from stable_audio_tools.models.pretrained import get_pretrained_model
from stable_audio_tools.models.utils import load_ckpt_state_dict
from stable_audio_tools.training.utils import copy_state_dict


def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, model_half=False):
if pretrained_name is not None:
print(f"Loading pretrained model {pretrained_name}")
model, model_config = get_pretrained_model(pretrained_name)

elif model_config is not None and model_ckpt_path is not None:
print(f"Creating model from config")
model = create_model_from_config(model_config)

print(f"Loading model checkpoint from {model_ckpt_path}")
copy_state_dict(model, load_ckpt_state_dict(model_ckpt_path))

model.eval().requires_grad_(False)

if model_half:
model.to(torch.float16)

print("Done loading model")

return model, model_config


class PreEncodedLatentsInferenceWrapper(pl.LightningModule):
def __init__(
self,
model,
output_path,
is_discrete=False,
model_half=False,
model_config=None,
dataset_config=None,
sample_size=1920000,
args_dict=None
):
super().__init__()
self.save_hyperparameters(ignore=['model'])
self.model = model
self.output_path = Path(output_path)

def prepare_data(self):
# runs on rank 0
self.output_path.mkdir(parents=True, exist_ok=True)
details_path = self.output_path / "details.json"
if not details_path.exists(): # Only save if it doesn't exist
details = {
"model_config": self.hparams.model_config,
"dataset_config": self.hparams.dataset_config,
"sample_size": self.hparams.sample_size,
"args": self.hparams.args_dict
}
details_path.write_text(json.dumps(details))

def setup(self, stage=None):
# runs on each device
process_dir = self.output_path / str(self.global_rank)
process_dir.mkdir(parents=True, exist_ok=True)

def validation_step(self, batch, batch_idx):
audio, metadata = batch

if audio.ndim == 4 and audio.shape[0] == 1:
audio = audio[0]

if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()

if self.hparams.model_half:
audio = audio.to(torch.float16)

with torch.no_grad():
if not self.hparams.is_discrete:
latents = self.model.encode(audio)
else:
_, info = self.model.encode(audio, return_info=True)
latents = info[self.model.bottleneck.tokens_id]

latents = latents.cpu().numpy()

# Save each sample in the batch
for i, latent in enumerate(latents):
latent_id = f"{self.global_rank:03d}{batch_idx:06d}{i:04d}"

# Save latent as numpy file
latent_path = self.output_path / str(self.global_rank) / f"{latent_id}.npy"
with open(latent_path, "wb") as f:
np.save(f, latent)

md = metadata[i]
padding_mask = F.interpolate(
md["padding_mask"].unsqueeze(0).unsqueeze(1).float(),
size=latent.shape[1],
mode="nearest"
).squeeze().int()
md["padding_mask"] = padding_mask.cpu().numpy().tolist()

# Convert tensors in md to serializable types
for k, v in md.items():
if isinstance(v, torch.Tensor):
md[k] = v.cpu().numpy().tolist()

# Save metadata to json file
metadata_path = self.output_path / str(self.global_rank) / f"{latent_id}.json"
with open(metadata_path, "w") as f:
json.dump(md, f)

def configure_optimizers(self):
return None


def main(args):
with open(args.model_config) as f:
model_config = json.load(f)

with open(args.dataset_config) as f:
dataset_config = json.load(f)

model, model_config = load_model(
model_config=model_config,
model_ckpt_path=args.ckpt_path,
model_half=args.model_half
)

data_loader = create_dataloader_from_config(
dataset_config,
batch_size=args.batch_size,
num_workers=args.num_workers,
sample_rate=model_config["sample_rate"],
sample_size=args.sample_size,
audio_channels=model_config.get("audio_channels", 2),
shuffle=args.shuffle
)

pl_module = PreEncodedLatentsInferenceWrapper(
model=model,
output_path=args.output_path,
is_discrete=args.is_discrete,
model_half=args.model_half,
model_config=args.model_config,
dataset_config=args.dataset_config,
sample_size=args.sample_size,
args_dict=vars(args)
)

trainer = pl.Trainer(
accelerator="gpu",
devices=args.num_gpus,
strategy=args.strategy,
precision="16-true" if args.model_half else "32",
max_steps=args.limit_batches if args.limit_batches else -1,
logger=False, # Disable logging since we're just doing inference
enable_checkpointing=False,
)
trainer.validate(pl_module, data_loader)

if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Encode audio dataset to VAE latents using PyTorch Lightning')
parser.add_argument('--model-config', type=str, help='Path to model config', required=False)
parser.add_argument('--ckpt-path', type=str, help='Path to unwrapped autoencoder model checkpoint', required=False)
parser.add_argument('--model-half', action='store_true', help='Whether to use half precision')
parser.add_argument('--dataset-config', type=str, help='Path to dataset config file', required=True)
parser.add_argument('--output-path', type=str, help='Path to output folder', required=True)
parser.add_argument('--batch-size', type=int, help='Batch size', default=1)
parser.add_argument('--sample-size', type=int, help='Number of audio samples to pad/crop to', default=1320960)
parser.add_argument('--is-discrete', action='store_true', help='Whether the model is discrete')
parser.add_argument('--num-workers', type=int, help='Number of dataloader workers', default=4)
parser.add_argument('--num-gpus', type=int, help='Number of GPUs to use', default=1)
parser.add_argument('--strategy', type=str, help='PyTorch Lightning strategy', default='auto')
parser.add_argument('--limit-batches', type=int, help='Limit number of batches (optional)', default=None)
parser.add_argument('--shuffle', action='store_true', help='Shuffle dataset')
args = parser.parse_args()
main(args)