Skip to content

add LLamaReranker and tests #1150

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions LLama.Unittest/Constants.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ internal static class Constants
public static readonly string GenerativeModelPath = "Models/Llama-3.2-1B-Instruct-Q4_0.gguf";
public static readonly string GenerativeModelPath2 = "Models/smollm-360m-instruct-add-basics-q8_0.gguf";
public static readonly string EmbeddingModelPath = "Models/all-MiniLM-L12-v2.Q8_0.gguf";
public static readonly string RerankingModelPath = "Models/jina-reranker-v1-tiny-en-FP16.gguf";

public static readonly string LLavaModelPath = "Models/llava-v1.6-mistral-7b.Q3_K_XS.gguf";
public static readonly string LLavaMmpPath = "Models/mmproj-model-f16.gguf";
Expand Down
10 changes: 8 additions & 2 deletions LLama.Unittest/LLama.Unittest.csproj
Original file line number Diff line number Diff line change
Expand Up @@ -34,8 +34,11 @@

<DownloadFile SourceUrl="https://huggingface.co/HuggingFaceTB/smollm-360M-instruct-v0.2-Q8_0-GGUF/resolve/main/smollm-360m-instruct-add-basics-q8_0.gguf" DestinationFolder="Models" DestinationFileName="smollm-360m-instruct-add-basics-q8_0.gguf" SkipUnchangedFiles="true">
</DownloadFile>

<DownloadFile SourceUrl="https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf" DestinationFolder="Models" DestinationFileName="llava-v1.6-mistral-7b.Q3_K_XS.gguf" SkipUnchangedFiles="true">

<DownloadFile SourceUrl="https://huggingface.co/gpustack/jina-reranker-v1-tiny-en-GGUF/resolve/main/jina-reranker-v1-tiny-en-FP16.gguf" DestinationFolder="Models" DestinationFileName="jina-reranker-v1-tiny-en-FP16.gguf" SkipUnchangedFiles="true">
</DownloadFile>

<DownloadFile SourceUrl="https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/llava-v1.6-mistral-7b.Q3_K_XS.gguf" DestinationFolder="Models" DestinationFileName="llava-v1.6-mistral-7b.Q3_K_XS.gguf" SkipUnchangedFiles="true">
</DownloadFile>

<DownloadFile SourceUrl="https://huggingface.co/cjpais/llava-1.6-mistral-7b-gguf/resolve/main/mmproj-model-f16.gguf" DestinationFolder="Models" DestinationFileName="mmproj-model-f16.gguf" SkipUnchangedFiles="true">
Expand Down Expand Up @@ -63,6 +66,9 @@
<None Update="Models\Llama-3.2-1B-Instruct-Q4_0.gguf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Models\jina-reranker-v1-tiny-en-FP16.gguf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
<None Update="Models\smollm-360m-instruct-add-basics-q8_0.gguf">
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
</None>
Expand Down
74 changes: 74 additions & 0 deletions LLama.Unittest/LLamaRerankerTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
using LLama.Common;
using LLama.Extensions;
using LLama.Native;
using Microsoft.Extensions.AI;
using System.Runtime.InteropServices;
using Xunit.Abstractions;

namespace LLama.Unittest;

public sealed class LLamaRerankerTests
{
private readonly ITestOutputHelper _testOutputHelper;
private readonly LLamaReranker _reranker;
public LLamaRerankerTests(ITestOutputHelper testOutputHelper)
{
_testOutputHelper = testOutputHelper;

var @params = new ModelParams(Constants.RerankingModelPath)
{
ContextSize = 0,
PoolingType = LLamaPoolingType.Rank,
GpuLayerCount = Constants.CIGpuLayerCount,

};
using var weights = LLamaWeights.LoadFromFile(@params);
_reranker = new LLamaReranker(weights, @params);
}

[Fact]
public async Task CompareRerankingScore()
{


var input = "what is panda?";
var documents = new string[] {
"hi",
"it's a bear",
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
"sometimes called a panda bear or simply panda",
"is a bear species endemic to China.")
};
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false);

Assert.True(documents.Length == scores.Count);

_testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}");
_testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}");
_testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}");
}

[Fact]
public async Task MostRelevantDocument()
{
var input = "what is panda?";
var documents = new string[] {
"hi",
"it's a bear",
string.Join(", ","The giant panda (Ailuropoda melanoleuca)",
"sometimes called a panda bear or simply panda",
"is a bear species endemic to China.")
};
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true);

Assert.NotNull(scores);
Assert.True(documents.Length == scores.Count);

int maxIndex = scores.Select((score, index) => (score, index))
.MaxBy(x => x.score)
.index;

var maxScoreDocument = documents[maxIndex];
Assert.Equal(documents[2], maxScoreDocument);
}
}
37 changes: 37 additions & 0 deletions LLama.Unittest/Native/SafeLlamaModelHandleVocabularyTests.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
using System.Text;
using System.Xml.Linq;
using LLama.Common;
using LLama.Extensions;
using Microsoft.Extensions.Logging;


namespace LLama.Unittest.Native;

public class SafeLlamaModelHandleVocabularyTests
{
private readonly LLamaWeights _model;

public SafeLlamaModelHandleVocabularyTests()
{
var @params = new ModelParams(Constants.RerankingModelPath)
{
ContextSize = 0,
PoolingType = LLama.Native.LLamaPoolingType.Rank,
GpuLayerCount = Constants.CIGpuLayerCount
};
_model = LLamaWeights.LoadFromFile(@params);
}

[Fact]
public void GetLLamaTokenString()
{
var bos = _model.Vocab.BOS;
var eos = _model.Vocab.EOS;

var bosStr = _model.Vocab.LLamaTokenToString(bos, true);
var eosStr = _model.Vocab.LLamaTokenToString(eos, true);

Assert.Equal("<s>", bosStr);
Assert.Equal("</s>", eosStr);
}
}
200 changes: 200 additions & 0 deletions LLama/LLamaReranker.cs
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Text;
using System.Threading;
using System.Threading.Tasks;
using System.Xml.Linq;
using LLama.Abstractions;
using LLama.Exceptions;
using LLama.Native;
using Microsoft.Extensions.Logging;

namespace LLama;

/// <summary>
/// Get rank scores between prompt and documents
/// </summary>
public sealed partial class LLamaReranker
: IDisposable
{
/// <summary>
/// Dimension of embedding vectors
/// </summary>
public int EmbeddingSize => Context.EmbeddingSize;

/// <summary>
/// LLama Context
/// </summary>
public LLamaContext Context { get; }

/// <summary>
/// Create a new reranker, using the given LLamaWeights
/// </summary>
/// <param name="weights"></param>
/// <param name="params"></param>
/// <param name="logger"></param>
public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null)
{
if (@params.UBatchSize != @params.BatchSize)
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params));
if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true })
throw new NotSupportedException("Computing rank in encoder-decoder models is not supported");
if (@params.PoolingType != LLamaPoolingType.Rank)
throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank");
Context = weights.CreateContext(@params, logger);
NativeApi.llama_set_embeddings(Context.NativeHandle, true);
}

/// <inheritdoc />
public void Dispose()
{
Context.Dispose();
}

/// <summary>
/// Retrieve relevance scores for input and documents by reranking, execute once.
/// </summary>
/// <param name="input"></param>
/// <param name="documents"></param>
/// <param name="normalize">Whether to normalize the score to the range (0, 1)</param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
/// <exception cref="NotSupportedException"></exception>
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default)
{
List<float> scores = new List<float>(documents.Count);
var inputTokens = Context.Tokenize(input);
var batch = new LLamaBatch();
var clearFlag = 0;

for(var idx = 0; idx < documents.Count; idx++)
{
var docTokens = Context.Tokenize(documents[idx] ?? "");
LLamaToken[] tokens = [.. inputTokens, .. docTokens];

if (batch.TokenCount + tokens.Length > Context.ContextSize)
{
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
batch.Clear();
clearFlag = idx;
}

for (var i = 0; i < tokens.Length; i++)
batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true);
}
if (batch.LogitPositionCount > 0)
{
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
batch.Clear();
}

return scores;
}

/// <summary>
/// Retrieve relevance score for input and document by reranking
/// </summary>
/// <param name="input"></param>
/// <param name="document"></param>
/// <param name="cancellationToken"></param>
/// <returns></returns>
/// <exception cref="RuntimeError"></exception>
/// <exception cref="NotSupportedException"></exception>
public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default)

Check warning on line 106 in LLama/LLamaReranker.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Parameter 'normalize' has no matching param tag in the XML comment for 'LLamaReranker.GetRelevanceScoreWithTokenCount(string, string, bool, CancellationToken)' (but other parameters do)

Check warning on line 106 in LLama/LLamaReranker.cs

View workflow job for this annotation

GitHub Actions / Test (windows-release)

Parameter 'normalize' has no matching param tag in the XML comment for 'LLamaReranker.GetRelevanceScoreWithTokenCount(string, string, bool, CancellationToken)' (but other parameters do)

Check warning on line 106 in LLama/LLamaReranker.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Parameter 'normalize' has no matching param tag in the XML comment for 'LLamaReranker.GetRelevanceScoreWithTokenCount(string, string, bool, CancellationToken)' (but other parameters do)

Check warning on line 106 in LLama/LLamaReranker.cs

View workflow job for this annotation

GitHub Actions / Test (osx-release)

Parameter 'normalize' has no matching param tag in the XML comment for 'LLamaReranker.GetRelevanceScoreWithTokenCount(string, string, bool, CancellationToken)' (but other parameters do)

Check warning on line 106 in LLama/LLamaReranker.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Parameter 'normalize' has no matching param tag in the XML comment for 'LLamaReranker.GetRelevanceScoreWithTokenCount(string, string, bool, CancellationToken)' (but other parameters do)

Check warning on line 106 in LLama/LLamaReranker.cs

View workflow job for this annotation

GitHub Actions / Test (linux-release)

Parameter 'normalize' has no matching param tag in the XML comment for 'LLamaReranker.GetRelevanceScoreWithTokenCount(string, string, bool, CancellationToken)' (but other parameters do)
{
var inputTokens = Context.Tokenize(input);
var docTokens = Context.Tokenize(document);
LLamaToken[] tokens = [..inputTokens, ..docTokens];
var batch = new LLamaBatch();
for (var i = 0; i < tokens.Length; i++)
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);

// clear previous kv_cache values
Context.NativeHandle.KvCacheClear();

// Check if we should cancel the work, just before doing anything expensive (encode/decode)
cancellationToken.ThrowIfCancellationRequested();

// Run model
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
{
case (true, false):
{
var result = await Context.EncodeAsync(batch, cancellationToken);
if (result != EncodeResult.Ok)
throw new RuntimeError($"Failed to encode: {result}");
break;
}

case (false, true):
{
var result = await Context.DecodeAsync(batch, cancellationToken);
if (result != DecodeResult.Ok)
throw new RuntimeError($"Failed to decode: {result}");
break;
}

default:
throw new NotSupportedException("Unsupported model type");
}

var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];

Context.NativeHandle.KvCacheClear();

return (normalize ? Sigmoid(score) : score, tokens.Length);
}

private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default)
{
var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1];
var seqNum = logicCap.Value + 1;
List<float> scores = new List<float>(seqNum);
// clear previous kv_cache values
Context.NativeHandle.KvCacheClear();

// Check if we should cancel the work, just before doing anything expensive (encode/decode)
cancellationToken.ThrowIfCancellationRequested();

// Run model
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder)
{
case (true, false):
{
var result = await Context.EncodeAsync(batch, cancellationToken);
if (result != EncodeResult.Ok)
throw new RuntimeError($"Failed to encode: {result}");
break;
}

case (false, true):
{
var result = await Context.DecodeAsync(batch, cancellationToken);
if (result != DecodeResult.Ok)
throw new RuntimeError($"Failed to decode: {result}");
break;
}

default:
throw new NotSupportedException("Unsupported model type");
}

for (var seq = 0; seq < seqNum; seq++)
{
var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0];
scores.Add(normalize ? Sigmoid(score) : score);
}

Context.NativeHandle.KvCacheClear();

return scores;
}

private float Sigmoid(float x)
{
return (float)(1 / (1 + Math.Exp(-x)));
}
}
18 changes: 12 additions & 6 deletions LLama/Native/SafeLlamaModelHandle.cs
Original file line number Diff line number Diff line change
Expand Up @@ -651,7 +651,18 @@ internal Vocabulary(SafeLlamaModelHandle model)
_model = model;
}

private string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
private static LLamaToken? Normalize(LLamaToken token)
{
return token == -1 ? null : token;
}

/// <summary>
/// Translate LLamaToken to String
/// </summary>
/// <param name="token"></param>
/// <param name="isSpecialToken"></param>
/// <returns></returns>
public string? LLamaTokenToString(LLamaToken? token, bool isSpecialToken)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this still needed after the latest changes? It looks like it's not used any more

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No longer needed in llamareranker, but I suggest that this can be opened as a public function

{
if (!token.HasValue)
return null;
Expand All @@ -676,11 +687,6 @@ internal Vocabulary(SafeLlamaModelHandle model)
return Encoding.UTF8.GetStringFromSpan(slice);
}

private static LLamaToken? Normalize(LLamaToken token)
{
return token == -1 ? null : token;
}

/// <summary>
/// Total number of tokens in this vocabulary
/// </summary>
Expand Down
Loading