Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions src-tauri/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 3 additions & 0 deletions src-tauri/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,6 @@ lto = true
codegen-units = 1
strip = true
panic = "abort"

[patch.crates-io]
transcribe-rs = { path = "../../transcribe-rs" }
8 changes: 8 additions & 0 deletions src-tauri/src/commands/transcription.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,11 @@ pub fn unload_model_manually(
.unload_model()
.map_err(|e| format!("Failed to unload model: {}", e))
}

#[tauri::command]
#[specta::specta]
pub fn retry_whisper_gpu(transcription_manager: State<TranscriptionManager>) -> Result<(), String> {
transcription_manager
.retry_whisper_gpu()
.map_err(|e| format!("Failed to return Whisper to GPU: {}", e))
}
3 changes: 3 additions & 0 deletions src-tauri/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ mod signal_handle;
mod tray;
mod tray_i18n;
mod utils;
pub mod whisper_worker;
use specta_typescript::{BigIntExportBehavior, Typescript};
use tauri_specta::{collect_commands, Builder};

Expand Down Expand Up @@ -255,6 +256,7 @@ pub fn run() {
shortcut::change_clipboard_handling_setting,
shortcut::change_post_process_enabled_setting,
shortcut::change_experimental_enabled_setting,
shortcut::change_whisper_compute_mode_setting,
shortcut::change_post_process_base_url_setting,
shortcut::change_post_process_api_key_setting,
shortcut::change_post_process_model_setting,
Expand Down Expand Up @@ -315,6 +317,7 @@ pub fn run() {
commands::transcription::set_model_unload_timeout,
commands::transcription::get_model_load_status,
commands::transcription::unload_model_manually,
commands::transcription::retry_whisper_gpu,
commands::history::get_history_entries,
commands::history::toggle_history_entry_saved,
commands::history::get_audio_file_path,
Expand Down
5 changes: 5 additions & 0 deletions src-tauri/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
#![cfg_attr(not(debug_assertions), windows_subsystem = "windows")]

fn main() {
if std::env::args().any(|arg| arg == "--whisper-worker") {
let _ = handy_app_lib::whisper_worker::run_worker_process();
return;
}

#[cfg(target_os = "linux")]
{
if std::path::Path::new("/dev/dri").exists()
Expand Down
165 changes: 139 additions & 26 deletions src-tauri/src/managers/transcription.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
use crate::audio_toolkit::{apply_custom_words, filter_transcription_output};
use crate::managers::model::{EngineType, ModelManager};
use crate::settings::{get_settings, ModelUnloadTimeout};
use crate::settings::{get_settings, ModelUnloadTimeout, WhisperComputeMode};
use crate::whisper_worker::{
WhisperRuntimeMode, WhisperWorkerClient, WhisperWorkerInferenceParams,
};
use anyhow::Result;
use log::{debug, error, info, warn};
use serde::Serialize;
Expand All @@ -15,7 +18,6 @@ use transcribe_rs::{
parakeet::{
ParakeetEngine, ParakeetInferenceParams, ParakeetModelParams, TimestampGranularity,
},
whisper::{WhisperEngine, WhisperInferenceParams},
},
TranscriptionEngine,
};
Expand All @@ -28,8 +30,16 @@ pub struct ModelStateEvent {
pub error: Option<String>,
}

#[derive(Clone, Debug, Serialize)]
pub struct WhisperComputeFallbackEvent {
pub from: String,
pub to: String,
pub reason: String,
pub can_retry_gpu: bool,
}

enum LoadedEngine {
Whisper(WhisperEngine),
Whisper(WhisperWorkerClient),
Parakeet(ParakeetEngine),
Moonshine(MoonshineEngine),
}
Expand Down Expand Up @@ -142,7 +152,10 @@ impl TranscriptionManager {
let mut engine = self.engine.lock().unwrap();
if let Some(ref mut loaded_engine) = *engine {
match loaded_engine {
LoadedEngine::Whisper(ref mut e) => e.unload_model(),
LoadedEngine::Whisper(ref mut e) => {
let _ = e.unload();
e.terminate();
}
LoadedEngine::Parakeet(ref mut e) => e.unload_model(),
LoadedEngine::Moonshine(ref mut e) => e.unload_model(),
}
Expand Down Expand Up @@ -225,21 +238,35 @@ impl TranscriptionManager {
// Create appropriate engine based on model type
let loaded_engine = match model_info.engine_type {
EngineType::Whisper => {
let mut engine = WhisperEngine::new();
engine.load_model(&model_path).map_err(|e| {
let error_msg = format!("Failed to load whisper model {}: {}", model_id, e);
let _ = self.app_handle.emit(
"model-state-changed",
ModelStateEvent {
event_type: "loading_failed".to_string(),
model_id: Some(model_id.to_string()),
model_name: Some(model_info.name.clone()),
error: Some(error_msg.clone()),
},
);
anyhow::anyhow!(error_msg)
})?;
LoadedEngine::Whisper(engine)
let settings = get_settings(&self.app_handle);
let preferred_mode =
whisper_runtime_mode_from_setting(settings.whisper_compute_mode);
let worker = WhisperWorkerClient::spawn_for_model(&model_path, preferred_mode)
.or_else(|first_err| {
if settings.whisper_compute_mode == WhisperComputeMode::Auto {
WhisperWorkerClient::spawn_for_model(
&model_path,
WhisperRuntimeMode::Cpu,
)
.map_err(|cpu_err| anyhow::anyhow!("{}; {}", first_err, cpu_err))
} else {
Err(first_err)
}
})
.map_err(|e| {
let error_msg = format!("Failed to load whisper model {}: {}", model_id, e);
let _ = self.app_handle.emit(
"model-state-changed",
ModelStateEvent {
event_type: "loading_failed".to_string(),
model_id: Some(model_id.to_string()),
model_name: Some(model_info.name.clone()),
error: Some(error_msg.clone()),
},
);
anyhow::anyhow!(error_msg)
})?;
LoadedEngine::Whisper(worker)
}
EngineType::Parakeet => {
let mut engine = ParakeetEngine::new();
Expand Down Expand Up @@ -341,6 +368,31 @@ impl TranscriptionManager {
current_model.clone()
}

pub fn retry_whisper_gpu(&self) -> Result<()> {
let model_id = self
.get_current_model()
.ok_or_else(|| anyhow::anyhow!("No active model"))?;
let model_info = self
.model_manager
.get_model_info(&model_id)
.ok_or_else(|| anyhow::anyhow!("Model not found: {}", model_id))?;
if !matches!(model_info.engine_type, EngineType::Whisper) {
return Err(anyhow::anyhow!("Current model is not Whisper"));
}

let model_path = self.model_manager.get_model_path(&model_id)?;
let gpu_worker =
WhisperWorkerClient::spawn_for_model(&model_path, WhisperRuntimeMode::Gpu)?;
let mut engine_guard = self.engine.lock().unwrap();
if let Some(LoadedEngine::Whisper(existing)) = engine_guard.as_mut() {
existing.terminate();
*existing = gpu_worker;
Ok(())
} else {
Err(anyhow::anyhow!("Whisper engine is not loaded"))
}
}

pub fn transcribe(&self, audio: Vec<f32>) -> Result<String> {
// Update last activity timestamp
self.last_activity.store(
Expand Down Expand Up @@ -389,8 +441,6 @@ impl TranscriptionManager {

match engine {
LoadedEngine::Whisper(whisper_engine) => {
// Normalize language code for Whisper
// Convert zh-Hans and zh-Hant to zh since Whisper uses ISO 639-1 codes
let whisper_language = if settings.selected_language == "auto" {
None
} else {
Expand All @@ -404,15 +454,70 @@ impl TranscriptionManager {
Some(normalized)
};

let params = WhisperInferenceParams {
let params = WhisperWorkerInferenceParams {
language: whisper_language,
translate: settings.translate_to_english,
..Default::default()
};

whisper_engine
.transcribe_samples(audio, Some(params))
.map_err(|e| anyhow::anyhow!("Whisper transcription failed: {}", e))?
match whisper_engine.transcribe(audio.clone(), params.clone()) {
Ok(text) => transcribe_rs::TranscriptionResult {
text,
segments: None,
},
Err(err) => {
let can_fallback = matches!(
settings.whisper_compute_mode,
WhisperComputeMode::Auto | WhisperComputeMode::Gpu
) && whisper_engine.runtime_mode()
== WhisperRuntimeMode::Gpu;

if can_fallback {
let current_model_id = self
.get_current_model()
.ok_or_else(|| anyhow::anyhow!("No active whisper model"))?;
let model_path =
self.model_manager.get_model_path(&current_model_id)?;
whisper_engine.terminate();
let mut cpu_worker = WhisperWorkerClient::spawn_for_model(
&model_path,
WhisperRuntimeMode::Cpu,
)
.map_err(|e| {
anyhow::anyhow!(
"Whisper GPU failed and CPU fallback failed: {}; {}",
err,
e
)
})?;
let retried_text =
cpu_worker.transcribe(audio, params).map_err(|e| {
anyhow::anyhow!(
"Whisper CPU fallback transcription failed: {}",
e
)
})?;
*whisper_engine = cpu_worker;
let _ = self.app_handle.emit(
"whisper-compute-fallback",
WhisperComputeFallbackEvent {
from: "gpu".to_string(),
to: "cpu".to_string(),
reason: err.to_string(),
can_retry_gpu: true,
},
);
transcribe_rs::TranscriptionResult {
text: retried_text,
segments: None,
}
} else {
return Err(anyhow::anyhow!(
"Whisper transcription failed: {}",
err
));
}
}
}
}
LoadedEngine::Parakeet(parakeet_engine) => {
let params = ParakeetInferenceParams {
Expand Down Expand Up @@ -486,3 +591,11 @@ impl Drop for TranscriptionManager {
}
}
}

fn whisper_runtime_mode_from_setting(mode: WhisperComputeMode) -> WhisperRuntimeMode {
match mode {
WhisperComputeMode::Auto => WhisperRuntimeMode::Gpu,
WhisperComputeMode::Gpu => WhisperRuntimeMode::Gpu,
WhisperComputeMode::Cpu => WhisperRuntimeMode::Cpu,
}
}
20 changes: 20 additions & 0 deletions src-tauri/src/settings.rs
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,14 @@ pub enum KeyboardImplementation {
HandyKeys,
}

#[derive(Serialize, Deserialize, Debug, Clone, Copy, PartialEq, Eq, Type)]
#[serde(rename_all = "snake_case")]
pub enum WhisperComputeMode {
Auto,
Gpu,
Cpu,
}

impl Default for KeyboardImplementation {
fn default() -> Self {
// Default to HandyKeys only on macOS where it's well-tested.
Expand All @@ -169,6 +177,15 @@ impl Default for KeyboardImplementation {
}
}

impl Default for WhisperComputeMode {
fn default() -> Self {
#[cfg(target_os = "windows")]
return WhisperComputeMode::Auto;
#[cfg(not(target_os = "windows"))]
return WhisperComputeMode::Auto;
}
}

impl Default for ModelUnloadTimeout {
fn default() -> Self {
ModelUnloadTimeout::Never
Expand Down Expand Up @@ -315,6 +332,8 @@ pub struct AppSettings {
pub experimental_enabled: bool,
#[serde(default)]
pub keyboard_implementation: KeyboardImplementation,
#[serde(default)]
pub whisper_compute_mode: WhisperComputeMode,
#[serde(default = "default_paste_delay_ms")]
pub paste_delay_ms: u64,
}
Expand Down Expand Up @@ -631,6 +650,7 @@ pub fn get_default_settings() -> AppSettings {
app_language: default_app_language(),
experimental_enabled: false,
keyboard_implementation: KeyboardImplementation::default(),
whisper_compute_mode: WhisperComputeMode::default(),
paste_delay_ms: default_paste_delay_ms(),
}
}
Expand Down
20 changes: 18 additions & 2 deletions src-tauri/src/shortcut/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -21,8 +21,8 @@ use tauri_plugin_autostart::ManagerExt;

use crate::settings::{
self, get_settings, ClipboardHandling, KeyboardImplementation, LLMPrompt, OverlayPosition,
PasteMethod, ShortcutBinding, SoundTheme, APPLE_INTELLIGENCE_DEFAULT_MODEL_ID,
APPLE_INTELLIGENCE_PROVIDER_ID,
PasteMethod, ShortcutBinding, SoundTheme, WhisperComputeMode,
APPLE_INTELLIGENCE_DEFAULT_MODEL_ID, APPLE_INTELLIGENCE_PROVIDER_ID,
};
use crate::tray;

Expand Down Expand Up @@ -730,6 +730,22 @@ pub fn change_experimental_enabled_setting(app: AppHandle, enabled: bool) -> Res
Ok(())
}

#[tauri::command]
#[specta::specta]
pub fn change_whisper_compute_mode_setting(app: AppHandle, mode: String) -> Result<(), String> {
let mut settings = settings::get_settings(&app);
settings.whisper_compute_mode = match mode.as_str() {
"auto" => WhisperComputeMode::Auto,
"gpu" => WhisperComputeMode::Gpu,
"cpu" => WhisperComputeMode::Cpu,
other => {
return Err(format!("Invalid whisper compute mode '{}'", other));
}
};
settings::write_settings(&app, settings);
Ok(())
}

#[tauri::command]
#[specta::specta]
pub fn change_post_process_base_url_setting(
Expand Down
Loading