2525import org .apache .flink .table .api .config .ExecutionConfigOptions ;
2626import org .apache .flink .table .catalog .Column ;
2727import org .apache .flink .table .catalog .ResolvedSchema ;
28+ import org .apache .flink .table .data .RowData ;
2829import org .apache .flink .table .factories .ModelProviderFactory ;
2930import org .apache .flink .table .functions .AsyncPredictFunction ;
3031import org .apache .flink .table .functions .FunctionContext ;
3536import org .slf4j .Logger ;
3637import org .slf4j .LoggerFactory ;
3738
39+ import javax .annotation .Nullable ;
40+
41+ import java .util .Collection ;
42+ import java .util .Collections ;
3843import java .util .List ;
44+ import java .util .concurrent .CompletableFuture ;
3945import java .util .stream .Collectors ;
4046
4147import static org .apache .flink .configuration .description .TextElement .code ;
@@ -73,11 +79,32 @@ public abstract class AbstractOpenAIModelFunction extends AsyncPredictFunction {
7379 code ("gpt-3.5-turbo" ), code ("text-embedding-ada-002" ))
7480 .build ());
7581
82+ public static final ConfigOption <Integer > MAX_CONTEXT_SIZE =
83+ ConfigOptions .key ("max-context-size" )
84+ .intType ()
85+ .noDefaultValue ()
86+ .withDescription (
87+ "Max number of tokens for context. context-overflow-action would be triggered if this threshold is exceeded." );
88+
89+ public static final ConfigOption <ContextOverflowAction > CONTEXT_OVERFLOW_ACTION =
90+ ConfigOptions .key ("context-overflow-action" )
91+ .enumType (ContextOverflowAction .class )
92+ .defaultValue (ContextOverflowAction .TRUNCATED_TAIL )
93+ .withDescription (
94+ Description .builder ()
95+ .text ("Action to handle context overflows. Supported actions:" )
96+ .linebreak ()
97+ .text (ContextOverflowAction .getAllValuesAndDescriptions ())
98+ .build ());
99+
76100 protected transient OpenAIClientAsync client ;
77101
78102 private final int numRetry ;
79103 private final String baseUrl ;
80104 private final String apiKey ;
105+ private final String model ;
106+ @ Nullable private final Integer maxContextSize ;
107+ private final ContextOverflowAction contextOverflowAction ;
81108
82109 public AbstractOpenAIModelFunction (
83110 ModelProviderFactory .Context factoryContext , ReadableConfig config ) {
@@ -94,6 +121,9 @@ public AbstractOpenAIModelFunction(
94121 // resilience while maintaining throughput efficiency.
95122 this .numRetry =
96123 config .get (ExecutionConfigOptions .TABLE_EXEC_ASYNC_LOOKUP_BUFFER_CAPACITY ) * 10 ;
124+ this .model = config .get (MODEL );
125+ this .maxContextSize = config .get (MAX_CONTEXT_SIZE );
126+ this .contextOverflowAction = config .get (CONTEXT_OVERFLOW_ACTION );
97127
98128 validateSingleColumnSchema (
99129 factoryContext .getCatalogModel ().getResolvedInputSchema (),
@@ -106,6 +136,24 @@ public void open(FunctionContext context) throws Exception {
106136 super .open (context );
107137 LOG .debug ("Creating an OpenAI client." );
108138 this .client = OpenAIUtils .createAsyncClient (baseUrl , apiKey , numRetry );
139+ this .contextOverflowAction .initializeEncodingForContextLimit (model , maxContextSize );
140+ }
141+
142+ @ Override
143+ public CompletableFuture <Collection <RowData >> asyncPredict (RowData rowData ) {
144+ if (rowData .isNullAt (0 )) {
145+ LOG .warn ("Input is null, skipping prediction." );
146+ return CompletableFuture .completedFuture (Collections .emptyList ());
147+ }
148+
149+ String input =
150+ contextOverflowAction .processTokensWithLimit (
151+ model , rowData .getString (0 ).toString (), maxContextSize );
152+ if (input == null ) {
153+ return CompletableFuture .completedFuture (Collections .emptyList ());
154+ }
155+
156+ return asyncPredictInternal (input );
109157 }
110158
111159 @ Override
@@ -120,6 +168,8 @@ public void close() throws Exception {
120168
121169 protected abstract String getEndpointSuffix ();
122170
171+ protected abstract CompletableFuture <Collection <RowData >> asyncPredictInternal (String input );
172+
123173 protected void validateSingleColumnSchema (
124174 ResolvedSchema schema , LogicalType expectedType , String inputOrOutput ) {
125175 List <Column > columns = schema .getColumns ();
0 commit comments