Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
27 changes: 23 additions & 4 deletions Sources/AnyLanguageModel/Models/MLXLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ import Foundation
import MLX
import MLXVLM
import Tokenizers
import Hub

/// A language model that runs locally using MLX.
///
Expand All @@ -29,14 +30,25 @@ import Foundation
/// This model is always available.
public typealias UnavailableReason = Never

/// The model identifier from the MLX community on Hugging Face.
/// The model identifier.
public let modelId: String

/// The Hub API instance for downloading models.
public let hub: HubApi?

/// The local directory containing the model files.
public let directory: URL?

/// Creates an MLX language model.
///
/// - Parameter modelId: The Hugging Face model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit").
public init(modelId: String) {
/// - Parameters:
/// - modelId: The model identifier (for example, "mlx-community/Llama-3.2-3B-Instruct-4bit").
/// - hub: An optional Hub API instance for downloading models. If not provided, the default Hub API is used.
/// - directory: An optional local directory URL containing the model files. If provided, the model is loaded from this directory instead of downloading.
public init(modelId: String, hub: HubApi? = nil, directory: URL? = nil) {
self.modelId = modelId
self.hub = hub
self.directory = directory
}

public func respond<Content>(
Expand All @@ -51,7 +63,14 @@ import Foundation
fatalError("MLXLanguageModel only supports generating String content")
}

let context = try await loadModel(id: modelId)
let context: ModelContext
if let directory {
context = try await loadModel(directory: directory)
} else if let hub {
context = try await loadModel(hub: hub, id: modelId)
} else {
context = try await loadModel(id: modelId)
}

// Convert session tools to MLX ToolSpec format
let toolSpecs: [ToolSpec]? =
Expand Down