diff --git a/Model/pom.xml b/Model/pom.xml
index cc13dc180..bf7d78569 100644
--- a/Model/pom.xml
+++ b/Model/pom.xml
@@ -135,6 +135,12 @@
openai-java
+
+ com.anthropic
+ anthropic-java
+ 2.9.0
+
+
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java
index 9b61a9f54..185e28ace 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/SingleGeneAiExpressionReporter.java
@@ -13,6 +13,7 @@
import org.apidb.apicommon.model.report.ai.expression.DailyCostMonitor;
import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor;
import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor.GeneSummaryInputs;
+import org.apidb.apicommon.model.report.ai.expression.ClaudeSummarizer;
import org.apidb.apicommon.model.report.ai.expression.Summarizer;
import org.gusdb.wdk.model.WdkModelException;
import org.gusdb.wdk.model.WdkServiceTemporarilyUnavailableException;
@@ -32,8 +33,11 @@ public class SingleGeneAiExpressionReporter extends AbstractReporter {
private static final int MAX_RESULT_SIZE = 1; // one gene at a time for now
private static final String POPULATION_MODE_PROP_KEY = "populateIfNotPresent";
+ private static final String AI_MAX_CONCURRENT_REQUESTS_PROP_KEY = "AI_MAX_CONCURRENT_REQUESTS";
+ private static final int DEFAULT_MAX_CONCURRENT_REQUESTS = 10;
private boolean _populateIfNotPresent;
+ private int _maxConcurrentRequests;
private DailyCostMonitor _costMonitor;
@Override
@@ -42,6 +46,12 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk
// assign cache mode
_populateIfNotPresent = config.optBoolean(POPULATION_MODE_PROP_KEY, false);
+ // read max concurrent requests from model properties or use default
+ String maxConcurrentRequestsStr = _wdkModel.getProperties().get(AI_MAX_CONCURRENT_REQUESTS_PROP_KEY);
+ _maxConcurrentRequests = maxConcurrentRequestsStr != null
+ ? Integer.parseInt(maxConcurrentRequestsStr)
+ : DEFAULT_MAX_CONCURRENT_REQUESTS;
+
// instantiate cost monitor
_costMonitor = new DailyCostMonitor(_wdkModel);
@@ -52,7 +62,7 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk
" should only be assigned to " + geneRecordClass.getFullName());
}
- // check result size; limit to small results due to OpenAI cost
+ // check result size; limit to small results due to AI API cost
if (_baseAnswer.getResultSizeFactory().getResultSize() > MAX_RESULT_SIZE) {
throw new ReporterConfigException("This reporter cannot be called with results of size greater than " + MAX_RESULT_SIZE);
}
@@ -79,8 +89,8 @@ protected void write(OutputStream out) throws IOException, WdkModelException {
// open summary cache (manages persistence of expression data)
AiExpressionCache cache = AiExpressionCache.getInstance(_wdkModel);
- // create summarizer (interacts with OpenAI)
- Summarizer summarizer = new Summarizer(_wdkModel, _costMonitor);
+ // create summarizer (interacts with Claude)
+ ClaudeSummarizer summarizer = new ClaudeSummarizer(_wdkModel, _costMonitor);
// open record and output streams
try (RecordStream recordStream = RecordStreamFactory.getRecordStream(_baseAnswer, List.of(), tables);
@@ -93,12 +103,12 @@ protected void write(OutputStream out) throws IOException, WdkModelException {
// create summary inputs
GeneSummaryInputs summaryInputs =
- GeneRecordProcessor.getSummaryInputsFromRecord(record, Summarizer.OPENAI_CHAT_MODEL.toString(),
+ GeneRecordProcessor.getSummaryInputsFromRecord(record, ClaudeSummarizer.CLAUDE_MODEL.toString(),
Summarizer::getExperimentMessage, Summarizer::getFinalSummaryMessage);
// fetch summary, producing if necessary and requested
JSONObject expressionSummary = _populateIfNotPresent
- ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments)
+ ? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments, _maxConcurrentRequests)
: cache.readSummary(summaryInputs);
// join entries with commas
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java
index 3bc5768b7..4054eb1e6 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/AiExpressionCache.java
@@ -73,7 +73,6 @@ public class AiExpressionCache {
private static Logger LOG = Logger.getLogger(AiExpressionCache.class);
// parallel processing
- private static final int MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST = 10;
private static final long VISIT_ENTRY_LOCK_MAX_WAIT_MILLIS = 50;
// cache location
@@ -317,18 +316,20 @@ private static Optional readCachedData(Path entryDir) {
* @param summaryInputs gene summary inputs
* @param experimentDescriber function to describe an experiment
* @param experimentSummarizer function to summarize experiments into an expression summary
+ * @param maxConcurrentRequests maximum number of concurrent experiment lookups
* @return expression summary (will always be a cache hit)
*/
public JSONObject populateSummary(GeneSummaryInputs summaryInputs,
FunctionWithException> experimentDescriber,
- BiFunctionWithException, JSONObject> experimentSummarizer) {
+ BiFunctionWithException, JSONObject> experimentSummarizer,
+ int maxConcurrentRequests) {
try {
return _cache.populateAndProcessContent(summaryInputs.getGeneId(),
// populator
entryDir -> {
// first populate each dataset entry as needed and collect experiment descriptors
- List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber);
+ List experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber, maxConcurrentRequests);
// sort them most-interesting first so that the "Other" section will be filled
// in that order (and also to give the AI the data in a sensible order)
@@ -362,14 +363,16 @@ public JSONObject populateSummary(GeneSummaryInputs summaryInputs,
*
* @param experimentData experiment inputs
* @param experimentDescriber function to describe an experiment
+ * @param maxConcurrentRequests maximum number of concurrent experiment lookups
* @return list of cached experiment descriptions
* @throws Exception if unable to generate descriptions or store
*/
private List populateExperiments(List experimentData,
- FunctionWithException> experimentDescriber) throws Exception {
+ FunctionWithException> experimentDescriber,
+ int maxConcurrentRequests) throws Exception {
// use a thread for each experiment, up to a reasonable max
- int threadPoolSize = Math.min(MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST, experimentData.size());
+ int threadPoolSize = Math.min(maxConcurrentRequests, experimentData.size());
ExecutorService exec = Executors.newFixedThreadPool(threadPoolSize);
try {
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java
new file mode 100644
index 000000000..7a49f4870
--- /dev/null
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/ClaudeSummarizer.java
@@ -0,0 +1,129 @@
+package org.apidb.apicommon.model.report.ai.expression;
+
+import java.time.Duration;
+import java.util.concurrent.CompletableFuture;
+
+import org.gusdb.wdk.model.WdkModel;
+import org.gusdb.wdk.model.WdkModelException;
+
+import com.anthropic.client.AnthropicClientAsync;
+import com.anthropic.client.okhttp.AnthropicOkHttpClientAsync;
+import com.anthropic.models.messages.MessageCreateParams;
+import com.anthropic.models.messages.Model;
+import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema;
+
+public class ClaudeSummarizer extends Summarizer {
+
+ public static final Model CLAUDE_MODEL = Model.CLAUDE_SONNET_4_5_20250929;
+
+ private static final String CLAUDE_API_KEY_PROP_NAME = "CLAUDE_API_KEY";
+
+ private final AnthropicClientAsync _claudeClient;
+
+ public ClaudeSummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException {
+ super(costMonitor);
+
+ String apiKey = wdkModel.getProperties().get(CLAUDE_API_KEY_PROP_NAME);
+ if (apiKey == null) {
+ throw new WdkModelException("WDK property '" + CLAUDE_API_KEY_PROP_NAME + "' has not been set.");
+ }
+
+ _claudeClient = AnthropicOkHttpClientAsync.builder()
+ .apiKey(apiKey)
+ .maxRetries(32) // Handle 429 errors
+ .checkJacksonVersionCompatibility(false)
+ .build();
+ }
+
+ @Override
+ protected CompletableFuture callApiForJson(String prompt, Schema schema) {
+ // Convert JSON schema to natural language description for Claude
+ String jsonFormatInstructions = convertSchemaToPromptInstructions(schema);
+
+ String enhancedPrompt = prompt + "\n\n" + jsonFormatInstructions;
+
+ MessageCreateParams request = MessageCreateParams.builder()
+ .model(CLAUDE_MODEL)
+ .maxTokens((long) MAX_RESPONSE_TOKENS)
+ .system(SYSTEM_MESSAGE)
+ .addUserMessage(enhancedPrompt)
+ .build();
+
+ return retryOnOverload(
+ () -> _claudeClient.messages().create(request),
+ e -> e instanceof com.anthropic.errors.InternalServerException,
+ "Claude API call"
+ ).thenApply(response -> {
+ // Convert Claude usage to OpenAI format for cost monitoring
+ com.anthropic.models.messages.Usage claudeUsage = response.usage();
+ com.openai.models.completions.CompletionUsage openAiUsage = com.openai.models.completions.CompletionUsage.builder()
+ .promptTokens(claudeUsage.inputTokens())
+ .completionTokens(claudeUsage.outputTokens())
+ .totalTokens(claudeUsage.inputTokens() + claudeUsage.outputTokens())
+ .build();
+
+ _costMonitor.updateCost(java.util.Optional.of(openAiUsage));
+
+ // Extract text from content blocks using stream API
+ String rawText = response.content().stream()
+ .flatMap(contentBlock -> contentBlock.text().stream())
+ .map(textBlock -> textBlock.text())
+ .findFirst()
+ .orElseThrow(() -> new RuntimeException("No text content found in Claude response"));
+
+ // Strip JSON markdown formatting if present
+ return stripJsonMarkdown(rawText);
+ });
+ }
+
+ @Override
+ protected void updateCostMonitor(Object apiResponse) {
+ // Claude response handling is done in callApiForJson
+ }
+
+ private String stripJsonMarkdown(String text) {
+ String trimmed = text.trim();
+
+ // Remove ```json and ``` markdown formatting
+ if (trimmed.startsWith("```json")) {
+ trimmed = trimmed.substring(7); // Remove "```json"
+ } else if (trimmed.startsWith("```")) {
+ trimmed = trimmed.substring(3); // Remove "```"
+ }
+
+ if (trimmed.endsWith("```")) {
+ trimmed = trimmed.substring(0, trimmed.length() - 3); // Remove trailing "```"
+ }
+
+ return trimmed.trim();
+ }
+
+ private String convertSchemaToPromptInstructions(Schema schema) {
+ // Convert OpenAI JSON schema to Claude-friendly format instructions
+ if (schema == experimentResponseSchema) {
+ return "Respond in valid JSON format matching this exact structure:\n" +
+ "{\n" +
+ " \"one_sentence_summary\": \"string describing gene expression\",\n" +
+ " \"biological_importance\": \"integer 0-5\",\n" +
+ " \"confidence\": \"integer 0-5\",\n" +
+ " \"experiment_keywords\": [\"array\", \"of\", \"strings\"],\n" +
+ " \"notes\": \"string with additional context\"\n" +
+ "}";
+ } else if (schema == finalResponseSchema) {
+ return "Respond in valid JSON format matching this exact structure:\n" +
+ "{\n" +
+ " \"headline\": \"string summarizing key results\",\n" +
+ " \"one_paragraph_summary\": \"string with ~100 words\",\n" +
+ " \"topics\": [\n" +
+ " {\n" +
+ " \"headline\": \"string summarizing topic\",\n" +
+ " \"one_sentence_summary\": \"string describing topic results\",\n" +
+ " \"dataset_ids\": [\"array\", \"of\", \"dataset_id\", \"strings\"]\n" +
+ " }\n" +
+ " ]\n" +
+ "}";
+ } else {
+ return "Respond in valid JSON format.";
+ }
+ }
+}
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java
index 2185ee365..6cd7f3574 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/DailyCostMonitor.java
@@ -21,7 +21,7 @@
import org.json.JSONException;
import org.json.JSONObject;
-import com.openai.models.CompletionUsage;
+import com.openai.models.completions.CompletionUsage;
public class DailyCostMonitor {
@@ -31,10 +31,15 @@ public class DailyCostMonitor {
private static final String DAILY_COST_ACCUMULATION_FILE_DIR = "dailyCost";
private static final String DAILY_COST_ACCUMULATION_FILE = "daily_cost_accumulation.txt";
- // model prop keys
- private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
- private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
- private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";
+ // model prop keys (new names without OPENAI_ prefix)
+ private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
+ private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
+ private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";
+
+ // deprecated model prop keys (with OPENAI_ prefix)
+ private static final String DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
+ private static final String DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
+ private static final String DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";
// lock characteristics
private static final long DEFAULT_TIMEOUT_MILLIS = 1000;
@@ -68,9 +73,25 @@ public DailyCostMonitor(WdkModel wdkModel) throws WdkModelException {
_costMonitoringFile = _costMonitoringDir.resolve(DAILY_COST_ACCUMULATION_FILE);
- _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME);
- _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000;
- _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000;
+ _maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME, DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME);
+ _costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000;
+ _costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000;
+ }
+
+ private double getNumberProp(WdkModel wdkModel, String propName, String deprecatedPropName) throws WdkModelException {
+ // First try the new property name
+ if (wdkModel.getProperties().get(propName) != null) {
+ return getNumberProp(wdkModel, propName);
+ }
+
+ // Fall back to deprecated property name with warning
+ if (wdkModel.getProperties().get(deprecatedPropName) != null) {
+ LOG.warn("WDK property '" + deprecatedPropName + "' is deprecated. Please use '" + propName + "' instead.");
+ return getNumberProp(wdkModel, deprecatedPropName);
+ }
+
+ // Neither property is set
+ throw new WdkModelException("WDK property '" + propName + "' (or deprecated '" + deprecatedPropName + "') has not been set.");
}
private double getNumberProp(WdkModel wdkModel, String propName) throws WdkModelException {
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java
new file mode 100644
index 000000000..36811b32b
--- /dev/null
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/OpenAISummarizer.java
@@ -0,0 +1,71 @@
+package org.apidb.apicommon.model.report.ai.expression;
+
+import java.util.concurrent.CompletableFuture;
+
+import org.gusdb.wdk.model.WdkModel;
+import org.gusdb.wdk.model.WdkModelException;
+
+import com.openai.client.OpenAIClientAsync;
+import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
+import com.openai.models.chat.completions.ChatCompletionCreateParams;
+import com.openai.models.ChatModel;
+import com.openai.models.ResponseFormatJsonSchema;
+import com.openai.models.ResponseFormatJsonSchema.JsonSchema;
+
+public class OpenAISummarizer extends Summarizer {
+
+ // provide exact model number for semi-reproducibility
+ public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06;
+
+ private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY";
+
+ private final OpenAIClientAsync _openAIClient;
+
+ public OpenAISummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException {
+ super(costMonitor);
+
+ String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME);
+ if (apiKey == null) {
+ throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set.");
+ }
+
+ _openAIClient = OpenAIOkHttpClientAsync.builder()
+ .apiKey(apiKey)
+ .maxRetries(32) // Handle 429 errors
+ .build();
+ }
+
+ @Override
+ protected CompletableFuture callApiForJson(String prompt, com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema schema) {
+ ChatCompletionCreateParams request = ChatCompletionCreateParams.builder()
+ .model(OPENAI_CHAT_MODEL)
+ .maxCompletionTokens(MAX_RESPONSE_TOKENS)
+ .responseFormat(ResponseFormatJsonSchema.builder()
+ .jsonSchema(JsonSchema.builder()
+ .name("structured-response")
+ .schema(schema)
+ .strict(true)
+ .build())
+ .build())
+ .addSystemMessage(SYSTEM_MESSAGE)
+ .addUserMessage(prompt)
+ .build();
+
+ return retryOnOverload(
+ () -> _openAIClient.chat().completions().create(request),
+ e -> e instanceof com.openai.errors.InternalServerException,
+ "OpenAI API call"
+ ).thenApply(completion -> {
+ // update cost accumulator
+ _costMonitor.updateCost(completion.usage());
+
+ // return JSON string
+ return completion.choices().get(0).message().content().get();
+ });
+ }
+
+ @Override
+ protected void updateCostMonitor(Object apiResponse) {
+ // OpenAI response handling is done in callApiForJson
+ }
+}
diff --git a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java
index b62338699..f58b03e60 100644
--- a/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java
+++ b/Model/src/main/java/org/apidb/apicommon/model/report/ai/expression/Summarizer.java
@@ -17,28 +17,20 @@
import org.json.JSONException;
import org.json.JSONObject;
-import com.openai.client.OpenAIClientAsync;
-import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
import com.openai.core.JsonValue;
-import com.openai.models.ChatCompletionCreateParams;
-import com.openai.models.ChatModel;
-import com.openai.models.ResponseFormatJsonSchema;
import com.openai.models.ResponseFormatJsonSchema.JsonSchema;
import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema;
-public class Summarizer {
+public abstract class Summarizer {
- // provide exact model number for semi-reproducibility
- public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06;
-
- private static final int MAX_RESPONSE_TOKENS = 10000;
+ protected static final int MAX_RESPONSE_TOKENS = 10000;
private static final int MAX_MALFORMED_RESPONSE_RETRIES = 3;
- private static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data";
+ protected static final String SYSTEM_MESSAGE = "You are a bioinformatician working for VEuPathDB.org. You are an expert at providing biologist-friendly summaries of transcriptomic data";
// Prepare JSON schemas for structured responses
- private static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder()
+ protected static final JsonSchema.Schema experimentResponseSchema = JsonSchema.Schema.builder()
.putAdditionalProperty("type", JsonValue.from("object"))
.putAdditionalProperty("properties", JsonValue.from(Map.of(
"one_sentence_summary", Map.of("type", "string"),
@@ -57,7 +49,7 @@ public class Summarizer {
.putAdditionalProperty("additionalProperties", JsonValue.from(false))
.build();
- private static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder()
+ protected static final JsonSchema.Schema finalResponseSchema = JsonSchema.Schema.builder()
.putAdditionalProperty("type", JsonValue.from("object"))
.putAdditionalProperty("properties", JsonValue.from(Map.of(
"headline", Map.of("type", "string"),
@@ -81,26 +73,86 @@ public class Summarizer {
.putAdditionalProperty("additionalProperties", JsonValue.from(false))
.build();
- private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY";
-
- private final OpenAIClientAsync _openAIClient;
- private final DailyCostMonitor _costMonitor;
+ protected final DailyCostMonitor _costMonitor;
private static final Logger LOG = Logger.getLogger(Summarizer.class);
- public Summarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException {
+ public Summarizer(DailyCostMonitor costMonitor) {
+ _costMonitor = costMonitor;
+ }
- String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME);
- if (apiKey == null) {
- throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set.");
- }
+ /**
+ * Retries an operation with exponential backoff if it fails with a retriable error.
+ *
+ * @param the return type of the operation
+ * @param operation supplier that produces the CompletableFuture to execute
+ * @param shouldRetry predicate to determine if an exception should trigger a retry
+ * @param operationDescription description for logging purposes
+ * @return CompletableFuture with the result of the operation
+ */
+ protected CompletableFuture retryOnOverload(
+ java.util.function.Supplier> operation,
+ java.util.function.Predicate shouldRetry,
+ String operationDescription) {
+
+ final int maxRetries = 3;
+ final long[] backoffDelaysMs = {1000, 2000, 4000}; // 1s, 2s, 4s
+
+ return retryWithBackoff(operation, shouldRetry, operationDescription, 0, maxRetries, backoffDelaysMs);
+ }
- _openAIClient = OpenAIOkHttpClientAsync.builder()
- .apiKey(apiKey)
- .maxRetries(32) // Handle 429 errors
- .build();
+ private CompletableFuture retryWithBackoff(
+ java.util.function.Supplier> operation,
+ java.util.function.Predicate shouldRetry,
+ String operationDescription,
+ int attemptNumber,
+ int maxRetries,
+ long[] backoffDelaysMs) {
+
+ CompletableFuture result = new CompletableFuture<>();
+
+ operation.get().whenComplete((value, throwable) -> {
+ if (throwable == null) {
+ // Success case
+ result.complete(value);
+ } else {
+ // Error case - unwrap CompletionException to get the actual cause
+ Throwable actualCause = throwable instanceof java.util.concurrent.CompletionException && throwable.getCause() != null
+ ? throwable.getCause()
+ : throwable;
+
+ // Check if we should retry this exception and haven't exceeded max retries
+ if (shouldRetry.test(actualCause) && attemptNumber < maxRetries) {
+ long delayMs = backoffDelaysMs[attemptNumber];
+ LOG.warn(String.format(
+ "Retrying %s after error (attempt %d/%d, waiting %dms): %s",
+ operationDescription, attemptNumber + 1, maxRetries, delayMs, actualCause.getMessage()));
+
+ // Schedule retry after delay
+ new java.util.Timer().schedule(new java.util.TimerTask() {
+ @Override
+ public void run() {
+ retryWithBackoff(operation, shouldRetry, operationDescription, attemptNumber + 1, maxRetries, backoffDelaysMs)
+ .whenComplete((retryValue, retryError) -> {
+ if (retryError != null) {
+ result.completeExceptionally(retryError);
+ } else {
+ result.complete(retryValue);
+ }
+ });
+ }
+ }, delayMs);
+ } else {
+ // No more retries or non-retriable exception
+ if (attemptNumber >= maxRetries) {
+ LOG.error(String.format("Failed %s after %d retries: %s", operationDescription, maxRetries, actualCause.getMessage()));
+ }
+ result.completeExceptionally(throwable);
+ }
+ }
+ });
- _costMonitor = costMonitor;
+ return result;
}
public static String getExperimentMessage(JSONObject experiment) {
@@ -133,12 +185,9 @@ public static String getExperimentMessage(JSONObject experiment) {
public CompletableFuture describeExperiment(ExperimentInputs experimentInputs) {
- ChatCompletionCreateParams request = buildAiRequest(
- "experiment-summary",
- experimentResponseSchema,
- getExperimentMessage(experimentInputs.getExperimentData()));
+ String prompt = getExperimentMessage(experimentInputs.getExperimentData());
- return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), request, json -> {
+ return getValidatedAiResponse("dataset " + experimentInputs.getDatasetId(), prompt, experimentResponseSchema, json -> {
// add some fields to the result to aid the final summarization
return json
.put("dataset_id", experimentInputs.getDatasetId())
@@ -151,6 +200,7 @@ public static String getFinalSummaryMessage(List experiments) {
return "Below are AI-generated summaries of one gene's behavior in all the transcriptomics experiments available in VEuPathDB, provided in JSON format:\n\n" +
String.format("```json\n%s\n```\n\n", new JSONArray(experiments).toString(2)) +
"Generate a one-paragraph summary (~100 words) describing the gene's expression. Structure it using , , and - tags with no attributes. If relevant, briefly speculate on the gene's potential function, but only if justified by the data. Also, generate a short, specific headline for the summary. The headline must reflect this gene's expression and **must not** include generic phrases like \"comprehensive insights into\" or the word \"gene\".\n\n" +
+ "Use sentence case for all headlines: capitalize only the first word and proper nouns, not every word.\n\n" +
"Additionally, group the per-experiment summaries (identified by `dataset_id`) with `biological_importance > 3` and `confidence > 3` into sections by topic. For each topic, provide:\n" +
"- A headline summarizing the key experimental results within the topic\n" +
"- A concise one-sentence summary of the topic's experimental results\n\n" +
@@ -159,12 +209,9 @@ public static String getFinalSummaryMessage(List experiments) {
public JSONObject summarizeExperiments(String geneId, List experiments) {
- ChatCompletionCreateParams request = buildAiRequest(
- "expression-summary",
- finalResponseSchema,
- getFinalSummaryMessage(experiments));
+ String prompt = getFinalSummaryMessage(experiments);
- return getValidatedAiResponse("summary for gene " + geneId, request, json ->
+ return getValidatedAiResponse("summary for gene " + geneId, prompt, finalResponseSchema, json ->
// quality control (remove bad `dataset_id`s) and add 'Others' section for any experiments not listed by AI
consolidateSummary(json, experiments)
).join();
@@ -240,34 +287,17 @@ private static JSONObject consolidateSummary(JSONObject summaryResponse,
}
- private static ChatCompletionCreateParams buildAiRequest(String name, Schema schema, String userMessage) {
- return ChatCompletionCreateParams.builder()
- .model(OPENAI_CHAT_MODEL)
- .maxCompletionTokens(MAX_RESPONSE_TOKENS)
- .responseFormat(ResponseFormatJsonSchema.builder()
- .jsonSchema(JsonSchema.builder()
- .name(name)
- .schema(schema)
- .strict(true)
- .build())
- .build())
- .addSystemMessage(SYSTEM_MESSAGE)
- .addUserMessage(userMessage)
- .build();
- }
+ protected abstract CompletableFuture callApiForJson(String prompt, Schema schema);
+
+ protected abstract void updateCostMonitor(Object apiResponse);
private CompletableFuture getValidatedAiResponse(
String operationDescription,
- ChatCompletionCreateParams request,
+ String prompt,
+ Schema schema,
Function createFinalJson
) {
- return _openAIClient.chat().completions().create(request).thenApply(completion -> {
-
- // update cost accumulator
- _costMonitor.updateCost(completion.usage());
-
- // expect response to be a JSON string
- String jsonString = completion.choices().get(0).message().content().get();
+ return callApiForJson(prompt, schema).thenApply(jsonString -> {
int attempts = 1;
Exception mostRecentError;
@@ -281,12 +311,10 @@ private CompletableFuture getValidatedAiResponse(
}
catch (JSONException e) {
mostRecentError = e;
- LOG.warn("Malformed JSON from OpenAI (attempt " + attempts + ") for " + operationDescription + ". Retrying...");
+ LOG.warn("Malformed JSON from AI (attempt " + attempts + ") for " + operationDescription + ". Retrying...");
- // Re-request from OpenAI
- completion = _openAIClient.chat().completions().create(request).join();
- _costMonitor.updateCost(completion.usage());
- jsonString = completion.choices().get(0).message().content().get();
+ // Re-request from AI
+ jsonString = callApiForJson(prompt, schema).join();
attempts++;
}
}