Skip to content

Commit d99670c

Browse files
committed
fix Reranking if documents is too large
1 parent c604359 commit d99670c

File tree

1 file changed

+53
-29
lines changed

1 file changed

+53
-29
lines changed

LLama/LLamaReranker.cs

+53-29
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
using System.Text;
66
using System.Threading;
77
using System.Threading.Tasks;
8+
using System.Xml.Linq;
89
using LLama.Abstractions;
910
using LLama.Exceptions;
1011
using LLama.Native;
@@ -65,16 +66,52 @@ public void Dispose()
6566
public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOnlyList<string> documents, bool normalize = false, CancellationToken cancellationToken = default)
6667
{
6768
List<float> scores = new List<float>(documents.Count);
68-
var batch = new LLamaBatch();
6969
var inputTokens = Context.Tokenize(input);
70-
foreach (var (index, document) in documents.Select((item, index) => (index, item)))
70+
var batch = new LLamaBatch();
71+
var clearFlag = 0;
72+
73+
for(var idx = 0; idx < documents.Count; idx++)
7174
{
72-
var docTokens = Context.Tokenize(document);
75+
var docTokens = Context.Tokenize(documents[idx]);
7376
LLamaToken[] tokens = [.. inputTokens, .. docTokens];
77+
78+
if (batch.TokenCount + tokens.Length > Context.ContextSize)
79+
{
80+
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
81+
batch.Clear();
82+
clearFlag = idx;
83+
}
84+
7485
for (var i = 0; i < tokens.Length; i++)
75-
batch.Add(tokens[i], i, (LLamaSeqId)index, true);
86+
batch.Add(tokens[i], i, (LLamaSeqId)(idx - clearFlag), true);
87+
}
88+
if (batch.LogitPositionCount > 0)
89+
{
90+
scores.AddRange(await CalcRelevanceScores(batch, normalize, cancellationToken));
91+
batch.Clear();
7692
}
7793

94+
return scores;
95+
}
96+
97+
/// <summary>
98+
/// Retrieve relevance score for input and document by reranking
99+
/// </summary>
100+
/// <param name="input"></param>
101+
/// <param name="document"></param>
102+
/// <param name="cancellationToken"></param>
103+
/// <returns></returns>
104+
/// <exception cref="RuntimeError"></exception>
105+
/// <exception cref="NotSupportedException"></exception>
106+
public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default)
107+
{
108+
var inputTokens = Context.Tokenize(input);
109+
var docTokens = Context.Tokenize(document);
110+
LLamaToken[] tokens = [..inputTokens, ..docTokens];
111+
var batch = new LLamaBatch();
112+
for (var i = 0; i < tokens.Length; i++)
113+
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);
114+
78115
// clear previous kv_cache values
79116
Context.NativeHandle.KvCacheClear();
80117

@@ -104,35 +141,18 @@ public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOn
104141
throw new NotSupportedException("Unsupported model type");
105142
}
106143

107-
for (var i = 0; i < documents.Count; i++)
108-
{
109-
var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)i)[0];
110-
scores.Add(normalize ? Sigmoid(score) : score);
111-
}
144+
var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];
112145

113146
Context.NativeHandle.KvCacheClear();
114147

115-
return scores;
148+
return (normalize ? Sigmoid(score) : score, tokens.Length);
116149
}
117150

118-
/// <summary>
119-
/// Retrieve relevance score for input and document by reranking
120-
/// </summary>
121-
/// <param name="input"></param>
122-
/// <param name="document"></param>
123-
/// <param name="cancellationToken"></param>
124-
/// <returns></returns>
125-
/// <exception cref="RuntimeError"></exception>
126-
/// <exception cref="NotSupportedException"></exception>
127-
public async Task<(float Score, int Tokens)> GetRelevanceScoreWithTokenCount(string input, string document, bool normalize = false, CancellationToken cancellationToken = default)
151+
private async Task<IReadOnlyList<float>> CalcRelevanceScores(LLamaBatch batch, bool normalize = false, CancellationToken cancellationToken = default)
128152
{
129-
var inputTokens = Context.Tokenize(input);
130-
var docTokens = Context.Tokenize(document);
131-
LLamaToken[] tokens = [..inputTokens, ..docTokens];
132-
var batch = new LLamaBatch();
133-
for (var i = 0; i < tokens.Length; i++)
134-
batch.Add(tokens[i], i, LLamaSeqId.Zero, true);
135-
153+
var (logicCap, _) = batch.GetLogitPositions()[batch.LogitPositionCount - 1];
154+
var seqNum = logicCap.Value + 1;
155+
List<float> scores = new List<float>(seqNum);
136156
// clear previous kv_cache values
137157
Context.NativeHandle.KvCacheClear();
138158

@@ -162,11 +182,15 @@ public async Task<IReadOnlyList<float>> GetRelevanceScores(string input, IReadOn
162182
throw new NotSupportedException("Unsupported model type");
163183
}
164184

165-
var score = Context.NativeHandle.GetEmbeddingsSeq(LLamaSeqId.Zero)[0];
185+
for (var seq = 0; seq < seqNum; seq++)
186+
{
187+
var score = Context.NativeHandle.GetEmbeddingsSeq((LLamaSeqId)seq)[0];
188+
scores.Add(normalize ? Sigmoid(score) : score);
189+
}
166190

167191
Context.NativeHandle.KvCacheClear();
168192

169-
return (normalize ? Sigmoid(score) : score, tokens.Length);
193+
return scores;
170194
}
171195

172196
private float Sigmoid(float x)

0 commit comments

Comments
 (0)