@@ -74,13 +74,16 @@ def decode_batch(tokens: List[List[int]]) -> str:
74
74
75
75
return text
76
76
77
- def encode (self , features : np .ndarray ) -> ctranslate2 .StorageView :
77
+ def encode (self , features : Union [ np .ndarray , torch . Tensor ] ) -> ctranslate2 .StorageView :
78
78
# When the model is running on multiple GPUs, the encoder output should be moved
79
79
# to the CPU since we don't know which GPU will handle the next job.
80
80
to_cpu = self .model .device == "cuda" and len (self .model .device_index ) > 1
81
81
# unsqueeze if batch size = 1
82
82
if len (features .shape ) == 2 :
83
- features = np .expand_dims (features , 0 )
83
+ if isinstance (features , np .ndarray ):
84
+ features = np .expand_dims (features , 0 )
85
+ else :
86
+ features = features .unsqueeze (0 )
84
87
features = faster_whisper .transcribe .get_ctranslate2_storage (features )
85
88
86
89
return self .model .encode (features , to_cpu = to_cpu )
@@ -171,19 +174,22 @@ def stack(items):
171
174
return final_iterator
172
175
173
176
def transcribe (
174
- self , audio : Union [str , np .ndarray ], batch_size = None , num_workers = 0 , language = None , task = None , chunk_size = 30 , print_progress = False , combined_progress = False
177
+ self , audio : Union [str , np .ndarray , torch . Tensor ], batch_size = None , num_workers = 0 , language = None , task = None , chunk_size = 30 , print_progress = False , combined_progress = False
175
178
) -> TranscriptionResult :
176
179
if isinstance (audio , str ):
177
180
audio = load_audio (audio )
178
181
182
+ if isinstance (audio , np .ndarray ):
183
+ audio = torch .from_numpy (audio )
184
+
179
185
def data (audio , segments ):
180
186
for seg in segments :
181
187
f1 = int (seg ['start' ] * SAMPLE_RATE )
182
188
f2 = int (seg ['end' ] * SAMPLE_RATE )
183
189
# print(f2-f1)
184
190
yield {'inputs' : audio [f1 :f2 ]}
185
191
186
- vad_segments = self .vad_model ({"waveform" : torch . from_numpy ( audio ) .unsqueeze (0 ), "sample_rate" : SAMPLE_RATE })
192
+ vad_segments = self .vad_model ({"waveform" : audio .unsqueeze (0 ), "sample_rate" : SAMPLE_RATE })
187
193
vad_segments = merge_chunks (
188
194
vad_segments ,
189
195
chunk_size ,
@@ -203,7 +209,7 @@ def data(audio, segments):
203
209
self .tokenizer = faster_whisper .tokenizer .Tokenizer (self .model .hf_tokenizer ,
204
210
self .model .model .is_multilingual , task = task ,
205
211
language = language )
206
-
212
+
207
213
if self .suppress_numerals :
208
214
previous_suppress_tokens = self .options .suppress_tokens
209
215
numeral_symbol_tokens = find_numeral_symbol_tokens (self .tokenizer )
@@ -242,7 +248,7 @@ def data(audio, segments):
242
248
return {"segments" : segments , "language" : language }
243
249
244
250
245
- def detect_language (self , audio : np .ndarray ):
251
+ def detect_language (self , audio : Union [ np .ndarray , torch . Tensor ] ):
246
252
if audio .shape [0 ] < N_SAMPLES :
247
253
print ("Warning: audio is shorter than 30s, language detection may be inaccurate." )
248
254
model_n_mels = self .model .feat_kwargs .get ("feature_size" )
0 commit comments