diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutor.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutor.java index ec0619ff06389..ab98408a5156a 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutor.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutor.java @@ -51,7 +51,12 @@ private static void ensureNoMorePlaceholdersExist(String substitutedString, Stri Matcher matcher = VARIABLE_PLACEHOLDER_PATTERN.matcher(substitutedString); if (matcher.find()) { throw new IllegalStateException( - Strings.format("Found placeholder [%s] in field [%s] after replacement call", matcher.group(), field) + Strings.format( + "Found placeholder [%s] in field [%s] after replacement call, " + + "please check that all templates have a corresponding field definition.", + matcher.group(), + field + ) ); } } diff --git a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java index 9a06c613e6e76..9bd696c08139b 100644 --- a/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java +++ b/x-pack/plugin/inference/src/main/java/org/elasticsearch/xpack/inference/services/custom/CustomService.java @@ -39,6 +39,7 @@ import org.elasticsearch.xpack.inference.services.SenderService; import org.elasticsearch.xpack.inference.services.ServiceComponents; import org.elasticsearch.xpack.inference.services.ServiceUtils; +import org.elasticsearch.xpack.inference.services.custom.request.CustomRequest; import java.util.EnumSet; import java.util.HashMap; @@ -55,6 +56,7 @@ import static org.elasticsearch.xpack.inference.services.ServiceUtils.throwUnsupportedUnifiedCompletionOperation; public class CustomService extends SenderService { + public static final String NAME = "custom"; private static final String SERVICE_NAME = "Custom"; @@ -101,12 +103,32 @@ public void parseRequestConfig( throwIfNotEmptyMap(serviceSettingsMap, NAME); throwIfNotEmptyMap(taskSettingsMap, NAME); + validateConfiguration(model); + parsedModelListener.onResponse(model); } catch (Exception e) { parsedModelListener.onFailure(e); } } + /** + * This does some initial validation with mock inputs to determine if any templates are missing a field to fill them. + */ + private static void validateConfiguration(CustomModel model) { + String query = null; + if (model.getTaskType() == TaskType.RERANK) { + query = "test query"; + } + + try { + new CustomRequest(query, List.of("test input"), model).createHttpRequest(); + } catch (IllegalStateException e) { + var validationException = new ValidationException(); + validationException.addValidationError(Strings.format("Failed to validate model configuration: %s", e.getMessage())); + throw validationException; + } + } + private static ChunkingSettings extractChunkingSettings(Map config, TaskType taskType) { if (TaskType.TEXT_EMBEDDING.equals(taskType)) { return ChunkingSettingsBuilder.fromMap(removeFromMap(config, ModelConfigurations.CHUNKING_SETTINGS)); diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutorTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutorTests.java index eb0224ff774dc..61563fd82de31 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutorTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/common/ValidatingSubstitutorTests.java @@ -37,20 +37,38 @@ public void testReplace_ThrowsException_WhenPlaceHolderStillExists() { var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "key2", "value2"), "${", "}"); var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super:${key}", "setting")); - assertThat(exception.getMessage(), is("Found placeholder [${key}] in field [setting] after replacement call")); + assertThat( + exception.getMessage(), + is( + "Found placeholder [${key}] in field [setting] after replacement call, " + + "please check that all templates have a corresponding field definition." + ) + ); } // only reports the first placeholder pattern { var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "some_key2", "value2"), "${", "}"); var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super, ${key}, ${key2}", "setting")); - assertThat(exception.getMessage(), is("Found placeholder [${key}] in field [setting] after replacement call")); + assertThat( + exception.getMessage(), + is( + "Found placeholder [${key}] in field [setting] after replacement call, " + + "please check that all templates have a corresponding field definition." + ) + ); } { var sub = new ValidatingSubstitutor(Map.of("some_key", "value", "key2", "value2"), "${", "}"); var exception = expectThrows(IllegalStateException.class, () -> sub.replace("super:${ \\/\tkey\"}", "setting")); - assertThat(exception.getMessage(), is("Found placeholder [${ \\/\tkey\"}] in field [setting] after replacement call")); + assertThat( + exception.getMessage(), + is( + "Found placeholder [${ \\/\tkey\"}] in field [setting] after replacement call," + + " please check that all templates have a corresponding field definition." + ) + ); } } } diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java index dc82d71df6503..d268b301dde8d 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/CustomServiceTests.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchStatusException; import org.elasticsearch.action.support.PlainActionFuture; import org.elasticsearch.common.Strings; +import org.elasticsearch.common.ValidationException; import org.elasticsearch.common.settings.SecureString; import org.elasticsearch.core.Nullable; import org.elasticsearch.index.mapper.vectors.DenseVectorFieldMapper; @@ -52,6 +53,7 @@ import java.util.Map; import static org.elasticsearch.xpack.inference.Utils.TIMEOUT; +import static org.elasticsearch.xpack.inference.Utils.getRequestConfigMap; import static org.elasticsearch.xpack.inference.external.http.Utils.entityAsMap; import static org.elasticsearch.xpack.inference.external.http.Utils.getUrl; import static org.elasticsearch.xpack.inference.services.ServiceComponentsTests.createWithEmptySettings; @@ -611,6 +613,42 @@ public void testInfer_HandlesSparseEmbeddingRequest_Alibaba_Format() throws IOEx } } + public void testParseRequestConfig_ThrowsAValidationError_WhenReplacementDoesNotFillTemplate() throws Exception { + try (var service = createService(threadPool, clientManager)) { + + var settingsMap = new HashMap<>( + Map.of( + CustomServiceSettings.URL, + "http://www.abc.com", + CustomServiceSettings.HEADERS, + Map.of("key", "value"), + QueryParameters.QUERY_PARAMETERS, + List.of(List.of("key", "value")), + CustomServiceSettings.REQUEST, + "request body ${some_template}", + CustomServiceSettings.RESPONSE, + new HashMap<>(Map.of(CustomServiceSettings.JSON_PARSER, createResponseParserMap(TaskType.COMPLETION))) + ) + ); + + var config = getRequestConfigMap(settingsMap, createTaskSettingsMap(), createSecretSettingsMap()); + + var listener = new PlainActionFuture(); + service.parseRequestConfig("id", TaskType.COMPLETION, config, listener); + + var exception = expectThrows(ValidationException.class, () -> listener.actionGet(TIMEOUT)); + + assertThat( + exception.getMessage(), + is( + "Validation Failed: 1: Failed to validate model configuration: Found placeholder " + + "[${some_template}] in field [request] after replacement call, please check that all " + + "templates have a corresponding field definition.;" + ) + ); + } + } + public void testChunkedInfer_ChunkingSettingsSet() throws IOException { var model = createInternalEmbeddingModel( SimilarityMeasure.DOT_PRODUCT, diff --git a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java index ca2726b043056..3ecacdb17cf93 100644 --- a/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java +++ b/x-pack/plugin/inference/src/test/java/org/elasticsearch/xpack/inference/services/custom/request/CustomRequestTests.java @@ -264,7 +264,13 @@ public void testCreateRequest_IgnoresNonStringFields_ForStringParams() throws IO var request = new CustomRequest(null, List.of("abc", "123"), model); var exception = expectThrows(IllegalStateException.class, request::createHttpRequest); - assertThat(exception.getMessage(), is("Found placeholder [${task.key}] in field [header.Accept] after replacement call")); + assertThat( + exception.getMessage(), + is( + "Found placeholder [${task.key}] in field [header.Accept] after replacement call, " + + "please check that all templates have a corresponding field definition." + ) + ); } public void testCreateRequest_ThrowsException_ForInvalidUrl() {