@@ -74,11 +74,13 @@ def decode_batch(tokens: List[List[int]]) -> str:
74
74
75
75
return text
76
76
77
- def encode (self , features : np .ndarray ) -> ctranslate2 .StorageView :
77
+ def encode (self , features : Union [ np .ndarray , torch . Tensor ] ) -> ctranslate2 .StorageView :
78
78
# When the model is running on multiple GPUs, the encoder output should be moved
79
79
# to the CPU since we don't know which GPU will handle the next job.
80
80
to_cpu = self .model .device == "cuda" and len (self .model .device_index ) > 1
81
81
# unsqueeze if batch size = 1
82
+ if isinstance (features , torch .Tensor ):
83
+ features = features .cpu ().numpy ()
82
84
if len (features .shape ) == 2 :
83
85
features = np .expand_dims (features , 0 )
84
86
features = faster_whisper .transcribe .get_ctranslate2_storage (features )
@@ -171,19 +173,22 @@ def stack(items):
171
173
return final_iterator
172
174
173
175
def transcribe (
174
- self , audio : Union [str , np .ndarray ], batch_size = None , num_workers = 0 , language = None , task = None , chunk_size = 30 , print_progress = False , combined_progress = False
176
+ self , audio : Union [str , np .ndarray , torch . Tensor ], batch_size = None , num_workers = 0 , language = None , task = None , chunk_size = 30 , print_progress = False , combined_progress = False
175
177
) -> TranscriptionResult :
176
178
if isinstance (audio , str ):
177
179
audio = load_audio (audio )
178
180
181
+ if isinstance (audio , np .ndarray ):
182
+ audio = torch .from_numpy (audio )
183
+
179
184
def data (audio , segments ):
180
185
for seg in segments :
181
186
f1 = int (seg ['start' ] * SAMPLE_RATE )
182
187
f2 = int (seg ['end' ] * SAMPLE_RATE )
183
188
# print(f2-f1)
184
189
yield {'inputs' : audio [f1 :f2 ]}
185
190
186
- vad_segments = self .vad_model ({"waveform" : torch . from_numpy ( audio ) .unsqueeze (0 ), "sample_rate" : SAMPLE_RATE })
191
+ vad_segments = self .vad_model ({"waveform" : audio .unsqueeze (0 ), "sample_rate" : SAMPLE_RATE })
187
192
vad_segments = merge_chunks (
188
193
vad_segments ,
189
194
chunk_size ,
@@ -203,7 +208,7 @@ def data(audio, segments):
203
208
self .tokenizer = faster_whisper .tokenizer .Tokenizer (self .model .hf_tokenizer ,
204
209
self .model .model .is_multilingual , task = task ,
205
210
language = language )
206
-
211
+
207
212
if self .suppress_numerals :
208
213
previous_suppress_tokens = self .options .suppress_tokens
209
214
numeral_symbol_tokens = find_numeral_symbol_tokens (self .tokenizer )
@@ -242,7 +247,7 @@ def data(audio, segments):
242
247
return {"segments" : segments , "language" : language }
243
248
244
249
245
- def detect_language (self , audio : np .ndarray ):
250
+ def detect_language (self , audio : Union [ np .ndarray , torch . Tensor ] ):
246
251
if audio .shape [0 ] < N_SAMPLES :
247
252
print ("Warning: audio is shorter than 30s, language detection may be inaccurate." )
248
253
model_n_mels = self .model .feat_kwargs .get ("feature_size" )
0 commit comments