diff --git a/gigaam/decoding.py b/gigaam/decoding.py index 6f9d239..150a166 100644 --- a/gigaam/decoding.py +++ b/gigaam/decoding.py @@ -61,17 +61,18 @@ def decode(self, head: CTCHead, encoded: Tensor, lengths: Tensor) -> List[str]: labels = log_probs.argmax(dim=-1, keepdim=False) skip_mask = labels != self.blank_id - skip_mask[:, 1:] = torch.logical_and( - skip_mask[:, 1:], labels[:, 1:] != labels[:, :-1] - ) + skip_mask[:, 1:] &= (labels[:, 1:] != labels[:, :-1]) for i, length in enumerate(lengths): skip_mask[i, length:] = 0 pred_texts: List[str] = [] for i in range(b): - pred_texts.append( - "".join(self.tokenizer.decode(labels[i][skip_mask[i]].cpu().tolist())) - ) + valid_labels = labels[i][skip_mask[i]] + if len(valid_labels) > 0: + label_list = valid_labels.cpu().tolist() + pred_texts.append(self.tokenizer.decode(label_list)) + else: + pred_texts.append("") return pred_texts diff --git a/gigaam/model.py b/gigaam/model.py index ac38c1d..9329e62 100644 --- a/gigaam/model.py +++ b/gigaam/model.py @@ -5,7 +5,7 @@ import torch from torch import Tensor, nn -from .preprocess import SAMPLE_RATE, load_audio +from .preprocess import SAMPLE_RATE, load_audio, load_audio_from_bytes from .utils import onnx_converter LONGFORM_THRESHOLD = 25 * SAMPLE_RATE @@ -101,6 +101,23 @@ def transcribe(self, wav_file: str) -> str: encoded, encoded_len = self.forward(wav, length) return self.decoding.decode(self.head, encoded, encoded_len)[0] + @torch.inference_mode() + def transcribe_bytes(self, audio_bytes: bytes) -> str: + """ + Transcribes raw PCM16 audio bytes directly into text. + This is the fastest method for in-memory processing. + """ + + wav = load_audio_from_bytes(audio_bytes) + wav = wav.to(self._device).to(self._dtype).unsqueeze(0) + length = torch.full([1], wav.shape[-1], device=self._device) + + if length.item() > LONGFORM_THRESHOLD: + raise ValueError("Too long audio, use 'transcribe_longform' method.") + + encoded, encoded_len = self.forward(wav, length) + return self.decoding.decode(self.head, encoded, encoded_len)[0] + def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor: """ Encoder-decoder forward to save model entirely in onnx format. diff --git a/gigaam/preprocess.py b/gigaam/preprocess.py index fb6ebde..3be0e8f 100644 --- a/gigaam/preprocess.py +++ b/gigaam/preprocess.py @@ -39,6 +39,19 @@ def load_audio(audio_path: str, sample_rate: int = SAMPLE_RATE) -> Tensor: warnings.simplefilter("ignore", category=UserWarning) return torch.frombuffer(audio, dtype=torch.int16).float() / 32768.0 +def load_audio_from_bytes(audio_bytes: bytes, device: torch.device = None) -> Tensor: + """ + Load audio directly from PCM16 bytes without any external tools. + This is the fastest method for in-memory audio processing. + """ + with warnings.catch_warnings(): + warnings.simplefilter("ignore", category=UserWarning) + audio_tensor = torch.frombuffer(audio_bytes, dtype=torch.int16).float() / 32768.0 + + if device is not None: + audio_tensor = audio_tensor.to(device) + + return audio_tensor class SpecScaler(nn.Module): """