From 4c0c82f32be60565ebcebf5f1acff31c505625e1 Mon Sep 17 00:00:00 2001
From: Carl Lange <carl@flax.ie>
Date: Thu, 16 Jan 2025 22:46:17 +0000
Subject: [PATCH] add tenuous sqlite-vec support

---
 langchainrb_rails.gemspec                     |   5 +-
 .../vectorsearch/sqlite_vec.rb                | 129 ++++++++++++++++++
 lib/langchainrb_rails.rb                      |   1 +
 lib/langchainrb_rails/active_record/hooks.rb  |  13 +-
 .../langchainrb_rails/sqlite_vec_generator.rb |  76 +++++++++++
 .../add_sqlite_vector_column_template.rb.tt   |  10 ++
 ...ble_sqlite_vector_extension_template.rb.tt |   5 +
 .../templates/sqlite_vec_initializer.rb.tt    |  20 +++
 lib/langchainrb_rails/railtie.rb              |   4 +-
 .../sqlite_vec_generator_spec.rb              |  39 ++++++
 10 files changed, 298 insertions(+), 4 deletions(-)
 create mode 100644 lib/langchainrb_overrides/vectorsearch/sqlite_vec.rb
 create mode 100644 lib/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator.rb
 create mode 100644 lib/langchainrb_rails/generators/langchainrb_rails/templates/add_sqlite_vector_column_template.rb.tt
 create mode 100644 lib/langchainrb_rails/generators/langchainrb_rails/templates/enable_sqlite_vector_extension_template.rb.tt
 create mode 100644 lib/langchainrb_rails/generators/langchainrb_rails/templates/sqlite_vec_initializer.rb.tt
 create mode 100644 spec/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator_spec.rb

diff --git a/langchainrb_rails.gemspec b/langchainrb_rails.gemspec
index 3c20f8d..a3d8fed 100644
--- a/langchainrb_rails.gemspec
+++ b/langchainrb_rails.gemspec
@@ -32,8 +32,9 @@ Gem::Specification.new do |spec|
 
   spec.add_dependency "langchainrb", ">= 0.19"
 
+  spec.add_development_dependency "generator_spec"
   spec.add_development_dependency "pry-byebug", "~> 3.10.0"
-  spec.add_development_dependency "yard", "~> 0.9.34"
   spec.add_development_dependency "rails", "> 6.0.0"
-  spec.add_development_dependency "generator_spec"
+  spec.add_development_dependency "yard", "~> 0.9.34"
+  spec.metadata['rubygems_mfa_required'] = 'true'
 end
diff --git a/lib/langchainrb_overrides/vectorsearch/sqlite_vec.rb b/lib/langchainrb_overrides/vectorsearch/sqlite_vec.rb
new file mode 100644
index 0000000..337f5d2
--- /dev/null
+++ b/lib/langchainrb_overrides/vectorsearch/sqlite_vec.rb
@@ -0,0 +1,129 @@
+# frozen_string_literal: true
+
+# Overriding Langchain.rb's SqliteVec implementation to use ActiveRecord.
+# Original implementation: https://github.com/andreibondarev/langchainrb/blob/main/lib/langchain/vectorsearch/sqlite_vec.rb
+
+require "sqlite_vec"
+
+module Langchain::Vectorsearch
+  class SqliteVec < Base
+    #
+    # The SQLite vector search adapter
+    #
+    # Gem requirements:
+    #     gem "sqlite_vec", "~> 0.1"
+    #
+    # Usage:
+    #     sqlite_vec = Langchain::Vectorsearch::SqliteVec.new(
+    #       url: ":memory:",
+    #       index_name: "documents",
+    #       namespace: "test",
+    #       llm: llm
+    #     )
+    #
+
+    attr_reader :llm
+    attr_accessor :model
+
+    # @param llm [Object] The LLM client to use
+    def initialize(llm:)
+      depends_on "sqlite3"
+      depends_on "sqlite_vec"
+
+      # Use the existing ActiveRecord connection
+      # TODO this doesn't appear to work so weird hack in the initializer
+      @db = ActiveRecord::Base.connection.raw_connection
+      @db.enable_load_extension(true)
+      ::SqliteVec.load(@db)
+      @db.enable_load_extension(false)
+
+      @llm = llm
+
+      super
+    end
+
+    # Add a list of texts to the index
+    # @param texts [Array<String>] The texts to add to the index
+    # Add a list of texts to the index
+    # @param texts [Array<String>] The texts to add to the index
+    # @param ids [Array<String>] The ids to add to the index, in the same order as the texts
+    # @return [Array<Integer>] The ids of the added texts.
+    def add_texts(texts:, ids:)
+      embeddings = texts.map do |text|
+        llm.embed(text: text).embedding
+      end
+
+      model.find(ids).each.with_index do |record, i|
+        record.update_column(:embedding, embeddings[i].pack("f*"))
+      end
+    end
+
+    def update_texts(texts:, ids:)
+      add_texts(texts: texts, ids: ids)
+    end
+
+    # Remove vectors from the index
+    #
+    # @param ids [Array<String>] The ids of the vectors to remove
+    # @return [Boolean] true
+    def remove_texts(ids:)
+      # Since the record is being destroyed and the `embedding` is a column on the record,
+      # we don't need to do anything here.
+      true
+    end
+
+    # Create default schema
+    def create_default_schema
+      Rake::Task["sqlite_vec"].invoke
+    end
+
+    # Destroy default schema
+    def destroy_default_schema
+      # Tell the user to rollback the migration
+    end
+
+    # Search for similar texts in the index
+    # @param query [String] The text to search for
+    # @param k [Integer] The number of top results to return
+    # @return [Array<Hash>] The results of the search
+    def similarity_search(query:, k: 4)
+      embedding = llm.embed(text: query).embedding
+
+      similarity_search_by_vector(
+        embedding: embedding,
+        k: k
+      )
+    end
+
+    # Search for similar texts in the index by the passed in vector.
+    # You must generate your own vector using the same LLM that generated the embeddings stored in the Vectorsearch DB.
+    # @param embedding [Array<Float>] The vector to search for
+    # @param k [Integer] The number of top results to return
+    # @return [Array<Hash>] The results of the search
+    def similarity_search_by_vector(embedding:, k: 4)
+      model
+        .nearest_neighbors(:embedding, embedding)
+        .limit(k)
+    end
+
+    # Ask a question and return the answer
+    # @param question [String] The question to ask
+    # @param k [Integer] The number of results to have in context
+    # @yield [String] Stream responses back one String at a time
+    # @return [String] The answer to the question
+    def ask(question:, k: 4, &block)
+      # Noisy as the embedding column has a lot of data
+      ActiveRecord::Base.logger.silence do
+        search_results = similarity_search(query: question, k: k)
+
+        context = search_results.map do |result|
+          result.as_vector
+        end
+        context = context.join("\n---\n")
+        prompt = generate_rag_prompt(question: question, context: context)
+        messages = [{ role: "user", content: prompt }]
+        llm.chat(messages: messages, &block)
+      end
+    end
+  end
+end
diff --git a/lib/langchainrb_rails.rb b/lib/langchainrb_rails.rb
index e109881..c1161ad 100644
--- a/lib/langchainrb_rails.rb
+++ b/lib/langchainrb_rails.rb
@@ -10,6 +10,7 @@
 require "langchainrb_rails/version"
 
 require_relative "langchainrb_overrides/vectorsearch/pgvector"
+require_relative "langchainrb_overrides/vectorsearch/sqlite_vec"
 require_relative "langchainrb_overrides/assistant"
 require_relative "langchainrb_overrides/message"
 
diff --git a/lib/langchainrb_rails/active_record/hooks.rb b/lib/langchainrb_rails/active_record/hooks.rb
index fca1e39..d4d0569 100644
--- a/lib/langchainrb_rails/active_record/hooks.rb
+++ b/lib/langchainrb_rails/active_record/hooks.rb
@@ -91,6 +91,16 @@ def vectorsearch
             has_neighbors(:embedding)
             class_variable_get(:@@provider).model = self
           end
+
+          # SQLite-Vec-specific configuration
+          return unless LangchainrbRails.config.vectorsearch.is_a?(Langchain::Vectorsearch::SqliteVec)
+
+          # Define nearest_neighbors scope for SQLite-Vec
+          scope :nearest_neighbors, lambda { |column, vector|
+            unscoped.select("#{table_name}.*, vec_distance_cosine(#{column}, '#{vector.to_json}') as distance")
+                    .order("distance")
+          }
+          class_variable_get(:@@provider).model = self
         end
 
         # Iterates over records and generate embeddings.
@@ -112,7 +122,8 @@ def similarity_search(query, k: 1)
             k: k
           )
 
-          return records if LangchainrbRails.config.vectorsearch.is_a?(Langchain::Vectorsearch::Pgvector)
+          return records if LangchainrbRails.config.vectorsearch.is_a?(Langchain::Vectorsearch::Pgvector) ||
+                            LangchainrbRails.config.vectorsearch.is_a?(Langchain::Vectorsearch::SqliteVec)
 
           # We use "__id" when Weaviate is the provider
           ids = records.map { |record| record.try("id") || record.dig("__id") }
diff --git a/lib/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator.rb b/lib/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator.rb
new file mode 100644
index 0000000..09105b9
--- /dev/null
+++ b/lib/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator.rb
@@ -0,0 +1,76 @@
+# frozen_string_literal: true
+
+module LangchainrbRails
+  module Generators
+    #
+    # Usage:
+    #     rails generate langchainrb_rails:sqlite_vec --model=Product --llm=openai
+    #
+    class SqliteVecGenerator < LangchainrbRails::Generators::BaseGenerator
+      desc "This generator adds sqlite-vec vectorsearch integration to your ActiveRecord model"
+      source_root File.join(__dir__, "templates")
+
+      def copy_migration
+        migration_template "enable_sqlite_vector_extension_template.rb", "db/migrate/enable_sqlite_vector_extension.rb",
+                           migration_version: migration_version
+        migration_template "add_sqlite_vector_column_template.rb", "db/migrate/add_sqlite_vector_column_to_#{table_name}.rb",
+                           migration_version: migration_version
+      end
+
+      def create_initializer_file
+        template "sqlite_vec_initializer.rb", "config/initializers/langchainrb_rails.rb"
+      end
+
+      def migration_version
+        "[#{::ActiveRecord::VERSION::MAJOR}.#{::ActiveRecord::VERSION::MINOR}]"
+      end
+
+      def add_to_model
+        inject_into_class "app/models/#{model_name.underscore}.rb", model_name do
+          "  vectorsearch\n\n  after_save :upsert_to_vectorsearch\n\n  after_destroy :destroy_from_vectorsearch\n\n"
+        end
+      end
+
+      # def add_to_gemfile
+      #   # Dependency for Sqlite-vec
+      #   gem_group :default do
+      #     gem "sqlite3", "> 2.1" unless gem_exists?("sqlite3")
+      #     gem "sqlite-vec", "~> 0.1.7.alpha.2" unless gem_exists?("sqlite-vec")
+
+      #     if options["llm"] == "ollama"
+      #       gem "faraday" unless gem_exists?("faraday")
+      #     elsif options["llm"] == "openai"
+      #       gem "ruby-openai" unless gem_exists?("ruby-openai")
+      #     end
+      #   end
+      # end
+
+      private
+
+      # @return [String] Name of the model
+      def model_name
+        options["model"]
+      end
+
+      # @return [String] Table name of the model
+      def table_name
+        model_name.tableize
+      end
+
+      # @return [String] LLM provider to use
+      def llm
+        options["llm"]
+      end
+
+      # @return [Langchain::LLM::*] LLM class
+      def llm_class
+        Langchain::LLM.const_get(LLMS[llm])
+      end
+
+      # @return [Integer] Dimension of the vector to be used
+      def vector_dimensions
+        llm_class.default_dimensions
+      end
+    end
+  end
+end
diff --git a/lib/langchainrb_rails/generators/langchainrb_rails/templates/add_sqlite_vector_column_template.rb.tt b/lib/langchainrb_rails/generators/langchainrb_rails/templates/add_sqlite_vector_column_template.rb.tt
new file mode 100644
index 0000000..c0faede
--- /dev/null
+++ b/lib/langchainrb_rails/generators/langchainrb_rails/templates/add_sqlite_vector_column_template.rb.tt
@@ -0,0 +1,10 @@
+class <%= migration_class_name %> < ActiveRecord::Migration<%= migration_version %>
+  def change
+    add_column :<%= table_name %>, :embedding, :blob,
+      limit: LangchainrbRails
+        .config
+        .vectorsearch
+        .llm
+        .default_dimensions
+  end
+end
diff --git a/lib/langchainrb_rails/generators/langchainrb_rails/templates/enable_sqlite_vector_extension_template.rb.tt b/lib/langchainrb_rails/generators/langchainrb_rails/templates/enable_sqlite_vector_extension_template.rb.tt
new file mode 100644
index 0000000..b8c6940
--- /dev/null
+++ b/lib/langchainrb_rails/generators/langchainrb_rails/templates/enable_sqlite_vector_extension_template.rb.tt
@@ -0,0 +1,5 @@
+class <%= migration_class_name %> < ActiveRecord::Migration<%= migration_version %>
+  def change
+    SqliteVec.load(ActiveRecord::Base.connection)
+  end
+end
diff --git a/lib/langchainrb_rails/generators/langchainrb_rails/templates/sqlite_vec_initializer.rb.tt b/lib/langchainrb_rails/generators/langchainrb_rails/templates/sqlite_vec_initializer.rb.tt
new file mode 100644
index 0000000..62da4db
--- /dev/null
+++ b/lib/langchainrb_rails/generators/langchainrb_rails/templates/sqlite_vec_initializer.rb.tt
@@ -0,0 +1,20 @@
+# frozen_string_literal: true
+
+LangchainrbRails.configure do |config|
+  config.vectorsearch = Langchain::Vectorsearch::SqliteVec.new(
+    llm: <%= llm_class %>.new(api_key: ENV["<%= llm.upcase %>_API_KEY"])
+  )
+end 
+
+ActiveRecord::ConnectionAdapters::SQLite3Adapter.class_eval do
+  alias_method :before_configure_connection, :configure_connection
+
+  def configure_connection
+    before_configure_connection
+
+    @raw_connection.enable_load_extension(true)
+    SqliteVec.load(@raw_connection)
+    @raw_connection.enable_load_extension(false)
+  end
+end
+
diff --git a/lib/langchainrb_rails/railtie.rb b/lib/langchainrb_rails/railtie.rb
index cb2acf6..4101167 100644
--- a/lib/langchainrb_rails/railtie.rb
+++ b/lib/langchainrb_rails/railtie.rb
@@ -2,8 +2,9 @@
 
 module LangchainrbRails
   class Railtie < Rails::Railtie
-    initializer "langchain" do
+    initializer "langchainrb_rails" do
       ActiveSupport.on_load(:active_record) do
+        require "sqlite_vec" if defined?(SqliteVec)
         ::ActiveRecord::Base.include LangchainrbRails::ActiveRecord::Hooks
       end
     end
@@ -15,6 +16,7 @@ class Railtie < Rails::Railtie
       require_relative "generators/langchainrb_rails/pgvector_generator"
       require_relative "generators/langchainrb_rails/qdrant_generator"
       require_relative "generators/langchainrb_rails/prompt_generator"
+      require_relative "generators/langchainrb_rails/sqlite_vec_generator"
     end
   end
 end
diff --git a/spec/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator_spec.rb b/spec/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator_spec.rb
new file mode 100644
index 0000000..cdfeee2
--- /dev/null
+++ b/spec/langchainrb_rails/generators/langchainrb_rails/sqlite_vec_generator_spec.rb
@@ -0,0 +1,39 @@
+# frozen_string_literal: true
+
+require "spec_helper"
+
+TMP_ROOT = Pathname.new(Dir.mktmpdir("tmp"))
+
+RSpec.describe LangchainrbRails::Generators::SqliteVecGenerator, type: :generator do
+  destination TMP_ROOT
+  arguments %w[--model=SupportArticle --llm=openai]
+
+  before(:all) do
+    prepare_destination
+    model_file = TMP_ROOT.join("app/models/support_article.rb")
+    gemfile = TMP_ROOT.join("Gemfile")
+    write_file(model_file, "class SupportArticle < ApplicationRecord\nend")
+    write_file(gemfile, "")
+    run_generator
+  end
+
+  after(:all) { delete_directory(TMP_ROOT) }
+
+  it "creates the initializer file" do
+    assert_file "config/initializers/langchainrb_rails.rb"
+  end
+
+  it "adds the vectorsearch module to the model" do
+    assert_file "app/models/support_article.rb" do |model|
+      assert_match(/vectorsearch/, model)
+      assert_match(/after_save :upsert_to_vectorsearch/, model)
+    end
+  end
+
+  it "adds the necessary gems to the Gemfile" do
+    assert_file "Gemfile" do |gemfile|
+      assert_match(/gem "sqlite_vec"/, gemfile)
+      assert_match(/gem "ruby-openai"/, gemfile)
+    end
+  end
+end