Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add support for CPU and MPS across the tools in this repo #153

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion stable_audio_tools/inference/generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,13 +8,20 @@
from .sampling import sample, sample_k, sample_rf
from ..data.utils import PadCrop

if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

def generate_diffusion_uncond(
model,
steps: int = 250,
batch_size: int = 1,
sample_size: int = 2097152,
seed: int = -1,
device: str = "cuda",
device: str = device.type,
init_audio: tp.Optional[tp.Tuple[int, torch.Tensor]] = None,
init_noise_level: float = 1.0,
return_latents = False,
Expand Down
41 changes: 26 additions & 15 deletions stable_audio_tools/inference/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,16 @@

import k_diffusion as K

if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

valid_autocast_device_types = {"cuda", "cpu"}
autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu"

# Define the noise schedule and sampling loop
def get_alphas_sigmas(t):
"""Returns the scaling factors for the clean image (alpha) and for the
Expand Down Expand Up @@ -58,7 +68,7 @@ def sample(model, x, steps, eta, **extra_args):
for i in trange(steps):

# Get the model output (v, the predicted velocity)
with torch.cuda.amp.autocast():
with torch.amp.autocast(autocast_device_type):
v = model(x, ts * t[i], **extra_args).float()

# Predict the noise and the denoised image
Expand Down Expand Up @@ -109,16 +119,17 @@ def cond_model_fn(x, sigma, **kwargs):
# For variations, set init_data
# For inpainting, set both init_data & mask
def sample_k(
model_fn,
noise,
model_fn,
noise,
init_data=None,
mask=None,
steps=100,
sampler_type="dpmpp-2m-sde",
sigma_min=0.5,
sigma_max=50,
rho=1.0, device="cuda",
callback=None,
steps=100,
sampler_type="dpmpp-2m-sde",
sigma_min=0.5,
sigma_max=50,
rho=1.0,
device=device.type,
callback=None,
cond_fn=None,
**extra_args
):
Expand Down Expand Up @@ -174,7 +185,7 @@ def inpainting_callback(args):
x = noise


with torch.cuda.amp.autocast():
with torch.amp.autocast(autocast_device_type):
if sampler_type == "k-heun":
return K.sampling.sample_heun(denoiser, x, sigmas, disable=False, callback=wrapped_callback, extra_args=extra_args)
elif sampler_type == "k-lms":
Expand All @@ -198,13 +209,13 @@ def inpainting_callback(args):
# For variations, set init_data
# For inpainting, set both init_data & mask
def sample_rf(
model_fn,
noise,
model_fn,
noise,
init_data=None,
steps=100,
steps=100,
sigma_max=1,
device="cuda",
callback=None,
device=device.type,
callback=None,
cond_fn=None,
**extra_args
):
Expand Down
9 changes: 8 additions & 1 deletion stable_audio_tools/interface/gradio.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,14 @@
sample_rate = 32000
sample_size = 1920000

def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device="cuda", model_half=False):
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

def load_model(model_config=None, model_ckpt_path=None, pretrained_name=None, pretransform_ckpt_path=None, device=device, model_half=False):
global model, sample_rate, sample_size

if pretrained_name is not None:
Expand Down
14 changes: 12 additions & 2 deletions stable_audio_tools/models/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,21 @@
from torch import nn
from torch.nn import functional as F

from torch.backends.cuda import sdp_kernel
from packaging import version

from dac.nn.layers import Snake1d

# Determine the device to use
if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

if device.type == 'cuda':
from torch.backends.cuda import sdp_kernel

class ResidualBlock(nn.Module):
def __init__(self, main, skip=None):
super().__init__()
Expand Down Expand Up @@ -41,7 +51,7 @@ def __init__(self, c_in, n_head=1, dropout_rate=0.):
self.out_proj = nn.Conv1d(c_in, c_in, 1)
self.dropout = nn.Dropout(dropout_rate, inplace=True)

self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
self.use_flash = True if device.type == 'cuda' and version.parse(torch.__version__) >= version.parse('2.0.0') else False

if not self.use_flash:
return
Expand Down
42 changes: 26 additions & 16 deletions stable_audio_tools/models/conditioners.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,16 @@

from torch import nn

if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

valid_autocast_device_types = {"cuda", "cpu"}
autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu"

class Conditioner(nn.Module):
def __init__(
self,
Expand Down Expand Up @@ -71,8 +81,8 @@ def __init__(self,

self.embedder = NumberEmbedder(features=output_dim)

def forward(self, floats: tp.List[float], device=None) -> tp.Any:
def forward(self, floats: tp.List[float], device=device) -> tp.Any:

# Cast the inputs to floats
floats = [float(x) for x in floats]

Expand Down Expand Up @@ -138,9 +148,10 @@ def __init__(self,
del self.model.model.audio_branch

gc.collect()
torch.cuda.empty_cache()
if device.type == 'cuda':
torch.cuda.empty_cache()

def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):
def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = device):
prompt_tokens = self.model.tokenizer(prompts)
attention_mask = prompt_tokens["attention_mask"].to(device=device, non_blocking=True)
prompt_features = self.model.model.text_branch(
Expand All @@ -151,7 +162,7 @@ def get_clap_features(self, prompts, layer_ix=-2, device: tp.Any = "cuda"):

return prompt_features, attention_mask

def forward(self, texts: tp.List[str], device: tp.Any = "cuda") -> tp.Any:
def forward(self, texts: tp.List[str], device: tp.Any = device) -> tp.Any:
self.model.to(device)

if self.use_text_features:
Expand Down Expand Up @@ -182,8 +193,6 @@ def __init__(self,
project_out: bool = False):
super().__init__(512, output_dim, project_out=project_out)

device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

# Suppress logging from transformers
previous_level = logging.root.manager.disable
logging.disable(logging.ERROR)
Expand All @@ -192,8 +201,8 @@ def __init__(self,
try:
import laion_clap
from laion_clap.clap_module.factory import load_state_dict as clap_load_state_dict
model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device='cpu')

model = laion_clap.CLAP_Module(enable_fusion=enable_fusion, amodel=audio_model_type, device=device)

if self.finetune:
self.model = model
Expand All @@ -216,9 +225,10 @@ def __init__(self,
del self.model.model.text_branch

gc.collect()
torch.cuda.empty_cache()
if device.type == 'cuda':
torch.cuda.empty_cache()

def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = "cuda") -> tp.Any:
def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple[torch.Tensor]] , device: tp.Any = device) -> tp.Any:

self.model.to(device)

Expand All @@ -228,7 +238,7 @@ def forward(self, audios: tp.Union[torch.Tensor, tp.List[torch.Tensor], tp.Tuple
# Convert to mono
mono_audios = audios.mean(dim=1)

with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(autocast_device_type, enabled=False):
audio_embedding = self.model.get_audio_embedding_from_data(mono_audios.float(), use_tensor=True)

audio_embedding = audio_embedding.unsqueeze(1).to(device)
Expand Down Expand Up @@ -310,12 +320,12 @@ def forward(self, texts: tp.List[str], device: tp.Union[torch.device, str]) -> t
attention_mask = encoded["attention_mask"].to(device).to(torch.bool)

self.model.eval()
with torch.cuda.amp.autocast(dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):

with torch.amp.autocast(autocast_device_type, dtype=torch.float16) and torch.set_grad_enabled(self.enable_grad):
embeddings = self.model(
input_ids=input_ids, attention_mask=attention_mask
)["last_hidden_state"]
)["last_hidden_state"]

embeddings = self.proj_out(embeddings.float())

embeddings = embeddings * attention_mask.unsqueeze(-1).float()
Expand Down
16 changes: 13 additions & 3 deletions stable_audio_tools/models/pretransforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,16 @@
from einops import rearrange
from torch import nn

if torch.cuda.is_available():
device = torch.device('cuda')
elif torch.backends.mps.is_available():
device = torch.device('mps')
else:
device = torch.device('cpu')

valid_autocast_device_types = {"cuda", "cpu"}
autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu"

class Pretransform(nn.Module):
def __init__(self, enable_grad, io_channels, is_discrete):
super().__init__()
Expand Down Expand Up @@ -250,9 +260,9 @@ def decode(self, z):
# return self.model.decode(z)

def tokenize(self, x):
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(autocast_device_type, enabled=False):
return self.model.encode(x.to(torch.float16))[0]

def decode_tokens(self, tokens):
with torch.cuda.amp.autocast(enabled=False):
with torch.amp.autocast(autocast_device_type, enabled=False):
return self.model.decode(tokens)
50 changes: 34 additions & 16 deletions stable_audio_tools/models/transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,21 @@
import torch
import torch.nn.functional as F
from torch import nn, einsum
from torch.cuda.amp import autocast
from torch.amp import autocast
from typing import Callable, Literal

if torch.cuda.is_available():
device = torch.device("cuda")
elif torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")

# Ensure device.type is valid for autocast
valid_autocast_device_types = {"cuda", "cpu"}
autocast_device_type = device.type if device.type in valid_autocast_device_types else "cpu"


try:
from flash_attn import flash_attn_func, flash_attn_kvpacked_func
except ImportError as e:
Expand Down Expand Up @@ -123,7 +135,7 @@ def forward_from_seq_len(self, seq_len):
t = torch.arange(seq_len, device = device)
return self.forward(t)

@autocast(enabled = False)
@autocast(device_type=autocast_device_type, enabled=False)
def forward(self, t):
device = self.inv_freq.device

Expand All @@ -148,8 +160,9 @@ def rotate_half(x):
x1, x2 = x.unbind(dim = -2)
return torch.cat((-x2, x1), dim = -1)

@autocast(enabled = False)
def apply_rotary_pos_emb(t, freqs, scale = 1):

@autocast(device_type=autocast_device_type, enabled=False)
def apply_rotary_pos_emb(t, freqs, scale=1):
out_dtype = t.dtype

# cast to float32 if necessary for numerical stability
Expand Down Expand Up @@ -311,15 +324,17 @@ def __init__(
if natten_kernel_size is not None:
return

self.use_pt_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0')
self.use_pt_flash = device.type == "cuda" and version.parse(
torch.__version__
) >= version.parse("2.0.0")

self.use_fa_flash = torch.cuda.is_available() and flash_attn_func is not None
self.use_fa_flash = device.type == "cuda" and flash_attn_func is not None

self.sdp_kwargs = dict(
enable_flash = True,
enable_math = True,
enable_mem_efficient = True
)
self.sdp_backends = [
torch.nn.attention.SDPBackend.FLASH_ATTENTION,
torch.nn.attention.SDPBackend.MATH,
torch.nn.attention.SDPBackend.EFFICIENT_ATTENTION,
]

def flash_attn(
self,
Expand Down Expand Up @@ -378,12 +393,15 @@ def flash_attn(
mask[..., 0] = mask[..., 0] | row_is_entirely_masked

causal = False

with torch.backends.cuda.sdp_kernel(**self.sdp_kwargs):

if device.type == "cuda":
with torch.nn.attention.sdpa_kernel(self.sdp_backends):
out = F.scaled_dot_product_attention(
q, k, v, attn_mask=mask, is_causal=causal
)
else:
out = F.scaled_dot_product_attention(
q, k, v,
attn_mask = mask,
is_causal = causal
q, k, v, attn_mask=mask, is_causal=causal
)

# for a row that is entirely masked out, should zero out the output of that row token
Expand Down