diff --git a/.github/workflows/gigaam.yml b/.github/workflows/gigaam.yml index 79d88a1..36b166f 100644 --- a/.github/workflows/gigaam.yml +++ b/.github/workflows/gigaam.yml @@ -84,6 +84,10 @@ jobs: run: | pytest -v tests/test_onnx.py --tb=short + - name: Run timestamps tests + run: | + pytest -v tests/test_timestamps.py --tb=short + - name: Run all tests with coverage if: matrix.python-version == '3.10' env: @@ -126,7 +130,7 @@ jobs: - name: Install linters run: | pip install --upgrade pip - pip install flake8 black isort mypy types-requests + pip install .[lint] - name: Check code formatting with black run: | diff --git a/README.md b/README.md index 509abd6..358d287 100644 --- a/README.md +++ b/README.md @@ -37,7 +37,7 @@ git clone https://github.com/salute-developers/GigaAM.git cd GigaAM # Install the package requirements -pip install -e . +pip install -e .[torch] # (optionally) Verify the installation: pip install -e ".[tests]" @@ -108,13 +108,17 @@ model = gigaam.load_model(model_name) transcription = model.transcribe(audio_path) print(transcription) +# ASR with word-level timestamps +result = model.transcribe(audio_path, word_timestamps=True) +for word in result.words: + print(f" [{word.start:.2f} - {word.end:.2f}] {word.text}") + # and long-form ASR import os os.environ["HF_TOKEN"] = -utterances = model.transcribe_longform(long_audio_path) -for utt in utterances: - transcription, (start, end) = utt["transcription"], utt["boundaries"] - print(f"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {transcription}") +result = model.transcribe_longform(long_audio_path) +for segment in result: + print(f"[{gigaam.format_time(segment.start)} - {gigaam.format_time(segment.end)}]: {segment.text}") # Emotion recognition model = gigaam.load_model("emo") diff --git a/README_ru.md b/README_ru.md index 3d458af..a653fa6 100644 --- a/README_ru.md +++ b/README_ru.md @@ -36,7 +36,7 @@ git clone https://github.com/salute-developers/GigaAM.git cd GigaAM # Установить зависимости -pip install -e . +pip install -e .[torch] # (опционально) Проверить установку: pip install -e ".[tests]" @@ -107,13 +107,17 @@ model = gigaam.load_model(model_name) transcription = model.transcribe(audio_path) print(transcription) +# Распознавание речи с таймстемпами на уровне слов +result = model.transcribe(audio_path, word_timestamps=True) +for word in result.words: + print(f" [{word.start:.2f} - {word.end:.2f}] {word.text}") + # Распознавание на длинном аудио import os os.environ["HF_TOKEN"] = "" -utterances = model.transcribe_longform(long_audio_path) -for utt in utterances: - transcription, (start, end) = utt["transcription"], utt["boundaries"] - print(f"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {transcription}") +result = model.transcribe_longform(long_audio_path) +for segment in result: + print(f"[{gigaam.format_time(segment.start)} - {gigaam.format_time(segment.end)}]: {segment.text}") # Распознавание эмоций model = gigaam.load_model("emo") diff --git a/colab_example.ipynb b/colab_example.ipynb index 03d4b46..b567e36 100644 --- a/colab_example.ipynb +++ b/colab_example.ipynb @@ -150,26 +150,7 @@ "id": "D8vK2SEWVq0o", "outputId": "7f87497e-de4e-44a1-8394-6f735c292a34" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Embeds: tensor([[[-0.2824, 0.3641, 0.4503, ..., -0.4728, -0.4027, -0.2415],\n", - " [ 0.1607, -0.4995, -0.0565, ..., -0.6242, -0.2318, -0.2053],\n", - " [-1.1857, -1.0032, -0.6091, ..., -0.5142, -0.3736, -0.2652],\n", - " ...,\n", - " [ 0.0187, -0.3757, -0.8965, ..., 0.1718, 0.0567, 0.1307],\n", - " [ 0.2691, -0.0672, -0.5010, ..., -1.4433, -1.4832, -1.4515],\n", - " [-1.5648, -1.6692, -1.2828, ..., 0.5110, 0.4826, 0.0133]]],\n", - " device='cuda:0', grad_fn=)\n", - "\n", - "Transcription: Ничьих не требуя похвал, Счастлив уж я надеждой сладкой, Что дева с трепетом любви Посмотрит, может быть, украдкой На песни грешные мои. У лукоморья дуб зелёный.\n", - "\n", - "Emotions: angry: 0.000, sad: 0.002, neutral: 0.923, positive: 0.075\n" - ] - } - ], + "outputs": [], "source": [ "# Load test audio\n", "audio_path = gigaam.utils.download_short_audio()\n", @@ -186,6 +167,12 @@ "transcription = model.transcribe(audio_path)\n", "print(\"\\nTranscription:\", transcription)\n", "\n", + "# ASR with word-level timestamps\n", + "result = model.transcribe(audio_path, word_timestamps=True)\n", + "print(\"\\nWord timestamps:\")\n", + "for word in result.words:\n", + " print(f\" [{word.start:.2f} - {word.end:.2f}] {word.text}\")\n", + "\n", "# Emotion recognition\n", "model = gigaam.load_model(\"emo\")\n", "emotion2prob = model.get_probs(audio_path)\n", @@ -299,7 +286,7 @@ }, "outputs": [], "source": [ - "! pip install -e .[longform]" + "! pip install pyannote.audio==4.0" ] }, { @@ -374,19 +361,7 @@ "id": "ZFlpV4VXapk7", "outputId": "17a36be8-d1db-4ce8-cf33-5812b93c5f14" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[00:00:00 - 00:16:80]: Вечерня отошла давно, Но в кельях тихо и темно; Уже и сам игумен строгий Свои молитвы прекратил И кости ветхие склонил, Перекрестясь на одр убогий. Кругом и сон, и тишина; Но церкви дверь отворена.\n", - "[00:17:07 - 00:32:54]: Трепещет луч лампады, И тускло озаряет он И тёмную живопись икон, и возглащённые оклады. И раздаётся в тишине То тяжкий вздох, то шёпот важный, И мрачно дремлет в тишине старинный свод.\n", - "[00:32:95 - 00:49:30]: Глухой и влажный Стоят за клиросом чернец и грешник, Неподвижны оба. И шёпот их — Как глаз из гроба, И грешник бледен, как мертвец — Монах. Несчастный! Полно, перестань!\n", - "[00:49:81 - 01:05:65]: Ужасна исповедь злодея, Заплачена тобою дань Тому, Кто в злобе пламенея Лукавого грешника блюдёт И к вечной гибели ведёт. Смирись, опомнись. Время, время. Раскаянье, покров\n", - "[01:05:94 - 01:10:88]: Я разрешу тебя, грехов сложи мучительное бремя.\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "import warnings\n", @@ -399,10 +374,9 @@ "long_audio_path = gigaam.utils.download_long_audio()\n", "model = gigaam.load_model(\"v3_e2e_rnnt\")\n", "\n", - "utterances = model.transcribe_longform(long_audio_path)\n", - "for utt in utterances:\n", - " transcription, (start, end) = utt[\"transcription\"], utt[\"boundaries\"]\n", - " print(f\"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {transcription}\")" + "result = model.transcribe_longform(long_audio_path)\n", + "for segment in result:\n", + " print(f\"[{gigaam.format_time(segment.start)} - {gigaam.format_time(segment.end)}]: {segment.text}\")" ] }, { @@ -433,15 +407,7 @@ "id": "fmwFIE9qaf7y", "outputId": "9d85b25e-0346-4306-cf3c-131a6a63af1f" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "['Ничьих не требуя похвал, Счастлив уж я надеждой сладкой, Что дева с трепетом любви Посмотрит, может быть, украдкой На песни грешные мои. У лукоморья дуб зелёный.', 'Ничьих не требуя похвал, Счастлив уж я надеждой сладкой, Что дева с трепетом любви Посмотрит, может быть, украдкой На песни грешные мои. У лукоморья дуб зелёный.']\n" - ] - } - ], + "outputs": [], "source": [ "import librosa\n", "import torch\n", @@ -459,7 +425,9 @@ " encoded, encoded_len = model(\n", " wav_tns.to(model._device).to(model._dtype), lengths.to(model._device)\n", " )\n", - " print(model.decoding.decode(model.head, encoded, encoded_len))\n", + " results = model.decoding.decode(model.head, encoded, encoded_len)\n", + " for token_ids, _ in results:\n", + " print(model.decoding.tokenizer.decode(token_ids))\n", "\n", "# outputs expected to be equal" ] @@ -530,19 +498,7 @@ "id": "xGeCADORttRY", "outputId": "2cff3399-663c-497a-d41a-2d028a3de65c" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "[00:00:00 - 00:16:80]: Вечерня отошла давно, Но в кельях тихо и темно; Уже и сам игумен строгий Свои молитвы прекратил И кости ветхие склонил, Перекрестясь на одр убогий. Кругом и сон, и тишина; Но церкви дверь отворена.\n", - "[00:17:07 - 00:32:54]: Трепещет луч лампады, И тускло озаряет он И тёмную живопись икон, и возглащённые оклады. И раздаётся в тишине То тяжкий вздох, то шёпот важный, И мрачно дремлет в тишине старинный свод.\n", - "[00:32:95 - 00:49:30]: Глухой и влажный Стоят за клиросом чернец и грешник, Неподвижны оба. И шёпот их — Как глаз из гроба, И грешник бледен, как мертвец — Монах. Несчастный! Полно, перестань!\n", - "[00:49:81 - 01:05:65]: Ужасна исповедь злодея, Заплачена тобою дань Тому, Кто в злобе пламенея Лукавого грешника блюдёт И к вечной гибели ведёт. Смирись, опомнись. Время, время. Раскаянье, покров\n", - "[01:05:94 - 01:10:88]: Я разрешу тебя, грехов сложи мучительное бремя.\n" - ] - } - ], + "outputs": [], "source": [ "import torch\n", "\n", @@ -557,7 +513,8 @@ " wav_tns, lengths = wav_tns.to(model._device).to(model._dtype), lengths.to(model._device)\n", " with torch.no_grad():\n", " encoded, encoded_len = model(wav_tns, lengths)\n", - " pred_texts.extend(model.decoding.decode(model.head, encoded, encoded_len))\n", + " results = model.decoding.decode(model.head, encoded, encoded_len)\n", + " pred_texts.extend(model.decoding.tokenizer.decode(ids) for ids, _ in results)\n", "\n", "for (start, end), text in zip(boundaries, pred_texts):\n", " print(f\"[{gigaam.format_time(start)} - {gigaam.format_time(end)}]: {text}\")" @@ -810,31 +767,16 @@ "id": "Ek1gKS6_cW-G", "outputId": "39308dfb-f63d-4eb5-d944-30bcf6d7e462" }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Longform transcription:\n", - "\n", - "[0.0000 - 16.8047]: Вечерня отошла давно, Но в кельях тихо и темно; Уже и сам игумен строгий Свои молитвы прекратил И кости ветхие склонил, Перекрестясь на одр убогий. Кругом и сон, и тишина; Но церкви дверь отворена.\n", - "[17.0747 - 32.5491]: Трепещет луч лампады, И тускло озаряет он И тёмную живопись икон, и возглащённые оклады. И раздаётся в тишине То тяжкий вздох, то шёпот важный, И мрачно дремлет в тишине старинный свод.\n", - "[32.9541 - 49.3060]: Глухой и влажный Стоят за клиросом чернец и грешник, Неподвижны оба. И шёпот их — Как глаз из гроба, И грешник бледен, как мертвец — Монах. Несчастный! Полно, перестань!\n", - "[49.8122 - 65.6578]: Ужасна исповедь злодея, Заплачена тобою дань Тому, Кто в злобе пламенея Лукавого грешника блюдёт И к вечной гибели ведёт. Смирись, опомнись. Время, время. Раскаянье, покров\n", - "[65.9447 - 70.8891]: Я разрешу тебя, грехов сложи мучительное бремя.\n" - ] - } - ], + "outputs": [], "source": [ "import os\n", "os.environ[\"HF_TOKEN\"] = \"\"\n", "\n", "model = AutoModel.from_pretrained(repo_name, revision=\"e2e_rnnt\", trust_remote_code=True)\n", - "utterances = model.transcribe_longform(\"long_example.wav\")\n", + "result = model.transcribe_longform(\"long_example.wav\")\n", "print(\"Longform transcription:\\n\")\n", - "for utt in utterances:\n", - " transcription, (start, end) = utt[\"transcription\"], utt[\"boundaries\"]\n", - " print(f\"[{start:.4f} - {end:.4f}]: {transcription}\")" + "for segment in result:\n", + " print(f\"[{segment.start:.4f} - {segment.end:.4f}]: {segment.text}\")" ] }, { diff --git a/gigaam/decoding.py b/gigaam/decoding.py index 6f9d239..87c1b5d 100644 --- a/gigaam/decoding.py +++ b/gigaam/decoding.py @@ -1,4 +1,4 @@ -from typing import List, Optional +from typing import List, Optional, Tuple import torch from sentencepiece import SentencePieceProcessor @@ -35,6 +35,14 @@ def __len__(self): """ return len(self.vocab) if self.charwise else len(self.model) + def id_to_str(self, token_id: int) -> str: + """ + Convert a single token ID to its string representation. + """ + if self.charwise: + return self.vocab[token_id] + return self.model.IdToPiece(token_id) + class CTCGreedyDecoding: """ @@ -46,33 +54,46 @@ def __init__(self, vocabulary: List[str], model_path: Optional[str] = None): self.blank_id = len(self.tokenizer) @torch.inference_mode() - def decode(self, head: CTCHead, encoded: Tensor, lengths: Tensor) -> List[str]: + def decode( + self, + head: "CTCHead", + encoded: Tensor, + lengths: Tensor, + ) -> List[Tuple[List[int], List[int]]]: """ - Decode the output of a CTC model into a list of hypotheses. + CTC greedy decode: returns (token_ids, token_frames) per sample. + Token frames are time indices (0..T-1) where a token is emitted. """ log_probs = head(encoder_output=encoded) assert ( - len(log_probs.shape) == 3 - ), f"Expected log_probs shape {log_probs.shape} == [B, T, C]" - b, _, c = log_probs.shape + log_probs.ndim == 3 + ), f"Expected log_probs [B,T,C], got {tuple(log_probs.shape)}" + B, T, C = log_probs.shape assert ( - c == len(self.tokenizer) + 1 - ), f"Num classes {c} != len(vocab) + 1 {len(self.tokenizer) + 1}" - labels = log_probs.argmax(dim=-1, keepdim=False) + C == len(self.tokenizer) + 1 + ), f"Num classes {C} != len(vocab)+1 {len(self.tokenizer)+1}" + + labels = log_probs.argmax(dim=-1) + device = labels.device + + lengths = lengths.to(device=device).clamp(min=0, max=T) skip_mask = labels != self.blank_id - skip_mask[:, 1:] = torch.logical_and( - skip_mask[:, 1:], labels[:, 1:] != labels[:, :-1] - ) - for i, length in enumerate(lengths): - skip_mask[i, length:] = 0 + skip_mask[:, 1:] &= labels[:, 1:] != labels[:, :-1] + + time = torch.arange(T, device=device)[None, :] + skip_mask &= time < lengths[:, None] + + idx = skip_mask.nonzero(as_tuple=False) + batch_idx = idx[:, 0] + token_frames_flat = idx[:, 1] + token_ids_flat = labels[skip_mask] - pred_texts: List[str] = [] - for i in range(b): - pred_texts.append( - "".join(self.tokenizer.decode(labels[i][skip_mask[i]].cpu().tolist())) - ) - return pred_texts + counts = torch.bincount(batch_idx, minlength=B).cpu().tolist() + ids_splits = token_ids_flat.cpu().split(counts) + fr_splits = token_frames_flat.cpu().split(counts) + + return [(ids.tolist(), fr.tolist()) for ids, fr in zip(ids_splits, fr_splits)] class RNNTGreedyDecoding: @@ -82,46 +103,68 @@ def __init__( model_path: Optional[str] = None, max_symbols_per_step: int = 10, ): - """ - Class for performing greedy decoding of RNN-T outputs. - """ self.tokenizer = Tokenizer(vocabulary, model_path) self.blank_id = len(self.tokenizer) self.max_symbols = max_symbols_per_step - def _greedy_decode(self, head: RNNTHead, x: Tensor, seqlen: Tensor) -> str: + def _greedy_decode( + self, + head: "RNNTHead", + x: Tensor, + seqlen: Tensor, + ) -> Tuple[List[int], List[int]]: """ - Internal helper function for performing greedy decoding on a single sequence. + Greedy decode a single sequence. + Returns (token_ids, token_frames). + Token frames are encoder time indices t where a token is emitted. """ + T = int(seqlen.item()) if torch.is_tensor(seqlen) else int(seqlen) + hyp: List[int] = [] + token_frames: List[int] = [] dec_state: Optional[Tensor] = None + last_label: Optional[Tensor] = None - for t in range(seqlen): + + last_label_buf = torch.empty((1, 1), device=x.device, dtype=torch.long) + + for t in range(T): f = x[t, :, :].unsqueeze(1) - not_blank = True new_symbols = 0 - while not_blank and new_symbols < self.max_symbols: + + while new_symbols < self.max_symbols: g, hidden = head.decoder.predict(last_label, dec_state) - k = head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item() + k = int(head.joint.joint(f, g)[0, 0, 0, :].argmax(0).item()) + if k == self.blank_id: - not_blank = False - else: - hyp.append(int(k)) - dec_state = hidden - last_label = torch.tensor([[hyp[-1]]]).to(x.device) - new_symbols += 1 + break + + hyp.append(k) + token_frames.append(t) + + dec_state = hidden + last_label_buf.fill_(k) + last_label = last_label_buf + new_symbols += 1 - return self.tokenizer.decode(hyp) + return hyp, token_frames @torch.inference_mode() - def decode(self, head: RNNTHead, encoded: Tensor, enc_len: Tensor) -> List[str]: + def decode( + self, + head: "RNNTHead", + encoded: Tensor, + enc_len: Tensor, + ) -> List[Tuple[List[int], List[int]]]: """ - Decode the output of an RNN-T model into a list of hypotheses. + Decode RNN-T outputs for a batch. + Returns (token_ids, token_frames) per sample. """ - b = encoded.shape[0] - pred_texts = [] + B = encoded.shape[0] encoded = encoded.transpose(1, 2) - for i in range(b): + + results: List[Tuple[List[int], List[int]]] = [] + for i in range(B): inseq = encoded[i, :, :].unsqueeze(1) - pred_texts.append(self._greedy_decode(head, inseq, enc_len[i])) - return pred_texts + results.append(self._greedy_decode(head, inseq, enc_len[i])) + return results diff --git a/gigaam/model.py b/gigaam/model.py index ac38c1d..aefe189 100644 --- a/gigaam/model.py +++ b/gigaam/model.py @@ -1,4 +1,4 @@ -from typing import Dict, List, Tuple, Union +from typing import Dict, List, Optional, Tuple import hydra import omegaconf @@ -6,6 +6,7 @@ from torch import Tensor, nn from .preprocess import SAMPLE_RATE, load_audio +from .types import LongformTranscriptionResult, Segment, TranscriptionResult, Word from .utils import onnx_converter LONGFORM_THRESHOLD = 25 * SAMPLE_RATE @@ -89,17 +90,59 @@ def __init__(self, cfg: omegaconf.DictConfig): self.head = hydra.utils.instantiate(self.cfg.head) self.decoding = hydra.utils.instantiate(self.cfg.decoding) + def _decode( + self, + encoded: Tensor, + encoded_len: Tensor, + audio_length: int, + word_timestamps: bool = False, + ) -> Tuple[str, Optional[List[Word]]]: + """ + Decode encoder output to text with optional word-level timestamps. + + Args: + encoded: Encoder output tensor + encoded_len: Length of encoded sequence + audio_length: Original audio length in samples + word_timestamps: Whether to compute word-level timestamps + + Returns: + Tuple of (text, words) where words is None if word_timestamps=False + """ + token_ids, token_frames = self.decoding.decode(self.head, encoded, encoded_len)[ + 0 + ] + + text = self.decoding.tokenizer.decode(token_ids) + + if not word_timestamps: + return text, None + + from .timestamps_utils import compute_frame_shift, frames_to_words + + frame_shift = compute_frame_shift(audio_length, int(encoded_len[0].item())) + words = frames_to_words( + self.decoding.tokenizer, token_ids, token_frames, frame_shift + ) + return text, words + @torch.inference_mode() - def transcribe(self, wav_file: str) -> str: + def transcribe( + self, wav_file: str, word_timestamps: bool = False + ) -> TranscriptionResult: """ Transcribes a short audio file into text. + Returns TranscriptionResult with optional word-level timestamps. """ wav, length = self.prepare_wav(wav_file) if length.item() > LONGFORM_THRESHOLD: raise ValueError("Too long wav file, use 'transcribe_longform' method.") encoded, encoded_len = self.forward(wav, length) - return self.decoding.decode(self.head, encoded, encoded_len)[0] + text, words = self._decode( + encoded, encoded_len, int(length[0].item()), word_timestamps + ) + return TranscriptionResult(text=text, words=words) def forward_for_export(self, features: Tensor, feature_lengths: Tensor) -> Tensor: """ @@ -147,27 +190,50 @@ def _to_onnx(self, dir_path: str = ".") -> None: @torch.inference_mode() def transcribe_longform( - self, wav_file: str, **kwargs - ) -> List[Dict[str, Union[str, Tuple[float, float]]]]: + self, wav_file: str, word_timestamps: bool = False, **kwargs + ) -> LongformTranscriptionResult: """ Transcribes a long audio file by splitting it into segments and then transcribing each segment. + Returns LongformTranscriptionResult with segments containing optional word-level timestamps. """ from .vad_utils import segment_audio_file - transcribed_segments = [] segments, boundaries = segment_audio_file( wav_file, SAMPLE_RATE, device=self._device, **kwargs ) + + result_segments: List[Segment] = [] for segment, segment_boundaries in zip(segments, boundaries): wav = segment.to(self._device).unsqueeze(0).to(self._dtype) length = torch.full([1], wav.shape[-1], device=self._device) encoded, encoded_len = self.forward(wav, length) - result = self.decoding.decode(self.head, encoded, encoded_len)[0] - transcribed_segments.append( - {"transcription": result, "boundaries": segment_boundaries} + + seg_start = segment_boundaries[0] + seg_end = segment_boundaries[1] + + text, words = self._decode( + encoded, encoded_len, int(length[0].item()), word_timestamps ) - return transcribed_segments + + if word_timestamps: + # Adjust word timestamps to absolute time positions + adjusted_words = [ + Word( + text=w.text, + start=round(w.start + seg_start, 3), + end=round(w.end + seg_start, 3), + ) + for w in words + ] + result_segments.append( + Segment( + text=text, start=seg_start, end=seg_end, words=adjusted_words + ) + ) + else: + result_segments.append(Segment(text=text, start=seg_start, end=seg_end)) + return LongformTranscriptionResult(segments=result_segments) class GigaAMEmo(GigaAM): diff --git a/gigaam/timestamps_utils.py b/gigaam/timestamps_utils.py new file mode 100644 index 0000000..472c989 --- /dev/null +++ b/gigaam/timestamps_utils.py @@ -0,0 +1,53 @@ +from typing import List + +from .decoding import Tokenizer +from .preprocess import SAMPLE_RATE +from .types import Word + + +def compute_frame_shift(audio_length_samples: int, seq_len: int) -> float: + """Compute frame shift (seconds per encoder frame).""" + return audio_length_samples / SAMPLE_RATE / seq_len + + +def frames_to_words( + tokenizer: Tokenizer, + token_ids: List[int], + token_frames: List[int], + frame_shift: float, +) -> List[Word]: + """ + Convert token-level frame indices to word-level timestamps. + Groups tokens into words at word boundaries (space or sentencepiece '▁' prefix). + """ + words: List[Word] = [] + current_chars: List[str] = [] + current_frames: List[int] = [] + + def commit(): + if not current_chars: + return + text = "".join(current_chars).strip() + if not text: + current_chars.clear() + current_frames.clear() + return + start = current_frames[0] * frame_shift + end = (current_frames[-1] + 1) * frame_shift + words.append(Word(text=text, start=start, end=end)) + current_chars.clear() + current_frames.clear() + + for token_id, frame in zip(token_ids, token_frames): + char = tokenizer.id_to_str(token_id) + if char.startswith("▁"): + commit() + char = char[1:] + elif char == " ": + commit() + continue + current_chars.append(char) + current_frames.append(frame) + + commit() + return words diff --git a/gigaam/types.py b/gigaam/types.py new file mode 100644 index 0000000..d66b0e9 --- /dev/null +++ b/gigaam/types.py @@ -0,0 +1,57 @@ +from dataclasses import dataclass +from typing import List, Optional + + +@dataclass +class Word: + text: str + start: float + end: float + + +@dataclass +class TranscriptionResult: + text: str + words: Optional[List[Word]] = None + + def __str__(self) -> str: + return self.text + + +@dataclass +class Segment: + text: str + start: float + end: float + words: Optional[List[Word]] = None + + +@dataclass +class LongformTranscriptionResult: + segments: List[Segment] + + @property + def words(self) -> List[Word]: + """Flatten all words from all segments.""" + result = [] + for seg in self.segments: + if seg.words: + result.extend(seg.words) + return result + + @property + def has_word_timestamps(self) -> bool: + return bool(self.segments) and self.segments[0].words is not None + + @property + def text(self) -> str: + return " ".join(s.text for s in self.segments) + + def __str__(self) -> str: + return self.text + + def __iter__(self): + return iter(self.segments) + + def __len__(self) -> int: + return len(self.segments) diff --git a/pyproject.toml b/pyproject.toml index 37eaa3e..1cf1912 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -20,12 +20,14 @@ dependencies = [ "onnx==1.19.*", "onnxruntime==1.23.*", "sentencepiece", - "torch>=2.5,<2.9", - "torchaudio>=2.5,<2.9", "tqdm", ] [project.optional-dependencies] +torch = [ + "torch>=2.5,<2.9", + "torchaudio>=2.5,<2.9", +] longform = [ "torch==2.8.*", "torchaudio==2.8.*", @@ -40,9 +42,22 @@ tests = [ "soundfile", "librosa", ] +lint = [ + "black==26.1.0", + "isort==7.0.0", + "flake8", + "mypy", + "types-requests", +] [project.urls] Homepage = "https://github.com/salute-developers/GigaAM/" [tool.setuptools.packages.find] include = ["gigaam"] + +[tool.black] +line-length = 88 + +[tool.isort] +profile = "black" diff --git a/tests/test_batching.py b/tests/test_batching.py index bf41f45..e7efc73 100644 --- a/tests/test_batching.py +++ b/tests/test_batching.py @@ -80,9 +80,7 @@ def compare_outputs(output1, output2, atol=0.03): feat1, feat2 = feat1[:, :, :min_len], feat2[:, :, :min_len] abs_diff = torch.abs(feat1 - feat2).max().item() close = abs_diff < atol - return close, { - "max_absolute_difference": abs_diff, - } + return close, {"max_absolute_difference": abs_diff} @pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) diff --git a/tests/test_loading.py b/tests/test_loading.py index bdea44f..7a99c50 100644 --- a/tests/test_loading.py +++ b/tests/test_loading.py @@ -44,13 +44,13 @@ def run_model_method(model, revision, test_audio): else: result = model.transcribe(test_audio) if "e2e" in revision: - assert ( - _predictions[revision] == result - ), f"Transcription failed ({revision}): {result}" + assert _predictions[revision] == str( + result + ), f"Transcription failed ({revision}): {str(result)}" else: - assert ( - _predictions["asr"] == result - ), f"Transcription failed ({revision}): {result}" + assert _predictions["asr"] == str( + result + ), f"Transcription failed ({revision}): {str(result)}" logger.info(f"{revision}: Transcription completed") diff --git a/tests/test_longform.py b/tests/test_longform.py index 534f57a..835e373 100644 --- a/tests/test_longform.py +++ b/tests/test_longform.py @@ -82,8 +82,9 @@ def generate_long_audio(duration=60.0, sr=16000, include_silence=True): ) envelope = signal.windows.tukey(len(segment), alpha=0.1) segment = segment * envelope - start_idx, end_idx = int(current_time * sr), int(current_time * sr) + len( - segment + start_idx, end_idx = ( + int(current_time * sr), + int(current_time * sr) + len(segment), ) audio[start_idx:end_idx] = segment if include_silence and i < len(segment_durations) - 1: @@ -152,24 +153,30 @@ def test_segmentation_functionality(duration): @pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) def test_transcribe_longform(revision): """Test longform transcription for different models""" + from gigaam.types import LongformTranscriptionResult, Segment + model = gigaam.load_model(revision) - results = model.transcribe_longform(download_long_audio()) + result = model.transcribe_longform(download_long_audio()) ref = _predictions[revision] - assert isinstance(results, list), "Should return list of segments" - assert len(results) == len(ref), "Distinct results len from reference" - - for segment, ref_segment in zip(results, ref): - assert "transcription" in segment, "Missing transcription key" - assert "boundaries" in segment, "Missing boundaries key" - start, end = segment["boundaries"] + assert isinstance( + result, LongformTranscriptionResult + ), "Should return LongformTranscriptionResult" + assert len(result.segments) == len(ref), "Distinct results len from reference" + + for segment, ref_segment in zip(result.segments, ref): + assert isinstance(segment, Segment), "Should be Segment object" + assert hasattr(segment, "text"), "Missing text attribute" + assert hasattr(segment, "start"), "Missing start attribute" + assert hasattr(segment, "end"), "Missing end attribute" + start, end = segment.start, segment.end ref_start, ref_end = ref_segment["boundaries"] assert ( abs(start - ref_start) < 0.1 and abs(end - ref_end) < 0.1 ), f"Segments are not close {start, end} and {ref_start, ref_end}" assert ( - segment["transcription"] == ref_segment["transcription"] - ), f"Different transcription: {segment['transcription']} and {ref_segment['transcription']}" + segment.text == ref_segment["transcription"] + ), f"Different transcription: {segment.text} and {ref_segment['transcription']}" @pytest.mark.parametrize("revision", ["v3_ctc"]) @@ -181,13 +188,16 @@ def test_longform_consistency(revision): sf.write(f.name, audio, 16000) model = gigaam.load_model(revision) - results1 = model.transcribe_longform(f.name) - results2 = model.transcribe_longform(f.name) - - assert len(results1) == len(results2), "Inconsistent segment count" - for seg1, seg2 in zip(results1, results2): - assert ( - seg1["boundaries"] == seg2["boundaries"] + result1 = model.transcribe_longform(f.name) + result2 = model.transcribe_longform(f.name) + + assert len(result1.segments) == len( + result2.segments + ), "Inconsistent segment count" + for seg1, seg2 in zip(result1.segments, result2.segments): + assert (seg1.start, seg1.end) == ( + seg2.start, + seg2.end, ), "Inconsistent boundaries" finally: diff --git a/tests/test_onnx.py b/tests/test_onnx.py index 75715ff..948df1e 100644 --- a/tests/test_onnx.py +++ b/tests/test_onnx.py @@ -41,7 +41,7 @@ def test_onnx_converting(revision, test_audio): ), f"{revision}: ONNX emotions probs failed: {pred_probs}" else: - orig_text = model.transcribe(test_audio) + orig_text = model.transcribe(test_audio).text assert orig_text == result, f"{revision}: ONNX transcribe failed: {result}" diff --git a/tests/test_reading.py b/tests/test_reading.py index 257f65b..1179ef0 100644 --- a/tests/test_reading.py +++ b/tests/test_reading.py @@ -19,7 +19,7 @@ def test_audio(): return download_short_audio() -@pytest.mark.parametrize("revision", ["emo", "v2_ctc", "v3_e2e_rnnt"]) +@pytest.mark.parametrize("revision", ["emo"]) def test_librosa_loading(revision, test_audio): """Test the outputs with librosa.load are close to load_audio""" model = gigaam.load_model(revision) @@ -29,29 +29,15 @@ def test_librosa_loading(revision, test_audio): encoded, encoded_len = model( wav_tns.unsqueeze(0).to(model._device).to(model._dtype), lengths ) - if "emo" in revision: - orig_probs = model.get_probs(test_audio) - pred_probs = ( - softmax(model.head(encoded.mean(dim=-1)), dim=-1) - .squeeze() - .cpu() - .tolist() - ) - pred_probs = { - model.id2name[i]: pred_probs[i] for i in range(len(model.id2name)) - } - are_close = ( - max(abs(pred_probs[k] - orig_probs[k]) for k in orig_probs) < 1e-3 - ) - assert ( - are_close - ), f"Emotions with librosa failed: {orig_probs} != {pred_probs}" - else: - orig_text = model.transcribe(test_audio) - pred_text = model.decoding.decode(model.head, encoded, encoded_len)[0] - assert ( - orig_text == pred_text - ), f"Transcribe with librosa failed: {orig_text} != {pred_text}" + orig_probs = model.get_probs(test_audio) + pred_probs = ( + softmax(model.head(encoded.mean(dim=-1)), dim=-1).squeeze().cpu().tolist() + ) + pred_probs = { + model.id2name[i]: pred_probs[i] for i in range(len(model.id2name)) + } + are_close = max(abs(pred_probs[k] - orig_probs[k]) for k in orig_probs) < 1e-3 + assert are_close, f"Emotions with librosa failed: {orig_probs} != {pred_probs}" if __name__ == "__main__": diff --git a/tests/test_timestamps.py b/tests/test_timestamps.py new file mode 100644 index 0000000..347e061 --- /dev/null +++ b/tests/test_timestamps.py @@ -0,0 +1,218 @@ +import logging + +import pytest + +import gigaam +from gigaam.utils import download_long_audio, download_short_audio + +logging.basicConfig(level=logging.INFO) +logger = logging.getLogger(__name__) + +_predictions = { + "v3_e2e_rnnt": { + "text": "Ничьих не требуя похвал, Счастлив уж я надеждой сладкой, Что дева с трепетом любви Посмотрит, может быть, украдкой На песни грешные мои. У лукоморья дуб зелёный.", # noqa: E501 + "words": [ + {"word": "Ничьих", "start": 0.04, "end": 0.4}, + {"word": "не", "start": 0.52, "end": 0.56}, + {"word": "требуя", "start": 0.64, "end": 0.96}, + {"word": "похвал,", "start": 1.08, "end": 1.6}, + {"word": "Счастлив", "start": 1.72, "end": 2.16}, + {"word": "уж", "start": 2.24, "end": 2.4}, + {"word": "я", "start": 2.48, "end": 2.52}, + {"word": "надеждой", "start": 2.64, "end": 3.12}, + {"word": "сладкой,", "start": 3.16, "end": 3.68}, + {"word": "Что", "start": 3.72, "end": 3.76}, + {"word": "дева", "start": 3.88, "end": 4.08}, + {"word": "с", "start": 4.16, "end": 4.2}, + {"word": "трепетом", "start": 4.24, "end": 4.72}, + {"word": "любви", "start": 4.8, "end": 5.04}, + {"word": "Посмотрит,", "start": 5.32, "end": 6.0}, + {"word": "может", "start": 6.08, "end": 6.12}, + {"word": "быть,", "start": 6.28, "end": 6.48}, + {"word": "украдкой", "start": 6.52, "end": 6.96}, + {"word": "На", "start": 7.16, "end": 7.2}, + {"word": "песни", "start": 7.28, "end": 7.56}, + {"word": "грешные", "start": 7.68, "end": 8.08}, + {"word": "мои.", "start": 8.24, "end": 8.72}, + {"word": "У", "start": 9.2, "end": 9.24}, + {"word": "лукоморья", "start": 9.36, "end": 10.0}, + {"word": "дуб", "start": 10.12, "end": 10.36}, + {"word": "зелёный.", "start": 10.48, "end": 11.08}, + ], + }, + "v3_ctc": { + "text": "ничьих не требуя похвал счастлив уж я надеждой сладкой что дева с трепетом любви посмотрит может быть украдкой на песни грешные мои у лукоморья дуб зеленый", # noqa: E501 + "words": [ + {"word": "ничьих", "start": 0.08, "end": 0.44}, + {"word": "не", "start": 0.52, "end": 0.64}, + {"word": "требуя", "start": 0.72, "end": 1.0}, + {"word": "похвал", "start": 1.16, "end": 1.52}, + {"word": "счастлив", "start": 1.76, "end": 2.2}, + {"word": "уж", "start": 2.28, "end": 2.4}, + {"word": "я", "start": 2.48, "end": 2.52}, + {"word": "надеждой", "start": 2.72, "end": 3.12}, + {"word": "сладкой", "start": 3.2, "end": 3.6}, + {"word": "что", "start": 3.68, "end": 3.8}, + {"word": "дева", "start": 3.92, "end": 4.12}, + {"word": "с", "start": 4.2, "end": 4.24}, + {"word": "трепетом", "start": 4.32, "end": 4.72}, + {"word": "любви", "start": 4.84, "end": 5.12}, + {"word": "посмотрит", "start": 5.4, "end": 5.92}, + {"word": "может", "start": 6.04, "end": 6.24}, + {"word": "быть", "start": 6.32, "end": 6.48}, + {"word": "украдкой", "start": 6.6, "end": 7.08}, + {"word": "на", "start": 7.16, "end": 7.24}, + {"word": "песни", "start": 7.36, "end": 7.64}, + {"word": "грешные", "start": 7.72, "end": 8.12}, + {"word": "мои", "start": 8.28, "end": 8.48}, + {"word": "у", "start": 9.28, "end": 9.32}, + {"word": "лукоморья", "start": 9.44, "end": 10.04}, + {"word": "дуб", "start": 10.16, "end": 10.36}, + {"word": "зеленый", "start": 10.48, "end": 10.92}, + ], + }, +} + + +@pytest.fixture(scope="session") +def test_audio(): + return download_short_audio() + + +@pytest.fixture(scope="session") +def long_audio(): + return download_long_audio() + + +@pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) +def test_word_timestamps_predictions(revision, test_audio): + """Test word timestamps match expected values.""" + model = gigaam.load_model(revision, device="cpu") + result = model.transcribe(test_audio, word_timestamps=True) + expected = _predictions[revision] + + assert result.text == expected["text"], f"Text mismatch: {result.text}" + assert len(result.words) == len(expected["words"]), "Word count mismatch" + + for actual, exp in zip(result.words, expected["words"]): + assert actual.text == exp["word"], f"Word mismatch: {actual} vs {exp}" + assert abs(actual.start - exp["start"]) < 0.1, f"Start mismatch: {actual}" + assert abs(actual.end - exp["end"]) < 0.1, f"End mismatch: {actual}" + + logger.info(f"{revision}: Word timestamps predictions matched") + + +@pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) +def test_transcribe_word_timestamps_structure(revision, test_audio): + """Test that word_timestamps=True returns correct structure.""" + from gigaam.types import TranscriptionResult, Word + + model = gigaam.load_model(revision) + result = model.transcribe(test_audio, word_timestamps=True) + + assert isinstance(result, TranscriptionResult), "Should return TranscriptionResult" + assert hasattr(result, "text"), "Result should have 'text' attribute" + assert hasattr(result, "words"), "Result should have 'words' attribute" + assert isinstance(result.text, str), "'text' should be string" + assert isinstance(result.words, list), "'words' should be list" + assert all( + isinstance(w, Word) for w in result.words + ), "All words should be Word objects" + logger.info(f"{revision}: text={result.text[:50]}...") + + +@pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) +def test_transcribe_word_timestamps_values(revision, test_audio): + """Test that word timestamps have valid and ordered values.""" + model = gigaam.load_model(revision) + result = model.transcribe(test_audio, word_timestamps=True) + + words = result.words + assert len(words) > 0, "Should have at least one word" + + prev_end = 0.0 + for w in words: + assert hasattr(w, "text"), "Word entry should have 'text' attribute" + assert hasattr(w, "start"), "Word entry should have 'start' attribute" + assert hasattr(w, "end"), "Word entry should have 'end' attribute" + assert isinstance(w.text, str), "'text' should be string" + assert w.start < w.end, f"start should be < end: {w}" + assert w.start >= prev_end - 0.01, f"Words should be ordered: {w}" + prev_end = w.end + + logger.info(f"{revision}: {len(words)} words, last end={prev_end:.2f}s") + + +@pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) +def test_transcribe_default_returns_string(revision, test_audio): + """Test that default behavior (word_timestamps=False) returns TranscriptionResult with __str__.""" + from gigaam.types import TranscriptionResult + + model = gigaam.load_model(revision) + result = model.transcribe(test_audio) + + assert isinstance(result, TranscriptionResult), "Should return TranscriptionResult" + assert isinstance(str(result), str), "str(result) should return string" + assert len(str(result)) > 0, "Transcription should not be empty" + assert result.words is None, "Should not have words when word_timestamps=False" + + +@pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) +def test_transcribe_longform_word_timestamps(revision, long_audio): + """Test longform transcription with word_timestamps=True.""" + from gigaam.types import LongformTranscriptionResult, Segment, Word + + model = gigaam.load_model(revision) + result = model.transcribe_longform(long_audio, word_timestamps=True) + + assert isinstance( + result, LongformTranscriptionResult + ), "Should return LongformTranscriptionResult" + assert len(result.segments) > 0, "Should have at least one segment" + + # Check segments have words + for seg in result.segments: + assert isinstance(seg, Segment), "Should be Segment object" + assert ( + seg.words is not None + ), "Segment should have words when word_timestamps=True" + assert all(isinstance(w, Word) for w in seg.words), "All should be Word objects" + + # Check flattened words + all_words = result.words + assert len(all_words) > 0, "Should have at least one word" + prev_end = 0.0 + for w in all_words: + assert w.start < w.end, f"start should be < end: {w}" + prev_end = w.end + + logger.info( + f"{revision} longform: {len(all_words)} words, last end={prev_end:.2f}s" + ) + + +@pytest.mark.parametrize("revision", ["v3_ctc", "v3_e2e_rnnt"]) +def test_transcribe_longform_default(revision, long_audio): + """Test that default longform behavior returns segments with transcription.""" + from gigaam.types import LongformTranscriptionResult, Segment + + model = gigaam.load_model(revision) + result = model.transcribe_longform(long_audio) + + assert isinstance( + result, LongformTranscriptionResult + ), "Should return LongformTranscriptionResult" + assert len(result.segments) > 0, "Should have segments" + + for seg in result.segments: + assert isinstance(seg, Segment), "Should be Segment object" + assert hasattr(seg, "text"), "Segment should have 'text' attribute" + assert hasattr(seg, "start"), "Segment should have 'start' attribute" + assert hasattr(seg, "end"), "Segment should have 'end' attribute" + assert ( + seg.words is None + ), "Segment should not have words when word_timestamps=False" + + +if __name__ == "__main__": + pytest.main([__file__, "-v"])