diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs index 3d81f23bf..d501b189b 100644 --- a/LLama.Unittest/Constants.cs +++ b/LLama.Unittest/Constants.cs @@ -7,6 +7,7 @@ internal static class Constants public static readonly string GenerativeModelPath = "Models/Llama-3.2-1B-Instruct-Q4_0.gguf"; public static readonly string GenerativeModelPath2 = "Models/smollm-360m-instruct-add-basics-q8_0.gguf"; public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf"; + public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf"; public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf"; public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf"; diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj index 2dd85e88f..6b0e0b8f4 100644 --- a/LLama.Unittest/LLama.Unittest.csproj +++ b/LLama.Unittest/LLama.Unittest.csproj @@ -46,6 +46,12 @@ smollm-360m-instruct-add-basics-q8_0.gguf + + https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-FP16.gguf + Models + jina-reranker-v1-tiny-en-FP16.gguf + + https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf Models @@ -130,6 +136,9 @@ PreserveNewest + + PreserveNewest + PreserveNewest diff --git a/LLama.Unittest/LLamaRerankerTests.cs b/LLama.Unittest/LLamaRerankerTests.cs new file mode 100644 index 000000000..b8dfcfa8d --- /dev/null +++ b/LLama.Unittest/LLamaRerankerTests.cs @@ -0,0 +1,79 @@ +using LLama.Common; +using LLama.Extensions; +using LLama.Native; +using Microsoft.Extensions.AI; +using System.Runtime.InteropServices; +using Xunit.Abstractions; + +namespace LLama.Unittest; + +public sealed class LLamaRerankerTests: IDisposable +{ + private readonly ITestOutputHelper _testOutputHelper; + private readonly LLamaReranker _reranker; + public LLamaRerankerTests(ITestOutputHelper testOutputHelper) + { + _testOutputHelper = testOutputHelper; + + var @params = new ModelParams(Constants.RerankingModelPath) + { + ContextSize = 0, + PoolingType = LLamaPoolingType.Rank, + GpuLayerCount = Constants.CIGpuLayerCount, + + }; + using var weights = LLamaWeights.LoadFromFile(@params); + _reranker = new LLamaReranker(weights, @params); + } + + public void Dispose() + { + _reranker.Dispose(); + } + + [Fact] + public async Task CompareRerankingScore() + { + + + var input = "what is panda?"; + var documents = new string[] { + "hi", + "it's a bear", + string.Join(", ","The giant panda (Ailuropoda melanoleuca)", + "sometimes called a panda bear or simply panda", + "is a bear species endemic to China.") + }; + var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false); + + Assert.True(documents.Length == scores.Count); + + _testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}"); + _testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}"); + _testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}"); + } + + [Fact] + public async Task MostRelevantDocument() + { + var input = "what is panda?"; + var documents = new string[] { + "hi", + "it's a bear", + string.Join(", ","The giant panda (Ailuropoda melanoleuca)", + "sometimes called a panda bear or simply panda", + "is a bear species endemic to China.") + }; + var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true); + + Assert.NotNull(scores); + Assert.True(documents.Length == scores.Count); + + int maxIndex = scores.Select((score, index) => (score, index)) + .MaxBy(x => x.score) + .index; + + var maxScoreDocument = documents[maxIndex]; + Assert.Equal(documents[2], maxScoreDocument); + } +} diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs new file mode 100644 index 000000000..1ce53f395 --- /dev/null +++ b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs @@ -0,0 +1,42 @@ +using System.Text; +using System.Xml.Linq; +using LLama.Common; +using LLama.Extensions; +using Microsoft.Extensions.Logging; + + +namespace LLama.Unittest.Native; + +public class SafeLlamaModelHandleVocabularyTests: IDisposable +{ + private readonly LLamaWeights _model; + + public SafeLlamaModelHandleVocabularyTests() + { + var @params = new ModelParams(Constants.RerankingModelPath) + { + ContextSize = 0, + PoolingType = LLama.Native.LLamaPoolingType.Rank, + GpuLayerCount = Constants.CIGpuLayerCount + }; + _model = LLamaWeights.LoadFromFile(@params); + } + + public void Dispose() + { + _model.Dispose(); + } + + [Fact] + public void GetLLamaTokenString() + { + var bos = _model.Vocab.BOS; + var eos = _model.Vocab.EOS; + + var bosStr = _model.Vocab.LLamaTokenToString(bos, true); + var eosStr = _model.Vocab.LLamaTokenToString(eos, true); + + Assert.Equal("", bosStr); + Assert.Equal("", eosStr); + } +} diff --git a/LLama/LLamaReranker.cs b/LLama/LLamaReranker.cs new file mode 100644 index 000000000..fa42d7f35 --- /dev/null +++ b/LLama/LLamaReranker.cs @@ -0,0 +1,201 @@ +using System; +using System.Collections.Generic; +using System.IO; +using System.Linq; +using System.Text; +using System.Threading; +using System.Threading.Tasks; +using System.Xml.Linq; +using LLama.Abstractions; +using LLama.Exceptions; +using LLama.Native; +using Microsoft.Extensions.Logging; + +namespace LLama; + +/// +/// Get rank scores between prompt and documents +/// +public sealed partial class LLamaReranker + : IDisposable +{ + /// + /// Dimension of embedding vectors + /// + public int EmbeddingSize => Context.EmbeddingSize; + + /// + /// LLama Context + /// + public LLamaContext Context { get; } + + /// + /// Create a new reranker, using the given LLamaWeights + /// + /// + /// + /// + public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null) + { + if (@params.UBatchSize != @params.BatchSize) + throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params)); + if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true }) + throw new NotSupportedException("Computing rank in encoder-decoder models is not supported"); + if (@params.PoolingType != LLamaPoolingType.Rank) + throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank"); + Context = weights.CreateContext(@params, logger); + NativeApi.llama_set_embeddings(Context.NativeHandle, true); + } + + /// + public void Dispose() + { + Context.Dispose(); + } + + /// + /// Retrieve relevance scores for input and documents by reranking, execute once. + /// + /// + /// + /// Whether to normalize the score to the range (0, 1) + /// + /// + /// + /// + public async Task> GetRelevanceScores(string input, IReadOnlyList documents, bool normalize = false, CancellationToken cancellationToken = default) + { + List scores = new List(documents.Count); + var inputTokens = Context.Tokenize(input); + var batch = new LLamaBatch(); + var clearFlag = 0; + + for(var idx = 0; idx < documents.Count; idx++) + { + var docTokens = Context.Tokenize(documents[idx] ?? ""); + LLamaToken[] tokens = [.. inputTokens, .. docTokens]; + + if (batch.TokenCount + tokens.Length > Context.ContextSize) + { + scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken)); + batch.Clear(); + clearFlag = idx; + } + + for (var i = 0; i < tokens.Length; i++) + batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true); + } + if (batch.LogitPositionCount > 0) + { + scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken)); + batch.Clear(); + } + + return scores; + } + + /// + /// Retrieve relevance score for input and document by reranking + /// + /// + /// + /// + /// Whether to normalize the score to the range (0, 1) + /// + /// + /// + public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default) + { + var inputTokens = Context.Tokenize(input); + var docTokens = Context.Tokenize(document); + LLamaToken[] tokens = [..inputTokens, ..docTokens]; + var batch = new LLamaBatch(); + for (var i = 0; i < tokens.Length; i++) + batch.Add(tokens[i], i, LLamaSeqId.Zero, true); + + // clear previous kv_cache values + Context.NativeHandle.KvCacheClear(); + + // Check if we should cancel the work, just before doing anything expensive (encode/decode) + cancellationToken.ThrowIfCancellationRequested(); + + // Run model + switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) + { + case (true, false): + { + var result = await Context.EncodeAsync(batch, cancellationToken); + if (result != EncodeResult.Ok) + throw new RuntimeError($"Failed to encode: {result}"); + break; + } + + case (false, true): + { + var result = await Context.DecodeAsync(batch, cancellationToken); + if (result != DecodeResult.Ok) + throw new RuntimeError($"Failed to decode: {result}"); + break; + } + + default: + throw new NotSupportedException("Unsupported model type"); + } + + var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0]; + + Context.NativeHandle.KvCacheClear(); + + return (normalize ? Sigmoid(score) : score, tokens.Length); + } + + private async Task> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default) + { + var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1]; + var seqNum = logicCap.Value + 1; + List scores = new List(seqNum); + // clear previous kv_cache values + Context.NativeHandle.KvCacheClear(); + + // Check if we should cancel the work, just before doing anything expensive (encode/decode) + cancellationToken.ThrowIfCancellationRequested(); + + // Run model + switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) + { + case (true, false): + { + var result = await Context.EncodeAsync(batch, cancellationToken); + if (result != EncodeResult.Ok) + throw new RuntimeError($"Failed to encode: {result}"); + break; + } + + case (false, true): + { + var result = await Context.DecodeAsync(batch, cancellationToken); + if (result != DecodeResult.Ok) + throw new RuntimeError($"Failed to decode: {result}"); + break; + } + + default: + throw new NotSupportedException("Unsupported model type"); + } + + for (var seq = 0; seq < seqNum; seq++) + { + var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0]; + scores.Add(normalize ? Sigmoid(score) : score); + } + + Context.NativeHandle.KvCacheClear(); + + return scores; + } + + private float Sigmoid(float x) + { + return (float)(1 / (1 + Math.Exp(-x))); + } +} diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs index db198ec30..801d25167 100644 --- a/LLama/Native/SafeLlamaModelHandle.cs +++ b/LLama/Native/SafeLlamaModelHandle.cs @@ -651,7 +651,18 @@ internal Vocabulary(SafeLlamaModelHandle model) _model = model; } - private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) + private static LLamaToken? Normalize(LLamaToken token) + { + return token == -1 ? null : token; + } + + /// + /// Translate LLamaToken to String + /// + /// + /// + /// + public string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken) { if (!token.HasValue) return null; @@ -676,11 +687,6 @@ internal Vocabulary(SafeLlamaModelHandle model) return Encoding.UTF8.GetStringFromSpan(slice); } - private static LLamaToken? Normalize(LLamaToken token) - { - return token == -1 ? null : token; - } - /// /// Total number of tokens in this vocabulary ///