diff --git a/LLama.Unittest/Constants.cs b/LLama.Unittest/Constants.cs
index 3d81f23bf..d501b189b 100644
--- a/LLama.Unittest/Constants.cs
+++ b/LLama.Unittest/Constants.cs
@@ -7,6 +7,7 @@ internal static class Constants
public static readonly string GenerativeModelPath = "Models/Llama-3.2-1B-Instruct-Q4_0.gguf";
public static readonly string GenerativeModelPath2 = "Models/smollm-360m-instruct-add-basics-q8_0.gguf";
public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
+ public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf";
public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf";
diff --git a/LLama.Unittest/LLama.Unittest.csproj b/LLama.Unittest/LLama.Unittest.csproj
index 2dd85e88f..6b0e0b8f4 100644
--- a/LLama.Unittest/LLama.Unittest.csproj
+++ b/LLama.Unittest/LLama.Unittest.csproj
@@ -46,6 +46,12 @@
smollm-360m-instruct-add-basics-q8_0.gguf
+
+ https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-FP16.gguf
+ Models
+ jina-reranker-v1-tiny-en-FP16.gguf
+
+
https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf
Models
@@ -130,6 +136,9 @@
PreserveNewest
+
+ PreserveNewest
+
PreserveNewest
diff --git a/LLama.Unittest/LLamaRerankerTests.cs b/LLama.Unittest/LLamaRerankerTests.cs
new file mode 100644
index 000000000..b8dfcfa8d
--- /dev/null
+++ b/LLama.Unittest/LLamaRerankerTests.cs
@@ -0,0 +1,79 @@
+using LLama.Common;
+using LLama.Extensions;
+using LLama.Native;
+using Microsoft.Extensions.AI;
+using System.Runtime.InteropServices;
+using Xunit.Abstractions;
+
+namespace LLama.Unittest;
+
+public sealed class LLamaRerankerTests: IDisposable
+{
+ private readonly ITestOutputHelper _testOutputHelper;
+ private readonly LLamaReranker _reranker;
+ public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
+ {
+ _testOutputHelper = testOutputHelper;
+
+ var @params = new ModelParams(Constants.RerankingModelPath)
+ {
+ ContextSize = 0,
+ PoolingType = LLamaPoolingType.Rank,
+ GpuLayerCount = Constants.CIGpuLayerCount,
+
+ };
+ using var weights = LLamaWeights.LoadFromFile(@params);
+ _reranker = new LLamaReranker(weights, @params);
+ }
+
+ public void Dispose()
+ {
+ _reranker.Dispose();
+ }
+
+ [Fact]
+ public async Task CompareRerankingScore()
+ {
+
+
+ var input = "what is panda?";
+ var documents = new string[] {
+ "hi",
+ "it's a bear",
+ string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
+ "sometimes called a panda bear or simply panda",
+ "is a bear species endemic to China.")
+ };
+ var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false);
+
+ Assert.True(documents.Length == scores.Count);
+
+ _testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}");
+ _testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}");
+ _testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}");
+ }
+
+ [Fact]
+ public async Task MostRelevantDocument()
+ {
+ var input = "what is panda?";
+ var documents = new string[] {
+ "hi",
+ "it's a bear",
+ string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
+ "sometimes called a panda bear or simply panda",
+ "is a bear species endemic to China.")
+ };
+ var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true);
+
+ Assert.NotNull(scores);
+ Assert.True(documents.Length == scores.Count);
+
+ int maxIndex = scores.Select((score, index) => (score, index))
+ .MaxBy(x => x.score)
+ .index;
+
+ var maxScoreDocument = documents[maxIndex];
+ Assert.Equal(documents[2], maxScoreDocument);
+ }
+}
diff --git a/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs
new file mode 100644
index 000000000..1ce53f395
--- /dev/null
+++ b/LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs
@@ -0,0 +1,42 @@
+using System.Text;
+using System.Xml.Linq;
+using LLama.Common;
+using LLama.Extensions;
+using Microsoft.Extensions.Logging;
+
+
+namespace LLama.Unittest.Native;
+
+public class SafeLlamaModelHandleVocabularyTests: IDisposable
+{
+ private readonly LLamaWeights _model;
+
+ public SafeLlamaModelHandleVocabularyTests()
+ {
+ var @params = new ModelParams(Constants.RerankingModelPath)
+ {
+ ContextSize = 0,
+ PoolingType = LLama.Native.LLamaPoolingType.Rank,
+ GpuLayerCount = Constants.CIGpuLayerCount
+ };
+ _model = LLamaWeights.LoadFromFile(@params);
+ }
+
+ public void Dispose()
+ {
+ _model.Dispose();
+ }
+
+ [Fact]
+ public void GetLLamaTokenString()
+ {
+ var bos = _model.Vocab.BOS;
+ var eos = _model.Vocab.EOS;
+
+ var bosStr = _model.Vocab.LLamaTokenToString(bos, true);
+ var eosStr = _model.Vocab.LLamaTokenToString(eos, true);
+
+ Assert.Equal("", bosStr);
+ Assert.Equal("", eosStr);
+ }
+}
diff --git a/LLama/LLamaReranker.cs b/LLama/LLamaReranker.cs
new file mode 100644
index 000000000..fa42d7f35
--- /dev/null
+++ b/LLama/LLamaReranker.cs
@@ -0,0 +1,201 @@
+using System;
+using System.Collections.Generic;
+using System.IO;
+using System.Linq;
+using System.Text;
+using System.Threading;
+using System.Threading.Tasks;
+using System.Xml.Linq;
+using LLama.Abstractions;
+using LLama.Exceptions;
+using LLama.Native;
+using Microsoft.Extensions.Logging;
+
+namespace LLama;
+
+///
+/// Get rank scores between prompt and documents
+///
+public sealed partial class LLamaReranker
+ : IDisposable
+{
+ ///
+ /// Dimension of embedding vectors
+ ///
+ public int EmbeddingSize => Context.EmbeddingSize;
+
+ ///
+ /// LLama Context
+ ///
+ public LLamaContext Context { get; }
+
+ ///
+ /// Create a new reranker, using the given LLamaWeights
+ ///
+ ///
+ ///
+ ///
+ public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
+ {
+ if (@params.UBatchSize != @params.BatchSize)
+ throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params));
+ if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true })
+ throw new NotSupportedException("Computing rank in encoder-decoder models is not supported");
+ if (@params.PoolingType != LLamaPoolingType.Rank)
+ throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank");
+ Context = weights.CreateContext(@params, logger);
+ NativeApi.llama_set_embeddings(Context.NativeHandle, true);
+ }
+
+ ///
+ public void Dispose()
+ {
+ Context.Dispose();
+ }
+
+ ///
+ /// Retrieve relevance scores for input and documents by reranking, execute once.
+ ///
+ ///
+ ///
+ /// Whether to normalize the score to the range (0, 1)
+ ///
+ ///
+ ///
+ ///
+ public async Task> GetRelevanceScores(string input, IReadOnlyList documents, bool normalize = false, CancellationToken cancellationToken = default)
+ {
+ List scores = new List(documents.Count);
+ var inputTokens = Context.Tokenize(input);
+ var batch = new LLamaBatch();
+ var clearFlag = 0;
+
+ for(var idx = 0; idx < documents.Count; idx++)
+ {
+ var docTokens = Context.Tokenize(documents[idx] ?? "");
+ LLamaToken[] tokens = [.. inputTokens, .. docTokens];
+
+ if (batch.TokenCount + tokens.Length > Context.ContextSize)
+ {
+ scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
+ batch.Clear();
+ clearFlag = idx;
+ }
+
+ for (var i = 0; i < tokens.Length; i++)
+ batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true);
+ }
+ if (batch.LogitPositionCount > 0)
+ {
+ scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
+ batch.Clear();
+ }
+
+ return scores;
+ }
+
+ ///
+ /// Retrieve relevance score for input and document by reranking
+ ///
+ ///
+ ///
+ ///
+ /// Whether to normalize the score to the range (0, 1)
+ ///
+ ///
+ ///
+ public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default)
+ {
+ var inputTokens = Context.Tokenize(input);
+ var docTokens = Context.Tokenize(document);
+ LLamaToken[] tokens = [..inputTokens, ..docTokens];
+ var batch = new LLamaBatch();
+ for (var i = 0; i < tokens.Length; i++)
+ batch.Add(tokens[i], i, LLamaSeqId.Zero, true);
+
+ // clear previous kv_cache values
+ Context.NativeHandle.KvCacheClear();
+
+ // Check if we should cancel the work, just before doing anything expensive (encode/decode)
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Run model
+ switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
+ {
+ case (true, false):
+ {
+ var result = await Context.EncodeAsync(batch, cancellationToken);
+ if (result != EncodeResult.Ok)
+ throw new RuntimeError($"Failed to encode: {result}");
+ break;
+ }
+
+ case (false, true):
+ {
+ var result = await Context.DecodeAsync(batch, cancellationToken);
+ if (result != DecodeResult.Ok)
+ throw new RuntimeError($"Failed to decode: {result}");
+ break;
+ }
+
+ default:
+ throw new NotSupportedException("Unsupported model type");
+ }
+
+ var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];
+
+ Context.NativeHandle.KvCacheClear();
+
+ return (normalize ? Sigmoid(score) : score, tokens.Length);
+ }
+
+ private async Task> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default)
+ {
+ var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1];
+ var seqNum = logicCap.Value + 1;
+ List scores = new List(seqNum);
+ // clear previous kv_cache values
+ Context.NativeHandle.KvCacheClear();
+
+ // Check if we should cancel the work, just before doing anything expensive (encode/decode)
+ cancellationToken.ThrowIfCancellationRequested();
+
+ // Run model
+ switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
+ {
+ case (true, false):
+ {
+ var result = await Context.EncodeAsync(batch, cancellationToken);
+ if (result != EncodeResult.Ok)
+ throw new RuntimeError($"Failed to encode: {result}");
+ break;
+ }
+
+ case (false, true):
+ {
+ var result = await Context.DecodeAsync(batch, cancellationToken);
+ if (result != DecodeResult.Ok)
+ throw new RuntimeError($"Failed to decode: {result}");
+ break;
+ }
+
+ default:
+ throw new NotSupportedException("Unsupported model type");
+ }
+
+ for (var seq = 0; seq < seqNum; seq++)
+ {
+ var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0];
+ scores.Add(normalize ? Sigmoid(score) : score);
+ }
+
+ Context.NativeHandle.KvCacheClear();
+
+ return scores;
+ }
+
+ private float Sigmoid(float x)
+ {
+ return (float)(1 / (1 + Math.Exp(-x)));
+ }
+}
diff --git a/LLama/Native/SafeLlamaModelHandle.cs b/LLama/Native/SafeLlamaModelHandle.cs
index db198ec30..801d25167 100644
--- a/LLama/Native/SafeLlamaModelHandle.cs
+++ b/LLama/Native/SafeLlamaModelHandle.cs
@@ -651,7 +651,18 @@ internal Vocabulary(SafeLlamaModelHandle model)
_model = model;
}
- private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
+ private static LLamaToken? Normalize(LLamaToken token)
+ {
+ return token == -1 ? null : token;
+ }
+
+ ///
+ /// Translate LLamaToken to String
+ ///
+ ///
+ ///
+ ///
+ public string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
{
if (!token.HasValue)
return null;
@@ -676,11 +687,6 @@ internal Vocabulary(SafeLlamaModelHandle model)
return Encoding.UTF8.GetStringFromSpan(slice);
}
- private static LLamaToken? Normalize(LLamaToken token)
- {
- return token == -1 ? null : token;
- }
-
///
/// Total number of tokens in this vocabulary
///