diff --git a/bot/tests/test_openapi_auth.py b/bot/tests/test_openapi_auth.py new file mode 100644 index 000000000..d01e53b8a --- /dev/null +++ b/bot/tests/test_openapi_auth.py @@ -0,0 +1,130 @@ +# Copyright (c) 2026 Beijing Volcano Engine Technology Co., Ltd. +# SPDX-License-Identifier: AGPL-3.0 +"""Regression tests for OpenAPI HTTP auth requirements.""" + +import tempfile +from pathlib import Path + +import pytest +from fastapi import FastAPI +from fastapi.testclient import TestClient +from vikingbot.bus.queue import MessageBus +from vikingbot.channels.openapi import OpenAPIChannel, OpenAPIChannelConfig +from vikingbot.channels.openapi_models import ChatResponse +from vikingbot.config.schema import BotChannelConfig + + +@pytest.fixture +def temp_workspace(): + with tempfile.TemporaryDirectory() as tmpdir: + yield Path(tmpdir) + + +@pytest.fixture +def message_bus(): + return MessageBus() + + +def _make_client(channel: OpenAPIChannel) -> TestClient: + app = FastAPI() + app.include_router(channel.get_router(), prefix="/bot/v1") + return TestClient(app) + + +class TestOpenAPIAuth: + def test_health_remains_available_without_api_key(self, message_bus, temp_workspace): + channel = OpenAPIChannel( + OpenAPIChannelConfig(api_key=""), + message_bus, + workspace_path=temp_workspace, + ) + client = _make_client(channel) + + response = client.get("/bot/v1/health") + + assert response.status_code == 200 + + def test_chat_rejects_requests_when_api_key_not_configured(self, message_bus, temp_workspace): + channel = OpenAPIChannel( + OpenAPIChannelConfig(api_key=""), + message_bus, + workspace_path=temp_workspace, + ) + client = _make_client(channel) + + response = client.post("/bot/v1/chat", json={"message": "hello"}) + + assert response.status_code == 503 + assert response.json()["detail"] == "OpenAPI channel API key is not configured" + + def test_chat_accepts_request_with_configured_valid_api_key( + self, message_bus, temp_workspace, monkeypatch + ): + channel = OpenAPIChannel( + OpenAPIChannelConfig(api_key="secret123"), + message_bus, + workspace_path=temp_workspace, + ) + + async def fake_handle_chat(request): + return ChatResponse( + session_id=request.session_id or "default", message="ok", events=None + ) + + monkeypatch.setattr(channel, "_handle_chat", fake_handle_chat) + client = _make_client(channel) + + response = client.post( + "/bot/v1/chat", + headers={"X-API-Key": "secret123"}, + json={"message": "hello"}, + ) + + assert response.status_code == 200 + assert response.json()["message"] == "ok" + + def test_bot_channel_rejects_requests_when_channel_api_key_not_configured( + self, message_bus, temp_workspace + ): + channel = OpenAPIChannel( + OpenAPIChannelConfig(api_key="gateway-secret"), + message_bus, + workspace_path=temp_workspace, + ) + channel._bot_configs["alpha"] = BotChannelConfig(id="alpha", api_key="") + client = _make_client(channel) + + response = client.post( + "/bot/v1/chat/channel", + json={"message": "hello", "channel_id": "alpha"}, + ) + + assert response.status_code == 503 + assert response.json()["detail"] == "Bot channel 'alpha' API key is not configured" + + def test_bot_channel_accepts_request_with_valid_api_key( + self, message_bus, temp_workspace, monkeypatch + ): + channel = OpenAPIChannel( + OpenAPIChannelConfig(api_key="gateway-secret"), + message_bus, + workspace_path=temp_workspace, + ) + channel._bot_configs["alpha"] = BotChannelConfig(id="alpha", api_key="bot-secret") + + async def fake_handle_bot_chat(channel_id, request): + return ChatResponse( + session_id=request.session_id or "default", message=f"ok:{channel_id}" + ) + + monkeypatch.setattr(channel, "_handle_bot_chat", fake_handle_bot_chat) + client = _make_client(channel) + + response = client.post( + "/bot/v1/chat/channel", + headers={"X-API-Key": "bot-secret"}, + json={"message": "hello", "channel_id": "alpha"}, + ) + + assert response.status_code == 200 + assert response.json()["message"] == "ok:alpha" diff --git a/bot/vikingbot/channels/openapi.py b/bot/vikingbot/channels/openapi.py index c97238a6e..16d3b305b 100644 --- a/bot/vikingbot/channels/openapi.py +++ b/bot/vikingbot/channels/openapi.py @@ -28,9 +28,9 @@ ) from vikingbot.config.schema import ( BaseChannelConfig, + BotChannelConfig, Config, SessionKey, - BotChannelConfig, ) @@ -179,10 +179,15 @@ async def send(self, msg: OutboundMessage) -> None: pending = self._bot_pending[channel_id].get(session_id) if not pending: - logger.warning(f"No pending request for BotChannel {channel_id} session: {session_id}") + logger.warning( + f"No pending request for BotChannel {channel_id} session: {session_id}" + ) return - if msg.event_type == OutboundEventType.RESPONSE or msg.event_type == OutboundEventType.NO_REPLY: + if ( + msg.event_type == OutboundEventType.RESPONSE + or msg.event_type == OutboundEventType.NO_REPLY + ): await pending.add_event("response", msg.content or "") pending.set_final(msg.content or "") await pending.close_stream() @@ -226,9 +231,12 @@ def _create_router(self) -> APIRouter: channel = self # Capture for closures async def verify_api_key(x_api_key: Optional[str] = Header(None)) -> bool: - """Verify API key if configured.""" + """Verify API key for privileged HTTP chat/session routes.""" if not channel.config.api_key: - return True # No auth required + raise HTTPException( + status_code=503, + detail="OpenAPI channel API key is not configured", + ) if not x_api_key: raise HTTPException(status_code=401, detail="X-API-Key header required") # Use secrets.compare_digest for timing-safe comparison @@ -332,10 +340,25 @@ async def delete_session( # ========== Bot Channel Routes ========== - async def verify_bot_channel_api_key(x_api_key: Optional[str] = Header(None)) -> Optional[str]: - """Verify API key and return it if valid.""" + async def verify_bot_channel_api_key( + x_api_key: Optional[str] = Header(None), + ) -> Optional[str]: + """Capture the raw bot-channel API key header for per-channel verification.""" return x_api_key + def ensure_bot_channel_api_key(channel_id: str, x_api_key: Optional[str]) -> None: + """Require an explicit per-channel API key for privileged bot HTTP routes.""" + bot_config = channel._bot_configs[channel_id] + if not bot_config.api_key: + raise HTTPException( + status_code=503, + detail=f"Bot channel '{channel_id}' API key is not configured", + ) + if not x_api_key: + raise HTTPException(status_code=401, detail="X-API-Key header required") + if not secrets.compare_digest(x_api_key, bot_config.api_key): + raise HTTPException(status_code=403, detail="Invalid API key") + @router.post("/chat/channel", response_model=ChatResponse) async def chat_channel( request: ChatRequest, @@ -348,13 +371,7 @@ async def chat_channel( if channel_id not in channel._bot_configs: raise HTTPException(status_code=404, detail=f"Channel '{channel_id}' not found") - # Verify API key for the specific channel - bot_config = channel._bot_configs[channel_id] - if bot_config.api_key: - if not x_api_key: - raise HTTPException(status_code=401, detail="X-API-Key header required") - if not secrets.compare_digest(x_api_key, bot_config.api_key): - raise HTTPException(status_code=403, detail="Invalid API key") + ensure_bot_channel_api_key(channel_id, x_api_key) return await channel._handle_bot_chat(channel_id, request) @@ -370,13 +387,7 @@ async def chat_channel_stream( if channel_id not in channel._bot_configs: raise HTTPException(status_code=404, detail=f"Channel '{channel_id}' not found") - # Verify API key for the specific channel - bot_config = channel._bot_configs[channel_id] - if bot_config.api_key: - if not x_api_key: - raise HTTPException(status_code=401, detail="X-API-Key header required") - if not secrets.compare_digest(x_api_key, bot_config.api_key): - raise HTTPException(status_code=403, detail="Invalid API key") + ensure_bot_channel_api_key(channel_id, x_api_key) if not request.stream: request.stream = True @@ -609,7 +620,9 @@ async def _handle_bot_chat(self, channel_id: str, request: ChatRequest) -> ChatR if channel_id in self._bot_pending: self._bot_pending[channel_id].pop(session_id, None) - async def _handle_bot_chat_stream(self, channel_id: str, request: ChatRequest) -> StreamingResponse: + async def _handle_bot_chat_stream( + self, channel_id: str, request: ChatRequest + ) -> StreamingResponse: """Handle a BotChannel streaming chat request.""" session_id = request.session_id or str(uuid.uuid4()) user_id = request.user_id or "anonymous" @@ -727,4 +740,4 @@ def get_openapi_router(bus: MessageBus, config: Config) -> APIRouter: ) logger.info(f"Subscribed to bot_api channel: {channel_id}") - return channel.get_router() \ No newline at end of file + return channel.get_router() diff --git a/bot/vikingbot/cli/commands.py b/bot/vikingbot/cli/commands.py index fc7c979ee..907f3263e 100644 --- a/bot/vikingbot/cli/commands.py +++ b/bot/vikingbot/cli/commands.py @@ -438,7 +438,7 @@ def prepare_channel( openapi_config = OpenAPIChannelConfig( enabled=True, port=openapi_port, - api_key="", # No auth required by default + api_key="", ) openapi_channel = OpenAPIChannel( openapi_config, @@ -447,7 +447,9 @@ def prepare_channel( global_config=config, ) channels.add_channel(openapi_channel) - logger.info(f"OpenAPI channel enabled on port {openapi_port}") + logger.info( + f"OpenAPI channel enabled on port {openapi_port}; configure an API key before using HTTP chat endpoints" + ) if channels.enabled_channels: console.print(f"[green]✓[/green] Channels enabled: {', '.join(channels.enabled_channels)}") diff --git a/bot/vikingbot/config/schema.py b/bot/vikingbot/config/schema.py index 438f1dc10..a4fed32af 100644 --- a/bot/vikingbot/config/schema.py +++ b/bot/vikingbot/config/schema.py @@ -41,8 +41,10 @@ class SandboxMode(str, Enum): SHARED = "shared" PER_CHANNEL = "per-channel" + class AgentMemoryMode(str, Enum): """Agent memory mode enumeration.""" + PER_SESSION = "per-session" SHARED = "shared" PER_CHANNEL = "per-channel" @@ -50,6 +52,7 @@ class AgentMemoryMode(str, Enum): class BotMode(str, Enum): """Bot running mode enumeration.""" + NORMAL = "normal" READONLY = "readonly" DEBUG = "debug" @@ -119,7 +122,10 @@ class FeishuChannelConfig(BaseChannelConfig): verification_token: str = "" allow_from: list[str] = Field(default_factory=list) allow_cmd_from: list[str] = Field(default_factory=list) ## 允许执行命令的Feishu用户ID列表 - thread_require_mention: bool = Field(default=True, description="话题群模式下是否需要@才响应:默认True=所有消息必须@才响应;False=新话题首条消息无需@,后续回复必须@") + thread_require_mention: bool = Field( + default=True, + description="话题群模式下是否需要@才响应:默认True=所有消息必须@才响应;False=新话题首条消息无需@,后续回复必须@", + ) def channel_id(self) -> str: # Use app_id directly as the ID @@ -266,7 +272,7 @@ class OpenAPIChannelConfig(BaseChannelConfig): type: ChannelType = ChannelType.OPENAPI enabled: bool = True - api_key: str = "" # If empty, no auth required + api_key: str = "" # Empty disables privileged HTTP routes until configured allow_from: list[str] = Field(default_factory=list) max_concurrent_requests: int = 100 _channel_id: str = "default" @@ -280,7 +286,7 @@ class BotChannelConfig(BaseChannelConfig): type: ChannelType = ChannelType.BOT_API enabled: bool = True - api_key: str = "" # If empty, no auth required + api_key: str = "" # Empty disables privileged HTTP routes until configured allow_from: list[str] = Field(default_factory=list) max_concurrent_requests: int = 100 need_mention: bool = False @@ -437,7 +443,9 @@ class ProviderConfig(BaseModel): api_key: str = "" api_base: Optional[str] = None - extra_headers: Optional[dict[str, str]] = Field(default_factory=dict) # Custom headers (e.g. APP-Code for AiHubMix) + extra_headers: Optional[dict[str, str]] = Field( + default_factory=dict + ) # Custom headers (e.g. APP-Code for AiHubMix) class ProvidersConfig(BaseModel): @@ -801,4 +809,4 @@ def from_safe_name(safe_name: str): file_name_split = safe_name.split("__") return SessionKey( type=file_name_split[0], channel_id=file_name_split[1], chat_id=file_name_split[2] - ) \ No newline at end of file + )