Skip to content

[ML] Remove error parsing functionality for custom service #128778

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
2 changes: 2 additions & 0 deletions server/src/main/java/org/elasticsearch/TransportVersions.java
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,7 @@ static TransportVersion def(int id) {
public static final TransportVersion NONE_CHUNKING_STRATEGY_8_19 = def(8_841_0_49);
public static final TransportVersion IDP_CUSTOM_SAML_ATTRIBUTES_ALLOW_LIST_8_19 = def(8_841_0_50);
public static final TransportVersion SETTINGS_IN_DATA_STREAMS_8_19 = def(8_841_0_51);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19 = def(8_841_0_52);
public static final TransportVersion V_9_0_0 = def(9_000_0_09);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_1 = def(9_000_0_10);
public static final TransportVersion INITIAL_ELASTICSEARCH_9_0_2 = def(9_000_0_11);
Expand Down Expand Up @@ -302,6 +303,7 @@ static TransportVersion def(int id) {
public static final TransportVersion SECURITY_CLOUD_API_KEY_REALM_AND_TYPE = def(9_099_0_00);
public static final TransportVersion STATE_PARAM_GET_SNAPSHOT = def(9_100_0_00);
public static final TransportVersion PROJECT_ID_IN_SNAPSHOTS_DELETIONS_AND_REPO_CLEANUP = def(9_101_0_00);
public static final TransportVersion ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING = def(9_102_0_00);

/*
* STOP! READ THIS FIRST! No, really,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,8 @@ public static RateLimitGrouping of(CustomModel model) {
}
}

private static ResponseHandler createCustomHandler(CustomModel model) {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse, model.getServiceSettings().getErrorParser());
private static ResponseHandler createCustomHandler() {
return new CustomResponseHandler("custom model", CustomResponseEntity::fromResponse);
}

public static CustomRequestManager of(CustomModel model, ThreadPool threadPool) {
Expand All @@ -55,7 +55,7 @@ public static CustomRequestManager of(CustomModel model, ThreadPool threadPool)
private CustomRequestManager(CustomModel model, ThreadPool threadPool) {
super(threadPool, model.getInferenceEntityId(), RateLimitGrouping.of(model), model.rateLimitServiceSettings().rateLimitSettings());
this.model = model;
this.handler = createCustomHandler(model);
this.handler = createCustomHandler();
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,34 @@
package org.elasticsearch.xpack.inference.services.custom;

import org.elasticsearch.ElasticsearchStatusException;
import org.elasticsearch.common.Strings;
import org.elasticsearch.inference.InferenceServiceResults;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.xpack.inference.external.http.HttpResult;
import org.elasticsearch.xpack.inference.external.http.retry.BaseResponseHandler;
import org.elasticsearch.xpack.inference.external.http.retry.ErrorResponse;
import org.elasticsearch.xpack.inference.external.http.retry.ResponseParser;
import org.elasticsearch.xpack.inference.external.http.retry.RetryException;
import org.elasticsearch.xpack.inference.external.request.Request;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;

import java.nio.charset.StandardCharsets;
import java.util.function.Function;

/**
* Defines how to handle various response types returned from the custom integration.
*/
public class CustomResponseHandler extends BaseResponseHandler {
public CustomResponseHandler(String requestType, ResponseParser parseFunction, ErrorResponseParser errorParser) {
super(requestType, parseFunction, errorParser);
// default for testing
static final Function<HttpResult, ErrorResponse> ERROR_PARSER = (httpResult) -> {
try {
return new ErrorResponse(new String(httpResult.body(), StandardCharsets.UTF_8));
} catch (Exception e) {
return new ErrorResponse(Strings.format("Failed to parse error response body: %s", e.getMessage()));
}
};

public CustomResponseHandler(String requestType, ResponseParser parseFunction) {
super(requestType, parseFunction, ERROR_PARSER);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -249,8 +249,7 @@ private static CustomServiceSettings getCustomServiceSettings(CustomModel custom
serviceSettings.getQueryParameters(),
serviceSettings.getRequestContentString(),
serviceSettings.getResponseJsonParser(),
serviceSettings.rateLimitSettings(),
serviceSettings.getErrorParser()
serviceSettings.rateLimitSettings()
);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@
import org.elasticsearch.xpack.inference.services.ConfigurationParseContext;
import org.elasticsearch.xpack.inference.services.custom.response.CompletionResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.NoopResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.SparseEmbeddingResponseParser;
Expand Down Expand Up @@ -59,7 +58,6 @@ public class CustomServiceSettings extends FilteredXContentObject implements Ser
public static final String REQUEST = "request";
public static final String RESPONSE = "response";
public static final String JSON_PARSER = "json_parser";
public static final String ERROR_PARSER = "error_parser";

private static final RateLimitSettings DEFAULT_RATE_LIMIT_SETTINGS = new RateLimitSettings(10_000);
private static final String RESPONSE_SCOPE = String.join(".", ModelConfigurations.SERVICE_SETTINGS, RESPONSE);
Expand Down Expand Up @@ -100,15 +98,6 @@ public static CustomServiceSettings fromMap(

var responseJsonParser = extractResponseParser(taskType, jsonParserMap, validationException);

Map<String, Object> errorParserMap = extractRequiredMap(
Objects.requireNonNullElse(responseParserMap, new HashMap<>()),
ERROR_PARSER,
RESPONSE_SCOPE,
validationException
);

var errorParser = ErrorResponseParser.fromMap(errorParserMap, RESPONSE_SCOPE, inferenceId, validationException);

RateLimitSettings rateLimitSettings = RateLimitSettings.of(
map,
DEFAULT_RATE_LIMIT_SETTINGS,
Expand All @@ -117,13 +106,12 @@ public static CustomServiceSettings fromMap(
context
);

if (responseParserMap == null || jsonParserMap == null || errorParserMap == null) {
if (responseParserMap == null || jsonParserMap == null) {
throw validationException;
}

throwIfNotEmptyMap(jsonParserMap, JSON_PARSER, NAME);
throwIfNotEmptyMap(responseParserMap, RESPONSE, NAME);
throwIfNotEmptyMap(errorParserMap, ERROR_PARSER, NAME);

if (validationException.validationErrors().isEmpty() == false) {
throw validationException;
Expand All @@ -136,8 +124,7 @@ public static CustomServiceSettings fromMap(
queryParams,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down Expand Up @@ -209,7 +196,6 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
private final String requestContentString;
private final CustomResponseParser responseJsonParser;
private final RateLimitSettings rateLimitSettings;
private final ErrorResponseParser errorParser;

public CustomServiceSettings(
TextEmbeddingSettings textEmbeddingSettings,
Expand All @@ -218,8 +204,7 @@ public CustomServiceSettings(
@Nullable QueryParameters queryParameters,
String requestContentString,
CustomResponseParser responseJsonParser,
@Nullable RateLimitSettings rateLimitSettings,
ErrorResponseParser errorParser
@Nullable RateLimitSettings rateLimitSettings
) {
this.textEmbeddingSettings = Objects.requireNonNull(textEmbeddingSettings);
this.url = Objects.requireNonNull(url);
Expand All @@ -228,7 +213,6 @@ public CustomServiceSettings(
this.requestContentString = Objects.requireNonNull(requestContentString);
this.responseJsonParser = Objects.requireNonNull(responseJsonParser);
this.rateLimitSettings = Objects.requireNonNullElse(rateLimitSettings, DEFAULT_RATE_LIMIT_SETTINGS);
this.errorParser = Objects.requireNonNull(errorParser);
}

public CustomServiceSettings(StreamInput in) throws IOException {
Expand All @@ -239,7 +223,12 @@ public CustomServiceSettings(StreamInput in) throws IOException {
requestContentString = in.readString();
responseJsonParser = in.readNamedWriteable(CustomResponseParser.class);
rateLimitSettings = new RateLimitSettings(in);
errorParser = new ErrorResponseParser(in);
if (in.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& in.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Read the error parsing fields for backwards compatibility
in.readString();
in.readString();
}
}

@Override
Expand Down Expand Up @@ -287,10 +276,6 @@ public CustomResponseParser getResponseJsonParser() {
return responseJsonParser;
}

public ErrorResponseParser getErrorParser() {
return errorParser;
}

@Override
public RateLimitSettings rateLimitSettings() {
return rateLimitSettings;
Expand Down Expand Up @@ -331,7 +316,6 @@ public XContentBuilder toXContentFragmentOfExposedFields(XContentBuilder builder
builder.startObject(RESPONSE);
{
responseJsonParser.toXContent(builder, params);
errorParser.toXContent(builder, params);
}
builder.endObject();

Expand Down Expand Up @@ -359,7 +343,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeString(requestContentString);
out.writeNamedWriteable(responseJsonParser);
rateLimitSettings.writeTo(out);
errorParser.writeTo(out);
if (out.getTransportVersion().before(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING)
&& out.getTransportVersion().isPatchFrom(TransportVersions.ML_INFERENCE_CUSTOM_SERVICE_REMOVE_ERROR_PARSING_8_19) == false) {
// Write empty strings for backwards compatibility for the error parsing fields
out.writeString("");
out.writeString("");
}
}

@Override
Expand All @@ -373,8 +362,7 @@ public boolean equals(Object o) {
&& Objects.equals(queryParameters, that.queryParameters)
&& Objects.equals(requestContentString, that.requestContentString)
&& Objects.equals(responseJsonParser, that.responseJsonParser)
&& Objects.equals(rateLimitSettings, that.rateLimitSettings)
&& Objects.equals(errorParser, that.errorParser);
&& Objects.equals(rateLimitSettings, that.rateLimitSettings);
}

@Override
Expand All @@ -386,8 +374,7 @@ public int hashCode() {
queryParameters,
requestContentString,
responseJsonParser,
rateLimitSettings,
errorParser
rateLimitSettings
);
}

Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
import org.elasticsearch.inference.TaskType;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.xpack.inference.services.custom.response.CustomResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.TextEmbeddingResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.hamcrest.MatcherAssert;
Expand Down Expand Up @@ -120,8 +119,7 @@ public static CustomModel getTestModel(TaskType taskType, CustomResponseParser r
QueryParameters.EMPTY,
requestContentString,
responseParser,
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

CustomTaskSettings taskSettings = new CustomTaskSettings(Map.of(taskSettingsKey, taskSettingsValue));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.xpack.inference.external.http.retry.RequestSender;
import org.elasticsearch.xpack.inference.external.http.sender.EmbeddingsInput;
import org.elasticsearch.xpack.inference.services.custom.response.ErrorResponseParser;
import org.elasticsearch.xpack.inference.services.custom.response.RerankResponseParser;
import org.elasticsearch.xpack.inference.services.settings.RateLimitSettings;
import org.junit.After;
Expand Down Expand Up @@ -64,8 +63,7 @@ public void testCreateRequest_ThrowsException_ForInvalidUrl() {
null,
requestContentString,
new RerankResponseParser("$.result.score"),
new RateLimitSettings(10_000),
new ErrorResponseParser("$.error.message", inferenceId)
new RateLimitSettings(10_000)
);

var model = CustomModelTests.createModel(
Expand Down
Loading