Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
85 changes: 84 additions & 1 deletion Sources/AnyLanguageModel/GeneratedContent.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import CoreFoundation
/// A type that represents structured, generated content.
///
/// Generated content may contain a single value, an array, or key-value pairs with unique keys.
public struct GeneratedContent: Sendable, Equatable, Generable, CustomDebugStringConvertible {
public struct GeneratedContent: Sendable, Equatable, Generable, CustomDebugStringConvertible, Codable {
/// An instance of the generation schema.
public static var generationSchema: GenerationSchema {
// GeneratedContent is self-describing, it doesn't have a fixed schema
Expand Down Expand Up @@ -391,3 +391,86 @@ public enum GeneratedContentError: Error {
case typeMismatch
case neverCannotBeInstantiated
}

// MARK: - Codable

extension GeneratedContent {
private enum CodingKeys: String, CodingKey {
case id
case kind
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
self.id = try container.decodeIfPresent(GenerationID.self, forKey: .id)
self.kind = try container.decode(Kind.self, forKey: .kind)
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)
try container.encodeIfPresent(id, forKey: .id)
try container.encode(kind, forKey: .kind)
}
}

extension GeneratedContent.Kind: Codable {
private enum CodingKeys: String, CodingKey {
case type
case value
case properties
case orderedKeys
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)
let type = try container.decode(String.self, forKey: .type)

switch type {
case "null":
self = .null
case "bool":
self = .bool(try container.decode(Bool.self, forKey: .value))
case "number":
self = .number(try container.decode(Double.self, forKey: .value))
case "string":
self = .string(try container.decode(String.self, forKey: .value))
case "array":
self = .array(try container.decode([GeneratedContent].self, forKey: .value))
case "structure":
let properties = try container.decode([String: GeneratedContent].self, forKey: .properties)
let orderedKeys = try container.decode([String].self, forKey: .orderedKeys)
self = .structure(properties: properties, orderedKeys: orderedKeys)
default:
throw DecodingError.dataCorruptedError(
forKey: .type,
in: container,
debugDescription: "Unknown kind type: \(type)"
)
}
}

public func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)

switch self {
case .null:
try container.encode("null", forKey: .type)
case .bool(let value):
try container.encode("bool", forKey: .type)
try container.encode(value, forKey: .value)
case .number(let value):
try container.encode("number", forKey: .type)
try container.encode(value, forKey: .value)
case .string(let value):
try container.encode("string", forKey: .type)
try container.encode(value, forKey: .value)
case .array(let elements):
try container.encode("array", forKey: .type)
try container.encode(elements, forKey: .value)
case .structure(let properties, let orderedKeys):
try container.encode("structure", forKey: .type)
try container.encode(properties, forKey: .properties)
try container.encode(orderedKeys, forKey: .orderedKeys)
}
}
}
2 changes: 1 addition & 1 deletion Sources/AnyLanguageModel/GenerationID.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ import struct Foundation.UUID
/// }
/// }
/// ```
public struct GenerationID: Sendable, Hashable {
public struct GenerationID: Sendable, Hashable, Codable {
private let uuid: UUID

/// Create a new, unique `GenerationID`.
Expand Down
6 changes: 3 additions & 3 deletions Sources/AnyLanguageModel/GenerationOptions.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
/// perform various adjustments on how the model chooses output tokens,
/// to specify the penalties for repeating tokens or generating
/// longer responses.
public struct GenerationOptions: Sendable, Equatable {
public struct GenerationOptions: Sendable, Equatable, Codable {

/// A sampling strategy for how the model picks tokens when generating a
/// response.
Expand Down Expand Up @@ -83,8 +83,8 @@ extension GenerationOptions {
/// loop the model produces a probability distribution for all the tokens in its
/// vocabulary. The sampling mode controls how a token is selected from that
/// distribution.
public struct SamplingMode: Sendable, Equatable {
enum Mode: Equatable {
public struct SamplingMode: Sendable, Equatable, Codable {
enum Mode: Equatable, Codable {
case greedy
case topK(Int, seed: UInt64?)
case nucleus(Double, seed: UInt64?)
Expand Down
145 changes: 106 additions & 39 deletions Sources/AnyLanguageModel/LanguageModelSession.swift
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,28 @@ public final class LanguageModelSession: @unchecked Sendable {
self.model = model
self.tools = tools
self.instructions = instructions
self.transcript = transcript

// Build transcript with instructions if provided and not already in transcript
var finalTranscript = transcript
if let instructions = instructions {
// Only add instructions if transcript doesn't already start with instructions
let hasInstructions =
finalTranscript.first.map { entry in
if case .instructions = entry { return true } else { return false }
} ?? false

if !hasInstructions {
let instructionsEntry = Transcript.Entry.instructions(
Transcript.Instructions(
segments: [.text(.init(content: instructions.description))],
toolDefinitions: tools.map { Transcript.ToolDefinition(tool: $0) }
)
)
finalTranscript.append(instructionsEntry)
}
}

self.transcript = finalTranscript
}

public func prewarm(promptPrefix: Prompt? = nil) {
Expand Down Expand Up @@ -89,18 +110,47 @@ public final class LanguageModelSession: @unchecked Sendable {
}

nonisolated private func wrapStream<Content>(
_ upstream: sending ResponseStream<Content>
_ upstream: sending ResponseStream<Content>,
promptEntry: Transcript.Entry
) -> ResponseStream<Content> where Content: Generable, Content.PartiallyGenerated: Sendable {
let session = self
let relay = AsyncThrowingStream<ResponseStream<Content>.Snapshot, any Error> { continuation in
let stream = upstream
Task {
// Add prompt to transcript when stream starts
await MainActor.run {
session.transcript.append(promptEntry)
}

await session.beginResponding()
var lastSnapshot: ResponseStream<Content>.Snapshot?
do {
for try await snapshot in stream {
lastSnapshot = snapshot
continuation.yield(snapshot)
}
continuation.finish()

// Add response to transcript after stream completes
if let lastSnapshot {
// Extract text content from the generated content
let textContent: String
if case .string(let str) = lastSnapshot.rawContent.kind {
textContent = str
} else {
textContent = lastSnapshot.rawContent.jsonString
}
Comment thread
mattt marked this conversation as resolved.

let responseEntry = Transcript.Entry.response(
Transcript.Response(
assetIDs: [],
segments: [.text(.init(content: textContent))]
)
)
await MainActor.run {
session.transcript.append(responseEntry)
}
}
} catch {
continuation.finish(throwing: error)
}
Expand All @@ -121,15 +171,12 @@ public final class LanguageModelSession: @unchecked Sendable {
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<String> {
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
}
try await respond(
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
}

@discardableResult
Expand All @@ -155,15 +202,12 @@ public final class LanguageModelSession: @unchecked Sendable {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) async throws -> Response<GeneratedContent> {
try await wrapRespond {
try await model.respond(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}
try await respond(
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}

@discardableResult
Expand Down Expand Up @@ -204,13 +248,32 @@ public final class LanguageModelSession: @unchecked Sendable {
options: GenerationOptions = GenerationOptions()
) async throws -> Response<Content> where Content: Generable {
try await wrapRespond {
try await model.respond(
// Add prompt to transcript
let promptEntry = Transcript.Entry.prompt(
Transcript.Prompt(
segments: [.text(.init(content: prompt.description))],
options: options,
responseFormat: nil
)
)
await MainActor.run {
self.transcript.append(promptEntry)
}

let response = try await model.respond(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)

// Add response entries to transcript
await MainActor.run {
self.transcript.append(contentsOf: response.transcriptEntries)
}

return response
}
}

Expand Down Expand Up @@ -250,14 +313,11 @@ public final class LanguageModelSession: @unchecked Sendable {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<GeneratedContent> {
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
streamResponse(
to: prompt,
generating: GeneratedContent.self,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
}

Expand Down Expand Up @@ -290,14 +350,24 @@ public final class LanguageModelSession: @unchecked Sendable {
includeSchemaInPrompt: Bool = true,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<Content> where Content: Generable {
wrapStream(
// Create prompt entry that will be added when stream starts
let promptEntry = Transcript.Entry.prompt(
Transcript.Prompt(
segments: [.text(.init(content: prompt.description))],
options: options,
responseFormat: nil
)
)

return wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: type,
includeSchemaInPrompt: includeSchemaInPrompt,
options: options
)
),
promptEntry: promptEntry
)
}

Expand Down Expand Up @@ -333,14 +403,11 @@ public final class LanguageModelSession: @unchecked Sendable {
to prompt: Prompt,
options: GenerationOptions = GenerationOptions()
) -> sending ResponseStream<String> {
wrapStream(
model.streamResponse(
within: self,
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
streamResponse(
to: prompt,
generating: String.self,
includeSchemaInPrompt: true,
options: options
)
}

Expand Down
Loading