Skip to content

Commit 5b9e453

Browse files
yunfengzhou-hubSxnan
authored andcommitted
[FLINK-38549][model] Support limiting context window size
This closes #27139
1 parent 562883f commit 5b9e453

File tree

11 files changed

+474
-24
lines changed

11 files changed

+474
-24
lines changed

docs/content.zh/docs/connectors/models/openai.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,33 @@ FROM ML_PREDICT(
130130
<td>String</td>
131131
<td>模型名称,例如:<code>gpt-3.5-turbo</code>, <code>text-embedding-ada-002</code>。</td>
132132
</tr>
133+
<tr>
134+
<td>
135+
<h5>max-context-size</h5>
136+
</td>
137+
<td>可选</td>
138+
<td style="word-wrap: break-word;">(none)</td>
139+
<td>Integer</td>
140+
<td>单个请求的最大上下文长度,单位为Token数量。当长度超过该值时,将使用context-overflow-action指定的溢出行为。</td>
141+
</tr>
142+
<tr>
143+
<td>
144+
<h5>context-overflow-action</h5>
145+
</td>
146+
<td>可选</td>
147+
<td style="word-wrap: break-word;">(none)</td>
148+
<td>String</td>
149+
<td>处理上下文溢出的操作。支持的操作:
150+
<ul>
151+
<li><code>truncated-tail</code>(默认): 从上下文尾部截断超出的token。</li>
152+
<li><code>truncated-tail-log</code>: 从上下文尾部截断超出的token。记录截断日志。</li>
153+
<li><code>truncated-head</code>: 从上下文头部截断超出的token。</li>
154+
<li><code>truncated-head-log</code>: 从上下文头部截断超出的token。记录截断日志。</li>
155+
<li><code>skipped</code>: 跳过输入行。</li>
156+
<li><code>skipped-log</code>: 跳过输入行。记录跳过日志。</li>
157+
</ul>
158+
</td>
159+
</tr>
133160
</tbody>
134161
</table>
135162

docs/content/docs/connectors/models/openai.md

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -130,6 +130,33 @@ FROM ML_PREDICT(
130130
<td>String</td>
131131
<td>Model name, e.g. <code>gpt-3.5-turbo</code>, <code>text-embedding-ada-002</code>.</td>
132132
</tr>
133+
<tr>
134+
<td>
135+
<h5>max-context-size</h5>
136+
</td>
137+
<td>optional</td>
138+
<td style="word-wrap: break-word;">(none)</td>
139+
<td>Integer</td>
140+
<td>Max number of tokens for context. context-overflow-action would be triggered if this threshold is exceeded.</td>
141+
</tr>
142+
<tr>
143+
<td>
144+
<h5>context-overflow-action</h5>
145+
</td>
146+
<td>optional</td>
147+
<td style="word-wrap: break-word;">(none)</td>
148+
<td>String</td>
149+
<td>Action to handle context overflows. Supported actions:
150+
<ul>
151+
<li><code>truncated-tail</code>(default): Truncates exceeded tokens from the tail of the context.</li>
152+
<li><code>truncated-tail-log</code>: Truncates exceeded tokens from the tail of the context. Records the truncation log.</li>
153+
<li><code>truncated-head</code>: Truncates exceeded tokens from the head of the context.</li>
154+
<li><code>truncated-head-log</code>: Truncates exceeded tokens from the head of the context. Records the truncation log.</li>
155+
<li><code>skipped</code>: Skips the input row.</li>
156+
<li><code>skipped-log</code>: Skips the input row. Records the skipping log.</li>
157+
</ul>
158+
</td>
159+
</tr>
133160
</tbody>
134161
</table>
135162

flink-models/flink-model-openai/pom.xml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,12 @@ under the License.
7272
<optional>${flink.markBundledAsOptional}</optional>
7373
</dependency>
7474

75+
<dependency>
76+
<groupId>com.knuddels</groupId>
77+
<artifactId>jtokkit</artifactId>
78+
<version>1.1.0</version>
79+
</dependency>
80+
7581
<!-- Core dependencies -->
7682
<dependency>
7783
<groupId>org.apache.flink</groupId>

flink-models/flink-model-openai/src/main/java/org/apache/flink/model/openai/AbstractOpenAIModelFunction.java

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.flink.table.api.config.ExecutionConfigOptions;
2626
import org.apache.flink.table.catalog.Column;
2727
import org.apache.flink.table.catalog.ResolvedSchema;
28+
import org.apache.flink.table.data.RowData;
2829
import org.apache.flink.table.factories.ModelProviderFactory;
2930
import org.apache.flink.table.functions.AsyncPredictFunction;
3031
import org.apache.flink.table.functions.FunctionContext;
@@ -35,7 +36,12 @@
3536
import org.slf4j.Logger;
3637
import org.slf4j.LoggerFactory;
3738

39+
import javax.annotation.Nullable;
40+
41+
import java.util.Collection;
42+
import java.util.Collections;
3843
import java.util.List;
44+
import java.util.concurrent.CompletableFuture;
3945
import java.util.stream.Collectors;
4046

4147
import static org.apache.flink.configuration.description.TextElement.code;
@@ -73,11 +79,32 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
7379
code("gpt-3.5-turbo"), code("text-embedding-ada-002"))
7480
.build());
7581

82+
public static final ConfigOption<Integer> MAX_CONTEXT_SIZE =
83+
ConfigOptions.key("max-context-size")
84+
.intType()
85+
.noDefaultValue()
86+
.withDescription(
87+
"Max number of tokens for context. context-overflow-action would be triggered if this threshold is exceeded.");
88+
89+
public static final ConfigOption<ContextOverflowAction> CONTEXT_OVERFLOW_ACTION =
90+
ConfigOptions.key("context-overflow-action")
91+
.enumType(ContextOverflowAction.class)
92+
.defaultValue(ContextOverflowAction.TRUNCATED_TAIL)
93+
.withDescription(
94+
Description.builder()
95+
.text("Action to handle context overflows. Supported actions:")
96+
.linebreak()
97+
.text(ContextOverflowAction.getAllValuesAndDescriptions())
98+
.build());
99+
76100
protected transient OpenAIClientAsync client;
77101

78102
private final int numRetry;
79103
private final String baseUrl;
80104
private final String apiKey;
105+
private final String model;
106+
@Nullable private final Integer maxContextSize;
107+
private final ContextOverflowAction contextOverflowAction;
81108

82109
public AbstractOpenAIModelFunction(
83110
ModelProviderFactory.Context factoryContext, ReadableConfig config) {
@@ -94,6 +121,9 @@ public AbstractOpenAIModelFunction(
94121
// resilience while maintaining throughput efficiency.
95122
this.numRetry =
96123
config.get(ExecutionConfigOptions.TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY) * 10;
124+
this.model = config.get(MODEL);
125+
this.maxContextSize = config.get(MAX_CONTEXT_SIZE);
126+
this.contextOverflowAction = config.get(CONTEXT_OVERFLOW_ACTION);
97127

98128
validateSingleColumnSchema(
99129
factoryContext.getCatalogModel().getResolvedInputSchema(),
@@ -106,6 +136,24 @@ public void open(FunctionContext context) throws Exception {
106136
super.open(context);
107137
LOG.debug("Creating an OpenAI client.");
108138
this.client = OpenAIUtils.createAsyncClient(baseUrl, apiKey, numRetry);
139+
this.contextOverflowAction.initializeEncodingForContextLimit(model, maxContextSize);
140+
}
141+
142+
@Override
143+
public CompletableFuture<Collection<RowData>> asyncPredict(RowData rowData) {
144+
if (rowData.isNullAt(0)) {
145+
LOG.warn("Input is null, skipping prediction.");
146+
return CompletableFuture.completedFuture(Collections.emptyList());
147+
}
148+
149+
String input =
150+
contextOverflowAction.processTokensWithLimit(
151+
model, rowData.getString(0).toString(), maxContextSize);
152+
if (input == null) {
153+
return CompletableFuture.completedFuture(Collections.emptyList());
154+
}
155+
156+
return asyncPredictInternal(input);
109157
}
110158

111159
@Override
@@ -120,6 +168,8 @@ public void close() throws Exception {
120168

121169
protected abstract String getEndpointSuffix();
122170

171+
protected abstract CompletableFuture<Collection<RowData>> asyncPredictInternal(String input);
172+
123173
protected void validateSingleColumnSchema(
124174
ResolvedSchema schema, LogicalType expectedType, String inputOrOutput) {
125175
List<Column> columns = schema.getColumns();

0 commit comments

Comments
 (0)