Skip to content

Commit 166280b

Browse files
[ML] Custom service adding support for the semantic text field (#129558) (#129658)
* Adding chunking tests * adjusting default batch size (cherry picked from commit daf4fca) # Conflicts: # server/src/main/java/org/elasticsearch/TransportVersions.java # x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java
1 parent e92de38 commit 166280b

File tree

6 files changed

+378
-14
lines changed

6 files changed

+378
-14
lines changed

server/src/main/java/org/elasticsearch/TransportVersions.java

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -243,6 +243,7 @@ static TransportVersion def(int id) {
243243
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
244244
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
245245
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
246+
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19 = def(8_841_0_53);
246247

247248
/*
248249
* STOP! READ THIS FIRST! No, really,

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomModel.java

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
package org.elasticsearch.xpack.inference.services.custom;
99

1010
import org.elasticsearch.core.Nullable;
11+
import org.elasticsearch.inference.ChunkingSettings;
1112
import org.elasticsearch.inference.Model;
1213
import org.elasticsearch.inference.ModelConfigurations;
1314
import org.elasticsearch.inference.ModelSecrets;
@@ -51,6 +52,27 @@ public CustomModel(
5152
);
5253
}
5354

55+
public CustomModel(
56+
String inferenceId,
57+
TaskType taskType,
58+
String service,
59+
Map<String, Object> serviceSettings,
60+
Map<String, Object> taskSettings,
61+
@Nullable Map<String, Object> secrets,
62+
@Nullable ChunkingSettings chunkingSettings,
63+
ConfigurationParseContext context
64+
) {
65+
this(
66+
inferenceId,
67+
taskType,
68+
service,
69+
CustomServiceSettings.fromMap(serviceSettings, context, taskType, inferenceId),
70+
CustomTaskSettings.fromMap(taskSettings),
71+
CustomSecretSettings.fromMap(secrets),
72+
chunkingSettings
73+
);
74+
}
75+
5476
// should only be used for testing
5577
CustomModel(
5678
String inferenceId,
@@ -67,6 +89,23 @@ public CustomModel(
6789
);
6890
}
6991

92+
// should only be used for testing
93+
CustomModel(
94+
String inferenceId,
95+
TaskType taskType,
96+
String service,
97+
CustomServiceSettings serviceSettings,
98+
CustomTaskSettings taskSettings,
99+
@Nullable CustomSecretSettings secretSettings,
100+
@Nullable ChunkingSettings chunkingSettings
101+
) {
102+
this(
103+
new ModelConfigurations(inferenceId, taskType, service, serviceSettings, taskSettings, chunkingSettings),
104+
new ModelSecrets(secretSettings),
105+
serviceSettings
106+
);
107+
}
108+
70109
protected CustomModel(CustomModel model, TaskSettings taskSettings) {
71110
super(model, taskSettings);
72111
rateLimitServiceSettings = model.rateLimitServiceSettings();

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java

Lines changed: 61 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import org.elasticsearch.core.Strings;
1818
import org.elasticsearch.core.TimeValue;
1919
import org.elasticsearch.inference.ChunkedInference;
20+
import org.elasticsearch.inference.ChunkingSettings;
2021
import org.elasticsearch.inference.InferenceServiceConfiguration;
2122
import org.elasticsearch.inference.InferenceServiceResults;
2223
import org.elasticsearch.inference.InputType;
@@ -27,6 +28,8 @@
2728
import org.elasticsearch.inference.SimilarityMeasure;
2829
import org.elasticsearch.inference.TaskType;
2930
import org.elasticsearch.rest.RestStatus;
31+
import org.elasticsearch.xpack.inference.chunking.ChunkingSettingsBuilder;
32+
import org.elasticsearch.xpack.inference.chunking.EmbeddingRequestChunker;
3033
import org.elasticsearch.xpack.inference.external.action.SenderExecutableAction;
3134
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
3235
import org.elasticsearch.xpack.inference.external.http.sender.HttpRequestSender;
@@ -45,6 +48,7 @@
4548
import static org.elasticsearch.inference.TaskType.unsupportedTaskTypeErrorMsg;
4649
import static org.elasticsearch.xpack.inference.external.action.ActionUtils.constructFailedToSendRequestMessage;
4750
import static org.elasticsearch.xpack.inference.services.ServiceUtils.createInvalidModelException;
51+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMap;
4852
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrDefaultEmpty;
4953
import static org.elasticsearch.xpack.inference.services.ServiceUtils.removeFromMapOrThrowIfNull;
5054
import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwIfNotEmptyMap;
@@ -81,12 +85,15 @@ public void parseRequestConfig(
8185
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
8286
Map<String, Object> taskSettingsMap = removeFromMapOrDefaultEmpty(config, ModelConfigurations.TASK_SETTINGS);
8387

88+
var chunkingSettings = extractChunkingSettings(config, taskType);
89+
8490
CustomModel model = createModel(
8591
inferenceEntityId,
8692
taskType,
8793
serviceSettingsMap,
8894
taskSettingsMap,
8995
serviceSettingsMap,
96+
chunkingSettings,
9097
ConfigurationParseContext.REQUEST
9198
);
9299

@@ -100,6 +107,14 @@ public void parseRequestConfig(
100107
}
101108
}
102109

110+
private static ChunkingSettings extractChunkingSettings(Map<String, Object> config, TaskType taskType) {
111+
if (TaskType.TEXT_EMBEDDING.equals(taskType)) {
112+
return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS));
113+
}
114+
115+
return null;
116+
}
117+
103118
@Override
104119
public InferenceServiceConfiguration getConfiguration() {
105120
return Configuration.get();
@@ -125,14 +140,16 @@ private static CustomModel createModelWithoutLoggingDeprecations(
125140
TaskType taskType,
126141
Map<String, Object> serviceSettings,
127142
Map<String, Object> taskSettings,
128-
@Nullable Map<String, Object> secretSettings
143+
@Nullable Map<String, Object> secretSettings,
144+
@Nullable ChunkingSettings chunkingSettings
129145
) {
130146
return createModel(
131147
inferenceEntityId,
132148
taskType,
133149
serviceSettings,
134150
taskSettings,
135151
secretSettings,
152+
chunkingSettings,
136153
ConfigurationParseContext.PERSISTENT
137154
);
138155
}
@@ -143,12 +160,13 @@ private static CustomModel createModel(
143160
Map<String, Object> serviceSettings,
144161
Map<String, Object> taskSettings,
145162
@Nullable Map<String, Object> secretSettings,
163+
@Nullable ChunkingSettings chunkingSettings,
146164
ConfigurationParseContext context
147165
) {
148166
if (supportedTaskTypes.contains(taskType) == false) {
149167
throw new ElasticsearchStatusException(unsupportedTaskTypeErrorMsg(taskType, NAME), RestStatus.BAD_REQUEST);
150168
}
151-
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, context);
169+
return new CustomModel(inferenceEntityId, taskType, NAME, serviceSettings, taskSettings, secretSettings, chunkingSettings, context);
152170
}
153171

154172
@Override
@@ -162,15 +180,33 @@ public CustomModel parsePersistedConfigWithSecrets(
162180
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
163181
Map<String, Object> secretSettingsMap = removeFromMapOrThrowIfNull(secrets, ModelSecrets.SECRET_SETTINGS);
164182

165-
return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, secretSettingsMap);
183+
var chunkingSettings = extractChunkingSettings(config, taskType);
184+
185+
return createModelWithoutLoggingDeprecations(
186+
inferenceEntityId,
187+
taskType,
188+
serviceSettingsMap,
189+
taskSettingsMap,
190+
secretSettingsMap,
191+
chunkingSettings
192+
);
166193
}
167194

168195
@Override
169196
public CustomModel parsePersistedConfig(String inferenceEntityId, TaskType taskType, Map<String, Object> config) {
170197
Map<String, Object> serviceSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.SERVICE_SETTINGS);
171198
Map<String, Object> taskSettingsMap = removeFromMapOrThrowIfNull(config, ModelConfigurations.TASK_SETTINGS);
172199

173-
return createModelWithoutLoggingDeprecations(inferenceEntityId, taskType, serviceSettingsMap, taskSettingsMap, null);
200+
var chunkingSettings = extractChunkingSettings(config, taskType);
201+
202+
return createModelWithoutLoggingDeprecations(
203+
inferenceEntityId,
204+
taskType,
205+
serviceSettingsMap,
206+
taskSettingsMap,
207+
null,
208+
chunkingSettings
209+
);
174210
}
175211

176212
@Override
@@ -211,7 +247,27 @@ protected void doChunkedInfer(
211247
TimeValue timeout,
212248
ActionListener<List<ChunkedInference>> listener
213249
) {
214-
listener.onFailure(new ElasticsearchStatusException("Chunking not supported by the {} service", RestStatus.BAD_REQUEST, NAME));
250+
if (model instanceof CustomModel == false) {
251+
listener.onFailure(createInvalidModelException(model));
252+
return;
253+
}
254+
255+
var customModel = (CustomModel) model;
256+
var overriddenModel = CustomModel.of(customModel, taskSettings);
257+
258+
var failedToSendRequestErrorMessage = constructFailedToSendRequestMessage(SERVICE_NAME);
259+
var manager = CustomRequestManager.of(overriddenModel, getServiceComponents().threadPool());
260+
261+
List<EmbeddingRequestChunker.BatchRequestAndListener> batchedRequests = new EmbeddingRequestChunker<>(
262+
inputs.getInputs(),
263+
customModel.getServiceSettings().getBatchSize(),
264+
customModel.getConfigurations().getChunkingSettings()
265+
).batchRequestsWithListeners(listener);
266+
267+
for (var request : batchedRequests) {
268+
var action = new SenderExecutableAction(getSender(), manager, failedToSendRequestErrorMessage);
269+
action.execute(EmbeddingsInput.fromStrings(request.batch().inputs().get(), inputType), timeout, request.listener());
270+
}
215271
}
216272

217273
@Override

x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceSettings.java

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@
4343
import static org.elasticsearch.xpack.inference.services.ServiceFields.MAX_INPUT_TOKENS;
4444
import static org.elasticsearch.xpack.inference.services.ServiceFields.SIMILARITY;
4545
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalMap;
46+
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractOptionalPositiveInteger;
4647
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredMap;
4748
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractRequiredString;
4849
import static org.elasticsearch.xpack.inference.services.ServiceUtils.extractSimilarity;
@@ -52,15 +53,18 @@
5253
import static org.elasticsearch.xpack.inference.services.ServiceUtils.validateMapStringValues;
5354

5455
public class CustomServiceSettings extends FilteredXContentObject implements ServiceSettings, CustomRateLimitServiceSettings {
56+
5557
public static final String NAME = "custom_service_settings";
5658
public static final String URL = "url";
59+
public static final String BATCH_SIZE = "batch_size";
5760
public static final String HEADERS = "headers";
5861
public static final String REQUEST = "request";
5962
public static final String RESPONSE = "response";
6063
public static final String JSON_PARSER = "json_parser";
6164

6265
private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
6366
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
67+
private static final int DEFAULT_EMBEDDING_BATCH_SIZE = 10;
6468

6569
public static CustomServiceSettings fromMap(
6670
Map<String, Object> map,
@@ -106,6 +110,8 @@ public static CustomServiceSettings fromMap(
106110
context
107111
);
108112

113+
var batchSize = extractOptionalPositiveInteger(map, BATCH_SIZE, ModelConfigurations.SERVICE_SETTINGS, validationException);
114+
109115
if (responseParserMap == null || jsonParserMap == null) {
110116
throw validationException;
111117
}
@@ -124,7 +130,8 @@ public static CustomServiceSettings fromMap(
124130
queryParams,
125131
requestContentString,
126132
responseJsonParser,
127-
rateLimitSettings
133+
rateLimitSettings,
134+
batchSize
128135
);
129136
}
130137

@@ -142,7 +149,6 @@ public record TextEmbeddingSettings(
142149
null,
143150
DenseVectorFieldMapper.ElementType.FLOAT
144151
);
145-
146152
// This refers to settings that are not related to the text embedding task type (all the settings should be null)
147153
public static final TextEmbeddingSettings NON_TEXT_EMBEDDING_TASK_TYPE_SETTINGS = new TextEmbeddingSettings(null, null, null, null);
148154

@@ -196,6 +202,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
196202
private final String requestContentString;
197203
private final CustomResponseParser responseJsonParser;
198204
private final RateLimitSettings rateLimitSettings;
205+
private final int batchSize;
199206

200207
public CustomServiceSettings(
201208
TextEmbeddingSettings textEmbeddingSettings,
@@ -205,6 +212,19 @@ public CustomServiceSettings(
205212
String requestContentString,
206213
CustomResponseParser responseJsonParser,
207214
@Nullable RateLimitSettings rateLimitSettings
215+
) {
216+
this(textEmbeddingSettings, url, headers, queryParameters, requestContentString, responseJsonParser, rateLimitSettings, null);
217+
}
218+
219+
public CustomServiceSettings(
220+
TextEmbeddingSettings textEmbeddingSettings,
221+
String url,
222+
@Nullable Map<String, String> headers,
223+
@Nullable QueryParameters queryParameters,
224+
String requestContentString,
225+
CustomResponseParser responseJsonParser,
226+
@Nullable RateLimitSettings rateLimitSettings,
227+
@Nullable Integer batchSize
208228
) {
209229
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
210230
this.url = Objects.requireNonNull(url);
@@ -213,6 +233,7 @@ public CustomServiceSettings(
213233
this.requestContentString = Objects.requireNonNull(requestContentString);
214234
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
215235
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
236+
this.batchSize = Objects.requireNonNullElse(batchSize, DEFAULT_EMBEDDING_BATCH_SIZE);
216237
}
217238

218239
public CustomServiceSettings(StreamInput in) throws IOException {
@@ -223,11 +244,18 @@ public CustomServiceSettings(StreamInput in) throws IOException {
223244
requestContentString = in.readString();
224245
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
225246
rateLimitSettings = new RateLimitSettings(in);
247+
226248
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19)) {
227249
// Read the error parsing fields for backwards compatibility
228250
in.readString();
229251
in.readString();
230252
}
253+
254+
if (in.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
255+
batchSize = in.readVInt();
256+
} else {
257+
batchSize = DEFAULT_EMBEDDING_BATCH_SIZE;
258+
}
231259
}
232260

233261
@Override
@@ -275,6 +303,10 @@ public CustomResponseParser getResponseJsonParser() {
275303
return responseJsonParser;
276304
}
277305

306+
public int getBatchSize() {
307+
return batchSize;
308+
}
309+
278310
@Override
279311
public RateLimitSettings rateLimitSettings() {
280312
return rateLimitSettings;
@@ -320,6 +352,8 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
320352

321353
rateLimitSettings.toXContent(builder, params);
322354

355+
builder.field(BATCH_SIZE, batchSize);
356+
323357
return builder;
324358
}
325359

@@ -342,11 +376,16 @@ public void writeTo(StreamOutput out) throws IOException {
342376
out.writeString(requestContentString);
343377
out.writeNamedWriteable(responseJsonParser);
344378
rateLimitSettings.writeTo(out);
379+
345380
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19)) {
346381
// Write empty strings for backwards compatibility for the error parsing fields
347382
out.writeString("");
348383
out.writeString("");
349384
}
385+
386+
if (out.getTransportVersion().onOrAfter(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_EMBEDDING_BATCH_SIZE_8_19)) {
387+
out.writeVInt(batchSize);
388+
}
350389
}
351390

352391
@Override
@@ -360,7 +399,8 @@ public boolean equals(Object o) {
360399
&& Objects.equals(queryParameters, that.queryParameters)
361400
&& Objects.equals(requestContentString, that.requestContentString)
362401
&& Objects.equals(responseJsonParser, that.responseJsonParser)
363-
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
402+
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
403+
&& Objects.equals(batchSize, that.batchSize);
364404
}
365405

366406
@Override
@@ -372,7 +412,8 @@ public int hashCode() {
372412
queryParameters,
373413
requestContentString,
374414
responseJsonParser,
375-
rateLimitSettings
415+
rateLimitSettings,
416+
batchSize
376417
);
377418
}
378419

0 commit comments

Comments
 (0)