Skip to content

Commit 477fc27

Browse files
authored
Expose all models (#2)
This pull request exposes a simple actor for each supported model. The actual implementation of the models is `LLM`, which is just a convenience wrapper around MLX's stuff. Each model conforms to a new protocol called `ModelProtocol`. This allows us to add extra functions for each model in just a single place: ModelProtocol.swift. The first example of this is `request(_:maxTokenCount:)`. Because of the reentrancy problem with Swift Actors, `ModelProtocol.llm` is wrapped inside of `ActorLock`, which is taken from [Apple's swift-build/ASyncLock.swift](https://github.com/swiftlang/swift-build/blob/main/Sources/SWBUtil/AsyncLock.swift). According to MLX's documentation, the AI models themselves are not thread-safe, which means calls to them need to be serialized. However, because Swift Actors are reentrant, calling `try await llm.request(_:maxTokenCount:)` could immediately suspend and allow another reference to the same actor be called. This may not be a problem with library today, but I think it may be in the future, especially when we add support for `KVCache`. I think it's better to ensure that every call to `ModelProtocol.someFunc` is transactional, which is what we are doing by wrapping the implementation of every method on `ModelProtocol` inside of a `AsyncLock`.
1 parent fa48c3c commit 477fc27

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+1119
-127
lines changed

Sources/SHLLM/ActorLock.swift

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
//===----------------------------------------------------------------------===//
2+
//
3+
// This source file is part of the Swift open source project
4+
//
5+
// Copyright (c) 2025 Apple Inc. and the Swift project authors
6+
// Licensed under Apache License v2.0 with Runtime Library Exception
7+
//
8+
// See http://swift.org/LICENSE.txt for license information
9+
// See http://swift.org/CONTRIBUTORS.txt for the list of Swift project authors
10+
//
11+
//===----------------------------------------------------------------------===//
12+
13+
/// Lock intended for use within an actor in order to prevent reentrancy in actor methods which
14+
/// themselves contain suspension points.
15+
public actor ActorLock {
16+
private var busy = false
17+
private var queue: ArraySlice<CheckedContinuation<Void, Never>> = []
18+
19+
public init() {}
20+
21+
public func withLock<
22+
T: Sendable,
23+
E
24+
>(_ body: @Sendable () async throws(E) -> T) async throws(E) -> T {
25+
while busy {
26+
await withCheckedContinuation { cc in
27+
queue.append(cc)
28+
}
29+
}
30+
busy = true
31+
defer {
32+
busy = false
33+
if let next = queue.popFirst() {
34+
next.resume(returning: ())
35+
} else {
36+
queue = [] // reallocate buffer if it's empty
37+
}
38+
}
39+
return try await body()
40+
}
41+
}
42+
43+
/// Small concurrency-compatible wrapper to provide only locked, non-reentrant access to its
44+
/// value.
45+
public final class AsyncLockedValue<Wrapped: Sendable> {
46+
@usableFromInline let lock = ActorLock()
47+
/// Don't use this from outside this class. Is internal to be inlinable.
48+
@usableFromInline var value: Wrapped
49+
public init(_ value: Wrapped) {
50+
self.value = value
51+
}
52+
53+
@discardableResult @inlinable
54+
public func withLock<
55+
Result: Sendable,
56+
E
57+
>(_ block: @Sendable (inout Wrapped) async throws(E) -> Result) async throws(E)
58+
-> Result
59+
{
60+
try await lock.withLock { () throws(E) -> Result in try await block(&value) }
61+
}
62+
}
63+
64+
extension AsyncLockedValue: @unchecked Sendable where Wrapped: Sendable {}

Sources/SHLLM/LLM.swift

+198
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,198 @@
1+
import Foundation
2+
import struct Hub.Config
3+
import MLX
4+
import MLXLLM
5+
import MLXLMCommon
6+
import Tokenizers
7+
8+
public final class LLM {
9+
private let directory: URL
10+
private let context: ModelContext
11+
private let configuration: ModelConfiguration
12+
13+
static func cohere(directory: URL) async throws -> LLM {
14+
try await Self(
15+
directory: directory,
16+
modelInit: CohereModel.init
17+
)
18+
}
19+
20+
static func gemma(directory: URL) async throws -> LLM {
21+
try await Self(
22+
directory: directory,
23+
modelInit: GemmaModel.init
24+
)
25+
}
26+
27+
static func gemma2(directory: URL) async throws -> LLM {
28+
try await Self(
29+
directory: directory,
30+
modelInit: Gemma2Model.init
31+
)
32+
}
33+
34+
static func internLM2(directory: URL) async throws -> LLM {
35+
try await Self(
36+
directory: directory,
37+
modelInit: InternLM2Model.init
38+
)
39+
}
40+
41+
static func llama(directory: URL) async throws -> LLM {
42+
try await Self(
43+
directory: directory,
44+
modelInit: LlamaModel.init
45+
)
46+
}
47+
48+
static func openELM(directory: URL) async throws -> LLM {
49+
try await Self(
50+
directory: directory,
51+
modelInit: OpenELMModel.init
52+
)
53+
}
54+
55+
static func phi(directory: URL) async throws -> LLM {
56+
try await Self(
57+
directory: directory,
58+
modelInit: PhiModel.init
59+
)
60+
}
61+
62+
static func phi3(directory: URL) async throws -> LLM {
63+
try await Self(
64+
directory: directory,
65+
modelInit: Phi3Model.init
66+
)
67+
}
68+
69+
static func phiMoE(directory: URL) async throws -> LLM {
70+
try await Self(
71+
directory: directory,
72+
modelInit: PhiMoEModel.init
73+
)
74+
}
75+
76+
static func qwen2(directory: URL) async throws -> LLM {
77+
try await Self(
78+
directory: directory,
79+
modelInit: Qwen2Model.init
80+
)
81+
}
82+
83+
static func smolLM(directory: URL) async throws -> LLM {
84+
try await Self(
85+
directory: directory,
86+
modelInit: LlamaModel.init
87+
)
88+
}
89+
90+
static func starcoder2(directory: URL) async throws -> LLM {
91+
try await Self(
92+
directory: directory,
93+
modelInit: Starcoder2Model.init
94+
)
95+
}
96+
97+
private init<Configuration: Decodable>(
98+
directory: URL,
99+
modelInit: (Configuration) -> some LanguageModel
100+
) async throws {
101+
self.directory = directory
102+
let decoder = JSONDecoder()
103+
104+
let config = try Data(
105+
contentsOf: directory.appending(
106+
path: "config.json",
107+
directoryHint: .notDirectory
108+
)
109+
)
110+
111+
let baseConfig = try decoder.decode(
112+
BaseConfiguration.self,
113+
from: config
114+
)
115+
116+
let modelConfig = try decoder.decode(
117+
Configuration.self,
118+
from: config
119+
)
120+
let model = modelInit(modelConfig)
121+
122+
try loadWeights(
123+
modelDirectory: directory,
124+
model: model,
125+
quantization: baseConfig.quantization
126+
)
127+
128+
guard let tokenizerConfigJSON = try JSONSerialization.jsonObject(
129+
with: try Data(contentsOf: directory.appending(
130+
path: "tokenizer_config.json",
131+
directoryHint: .notDirectory
132+
))
133+
) as? [NSString: Any] else {
134+
throw SHLLMError.invalidOrMissingConfig(
135+
"tokenizer_config.json"
136+
)
137+
}
138+
139+
let tokenizerConfig = Config(tokenizerConfigJSON)
140+
141+
guard let tokenizerDataJSON = try JSONSerialization.jsonObject(
142+
with: try Data(contentsOf: directory.appending(
143+
path: "tokenizer.json",
144+
directoryHint: .notDirectory
145+
))
146+
) as? [NSString: Any] else {
147+
throw SHLLMError.invalidOrMissingConfig(
148+
"tokenizer.json"
149+
)
150+
}
151+
152+
let tokenizerData = Config(tokenizerDataJSON)
153+
154+
let tokenizer = try PreTrainedTokenizer(
155+
tokenizerConfig: tokenizerConfig,
156+
tokenizerData: tokenizerData
157+
)
158+
159+
configuration = ModelConfiguration(
160+
directory: directory,
161+
overrideTokenizer: nil,
162+
defaultPrompt: "You are a helpful assistant."
163+
)
164+
165+
context = ModelContext(
166+
configuration: configuration,
167+
model: model,
168+
processor: LLMUserInputProcessor(
169+
tokenizer: tokenizer,
170+
configuration: configuration
171+
),
172+
tokenizer: tokenizer
173+
)
174+
}
175+
}
176+
177+
extension LLM {
178+
func request(
179+
_ input: UserInput,
180+
maxTokenCount: Int = 1024 * 1024
181+
) async throws -> String {
182+
let input = try await context.processor.prepare(input: input)
183+
184+
let result = try MLXLMCommon.generate(
185+
input: input,
186+
parameters: .init(),
187+
context: context
188+
) { tokens in
189+
if tokens.count >= maxTokenCount {
190+
.stop
191+
} else {
192+
.more
193+
}
194+
}
195+
196+
return result.output
197+
}
198+
}

Sources/SHLLM/ModelProtocol.swift

+17
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
import Foundation
2+
import MLXLMCommon
3+
4+
public protocol ModelProtocol {
5+
var llm: AsyncLockedValue<LLM> { get async }
6+
}
7+
8+
public extension ModelProtocol {
9+
func request(
10+
_ input: UserInput,
11+
maxTokenCount: Int = 1024 * 1024
12+
) async throws -> String {
13+
try await llm.withLock { llm in
14+
try await llm.request(input, maxTokenCount: maxTokenCount)
15+
}
16+
}
17+
}

Sources/SHLLM/Models/CodeLlama.swift

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import Foundation
2+
3+
public actor CodeLlama: ModelProtocol {
4+
public let llm: AsyncLockedValue<LLM>
5+
6+
public init(directory: URL) async throws {
7+
let llm = try await LLM.llama(directory: directory)
8+
self.llm = .init(llm)
9+
}
10+
}
11+
12+
extension CodeLlama {
13+
static var bundleDirectory: URL {
14+
get throws {
15+
let dir = "CodeLlama-13b-Instruct-hf-4bit-MLX"
16+
guard let url = Bundle.shllm.url(
17+
forResource: dir,
18+
withExtension: nil,
19+
subdirectory: "Resources"
20+
) else {
21+
throw SHLLMError.directoryNotFound(dir)
22+
}
23+
return url
24+
}
25+
}
26+
}

Sources/SHLLM/Models/DeepSeekR1.swift

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import Foundation
2+
3+
public actor DeepSeekR1: ModelProtocol {
4+
public let llm: AsyncLockedValue<LLM>
5+
6+
public init(directory: URL) async throws {
7+
let llm = try await LLM.qwen2(directory: directory)
8+
self.llm = .init(llm)
9+
}
10+
}
11+
12+
extension DeepSeekR1 {
13+
static var bundleDirectory: URL {
14+
get throws {
15+
let dir = "DeepSeek-R1-Distill-Qwen-7B-4bit"
16+
guard let url = Bundle.shllm.url(
17+
forResource: dir,
18+
withExtension: nil,
19+
subdirectory: "Resources"
20+
) else {
21+
throw SHLLMError.directoryNotFound(dir)
22+
}
23+
return url
24+
}
25+
}
26+
}

Sources/SHLLM/Models/Gemma.swift

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import Foundation
2+
3+
public actor Gemma: ModelProtocol {
4+
public let llm: AsyncLockedValue<LLM>
5+
6+
public init(directory: URL) async throws {
7+
let llm = try await LLM.gemma(directory: directory)
8+
self.llm = .init(llm)
9+
}
10+
}
11+
12+
extension Gemma {
13+
static var bundleDirectory: URL {
14+
get throws {
15+
let dir = "quantized-gemma-2b-it"
16+
guard let url = Bundle.shllm.url(
17+
forResource: dir,
18+
withExtension: nil,
19+
subdirectory: "Resources"
20+
) else {
21+
throw SHLLMError.directoryNotFound(dir)
22+
}
23+
return url
24+
}
25+
}
26+
}

Sources/SHLLM/Models/Gemma2-2B.swift

+26
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,26 @@
1+
import Foundation
2+
3+
public actor Gemma2_2B: ModelProtocol {
4+
public let llm: AsyncLockedValue<LLM>
5+
6+
public init(directory: URL) async throws {
7+
let llm = try await LLM.gemma2(directory: directory)
8+
self.llm = .init(llm)
9+
}
10+
}
11+
12+
extension Gemma2_2B {
13+
static var bundleDirectory: URL {
14+
get throws {
15+
let dir = "gemma-2-2b-it-4bit"
16+
guard let url = Bundle.shllm.url(
17+
forResource: dir,
18+
withExtension: nil,
19+
subdirectory: "Resources"
20+
) else {
21+
throw SHLLMError.directoryNotFound(dir)
22+
}
23+
return url
24+
}
25+
}
26+
}

0 commit comments

Comments
 (0)