-
Notifications
You must be signed in to change notification settings - Fork 420
add LLamaReranker and tests #1150
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
6f4c53c
15c5247
c604359
d99670c
05677fe
4258cc1
8d61a92
49ae0a8
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,74 @@ | ||
using LLama.Common; | ||
using LLama.Extensions; | ||
using LLama.Native; | ||
using Microsoft.Extensions.AI; | ||
using System.Runtime.InteropServices; | ||
using Xunit.Abstractions; | ||
|
||
namespace LLama.Unittest; | ||
|
||
public sealed class LLamaRerankerTests | ||
{ | ||
private readonly ITestOutputHelper _testOutputHelper; | ||
private readonly LLamaReranker _reranker; | ||
public LLamaRerankerTests(ITestOutputHelper testOutputHelper) | ||
{ | ||
_testOutputHelper = testOutputHelper; | ||
|
||
var @params = new ModelParams(Constants.RerankingModelPath) | ||
{ | ||
ContextSize = 0, | ||
PoolingType = LLamaPoolingType.Rank, | ||
GpuLayerCount = Constants.CIGpuLayerCount, | ||
|
||
}; | ||
using var weights = LLamaWeights.LoadFromFile(@params); | ||
_reranker = new LLamaReranker(weights, @params); | ||
} | ||
|
||
[Fact] | ||
public async Task CompareRerankingScore() | ||
{ | ||
|
||
|
||
var input = "what is panda?"; | ||
var documents = new string[] { | ||
"hi", | ||
"it's a bear", | ||
string.Join(", ","The giant panda (Ailuropoda melanoleuca)", | ||
"sometimes called a panda bear or simply panda", | ||
"is a bear species endemic to China.") | ||
}; | ||
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: false); | ||
|
||
Assert.True(documents.Length == scores.Count); | ||
|
||
_testOutputHelper.WriteLine($"Rerank score 0: {scores[0]:F4}"); | ||
_testOutputHelper.WriteLine($"Rerank score 1: {scores[1]:F4}"); | ||
_testOutputHelper.WriteLine($"Rerank score 2: {scores[2]:F4}"); | ||
} | ||
|
||
[Fact] | ||
public async Task MostRelevantDocument() | ||
{ | ||
var input = "what is panda?"; | ||
var documents = new string[] { | ||
"hi", | ||
"it's a bear", | ||
string.Join(", ","The giant panda (Ailuropoda melanoleuca)", | ||
"sometimes called a panda bear or simply panda", | ||
"is a bear species endemic to China.") | ||
}; | ||
var scores = await _reranker.GetRelevanceScores(input, documents, normalize: true); | ||
|
||
Assert.NotNull(scores); | ||
Assert.True(documents.Length == scores.Count); | ||
|
||
int maxIndex = scores.Select((score, index) => (score, index)) | ||
.MaxBy(x => x.score) | ||
.index; | ||
|
||
var maxScoreDocument = documents[maxIndex]; | ||
Assert.Equal(documents[2], maxScoreDocument); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
using System.Text; | ||
using System.Xml.Linq; | ||
using LLama.Common; | ||
using LLama.Extensions; | ||
using Microsoft.Extensions.Logging; | ||
|
||
|
||
namespace LLama.Unittest.Native; | ||
|
||
public class SafeLlamaModelHandleVocabularyTests | ||
{ | ||
private readonly LLamaWeights _model; | ||
|
||
public SafeLlamaModelHandleVocabularyTests() | ||
{ | ||
var @params = new ModelParams(Constants.RerankingModelPath) | ||
{ | ||
ContextSize = 0, | ||
PoolingType = LLama.Native.LLamaPoolingType.Rank, | ||
GpuLayerCount = Constants.CIGpuLayerCount | ||
}; | ||
_model = LLamaWeights.LoadFromFile(@params); | ||
} | ||
|
||
[Fact] | ||
public void GetLLamaTokenString() | ||
{ | ||
var bos = _model.Vocab.BOS; | ||
var eos = _model.Vocab.EOS; | ||
|
||
var bosStr = _model.Vocab.LLamaTokenToString(bos, true); | ||
var eosStr = _model.Vocab.LLamaTokenToString(eos, true); | ||
|
||
Assert.Equal("<s>", bosStr); | ||
Assert.Equal("</s>", eosStr); | ||
} | ||
} |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,200 @@ | ||
using System; | ||
using System.Collections.Generic; | ||
using System.IO; | ||
using System.Linq; | ||
using System.Text; | ||
using System.Threading; | ||
using System.Threading.Tasks; | ||
using System.Xml.Linq; | ||
using LLama.Abstractions; | ||
using LLama.Exceptions; | ||
using LLama.Native; | ||
using Microsoft.Extensions.Logging; | ||
|
||
namespace LLama; | ||
|
||
/// <summary> | ||
/// Get rank scores between prompt and documents | ||
/// </summary> | ||
public sealed partial class LLamaReranker | ||
: IDisposable | ||
{ | ||
/// <summary> | ||
/// Dimension of embedding vectors | ||
/// </summary> | ||
public int EmbeddingSize => Context.EmbeddingSize; | ||
|
||
/// <summary> | ||
/// LLama Context | ||
/// </summary> | ||
public LLamaContext Context { get; } | ||
|
||
/// <summary> | ||
/// Create a new reranker, using the given LLamaWeights | ||
/// </summary> | ||
/// <param name="weights"></param> | ||
/// <param name="params"></param> | ||
/// <param name="logger"></param> | ||
public LLamaReranker(LLamaWeights weights, IContextParams @params, ILogger? logger = null) | ||
{ | ||
if (@params.UBatchSize != @params.BatchSize) | ||
throw new ArgumentException("For non-causal models, batch size must be equal to ubatch size", nameof(@params)); | ||
if (weights.NativeHandle is { HasEncoder: true, HasDecoder: true }) | ||
throw new NotSupportedException("Computing rank in encoder-decoder models is not supported"); | ||
if (@params.PoolingType != LLamaPoolingType.Rank) | ||
throw new NotSupportedException("Computing rank score, PoolingType must be equal to LLamaPoolingType.Rank"); | ||
Context = weights.CreateContext(@params, logger); | ||
NativeApi.llama_set_embeddings(Context.NativeHandle, true); | ||
} | ||
|
||
/// <inheritdoc /> | ||
public void Dispose() | ||
{ | ||
Context.Dispose(); | ||
} | ||
|
||
/// <summary> | ||
/// Retrieve relevance scores for input and documents by reranking, execute once. | ||
/// </summary> | ||
/// <param name="input"></param> | ||
/// <param name="documents"></param> | ||
/// <param name="normalize">Whether to normalize the score to the range (0, 1)</param> | ||
/// <param name="cancellationToken"></param> | ||
/// <returns></returns> | ||
/// <exception cref="RuntimeError"></exception> | ||
/// <exception cref="NotSupportedException"></exception> | ||
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default) | ||
{ | ||
List<float> scores = new List<float>(documents.Count); | ||
var inputTokens = Context.Tokenize(input); | ||
var batch = new LLamaBatch(); | ||
var clearFlag = 0; | ||
|
||
for(var idx = 0; idx < documents.Count; idx++) | ||
{ | ||
var docTokens = Context.Tokenize(documents[idx] ?? ""); | ||
LLamaToken[] tokens = [.. inputTokens, .. docTokens]; | ||
|
||
if (batch.TokenCount + tokens.Length > Context.ContextSize) | ||
{ | ||
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken)); | ||
batch.Clear(); | ||
clearFlag = idx; | ||
} | ||
|
||
for (var i = 0; i < tokens.Length; i++) | ||
batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true); | ||
} | ||
if (batch.LogitPositionCount > 0) | ||
{ | ||
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken)); | ||
batch.Clear(); | ||
} | ||
|
||
return scores; | ||
} | ||
|
||
/// <summary> | ||
/// Retrieve relevance score for input and document by reranking | ||
/// </summary> | ||
/// <param name="input"></param> | ||
/// <param name="document"></param> | ||
/// <param name="cancellationToken"></param> | ||
/// <returns></returns> | ||
/// <exception cref="RuntimeError"></exception> | ||
/// <exception cref="NotSupportedException"></exception> | ||
public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default) | ||
Check warning on line 106 in LLama/LLamaReranker.cs
|
||
{ | ||
var inputTokens = Context.Tokenize(input); | ||
var docTokens = Context.Tokenize(document); | ||
LLamaToken[] tokens = [..inputTokens, ..docTokens]; | ||
var batch = new LLamaBatch(); | ||
for (var i = 0; i < tokens.Length; i++) | ||
batch.Add(tokens[i], i, LLamaSeqId.Zero, true); | ||
|
||
// clear previous kv_cache values | ||
Context.NativeHandle.KvCacheClear(); | ||
|
||
// Check if we should cancel the work, just before doing anything expensive (encode/decode) | ||
cancellationToken.ThrowIfCancellationRequested(); | ||
|
||
// Run model | ||
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) | ||
{ | ||
case (true, false): | ||
{ | ||
var result = await Context.EncodeAsync(batch, cancellationToken); | ||
if (result != EncodeResult.Ok) | ||
throw new RuntimeError($"Failed to encode: {result}"); | ||
break; | ||
} | ||
|
||
case (false, true): | ||
{ | ||
var result = await Context.DecodeAsync(batch, cancellationToken); | ||
if (result != DecodeResult.Ok) | ||
throw new RuntimeError($"Failed to decode: {result}"); | ||
break; | ||
} | ||
|
||
default: | ||
throw new NotSupportedException("Unsupported model type"); | ||
} | ||
|
||
var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0]; | ||
|
||
Context.NativeHandle.KvCacheClear(); | ||
|
||
return (normalize ? Sigmoid(score) : score, tokens.Length); | ||
} | ||
|
||
private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default) | ||
{ | ||
var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1]; | ||
var seqNum = logicCap.Value + 1; | ||
List<float> scores = new List<float>(seqNum); | ||
// clear previous kv_cache values | ||
Context.NativeHandle.KvCacheClear(); | ||
|
||
// Check if we should cancel the work, just before doing anything expensive (encode/decode) | ||
cancellationToken.ThrowIfCancellationRequested(); | ||
|
||
// Run model | ||
switch (Context.NativeHandle.ModelHandle.HasEncoder, Context.NativeHandle.ModelHandle.HasDecoder) | ||
{ | ||
case (true, false): | ||
{ | ||
var result = await Context.EncodeAsync(batch, cancellationToken); | ||
if (result != EncodeResult.Ok) | ||
throw new RuntimeError($"Failed to encode: {result}"); | ||
break; | ||
} | ||
|
||
case (false, true): | ||
{ | ||
var result = await Context.DecodeAsync(batch, cancellationToken); | ||
if (result != DecodeResult.Ok) | ||
throw new RuntimeError($"Failed to decode: {result}"); | ||
break; | ||
} | ||
|
||
default: | ||
throw new NotSupportedException("Unsupported model type"); | ||
} | ||
|
||
for (var seq = 0; seq < seqNum; seq++) | ||
{ | ||
var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0]; | ||
scores.Add(normalize ? Sigmoid(score) : score); | ||
} | ||
|
||
Context.NativeHandle.KvCacheClear(); | ||
|
||
return scores; | ||
} | ||
|
||
private float Sigmoid(float x) | ||
{ | ||
return (float)(1 / (1 + Math.Exp(-x))); | ||
} | ||
} |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this still needed after the latest changes? It looks like it's not used any more
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No longer needed in llamareranker, but I suggest that this can be opened as a public function