Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 6 additions & 0 deletions Model/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,12 @@
<artifactId>openai-java</artifactId>
</dependency>

<dependency>
<groupId>com.anthropic</groupId>
<artifactId>anthropic-java</artifactId>
<version>2.9.0</version>
</dependency>

</dependencies>

</project>
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
import org.apidb.apicommon.model.report.ai.expression.DailyCostMonitor;
import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor;
import org.apidb.apicommon.model.report.ai.expression.GeneRecordProcessor.GeneSummaryInputs;
import org.apidb.apicommon.model.report.ai.expression.ClaudeSummarizer;
import org.apidb.apicommon.model.report.ai.expression.Summarizer;
import org.gusdb.wdk.model.WdkModelException;
import org.gusdb.wdk.model.WdkServiceTemporarilyUnavailableException;
Expand All @@ -32,8 +33,11 @@ public class SingleGeneAiExpressionReporter extends AbstractReporter {
private static final int MAX_RESULT_SIZE = 1; // one gene at a time for now

private static final String POPULATION_MODE_PROP_KEY = "populateIfNotPresent";
private static final String AI_MAX_CONCURRENT_REQUESTS_PROP_KEY = "AI_MAX_CONCURRENT_REQUESTS";
private static final int DEFAULT_MAX_CONCURRENT_REQUESTS = 10;

private boolean _populateIfNotPresent;
private int _maxConcurrentRequests;
private DailyCostMonitor _costMonitor;

@Override
Expand All @@ -42,6 +46,12 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk
// assign cache mode
_populateIfNotPresent = config.optBoolean(POPULATION_MODE_PROP_KEY, false);

// read max concurrent requests from model properties or use default
String maxConcurrentRequestsStr = _wdkModel.getProperties().get(AI_MAX_CONCURRENT_REQUESTS_PROP_KEY);
_maxConcurrentRequests = maxConcurrentRequestsStr != null
? Integer.parseInt(maxConcurrentRequestsStr)
: DEFAULT_MAX_CONCURRENT_REQUESTS;

// instantiate cost monitor
_costMonitor = new DailyCostMonitor(_wdkModel);

Expand All @@ -52,7 +62,7 @@ public Reporter configure(JSONObject config) throws ReporterConfigException, Wdk
" should only be assigned to " + geneRecordClass.getFullName());
}

// check result size; limit to small results due to OpenAI cost
// check result size; limit to small results due to AI API cost
if (_baseAnswer.getResultSizeFactory().getResultSize() > MAX_RESULT_SIZE) {
throw new ReporterConfigException("This reporter cannot be called with results of size greater than " + MAX_RESULT_SIZE);
}
Expand All @@ -79,8 +89,8 @@ protected void write(OutputStream out) throws IOException, WdkModelException {
// open summary cache (manages persistence of expression data)
AiExpressionCache cache = AiExpressionCache.getInstance(_wdkModel);

// create summarizer (interacts with OpenAI)
Summarizer summarizer = new Summarizer(_wdkModel, _costMonitor);
// create summarizer (interacts with Claude)
ClaudeSummarizer summarizer = new ClaudeSummarizer(_wdkModel, _costMonitor);

// open record and output streams
try (RecordStream recordStream = RecordStreamFactory.getRecordStream(_baseAnswer, List.of(), tables);
Expand All @@ -93,12 +103,12 @@ protected void write(OutputStream out) throws IOException, WdkModelException {

// create summary inputs
GeneSummaryInputs summaryInputs =
GeneRecordProcessor.getSummaryInputsFromRecord(record, Summarizer.OPENAI_CHAT_MODEL.toString(),
GeneRecordProcessor.getSummaryInputsFromRecord(record, ClaudeSummarizer.CLAUDE_MODEL.toString(),
Summarizer::getExperimentMessage, Summarizer::getFinalSummaryMessage);

// fetch summary, producing if necessary and requested
JSONObject expressionSummary = _populateIfNotPresent
? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments)
? cache.populateSummary(summaryInputs, summarizer::describeExperiment, summarizer::summarizeExperiments, _maxConcurrentRequests)
: cache.readSummary(summaryInputs);

// join entries with commas
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ public class AiExpressionCache {
private static Logger LOG = Logger.getLogger(AiExpressionCache.class);

// parallel processing
private static final int MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST = 10;
private static final long VISIT_ENTRY_LOCK_MAX_WAIT_MILLIS = 50;

// cache location
Expand Down Expand Up @@ -317,18 +316,20 @@ private static Optional<JSONObject> readCachedData(Path entryDir) {
* @param summaryInputs gene summary inputs
* @param experimentDescriber function to describe an experiment
* @param experimentSummarizer function to summarize experiments into an expression summary
* @param maxConcurrentRequests maximum number of concurrent experiment lookups
* @return expression summary (will always be a cache hit)
*/
public JSONObject populateSummary(GeneSummaryInputs summaryInputs,
FunctionWithException<ExperimentInputs, CompletableFuture<JSONObject>> experimentDescriber,
BiFunctionWithException<String, List<JSONObject>, JSONObject> experimentSummarizer) {
BiFunctionWithException<String, List<JSONObject>, JSONObject> experimentSummarizer,
int maxConcurrentRequests) {
try {
return _cache.populateAndProcessContent(summaryInputs.getGeneId(),

// populator
entryDir -> {
// first populate each dataset entry as needed and collect experiment descriptors
List<JSONObject> experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber);
List<JSONObject> experiments = populateExperiments(summaryInputs.getExperimentsWithData(), experimentDescriber, maxConcurrentRequests);

// sort them most-interesting first so that the "Other" section will be filled
// in that order (and also to give the AI the data in a sensible order)
Expand Down Expand Up @@ -362,14 +363,16 @@ public JSONObject populateSummary(GeneSummaryInputs summaryInputs,
*
* @param experimentData experiment inputs
* @param experimentDescriber function to describe an experiment
* @param maxConcurrentRequests maximum number of concurrent experiment lookups
* @return list of cached experiment descriptions
* @throws Exception if unable to generate descriptions or store
*/
private List<JSONObject> populateExperiments(List<ExperimentInputs> experimentData,
FunctionWithException<ExperimentInputs, CompletableFuture<JSONObject>> experimentDescriber) throws Exception {
FunctionWithException<ExperimentInputs, CompletableFuture<JSONObject>> experimentDescriber,
int maxConcurrentRequests) throws Exception {

// use a thread for each experiment, up to a reasonable max
int threadPoolSize = Math.min(MAX_CONCURRENT_EXPERIMENT_LOOKUPS_PER_REQUEST, experimentData.size());
int threadPoolSize = Math.min(maxConcurrentRequests, experimentData.size());

ExecutorService exec = Executors.newFixedThreadPool(threadPoolSize);
try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
package org.apidb.apicommon.model.report.ai.expression;

import java.time.Duration;
import java.util.concurrent.CompletableFuture;

import org.gusdb.wdk.model.WdkModel;
import org.gusdb.wdk.model.WdkModelException;

import com.anthropic.client.AnthropicClientAsync;
import com.anthropic.client.okhttp.AnthropicOkHttpClientAsync;
import com.anthropic.models.messages.MessageCreateParams;
import com.anthropic.models.messages.Model;
import com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema;

public class ClaudeSummarizer extends Summarizer {

public static final Model CLAUDE_MODEL = Model.CLAUDE_SONNET_4_5_20250929;

private static final String CLAUDE_API_KEY_PROP_NAME = "CLAUDE_API_KEY";

private final AnthropicClientAsync _claudeClient;

public ClaudeSummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException {
super(costMonitor);

String apiKey = wdkModel.getProperties().get(CLAUDE_API_KEY_PROP_NAME);
if (apiKey == null) {
throw new WdkModelException("WDK property '" + CLAUDE_API_KEY_PROP_NAME + "' has not been set.");
}

_claudeClient = AnthropicOkHttpClientAsync.builder()
.apiKey(apiKey)
.maxRetries(32) // Handle 429 errors
.checkJacksonVersionCompatibility(false)
.build();
}

@Override
protected CompletableFuture<String> callApiForJson(String prompt, Schema schema) {
// Convert JSON schema to natural language description for Claude
String jsonFormatInstructions = convertSchemaToPromptInstructions(schema);

String enhancedPrompt = prompt + "\n\n" + jsonFormatInstructions;

MessageCreateParams request = MessageCreateParams.builder()
.model(CLAUDE_MODEL)
.maxTokens((long) MAX_RESPONSE_TOKENS)
.system(SYSTEM_MESSAGE)
.addUserMessage(enhancedPrompt)
.build();

return retryOnOverload(
() -> _claudeClient.messages().create(request),
e -> e instanceof com.anthropic.errors.InternalServerException,
"Claude API call"
).thenApply(response -> {
// Convert Claude usage to OpenAI format for cost monitoring
com.anthropic.models.messages.Usage claudeUsage = response.usage();
com.openai.models.completions.CompletionUsage openAiUsage = com.openai.models.completions.CompletionUsage.builder()
.promptTokens(claudeUsage.inputTokens())
.completionTokens(claudeUsage.outputTokens())
.totalTokens(claudeUsage.inputTokens() + claudeUsage.outputTokens())
.build();

_costMonitor.updateCost(java.util.Optional.of(openAiUsage));

// Extract text from content blocks using stream API
String rawText = response.content().stream()
.flatMap(contentBlock -> contentBlock.text().stream())
.map(textBlock -> textBlock.text())
.findFirst()
.orElseThrow(() -> new RuntimeException("No text content found in Claude response"));

// Strip JSON markdown formatting if present
return stripJsonMarkdown(rawText);
});
}

@Override
protected void updateCostMonitor(Object apiResponse) {
// Claude response handling is done in callApiForJson
}

private String stripJsonMarkdown(String text) {
String trimmed = text.trim();

// Remove ```json and ``` markdown formatting
if (trimmed.startsWith("```json")) {
trimmed = trimmed.substring(7); // Remove "```json"
} else if (trimmed.startsWith("```")) {
trimmed = trimmed.substring(3); // Remove "```"
}

if (trimmed.endsWith("```")) {
trimmed = trimmed.substring(0, trimmed.length() - 3); // Remove trailing "```"
}

return trimmed.trim();
}

private String convertSchemaToPromptInstructions(Schema schema) {
// Convert OpenAI JSON schema to Claude-friendly format instructions
if (schema == experimentResponseSchema) {
return "Respond in valid JSON format matching this exact structure:\n" +
"{\n" +
" \"one_sentence_summary\": \"string describing gene expression\",\n" +
" \"biological_importance\": \"integer 0-5\",\n" +
" \"confidence\": \"integer 0-5\",\n" +
" \"experiment_keywords\": [\"array\", \"of\", \"strings\"],\n" +
" \"notes\": \"string with additional context\"\n" +
"}";
} else if (schema == finalResponseSchema) {
return "Respond in valid JSON format matching this exact structure:\n" +
"{\n" +
" \"headline\": \"string summarizing key results\",\n" +
" \"one_paragraph_summary\": \"string with ~100 words\",\n" +
" \"topics\": [\n" +
" {\n" +
" \"headline\": \"string summarizing topic\",\n" +
" \"one_sentence_summary\": \"string describing topic results\",\n" +
" \"dataset_ids\": [\"array\", \"of\", \"dataset_id\", \"strings\"]\n" +
" }\n" +
" ]\n" +
"}";
} else {
return "Respond in valid JSON format.";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
import org.json.JSONException;
import org.json.JSONObject;

import com.openai.models.CompletionUsage;
import com.openai.models.completions.CompletionUsage;

public class DailyCostMonitor {

Expand All @@ -31,10 +31,15 @@ public class DailyCostMonitor {
private static final String DAILY_COST_ACCUMULATION_FILE_DIR = "dailyCost";
private static final String DAILY_COST_ACCUMULATION_FILE = "daily_cost_accumulation.txt";

// model prop keys
private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";
// model prop keys (new names without OPENAI_ prefix)
private static final String MAX_DAILY_DOLLAR_COST_PROP_NAME = "MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
private static final String DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
private static final String DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";

// deprecated model prop keys (with OPENAI_ prefix)
private static final String DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME = "OPENAI_MAX_DAILY_AI_EXPRESSION_DOLLAR_COST";
private static final String DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_INPUT_TOKENS";
private static final String DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME = "OPENAI_DOLLAR_COST_PER_1M_AI_OUTPUT_TOKENS";

// lock characteristics
private static final long DEFAULT_TIMEOUT_MILLIS = 1000;
Expand Down Expand Up @@ -68,9 +73,25 @@ public DailyCostMonitor(WdkModel wdkModel) throws WdkModelException {

_costMonitoringFile = _costMonitoringDir.resolve(DAILY_COST_ACCUMULATION_FILE);

_maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME);
_costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000;
_costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000;
_maxDailyDollarCost = getNumberProp(wdkModel, MAX_DAILY_DOLLAR_COST_PROP_NAME, DEPRECATED_MAX_DAILY_DOLLAR_COST_PROP_NAME);
_costPerInputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_INPUT_TOKENS_PROP_NAME) / 1000000;
_costPerOutputToken = getNumberProp(wdkModel, DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME, DEPRECATED_DOLLAR_COST_PER_1M_OUTPUT_TOKENS_PROP_NAME) / 1000000;
}

private double getNumberProp(WdkModel wdkModel, String propName, String deprecatedPropName) throws WdkModelException {
// First try the new property name
if (wdkModel.getProperties().get(propName) != null) {
return getNumberProp(wdkModel, propName);
}

// Fall back to deprecated property name with warning
if (wdkModel.getProperties().get(deprecatedPropName) != null) {
LOG.warn("WDK property '" + deprecatedPropName + "' is deprecated. Please use '" + propName + "' instead.");
return getNumberProp(wdkModel, deprecatedPropName);
}

// Neither property is set
throw new WdkModelException("WDK property '" + propName + "' (or deprecated '" + deprecatedPropName + "') has not been set.");
}

private double getNumberProp(WdkModel wdkModel, String propName) throws WdkModelException {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
package org.apidb.apicommon.model.report.ai.expression;

import java.util.concurrent.CompletableFuture;

import org.gusdb.wdk.model.WdkModel;
import org.gusdb.wdk.model.WdkModelException;

import com.openai.client.OpenAIClientAsync;
import com.openai.client.okhttp.OpenAIOkHttpClientAsync;
import com.openai.models.chat.completions.ChatCompletionCreateParams;
import com.openai.models.ChatModel;
import com.openai.models.ResponseFormatJsonSchema;
import com.openai.models.ResponseFormatJsonSchema.JsonSchema;

public class OpenAISummarizer extends Summarizer {

// provide exact model number for semi-reproducibility
public static final ChatModel OPENAI_CHAT_MODEL = ChatModel.GPT_4O_2024_11_20; // GPT_4O_2024_08_06;

private static final String OPENAI_API_KEY_PROP_NAME = "OPENAI_API_KEY";

private final OpenAIClientAsync _openAIClient;

public OpenAISummarizer(WdkModel wdkModel, DailyCostMonitor costMonitor) throws WdkModelException {
super(costMonitor);

String apiKey = wdkModel.getProperties().get(OPENAI_API_KEY_PROP_NAME);
if (apiKey == null) {
throw new WdkModelException("WDK property '" + OPENAI_API_KEY_PROP_NAME + "' has not been set.");
}

_openAIClient = OpenAIOkHttpClientAsync.builder()
.apiKey(apiKey)
.maxRetries(32) // Handle 429 errors
.build();
}

@Override
protected CompletableFuture<String> callApiForJson(String prompt, com.openai.models.ResponseFormatJsonSchema.JsonSchema.Schema schema) {
ChatCompletionCreateParams request = ChatCompletionCreateParams.builder()
.model(OPENAI_CHAT_MODEL)
.maxCompletionTokens(MAX_RESPONSE_TOKENS)
.responseFormat(ResponseFormatJsonSchema.builder()
.jsonSchema(JsonSchema.builder()
.name("structured-response")
.schema(schema)
.strict(true)
.build())
.build())
.addSystemMessage(SYSTEM_MESSAGE)
.addUserMessage(prompt)
.build();

return retryOnOverload(
() -> _openAIClient.chat().completions().create(request),
e -> e instanceof com.openai.errors.InternalServerException,
"OpenAI API call"
).thenApply(completion -> {
// update cost accumulator
_costMonitor.updateCost(completion.usage());

// return JSON string
return completion.choices().get(0).message().content().get();
});
}

@Override
protected void updateCostMonitor(Object apiResponse) {
// OpenAI response handling is done in callApiForJson
}
}
Loading