diff --git a/README.md b/README.md index 8ce90977..79ffc1ff 100644 --- a/README.md +++ b/README.md @@ -618,28 +618,51 @@ swift test Tests for different language model backends have varying requirements: -- **CoreML tests**: `swift test --traits CoreML` + `ENABLE_COREML_TESTS=1` + `HF_TOKEN` (downloads model from HuggingFace) -- **MLX tests**: `swift test --traits MLX` + `ENABLE_MLX_TESTS=1` + `HF_TOKEN` (uses pre-defined model) -- **Llama tests**: `swift test --traits Llama` + `LLAMA_MODEL_PATH` (points to local GGUF file) -- **Anthropic tests**: `ANTHROPIC_API_KEY` (no traits needed) -- **OpenAI tests**: `OPENAI_API_KEY` (no traits needed) -- **Ollama tests**: No setup needed (skips in CI) +| Backend | Traits | Environment Variables | +|---------|--------|----------------------| +| CoreML | `CoreML` | `HF_TOKEN` | +| MLX | `MLX` | `HF_TOKEN` | +| Llama | `Llama` | `LLAMA_MODEL_PATH` | +| Anthropic | — | `ANTHROPIC_API_KEY` | +| OpenAI | — | `OPENAI_API_KEY` | +| Ollama | — | — | -Example setup for all backends: +Example setup for running multiple tests at once: ```bash -# Environment variables -export ENABLE_COREML_TESTS=1 -export ENABLE_MLX_TESTS=1 export HF_TOKEN=your_huggingface_token export LLAMA_MODEL_PATH=/path/to/model.gguf export ANTHROPIC_API_KEY=your_anthropic_key export OPENAI_API_KEY=your_openai_key -# Run all tests with traits enabled -swift test --traits CoreML,MLX,Llama +swift test --traits CoreML,Llama ``` +> [!TIP] +> Tests that perform generation are skipped in CI environments (when `CI` is set). +> Override this by setting `ENABLE_COREML_TESTS=1` or `ENABLE_MLX_TESTS=1`. + +> [!NOTE] +> MLX tests must be run with `xcodebuild` rather than `swift test` +> due to Metal library loading requirements. +> Since `xcodebuild` doesn't support package traits directly, +> you'll first need to update `Package.swift` to enable the MLX trait by default. +> +> ```diff +> - .default(enabledTraits: []), +> + .default(enabledTraits: ["MLX"]), +> ``` +> +> Pass environment variables with `TEST_RUNNER_` prefix: +> +> ```bash +> export TEST_RUNNER_HF_TOKEN=your_huggingface_token +> xcodebuild test \ +> -scheme AnyLanguageModel \ +> -destination 'platform=macOS' \ +> -only-testing:AnyLanguageModelTests/MLXLanguageModelTests +> ``` + ## License This project is available under the MIT license. diff --git a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift index b1c7b564..379aa6a7 100644 --- a/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift +++ b/Sources/AnyLanguageModel/Models/MLXLanguageModel.swift @@ -83,10 +83,20 @@ import Foundation // Map AnyLanguageModel GenerationOptions to MLX GenerateParameters let generateParameters = toGenerateParameters(options) - // Start with user prompt + // Build chat history starting with system message if instructions are present + var chat: [MLXLMCommon.Chat.Message] = [] + + // Add system message if instructions are present + if let instructionSegments = extractInstructionSegments(from: session) { + let systemMessage = convertSegmentsToMLXSystemMessage(instructionSegments) + chat.append(systemMessage) + } + + // Add user prompt let userSegments = extractPromptSegments(from: session, fallbackText: prompt.description) let userMessage = convertSegmentsToMLXMessage(userSegments) - var chat: [MLXLMCommon.Chat.Message] = [userMessage] + chat.append(userMessage) + var allTextChunks: [String] = [] var allEntries: [Transcript.Entry] = [] @@ -211,6 +221,20 @@ import Foundation return [.text(.init(content: fallbackText))] } + private func extractInstructionSegments(from session: LanguageModelSession) -> [Transcript.Segment]? { + // Prefer the first Transcript.Instructions entry if present + for entry in session.transcript { + if case .instructions(let i) = entry { + return i.segments + } + } + // Fallback to session.instructions + if let instructions = session.instructions?.description, !instructions.isEmpty { + return [.text(.init(content: instructions))] + } + return nil + } + private func convertSegmentsToMLXMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { var textParts: [String] = [] var images: [MLXLMCommon.UserInput.Image] = [] @@ -248,6 +272,43 @@ import Foundation return MLXLMCommon.Chat.Message(role: .user, content: content, images: images) } + private func convertSegmentsToMLXSystemMessage(_ segments: [Transcript.Segment]) -> MLXLMCommon.Chat.Message { + var textParts: [String] = [] + var images: [MLXLMCommon.UserInput.Image] = [] + + for segment in segments { + switch segment { + case .text(let text): + textParts.append(text.content) + case .structure(let structured): + textParts.append(structured.content.jsonString) + case .image(let imageSegment): + switch imageSegment.source { + case .url(let url): + images.append(.url(url)) + case .data(let data, _): + #if canImport(UIKit) + if let uiImage = UIKit.UIImage(data: data), + let ciImage = CIImage(image: uiImage) + { + images.append(.ciImage(ciImage)) + } + #elseif canImport(AppKit) + if let nsImage = AppKit.NSImage(data: data), + let cgImage = nsImage.cgImage(forProposedRect: nil, context: nil, hints: nil) + { + let ciImage = CIImage(cgImage: cgImage) + images.append(.ciImage(ciImage)) + } + #endif + } + } + } + + let content = textParts.joined(separator: "\n") + return MLXLMCommon.Chat.Message(role: .system, content: content, images: images) + } + // MARK: - Tool Conversion private func convertToolToMLXSpec(_ tool: any Tool) -> ToolSpec {