Skip to content

Commit 3be4987

Browse files
committed
Accept torch.Tensor as input
1 parent f2da2f8 commit 3be4987

File tree

2 files changed

+23
-11
lines changed

2 files changed

+23
-11
lines changed

whisperx/asr.py

+12-6
Original file line numberDiff line numberDiff line change
@@ -74,13 +74,16 @@ def decode_batch(tokens: List[List[int]]) -> str:
7474

7575
return text
7676

77-
def encode(self, features: np.ndarray) -> ctranslate2.StorageView:
77+
def encode(self, features: Union[np.ndarray, torch.Tensor]) -> ctranslate2.StorageView:
7878
# When the model is running on multiple GPUs, the encoder output should be moved
7979
# to the CPU since we don't know which GPU will handle the next job.
8080
to_cpu = self.model.device == "cuda" and len(self.model.device_index) > 1
8181
# unsqueeze if batch size = 1
8282
if len(features.shape) == 2:
83-
features = np.expand_dims(features, 0)
83+
if isinstance(features, np.ndarray):
84+
features = np.expand_dims(features, 0)
85+
else:
86+
features = features.unsqueeze(0)
8487
features = faster_whisper.transcribe.get_ctranslate2_storage(features)
8588

8689
return self.model.encode(features, to_cpu=to_cpu)
@@ -171,19 +174,22 @@ def stack(items):
171174
return final_iterator
172175

173176
def transcribe(
174-
self, audio: Union[str, np.ndarray], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
177+
self, audio: Union[str, np.ndarray, torch.Tensor], batch_size=None, num_workers=0, language=None, task=None, chunk_size=30, print_progress = False, combined_progress=False
175178
) -> TranscriptionResult:
176179
if isinstance(audio, str):
177180
audio = load_audio(audio)
178181

182+
if isinstance(audio, np.ndarray):
183+
audio = torch.from_numpy(audio)
184+
179185
def data(audio, segments):
180186
for seg in segments:
181187
f1 = int(seg['start'] * SAMPLE_RATE)
182188
f2 = int(seg['end'] * SAMPLE_RATE)
183189
# print(f2-f1)
184190
yield {'inputs': audio[f1:f2]}
185191

186-
vad_segments = self.vad_model({"waveform": torch.from_numpy(audio).unsqueeze(0), "sample_rate": SAMPLE_RATE})
192+
vad_segments = self.vad_model({"waveform": audio.unsqueeze(0), "sample_rate": SAMPLE_RATE})
187193
vad_segments = merge_chunks(
188194
vad_segments,
189195
chunk_size,
@@ -203,7 +209,7 @@ def data(audio, segments):
203209
self.tokenizer = faster_whisper.tokenizer.Tokenizer(self.model.hf_tokenizer,
204210
self.model.model.is_multilingual, task=task,
205211
language=language)
206-
212+
207213
if self.suppress_numerals:
208214
previous_suppress_tokens = self.options.suppress_tokens
209215
numeral_symbol_tokens = find_numeral_symbol_tokens(self.tokenizer)
@@ -242,7 +248,7 @@ def data(audio, segments):
242248
return {"segments": segments, "language": language}
243249

244250

245-
def detect_language(self, audio: np.ndarray):
251+
def detect_language(self, audio: Union[np.ndarray, torch.Tensor]):
246252
if audio.shape[0] < N_SAMPLES:
247253
print("Warning: audio is shorter than 30s, language detection may be inaccurate.")
248254
model_n_mels = self.model.feat_kwargs.get("feature_size")

whisperx/diarize.py

+11-5
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,17 @@ def __init__(
1818
device = torch.device(device)
1919
self.model = Pipeline.from_pretrained(model_name, use_auth_token=use_auth_token).to(device)
2020

21-
def __call__(self, audio: Union[str, np.ndarray], num_speakers=None, min_speakers=None, max_speakers=None):
21+
def __call__(self, audio: Union[str, np.ndarray, torch.Tensor], num_speakers=None, min_speakers=None, max_speakers=None):
2222
if isinstance(audio, str):
2323
audio = load_audio(audio)
24+
25+
audio = audio[None, :]
26+
27+
if isinstance(audio, np.ndarray):
28+
audio = torch.from_numpy(audio)
29+
2430
audio_data = {
25-
'waveform': torch.from_numpy(audio[None, :]),
31+
'waveform': audio,
2632
'sample_rate': SAMPLE_RATE
2733
}
2834
segments = self.model(audio_data, num_speakers = num_speakers, min_speakers=min_speakers, max_speakers=max_speakers)
@@ -47,7 +53,7 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
4753
# sum over speakers
4854
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
4955
seg["speaker"] = speaker
50-
56+
5157
# assign speaker to words
5258
if 'words' in seg:
5359
for word in seg['words']:
@@ -63,8 +69,8 @@ def assign_word_speakers(diarize_df, transcript_result, fill_nearest=False):
6369
# sum over speakers
6470
speaker = dia_tmp.groupby("speaker")["intersection"].sum().sort_values(ascending=False).index[0]
6571
word["speaker"] = speaker
66-
67-
return transcript_result
72+
73+
return transcript_result
6874

6975

7076
class Segment:

0 commit comments

Comments
 (0)