Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 7 additions & 6 deletions gigaam/decoding.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,17 +61,18 @@ def decode(self, head: CTCHead, encoded: Tensor, lengths: Tensor) -> List[str]:
labels = log_probs.argmax(dim=-1, keepdim=False)

skip_mask = labels != self.blank_id
skip_mask[:, 1:] = torch.logical_and(
skip_mask[:, 1:], labels[:, 1:] != labels[:, :-1]
)
skip_mask[:, 1:] &= (labels[:, 1:] != labels[:, :-1])
for i, length in enumerate(lengths):
skip_mask[i, length:] = 0

pred_texts: List[str] = []
for i in range(b):
pred_texts.append(
"".join(self.tokenizer.decode(labels[i][skip_mask[i]].cpu().tolist()))
)
valid_labels = labels[i][skip_mask[i]]
if len(valid_labels) > 0:
label_list = valid_labels.cpu().tolist()
pred_texts.append(self.tokenizer.decode(label_list))
else:
pred_texts.append("")
return pred_texts


Expand Down
19 changes: 18 additions & 1 deletion gigaam/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from torch import Tensor, nn

from .preprocess import SAMPLE_RATE, load_audio
from .preprocess import SAMPLE_RATE, load_audio, load_audio_from_bytes
from .utils import onnx_converter

LONGFORM_THRESHOLD = 25 * SAMPLE_RATE
Expand Down Expand Up @@ -101,6 +101,23 @@ def transcribe(self, wav_file: str) -> str:
encoded, encoded_len = self.forward(wav, length)
return self.decoding.decode(self.head, encoded, encoded_len)[0]

@torch.inference_mode()
def transcribe_bytes(self, audio_bytes: bytes) -> str:
"""
Transcribes raw PCM16 audio bytes directly into text.
This is the fastest method for in-memory processing.
"""

wav = load_audio_from_bytes(audio_bytes)
wav = wav.to(self._device).to(self._dtype).unsqueeze(0)
length = torch.full([1], wav.shape[-1], device=self._device)

if length.item() > LONGFORM_THRESHOLD:
raise ValueError("Too long audio, use 'transcribe_longform' method.")

encoded, encoded_len = self.forward(wav, length)
return self.decoding.decode(self.head, encoded, encoded_len)[0]

def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor:
"""
Encoder-decoder forward to save model entirely in onnx format.
Expand Down
13 changes: 13 additions & 0 deletions gigaam/preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,19 @@ def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor:
warnings.simplefilter("ignore", category=UserWarning)
return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0

def load_audio_from_bytes(audio_bytes: bytes, device: torch.device = None) -> Tensor:
"""
Load audio directly from PCM16 bytes without any external tools.
This is the fastest method for in-memory audio processing.
"""
with warnings.catch_warnings():
warnings.simplefilter("ignore", category=UserWarning)
audio_tensor = torch.frombuffer(audio_bytes, dtype=torch.int16).float() / 32768.0

if device is not None:
audio_tensor = audio_tensor.to(device)

return audio_tensor

class SpecScaler(nn.Module):
"""
Expand Down