Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
174 changes: 125 additions & 49 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,16 @@ import Foundation
import Observation

@Observable
public final class LanguageModelSession {
public final class LanguageModelSession: @unchecked Sendable {
Comment thread
mattt marked this conversation as resolved.
public private(set) var isResponding: Bool = false
public private(set) var transcript: Transcript

private let model: any LanguageModel
public let tools: [any Tool]
public let instructions: Instructions?

@ObservationIgnored private let respondingState = RespondingState()

public convenience init(
model: any LanguageModel,
tools: [any Tool] = [],
Expand Down Expand Up @@ -58,7 +60,57 @@ public final class LanguageModelSession {
model.prewarm(for: self, promptPrefix: promptPrefix)
}

public struct Response<Content> where Content: Generable {
nonisolated private func beginResponding() async {
let count = await respondingState.increment()
let active = count > 0
await MainActor.run {
self.isResponding = active
}
}

nonisolated private func endResponding() async {
let count = await respondingState.decrement()
let active = count > 0
await MainActor.run {
self.isResponding = active
}
}

nonisolated private func wrapRespond<T>(_ operation: () async throws -> T) async throws -> T {
await beginResponding()
do {
let result = try await operation()
await endResponding()
return result
} catch {
await endResponding()
throw error
}
}

nonisolated private func wrapStream<Content>(
_ upstream: sending ResponseStream<Content>
) -> ResponseStream<Content> where Content: Generable, Content.PartiallyGenerated: Sendable {
let session = self
let relay = AsyncThrowingStream<ResponseStream<Content>.Snapshot, any Error> { continuation in
let stream = upstream
Task {
await session.beginResponding()
do {
for try await snapshot in stream {
continuation.yield(snapshot)
}
continuation.finish()
} catch {
continuation.finish(throwing: error)
}
await session.endResponding()
}
}
return ResponseStream(stream: relay)
}

public struct Response<Content>: Sendable where Content: Generable, Content: Sendable {
public let content: Content
public let rawContent: GeneratedContent
public let transcriptEntries: ArraySlice<Transcript.Entry>
Expand All @@ -69,13 +121,15 @@ public final class LanguageModelSession {
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<String> {
try await model.respond(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
}
}

@discardableResult
Expand All @@ -101,13 +155,15 @@ public final class LanguageModelSession {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<GeneratedContent> {
try await model.respond(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
}

@discardableResult
Expand Down Expand Up @@ -147,13 +203,15 @@ public final class LanguageModelSession {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<Content> where Content: Generable {
try await model.respond(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
}

@discardableResult
Expand Down Expand Up @@ -186,22 +244,24 @@ public final class LanguageModelSession {
)
}

public func streamResponse(
nonisolated public func streamResponse(
to prompt: Prompt,
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<GeneratedContent> {
model.streamResponse(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
)
}

public func streamResponse(
nonisolated public func streamResponse(
to prompt: String,
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
Expand All @@ -215,7 +275,7 @@ public final class LanguageModelSession {
)
}

public func streamResponse(
nonisolated public func streamResponse(
schema: GenerationSchema,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions(),
Expand All @@ -224,22 +284,24 @@ public final class LanguageModelSession {
streamResponse(to: try prompt(), schema: schema, includeSchemaInPrompt: includeSchemaInPrompt, options: options)
}

public func streamResponse<Content>(
nonisolated public func streamResponse<Content>(
to prompt: Prompt,
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<Content> where Content: Generable {
model.streamResponse(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
)
}

public func streamResponse<Content>(
nonisolated public func streamResponse<Content>(
to prompt: String,
generating type: Content.Type = Content.self,
includeSchemaInPrompt: Bool = true,
Expand Down Expand Up @@ -271,12 +333,14 @@ public final class LanguageModelSession {
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<String> {
model.streamResponse(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
)
}

Expand Down Expand Up @@ -309,7 +373,19 @@ public final class LanguageModelSession {
}
}

extension LanguageModelSession: @unchecked Sendable, Observable {}
private actor RespondingState {
private var count = 0

func increment() -> Int {
count += 1
return count
}

func decrement() -> Int {
count = max(0, count - 1)
return count
}
}

extension LanguageModelSession {
public enum GenerationError: Error, LocalizedError {
Expand Down Expand Up @@ -401,7 +477,7 @@ extension LanguageModelSession {
}

extension LanguageModelSession {
public struct ResponseStream<Content> where Content: Generable {
public struct ResponseStream<Content>: Sendable where Content: Generable, Content.PartiallyGenerated: Sendable {
private let content: Content
private let rawContent: GeneratedContent
private let streaming: AsyncThrowingStream<Snapshot, any Error>?
Expand All @@ -420,7 +496,7 @@ extension LanguageModelSession {
self.streaming = stream
}

public struct Snapshot {
public struct Snapshot: Sendable where Content.PartiallyGenerated: Sendable {
public var content: Content.PartiallyGenerated
public var rawContent: GeneratedContent
}
Expand Down
10 changes: 5 additions & 5 deletions Sources/AnyLanguageModel/Models/SystemLanguageModel.swift
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
/// let model = SystemLanguageModel()
/// ```
@available(macOS 26.0, iOS 26.0, watchOS 26.0, tvOS 26.0, visionOS 26.0, *)
public struct SystemLanguageModel: LanguageModel {
public actor SystemLanguageModel: LanguageModel {
/// The reason the model is unavailable.
public typealias UnavailableReason = FoundationModels.SystemLanguageModel.Availability.UnavailableReason

Expand Down Expand Up @@ -54,7 +54,7 @@
}

/// The availability status for the system language model.
public var availability: Availability<UnavailableReason> {
nonisolated public var availability: Availability<UnavailableReason> {
switch systemModel.availability {
case .available:
.available
Expand All @@ -63,7 +63,7 @@
}
}

public func respond<Content>(
nonisolated public func respond<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
Expand Down Expand Up @@ -100,7 +100,7 @@
}
}

public func streamResponse<Content>(
nonisolated public func streamResponse<Content>(
within session: LanguageModelSession,
to prompt: Prompt,
generating type: Content.Type,
Expand Down Expand Up @@ -180,7 +180,7 @@
return LanguageModelSession.ResponseStream(stream: stream)
}

public func logFeedbackAttachment(
nonisolated public func logFeedbackAttachment(
within session: LanguageModelSession,
sentiment: LanguageModelFeedback.Sentiment?,
issues: [LanguageModelFeedback.Issue],
Expand Down
48 changes: 48 additions & 0 deletions Tests/AnyLanguageModelTests/MockLanguageModelTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -55,4 +55,52 @@ struct MockLanguageModelTests {
#expect(model.availability == .unavailable(.custom("MockLanguageModel is unavailable")))
#expect(model.isAvailable == false)
}

@Test func isRespondingDuringAsyncResponse() async throws {
let model = MockLanguageModel { _, _ in
try await Task.sleep(for: .milliseconds(100))
return "Response"
}
let session = LanguageModelSession(model: model)

#expect(session.isResponding == false)

let task = Task {
try await session.respond(to: "Test")
}

try await Task.sleep(for: .milliseconds(50))
#expect(session.isResponding == true)

_ = try await task.value
try await Task.sleep(for: .milliseconds(10))
#expect(session.isResponding == false)
}

@Test func isRespondingDuringStreaming() async throws {
let model = MockLanguageModel.streamingMock()
let session = LanguageModelSession(model: model)

#expect(session.isResponding == false)

let stream = session.streamResponse(to: "Test")

// Start consuming the stream in a task
let task = Task {
for try await _ in stream {
// Just consume the stream
}
}

// Give the streaming task time to start and call beginResponding
try await Task.sleep(for: .milliseconds(50))
#expect(session.isResponding == true)

// Wait for stream to complete
_ = try await task.value

// Give time for endResponding to complete
try await Task.sleep(for: .milliseconds(10))
#expect(session.isResponding == false)
}
}
Loading