43
43
import static org .elasticsearch .xpack .inference .services .ServiceFields .MAX_INPUT_TOKENS ;
44
44
import static org .elasticsearch .xpack .inference .services .ServiceFields .SIMILARITY ;
45
45
import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalMap ;
46
+ import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractOptionalPositiveInteger ;
46
47
import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredMap ;
47
48
import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractRequiredString ;
48
49
import static org .elasticsearch .xpack .inference .services .ServiceUtils .extractSimilarity ;
52
53
import static org .elasticsearch .xpack .inference .services .ServiceUtils .validateMapStringValues ;
53
54
54
55
public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings , CustomRateLimitServiceSettings {
56
+
55
57
public static final String NAME = "custom_service_settings" ;
56
58
public static final String URL = "url" ;
59
+ public static final String BATCH_SIZE = "batch_size" ;
57
60
public static final String HEADERS = "headers" ;
58
61
public static final String REQUEST = "request" ;
59
62
public static final String RESPONSE = "response" ;
60
63
public static final String JSON_PARSER = "json_parser" ;
61
64
62
65
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings (10_000 );
63
66
private static final String RESPONSE_SCOPE = String .join ("." , ModelConfigurations .SERVICE_SETTINGS , RESPONSE );
67
+ private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10 ;
64
68
65
69
public static CustomServiceSettings fromMap (
66
70
Map <String , Object > map ,
@@ -106,6 +110,8 @@ public static CustomServiceSettings fromMap(
106
110
context
107
111
);
108
112
113
+ var batchSize = extractOptionalPositiveInteger (map , BATCH_SIZE , ModelConfigurations .SERVICE_SETTINGS , validationException );
114
+
109
115
if (responseParserMap == null || jsonParserMap == null ) {
110
116
throw validationException ;
111
117
}
@@ -124,7 +130,8 @@ public static CustomServiceSettings fromMap(
124
130
queryParams ,
125
131
requestContentString ,
126
132
responseJsonParser ,
127
- rateLimitSettings
133
+ rateLimitSettings ,
134
+ batchSize
128
135
);
129
136
}
130
137
@@ -142,7 +149,6 @@ public record TextEmbeddingSettings(
142
149
null ,
143
150
DenseVectorFieldMapper .ElementType .FLOAT
144
151
);
145
-
146
152
// This refers to settings that are not related to the text embedding task type (all the settings should be null)
147
153
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings (null , null , null , null );
148
154
@@ -196,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196
202
private final String requestContentString ;
197
203
private final CustomResponseParser responseJsonParser ;
198
204
private final RateLimitSettings rateLimitSettings ;
205
+ private final int batchSize ;
199
206
200
207
public CustomServiceSettings (
201
208
TextEmbeddingSettings textEmbeddingSettings ,
@@ -205,6 +212,19 @@ public CustomServiceSettings(
205
212
String requestContentString ,
206
213
CustomResponseParser responseJsonParser ,
207
214
@ Nullable RateLimitSettings rateLimitSettings
215
+ ) {
216
+ this (textEmbeddingSettings , url , headers , queryParameters , requestContentString , responseJsonParser , rateLimitSettings , null );
217
+ }
218
+
219
+ public CustomServiceSettings (
220
+ TextEmbeddingSettings textEmbeddingSettings ,
221
+ String url ,
222
+ @ Nullable Map <String , String > headers ,
223
+ @ Nullable QueryParameters queryParameters ,
224
+ String requestContentString ,
225
+ CustomResponseParser responseJsonParser ,
226
+ @ Nullable RateLimitSettings rateLimitSettings ,
227
+ @ Nullable Integer batchSize
208
228
) {
209
229
this .textEmbeddingSettings = Objects .requireNonNull (textEmbeddingSettings );
210
230
this .url = Objects .requireNonNull (url );
@@ -213,6 +233,7 @@ public CustomServiceSettings(
213
233
this .requestContentString = Objects .requireNonNull (requestContentString );
214
234
this .responseJsonParser = Objects .requireNonNull (responseJsonParser );
215
235
this .rateLimitSettings = Objects .requireNonNullElse (rateLimitSettings , DEFAULT_RATE_LIMIT_SETTINGS );
236
+ this .batchSize = Objects .requireNonNullElse (batchSize , DEFAULT_EMBEDDING_BATCH_SIZE );
216
237
}
217
238
218
239
public CustomServiceSettings (StreamInput in ) throws IOException {
@@ -223,11 +244,18 @@ public CustomServiceSettings(StreamInput in) throws IOException {
223
244
requestContentString = in .readString ();
224
245
responseJsonParser = in .readNamedWriteable (CustomResponseParser .class );
225
246
rateLimitSettings = new RateLimitSettings (in );
247
+
226
248
if (in .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 )) {
227
249
// Read the error parsing fields for backwards compatibility
228
250
in .readString ();
229
251
in .readString ();
230
252
}
253
+
254
+ if (in .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
255
+ batchSize = in .readVInt ();
256
+ } else {
257
+ batchSize = DEFAULT_EMBEDDING_BATCH_SIZE ;
258
+ }
231
259
}
232
260
233
261
@ Override
@@ -275,6 +303,10 @@ public CustomResponseParser getResponseJsonParser() {
275
303
return responseJsonParser ;
276
304
}
277
305
306
+ public int getBatchSize () {
307
+ return batchSize ;
308
+ }
309
+
278
310
@ Override
279
311
public RateLimitSettings rateLimitSettings () {
280
312
return rateLimitSettings ;
@@ -320,6 +352,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
320
352
321
353
rateLimitSettings .toXContent (builder , params );
322
354
355
+ builder .field (BATCH_SIZE , batchSize );
356
+
323
357
return builder ;
324
358
}
325
359
@@ -342,11 +376,16 @@ public void writeTo(StreamOutput out) throws IOException {
342
376
out .writeString (requestContentString );
343
377
out .writeNamedWriteable (responseJsonParser );
344
378
rateLimitSettings .writeTo (out );
379
+
345
380
if (out .getTransportVersion ().before (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 )) {
346
381
// Write empty strings for backwards compatibility for the error parsing fields
347
382
out .writeString ("" );
348
383
out .writeString ("" );
349
384
}
385
+
386
+ if (out .getTransportVersion ().onOrAfter (TransportVersions .ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 )) {
387
+ out .writeVInt (batchSize );
388
+ }
350
389
}
351
390
352
391
@ Override
@@ -360,7 +399,8 @@ public boolean equals(Object o) {
360
399
&& Objects .equals (queryParameters , that .queryParameters )
361
400
&& Objects .equals (requestContentString , that .requestContentString )
362
401
&& Objects .equals (responseJsonParser , that .responseJsonParser )
363
- && Objects .equals (rateLimitSettings , that .rateLimitSettings );
402
+ && Objects .equals (rateLimitSettings , that .rateLimitSettings )
403
+ && Objects .equals (batchSize , that .batchSize );
364
404
}
365
405
366
406
@ Override
@@ -372,7 +412,8 @@ public int hashCode() {
372
412
queryParameters ,
373
413
requestContentString ,
374
414
responseJsonParser ,
375
- rateLimitSettings
415
+ rateLimitSettings ,
416
+ batchSize
376
417
);
377
418
}
378
419
0 commit comments