Skip to content

Commit

Permalink
Merge pull request #121 from allenai/106-prompt-at-top
Browse files Browse the repository at this point in the history
New input shown at top. Resized scroll button.
  • Loading branch information
jonryser authored Dec 18, 2024
2 parents 9662975 + ce08768 commit 2fdc455
Show file tree
Hide file tree
Showing 3 changed files with 118 additions and 90 deletions.
64 changes: 29 additions & 35 deletions OLMoE.swift/Model/LLM.swift
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ import llama
public typealias Token = llama_token
public typealias Model = OpaquePointer

public struct Chat: Identifiable {
public struct Chat: Identifiable, Equatable {
public var id: UUID? // Optional unique identifier
public var role: Role
public var content: String
Expand Down Expand Up @@ -51,7 +51,7 @@ open class LLM: ObservableObject {
public var path: [CChar]
public var loopBackTestResponse: Bool = false
public var savedState: Data?

@Published public private(set) var output = ""
@MainActor public func setOutput(to newOutput: consuming String) {
output = newOutput.trimmingCharacters(in: .whitespaces)
Expand All @@ -74,7 +74,7 @@ open class LLM: ObservableObject {
private var updateProgress: (Double) -> Void = { _ in }
private var nPast: Int32 = 0 // Track number of tokens processed
private var inputTokenCount: Int32 = 0

public init(
from path: String,
stopSequence: String? = nil,
Expand Down Expand Up @@ -153,11 +153,11 @@ open class LLM: ObservableObject {
self.preprocess = template.preprocess
self.template = template
}

@InferenceActor
public func stop() {
guard self.inferenceTask != nil else { return }

self.inferenceTask?.cancel()
self.inferenceTask = nil
self.batch.clear()
Expand All @@ -177,13 +177,13 @@ open class LLM: ObservableObject {
print("Error: Batch is empty or invalid.")
return model.endToken
}

// Check if the batch size is within limits
guard self.batch.n_tokens < self.maxTokenCount else {
print("Error: Batch token limit exceeded.")
return model.endToken
}

guard let sampler = self.sampler else {
fatalError("Sampler not initialized")
}
Expand All @@ -210,7 +210,7 @@ open class LLM: ObservableObject {
// For example, if you have a variable tracking the current conversation context:
// currentContext = nil
}

@InferenceActor
private func tokenizeAndBatchInput(message input: borrowing String) -> Bool {
guard self.inferenceTask != nil else { return false }
Expand All @@ -224,14 +224,14 @@ open class LLM: ObservableObject {
}
for (i, token) in tokens.enumerated() {
let isLastToken = i == tokens.count - 1

self.batch.add(token, self.nPast, [0], isLastToken)
nPast += 1
}

// Check batch has not been cleared by a side effect (stop button) at the time of decoding
guard self.batch.n_tokens > 0 else { return false }

self.context.decode(self.batch)
return true
}
Expand Down Expand Up @@ -286,16 +286,16 @@ open class LLM: ObservableObject {
}
return true
}

@InferenceActor
private func generateResponseStream(from input: String) -> AsyncStream<String> {
AsyncStream<String> { output in
Task { [weak self] in
guard let self = self else { return output.finish() } // Safely unwrap `self`
// Use `self` safely now that it's unwrapped

guard self.inferenceTask != nil else { return output.finish() }

defer {
if !FeatureFlags.useLLMCaching {
self.context = nil
Expand All @@ -305,7 +305,7 @@ open class LLM: ObservableObject {
guard self.tokenizeAndBatchInput(message: input) else {
return output.finish()
}

var token = await self.predictNextToken()
while self.emitDecoded(token: token, to: output) {
if self.nPast >= self.maxTokenCount {
Expand All @@ -317,7 +317,7 @@ open class LLM: ObservableObject {
}
}
}

/**
Halves the llama_kv_cache by removing the oldest half of tokens and shifting the newer half to the beginning.
Updates `nPast` to reflect the reduced cache size.
Expand All @@ -327,18 +327,18 @@ open class LLM: ObservableObject {
let seq_id: Int32 = 0
let beginning: Int32 = 0
let middle = Int32(self.maxTokenCount / 2)

// Remove the oldest half
llama_kv_cache_seq_rm(self.context.pointer, seq_id, beginning, middle)

// Shift the newer half to the start
llama_kv_cache_seq_add(
self.context.pointer,
seq_id,
middle,
Int32(self.maxTokenCount), -middle
)

// Update nPast
let kvCacheTokenCount: Int32 = llama_get_kv_cache_token_count(self.context.pointer)
self.nPast = kvCacheTokenCount
Expand All @@ -364,14 +364,8 @@ open class LLM: ObservableObject {
self.inferenceTask = Task { [weak self] in
guard let self = self else { return }

let historyBeforeInput = self.history
await MainActor.run {
// Append user's message to history prior to response generation
self.history.append(Chat(role: .user, content: input))
}

self.input = input
let processedInput = self.preprocess(input, historyBeforeInput, self)
let processedInput = self.preprocess(input, self.history, self)
let responseStream = self.loopBackTestResponse
? self.getTestLoopbackResponse()
: self.generateResponseStream(from: processedInput)
Expand All @@ -392,21 +386,21 @@ open class LLM: ObservableObject {

self.postprocess(output)
}

self.inputTokenCount = 0
// Save the state after generating a response
if FeatureFlags.useLLMCaching {
self.savedState = saveState()
}

if Task.isCancelled {
return
}
}

await inferenceTask?.value
}

/**
Entry point to generate a model response from the input message
*/
Expand All @@ -415,7 +409,7 @@ open class LLM: ObservableObject {
if let savedState = FeatureFlags.useLLMCaching ? self.savedState : nil {
restoreState(from: savedState)
}

await performInference(to: input) { [self] response in
await setOutput(to: "")
for await responseDelta in response {
Expand All @@ -424,15 +418,15 @@ open class LLM: ObservableObject {
}
update(nil)
let trimmedOutput = output.trimmingCharacters(in: .whitespacesAndNewlines)


self.rollbackLastUserInputIfEmptyResponse(trimmedOutput)

await setOutput(to: trimmedOutput.isEmpty ? "..." : trimmedOutput)
return output
}
}

/**
If the model fails to produce a response (empty output), remove the last user input’s tokens
from the KV cache to prevent the model’s internal state from being "poisoned" by bad input.
Expand Down Expand Up @@ -498,7 +492,7 @@ extension LLM {
assert(bytesRead == stateData.count, "Error: Read state size does not match expected size.")
}
}

let beginningOfSequenceOffset: Int32 = 1
self.nPast = llama_get_kv_cache_token_count(self.context.pointer) + beginningOfSequenceOffset
}
Expand Down
110 changes: 75 additions & 35 deletions OLMoE.swift/Views/ChatView.swift
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public struct TypingIndicator: View {
}

struct ScrollState {
static let BottomScrollThreshold = 120.0
static let BottomScrollThreshold = 40.0
static let ScrollSpaceName: String = "scrollSpace"

public var scrollViewHeight: CGFloat = 0
Expand Down Expand Up @@ -102,51 +102,87 @@ public struct ChatView: View {
public var output: String
@Binding var isGenerating: Bool
@Binding var isScrolledToBottom: Bool
@State private var contentHeight: CGFloat = 0
@State private var newHeight: CGFloat = 0
@State private var previousHeight: CGFloat = 0
@State private var outerHeight: CGFloat = 0
@State private var scrollState = ScrollState()
@State private var lastAdjustedUserCount: Int = 0
@StateObject private var keyboardResponder = KeyboardResponder()
@State var id = UUID()


public var body: some View {
ScrollViewReader { proxy in
ScrollView {
VStack(alignment: .leading, spacing: 10) {
// History
ForEach(history) { chat in
if !chat.content.isEmpty {
switch chat.role {
case .user:
UserChatBubble(text: chat.content)
case .bot:
BotChatBubble(text: chat.content)
GeometryReader { geometry in
ScrollViewReader { proxy in
ScrollView {
VStack(alignment: .leading, spacing: 10) {
// History
ForEach(history) { chat in
if !chat.content.isEmpty {
switch chat.role {
case .user:
UserChatBubble(text: chat.content)
.id(chat.id)
case .bot:
BotChatBubble(text: chat.content)
}
}
}

// Current output
if isGenerating {
BotChatBubble(text: output, isGenerating: isGenerating)
}

Color.clear.frame(height: 1).id(ChatView.BottomID)
}

// Current output
if isGenerating {
BotChatBubble(text: output, isGenerating: isGenerating)
.font(.body.monospaced())
.foregroundColor(Color("TextColor"))
.background(scrollTracker())
.frame(minHeight: newHeight, alignment: .top)
}
.background(scrollHeightTracker())
.coordinateSpace(name: ScrollState.ScrollSpaceName)
.preferredColorScheme(.dark)
.onChange(of: history) { oldHistory, newHistory in
if let lastMessage = getLatestUserChat() {
if oldHistory.count < newHistory.count && lastMessage.role == .user {
let userMessagesCount = newHistory.filter { $0.role == .user }.count

// Only adjust height if this is a new user message count we haven't handled yet
if userMessagesCount > 1 && userMessagesCount > lastAdjustedUserCount {
// Set new height based on current content plus outer height
self.newHeight = self.contentHeight + self.outerHeight
self.lastAdjustedUserCount = userMessagesCount

DispatchQueue.main.asyncAfter(deadline: .now() + 0.2) {
withAnimation {
proxy.scrollTo(lastMessage.id, anchor: .top)
}
}
}
}
}

Color.clear.frame(height: 1).id(ChatView.BottomID)
}
.font(.body.monospaced())
.foregroundColor(Color("TextColor"))
.background(scrollTracker())
}
.background(scrollHeightTracker())
.coordinateSpace(name: ScrollState.ScrollSpaceName)
.preferredColorScheme(.dark)
.onChange(of: keyboardResponder.keyboardHeight) { _,newHeight in
let keyboardIsVisible = newHeight > 0
if keyboardIsVisible {
id = UUID() // Trigger refresh by changing the id
.onChange(of: keyboardResponder.keyboardHeight) { oldKeyboardHeight, newKeyboardHeight in
self.previousHeight = self.newHeight
self.contentHeight = scrollState.contentHeight
let keyboardIsVisible = newKeyboardHeight > 0
if keyboardIsVisible {
let newHeight = self.newHeight - newKeyboardHeight
self.newHeight = max(newHeight, self.outerHeight)
DispatchQueue.main.asyncAfter(deadline: .now() + 0.1) {
withAnimation {
proxy.scrollTo(ChatView.BottomID, anchor: .bottom)
}
}
} else {
self.newHeight = self.previousHeight
}
}
}
.onAppear() {
// Scroll on refresh
proxy.scrollTo(ChatView.BottomID, anchor: .bottom)
.onAppear {
self.outerHeight = geometry.size.height
}
.id(id)
}
}

Expand Down Expand Up @@ -180,6 +216,10 @@ public struct ChatView: View {
}
}
}

private func getLatestUserChat() -> Chat? {
return self.history.last(where: { $0.role == .user })
}
}

#Preview("Replying") {
Expand Down
Loading

0 comments on commit 2fdc455

Please sign in to comment.