diff --git a/apps/desktop/src/i18n/locales/en.json b/apps/desktop/src/i18n/locales/en.json index 54887349..d8c002cf 100644 --- a/apps/desktop/src/i18n/locales/en.json +++ b/apps/desktop/src/i18n/locales/en.json @@ -763,7 +763,23 @@ }, "providers": { "openRouter": "OpenRouter", - "ollama": "Ollama" + "ollama": "Ollama", + "appleIntelligence": "Apple Intelligence" + }, + "appleIntelligence": { + "checking": "Checking...", + "available": "Available", + "unavailable": "Not Available", + "sync": "Sync Model", + "syncing": "Syncing...", + "descriptionAvailable": "On-device language model powered by Apple Intelligence. No API key required.", + "descriptionUnavailable": "Apple Intelligence is not available: {{reason}}", + "descriptionUnavailableGeneric": "Apple Intelligence is not available on this device. Requires macOS 26 or later with Apple Silicon.", + "toast": { + "synced": "Apple Intelligence model synced successfully!", + "notAvailable": "Apple Intelligence is not available on this device.", + "syncFailed": "Failed to sync Apple Intelligence model." + } }, "provider": { "status": { diff --git a/apps/desktop/src/pipeline/providers/formatting/apple-intelligence-formatter.ts b/apps/desktop/src/pipeline/providers/formatting/apple-intelligence-formatter.ts new file mode 100644 index 00000000..e57f0307 --- /dev/null +++ b/apps/desktop/src/pipeline/providers/formatting/apple-intelligence-formatter.ts @@ -0,0 +1,70 @@ +import { FormattingProvider, FormatParams } from "../../core/pipeline-types"; +import { logger } from "../../../main/logger"; +import { constructFormatterPrompt } from "./formatter-prompt"; +import type { NativeBridge } from "../../../services/platform/native-bridge-service"; + +export class AppleIntelligenceFormatter implements FormattingProvider { + readonly name = "apple-intelligence"; + + constructor(private nativeBridge: NativeBridge) {} + + async format(params: FormatParams): Promise { + try { + const { text, context } = params; + // Use amical-notes formatting for on-device models to ensure + // consistent Markdown output with smart structure detection. + const { systemPrompt } = constructFormatterPrompt(context, { + overrideAppType: "amical-notes", + }); + + logger.pipeline.debug("Apple Intelligence formatting request", { + systemPrompt, + userPrompt: text, + }); + + // Wrap user text explicitly so the on-device model treats it as + // text to format rather than a conversational query to respond to. + const userPrompt = `Format the following transcribed text:\n\n${text}`; + + const result = await this.nativeBridge.call( + "generateWithFoundationModel", + { + systemPrompt, + userPrompt, + temperature: 0.1, + }, + 30000, + ); + + logger.pipeline.debug("Apple Intelligence formatting raw response", { + rawResponse: result.content, + }); + + // Extract formatted text from XML tags (same pattern as Ollama/OpenRouter) + const match = result.content.match( + /([\s\S]*?)<\/formatted_text>/, + ); + const formattedText = match ? match[1] : result.content; + + logger.pipeline.debug("Apple Intelligence formatting completed", { + original: text, + formatted: formattedText, + hadXmlTags: !!match, + }); + + // If formatted text is empty, fall back to original text + // On-device models may return empty tags for short inputs + if (!formattedText || formattedText.trim().length === 0) { + logger.pipeline.warn( + "Apple Intelligence returned empty formatted text, using original", + ); + return text; + } + + return formattedText; + } catch (error) { + logger.pipeline.error("Apple Intelligence formatting failed:", error); + return params.text; + } + } +} diff --git a/apps/desktop/src/pipeline/providers/formatting/formatter-prompt.ts b/apps/desktop/src/pipeline/providers/formatting/formatter-prompt.ts index 15a10b66..4ca41ba1 100644 --- a/apps/desktop/src/pipeline/providers/formatting/formatter-prompt.ts +++ b/apps/desktop/src/pipeline/providers/formatting/formatter-prompt.ts @@ -12,7 +12,6 @@ const BASE_INSTRUCTIONS = [ "Maintain the original meaning and tone", "Use the custom vocabulary to correct domain-specific terms", "Remove unnecessary filler words (um, uh, etc.) but keep natural speech patterns", - "If the text is empty, return ", "Return ONLY the formatted text enclosed in tags", "Do not include any commentary, explanations, or text outside the XML tags", ]; @@ -54,10 +53,28 @@ const APPLICATION_TYPE_RULES: Record = { "Use bullet points (-) for unordered lists of items, ideas, or notes", "Use numbered lists (1. 2. 3.) for sequential steps, priorities, or ranked items", "Use headers for distinct topics or sections (## for main sections, ### for subsections)", - "Use bold (**text**) for emphasis on key terms or action items", - "Use code blocks (```) for technical content, commands, or code snippets", + "Do NOT use bold (**text**) or italic (*text*) markup - output plain text with list and header formatting only", "Keep formatting minimal and purposeful - don't over-format simple content", "Preserve natural speech flow while adding structure where it improves clarity", + "Detect implicit structure in speech: when someone lists items (e.g. 'A, B, and C'), format them as a bullet list", + "", + "Examples:", + 'Input: "My favorite foods are ramen, curry, and oyakodon."', + "Output:", + "My favorite foods are:", + "- Ramen", + "- Curry", + "- Oyakodon", + "", + 'Input: "First you need to install Node then run npm install and finally start the server"', + "Output:", + "1. Install Node", + "2. Run `npm install`", + "3. Start the server", + "", + 'Input: "The meeting went well we discussed the budget and the timeline"', + "Output:", + 'The meeting went well. We discussed the budget and the timeline.', ], default: [ "Apply standard formatting for general text", @@ -131,13 +148,17 @@ const URL_PATTERNS: Partial> = { ], }; -export function constructFormatterPrompt(context: FormatParams["context"]): { +export function constructFormatterPrompt( + context: FormatParams["context"], + options?: { overrideAppType?: AppType }, +): { systemPrompt: string; } { const { accessibilityContext, vocabulary } = context; - // Detect application type - const applicationType = detectApplicationType(accessibilityContext); + // Use override if provided, otherwise detect from accessibility context + const applicationType = + options?.overrideAppType ?? detectApplicationType(accessibilityContext); // Build instructions array const instructions = [ diff --git a/apps/desktop/src/renderer/main/pages/settings/ai-models/components/apple-intelligence-provider.tsx b/apps/desktop/src/renderer/main/pages/settings/ai-models/components/apple-intelligence-provider.tsx new file mode 100644 index 00000000..03f5fa05 --- /dev/null +++ b/apps/desktop/src/renderer/main/pages/settings/ai-models/components/apple-intelligence-provider.tsx @@ -0,0 +1,115 @@ +"use client"; +import { useState } from "react"; +import { Button } from "@/components/ui/button"; +import { Badge } from "@/components/ui/badge"; +import { Loader2 } from "lucide-react"; +import { cn } from "@/lib/utils"; +import { api } from "@/trpc/react"; +import { toast } from "sonner"; +import { useTranslation } from "react-i18next"; + +export default function AppleIntelligenceProvider() { + const { t } = useTranslation(); + const [isSyncing, setIsSyncing] = useState(false); + + const isMac = window.electronAPI?.platform === "darwin"; + + const availabilityQuery = + api.models.checkAppleIntelligenceAvailability.useQuery(undefined, { + enabled: isMac, + }); + + const utils = api.useUtils(); + const syncMutation = api.models.syncAppleIntelligenceModel.useMutation({ + onMutate: () => setIsSyncing(true), + onSuccess: (result) => { + setIsSyncing(false); + if (result.available) { + toast.success(t("settings.aiModels.appleIntelligence.toast.synced")); + utils.models.getSyncedProviderModels.invalidate(); + utils.models.getDefaultLanguageModel.invalidate(); + utils.models.getModels.invalidate(); + } else { + toast.error( + t("settings.aiModels.appleIntelligence.toast.notAvailable"), + ); + } + }, + onError: () => { + setIsSyncing(false); + toast.error(t("settings.aiModels.appleIntelligence.toast.syncFailed")); + }, + }); + + if (!isMac) return null; + + const available = availabilityQuery.data?.available ?? false; + const reason = availabilityQuery.data?.reason; + const isLoading = availabilityQuery.isLoading; + + return ( +
+
+
+ + {t("settings.aiModels.providers.appleIntelligence")} + + {isLoading ? ( + + + {t("settings.aiModels.appleIntelligence.checking")} + + ) : ( + + + {available + ? t("settings.aiModels.appleIntelligence.available") + : t("settings.aiModels.appleIntelligence.unavailable")} + + )} +
+ {available && ( + + )} +
+

+ {available + ? t("settings.aiModels.appleIntelligence.descriptionAvailable") + : reason + ? t("settings.aiModels.appleIntelligence.descriptionUnavailable", { + reason, + }) + : t( + "settings.aiModels.appleIntelligence.descriptionUnavailableGeneric", + )} +

+
+ ); +} diff --git a/apps/desktop/src/renderer/main/pages/settings/ai-models/tabs/LanguageTab.tsx b/apps/desktop/src/renderer/main/pages/settings/ai-models/tabs/LanguageTab.tsx index 6f84f300..01f7d29b 100644 --- a/apps/desktop/src/renderer/main/pages/settings/ai-models/tabs/LanguageTab.tsx +++ b/apps/desktop/src/renderer/main/pages/settings/ai-models/tabs/LanguageTab.tsx @@ -4,6 +4,7 @@ import { Accordion } from "@/components/ui/accordion"; import SyncedModelsList from "../components/synced-models-list"; import DefaultModelCombobox from "../components/default-model-combobox"; import ProviderAccordion from "../components/provider-accordion"; +import AppleIntelligenceProvider from "../components/apple-intelligence-provider"; import { useTranslation } from "react-i18next"; export default function LanguageTab() { @@ -17,6 +18,9 @@ export default function LanguageTab() { title={t("settings.aiModels.defaultModels.language")} /> + {/* Apple Intelligence (macOS only, auto-detected) */} + + {/* Providers Accordions */} diff --git a/apps/desktop/src/services/model-service.ts b/apps/desktop/src/services/model-service.ts index 06f88440..7c5848a4 100644 --- a/apps/desktop/src/services/model-service.ts +++ b/apps/desktop/src/services/model-service.ts @@ -31,6 +31,7 @@ import { } from "../types/providers"; import { SettingsService } from "./settings-service"; import { AuthService } from "./auth-service"; +import type { NativeBridge } from "./platform/native-bridge-service"; import { logger } from "../main/logger"; import { getUserAgent } from "../utils/http-client"; @@ -822,6 +823,67 @@ class ModelService extends EventEmitter { } } + // ============================================ + // Apple Intelligence Model Sync + // ============================================ + + /** + * Sync Apple Intelligence model based on Foundation Model availability. + * Registers the model if available, removes it if not. + */ + async syncAppleIntelligenceModel( + nativeBridge: NativeBridge, + ): Promise<{ available: boolean; reason?: string }> { + if (process.platform !== "darwin") { + return { available: false, reason: "notMacOS" }; + } + + try { + const result = await nativeBridge.call( + "checkFoundationModelAvailability", + {}, + ); + + if (result.available) { + await upsertModel({ + id: "apple-intelligence", + provider: "AppleIntelligence", + name: "Apple Intelligence", + type: "language", + description: "On-device Apple Intelligence model", + size: null, + context: null, + checksum: null, + speed: null, + accuracy: null, + localPath: null, + sizeBytes: null, + downloadedAt: null, + originalModel: null, + }); + logger.main.info( + "Apple Intelligence model registered (Foundation Model available)", + ); + } else { + // Remove from DB if previously registered + await removeModel("AppleIntelligence", "apple-intelligence").catch( + () => {}, + ); + logger.main.info( + "Apple Intelligence model not available, removed from DB", + { reason: result.reason }, + ); + } + + return { available: result.available, reason: result.reason }; + } catch (error) { + logger.main.warn("Failed to check Apple Intelligence availability", { + error: error instanceof Error ? error.message : String(error), + }); + return { available: false, reason: "checkFailed" }; + } + } + // ============================================ // Provider Model Methods (OpenRouter, Ollama) // ============================================ diff --git a/apps/desktop/src/services/platform/native-bridge-service.ts b/apps/desktop/src/services/platform/native-bridge-service.ts index 0acd28f1..c41b2bea 100644 --- a/apps/desktop/src/services/platform/native-bridge-service.ts +++ b/apps/desktop/src/services/platform/native-bridge-service.ts @@ -44,6 +44,12 @@ import { RecheckPressedKeysParams, RecheckPressedKeysResult, RecheckPressedKeysResultSchema, + CheckFoundationModelAvailabilityParams, + CheckFoundationModelAvailabilityResult, + CheckFoundationModelAvailabilityResultSchema, + GenerateWithFoundationModelParams, + GenerateWithFoundationModelResult, + GenerateWithFoundationModelResultSchema, AppContext, } from "@amical/types"; @@ -85,6 +91,14 @@ interface RPCMethods { params: RecheckPressedKeysParams; result: RecheckPressedKeysResult; }; + checkFoundationModelAvailability: { + params: CheckFoundationModelAvailabilityParams; + result: CheckFoundationModelAvailabilityResult; + }; + generateWithFoundationModel: { + params: GenerateWithFoundationModelParams; + result: GenerateWithFoundationModelResult; + }; } type PendingRpc = { @@ -108,6 +122,8 @@ const RPC_RESULT_SCHEMAS: Record = { restoreSystemAudio: RestoreSystemAudioResultSchema, setShortcuts: SetShortcutsResultSchema, recheckPressedKeys: RecheckPressedKeysResultSchema, + checkFoundationModelAvailability: CheckFoundationModelAvailabilityResultSchema, + generateWithFoundationModel: GenerateWithFoundationModelResultSchema, }; class NativeBridgeTimeoutError extends Error { diff --git a/apps/desktop/src/services/transcription-service.ts b/apps/desktop/src/services/transcription-service.ts index 17f59718..a1009731 100644 --- a/apps/desktop/src/services/transcription-service.ts +++ b/apps/desktop/src/services/transcription-service.ts @@ -10,6 +10,7 @@ import { WhisperProvider } from "../pipeline/providers/transcription/whisper-pro import { AmicalCloudProvider } from "../pipeline/providers/transcription/amical-cloud-provider"; import { OpenRouterProvider } from "../pipeline/providers/formatting/openrouter-formatter"; import { OllamaFormatter } from "../pipeline/providers/formatting/ollama-formatter"; +import { AppleIntelligenceFormatter } from "../pipeline/providers/formatting/apple-intelligence-formatter"; import { ModelService } from "../services/model-service"; import { SettingsService } from "../services/settings-service"; import { TelemetryService } from "../services/telemetry-service"; @@ -784,6 +785,29 @@ export class TranscriptionService { formattingModel = modelId; } } + } else if (model.provider === "AppleIntelligence") { + if (!this.nativeBridge) { + logger.transcription.warn( + "Formatting skipped: NativeBridge not available for Apple Intelligence", + ); + } else { + logger.transcription.info("Starting formatting", { + provider: model.provider, + model: modelId, + }); + const provider = new AppleIntelligenceFormatter(this.nativeBridge); + const result = await this.formatWithProvider(provider, text, { + style: options.formattingStyle, + vocabulary: options.vocabulary, + accessibilityContext: options.accessibilityContext, + }); + if (result) { + text = result.text; + formattingDuration = result.duration; + formattingUsed = true; + formattingModel = modelId; + } + } } else { logger.transcription.warn( "Formatting skipped: unsupported provider", diff --git a/apps/desktop/src/trpc/routers/models.ts b/apps/desktop/src/trpc/routers/models.ts index e65f5669..f8d74da1 100644 --- a/apps/desktop/src/trpc/routers/models.ts +++ b/apps/desktop/src/trpc/routers/models.ts @@ -459,6 +459,34 @@ export const modelsRouter = createRouter({ return true; }), + // Apple Intelligence + checkAppleIntelligenceAvailability: procedure.query(async ({ ctx }) => { + const nativeBridge = ctx.serviceManager.getService("nativeBridge"); + if (!nativeBridge) { + return { available: false, reason: "nativeBridgeUnavailable" }; + } + try { + return await nativeBridge.call( + "checkFoundationModelAvailability", + {}, + ); + } catch { + return { available: false, reason: "checkFailed" }; + } + }), + + syncAppleIntelligenceModel: procedure.mutation(async ({ ctx }) => { + const modelService = ctx.serviceManager.getService("modelService"); + if (!modelService) { + throw new Error("Model manager service not initialized"); + } + const nativeBridge = ctx.serviceManager.getService("nativeBridge"); + if (!nativeBridge) { + return { available: false, reason: "nativeBridgeUnavailable" }; + } + return await modelService.syncAppleIntelligenceModel(nativeBridge); + }), + removeOllamaProvider: procedure.mutation(async ({ ctx }) => { const modelService = ctx.serviceManager.getService("modelService"); if (!modelService) { diff --git a/apps/desktop/tests/pipeline/apple-intelligence-formatter.test.ts b/apps/desktop/tests/pipeline/apple-intelligence-formatter.test.ts new file mode 100644 index 00000000..4fae783b --- /dev/null +++ b/apps/desktop/tests/pipeline/apple-intelligence-formatter.test.ts @@ -0,0 +1,171 @@ +import { describe, it, expect, vi, beforeEach } from "vitest"; +import { AppleIntelligenceFormatter } from "../../src/pipeline/providers/formatting/apple-intelligence-formatter"; + +// Mock the logger +vi.mock("../../src/main/logger", () => ({ + logger: { + pipeline: { + debug: vi.fn(), + info: vi.fn(), + error: vi.fn(), + warn: vi.fn(), + }, + }, +})); + +describe("AppleIntelligenceFormatter", () => { + let mockNativeBridge: { + call: ReturnType; + isHelperRunning: ReturnType; + }; + + beforeEach(() => { + mockNativeBridge = { + call: vi.fn(), + isHelperRunning: vi.fn(() => true), + }; + }); + + it("should have name 'apple-intelligence'", () => { + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + expect(formatter.name).toBe("apple-intelligence"); + }); + + describe("format", () => { + it("should call generateWithFoundationModel via NativeBridge", async () => { + mockNativeBridge.call.mockResolvedValue({ + content: "Hello world", + }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + const result = await formatter.format({ + text: "hello world", + context: {}, + }); + + expect(result).toBe("Hello world"); + expect(mockNativeBridge.call).toHaveBeenCalledWith( + "generateWithFoundationModel", + expect.objectContaining({ + userPrompt: expect.stringContaining("hello world"), + }), + expect.any(Number), + ); + }); + + it("should extract text from tags", async () => { + mockNativeBridge.call.mockResolvedValue({ + content: + "Some preamble Formatted output trailing", + }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + const result = await formatter.format({ text: "test", context: {} }); + expect(result).toBe("Formatted output"); + }); + + it("should return raw content when no tags present", async () => { + mockNativeBridge.call.mockResolvedValue({ content: "Raw response" }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + const result = await formatter.format({ text: "test", context: {} }); + expect(result).toBe("Raw response"); + }); + + it("should return original text on NativeBridge error", async () => { + mockNativeBridge.call.mockRejectedValue(new Error("Helper crashed")); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + const result = await formatter.format({ + text: "original text", + context: {}, + }); + expect(result).toBe("original text"); + }); + + it("should pass system prompt from constructFormatterPrompt", async () => { + mockNativeBridge.call.mockResolvedValue({ content: "formatted" }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + await formatter.format({ + text: "test", + context: { vocabulary: ["Amical"] }, + }); + + expect(mockNativeBridge.call).toHaveBeenCalledWith( + "generateWithFoundationModel", + expect.objectContaining({ + systemPrompt: expect.stringContaining("Markdown"), + }), + expect.any(Number), + ); + }); + + it("should set temperature to 0.1 for consistent formatting", async () => { + mockNativeBridge.call.mockResolvedValue({ content: "formatted" }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + await formatter.format({ text: "test", context: {} }); + + expect(mockNativeBridge.call).toHaveBeenCalledWith( + "generateWithFoundationModel", + expect.objectContaining({ temperature: 0.1 }), + expect.any(Number), + ); + }); + + it("should use 30 second timeout for Foundation Model calls", async () => { + mockNativeBridge.call.mockResolvedValue({ content: "formatted" }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + await formatter.format({ text: "test", context: {} }); + + expect(mockNativeBridge.call).toHaveBeenCalledWith( + "generateWithFoundationModel", + expect.any(Object), + 30000, + ); + }); + + it("should fall back to original text when formatted_text tags are empty", async () => { + mockNativeBridge.call.mockResolvedValue({ + content: "", + }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + const result = await formatter.format({ text: "こんにちは", context: {} }); + expect(result).toBe("こんにちは"); + }); + + it("should handle multiline formatted text", async () => { + mockNativeBridge.call.mockResolvedValue({ + content: + "Line 1\nLine 2\nLine 3", + }); + + const formatter = new AppleIntelligenceFormatter( + mockNativeBridge as any, + ); + const result = await formatter.format({ text: "test", context: {} }); + expect(result).toBe("Line 1\nLine 2\nLine 3"); + }); + }); +}); diff --git a/apps/desktop/tests/services/model-service-apple-intelligence.test.ts b/apps/desktop/tests/services/model-service-apple-intelligence.test.ts new file mode 100644 index 00000000..ced7504e --- /dev/null +++ b/apps/desktop/tests/services/model-service-apple-intelligence.test.ts @@ -0,0 +1,190 @@ +import { describe, it, expect, vi, beforeEach, afterEach } from "vitest"; + +// Mock DB operations before importing ModelService +const mockUpsertModel = vi.fn(); +const mockRemoveModel = vi.fn().mockResolvedValue(undefined); +const mockGetModelsByProvider = vi.fn().mockResolvedValue([]); +const mockSyncLocalWhisperModels = vi + .fn() + .mockResolvedValue({ added: 0, updated: 0, removed: 0 }); +const mockGetDownloadedWhisperModels = vi.fn().mockResolvedValue([]); +const mockGetAllModels = vi.fn().mockResolvedValue([]); +const mockModelExists = vi.fn().mockResolvedValue(false); +const mockSyncModelsForProvider = vi.fn(); +const mockRemoveModelsForProvider = vi.fn(); +const mockGetModelById = vi.fn(); + +vi.mock("../../src/db/models", () => ({ + upsertModel: (...args: any[]) => mockUpsertModel(...args), + removeModel: (...args: any[]) => mockRemoveModel(...args), + getModelsByProvider: (...args: any[]) => mockGetModelsByProvider(...args), + syncLocalWhisperModels: (...args: any[]) => + mockSyncLocalWhisperModels(...args), + getDownloadedWhisperModels: () => mockGetDownloadedWhisperModels(), + getAllModels: () => mockGetAllModels(), + modelExists: (...args: any[]) => mockModelExists(...args), + syncModelsForProvider: (...args: any[]) => + mockSyncModelsForProvider(...args), + removeModelsForProvider: (...args: any[]) => + mockRemoveModelsForProvider(...args), + getModelById: (...args: any[]) => mockGetModelById(...args), +})); + +import { ModelService } from "../../src/services/model-service"; + +describe("ModelService - Apple Intelligence", () => { + let modelService: ModelService; + let mockNativeBridge: { + call: ReturnType; + isHelperRunning: ReturnType; + }; + let mockSettingsService: any; + + const originalPlatform = process.platform; + + beforeEach(() => { + vi.clearAllMocks(); + + mockNativeBridge = { + call: vi.fn(), + isHelperRunning: vi.fn(() => true), + }; + + mockSettingsService = { + getDefaultSpeechModel: vi.fn().mockResolvedValue(null), + getDefaultLanguageModel: vi.fn().mockResolvedValue(null), + getDefaultEmbeddingModel: vi.fn().mockResolvedValue(null), + setDefaultSpeechModel: vi.fn().mockResolvedValue(undefined), + setDefaultLanguageModel: vi.fn().mockResolvedValue(undefined), + setDefaultEmbeddingModel: vi.fn().mockResolvedValue(undefined), + getModelProvidersConfig: vi.fn().mockResolvedValue({}), + setModelProvidersConfig: vi.fn().mockResolvedValue(undefined), + getFormatterConfig: vi.fn().mockResolvedValue({ enabled: false }), + setFormatterConfig: vi.fn().mockResolvedValue(undefined), + }; + + modelService = new ModelService(mockSettingsService); + }); + + afterEach(() => { + Object.defineProperty(process, "platform", { value: originalPlatform }); + }); + + describe("syncAppleIntelligenceModel", () => { + it("should register model when Foundation Model is available", async () => { + Object.defineProperty(process, "platform", { value: "darwin" }); + mockNativeBridge.call.mockResolvedValue({ available: true }); + + const result = await modelService.syncAppleIntelligenceModel( + mockNativeBridge as any, + ); + + expect(result).toEqual({ available: true, reason: undefined }); + expect(mockUpsertModel).toHaveBeenCalledWith( + expect.objectContaining({ + id: "apple-intelligence", + provider: "AppleIntelligence", + name: "Apple Intelligence", + type: "language", + }), + ); + }); + + it("should remove model when Foundation Model is not available", async () => { + Object.defineProperty(process, "platform", { value: "darwin" }); + mockNativeBridge.call.mockResolvedValue({ + available: false, + reason: "deviceNotEligible", + }); + + const result = await modelService.syncAppleIntelligenceModel( + mockNativeBridge as any, + ); + + expect(result).toEqual({ + available: false, + reason: "deviceNotEligible", + }); + expect(mockRemoveModel).toHaveBeenCalledWith( + "AppleIntelligence", + "apple-intelligence", + ); + expect(mockUpsertModel).not.toHaveBeenCalled(); + }); + + it("should skip on non-macOS platforms", async () => { + Object.defineProperty(process, "platform", { value: "win32" }); + + const result = await modelService.syncAppleIntelligenceModel( + mockNativeBridge as any, + ); + + expect(result).toEqual({ available: false, reason: "notMacOS" }); + expect(mockNativeBridge.call).not.toHaveBeenCalled(); + }); + + it("should not throw on NativeBridge errors", async () => { + Object.defineProperty(process, "platform", { value: "darwin" }); + mockNativeBridge.call.mockRejectedValue(new Error("Helper crashed")); + + const result = await modelService.syncAppleIntelligenceModel( + mockNativeBridge as any, + ); + + expect(result).toEqual({ available: false, reason: "checkFailed" }); + }); + + it("should register with correct model metadata", async () => { + Object.defineProperty(process, "platform", { value: "darwin" }); + mockNativeBridge.call.mockResolvedValue({ available: true }); + + await modelService.syncAppleIntelligenceModel( + mockNativeBridge as any, + ); + + expect(mockUpsertModel).toHaveBeenCalledWith({ + id: "apple-intelligence", + provider: "AppleIntelligence", + name: "Apple Intelligence", + type: "language", + description: "On-device Apple Intelligence model", + size: null, + context: null, + checksum: null, + speed: null, + accuracy: null, + localPath: null, + sizeBytes: null, + downloadedAt: null, + originalModel: null, + }); + }); + + it("should call checkFoundationModelAvailability on NativeBridge", async () => { + Object.defineProperty(process, "platform", { value: "darwin" }); + mockNativeBridge.call.mockResolvedValue({ available: false }); + + await modelService.syncAppleIntelligenceModel( + mockNativeBridge as any, + ); + + expect(mockNativeBridge.call).toHaveBeenCalledWith( + "checkFoundationModelAvailability", + {}, + ); + }); + + it("should not throw when removeModel fails (model not previously registered)", async () => { + Object.defineProperty(process, "platform", { value: "darwin" }); + mockNativeBridge.call.mockResolvedValue({ available: false }); + mockRemoveModel.mockRejectedValue(new Error("Not found")); + + const result = await modelService.syncAppleIntelligenceModel( + mockNativeBridge as any, + ); + + // Should succeed without throwing + expect(result).toEqual({ available: false, reason: undefined }); + }); + }); +}); diff --git a/packages/native-helpers/swift-helper/Sources/SwiftHelper/RpcHandler.swift b/packages/native-helpers/swift-helper/Sources/SwiftHelper/RpcHandler.swift index 5132fb74..9ff2ff9a 100644 --- a/packages/native-helpers/swift-helper/Sources/SwiftHelper/RpcHandler.swift +++ b/packages/native-helpers/swift-helper/Sources/SwiftHelper/RpcHandler.swift @@ -13,12 +13,14 @@ class IOBridge: NSObject { let jsonDecoder: JSONDecoder private let accessibilityService: AccessibilityService private let audioService: AudioService + private let foundationModelService: FoundationModelService init(jsonEncoder: JSONEncoder, jsonDecoder: JSONDecoder) { self.jsonEncoder = jsonEncoder self.jsonDecoder = jsonDecoder self.accessibilityService = AccessibilityService() self.audioService = AudioService() // Audio preloaded here at startup + self.foundationModelService = FoundationModelService() super.init() } @@ -229,6 +231,51 @@ class IOBridge: NSObject { rpcResponse = RPCResponseSchema(error: errPayload, id: request.id, result: nil) } + case .checkFoundationModelAvailability: + logToStderr("[IOBridge] Handling checkFoundationModelAvailability for ID: \(request.id)") + let result = foundationModelService.checkAvailability() + sendResult(id: request.id, result: result) + return + + case .generateWithFoundationModel: + logToStderr("[IOBridge] Handling generateWithFoundationModel for ID: \(request.id)") + guard let paramsAnyCodable = request.params else { + let errPayload = Error( + code: -32602, data: nil, message: "Missing params for generateWithFoundationModel") + rpcResponse = RPCResponseSchema(error: errPayload, id: request.id, result: nil) + sendRpcResponse(rpcResponse) + return + } + + do { + let paramsData = try jsonEncoder.encode(paramsAnyCodable) + let generateParams = try jsonDecoder.decode( + GenerateWithFoundationModelParamsSchema.self, from: paramsData) + + // Run async Foundation Model call on a background queue + let requestId = request.id + Task { + do { + let result = try await self.foundationModelService.generate(params: generateParams) + self.sendResult(id: requestId, result: result) + } catch { + self.logToStderr( + "[IOBridge] Error in generateWithFoundationModel: \(error.localizedDescription) for ID: \(requestId)" + ) + self.sendError(id: requestId, code: -32603, + message: "Foundation Model error: \(error.localizedDescription)") + } + } + return + } catch { + logToStderr( + "[IOBridge] Error decoding generateWithFoundationModel params: \(error.localizedDescription) for ID: \(request.id)" + ) + sendError(id: request.id, code: -32602, + message: "Invalid params: \(error.localizedDescription)") + return + } + default: logToStderr("[IOBridge] Method not found: \(request.method) for ID: \(request.id)") let errPayload = Error( diff --git a/packages/native-helpers/swift-helper/Sources/SwiftHelper/services/FoundationModelService.swift b/packages/native-helpers/swift-helper/Sources/SwiftHelper/services/FoundationModelService.swift new file mode 100644 index 00000000..3251eee1 --- /dev/null +++ b/packages/native-helpers/swift-helper/Sources/SwiftHelper/services/FoundationModelService.swift @@ -0,0 +1,45 @@ +import Foundation + +#if canImport(FoundationModels) +import FoundationModels +#endif + +class FoundationModelService { + + func checkAvailability() -> CheckFoundationModelAvailabilityResultSchema { + #if canImport(FoundationModels) + if #available(macOS 26, *) { + let model = SystemLanguageModel.default + switch model.availability { + case .available: + return CheckFoundationModelAvailabilityResultSchema(available: true, reason: nil) + case .unavailable(let reason): + return CheckFoundationModelAvailabilityResultSchema(available: false, reason: String(describing: reason)) + @unknown default: + return CheckFoundationModelAvailabilityResultSchema(available: false, reason: "unknown") + } + } + #endif + return CheckFoundationModelAvailabilityResultSchema(available: false, reason: "deviceNotEligible") + } + + func generate(params: GenerateWithFoundationModelParamsSchema) async throws -> GenerateWithFoundationModelResultSchema { + #if canImport(FoundationModels) + if #available(macOS 26, *) { + let instructions = params.systemPrompt + let session = LanguageModelSession(instructions: instructions) + var options = GenerationOptions() + if let temperature = params.temperature { + options.temperature = temperature + } + if let maxTokens = params.maxTokens { + options.maximumResponseTokens = Int(maxTokens) + } + let response = try await session.respond(to: params.userPrompt, options: options) + return GenerateWithFoundationModelResultSchema(content: response.content) + } + #endif + throw NSError(domain: "FoundationModelService", code: -1, + userInfo: [NSLocalizedDescriptionKey: "Foundation Models not available on this device"]) + } +} diff --git a/packages/native-helpers/windows-helper/src/Models/Generated/Models.cs b/packages/native-helpers/windows-helper/src/Models/Generated/Models.cs index 9abebf9d..c04aa43e 100644 --- a/packages/native-helpers/windows-helper/src/Models/Generated/Models.cs +++ b/packages/native-helpers/windows-helper/src/Models/Generated/Models.cs @@ -537,7 +537,7 @@ public partial class HelperEventPayload public bool? ShiftKey { get; set; } } - public enum Method { GetAccessibilityContext, GetAccessibilityStatus, GetAccessibilityTreeDetails, MuteSystemAudio, PasteText, RecheckPressedKeys, RequestAccessibilityPermission, RestoreSystemAudio, SetShortcuts }; + public enum Method { CheckFoundationModelAvailability, GenerateWithFoundationModel, GetAccessibilityContext, GetAccessibilityStatus, GetAccessibilityTreeDetails, MuteSystemAudio, PasteText, RecheckPressedKeys, RequestAccessibilityPermission, RestoreSystemAudio, SetShortcuts }; public enum The0 { ClipboardCopy, None, SelectedTextRange, SelectedTextRanges, StringForRange, TextMarkerRange, ValueAttribute }; @@ -703,6 +703,10 @@ public override Method Read(ref Utf8JsonReader reader, Type typeToConvert, JsonS var value = reader.GetString(); switch (value) { + case "checkFoundationModelAvailability": + return Method.CheckFoundationModelAvailability; + case "generateWithFoundationModel": + return Method.GenerateWithFoundationModel; case "getAccessibilityContext": return Method.GetAccessibilityContext; case "getAccessibilityStatus": @@ -729,6 +733,12 @@ public override void Write(Utf8JsonWriter writer, Method value, JsonSerializerOp { switch (value) { + case Method.CheckFoundationModelAvailability: + JsonSerializer.Serialize(writer, "checkFoundationModelAvailability", options); + return; + case Method.GenerateWithFoundationModel: + JsonSerializer.Serialize(writer, "generateWithFoundationModel", options); + return; case Method.GetAccessibilityContext: JsonSerializer.Serialize(writer, "getAccessibilityContext", options); return; diff --git a/packages/types/scripts/generate-json-schemas.ts b/packages/types/scripts/generate-json-schemas.ts index 7f8510cc..ca47b82f 100644 --- a/packages/types/scripts/generate-json-schemas.ts +++ b/packages/types/scripts/generate-json-schemas.ts @@ -33,6 +33,14 @@ import { RecheckPressedKeysParamsSchema, RecheckPressedKeysResultSchema, } from "../src/schemas/methods/recheck-pressed-keys.js"; +import { + CheckFoundationModelAvailabilityParamsSchema, + CheckFoundationModelAvailabilityResultSchema, +} from "../src/schemas/methods/check-foundation-model-availability.js"; +import { + GenerateWithFoundationModelParamsSchema, + GenerateWithFoundationModelResultSchema, +} from "../src/schemas/methods/generate-with-foundation-model.js"; import { KeyDownEventSchema, KeyUpEventSchema, @@ -116,6 +124,26 @@ const schemasToGenerate = [ name: "RecheckPressedKeysResult", category: "methods", }, + { + zod: CheckFoundationModelAvailabilityParamsSchema, + name: "CheckFoundationModelAvailabilityParams", + category: "methods", + }, + { + zod: CheckFoundationModelAvailabilityResultSchema, + name: "CheckFoundationModelAvailabilityResult", + category: "methods", + }, + { + zod: GenerateWithFoundationModelParamsSchema, + name: "GenerateWithFoundationModelParams", + category: "methods", + }, + { + zod: GenerateWithFoundationModelResultSchema, + name: "GenerateWithFoundationModelResult", + category: "methods", + }, ]; schemasToGenerate.forEach(({ zod, name, category }) => { diff --git a/packages/types/scripts/generate-swift-models.ts b/packages/types/scripts/generate-swift-models.ts index af73d0da..626ab246 100644 --- a/packages/types/scripts/generate-swift-models.ts +++ b/packages/types/scripts/generate-swift-models.ts @@ -36,7 +36,11 @@ try { "generated/json-schemas/events/key-down-event.schema.json " + "generated/json-schemas/events/key-up-event.schema.json " + "generated/json-schemas/events/flags-changed-event.schema.json " + - "generated/json-schemas/events/helper-event.schema.json", + "generated/json-schemas/events/helper-event.schema.json " + + "generated/json-schemas/methods/check-foundation-model-availability-params.schema.json " + + "generated/json-schemas/methods/check-foundation-model-availability-result.schema.json " + + "generated/json-schemas/methods/generate-with-foundation-model-params.schema.json " + + "generated/json-schemas/methods/generate-with-foundation-model-result.schema.json", ]; commands.forEach((command) => { diff --git a/packages/types/src/index.ts b/packages/types/src/index.ts index d42c7c76..7e373bea 100644 --- a/packages/types/src/index.ts +++ b/packages/types/src/index.ts @@ -10,6 +10,8 @@ export * from "./schemas/methods/mute-system-audio.js"; export * from "./schemas/methods/restore-system-audio.js"; export * from "./schemas/methods/set-shortcuts.js"; export * from "./schemas/methods/recheck-pressed-keys.js"; +export * from "./schemas/methods/check-foundation-model-availability.js"; +export * from "./schemas/methods/generate-with-foundation-model.js"; // Event Schemas export * from "./schemas/events/key-events.js"; diff --git a/packages/types/src/schemas/methods/check-foundation-model-availability.ts b/packages/types/src/schemas/methods/check-foundation-model-availability.ts new file mode 100644 index 00000000..074ed6bc --- /dev/null +++ b/packages/types/src/schemas/methods/check-foundation-model-availability.ts @@ -0,0 +1,18 @@ +import { z } from "zod"; + +// Request params +export const CheckFoundationModelAvailabilityParamsSchema = z + .object({}) + .optional(); +export type CheckFoundationModelAvailabilityParams = z.infer< + typeof CheckFoundationModelAvailabilityParamsSchema +>; + +// Response result +export const CheckFoundationModelAvailabilityResultSchema = z.object({ + available: z.boolean(), + reason: z.string().optional(), +}); +export type CheckFoundationModelAvailabilityResult = z.infer< + typeof CheckFoundationModelAvailabilityResultSchema +>; diff --git a/packages/types/src/schemas/methods/generate-with-foundation-model.ts b/packages/types/src/schemas/methods/generate-with-foundation-model.ts new file mode 100644 index 00000000..618be382 --- /dev/null +++ b/packages/types/src/schemas/methods/generate-with-foundation-model.ts @@ -0,0 +1,20 @@ +import { z } from "zod"; + +// Request params +export const GenerateWithFoundationModelParamsSchema = z.object({ + systemPrompt: z.string(), + userPrompt: z.string(), + temperature: z.number().optional(), + maxTokens: z.number().optional(), +}); +export type GenerateWithFoundationModelParams = z.infer< + typeof GenerateWithFoundationModelParamsSchema +>; + +// Response result +export const GenerateWithFoundationModelResultSchema = z.object({ + content: z.string(), +}); +export type GenerateWithFoundationModelResult = z.infer< + typeof GenerateWithFoundationModelResultSchema +>; diff --git a/packages/types/src/schemas/rpc/request.ts b/packages/types/src/schemas/rpc/request.ts index 9037f535..b44b7d6a 100644 --- a/packages/types/src/schemas/rpc/request.ts +++ b/packages/types/src/schemas/rpc/request.ts @@ -15,6 +15,8 @@ const RPCMethodNameSchema = z.union([ z.literal("restoreSystemAudio"), z.literal("setShortcuts"), z.literal("recheckPressedKeys"), + z.literal("checkFoundationModelAvailability"), + z.literal("generateWithFoundationModel"), ]); export const RpcRequestSchema = z.object({ diff --git a/packages/types/tests/foundation-model-schemas.test.ts b/packages/types/tests/foundation-model-schemas.test.ts new file mode 100644 index 00000000..c620dc69 --- /dev/null +++ b/packages/types/tests/foundation-model-schemas.test.ts @@ -0,0 +1,127 @@ +import { describe, it, expect } from "vitest"; +import { + CheckFoundationModelAvailabilityResultSchema, +} from "../src/schemas/methods/check-foundation-model-availability"; +import { + GenerateWithFoundationModelParamsSchema, + GenerateWithFoundationModelResultSchema, +} from "../src/schemas/methods/generate-with-foundation-model"; + +describe("Foundation Model Schemas", () => { + describe("CheckFoundationModelAvailabilityResultSchema", () => { + it("should accept available result", () => { + const result = CheckFoundationModelAvailabilityResultSchema.parse({ + available: true, + }); + expect(result).toEqual({ available: true }); + }); + + it("should accept unavailable with reason", () => { + const result = CheckFoundationModelAvailabilityResultSchema.parse({ + available: false, + reason: "deviceNotEligible", + }); + expect(result).toEqual({ + available: false, + reason: "deviceNotEligible", + }); + }); + + it("should accept unavailable without reason", () => { + const result = CheckFoundationModelAvailabilityResultSchema.parse({ + available: false, + }); + expect(result).toEqual({ available: false }); + }); + + it("should reject missing available field", () => { + expect(() => + CheckFoundationModelAvailabilityResultSchema.parse({}), + ).toThrow(); + }); + + it("should reject non-boolean available", () => { + expect(() => + CheckFoundationModelAvailabilityResultSchema.parse({ + available: "yes", + }), + ).toThrow(); + }); + }); + + describe("GenerateWithFoundationModelParamsSchema", () => { + it("should accept required fields", () => { + const result = GenerateWithFoundationModelParamsSchema.parse({ + systemPrompt: "sys", + userPrompt: "user", + }); + expect(result).toBeDefined(); + expect(result.systemPrompt).toBe("sys"); + expect(result.userPrompt).toBe("user"); + }); + + it("should accept optional temperature and maxTokens", () => { + const result = GenerateWithFoundationModelParamsSchema.parse({ + systemPrompt: "sys", + userPrompt: "user", + temperature: 0.5, + maxTokens: 1000, + }); + expect(result.temperature).toBe(0.5); + expect(result.maxTokens).toBe(1000); + }); + + it("should reject missing systemPrompt", () => { + expect(() => + GenerateWithFoundationModelParamsSchema.parse({ + userPrompt: "user", + }), + ).toThrow(); + }); + + it("should reject missing userPrompt", () => { + expect(() => + GenerateWithFoundationModelParamsSchema.parse({ + systemPrompt: "sys", + }), + ).toThrow(); + }); + + it("should reject non-string systemPrompt", () => { + expect(() => + GenerateWithFoundationModelParamsSchema.parse({ + systemPrompt: 123, + userPrompt: "user", + }), + ).toThrow(); + }); + }); + + describe("GenerateWithFoundationModelResultSchema", () => { + it("should accept content string", () => { + const result = GenerateWithFoundationModelResultSchema.parse({ + content: "hello", + }); + expect(result).toEqual({ content: "hello" }); + }); + + it("should accept empty content string", () => { + const result = GenerateWithFoundationModelResultSchema.parse({ + content: "", + }); + expect(result).toEqual({ content: "" }); + }); + + it("should reject missing content", () => { + expect(() => + GenerateWithFoundationModelResultSchema.parse({}), + ).toThrow(); + }); + + it("should reject non-string content", () => { + expect(() => + GenerateWithFoundationModelResultSchema.parse({ content: 123 }), + ).toThrow(); + }); + }); +});