Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
18 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -89,7 +89,7 @@ public AiTab() {
}

public void initialize() {
this.viewModel = new AiTabViewModel(preferences);
this.viewModel = new AiTabViewModel(preferences, taskExecutor);

initializeEnableAi();
initializeAiProvider();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,11 +22,14 @@
import org.jabref.gui.preferences.PreferenceTabViewModel;
import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.logic.ai.AiPreferences;
import org.jabref.logic.ai.models.AiModelService;
import org.jabref.logic.ai.models.FetchAiModelsBackgroundTask;
import org.jabref.logic.ai.templates.AiTemplate;
import org.jabref.logic.l10n.Localization;
import org.jabref.logic.preferences.CliPreferences;
import org.jabref.logic.util.LocalizedNumbers;
import org.jabref.logic.util.OptionalObjectProperty;
import org.jabref.logic.util.TaskExecutor;
import org.jabref.logic.util.strings.StringUtil;
import org.jabref.model.ai.AiProvider;
import org.jabref.model.ai.EmbeddingModel;
Expand Down Expand Up @@ -107,6 +110,8 @@ AiTemplate.CITATION_PARSING_USER_MESSAGE, new SimpleStringProperty()
private final BooleanProperty disableExpertSettings = new SimpleBooleanProperty(true);

private final AiPreferences aiPreferences;
private final AiModelService aiModelService;
private final TaskExecutor taskExecutor;

private final Validator apiKeyValidator;
private final Validator chatModelValidator;
Expand All @@ -121,10 +126,12 @@ AiTemplate.CITATION_PARSING_USER_MESSAGE, new SimpleStringProperty()
private final Validator ragMinScoreTypeValidator;
private final Validator ragMinScoreRangeValidator;

public AiTabViewModel(CliPreferences preferences) {
public AiTabViewModel(CliPreferences preferences, TaskExecutor taskExecutor) {
this.oldLocale = Locale.getDefault();

this.aiPreferences = preferences.getAiPreferences();
this.aiModelService = new AiModelService();
this.taskExecutor = taskExecutor;

this.enableAi.addListener((_, _, newValue) -> {
disableBasicSettings.set(!newValue);
Expand Down Expand Up @@ -428,6 +435,48 @@ public void resetCurrentTemplate() {
});
}

/**
* Fetches available models for the currently selected AI provider.
* Attempts to fetch models dynamically from the API, falling back to hardcoded models if fetch fails.
* This method runs asynchronously using a BackgroundTask and updates the chatModelsList when complete.
*/
public void refreshAvailableModels() {
AiProvider provider = selectedAiProvider.get();
if (provider == null) {
return;
}

String apiKey = currentApiKey.get();
String apiBaseUrl = customizeExpertSettings.get() ? currentApiBaseUrl.get() : provider.getApiUrl();

List<String> staticModels = aiModelService.getStaticModels(provider);
chatModelsList.setAll(staticModels);

FetchAiModelsBackgroundTask fetchTask = getAiModelsBackgroundTask(provider, apiBaseUrl, apiKey);

fetchTask.executeWith(taskExecutor);
}

private FetchAiModelsBackgroundTask getAiModelsBackgroundTask(AiProvider provider, String apiBaseUrl, String apiKey) {
FetchAiModelsBackgroundTask fetchTask = new FetchAiModelsBackgroundTask(
aiModelService,
provider,
apiBaseUrl,
apiKey
);

fetchTask.onSuccess(dynamicModels -> {
if (!dynamicModels.isEmpty()) {
String currentModel = currentChatModel.get();
chatModelsList.setAll(dynamicModels);
if (currentModel != null && !currentModel.isBlank()) {
currentChatModel.set(currentModel);
}
}
});
return fetchTask;
}

@Override
public boolean validateSettings() {
if (enableAi.get()) {
Expand Down
1 change: 1 addition & 0 deletions jablib/src/main/java/module-info.java
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,7 @@
exports org.jabref.model.groups.event;
exports org.jabref.logic.preview;
exports org.jabref.logic.ai;
exports org.jabref.logic.ai.models;
exports org.jabref.logic.pdf;
exports org.jabref.model.database.event;
exports org.jabref.model.entry.event;
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
package org.jabref.logic.ai.models;

import java.util.List;

import org.jabref.model.ai.AiProvider;

import org.jspecify.annotations.NullMarked;

/**
* Interface for fetching available AI models from different providers.
* Implementations should handle API calls to retrieve model lists dynamically.
*/
@NullMarked
public interface AiModelProvider {
/**
* Fetches the list of available models for the given AI provider.
*
* @param aiProvider The AI provider to fetch models from
* @param apiBaseUrl The base URL for the API
* @param apiKey The API key for authentication
* @return A list of available model names
*/
List<String> fetchModels(AiProvider aiProvider, String apiBaseUrl, String apiKey);

/**
* Checks if this provider supports the given AI provider type.
*
* @param aiProvider The AI provider to check
* @return true if this provider can fetch models for the given AI provider
*/
boolean supports(AiProvider aiProvider);
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
package org.jabref.logic.ai.models;

import java.util.List;

import org.jabref.logic.ai.AiDefaultPreferences;
import org.jabref.model.ai.AiProvider;

import org.jspecify.annotations.NullMarked;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Service for managing AI models from different providers.
* Provides both static (hardcoded) and dynamic (API-fetched) model lists.
*/
@NullMarked
public class AiModelService {
private static final Logger LOGGER = LoggerFactory.getLogger(AiModelService.class);

private final List<AiModelProvider> modelProviders = List.of(new OpenAiCompatibleModelProvider());

/**
* Gets the list of available models for the given provider.
* First attempts to fetch models dynamically from the API.
* If that fails or times out, falls back to the hardcoded list.
*
* @param aiProvider The AI provider
* @param apiBaseUrl The base URL for the API
* @param apiKey The API key for authentication
* @return A list of available model names
*/
public List<String> getAvailableModels(AiProvider aiProvider, String apiBaseUrl, String apiKey) {
List<String> dynamicModels = fetchModelsSynchronously(aiProvider, apiBaseUrl, apiKey);

if (!dynamicModels.isEmpty()) {
LOGGER.info("Using {} dynamic models for {}", dynamicModels.size(), aiProvider.getLabel());
return dynamicModels;
}

List<String> staticModels = AiDefaultPreferences.getAvailableModels(aiProvider);
LOGGER.debug("Using {} hardcoded models for {}", staticModels.size(), aiProvider.getLabel());
return staticModels;
}

/**
* Gets the list of available models for the given provider, using only hardcoded values.
*
* @param aiProvider The AI provider
* @return A list of available model names
*/
public List<String> getStaticModels(AiProvider aiProvider) {
return AiDefaultPreferences.getAvailableModels(aiProvider);
}

/**
* Synchronously fetches the list of available models from the API.
* This method will block until the fetch completes or the HTTP client times out.
*
* @param aiProvider The AI provider
* @param apiBaseUrl The base URL for the API
* @param apiKey The API key for authentication
* @return A list of model names, or an empty list if the fetch fails
*/
public List<String> fetchModelsSynchronously(AiProvider aiProvider, String apiBaseUrl, String apiKey) {
for (AiModelProvider provider : modelProviders) {
if (provider.supports(aiProvider)) {
try {
List<String> models = provider.fetchModels(aiProvider, apiBaseUrl, apiKey);
if (models.isEmpty()) {
return models;
}
} catch (Exception e) {
LOGGER.debug("Failed to fetch models for {}", aiProvider.getLabel(), e);
}
}
}

return List.of();
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
package org.jabref.logic.ai.models;

import java.util.List;

import org.jabref.logic.l10n.Localization;
import org.jabref.logic.util.BackgroundTask;
import org.jabref.model.ai.AiProvider;

/**
* Background task for fetching AI models from a provider's API.
*/
public class FetchAiModelsBackgroundTask extends BackgroundTask<List<String>> {

private final AiModelService aiModelService;
private final AiProvider aiProvider;
private final String apiBaseUrl;
private final String apiKey;

public FetchAiModelsBackgroundTask(AiModelService aiModelService, AiProvider aiProvider, String apiBaseUrl, String apiKey) {
this.aiModelService = aiModelService;
this.aiProvider = aiProvider;
this.apiBaseUrl = apiBaseUrl;
this.apiKey = apiKey;

configure();
}

private void configure() {
showToUser(false);
titleProperty().set(Localization.lang("Fetching models for %0", aiProvider.getLabel()));
willBeRecoveredAutomatically(true);
}

@Override
public List<String> call() {
return aiModelService.fetchModelsSynchronously(
aiProvider,
apiBaseUrl,
apiKey
);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package org.jabref.logic.ai.models;

import java.util.ArrayList;
import java.util.List;

import org.jabref.model.ai.AiProvider;

import kong.unirest.core.HttpResponse;
import kong.unirest.core.JsonNode;
import kong.unirest.core.Unirest;
import kong.unirest.core.UnirestException;
import kong.unirest.core.json.JSONArray;
import kong.unirest.core.json.JSONObject;
import org.jspecify.annotations.NullMarked;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* Model provider for OpenAI-compatible APIs.
* Fetches available models from the /v1/models endpoint.
* Mistral provides an OpenAI-compatible API, so this works for Mistral as well.
*/
@NullMarked
public class OpenAiCompatibleModelProvider implements AiModelProvider {

private static final Logger LOGGER = LoggerFactory.getLogger(OpenAiCompatibleModelProvider.class);

@Override
public List<String> fetchModels(AiProvider aiProvider, String apiBaseUrl, String apiKey) {
if (apiKey.isBlank()) {
LOGGER.debug("API key is not provided for {}, skipping model fetch", aiProvider.getLabel());
return List.of();
}

List<String> models = List.of();

try {
String modelsEndpoint = buildModelsEndpoint(apiBaseUrl);
HttpResponse<JsonNode> response = Unirest.get(modelsEndpoint)
.header("Authorization", "Bearer " + apiKey)
.header("accept", "application/json")
.asJson();

if (response.getStatus() == 200) {
models = parseModelsFromResponse(response.getBody());
LOGGER.info("Successfully fetched {} models from {}", models.size(), aiProvider.getLabel());
} else {
LOGGER.debug("Failed to fetch models from {} (status: {})", aiProvider.getLabel(), response.getStatus());
}
} catch (UnirestException e) {
LOGGER.debug("Failed to fetch models from {}", aiProvider.getLabel(), e);
} catch (Exception e) {
LOGGER.debug("Unexpected error while fetching models from {}", aiProvider.getLabel(), e);
}

return models;
}

@Override
public boolean supports(AiProvider aiProvider) {
return aiProvider == AiProvider.OPEN_AI
|| aiProvider == AiProvider.MISTRAL_AI
|| aiProvider == AiProvider.GPT4ALL;
}

/**
* Builds the URL for the models endpoint from the given API base URL.
* <p>
* The OpenAI API specification defines the models endpoint at /v1/models.
* This method handles various URL formats:
* <ul>
* <li>If the URL already ends with /v1, appends /models</li>
* <li>If the URL doesn't end with /v1, appends /v1/models</li>
* <li>Removes trailing slashes before building the path</li>
* </ul>
*
* @param apiBaseUrl the base URL of the API (e.g., "https://api.openai.com" or "https://api.openai.com/v1")
* @return the complete URL for the models endpoint
*/
private String buildModelsEndpoint(String apiBaseUrl) {
String baseUrl = apiBaseUrl.trim();
if (baseUrl.endsWith("/")) {
baseUrl = baseUrl.substring(0, baseUrl.length() - 1);
}

if (baseUrl.endsWith("/v1")) {
return baseUrl + "/models";
} else {
return baseUrl + "/v1/models";
}
}

private List<String> parseModelsFromResponse(JsonNode jsonNode) {
List<String> models = new ArrayList<>();

try {
JSONObject jsonResponse = jsonNode.getObject();

if (jsonResponse.has("data")) {
JSONArray modelsArray = jsonResponse.getJSONArray("data");

for (int i = 0; i < modelsArray.length(); i++) {
JSONObject modelObject = modelsArray.getJSONObject(i);
if (modelObject.has("id")) {
String modelId = modelObject.getString("id");
models.add(modelId);
}
}
}
} catch (Exception e) {
LOGGER.warn("Failed to parse models response.", e);
}

return models;
}
}
1 change: 1 addition & 0 deletions jablib/src/main/resources/l10n/JabRef_en.properties
Original file line number Diff line number Diff line change
Expand Up @@ -1806,6 +1806,7 @@ Processing\ entry\ %0\ of\ %1=Processing entry %0 of %1
Fetching\ and\ merging\ entry(s)=Fetching and merging entry(s)
No\ updates\ found.=No updates found.
Fetching\ information\ using\ %0=Fetching information using %0
Fetching\ models\ for\ %0=Fetching models for %0
No\ information\ added=No information added
Updated\ entry\ with\ info\ from\ %0=Updated entry with info from %0
Add\ new\ list=Add new list
Expand Down
Loading
Loading