Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
73 changes: 68 additions & 5 deletions gigaam/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

from .preprocess import SAMPLE_RATE, load_audio
from .utils import onnx_converter
from .timestamps_utils import decode_with_alignment_ctc, decode_with_alignment_rnnt, token_to_str, chars_to_words

LONGFORM_THRESHOLD = 25 * SAMPLE_RATE

Expand Down Expand Up @@ -145,27 +146,89 @@ def _to_onnx(self, dir_path: str = ".") -> None:
module=self.head.joint,
)

def _extract_word_timestamps(self, wav: Tensor, length: Tensor) -> Tuple[str, List[Dict[str, Union[str, float]]]]:
"""
Run the model on a single waveform chunk and return the decoded transcript
together with word-level time spans (in seconds) aligned to that chunk.
"""
encoded, encoded_len = self.forward(wav, length)
seq_len = int(encoded_len[0].item())
frame_shift = int(length[0].item()) / SAMPLE_RATE / seq_len

tokenizer = self.decoding.tokenizer
blank_id = self.decoding.blank_id

if hasattr(self.head, "decoder"): # RNNT family
encoded_rnnt = encoded.transpose(1, 2)
seq = encoded_rnnt[0, :, :].unsqueeze(1)
max_symbols = getattr(self.decoding, "max_symbols", 3)
token_ids, token_frames = decode_with_alignment_rnnt(
self.head, seq, seq_len, blank_id, max_symbols
)
else: # CTC family
token_ids, token_frames = decode_with_alignment_ctc(
self.head, encoded, seq_len, blank_id
)

transcript = tokenizer.decode(token_ids)
chars = [token_to_str(tokenizer, idx) for idx in token_ids]
word_segments = chars_to_words(chars, token_frames, frame_shift)

return transcript.strip(), word_segments

@torch.inference_mode()
def transcribe_longform(
self, wav_file: str, **kwargs
) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
self,
wav_file: str,
word_timestamps: bool = False,
**kwargs) -> List[Dict[str, Union[str, Tuple[float, float]]]]:
"""
Transcribes a long audio file by splitting it into segments and
then transcribing each segment.
If word_timestamps = True, provide word level timestamps for each word in each segment.

Return format:
[
{
"text": str,
"start": float,
"end": float
}
]
"""
from .vad_utils import segment_audio_file

transcribed_segments = []
segments, boundaries = segment_audio_file(
wav_file, SAMPLE_RATE, device=self._device, **kwargs
)
if word_timestamps:
words_with_timestamps: List[Dict[str, float]] = []
for segment, segment_boundaries in zip(segments, boundaries):
segment_offset = segment_boundaries[0] # seconds from start of full audio
wav = segment.to(self._device).unsqueeze(0).to(self._dtype)
length = torch.full([1], wav.shape[-1], device=self._device)
_, words = self._extract_word_timestamps(wav, length)
for word in words:
words_with_timestamps.append(
{
"text": word["word"],
"start": round(word["start"] + segment_offset, 3),
"end": round(word["end"] + segment_offset, 3),
}
)
return words_with_timestamps

transcribed_segments: List[Dict[str, Union[str, Tuple[float, float]]]] = []
for segment, segment_boundaries in zip(segments, boundaries):
wav = segment.to(self._device).unsqueeze(0).to(self._dtype)
length = torch.full([1], wav.shape[-1], device=self._device)
encoded, encoded_len = self.forward(wav, length)
result = self.decoding.decode(self.head, encoded, encoded_len)[0]
transcription = self.decoding.decode(self.head, encoded, encoded_len)[0]

transcribed_segments.append(
{"transcription": result, "boundaries": segment_boundaries}
{"text": transcription,
"start": segment_boundaries[0],
"end": segment_boundaries[1]}
)
return transcribed_segments

Expand Down
104 changes: 104 additions & 0 deletions gigaam/timestamps_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from typing import Dict, List, Tuple, Union, Sequence
import torch
from torch import Tensor
from .preprocess import SAMPLE_RATE
from .decoder import CTCHead, RNNTHead
from .decoding import Tokenizer


def decode_with_alignment_ctc(
head: CTCHead,
encoded_seq: Tensor,
seq_len: int, blank_id: int) -> Tuple[List[int], List[int]]:
"""
Greedy CTC decoding that also keeps the encoder frame index for every token.
"""
log_probs = head(encoder_output=encoded_seq)
labels = log_probs.argmax(dim=-1)
frames = labels[0, :seq_len].cpu().tolist()

hyp, token_frames = [], []
prev = blank_id
for t, label in enumerate(frames):
if label != blank_id and (label != prev or prev == blank_id):
hyp.append(int(label))
token_frames.append(t)
prev = label
return hyp, token_frames


def decode_with_alignment_rnnt(
head: RNNTHead,
encoded_seq: Tensor,
seq_len: int,
blank_id: int,
max_symbols: int) -> Tuple[List[int], List[int]]:
"""
Greedy RNNT decoding that also keeps the encoder frame index for every token.
"""
hyp, token_frames = [], []
dec_state = None
last_label = None

for t in range(seq_len):
encoder_step = encoded_seq[t, :, :].unsqueeze(1)
emitted = 0
not_blank = True

while not_blank and emitted < max_symbols:
decoder_step, hidden = head.decoder.predict(last_label, dec_state)
joint_logp = head.joint.joint(encoder_step, decoder_step)[0, 0, 0, :]
k = int(torch.argmax(joint_logp).item())
if k == blank_id:
not_blank = False
continue
hyp.append(k)
token_frames.append(t)
dec_state = hidden
last_label = torch.tensor([[k]], dtype=torch.long, device=encoded_seq.device)
emitted += 1
return hyp, token_frames


def token_to_str(tokenizer: Tokenizer, token_id: int) -> str:
if tokenizer.charwise:
return tokenizer.vocab[token_id]
return tokenizer.model.IdToPiece(token_id)


def chars_to_words(
chars: Sequence[str],
frames: Sequence[int],
frame_shift: float) -> List[Dict[str, Union[str, float]]]:
"""
Collapse a sequence of character (or subword) tokens with frame indices into
contiguous word segments, emitting absolute start/end times in seconds.
"""
words, current_chars, current_frames = [], [], []

def commit():
if not current_chars:
return
text = "".join(current_chars).strip()
if not text:
current_chars.clear()
current_frames.clear()
return
start = current_frames[0] * frame_shift
end = (current_frames[-1] + 1) * frame_shift
words.append({"word": text, "start": start, "end": end})
current_chars.clear()
current_frames.clear()

for char, frame in zip(chars, frames):
if char.startswith("▁"):
commit()
char = char[1:]
elif char == " ":
commit()
continue
current_chars.append(char)
current_frames.append(frame)

commit()
return words
12 changes: 11 additions & 1 deletion gigaam/vad_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,9 @@

import torch
from pyannote.audio import Model, Pipeline
from pyannote.audio.core.task import Problem, Resolution, Specifications
from pyannote.audio.pipelines import VoiceActivityDetection
from torch.torch_version import TorchVersion

from .preprocess import load_audio

Expand All @@ -25,7 +27,15 @@ def get_pipeline(device: torch.device) -> Pipeline:
except KeyError as exc:
raise ValueError("HF_TOKEN environment variable is not set") from exc

model = Model.from_pretrained("pyannote/segmentation-3.0", token=hf_token)
with torch.serialization.safe_globals(
[
TorchVersion,
Problem,
Specifications,
Resolution,
]
):
model = Model.from_pretrained("pyannote/segmentation-3.0", token=hf_token)
_PIPELINE = VoiceActivityDetection(segmentation=model)
_PIPELINE.instantiate({"min_duration_on": 0.0, "min_duration_off": 0.0})

Expand Down