diff --git a/Model/pom.xml b/Model/pom.xml index cc13dc180..bf7d78569 100644 --- a/Model/pom.xml +++ b/Model/pom.xml @@ -135,6 +135,12 @@ openai-java + + com.anthropic + anthropic-java + 2.9.0 + + diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java index 9b61a9f54..185e28ace 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java @@ -13,6 +13,7 @@ import org.apidb.apicommon.model.report.ai.expression.DailyCostMonitor; import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor; import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor.GeneSummaryInputs; +import org.apidb.apicommon.model.report.ai.expression.ClaudeSummarizer; import org.apidb.apicommon.model.report.ai.expression.Summarizer; import org.gusdb.wdk.model.WdkModelException; import org.gusdb.wdk.model.WdkServiceTemporarilyUnavailableException; @@ -32,8 +33,11 @@ public class SingleGeneAiExpressionReporter extends AbstractReporter { private static final int MAX_RESULT_SIZE = 1; // one gene at a time for now private static final String POPULATION_MODE_PROP_KEY = "populateIfNotPresent"; + private static final String AI_MAX_CONCURRENT_REQUESTS_PROP_KEY = "AI_MAX_CONCURRENT_REQUESTS"; + private static final int DEFAULT_MAX_CONCURRENT_REQUESTS = 10; private boolean _populateIfNotPresent; + private int _maxConcurrentRequests; private DailyCostMonitor _costMonitor; @Override @@ -42,6 +46,12 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk // assign cache mode _populateIfNotPresent = config.optBoolean(POPULATION_MODE_PROP_KEY, false); + // read max concurrent requests from model properties or use default + String maxConcurrentRequestsStr = _wdkModel.getProperties().get(AI_MAX_CONCURRENT_REQUESTS_PROP_KEY); + _maxConcurrentRequests = maxConcurrentRequestsStr != null + ? Integer.parseInt(maxConcurrentRequestsStr) + : DEFAULT_MAX_CONCURRENT_REQUESTS; + // instantiate cost monitor _costMonitor = new DailyCostMonitor(_wdkModel); @@ -52,7 +62,7 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk " should only be assigned to " + geneRecordClass.getFullName()); } - // check result size; limit to small results due to OpenAI cost + // check result size; limit to small results due to AI API cost if (_baseAnswer.getResultSizeFactory().getResultSize() > MAX_RESULT_SIZE) { throw new ReporterConfigException("This reporter cannot be called with results of size greater than " + MAX_RESULT_SIZE); } @@ -79,8 +89,8 @@ protected void write(OutputStream out) throws IOException, WdkModelException { // open summary cache (manages persistence of expression data) AiExpressionCache cache = AiExpressionCache.getInstance(_wdkModel); - // create summarizer (interacts with OpenAI) - Summarizer summarizer = new Summarizer(_wdkModel, _costMonitor); + // create summarizer (interacts with Claude) + ClaudeSummarizer summarizer = new ClaudeSummarizer(_wdkModel, _costMonitor); // open record and output streams try (RecordStream recordStream = RecordStreamFactory.getRecordStream(_baseAnswer, List.of(), tables); @@ -93,12 +103,12 @@ protected void write(OutputStream out) throws IOException, WdkModelException { // create summary inputs GeneSummaryInputs summaryInputs = - GeneRecordProcessor.getSummaryInputsFromRecord(record, Summarizer.OPENAI_CHAT_MODEL.toString(), + GeneRecordProcessor.getSummaryInputsFromRecord(record, ClaudeSummarizer.CLAUDE_MODEL.toString(), Summarizer::getExperimentMessage, Summarizer::getFinalSummaryMessage); // fetch summary, producing if necessary and requested JSONObject expressionSummary = _populateIfNotPresent - ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments) + ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments, _maxConcurrentRequests) : cache.readSummary(summaryInputs); // join entries with commas diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java index 3bc5768b7..4054eb1e6 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java @@ -73,7 +73,6 @@ public class AiExpressionCache { private static Logger LOG = Logger.getLogger(AiExpressionCache.class); // parallel processing - private static final int MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST = 10; private static final long VISIT_ENTRY_LOCK_MAX_WAIT_MILLIS = 50; // cache location @@ -317,18 +316,20 @@ private static Optional readCachedData(Path entryDir) { * @param summaryInputs gene summary inputs * @param experimentDescriber function to describe an experiment * @param experimentSummarizer function to summarize experiments into an expression summary + * @param maxConcurrentRequests maximum number of concurrent experiment lookups * @return expression summary (will always be a cache hit) */ public JSONObject populateSummary(GeneSummaryInputs summaryInputs, FunctionWithException> experimentDescriber, - BiFunctionWithException, JSONObject> experimentSummarizer) { + BiFunctionWithException, JSONObject> experimentSummarizer, + int maxConcurrentRequests) { try { return _cache.populateAndProcessContent(summaryInputs.getGeneId(), // populator entryDir -> { // first populate each dataset entry as needed and collect experiment descriptors - List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber); + List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber, maxConcurrentRequests); // sort them most-interesting first so that the "Other" section will be filled // in that order (and also to give the AI the data in a sensible order) @@ -362,14 +363,16 @@ public JSONObject populateSummary(GeneSummaryInputs summaryInputs, * * @param experimentData experiment inputs * @param experimentDescriber function to describe an experiment + * @param maxConcurrentRequests maximum number of concurrent experiment lookups * @return list of cached experiment descriptions * @throws Exception if unable to generate descriptions or store */ private List populateExperiments(List experimentData, - FunctionWithException> experimentDescriber) throws Exception { + FunctionWithException> experimentDescriber, + int maxConcurrentRequests) throws Exception { // use a thread for each experiment, up to a reasonable max - int threadPoolSize = Math.min(MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST, experimentData.size()); + int threadPoolSize = Math.min(maxConcurrentRequests, experimentData.size()); ExecutorService exec = Executors.newFixedThreadPool(threadPoolSize); try { diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java new file mode 100644 index 000000000..7a49f4870 --- /dev/null +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java @@ -0,0 +1,129 @@ +package org.apidb.apicommon.model.report.ai.expression; + +import java.time.Duration; +import java.util.concurrent.CompletableFuture; + +import org.gusdb.wdk.model.WdkModel; +import org.gusdb.wdk.model.WdkModelException; + +import com.anthropic.client.AnthropicClientAsync; +import com.anthropic.client.okhttp.AnthropicOkHttpClientAsync; +import com.anthropic.models.messages.MessageCreateParams; +import com.anthropic.models.messages.Model; +import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema; + +public class ClaudeSummarizer extends Summarizer { + + public static final Model CLAUDE_MODEL = Model.CLAUDE_SONNET_4_5_20250929; + + private static final String CLAUDE_API_KEY_PROP_NAME = "CLAUDE_API_KEY"; + + private final AnthropicClientAsync _claudeClient; + + public ClaudeSummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException { + super(costMonitor); + + String apiKey = wdkModel.getProperties().get(CLAUDE_API_KEY_PROP_NAME); + if (apiKey == null) { + throw new WdkModelException("WDK property '" + CLAUDE_API_KEY_PROP_NAME + "' has not been set."); + } + + _claudeClient = AnthropicOkHttpClientAsync.builder() + .apiKey(apiKey) + .maxRetries(32) // Handle 429 errors + .checkJacksonVersionCompatibility(false) + .build(); + } + + @Override + protected CompletableFuture callApiForJson(String prompt, Schema schema) { + // Convert JSON schema to natural language description for Claude + String jsonFormatInstructions = convertSchemaToPromptInstructions(schema); + + String enhancedPrompt = prompt + "\n\n" + jsonFormatInstructions; + + MessageCreateParams request = MessageCreateParams.builder() + .model(CLAUDE_MODEL) + .maxTokens((long) MAX_RESPONSE_TOKENS) + .system(SYSTEM_MESSAGE) + .addUserMessage(enhancedPrompt) + .build(); + + return retryOnOverload( + () -> _claudeClient.messages().create(request), + e -> e instanceof com.anthropic.errors.InternalServerException, + "Claude API call" + ).thenApply(response -> { + // Convert Claude usage to OpenAI format for cost monitoring + com.anthropic.models.messages.Usage claudeUsage = response.usage(); + com.openai.models.completions.CompletionUsage openAiUsage = com.openai.models.completions.CompletionUsage.builder() + .promptTokens(claudeUsage.inputTokens()) + .completionTokens(claudeUsage.outputTokens()) + .totalTokens(claudeUsage.inputTokens() + claudeUsage.outputTokens()) + .build(); + + _costMonitor.updateCost(java.util.Optional.of(openAiUsage)); + + // Extract text from content blocks using stream API + String rawText = response.content().stream() + .flatMap(contentBlock -> contentBlock.text().stream()) + .map(textBlock -> textBlock.text()) + .findFirst() + .orElseThrow(() -> new RuntimeException("No text content found in Claude response")); + + // Strip JSON markdown formatting if present + return stripJsonMarkdown(rawText); + }); + } + + @Override + protected void updateCostMonitor(Object apiResponse) { + // Claude response handling is done in callApiForJson + } + + private String stripJsonMarkdown(String text) { + String trimmed = text.trim(); + + // Remove ```json and ``` markdown formatting + if (trimmed.startsWith("```json")) { + trimmed = trimmed.substring(7); // Remove "```json" + } else if (trimmed.startsWith("```")) { + trimmed = trimmed.substring(3); // Remove "```" + } + + if (trimmed.endsWith("```")) { + trimmed = trimmed.substring(0, trimmed.length() - 3); // Remove trailing "```" + } + + return trimmed.trim(); + } + + private String convertSchemaToPromptInstructions(Schema schema) { + // Convert OpenAI JSON schema to Claude-friendly format instructions + if (schema == experimentResponseSchema) { + return "Respond in valid JSON format matching this exact structure:\n" + + "{\n" + + " \"one_sentence_summary\": \"string describing gene expression\",\n" + + " \"biological_importance\": \"integer 0-5\",\n" + + " \"confidence\": \"integer 0-5\",\n" + + " \"experiment_keywords\": [\"array\", \"of\", \"strings\"],\n" + + " \"notes\": \"string with additional context\"\n" + + "}"; + } else if (schema == finalResponseSchema) { + return "Respond in valid JSON format matching this exact structure:\n" + + "{\n" + + " \"headline\": \"string summarizing key results\",\n" + + " \"one_paragraph_summary\": \"string with ~100 words\",\n" + + " \"topics\": [\n" + + " {\n" + + " \"headline\": \"string summarizing topic\",\n" + + " \"one_sentence_summary\": \"string describing topic results\",\n" + + " \"dataset_ids\": [\"array\", \"of\", \"dataset_id\", \"strings\"]\n" + + " }\n" + + " ]\n" + + "}"; + } else { + return "Respond in valid JSON format."; + } + } +} diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java index 2185ee365..6cd7f3574 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java @@ -21,7 +21,7 @@ import org.json.JSONException; import org.json.JSONObject; -import com.openai.models.CompletionUsage; +import com.openai.models.completions.CompletionUsage; public class DailyCostMonitor { @@ -31,10 +31,15 @@ public class DailyCostMonitor { private static final String DAILY_COST_ACCUMULATION_FILE_DIR = "dailyCost"; private static final String DAILY_COST_ACCUMULATION_FILE = "daily_cost_accumulation.txt"; - // model prop keys - private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST"; - private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS"; - private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS"; + // model prop keys (new names without OPENAI_ prefix) + private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "MAX_DAILY_AI_EXPRESSION_DOLLAR_COST"; + private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_INPUT_TOKENS"; + private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS"; + + // deprecated model prop keys (with OPENAI_ prefix) + private static final String DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST"; + private static final String DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS"; + private static final String DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS"; // lock characteristics private static final long DEFAULT_TIMEOUT_MILLIS = 1000; @@ -68,9 +73,25 @@ public DailyCostMonitor(WdkModel wdkModel) throws WdkModelException { _costMonitoringFile = _costMonitoringDir.resolve(DAILY_COST_ACCUMULATION_FILE); - _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME); - _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000; - _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000; + _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME, DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME); + _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000; + _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000; + } + + private double getNumberProp(WdkModel wdkModel, String propName, String deprecatedPropName) throws WdkModelException { + // First try the new property name + if (wdkModel.getProperties().get(propName) != null) { + return getNumberProp(wdkModel, propName); + } + + // Fall back to deprecated property name with warning + if (wdkModel.getProperties().get(deprecatedPropName) != null) { + LOG.warn("WDK property '" + deprecatedPropName + "' is deprecated. Please use '" + propName + "' instead."); + return getNumberProp(wdkModel, deprecatedPropName); + } + + // Neither property is set + throw new WdkModelException("WDK property '" + propName + "' (or deprecated '" + deprecatedPropName + "') has not been set."); } private double getNumberProp(WdkModel wdkModel, String propName) throws WdkModelException { diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java new file mode 100644 index 000000000..36811b32b --- /dev/null +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java @@ -0,0 +1,71 @@ +package org.apidb.apicommon.model.report.ai.expression; + +import java.util.concurrent.CompletableFuture; + +import org.gusdb.wdk.model.WdkModel; +import org.gusdb.wdk.model.WdkModelException; + +import com.openai.client.OpenAIClientAsync; +import com.openai.client.okhttp.OpenAIOkHttpClientAsync; +import com.openai.models.chat.completions.ChatCompletionCreateParams; +import com.openai.models.ChatModel; +import com.openai.models.ResponseFormatJsonSchema; +import com.openai.models.ResponseFormatJsonSchema.JsonSchema; + +public class OpenAISummarizer extends Summarizer { + + // provide exact model number for semi-reproducibility + public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06; + + private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY"; + + private final OpenAIClientAsync _openAIClient; + + public OpenAISummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException { + super(costMonitor); + + String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME); + if (apiKey == null) { + throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set."); + } + + _openAIClient = OpenAIOkHttpClientAsync.builder() + .apiKey(apiKey) + .maxRetries(32) // Handle 429 errors + .build(); + } + + @Override + protected CompletableFuture callApiForJson(String prompt, com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema schema) { + ChatCompletionCreateParams request = ChatCompletionCreateParams.builder() + .model(OPENAI_CHAT_MODEL) + .maxCompletionTokens(MAX_RESPONSE_TOKENS) + .responseFormat(ResponseFormatJsonSchema.builder() + .jsonSchema(JsonSchema.builder() + .name("structured-response") + .schema(schema) + .strict(true) + .build()) + .build()) + .addSystemMessage(SYSTEM_MESSAGE) + .addUserMessage(prompt) + .build(); + + return retryOnOverload( + () -> _openAIClient.chat().completions().create(request), + e -> e instanceof com.openai.errors.InternalServerException, + "OpenAI API call" + ).thenApply(completion -> { + // update cost accumulator + _costMonitor.updateCost(completion.usage()); + + // return JSON string + return completion.choices().get(0).message().content().get(); + }); + } + + @Override + protected void updateCostMonitor(Object apiResponse) { + // OpenAI response handling is done in callApiForJson + } +} diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java index b62338699..f58b03e60 100644 --- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java +++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java @@ -17,28 +17,20 @@ import org.json.JSONException; import org.json.JSONObject; -import com.openai.client.OpenAIClientAsync; -import com.openai.client.okhttp.OpenAIOkHttpClientAsync; import com.openai.core.JsonValue; -import com.openai.models.ChatCompletionCreateParams; -import com.openai.models.ChatModel; -import com.openai.models.ResponseFormatJsonSchema; import com.openai.models.ResponseFormatJsonSchema.JsonSchema; import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema; -public class Summarizer { +public abstract class Summarizer { - // provide exact model number for semi-reproducibility - public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06; - - private static final int MAX_RESPONSE_TOKENS = 10000; + protected static final int MAX_RESPONSE_TOKENS = 10000; private static final int MAX_MALFORMED_RESPONSE_RETRIES = 3; - private static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data"; + protected static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data"; // Prepare JSON schemas for structured responses - private static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder() + protected static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder() .putAdditionalProperty("type", JsonValue.from("object")) .putAdditionalProperty("properties", JsonValue.from(Map.of( "one_sentence_summary", Map.of("type", "string"), @@ -57,7 +49,7 @@ public class Summarizer { .putAdditionalProperty("additionalProperties", JsonValue.from(false)) .build(); - private static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder() + protected static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder() .putAdditionalProperty("type", JsonValue.from("object")) .putAdditionalProperty("properties", JsonValue.from(Map.of( "headline", Map.of("type", "string"), @@ -81,26 +73,86 @@ public class Summarizer { .putAdditionalProperty("additionalProperties", JsonValue.from(false)) .build(); - private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY"; - - private final OpenAIClientAsync _openAIClient; - private final DailyCostMonitor _costMonitor; + protected final DailyCostMonitor _costMonitor; private static final Logger LOG = Logger.getLogger(Summarizer.class); - public Summarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException { + public Summarizer(DailyCostMonitor costMonitor) { + _costMonitor = costMonitor; + } - String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME); - if (apiKey == null) { - throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set."); - } + /** + * Retries an operation with exponential backoff if it fails with a retriable error. + * + * @param the return type of the operation + * @param operation supplier that produces the CompletableFuture to execute + * @param shouldRetry predicate to determine if an exception should trigger a retry + * @param operationDescription description for logging purposes + * @return CompletableFuture with the result of the operation + */ + protected CompletableFuture retryOnOverload( + java.util.function.Supplier> operation, + java.util.function.Predicate shouldRetry, + String operationDescription) { + + final int maxRetries = 3; + final long[] backoffDelaysMs = {1000, 2000, 4000}; // 1s, 2s, 4s + + return retryWithBackoff(operation, shouldRetry, operationDescription, 0, maxRetries, backoffDelaysMs); + } - _openAIClient = OpenAIOkHttpClientAsync.builder() - .apiKey(apiKey) - .maxRetries(32) // Handle 429 errors - .build(); + private CompletableFuture retryWithBackoff( + java.util.function.Supplier> operation, + java.util.function.Predicate shouldRetry, + String operationDescription, + int attemptNumber, + int maxRetries, + long[] backoffDelaysMs) { + + CompletableFuture result = new CompletableFuture<>(); + + operation.get().whenComplete((value, throwable) -> { + if (throwable == null) { + // Success case + result.complete(value); + } else { + // Error case - unwrap CompletionException to get the actual cause + Throwable actualCause = throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null + ? throwable.getCause() + : throwable; + + // Check if we should retry this exception and haven't exceeded max retries + if (shouldRetry.test(actualCause) && attemptNumber < maxRetries) { + long delayMs = backoffDelaysMs[attemptNumber]; + LOG.warn(String.format( + "Retrying %s after error (attempt %d/%d, waiting %dms): %s", + operationDescription, attemptNumber + 1, maxRetries, delayMs, actualCause.getMessage())); + + // Schedule retry after delay + new java.util.Timer().schedule(new java.util.TimerTask() { + @Override + public void run() { + retryWithBackoff(operation, shouldRetry, operationDescription, attemptNumber + 1, maxRetries, backoffDelaysMs) + .whenComplete((retryValue, retryError) -> { + if (retryError != null) { + result.completeExceptionally(retryError); + } else { + result.complete(retryValue); + } + }); + } + }, delayMs); + } else { + // No more retries or non-retriable exception + if (attemptNumber >= maxRetries) { + LOG.error(String.format("Failed %s after %d retries: %s", operationDescription, maxRetries, actualCause.getMessage())); + } + result.completeExceptionally(throwable); + } + } + }); - _costMonitor = costMonitor; + return result; } public static String getExperimentMessage(JSONObject experiment) { @@ -133,12 +185,9 @@ public static String getExperimentMessage(JSONObject experiment) { public CompletableFuture describeExperiment(ExperimentInputs experimentInputs) { - ChatCompletionCreateParams request = buildAiRequest( - "experiment-summary", - experimentResponseSchema, - getExperimentMessage(experimentInputs.getExperimentData())); + String prompt = getExperimentMessage(experimentInputs.getExperimentData()); - return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), request, json -> { + return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), prompt, experimentResponseSchema, json -> { // add some fields to the result to aid the final summarization return json .put("dataset_id", experimentInputs.getDatasetId()) @@ -151,6 +200,7 @@ public static String getFinalSummaryMessage(List experiments) { return "Below are AI-generated summaries of one gene's behavior in all the transcriptomics experiments available in VEuPathDB, provided in JSON format:\n\n" + String.format("```json\n%s\n```\n\n", new JSONArray(experiments).toString(2)) + "Generate a one-paragraph summary (~100 words) describing the gene's expression. Structure it using ,
    , and
  • tags with no attributes. If relevant, briefly speculate on the gene's potential function, but only if justified by the data. Also, generate a short, specific headline for the summary. The headline must reflect this gene's expression and **must not** include generic phrases like \"comprehensive insights into\" or the word \"gene\".\n\n" + + "Use sentence case for all headlines: capitalize only the first word and proper nouns, not every word.\n\n" + "Additionally, group the per-experiment summaries (identified by `dataset_id`) with `biological_importance > 3` and `confidence > 3` into sections by topic. For each topic, provide:\n" + "- A headline summarizing the key experimental results within the topic\n" + "- A concise one-sentence summary of the topic's experimental results\n\n" + @@ -159,12 +209,9 @@ public static String getFinalSummaryMessage(List experiments) { public JSONObject summarizeExperiments(String geneId, List experiments) { - ChatCompletionCreateParams request = buildAiRequest( - "expression-summary", - finalResponseSchema, - getFinalSummaryMessage(experiments)); + String prompt = getFinalSummaryMessage(experiments); - return getValidatedAiResponse("summary for gene " + geneId, request, json -> + return getValidatedAiResponse("summary for gene " + geneId, prompt, finalResponseSchema, json -> // quality control (remove bad `dataset_id`s) and add 'Others' section for any experiments not listed by AI consolidateSummary(json, experiments) ).join(); @@ -240,34 +287,17 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse, } - private static ChatCompletionCreateParams buildAiRequest(String name, Schema schema, String userMessage) { - return ChatCompletionCreateParams.builder() - .model(OPENAI_CHAT_MODEL) - .maxCompletionTokens(MAX_RESPONSE_TOKENS) - .responseFormat(ResponseFormatJsonSchema.builder() - .jsonSchema(JsonSchema.builder() - .name(name) - .schema(schema) - .strict(true) - .build()) - .build()) - .addSystemMessage(SYSTEM_MESSAGE) - .addUserMessage(userMessage) - .build(); - } + protected abstract CompletableFuture callApiForJson(String prompt, Schema schema); + + protected abstract void updateCostMonitor(Object apiResponse); private CompletableFuture getValidatedAiResponse( String operationDescription, - ChatCompletionCreateParams request, + String prompt, + Schema schema, Function createFinalJson ) { - return _openAIClient.chat().completions().create(request).thenApply(completion -> { - - // update cost accumulator - _costMonitor.updateCost(completion.usage()); - - // expect response to be a JSON string - String jsonString = completion.choices().get(0).message().content().get(); + return callApiForJson(prompt, schema).thenApply(jsonString -> { int attempts = 1; Exception mostRecentError; @@ -281,12 +311,10 @@ private CompletableFuture getValidatedAiResponse( } catch (JSONException e) { mostRecentError = e; - LOG.warn("Malformed JSON from OpenAI (attempt " + attempts + ") for " + operationDescription + ". Retrying..."); + LOG.warn("Malformed JSON from AI (attempt " + attempts + ") for " + operationDescription + ". Retrying..."); - // Re-request from OpenAI - completion = _openAIClient.chat().completions().create(request).join(); - _costMonitor.updateCost(completion.usage()); - jsonString = completion.choices().get(0).message().content().get(); + // Re-request from AI + jsonString = callApiForJson(prompt, schema).join(); attempts++; } }