Skip to content

Commit b2c45c0

Browse files
committed
Accept torch.Tensor as input
1 parent f2da2f8 commit b2c45c0

File tree

2 files changed

+21
-10
lines changed

2 files changed

+21
-10
lines changed

whisperx/asr.py

+10-5
Original file line numberDiff line numberDiff line change
@@ -74,11 +74,13 @@ def decode_batch(tokens: List[List[int]]) -> str:
7474

7575
return text
7676

77-
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
77+
def encode(self, features: Union[np.ndarray, torch.Tensor]) -> ctranslate2.StorageView:
7878
# When the model is running on multiple GPUs, the encoder output should be moved
7979
# to the CPU since we don't know which GPU will handle the next job.
8080
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
8181
# unsqueeze if batch size = 1
82+
if isinstance(features, torch.Tensor):
83+
features = features.cpu().numpy()
8284
if len(features.shape) == 2:
8385
features = np.expand_dims(features, 0)
8486
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
@@ -171,19 +173,22 @@ def stack(items):
171173
return final_iterator
172174

173175
def transcribe(
174-
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
176+
self, audio: Union[str, np.ndarray, torch.Tensor], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
175177
) -> TranscriptionResult:
176178
if isinstance(audio, str):
177179
audio = load_audio(audio)
178180

181+
if isinstance(audio, np.ndarray):
182+
audio = torch.from_numpy(audio)
183+
179184
def data(audio, segments):
180185
for seg in segments:
181186
f1 = int(seg['start'] * SAMPLE_RATE)
182187
f2 = int(seg['end'] * SAMPLE_RATE)
183188
# print(f2-f1)
184189
yield {'inputs': audio[f1:f2]}
185190

186-
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
191+
vad_segments = self.vad_model({"waveform": audio.unsqueeze(0), "sample_rate": SAMPLE_RATE})
187192
vad_segments = merge_chunks(
188193
vad_segments,
189194
chunk_size,
@@ -203,7 +208,7 @@ def data(audio, segments):
203208
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
204209
self.model.model.is_multilingual, task=task,
205210
language=language)
206-
211+
207212
if self.suppress_numerals:
208213
previous_suppress_tokens = self.options.suppress_tokens
209214
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
@@ -242,7 +247,7 @@ def data(audio, segments):
242247
return {"segments": segments, "language": language}
243248

244249

245-
def detect_language(self, audio: np.ndarray):
250+
def detect_language(self, audio: Union[np.ndarray, torch.Tensor]):
246251
if audio.shape[0] < N_SAMPLES:
247252
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
248253
model_n_mels = self.model.feat_kwargs.get("feature_size")

whisperx/diarize.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@ def __init__(
1818
device = torch.device(device)
1919
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
2020

21-
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
21+
def __call__(self, audio: Union[str, np.ndarray, torch.Tensor], num_speakers=None, min_speakers=None, max_speakers=None):
2222
if isinstance(audio, str):
2323
audio = load_audio(audio)
24+
25+
audio = audio[None, :]
26+
27+
if isinstance(audio, np.ndarray):
28+
audio = torch.from_numpy(audio)
29+
2430
audio_data = {
25-
'waveform': torch.from_numpy(audio[None, :]),
31+
'waveform': audio,
2632
'sample_rate': SAMPLE_RATE
2733
}
2834
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
@@ -47,7 +53,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
4753
# sum over speakers
4854
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
4955
seg["speaker"] = speaker
50-
56+
5157
# assign speaker to words
5258
if 'words' in seg:
5359
for word in seg['words']:
@@ -63,8 +69,8 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
6369
# sum over speakers
6470
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
6571
word["speaker"] = speaker
66-
67-
return transcript_result
72+
73+
return transcript_result
6874

6975

7076
class Segment:

0 commit comments

Comments
 (0)