From 41a002bf8cb1239b55e698242eaf7109a83bdf68 Mon Sep 17 00:00:00 2001 From: Sam Saffron Date: Wed, 6 Nov 2024 10:25:02 +1100 Subject: [PATCH 1/4] FEATURE: AI artifacts Initial implementation of an artifact system which allows users to generate HTML pages directly from the AI persona. FEATURE: support tool progress callbacks This is anthropic only for now, but we can get a callback as tool is completing, this gives us the ability to show progress to user as the function is populating. work in progress Revert "work in progress" This reverts commit 30ebe562ea4e5601701ff89af80b964e1cc6af5e. Revert "FEATURE: support tool progress callbacks" This reverts commit fd7ccfd0ab538d2845b096d3902dc06bdbf5cf1f. --- .../ai_bot/artifacts_controller.rb | 45 +++++ app/models/ai_artifact.rb | 22 +++ .../javascripts/initializers/ai-artifacts.js | 54 ++++++ .../lib/discourse-markdown/ai-tags.js | 5 + .../modules/ai-bot/common/ai-artifact.scss | 56 ++++++ config/locales/server.en.yml | 3 + config/routes.rb | 4 + db/migrate/20241104053017_add_ai_artifacts.rb | 16 ++ lib/ai_bot/personas/persona.rb | 1 + lib/ai_bot/tools/create_artifact.rb | 159 ++++++++++++++++++ plugin.rb | 2 + 11 files changed, 367 insertions(+) create mode 100644 app/controllers/discourse_ai/ai_bot/artifacts_controller.rb create mode 100644 app/models/ai_artifact.rb create mode 100644 assets/javascripts/initializers/ai-artifacts.js create mode 100644 assets/stylesheets/modules/ai-bot/common/ai-artifact.scss create mode 100644 db/migrate/20241104053017_add_ai_artifacts.rb create mode 100644 lib/ai_bot/tools/create_artifact.rb diff --git a/app/controllers/discourse_ai/ai_bot/artifacts_controller.rb b/app/controllers/discourse_ai/ai_bot/artifacts_controller.rb new file mode 100644 index 000000000..c9aa3c855 --- /dev/null +++ b/app/controllers/discourse_ai/ai_bot/artifacts_controller.rb @@ -0,0 +1,45 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class ArtifactsController < ApplicationController + + requires_plugin DiscourseAi::PLUGIN_NAME + + skip_before_action :preload_json, :check_xhr, only: %i[show] + + def show + artifact = AiArtifact.find(params[:id]) + + post = Post.find_by(id: artifact.post_id) + raise Discourse::NotFound unless post && guardian.can_see?(post) + + # Prepare the HTML document + html = <<~HTML + + + + + #{ERB::Util.html_escape(artifact.name)} + + + + #{artifact.html} + + + + HTML + + response.headers.delete("X-Frame-Options") + response.headers.delete("Content-Security-Policy") + + # Render the content + render html: html.html_safe, layout: false, content_type: "text/html" + end + end + end +end diff --git a/app/models/ai_artifact.rb b/app/models/ai_artifact.rb new file mode 100644 index 000000000..e44518275 --- /dev/null +++ b/app/models/ai_artifact.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true + +class AiArtifact < ActiveRecord::Base + belongs_to :user + belongs_to :post +end + +# == Schema Information +# +# Table name: ai_artifacts +# +# id :bigint not null, primary key +# user_id :integer not null +# post_id :integer not null +# name :string(255) not null +# html :string(65535) +# css :string(65535) +# js :string(65535) +# metadata :jsonb +# created_at :datetime not null +# updated_at :datetime not null +# diff --git a/assets/javascripts/initializers/ai-artifacts.js b/assets/javascripts/initializers/ai-artifacts.js new file mode 100644 index 000000000..566d18d6b --- /dev/null +++ b/assets/javascripts/initializers/ai-artifacts.js @@ -0,0 +1,54 @@ +import { withPluginApi } from "discourse/lib/plugin-api"; + +function initializeAiArtifactTabs(api) { + api.decorateCooked( + ($element) => { + const element = $element[0]; + const artifacts = element.querySelectorAll(".ai-artifact"); + if (!artifacts.length) { + return; + } + + artifacts.forEach((artifact) => { + const tabs = artifact.querySelectorAll(".ai-artifact-tab"); + const panels = artifact.querySelectorAll(".ai-artifact-panel"); + + tabs.forEach((tab) => { + tab.addEventListener("click", (e) => { + e.preventDefault(); + + if (tab.hasAttribute("data-selected")) { + return; + } + + const tabType = Object.keys(tab.dataset).find( + (key) => key !== "selected" + ); + + tabs.forEach((t) => t.removeAttribute("data-selected")); + panels.forEach((p) => p.removeAttribute("data-selected")); + + tab.setAttribute("data-selected", ""); + const targetPanel = artifact.querySelector( + `.ai-artifact-panel[data-${tabType}]` + ); + if (targetPanel) { + targetPanel.setAttribute("data-selected", ""); + } + }); + }); + }); + }, + { + id: "ai-artifact-tabs", + onlyStream: false, + } + ); +} + +export default { + name: "ai-artifact-tabs", + initialize() { + withPluginApi("0.8.7", initializeAiArtifactTabs); + }, +}; diff --git a/assets/javascripts/lib/discourse-markdown/ai-tags.js b/assets/javascripts/lib/discourse-markdown/ai-tags.js index c2d9b6726..65d71b5bc 100644 --- a/assets/javascripts/lib/discourse-markdown/ai-tags.js +++ b/assets/javascripts/lib/discourse-markdown/ai-tags.js @@ -1,3 +1,8 @@ export function setup(helper) { helper.allowList(["details[class=ai-quote]"]); + helper.allowList(["div[class=ai-artifact]"]); + helper.allowList(["div[class=ai-artifact-tab]"]); + helper.allowList(["div[class=ai-artifact-tabs]"]); + helper.allowList(["div[class=ai-artifact-panels]"]); + helper.allowList(["div[class=ai-artifact-panel]"]); } diff --git a/assets/stylesheets/modules/ai-bot/common/ai-artifact.scss b/assets/stylesheets/modules/ai-bot/common/ai-artifact.scss new file mode 100644 index 000000000..cf405159f --- /dev/null +++ b/assets/stylesheets/modules/ai-bot/common/ai-artifact.scss @@ -0,0 +1,56 @@ +.ai-artifact { + margin: 1em 0; + + .ai-artifact-tabs { + display: flex; + gap: 0.20em; + border-bottom: 2px solid var(--primary-low); + padding: 0 0.2em; + + .ai-artifact-tab { + margin-bottom: -2px; + + &[data-selected] { + a { + color: var(--tertiary); + font-weight: 500; + border-bottom: 2px solid var(--tertiary); + } + } + + &:hover:not([data-selected]) { + a { + color: var(--primary); + background: var(--primary-very-low); + } + } + + a { + display: block; + padding: 0.5em 1em; + color: var(--primary-medium); + text-decoration: none; + cursor: pointer; + border-bottom: 2px solid transparent; + } + } + } + + .ai-artifact-panels { + padding: 1em 0 0 0; + background: var(--blend-primary-secondary-5); + + .ai-artifact-panel { + display: none; + min-height: 400px; + + &[data-selected] { + display: block; + } + + pre { + margin: 0; + } + } + } +} diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index d18a55c8d..fbafba5c5 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -222,6 +222,7 @@ en: name: "Base Search Query" description: "Base query to use when searching. Example: '#urgent' will prepend '#urgent' to the search query and only include topics with the urgent category or tag." tool_summary: + create_artifact: "Create web artifact" web_browser: "Browse Web" github_search_files: "GitHub search files" github_search_code: "GitHub code search" @@ -243,6 +244,7 @@ en: search_meta_discourse: "Search Meta Discourse" javascript_evaluator: "Evaluate JavaScript" tool_help: + create_artifact: "Create a web artifact using the AI Bot" web_browser: "Browse web page using the AI Bot" github_search_code: "Search for code in a GitHub repository" github_search_files: "Search for files in a GitHub repository" @@ -264,6 +266,7 @@ en: search_meta_discourse: "Search Meta Discourse" javascript_evaluator: "Evaluate JavaScript" tool_description: + create_artifact: "Created a web artifact using the AI Bot" web_browser: "Reading %{url}" github_search_files: "Searched for '%{keywords}' in %{repo}/%{branch}" github_search_code: "Searched for '%{query}' in %{repo}" diff --git a/config/routes.rb b/config/routes.rb index 322e67ce1..6161ab176 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -33,6 +33,10 @@ get "/preview/:topic_id" => "shared_ai_conversations#preview" end + scope module: :ai_bot, path: "/ai-bot/artifacts" do + get "/:id" => "artifacts#show" + end + scope module: :summarization, path: "/summarization", defaults: { format: :json } do get "/t/:topic_id" => "summary#show", :constraints => { topic_id: /\d+/ } get "/channels/:channel_id" => "chat_summary#show" diff --git a/db/migrate/20241104053017_add_ai_artifacts.rb b/db/migrate/20241104053017_add_ai_artifacts.rb new file mode 100644 index 000000000..895692e6d --- /dev/null +++ b/db/migrate/20241104053017_add_ai_artifacts.rb @@ -0,0 +1,16 @@ +# frozen_string_literal: true +class AddAiArtifacts < ActiveRecord::Migration[7.1] + def change + create_table :ai_artifacts do |t| + t.integer :user_id, null: false + t.integer :post_id, null: false + t.string :name, null: false, limit: 255 + t.string :html, limit: 65535 # ~64KB limit + t.string :css, limit: 65535 # ~64KB limit + t.string :js, limit: 65535 # ~64KB limit + t.jsonb :metadata # For any additional properties + + t.timestamps + end + end +end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 63255a172..e49787b5e 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -96,6 +96,7 @@ def all_available_tools Tools::GithubSearchFiles, Tools::WebBrowser, Tools::JavascriptEvaluator, + Tools::CreateArtifact, ] tools << Tools::GithubSearchCode if SiteSetting.ai_bot_github_access_token.present? diff --git a/lib/ai_bot/tools/create_artifact.rb b/lib/ai_bot/tools/create_artifact.rb new file mode 100644 index 000000000..ebbf54aab --- /dev/null +++ b/lib/ai_bot/tools/create_artifact.rb @@ -0,0 +1,159 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + module Tools + class CreateArtifact < Tool + def self.name + "create_artifact" + end + + def self.signature + { + name: "create_artifact", + description: + "Creates a web artifact with HTML, CSS, and JavaScript that can be displayed in an iframe", + parameters: [ + { + name: "name", + description: "A name for the artifact (max 255 chars)", + type: "string", + required: true, + }, + { + name: "html_content", + description: "The HTML content for the artifact", + type: "string", + required: true, + }, + { name: "css", description: "Optional CSS styles for the artifact", type: "string" }, + { + name: "js", + description: + "Optional + JavaScript code for the artifact", + type: "string", + }, + ], + } + end + + def invoke + # Get the current post from context + post = Post.find_by(id: context[:post_id]) + return error_response("No post context found") unless post + + html = parameters[:html_content].to_s + css = parameters[:css].to_s + js = parameters[:js].to_s + + # Create the artifact + artifact = + AiArtifact.new( + user_id: bot_user.id, + post_id: post.id, + name: parameters[:name].to_s[0...255], + html: html, + css: css, + js: js, + metadata: parameters[:metadata], + ) + + if artifact.save + tabs = { + css: [css, "CSS"], + js: [js, "JavaScript"], + html: [html, "HTML"], + preview: [ + "", + "Preview", + ], + } + + first = true + html_tabs = + tabs.map do |tab, (content, name)| + selected = " data-selected" if first + first = false + (<<~HTML).strip +
+ #{name} +
+ HTML + end + + first = true + html_panels = + tabs.map do |tab, (content, name)| + selected = " data-selected" if first + first = false + inner_content = + if tab == :preview + content + else + <<~HTML + + ```#{tab} + #{content} + ``` + HTML + end + (<<~HTML).strip +
+ + #{inner_content} +
+ HTML + end + + self.custom_raw = <<~RAW +
+
+ #{html_tabs.join("\n")} +
+
+ #{html_panels.join("\n")} +
+
+ RAW + + success_response(artifact) + else + error_response(artifact.errors.full_messages.join(", ")) + end + end + + def chain_next_response? + @chain_next_response + end + + private + + def success_response(artifact) + @chain_next_response = false + iframe_url = "#{Discourse.base_url}/discourse-ai/ai-bot/artifacts/#{artifact.id}" + + { + status: "success", + artifact_id: artifact.id, + iframe_html: + "", + message: "Artifact created successfully and rendered to user.", + } + end + + def error_response(message) + @chain_next_response = false + + { status: "error", error: message } + end + + def help + "Creates a web artifact with HTML, CSS, and JavaScript that can be displayed in an iframe. " \ + "Requires a name and HTML content. CSS and JavaScript are optional. " \ + "The artifact will be associated with the current post and can be displayed using an iframe." + end + end + end + end +end diff --git a/plugin.rb b/plugin.rb index bb4a320a1..9d3baf75c 100644 --- a/plugin.rb +++ b/plugin.rb @@ -39,6 +39,8 @@ register_asset "stylesheets/modules/ai-bot/common/ai-tools.scss" +register_asset "stylesheets/modules/ai-bot/common/ai-artifact.scss" + module ::DiscourseAi PLUGIN_NAME = "discourse-ai" end From 1eb19933b5fb803b04bd081b0d4fb4ec37f58b3a Mon Sep 17 00:00:00 2001 From: Sam Saffron Date: Thu, 14 Nov 2024 11:37:29 +1100 Subject: [PATCH 2/4] Handle malformed gemini replies --- lib/ai_bot/tools/create_artifact.rb | 6 ++-- lib/completions/endpoints/gemini.rb | 33 +++++++++---------- spec/lib/completions/endpoints/gemini_spec.rb | 31 +++++++++++++++++ 3 files changed, 50 insertions(+), 20 deletions(-) diff --git a/lib/ai_bot/tools/create_artifact.rb b/lib/ai_bot/tools/create_artifact.rb index ebbf54aab..231945099 100644 --- a/lib/ai_bot/tools/create_artifact.rb +++ b/lib/ai_bot/tools/create_artifact.rb @@ -21,8 +21,8 @@ def self.signature required: true, }, { - name: "html_content", - description: "The HTML content for the artifact", + name: "html_body", + description: "The HTML content for the BODY tag (do not include the BODY tag)", type: "string", required: true, }, @@ -43,7 +43,7 @@ def invoke post = Post.find_by(id: context[:post_id]) return error_response("No post context found") unless post - html = parameters[:html_content].to_s + html = parameters[:html_body].to_s css = parameters[:css].to_s js = parameters[:js].to_s diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 2450dc99e..c3afc313d 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -173,25 +173,24 @@ def decode_chunk(chunk) .decode(chunk) .map do |parsed| update_usage(parsed) - parsed - .dig(:candidates, 0, :content, :parts) - .map do |part| - if part[:text] - part = part[:text] - if part != "" - part - else - nil - end - elsif part[:functionCall] - @tool_index += 1 - ToolCall.new( - id: "tool_#{@tool_index}", - name: part[:functionCall][:name], - parameters: part[:functionCall][:args], - ) + parts = parsed.dig(:candidates, 0, :content, :parts) + parts&.map do |part| + if part[:text] + part = part[:text] + if part != "" + part + else + nil end + elsif part[:functionCall] + @tool_index += 1 + ToolCall.new( + id: "tool_#{@tool_index}", + name: part[:functionCall][:name], + parameters: part[:functionCall][:args], + ) end + end end .flatten .compact diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 189338438..0c7b92088 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -324,6 +324,37 @@ def tool_response expect(log.response_tokens).to eq(4) end + it "Can correctly handle malformed responses" do + response = <<~TEXT + data: {"candidates": [{"content": {"parts": [{"text": "Certainly"}],"role": "model"}}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"content": {"parts": [{"text": "! I'll create a simple \\"Hello, World!\\" page where each letter"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"content": {"parts": [{"text": " has a different color using inline styles for simplicity. Each letter will be wrapped"}],"role": "model"},"safetyRatings": [{"category": "HARM_CATEGORY_HATE_SPEECH","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_DANGEROUS_CONTENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_HARASSMENT","probability": "NEGLIGIBLE"},{"category": "HARM_CATEGORY_SEXUALLY_EXPLICIT","probability": "NEGLIGIBLE"}]}],"usageMetadata": {"promptTokenCount": 399,"totalTokenCount": 399},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"content": {"parts": [{"text": ""}],"role": "model"},"finishReason": "STOP"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"} + + data: {"candidates": [{"finishReason": "MALFORMED_FUNCTION_CALL"}],"usageMetadata": {"promptTokenCount": 399,"candidatesTokenCount": 191,"totalTokenCount": 590},"modelVersion": "gemini-1.5-pro-002"} + + TEXT + + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + url = "#{model.url}:streamGenerateContent?alt=sse&key=123" + + output = [] + + stub_request(:post, url).to_return(status: 200, body: response) + llm.generate("Hello", user: user) { |partial| output << partial } + + expect(output).to eq( + [ + "Certainly", + "! I'll create a simple \"Hello, World!\" page where each letter", + " has a different color using inline styles for simplicity. Each letter will be wrapped", + ], + ) + end + it "Can correctly handle streamed responses even if they are chunked badly" do data = +"" data << "da|ta: |" From 7300737e65bf2ae24dd6a913db4d3698b7ebcdfa Mon Sep 17 00:00:00 2001 From: Sam Saffron Date: Thu, 14 Nov 2024 17:24:39 +1100 Subject: [PATCH 3/4] - halt after tools - post streamer ensures we don't have half completed stuff on screen when a tool is slow - reimplemnt xml tools to have a more relaxed parse --- lib/ai_bot/bot.rb | 7 ++ lib/ai_bot/playground.rb | 23 +++---- lib/ai_bot/post_streamer.rb | 58 +++++++++++++++++ lib/ai_bot/tools/create_artifact.rb | 1 + lib/completions/xml_tool_processor.rb | 65 ++++++++++++------- .../completions/xml_tool_processor_spec.rb | 23 ++++++- 6 files changed, 140 insertions(+), 37 deletions(-) create mode 100644 lib/ai_bot/post_streamer.rb diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index c00d5b650..420c848ea 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -106,6 +106,8 @@ def reply(context, &update_blk) tool_found = false force_tool_if_needed(prompt, context) + tool_halted = false + result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context) @@ -122,7 +124,12 @@ def reply(context, &update_blk) process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) tools_ran += 1 ongoing_chain &&= tool.chain_next_response? + + if !tool.chain_next_response? + tool_halted = true + end else + next if tool_halted needs_newlines = true if partial.is_a?(DiscourseAi::Completions::ToolCall) Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}") diff --git a/lib/ai_bot/playground.rb b/lib/ai_bot/playground.rb index 3873ebc61..984dd3c63 100644 --- a/lib/ai_bot/playground.rb +++ b/lib/ai_bot/playground.rb @@ -399,7 +399,7 @@ def reply_to(post, custom_instructions: nil, &blk) PostCustomPrompt.none reply = +"" - start = Time.now + post_streamer = nil post_type = post.post_type == Post.types[:whisper] ? Post.types[:whisper] : Post.types[:regular] @@ -448,6 +448,8 @@ def reply_to(post, custom_instructions: nil, &blk) context[:skip_tool_details] ||= !bot.persona.class.tool_details + post_streamer = PostStreamer.new(delay: Rails.env.test? ? 0 : 0.5) if stream_reply + new_custom_prompts = bot.reply(context) do |partial, cancel, placeholder, type| reply << partial @@ -461,22 +463,20 @@ def reply_to(post, custom_instructions: nil, &blk) reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) end - if stream_reply - # Minor hack to skip the delay during tests. - if placeholder.blank? - next if (Time.now - start < 0.5) && !Rails.env.test? - start = Time.now - end - - Discourse.redis.expire(redis_stream_key, 60) - - publish_update(reply_post, { raw: raw }) + if post_streamer + post_streamer.run_later { + Discourse.redis.expire(redis_stream_key, 60) + publish_update(reply_post, { raw: raw }) + } end end return if reply.blank? if stream_reply + post_streamer.finish + post_streamer = nil + # land the final message prior to saving so we don't clash reply_post.cooked = PrettyText.cook(reply) publish_final_update(reply_post) @@ -514,6 +514,7 @@ def reply_to(post, custom_instructions: nil, &blk) reply_post ensure + post_streamer&.finish(skip_callback: true) publish_final_update(reply_post) if stream_reply if reply_post && post.post_number == 1 && post.topic.private_message? title_playground(reply_post) diff --git a/lib/ai_bot/post_streamer.rb b/lib/ai_bot/post_streamer.rb new file mode 100644 index 000000000..57ba3c408 --- /dev/null +++ b/lib/ai_bot/post_streamer.rb @@ -0,0 +1,58 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class PostStreamer + def initialize(delay: 0.5) + @mutex = Mutex.new + @callback = nil + @delay = delay + @done = false + end + + def run_later(&callback) + @mutex.synchronize { @callback = callback } + ensure_worker! + end + + def finish(skip_callback: false) + @mutex.synchronize do + @callback&.call if skip_callback + @callback = nil + @done = true + end + + begin + @worker_thread&.wakeup + rescue StandardError + ThreadError + end + @worker_thread&.join + @worker_thread = nil + end + + private + + def run + while !@done + @mutex.synchronize do + callback = @callback + @callback = nil + callback&.call + end + sleep @delay + end + end + + def ensure_worker! + return if @worker_thread + @mutex.synchronize do + return if @worker_thread + db = RailsMultisite::ConnectionManagement.current_db + @worker_thread = + Thread.new { RailsMultisite::ConnectionManagement.with_connection(db) { run } } + end + end + end + end +end diff --git a/lib/ai_bot/tools/create_artifact.rb b/lib/ai_bot/tools/create_artifact.rb index 231945099..5cf67ddcf 100644 --- a/lib/ai_bot/tools/create_artifact.rb +++ b/lib/ai_bot/tools/create_artifact.rb @@ -39,6 +39,7 @@ def self.signature end def invoke + yield parameters[:name] || "Web Artifact" # Get the current post from context post = Post.find_by(id: context[:post_id]) return error_response("No post context found") unless post diff --git a/lib/completions/xml_tool_processor.rb b/lib/completions/xml_tool_processor.rb index 1b42b333c..c1ae9d229 100644 --- a/lib/completions/xml_tool_processor.rb +++ b/lib/completions/xml_tool_processor.rb @@ -62,31 +62,14 @@ def <<(text) def finish return [] if @function_buffer.blank? - xml = Nokogiri::HTML5.fragment(@function_buffer) - normalize_function_ids!(xml) - last_invoke = xml.at("invoke:last") - if last_invoke - last_invoke.next_sibling.remove while last_invoke.next_sibling - xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling + idx = -1 + parse_malformed_xml(@function_buffer).map do |tool| + ToolCall.new( + id: "tool_#{idx += 1}", + name: tool[:tool_name], + parameters: tool[:parameters] + ) end - - xml - .css("invoke") - .map do |invoke| - tool_name = invoke.at("tool_name").content.force_encoding("UTF-8") - tool_id = invoke.at("tool_id").content.force_encoding("UTF-8") - parameters = {} - invoke - .at("parameters") - &.children - &.each do |node| - next if node.text? - name = node.name - value = node.content.to_s - parameters[name.to_sym] = value.to_s.force_encoding("UTF-8") - end - ToolCall.new(id: tool_id, name: tool_name, parameters: parameters) - end end def should_cancel? @@ -95,6 +78,40 @@ def should_cancel? private + def parse_malformed_xml(input) + input + .scan( + %r{ + + \s* + + ([^<]+) + + \s* + + (.*?) + + \s* + + }mx, + ) + .map do |tool_name, params| + { + tool_name: tool_name.strip, + parameters: + params + .scan(%r{ + <([^>]+)> + (.*?) + + }mx) + .each_with_object({}) do |(name, value), hash| + hash[name.to_sym] = value.gsub(/^$/, "") + end, + } + end + end + def normalize_function_ids!(function_buffer) function_buffer .css("invoke") diff --git a/spec/lib/completions/xml_tool_processor_spec.rb b/spec/lib/completions/xml_tool_processor_spec.rb index 003f4356c..ad9c1b477 100644 --- a/spec/lib/completions/xml_tool_processor_spec.rb +++ b/spec/lib/completions/xml_tool_processor_spec.rb @@ -12,6 +12,26 @@ expect(processor.should_cancel?).to eq(false) end + it "can handle mix and match xml cause tool llms may not encode" do + xml = (<<~XML).strip + + + hello + + world sam + \n\n]]> + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = result.last.first + expect(tool_call.parameters).to eq(hello: "world sam", test: "\n\n") + end + it "is usable for simple single message mode" do xml = (<<~XML).strip hello @@ -149,8 +169,7 @@ result << (processor.finish) # Should just do its best to parse the XML - tool_call = - DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" }) + tool_call = DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: {}) expect(result).to eq([["text"], [tool_call]]) end From 5c8d62407b9ca0083785f16c9db057ec5cd68080 Mon Sep 17 00:00:00 2001 From: Sam Saffron Date: Thu, 14 Nov 2024 17:52:48 +1100 Subject: [PATCH 4/4] fix ollama tool support --- lib/completions/dialects/ollama.rb | 14 ++++++++++++-- 1 file changed, 12 insertions(+), 2 deletions(-) diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index 3a32e5927..601b937cf 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -37,11 +37,21 @@ def model_msg(msg) end def tool_call_msg(msg) - tools_dialect.from_raw_tool_call(msg) + if enable_native_tool? + tools_dialect.from_raw_tool_call(msg) + else + translated = tools_dialect.from_raw_tool_call(msg) + { role: "assistant", content: translated } + end end def tool_msg(msg) - tools_dialect.from_raw_tool(msg) + if enable_native_tool? + tools_dialect.from_raw_tool(msg) + else + translated = tools_dialect.from_raw_tool(msg) + { role: "user", content: translated } + end end def system_msg(msg)