diff --git a/core/src/main/kotlin/org/evomaster/core/EMConfig.kt b/core/src/main/kotlin/org/evomaster/core/EMConfig.kt index b29a7ec84d..61f3ff3d83 100644 --- a/core/src/main/kotlin/org/evomaster/core/EMConfig.kt +++ b/core/src/main/kotlin/org/evomaster/core/EMConfig.kt @@ -591,6 +591,14 @@ class EMConfig { throw ConfigProblemException("The use of 'security' requires 'minimize'") } + if (languageModelConnector && languageModelServerURL.isNullOrEmpty()) { + throw ConfigProblemException("Language model server URL cannot be empty.") + } + + if (languageModelConnector && languageModelName.isNullOrEmpty()) { + throw ConfigProblemException("Language model name cannot be empty.") + } + if(prematureStop.isNotEmpty() && stoppingCriterion != StoppingCriterion.TIME){ throw ConfigProblemException("The use of 'prematureStop' is meaningful only if the stopping criterion" + " 'stoppingCriterion' is based on time") @@ -2240,16 +2248,16 @@ class EMConfig { RANDOM } - @Cfg("Specify a method to select the first external service spoof IP address.") @Experimental + @Cfg("Specify a method to select the first external service spoof IP address.") var externalServiceIPSelectionStrategy = ExternalServiceIPSelectionStrategy.NONE + @Experimental @Cfg("User provided external service IP." + " When EvoMaster mocks external services, mock server instances will run on local addresses starting from" + " this provided address." + " Min value is ${defaultExternalServiceIP}." + " Lower values like ${ExternalServiceSharedUtils.RESERVED_RESOLVED_LOCAL_IP} and ${ExternalServiceSharedUtils.DEFAULT_WM_LOCAL_IP} are reserved.") - @Experimental @Regex(externalServiceIPRegex) var externalServiceIP : String = defaultExternalServiceIP @@ -2280,26 +2288,24 @@ class EMConfig { @Probability(true) var useExtraSqlDbConstraintsProbability = 0.9 - - @Cfg("a probability of harvesting actual responses from external services as seeds.") @Experimental + @Cfg("a probability of harvesting actual responses from external services as seeds.") @Probability(activating = true) var probOfHarvestingResponsesFromActualExternalServices = 0.0 - - @Cfg("a probability of prioritizing to employ successful harvested actual responses from external services as seeds (e.g., 2xx from HTTP external service).") @Experimental + @Cfg("a probability of prioritizing to employ successful harvested actual responses from external services as seeds (e.g., 2xx from HTTP external service).") @Probability(activating = true) var probOfPrioritizingSuccessfulHarvestedActualResponses = 0.0 - @Cfg("a probability of mutating mocked responses based on actual responses") @Experimental + @Cfg("a probability of mutating mocked responses based on actual responses") @Probability(activating = true) var probOfMutatingResponsesBasedOnActualResponse = 0.0 + @Experimental @Cfg("Number of threads for external request harvester. No more threads than numbers of processors will be used.") @Min(1.0) - @Experimental var externalRequestHarvesterNumberOfThreads: Int = 2 @@ -2328,8 +2334,8 @@ class EMConfig { RANDOM } - @Cfg("Harvested external request response selection strategy") @Experimental + @Cfg("Harvested external request response selection strategy") var externalRequestResponseSelectionStrategy = ExternalRequestResponseSelectionStrategy.EXACT @Cfg("Whether to employ constraints specified in API schema (e.g., OpenAPI) in test generation") @@ -2378,6 +2384,23 @@ class EMConfig { @Cfg("Apply a security testing phase after functional test cases have been generated.") var security = true + @Experimental + @Cfg("Enable language model connector") + var languageModelConnector = false + + @Experimental + @Cfg("Large-language model external service URL. Default is set to Ollama local instance URL.") + var languageModelServerURL: String = "http://localhost:11434/" + + @Experimental + @Cfg("Large-language model name as listed in Ollama") + var languageModelName: String = "llama3.2:latest" + + @Experimental + @Cfg("Number of threads for language model connector. No more threads than numbers of processors will be used.") + @Min(1.0) + var languageModelConnectorNumberOfThreads: Int = 2 + @Cfg("If there is no configuration file, create a default template at given configPath location." + " However this is done only on the 'default' location. If you change 'configPath', no new file will be" + diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/AnsweredPrompt.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/AnsweredPrompt.kt new file mode 100644 index 0000000000..db29f96949 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/AnsweredPrompt.kt @@ -0,0 +1,7 @@ +package org.evomaster.core.languagemodel.data + +class AnsweredPrompt ( + val prompt: Prompt, + val answer: String, +) { +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/Prompt.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/Prompt.kt new file mode 100644 index 0000000000..d3027ea355 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/Prompt.kt @@ -0,0 +1,10 @@ +package org.evomaster.core.languagemodel.data + +import java.util.UUID + +class Prompt( + val id: UUID, + + val prompt: String +) { +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaEndpoints.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaEndpoints.kt new file mode 100644 index 0000000000..e661e5808b --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaEndpoints.kt @@ -0,0 +1,30 @@ +package org.evomaster.core.languagemodel.data.ollama + +class OllamaEndpoints { + + companion object { + /** + * API URL to generate a response for a given prompt with a provided model. + */ + const val GENERATE_ENDPOINT = "/api/generate" + + /** + * API URL to list models that are available locally. + */ + const val TAGS_ENDPOINT = "/api/tags" + + fun getGenerateEndpoint(serverURL: String): String { + return cleanURL(serverURL) + GENERATE_ENDPOINT + } + + fun getTagEndpoint(serverURL: String): String { + return cleanURL(serverURL) + TAGS_ENDPOINT + } + + private fun cleanURL(serverURL: String): String { + return if (serverURL.endsWith("/")) serverURL.dropLast(1) else serverURL + } + } + + +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModel.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModel.kt new file mode 100644 index 0000000000..0d2fcf610a --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModel.kt @@ -0,0 +1,17 @@ +package org.evomaster.core.languagemodel.data.ollama + +class OllamaModel { + + val name: String = "" + + val model: String = "" + + val modified_at: String = "" + + val size: Int = 0 + + val digest: String = "" + + val details: OllamaModelDetail = OllamaModelDetail() + +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModelDetail.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModelDetail.kt new file mode 100644 index 0000000000..9396001292 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModelDetail.kt @@ -0,0 +1,15 @@ +package org.evomaster.core.languagemodel.data.ollama + +class OllamaModelDetail { + val parent_model: String = "" + + val format: String = "" + + val family: String = "" + + val families: List = listOf() + + val parameter_size: String = "" + + val quantization_level: String = "" +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModelResponse.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModelResponse.kt new file mode 100644 index 0000000000..67521af4f7 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaModelResponse.kt @@ -0,0 +1,5 @@ +package org.evomaster.core.languagemodel.data.ollama + +class OllamaModelResponse { + val models: List = listOf() +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaRequest.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaRequest.kt new file mode 100644 index 0000000000..6fc1aebca1 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaRequest.kt @@ -0,0 +1,21 @@ +package org.evomaster.core.languagemodel.data.ollama + +/** + * DTO to represent the Ollama request schema. + */ +class OllamaRequest ( + val model: String, + + /** + * Contains the string of the prompt for the language model. + */ + val prompt: String, + + /** + * False will return the response as a single object; meanwhile, + * True will respond a stream of objects. + */ + val stream: Boolean = false +) { + +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaRequestVerb.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaRequestVerb.kt new file mode 100644 index 0000000000..73b098a383 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaRequestVerb.kt @@ -0,0 +1,6 @@ +package org.evomaster.core.languagemodel.data.ollama + +enum class OllamaRequestVerb { + GET, + POST; +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaResponse.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaResponse.kt new file mode 100644 index 0000000000..8ab638bfd4 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/data/ollama/OllamaResponse.kt @@ -0,0 +1,37 @@ +package org.evomaster.core.languagemodel.data.ollama + +/** + * DTO to represent the Ollama response schema. + */ +class OllamaResponse { + + /** + * Used model name + */ + val model: String = "" + + val created_at: String = "" + + /** + * Contains the response string for non-stream output + */ + val response: String = "" + + val done: Boolean = false + + val done_reason: String = "" + + val context: List = emptyList() + + val total_duration: Int = 0 + + val load_duration: Int = 0 + + val prompt_eval_count: Int = 0 + + val prompt_eval_duration: Int = 0 + + val eval_count: Int = 0 + + val eval_duration: Int = 0 +} diff --git a/core/src/main/kotlin/org/evomaster/core/languagemodel/service/LanguageModelConnector.kt b/core/src/main/kotlin/org/evomaster/core/languagemodel/service/LanguageModelConnector.kt new file mode 100644 index 0000000000..36dbd5f240 --- /dev/null +++ b/core/src/main/kotlin/org/evomaster/core/languagemodel/service/LanguageModelConnector.kt @@ -0,0 +1,332 @@ +package org.evomaster.core.languagemodel.service + +import com.fasterxml.jackson.databind.ObjectMapper +import com.google.inject.Inject +import org.evomaster.core.EMConfig +import org.evomaster.core.languagemodel.data.AnsweredPrompt +import org.evomaster.core.languagemodel.data.ollama.OllamaModelResponse +import org.evomaster.core.languagemodel.data.ollama.OllamaRequest +import org.evomaster.core.languagemodel.data.ollama.OllamaResponse +import org.evomaster.core.languagemodel.data.Prompt +import org.evomaster.core.languagemodel.data.ollama.OllamaEndpoints +import org.evomaster.core.languagemodel.data.ollama.OllamaRequestVerb +import org.evomaster.core.logging.LoggingUtil +import org.evomaster.core.remote.HttpClientFactory +import org.slf4j.Logger +import org.slf4j.LoggerFactory +import java.util.UUID +import java.util.concurrent.CompletableFuture +import java.util.concurrent.ConcurrentHashMap +import java.util.concurrent.ExecutorService +import java.util.concurrent.Executors +import javax.annotation.PostConstruct +import javax.annotation.PreDestroy +import javax.ws.rs.client.Client +import javax.ws.rs.client.Entity +import javax.ws.rs.core.MediaType +import javax.ws.rs.core.Response +import kotlin.math.min + +/** + * A utility service designed to handle large language model server + * related functions. + * + * Designed to work with Ollama (version 0.7.0). + */ +class LanguageModelConnector { + + @Inject + private lateinit var config: EMConfig + + /** + * Holds request prompts using [query], [queryAsync], and [addPrompt] + * as [AnsweredPrompt]. + * Key holds the promptId as type [UUID], and the value is the type of [AnsweredPrompt]. + */ + private var prompts: MutableMap = mutableMapOf() + + private val objectMapper = ObjectMapper() + + private var actualFixedThreadPool = 0 + + private lateinit var workerPool: ExecutorService + + private val httpClients: ConcurrentHashMap = ConcurrentHashMap() + + private var isLanguageModelAvailable: Boolean = false + + companion object { + private val log: Logger = LoggerFactory.getLogger(LanguageModelConnector::class.java) + } + + @PostConstruct + fun init() { + if (config.languageModelConnector) { + LoggingUtil.Companion.getInfoLogger().info("Initializing {}", LanguageModelConnector::class.simpleName) + + if (!this.isModelAvailable()) { + LoggingUtil.uniqueWarn( + log, "${config.languageModelName} is not available in the provided " + + "language model server URL: ${config.languageModelServerURL}. " + + "Language Model Connector will be disabled." + ) + return + } else { + LoggingUtil.getInfoLogger().info("Language model ${config.languageModelName} is available.") + isLanguageModelAvailable = true + } + + actualFixedThreadPool = min( + config.languageModelConnectorNumberOfThreads, + Runtime.getRuntime().availableProcessors() + ) + workerPool = Executors.newFixedThreadPool( + actualFixedThreadPool + ) + } + } + + @PreDestroy + private fun preDestroy() { + if (config.languageModelConnector) { + httpClients.values.forEach { it.close() } + workerPool.shutdown() + prompts.clear() + } + } + + /** + * For testing purposes. + * @return number of [Client] in [httpClients] + */ + fun getHttpClientCount() = httpClients.size + + /** + * Use concurrent programming to make prompt request asynchronously. + * @return the [CompletableFuture] for the prompt. + */ + fun queryAsync(prompt: String): CompletableFuture { + if (!config.languageModelConnector) { + throw IllegalStateException("Language Model Connector is disabled") + } + + if (!isLanguageModelAvailable) { + throw IllegalStateException("Specified Language Model (${config.languageModelName}) is not available in the server.") + } + + val promptDto = Prompt(getIdForPrompt(), prompt) + + val client = httpClients.getOrPut(Thread.currentThread().id) { + getHttpClient() + } + + val future = CompletableFuture.supplyAsync { + makeQueryWithClient(client, promptDto) + } + + return future + } + + /** + * Added prompt will be queried in a separate thread without + * blocking the main thread. + * [getAnswerByPrompt] and [getAnswerById] can be used to retrieve the + * answers. + * @return unique prompt identifier as [UUID] + * @throws [IllegalStateException] if the connector is disabled in [EMConfig] + */ + fun addPrompt(prompt: String): UUID { + if (!config.languageModelConnector) { + throw IllegalStateException("Language Model Connector is disabled") + } + + if (!isLanguageModelAvailable) { + throw IllegalStateException("Specified Language Model (${config.languageModelName}) is not available in the server.") + } + + val promptId = getIdForPrompt() + + val promptDto = Prompt(promptId, prompt) + + val task = Runnable { + val id = Thread.currentThread().id + val httpClient = httpClients.getOrPut(id) { + getHttpClient() + } + makeQueryWithClient(httpClient, promptDto) + } + + workerPool.submit(task) + + return promptId + } + + /** + * @return answer for the prompt as [Prompt] if exists + * @return null if there is no answer for the prompt + */ + fun getAnswerByPrompt(prompt: String): AnsweredPrompt? { + return prompts.filter { it.value.prompt.prompt == prompt && !it.value.answer.isNullOrEmpty() }.values.firstOrNull() + } + + /** + * @param id unique identifier returned when [addPrompt] invoked. + * @return answer for the UUID of the prompt + */ + fun getAnswerById(id: UUID): AnsweredPrompt? { + return prompts[id] + } + + /** + * To query the large language server with a simple prompt. + * @return answer string from the language model server. + * @return null if the request failed. + */ + fun query(prompt: String): AnsweredPrompt? { + if (!config.languageModelConnector) { + throw IllegalStateException("Language Model Connector is disabled") + } + + if (!isLanguageModelAvailable) { + throw IllegalStateException("Specified Language Model (${config.languageModelName}) is not available in the server.") + } + + val promptDto = Prompt(getIdForPrompt(), prompt) + + val client = httpClients.getOrPut(Thread.currentThread().id) { + getHttpClient() + } + + val response = makeQueryWithClient(client, promptDto) + + return response + } + + /** + * @return the given structured request for the prompt. + */ + fun queryStructured(prompt: String) { + if (!config.languageModelConnector) { + throw IllegalStateException("Language Model Connector is disabled") + } + + TODO("Requires more time to implement this.") + } + + private fun isModelAvailable(): Boolean { + val url = OllamaEndpoints + .getTagEndpoint(config.languageModelServerURL) + + val client = httpClients.getOrPut(Thread.currentThread().id) { + getHttpClient() + } + + val response = callWithClient(client, url, OllamaRequestVerb.GET) + + if (response != null && response.status == 200 && response.hasEntity()) { + val body = response.readEntity(String::class.java) + + val bodyObject = objectMapper.readValue( + body, + OllamaModelResponse::class.java + ) + + if (bodyObject.models.any { it.name == config.languageModelName }) { + return true + } + } + + return false + } + + /** + * @return [AnsweredPrompt] if the request is successfully completed. + * @return null if the request failed. + */ + private fun makeQueryWithClient(httpClient: Client, prompt: Prompt): AnsweredPrompt? { + val languageModelServerURL = OllamaEndpoints + .getGenerateEndpoint(config.languageModelServerURL) + + val requestBody = objectMapper.writeValueAsString( + OllamaRequest( + config.languageModelName, + prompt.prompt, + false + ) + ) + + val response = callWithClient(httpClient, languageModelServerURL, OllamaRequestVerb.POST, requestBody) + + if (response != null && response.status == 200 && response.hasEntity()) { + val body = response.readEntity(String::class.java) + val bodyObject = objectMapper.readValue( + body, + OllamaResponse::class.java + ) + + val answer = AnsweredPrompt( + prompt, + bodyObject.response + ) + + prompts[prompt.id] = answer + + return answer + } + + return null + } + + /** + * @return [Response] for the request. + * Private method to make the call to the large language model server. + * + * Note: If you are using Ollama as a server, please make sure to set the + * CORS origin for Ollama on the host operating system. + * + * Reference: + * https://medium.com/dcoderai/how-to-handle-cors-settings-in-ollama-a-comprehensive-guide-ee2a5a1beef0 + */ + private fun callWithClient( + httpClient: Client, + languageModelServerURL: String, + requestMethod: OllamaRequestVerb, + requestBody: String? = "", + ): Response? { + val bodyEntity = Entity.entity(requestBody, MediaType.APPLICATION_JSON_TYPE) + + val builder = httpClient.target(languageModelServerURL) + .request("application/json") + + val invocation = when (requestMethod) { + OllamaRequestVerb.GET -> builder.buildGet() + OllamaRequestVerb.POST -> builder.buildPost(bodyEntity) + } + + val response = try { + invocation.invoke() + } catch (e: Exception) { + LoggingUtil.uniqueWarn(log, "Failed to connect to the language model server. Error: ${e.message}") + + return null + } + + return response + } + + /** + * @return unique prompt identifier as [UUID] + */ + private fun getIdForPrompt(): UUID { + return UUID.randomUUID() + } + + /** + * @return new [Client] from [HttpClientFactory] + */ + private fun getHttpClient(): Client { + return HttpClientFactory + .createTrustingJerseyClient(false, 60_000) + } + +} diff --git a/core/src/main/kotlin/org/evomaster/core/problem/enterprise/service/EnterpriseModule.kt b/core/src/main/kotlin/org/evomaster/core/problem/enterprise/service/EnterpriseModule.kt index 5bad3865a1..a6e4aa6097 100644 --- a/core/src/main/kotlin/org/evomaster/core/problem/enterprise/service/EnterpriseModule.kt +++ b/core/src/main/kotlin/org/evomaster/core/problem/enterprise/service/EnterpriseModule.kt @@ -1,6 +1,7 @@ package org.evomaster.core.problem.enterprise.service import com.google.inject.AbstractModule +import org.evomaster.core.languagemodel.service.LanguageModelConnector abstract class EnterpriseModule : AbstractModule() { @@ -9,5 +10,8 @@ abstract class EnterpriseModule : AbstractModule() { bind(WFCReportWriter::class.java) .asEagerSingleton() + + bind(LanguageModelConnector::class.java) + .asEagerSingleton() } -} \ No newline at end of file +} diff --git a/core/src/main/kotlin/org/evomaster/core/problem/rest/service/module/RestModule.kt b/core/src/main/kotlin/org/evomaster/core/problem/rest/service/module/RestModule.kt index a143562389..92c9f1289d 100644 --- a/core/src/main/kotlin/org/evomaster/core/problem/rest/service/module/RestModule.kt +++ b/core/src/main/kotlin/org/evomaster/core/problem/rest/service/module/RestModule.kt @@ -1,6 +1,7 @@ package org.evomaster.core.problem.rest.service.module import com.google.inject.TypeLiteral +import org.evomaster.core.languagemodel.service.LanguageModelConnector import org.evomaster.core.problem.externalservice.httpws.service.HarvestActualHttpWsResponseHandler import org.evomaster.core.problem.externalservice.httpws.service.HttpWsExternalServiceHandler import org.evomaster.core.problem.rest.data.RestIndividual @@ -18,61 +19,58 @@ import org.evomaster.core.search.service.mutator.Mutator import org.evomaster.core.search.service.mutator.StructureMutator -class RestModule(private val bindRemote : Boolean = true) : RestBaseModule(){ +class RestModule(private val bindRemote: Boolean = true) : RestBaseModule() { override fun configure() { super.configure() - if (bindRemote){ + if (bindRemote) { bind(RemoteController::class.java) .to(RemoteControllerImplementation::class.java) .asEagerSingleton() } bind(object : TypeLiteral>() {}) - .to(RestSampler::class.java) - .asEagerSingleton() + .to(RestSampler::class.java) + .asEagerSingleton() bind(object : TypeLiteral>() {}) - .to(RestSampler::class.java) - .asEagerSingleton() + .to(RestSampler::class.java) + .asEagerSingleton() bind(AbstractRestSampler::class.java) - .to(RestSampler::class.java) - .asEagerSingleton() + .to(RestSampler::class.java) + .asEagerSingleton() bind(RestSampler::class.java) - .asEagerSingleton() + .asEagerSingleton() bind(object : TypeLiteral>() {}) - .to(RestFitness::class.java) - .asEagerSingleton() + .to(RestFitness::class.java) + .asEagerSingleton() bind(object : TypeLiteral() {}) .to(ResourceRestFitness::class.java) .asEagerSingleton() - bind(object : TypeLiteral>() {}) - .to(RestFitness::class.java) - .asEagerSingleton() - + .to(RestFitness::class.java) + .asEagerSingleton() bind(object : TypeLiteral>() {}) - .to(object : TypeLiteral>(){}) - .asEagerSingleton() + .to(object : TypeLiteral>() {}) + .asEagerSingleton() bind(StructureMutator::class.java) - .to(RestStructureMutator::class.java) - .asEagerSingleton() - + .to(RestStructureMutator::class.java) + .asEagerSingleton() bind(HttpWsExternalServiceHandler::class.java) - .asEagerSingleton() + .asEagerSingleton() bind(HarvestActualHttpWsResponseHandler::class.java) .asEagerSingleton() } -} \ No newline at end of file +} diff --git a/core/src/test/kotlin/org/evomaster/core/languagemodel/LanguageModelConnectorTest.kt b/core/src/test/kotlin/org/evomaster/core/languagemodel/LanguageModelConnectorTest.kt new file mode 100644 index 0000000000..f46510b3fa --- /dev/null +++ b/core/src/test/kotlin/org/evomaster/core/languagemodel/LanguageModelConnectorTest.kt @@ -0,0 +1,129 @@ +package org.evomaster.core.languagemodel + +import com.google.inject.Injector +import com.netflix.governator.guice.LifecycleInjector +import org.evomaster.ci.utils.CIUtils +import org.evomaster.core.BaseModule +import org.evomaster.core.EMConfig +import org.evomaster.core.KGenericContainer +import org.evomaster.core.languagemodel.service.LanguageModelConnector +import org.junit.jupiter.api.Test +import org.junit.jupiter.api.AfterAll +import org.junit.jupiter.api.Assertions +import org.junit.jupiter.api.BeforeAll +import org.junit.jupiter.api.BeforeEach +import org.testcontainers.containers.wait.strategy.LogMessageWaitStrategy + +class LanguageModelConnectorTest { + + private lateinit var config: EMConfig + + val injector: Injector = LifecycleInjector.builder() + .withModules(BaseModule()) + .build().createInjector() + + private lateinit var languageModelConnector: LanguageModelConnector + + companion object { + + /** + * This chosen based on the two parameters, size and accuracy, + * after multiple manual trials with other smaller models. + * The model size is 815MB, so it might take a while to execute the test. + */ + private const val LANGUAGE_MODEL_NAME: String = "gemma3:1b" + + private const val PROMPT = "Is A is the first letter in english alphabet? say YES or NO" + + private const val EXPECTED_ANSWER = "YES\n" + + private val ollama = KGenericContainer("ollama/ollama:latest") + .withExposedPorts(11434) + .withEnv("OLLAMA_ORIGINS", "*") // This to allow avoiding CORS filtering. + + private var ollamaURL: String = "" + + @BeforeAll + @JvmStatic + fun initClass() { + + // This test takes time to download the LLM model inside + // docker. So it's wise to avoid running it on CI + // to reduce execution time. + CIUtils.skipIfOnGA() + + ollama.start() + + val host = ollama.host + val port = ollama.getMappedPort(11434)!! + + ollamaURL = "http://$host:$port/" + + ollama.execInContainer("ollama", "pull", LANGUAGE_MODEL_NAME) + + ollama.waitingFor( + LogMessageWaitStrategy() + .withRegEx(".*writing manifest \n success.*") + .withTimes(5) + ) + } + + @AfterAll + @JvmStatic + fun cleanClass() { + ollama.stop() + } + } + + @BeforeEach + fun prepareForTest() { + if (!ollama.isRunning) { + throw IllegalStateException("Ollama container is not running") + } + + config = injector.getInstance(EMConfig::class.java) + config.languageModelConnector = true + // If languageModelName or languageModelURL set to empty, an exception + // will the thrown. + config.languageModelName = LANGUAGE_MODEL_NAME + config.languageModelServerURL = ollamaURL + + languageModelConnector = injector.getInstance(LanguageModelConnector::class.java) + } + + @Test + fun testLocalOllamaConnection() { + // gemma3:1b returns with a newline character + val answer = languageModelConnector.query(PROMPT) + + Assertions.assertEquals(EXPECTED_ANSWER, answer!!.answer) + // We use HttpClient for two purposes by default when make a query. + // First time connector checks for the model availability, + // second to make the prompt query. + // This check validates if there is a client it is repurposed for the second query. + Assertions.assertEquals(1, languageModelConnector.getHttpClientCount()) + } + + @Test + fun testConcurrentRequests() { + // gemma3:1b returns with a newline character + val future = languageModelConnector.queryAsync(PROMPT) + + future.thenAccept { result -> + Assertions.assertEquals(EXPECTED_ANSWER, result!!.answer) + Assertions.assertEquals(1, languageModelConnector.getHttpClientCount()) + } + } + + @Test + fun testQueriedPrompts() { + val promptId = languageModelConnector.addPrompt(PROMPT) + + Thread.sleep(3000) + + val result = languageModelConnector.getAnswerById(promptId) + + Assertions.assertEquals(result!!.answer, EXPECTED_ANSWER) + Assertions.assertEquals(2, languageModelConnector.getHttpClientCount()) + } +} diff --git a/docs/options.md b/docs/options.md index faf452dcf8..c3b1dd516e 100644 --- a/docs/options.md +++ b/docs/options.md @@ -255,6 +255,10 @@ There are 3 types of options: |`httpOracles`| __Boolean__. Extra checks on HTTP properties in returned responses, used as automated oracles to detect faults. *Default value*: `false`.| |`initStructureMutationProbability`| __Double__. Probability of applying a mutation that can change the structure of test's initialization if it has. *Constraints*: `probability 0.0-1.0`. *Default value*: `0.0`.| |`instrumentMR_NET`| __Boolean__. Execute instrumentation for method replace with category NET. Note: this applies only for languages in which instrumentation is applied at runtime, like Java/Kotlin on the JVM. *Default value*: `false`.| +|`languageModelConnector`| __Boolean__. Enable language model connector. *Default value*: `false`.| +|`languageModelConnectorNumberOfThreads`| __Int__. Number of threads for language model connector. No more threads than numbers of processors will be used. *Constraints*: `min=1.0`. *Default value*: `2`.| +|`languageModelName`| __String__. Large-language model name as listed in Ollama. *Default value*: `llama3.2:latest`.| +|`languageModelServerURL`| __String__. Large-language model external service URL. Default is set to Ollama local instance URL. *Default value*: `http://localhost:11434/`.| |`maxResourceSize`| __Int__. Specify a max size of resources in a test. 0 means the there is no specified restriction on a number of resources. *Constraints*: `min=0.0`. *Default value*: `0`.| |`maxSizeDataPool`| __Int__. How much data elements, per key, can be stored in the Data Pool. Once limit is reached, new old will replace old data. *Constraints*: `min=1.0`. *Default value*: `100`.| |`maxSizeOfExistingDataToSample`| __Int__. Specify a maximum number of existing data in the database to sample in a test when SQL handling is enabled. Note that a negative number means all existing data would be sampled. *Default value*: `-1`.|