diff --git a/api/app/core/models/base.py b/api/app/core/models/base.py index 4a453c6b2..3b0a9d769 100644 --- a/api/app/core/models/base.py +++ b/api/app/core/models/base.py @@ -67,6 +67,35 @@ def get_model_params(cls, config: RedBearModelConfig) -> Dict[str, Any]: **config.extra_params } + if provider == ModelProvider.MINIMAX: + # MiniMax 使用 OpenAI 兼容模式,需要设置默认 base_url 和温度钳制 + import httpx + if not config.base_url: + config.base_url = "https://api.minimax.io/v1" + timeout_config = httpx.Timeout( + timeout=config.timeout, + connect=60.0, + read=config.timeout, + write=60.0, + pool=10.0, + ) + # MiniMax 温度范围为 (0.0, 1.0],需要钳制 + extra = dict(config.extra_params) + if "temperature" in extra: + temp = extra["temperature"] + if temp <= 0: + extra["temperature"] = 0.01 + elif temp > 1.0: + extra["temperature"] = 1.0 + return { + "model": config.model_name, + "base_url": config.base_url, + "api_key": config.api_key, + "timeout": timeout_config, + "max_retries": config.max_retries, + **extra + } + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.OLLAMA]: # 使用 httpx.Timeout 对象来设置详细的超时配置 # 这样可以分别控制连接超时和读取超时 @@ -165,6 +194,9 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy return OpenAI elif type == ModelType.CHAT: return ChatOpenAI + elif provider == ModelProvider.MINIMAX: + # MiniMax 使用 OpenAI 兼容 API,始终返回 ChatOpenAI + return ChatOpenAI elif provider == ModelProvider.DASHSCOPE: return ChatTongyi elif provider == ModelProvider.OLLAMA: @@ -178,7 +210,7 @@ def get_provider_llm_class(config: RedBearModelConfig, type: ModelType = ModelTy def get_provider_embedding_class(provider: str) -> type[Embeddings]: """根据模型提供商获取对应的模型类""" provider = provider.lower() - if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK]: + if provider in [ModelProvider.OPENAI, ModelProvider.XINFERENCE, ModelProvider.GPUSTACK, ModelProvider.MINIMAX]: from langchain_openai import OpenAIEmbeddings return OpenAIEmbeddings elif provider == ModelProvider.DASHSCOPE: diff --git a/api/app/models/models_model.py b/api/app/models/models_model.py index 23fafcefd..2f4e775f7 100644 --- a/api/app/models/models_model.py +++ b/api/app/models/models_model.py @@ -41,6 +41,7 @@ class ModelProvider(StrEnum): # ZHIPU = "zhipu" # MOONSHOT = "moonshot" # DEEPSEEK = "deepseek" + MINIMAX = "minimax" OLLAMA = "ollama" XINFERENCE = "xinference" GPUSTACK = "gpustack" diff --git a/api/app/services/llm_client.py b/api/app/services/llm_client.py index a7bc81b0b..a117635c4 100644 --- a/api/app/services/llm_client.py +++ b/api/app/services/llm_client.py @@ -226,55 +226,55 @@ async def chat(self, prompt: str, **kwargs) -> str: class MockLLMClient(BaseLLMClient): """模拟 LLM 客户端(用于测试)""" - + def __init__(self): """初始化模拟客户端""" self.call_count = 0 - + async def chat(self, prompt: str, **kwargs) -> str: """发送聊天请求(返回模拟结果)""" self.call_count += 1 - + logger.info(f"模拟 LLM 调用 (第 {self.call_count} 次)") - + # 简单的规则匹配 prompt_lower = prompt.lower() - + if "数学" in prompt_lower or "方程" in prompt_lower or "计算" in prompt_lower: return json.dumps({ "agent_id": "math-agent", "confidence": 0.9, "reason": "消息包含数学相关内容" }, ensure_ascii=False) - + elif "化学" in prompt_lower or "反应" in prompt_lower or "元素" in prompt_lower: return json.dumps({ "agent_id": "chemistry-agent", "confidence": 0.85, "reason": "消息包含化学相关内容" }, ensure_ascii=False) - + elif "物理" in prompt_lower or "力" in prompt_lower or "速度" in prompt_lower: return json.dumps({ "agent_id": "physics-agent", "confidence": 0.88, "reason": "消息包含物理相关内容" }, ensure_ascii=False) - + elif "语文" in prompt_lower or "古诗" in prompt_lower or "作文" in prompt_lower: return json.dumps({ "agent_id": "chinese-agent", "confidence": 0.87, "reason": "消息包含语文相关内容" }, ensure_ascii=False) - + elif "英语" in prompt_lower or "单词" in prompt_lower or "语法" in prompt_lower: return json.dumps({ "agent_id": "english-agent", "confidence": 0.86, "reason": "消息包含英语相关内容" }, ensure_ascii=False) - + else: return json.dumps({ "agent_id": "math-agent", @@ -283,6 +283,78 @@ async def chat(self, prompt: str, **kwargs) -> str: }, ensure_ascii=False) +class MiniMaxClient(BaseLLMClient): + """MiniMax LLM 客户端(通过 OpenAI 兼容 API)""" + + def __init__( + self, + api_key: Optional[str] = None, + model: str = "MiniMax-M2.7", + base_url: str = "https://api.minimax.io/v1" + ): + """初始化 MiniMax 客户端 + + Args: + api_key: API 密钥 + model: 模型名称 (MiniMax-M2.7, MiniMax-M2.7-highspeed) + base_url: API 基础 URL + """ + self.api_key = api_key or os.getenv("MINIMAX_API_KEY") + self.model = model + self.base_url = base_url + + if not self.api_key: + raise ValueError("MiniMax API key 未配置,请设置 MINIMAX_API_KEY 环境变量") + + try: + from openai import AsyncOpenAI + self.client = AsyncOpenAI( + api_key=self.api_key, + base_url=self.base_url + ) + except ImportError: + raise ImportError("请安装 openai 库: pip install openai") + + @staticmethod + def _clamp_temperature(temperature: float) -> float: + """钳制温度值到 MiniMax 支持的范围 (0.0, 1.0]""" + if temperature <= 0: + return 0.01 + if temperature > 1.0: + return 1.0 + return temperature + + async def chat(self, prompt: str, **kwargs) -> str: + """发送聊天请求 + + Args: + prompt: 提示词 + **kwargs: 其他参数(temperature, max_tokens 等) + + Returns: + LLM 响应文本 + """ + try: + temperature = self._clamp_temperature(kwargs.get("temperature", 0.3)) + response = await self.client.chat.completions.create( + model=self.model, + messages=[{"role": "user", "content": prompt}], + temperature=temperature, + max_tokens=kwargs.get("max_tokens", 500) + ) + + content = response.choices[0].message.content + # 去除 MiniMax M2.7 可能返回的思考标签 + if content and "" in content: + import re + content = re.sub(r".*?\s*", "", content, flags=re.DOTALL) + return content + + except Exception as e: + logger.error(f"MiniMax API 调用失败: {str(e)}") + raise + + class LLMClientFactory: """LLM 客户端工厂""" @@ -292,9 +364,9 @@ def create( **kwargs ) -> BaseLLMClient: """创建 LLM 客户端 - + Args: - provider: 提供商名称 (openai, azure, anthropic, local, mock) + provider: 提供商名称 (openai, azure, anthropic, minimax, local, mock) **kwargs: 客户端配置参数 Returns: @@ -304,16 +376,19 @@ def create( if provider == "openai": return OpenAIClient(**kwargs) - + elif provider == "azure": return AzureOpenAIClient(**kwargs) - + elif provider == "anthropic": return AnthropicClient(**kwargs) - + + elif provider == "minimax": + return MiniMaxClient(**kwargs) + elif provider == "local": return LocalLLMClient(**kwargs) - + elif provider == "mock": return MockLLMClient() diff --git a/api/env.example b/api/env.example index e324d1e5b..6acc430e4 100644 --- a/api/env.example +++ b/api/env.example @@ -51,7 +51,10 @@ ELASTICSEARCH_RETRY_ON_TIMEOUT= ELASTICSEARCH_MAX_RETRIES= # xinference configuration -XINFERENCE_URL= +XINFERENCE_URL= + +# MiniMax configuration +MINIMAX_API_KEY= # LangSmith configuration LANGCHAIN_TRACING_V2= diff --git a/api/tests/test_minimax_integration.py b/api/tests/test_minimax_integration.py new file mode 100644 index 000000000..dd128d214 --- /dev/null +++ b/api/tests/test_minimax_integration.py @@ -0,0 +1,79 @@ +# -*- coding: UTF-8 -*- +"""Integration tests for MiniMax LLM provider. + +These tests verify end-to-end MiniMax integration with actual API calls. +They require a valid MINIMAX_API_KEY environment variable and are skipped +when the key is not available. + +Usage: + MINIMAX_API_KEY=your-key cd api && python -m pytest tests/test_minimax_integration.py -v +""" + +import os +import sys + +import pytest + +API_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +MINIMAX_API_KEY = os.getenv("MINIMAX_API_KEY") +SKIP_REASON = "MINIMAX_API_KEY not set; skipping MiniMax integration tests" + + +@pytest.mark.skipif(not MINIMAX_API_KEY, reason=SKIP_REASON) +class TestMiniMaxClientIntegration: + """Integration tests for MiniMaxClient in services layer.""" + + @pytest.mark.asyncio + async def test_minimax_client_chat(self): + """Test MiniMaxClient.chat() with real API.""" + from app.services.llm_client import MiniMaxClient + + client = MiniMaxClient( + api_key=MINIMAX_API_KEY, + model="MiniMax-M2.7-highspeed" + ) + result = await client.chat( + "Reply with exactly the word 'pong'. Do not include any reasoning.", + temperature=0.1, + max_tokens=50 + ) + assert result is not None + # Result may be empty if model returns only think tags; + # the important thing is no exception + assert isinstance(result, str) + + @pytest.mark.asyncio + async def test_minimax_client_factory(self): + """Test LLMClientFactory creates working MiniMaxClient.""" + from app.services.llm_client import LLMClientFactory + + client = LLMClientFactory.create( + "minimax", + api_key=MINIMAX_API_KEY, + model="MiniMax-M2.7-highspeed" + ) + result = await client.chat( + "Reply with the number 42. Do not include any reasoning.", + temperature=0.1, + max_tokens=50 + ) + assert result is not None + + @pytest.mark.asyncio + async def test_minimax_client_temperature_edge(self): + """Test that temperature=0 works (clamped to 0.01).""" + from app.services.llm_client import MiniMaxClient + + client = MiniMaxClient( + api_key=MINIMAX_API_KEY, + model="MiniMax-M2.7-highspeed" + ) + result = await client.chat( + "Say hi. Do not include any reasoning.", + temperature=0, + max_tokens=50 + ) + assert result is not None diff --git a/api/tests/test_minimax_provider.py b/api/tests/test_minimax_provider.py new file mode 100644 index 000000000..3beea44c7 --- /dev/null +++ b/api/tests/test_minimax_provider.py @@ -0,0 +1,306 @@ +# -*- coding: UTF-8 -*- +"""Unit tests for MiniMax LLM provider integration. + +Tests cover: +- ModelProvider enum registration +- MiniMaxClient temperature clamping, think-tag stripping +- LLMClientFactory.create("minimax") dispatching + +Run: cd api && python -m pytest tests/test_minimax_provider.py -v +""" + +import os +import sys +import json +import importlib +import importlib.util +from enum import Enum +from unittest.mock import patch, AsyncMock, MagicMock + +import pytest + +# ---- Ensure app package is importable ---- +API_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if API_DIR not in sys.path: + sys.path.insert(0, API_DIR) + +# ---- StrEnum compat for Python < 3.11 ---- +if not hasattr(Enum, '__str_members__'): + try: + from enum import StrEnum # noqa: F401 + except ImportError: + import enum + class StrEnum(str, enum.Enum): + pass + enum.StrEnum = StrEnum + + +# --------------------------------------------------------------------------- +# Helper: isolated import of models_model +# --------------------------------------------------------------------------- + +def _import_models_model(): + """Import models_model.py directly, bypassing app.models.__init__.""" + spec = importlib.util.spec_from_file_location( + "models_model_isolated", + os.path.join(API_DIR, "app", "models", "models_model.py"), + ) + mod = importlib.util.module_from_spec(spec) + + # Stub out heavy deps + sa_stub = MagicMock() + pg_stub = MagicMock() + pg_stub.UUID = MagicMock(return_value=MagicMock()) + pg_stub.JSON = MagicMock() + db_stub = MagicMock() + db_stub.Base = type("Base", (), {"metadata": MagicMock()}) + + originals = {} + stubs = { + "sqlalchemy": sa_stub, + "sqlalchemy.orm": sa_stub.orm, + "sqlalchemy.sql": sa_stub.sql, + "sqlalchemy.dialects": sa_stub.dialects, + "sqlalchemy.dialects.postgresql": pg_stub, + "app.db": db_stub, + } + for m, s in stubs.items(): + originals[m] = sys.modules.get(m) + sys.modules[m] = s + + try: + spec.loader.exec_module(mod) + finally: + for m, v in originals.items(): + if v is None: + sys.modules.pop(m, None) + else: + sys.modules[m] = v + return mod + + +def _import_llm_client(): + """Import llm_client.py directly with minimal stubs.""" + spec = importlib.util.spec_from_file_location( + "llm_client_isolated", + os.path.join(API_DIR, "app", "services", "llm_client.py"), + ) + mod = importlib.util.module_from_spec(spec) + + logger_stub = MagicMock() + logging_mod = MagicMock() + logging_mod.get_business_logger = MagicMock(return_value=logger_stub) + saved = sys.modules.get("app.core.logging_config") + sys.modules["app.core.logging_config"] = logging_mod + try: + spec.loader.exec_module(mod) + finally: + if saved is None: + sys.modules.pop("app.core.logging_config", None) + else: + sys.modules["app.core.logging_config"] = saved + return mod + + +# =========================================================================== +# 1. ModelProvider enum +# =========================================================================== + +class TestModelProviderEnum: + + def test_minimax_in_model_provider(self): + mod = _import_models_model() + assert hasattr(mod.ModelProvider, "MINIMAX") + assert mod.ModelProvider.MINIMAX == "minimax" + + def test_minimax_is_str(self): + mod = _import_models_model() + assert isinstance(mod.ModelProvider.MINIMAX, str) + + def test_minimax_not_composite(self): + mod = _import_models_model() + assert mod.ModelProvider.MINIMAX != mod.ModelProvider.COMPOSITE + + def test_all_original_providers_preserved(self): + mod = _import_models_model() + for name in ["OPENAI", "DASHSCOPE", "OLLAMA", "XINFERENCE", "GPUSTACK", "BEDROCK", "COMPOSITE"]: + assert hasattr(mod.ModelProvider, name), f"Missing provider: {name}" + + +# =========================================================================== +# 2. MiniMaxClient +# =========================================================================== + +class TestMiniMaxClient: + + def test_missing_api_key_raises(self): + mod = _import_llm_client() + env = {k: v for k, v in os.environ.items() if k != "MINIMAX_API_KEY"} + with patch.dict(os.environ, env, clear=True): + with pytest.raises(ValueError, match="MiniMax API key"): + mod.MiniMaxClient(api_key=None) + + def test_env_api_key_used(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-from-env"}): + client = mod.MiniMaxClient() + assert client.api_key == "sk-from-env" + + def test_explicit_api_key_overrides_env(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-from-env"}): + client = mod.MiniMaxClient(api_key="sk-explicit") + assert client.api_key == "sk-explicit" + + def test_default_model(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient() + assert client.model == "MiniMax-M2.7" + + def test_custom_model(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient(model="MiniMax-M2.7-highspeed") + assert client.model == "MiniMax-M2.7-highspeed" + + def test_default_base_url(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient() + assert client.base_url == "https://api.minimax.io/v1" + + def test_temperature_clamping_static(self): + mod = _import_llm_client() + clamp = mod.MiniMaxClient._clamp_temperature + assert clamp(0) == 0.01 + assert clamp(-1.0) == 0.01 + assert clamp(0.5) == 0.5 + assert clamp(1.0) == 1.0 + assert clamp(2.0) == 1.0 + assert clamp(0.01) == 0.01 + assert clamp(0.99) == 0.99 + + @pytest.mark.asyncio + async def test_chat_strips_think_tags(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient() + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "reasoning\nHello!" + client.client = AsyncMock() + client.client.chat.completions.create = AsyncMock(return_value=mock_resp) + result = await client.chat("Hi") + assert "" not in result + assert "Hello!" in result + + @pytest.mark.asyncio + async def test_chat_normal_response(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient() + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "Hello world!" + client.client = AsyncMock() + client.client.chat.completions.create = AsyncMock(return_value=mock_resp) + result = await client.chat("Hi") + assert result == "Hello world!" + + @pytest.mark.asyncio + async def test_chat_uses_clamped_temperature(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient() + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = "ok" + client.client = AsyncMock() + client.client.chat.completions.create = AsyncMock(return_value=mock_resp) + await client.chat("test", temperature=0) + call_kwargs = client.client.chat.completions.create.call_args + assert call_kwargs.kwargs["temperature"] == 0.01 + + @pytest.mark.asyncio + async def test_chat_api_error_propagated(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient() + client.client = AsyncMock() + client.client.chat.completions.create = AsyncMock( + side_effect=Exception("API error") + ) + with pytest.raises(Exception, match="API error"): + await client.chat("test") + + @pytest.mark.asyncio + async def test_chat_multiline_think_tag(self): + """Think tag with multiple lines should be fully stripped.""" + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.MiniMaxClient() + mock_resp = MagicMock() + mock_resp.choices = [MagicMock()] + mock_resp.choices[0].message.content = ( + "\nStep 1: analyze\nStep 2: reason\n\n\nFinal answer." + ) + client.client = AsyncMock() + client.client.chat.completions.create = AsyncMock(return_value=mock_resp) + result = await client.chat("complex question") + assert "" not in result + assert "Final answer." in result + + +# =========================================================================== +# 3. LLMClientFactory +# =========================================================================== + +class TestLLMClientFactoryMiniMax: + + def test_create_minimax(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.LLMClientFactory.create("minimax") + assert isinstance(client, mod.MiniMaxClient) + + def test_create_minimax_uppercase(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"MINIMAX_API_KEY": "sk-test"}): + client = mod.LLMClientFactory.create("MiniMax") + assert isinstance(client, mod.MiniMaxClient) + + def test_create_from_env_minimax(self): + mod = _import_llm_client() + with patch.dict(os.environ, {"LLM_PROVIDER": "minimax", "MINIMAX_API_KEY": "sk-test"}): + client = mod.LLMClientFactory.create_from_env() + assert isinstance(client, mod.MiniMaxClient) + + def test_other_providers_still_work(self): + mod = _import_llm_client() + client = mod.LLMClientFactory.create("mock") + assert isinstance(client, mod.MockLLMClient) + + +# =========================================================================== +# 4. Temperature clamping edge cases +# =========================================================================== + +class TestMiniMaxTemperatureClamping: + + def test_boundary_values(self): + mod = _import_llm_client() + clamp = mod.MiniMaxClient._clamp_temperature + assert clamp(0.0) == 0.01 + assert clamp(1.0) == 1.0 + assert clamp(0.001) == 0.001 + assert clamp(0.999) == 0.999 + assert clamp(-100) == 0.01 + assert clamp(100) == 1.0 + assert clamp(1.001) == 1.0 + + def test_common_temperatures(self): + mod = _import_llm_client() + clamp = mod.MiniMaxClient._clamp_temperature + for t in [0.1, 0.2, 0.3, 0.5, 0.7, 0.8, 0.9]: + assert clamp(t) == t, f"Temperature {t} should be preserved" diff --git a/web/src/assets/images/model/minimax.svg b/web/src/assets/images/model/minimax.svg new file mode 100644 index 000000000..149b31071 --- /dev/null +++ b/web/src/assets/images/model/minimax.svg @@ -0,0 +1,7 @@ + + + Mini + Max + + + diff --git a/web/src/i18n/en.ts b/web/src/i18n/en.ts index 2007005fd..297407718 100644 --- a/web/src/i18n/en.ts +++ b/web/src/i18n/en.ts @@ -548,6 +548,7 @@ export const en = { openai: "Openai", dashscope: "Dashscope", ollama: "Ollama", + minimax: "MiniMax", xinference: "Xinference", gpustack: "Gpustack", bedrock: "Bedrock", @@ -606,6 +607,7 @@ export const en = { openai: "Openai", dashscope: "Dashscope", ollama: "Ollama", + minimax: "MiniMax", xinference: "Xinference", gpustack: "Gpustack", bedrock: "Bedrock", diff --git a/web/src/i18n/zh.ts b/web/src/i18n/zh.ts index bce9b5cc9..2c1621859 100644 --- a/web/src/i18n/zh.ts +++ b/web/src/i18n/zh.ts @@ -1181,6 +1181,7 @@ export const zh = { openai: "Openai", dashscope: "Dashscope", ollama: "Ollama", + minimax: "MiniMax", xinference: "Xinference", gpustack: "Gpustack", bedrock: "Bedrock" @@ -1239,6 +1240,7 @@ export const zh = { openai: "Openai", dashscope: "Dashscope", ollama: "Ollama", + minimax: "MiniMax", xinference: "Xinference", gpustack: "Gpustack", bedrock: "Bedrock", diff --git a/web/src/views/ModelManagement/utils.ts b/web/src/views/ModelManagement/utils.ts index bf44367f5..f00d8e237 100644 --- a/web/src/views/ModelManagement/utils.ts +++ b/web/src/views/ModelManagement/utils.ts @@ -11,6 +11,7 @@ import bedrockIcon from '@/assets/images/model/bedrock.svg' import dashscopeIcon from '@/assets/images/model/dashscope.png' import gpustackIcon from '@/assets/images/model/gpustack.png' +import minimaxIcon from '@/assets/images/model/minimax.svg' import ollamaIcon from '@/assets/images/model/ollama.svg' import openaiIcon from '@/assets/images/model/openai.svg' import xinferenceIcon from '@/assets/images/model/xinference.svg' @@ -22,6 +23,7 @@ export const ICONS = { bedrock: bedrockIcon, dashscope: dashscopeIcon, gpustack: gpustackIcon, + minimax: minimaxIcon, ollama: ollamaIcon, openai: openaiIcon, xinference: xinferenceIcon