Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
256 changes: 241 additions & 15 deletions inference.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,11 @@
import argparse
from pathlib import Path
from typing import Generator

import torch
import torchaudio
from transformers import (
AutoConfig,
AutoModelForCausalLM,
AutoTokenizer,
WhisperFeatureExtractor,
)
from transformers import (AutoConfig, AutoModelForCausalLM, AutoTokenizer,
WhisperFeatureExtractor)

WHISPER_FEAT_CFG = {
"chunk_length": 30,
Expand All @@ -26,8 +23,60 @@
}


def chunk_iter(
inputs: torch.Tensor,
chunk_len: int,
stride_left: int,
stride_right: int,
) -> Generator[dict, None, None]:
"""
Iterate over audio in overlapping chunks following HuggingFace Transformers spec.

Args:
inputs: Audio tensor of shape (1, samples) or (samples,)
chunk_len: Number of samples per chunk
stride_left: Number of overlap samples on left side
stride_right: Number of overlap samples on right side

Yields:
dict with keys:
- chunk: Audio chunk tensor
- stride: Tuple of (chunk_length, left_stride, right_stride)
- is_last: Whether this is the final chunk
- start_sample: Starting sample index in original audio
"""
if inputs.dim() == 2:
inputs = inputs.squeeze(0)

inputs_len = inputs.shape[0]
step = chunk_len - stride_left - stride_right

for chunk_start_idx in range(0, inputs_len, step):
chunk_end_idx = chunk_start_idx + chunk_len
chunk = inputs[chunk_start_idx:chunk_end_idx]

# First chunk: no left stride
_stride_left = 0 if chunk_start_idx == 0 else stride_left
# Last chunk: no right stride
is_last = chunk_end_idx >= inputs_len
_stride_right = 0 if is_last else stride_right

# Skip if chunk is too small (only stride content)
if chunk.shape[0] > _stride_left:
yield {
"chunk": chunk.unsqueeze(0), # (1, samples)
"stride": (chunk.shape[0], _stride_left, _stride_right),
"is_last": is_last,
"start_sample": chunk_start_idx,
}

if is_last:
break


def get_audio_token_length(seconds, merge_factor=2):
def get_T_after_cnn(L_in, dilation=1):
L_out = L_in
for padding, kernel_size, stride in eval("[(1,3,1)] + [(1,3,2)] "):
L_out = L_in + 2 * padding - dilation * (kernel_size - 1) - 1
L_out = 1 + L_out // stride
Expand Down Expand Up @@ -66,7 +115,7 @@ def build_prompt(
audio_length = []
chunk_size = chunk_seconds * feature_extractor.sampling_rate
for start in range(0, wav.shape[1], chunk_size):
chunk = wav[:, start : start + chunk_size]
chunk = wav[:, start: start + chunk_size]
mel = feature_extractor(
chunk.numpy(),
sampling_rate=feature_extractor.sampling_rate,
Expand Down Expand Up @@ -101,7 +150,45 @@ def build_prompt(
return batch


def prepare_inputs(batch: dict, device: torch.device) -> tuple[dict, int]:
def build_single_chunk_prompt(
chunk: torch.Tensor,
tokenizer,
feature_extractor: WhisperFeatureExtractor,
merge_factor: int,
) -> dict:
"""Build prompt for a single audio chunk."""
mel = feature_extractor(
chunk.numpy(),
sampling_rate=feature_extractor.sampling_rate,
return_tensors="pt",
padding="max_length",
)["input_features"]

seconds = chunk.shape[1] / feature_extractor.sampling_rate
num_tokens = get_audio_token_length(seconds, merge_factor)

tokens = []
tokens += tokenizer.encode("<|user|>")
tokens += tokenizer.encode("\n")
tokens += tokenizer.encode("<|begin_of_audio|>")
audio_offset = len(tokens)
tokens += [0] * num_tokens
tokens += tokenizer.encode("<|end_of_audio|>")
tokens += tokenizer.encode("<|user|>")
tokens += tokenizer.encode("\nPlease transcribe this audio into text")
tokens += tokenizer.encode("<|assistant|>")
tokens += tokenizer.encode("\n")

return {
"input_ids": torch.tensor([tokens], dtype=torch.long),
"audios": mel,
"audio_offsets": [[audio_offset]],
"audio_length": [[num_tokens]],
"attention_mask": torch.ones(1, len(tokens), dtype=torch.long),
}


def prepare_inputs(batch: dict, device: str | torch.device) -> tuple[dict, int]:
tokens = batch["input_ids"].to(device)
attention_mask = batch["attention_mask"].to(device)
audios = batch["audios"].to(device)
Expand Down Expand Up @@ -156,6 +243,117 @@ def transcribe(
print(transcript or "[Empty transcription]")


def transcribe_sliding_window(
checkpoint_dir: Path,
audio_path: Path,
tokenizer_path: str,
max_new_tokens: int,
device: str,
chunk_length_s: float = 30.0,
stride_length_s: float | None = None,
):
"""
Transcribe long audio using sliding window approach.

Following HuggingFace Transformers spec:
- Audio is split into overlapping chunks
- Each chunk is transcribed independently
- Transcriptions are concatenated (stride regions provide context but are not double-transcribed)

Args:
checkpoint_dir: Path to model checkpoint
audio_path: Path to audio file
tokenizer_path: Path to tokenizer (defaults to checkpoint_dir)
max_new_tokens: Maximum tokens to generate per chunk
device: Device to run inference on
chunk_length_s: Length of each chunk in seconds (default: 30)
stride_length_s: Overlap on each side in seconds (default: chunk_length_s / 6)
"""
# Default stride: no overlap to avoid duplicate transcriptions
if stride_length_s is None:
stride_length_s = 0.0

tokenizer_source = tokenizer_path if tokenizer_path else checkpoint_dir
tokenizer = AutoTokenizer.from_pretrained(tokenizer_source)
feature_extractor = WhisperFeatureExtractor(**WHISPER_FEAT_CFG)

config = AutoConfig.from_pretrained(checkpoint_dir, trust_remote_code=True)
model = AutoModelForCausalLM.from_pretrained(
checkpoint_dir,
config=config,
torch_dtype=torch.bfloat16,
trust_remote_code=True,
).to(device)
model.eval()

# Load audio
audio_path = Path(audio_path)
wav, sr = torchaudio.load(str(audio_path))
wav = wav[:1, :] # mono
if sr != feature_extractor.sampling_rate:
wav = torchaudio.transforms.Resample(sr, feature_extractor.sampling_rate)(wav)

# Calculate chunk parameters in samples
sampling_rate = feature_extractor.sampling_rate
chunk_len = int(chunk_length_s * sampling_rate)
stride_left = int(stride_length_s * sampling_rate)
stride_right = int(stride_length_s * sampling_rate)

audio_duration = wav.shape[1] / sampling_rate
print(f"Audio duration: {audio_duration:.1f}s")
print(f"Chunk length: {chunk_length_s}s, Stride: {stride_length_s}s")
print(f"Step size: {chunk_length_s - 2 * stride_length_s}s")
print("----------")

transcripts = []
chunk_idx = 0

for chunk_data in chunk_iter(wav, chunk_len, stride_left, stride_right):
chunk = chunk_data["chunk"]
stride_info = chunk_data["stride"]
is_last = chunk_data["is_last"]
start_sample = chunk_data["start_sample"]

start_time = start_sample / sampling_rate
chunk_duration = chunk.shape[1] / sampling_rate

print(f"Processing chunk {chunk_idx + 1}: {start_time:.1f}s - {start_time + chunk_duration:.1f}s")

# Build prompt for this chunk
batch = build_single_chunk_prompt(
chunk,
tokenizer,
feature_extractor,
merge_factor=config.merge_factor,
)

model_inputs, prompt_len = prepare_inputs(batch, device)

with torch.inference_mode():
generated = model.generate(
**model_inputs,
max_new_tokens=max_new_tokens,
do_sample=False,
)

transcript_ids = generated[0, prompt_len:].cpu().tolist()
transcript = tokenizer.decode(transcript_ids, skip_special_tokens=True).strip()

if transcript:
transcripts.append(transcript)
print(f" -> {transcript[:80]}{'...' if len(transcript) > 80 else ''}")

chunk_idx += 1

# Combine all transcriptions
full_transcript = " ".join(transcripts)
print("----------")
print("Full transcription:")
print(full_transcript or "[Empty transcription]")

return full_transcript


def main():
parser = argparse.ArgumentParser(description="Minimal ASR transcription demo.")
parser.add_argument(
Expand All @@ -172,15 +370,43 @@ def main():
parser.add_argument(
"--device", type=str, default="cuda" if torch.cuda.is_available() else "cpu"
)
parser.add_argument(
"--sliding_window",
action="store_true",
help="Use sliding window for long audio transcription.",
)
parser.add_argument(
"--chunk_length_s",
type=float,
default=30.0,
help="Chunk length in seconds for sliding window mode (default: 30).",
)
parser.add_argument(
"--stride_length_s",
type=float,
default=None,
help="Stride/overlap length in seconds (default: chunk_length_s / 6).",
)
args = parser.parse_args()

transcribe(
checkpoint_dir=Path(args.checkpoint_dir),
audio_path=Path(args.audio),
tokenizer_path=args.tokenizer_path,
max_new_tokens=args.max_new_tokens,
device=args.device,
)
if args.sliding_window:
transcribe_sliding_window(
checkpoint_dir=Path(args.checkpoint_dir),
audio_path=Path(args.audio),
tokenizer_path=args.tokenizer_path,
max_new_tokens=args.max_new_tokens,
device=args.device,
chunk_length_s=args.chunk_length_s,
stride_length_s=args.stride_length_s,
)
else:
transcribe(
checkpoint_dir=Path(args.checkpoint_dir),
audio_path=Path(args.audio),
tokenizer_path=args.tokenizer_path,
max_new_tokens=args.max_new_tokens,
device=args.device,
)


if __name__ == "__main__":
Expand Down