Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion src/framex/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,13 @@ def run(
builtin_plugins = settings.load_builtin_plugins if load_builtin_plugins is None else load_builtin_plugins
external_plugins = settings.load_plugins if load_plugins is None else load_plugins

reversion = reversion or VERSION
if reversion:
settings.server.reversion = reversion
elif settings.server.reversion:
reversion = settings.server.reversion
else:
reversion = VERSION
settings.server.reversion = VERSION

if test_mode and use_ray:
raise RuntimeError("FlameX can not run when `test_mode` == True, and `use_ray` == True")
Expand Down
1 change: 1 addition & 0 deletions src/framex/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,7 @@ class ServerConfig(BaseModel):
num_cpus: int = -1
excluded_log_paths: list[str] = Field(default_factory=list)
ingress_config: dict[str, Any] = Field(default_factory=lambda: {"max_ongoing_requests": 60})
reversion: str = ""


class TestConfig(BaseModel):
Expand Down
3 changes: 2 additions & 1 deletion src/framex/driver/application.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,8 @@ def build_openapi_description() -> str:
|--------|-------|
| Started At | `{started_at}` |
| Uptime | `{uptime}` |
| Version | `v{VERSION}` |
| Service-Version | `v{settings.server.reversion}` |
| FrameX-Version | `v{VERSION}` |

---
"""
Expand Down
50 changes: 41 additions & 9 deletions src/framex/driver/ingress.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import re
from collections.abc import Callable
from enum import Enum
from typing import Any
Expand All @@ -16,7 +17,7 @@
from framex.driver.decorator import api_ingress
from framex.log import setup_logger
from framex.plugin.model import ApiType, PluginApi
from framex.utils import escape_tag, safe_error_message
from framex.utils import escape_tag, safe_error_message, shorten_str

app = create_fastapi_application()

Expand Down Expand Up @@ -74,13 +75,18 @@ def register_route(
from framex.config import settings

auth_keys = settings.auth.get_auth_keys(path)
logger.debug(f"API({path}) with tags {tags} requires auth_keys {auth_keys}")
logger.trace(f"API({path}) with tags {tags} requires auth_keys {auth_keys}")
adapter = get_adapter()

try:
routes: list[str] = [route.path for route in app.routes if isinstance(route, Route | APIRoute)]
# logger.warning(f"API({path}) with tags {tags} is already registered, skipping duplicate registration.")
methods_str = ",".join(m.upper() for m in methods)

if path in routes:
logger.warning(f"API({path}) with tags {tags} is already registered, skipping duplicate registration.")
logger.opt(colors=True).warning(
f"API route already registered: {methods_str:<4} {path[:40] + '...':<45} ({handle.deployment_name})"
)
return False
if (not path) or (not methods):
raise RuntimeError(f"Api({path}) or methods({methods}) is empty")
Expand Down Expand Up @@ -111,33 +117,38 @@ async def route_handler(response: Response, model: Model = Depends()) -> Any: #
# Inject auth dependency if needed
dependencies = []
if auth_keys is not None:
logger.debug(f"API({path}) with tags {tags} requires auth.")
logger.trace(f"API({path}) with tags {tags} requires auth.")

def _verify_api_key(request: Request, api_key: str | None = Depends(api_key_header)) -> None:
if (api_key is None or api_key not in auth_keys) and (not auth_jwt(request)):
logger.error(f"Unauthorized access attempt with API Key({api_key}) for API({path})")
logger.opt(colors=True).error(
f"<r>Unauthorized access attempt with API Key({api_key}) for API({path})</r>"
)
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail=f"Invalid API Key({api_key}) for API({path})",
)

dependencies.append(Depends(_verify_api_key))

app.add_api_route(
self.add_api_route(
path,
route_handler,
methods=methods,
tags=tags,
response_class=StreamingResponse if stream else JSONResponse,
dependencies=dependencies,
)
methods_str = ",".join(m.upper() for m in methods)
short_path = shorten_str(path)
logger.opt(colors=True).success(
f"Succeeded to register api({methods}): {path} from {handle.deployment_name}, params: {params}"
f"API route registered: {methods_str:<4} <g>{short_path:<45}</g> ({handle.deployment_name})"
)
return True
except Exception as e:
logger.opt(exception=e).error(f'Failed to register api "{escape_tag(path)}" from {handle.deployment_name}')

logger.opt(exception=e, colors=True).error(
f'<r>Failed to register api "{escape_tag(path)}" from {handle.deployment_name}</r>'
)
return False

@app.get("/ping")
Expand All @@ -146,3 +157,24 @@ async def inner(self) -> str: # pragma: no cover

def __repr__(self):
return BACKEND_NAME

def add_api_route(
self,
path: str,
endpoint: Callable[..., Any],
*,
methods: list[str] | None = None,
**kwargs: Any,
) -> None:
method_set: set[str] = {m.upper() for m in methods} if methods else {"GET"}
norm_path = re.sub(r"\{[^}]+\}", "{}", path)

for route in app.routes:
if (
isinstance(route, APIRoute)
and re.sub(r"\{[^}]+\}", "{}", route.path) == norm_path
and route.methods & method_set
):
raise RuntimeError(f"Duplicate API route: {sorted(method_set)} {norm_path}")

app.add_api_route(path, endpoint, methods=list(method_set), **kwargs)
2 changes: 1 addition & 1 deletion src/framex/plugin/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -125,7 +125,7 @@ async def call_plugin_api(
res = result.get("data")
status = result.get("status")
if status not in settings.server.legal_proxy_code:
logger.opt(colors=True).error(f"Proxy API {api_name} call illegal: <r>{result}</r>")
logger.opt(colors=True).error(f"<>Proxy API {api_name} call illegal: <r>{result}</r>")
raise RuntimeError(f"Proxy API {api_name} returned status {status}")
if res is None:
logger.opt(colors=True).warning(f"API {api_name} returned empty data")
Expand Down
11 changes: 6 additions & 5 deletions src/framex/plugins/proxy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from framex.plugins.proxy.builder import create_pydantic_model, type_map
from framex.plugins.proxy.config import ProxyPluginConfig, settings
from framex.plugins.proxy.model import ProxyFunc, ProxyFuncHttpBody
from framex.utils import cache_decode, cache_encode
from framex.utils import cache_decode, cache_encode, shorten_str

__plugin_meta__ = PluginMetadata(
name="proxy",
Expand Down Expand Up @@ -85,13 +85,13 @@ async def _parse_openai_docs(self, url: str) -> None:
for path, details in paths.items():
# Check if the path is legal!
if not settings.is_white_url(path):
logger.warning(f"Proxy api({path}) not in white_list, skipping...")
logger.opt(colors=True).warning(f"Proxy api(<y>{path}</y>) not in white_list, skipping...")
continue

# Get auth api_keys
if auth_api_key := settings.auth.get_auth_keys(path):
headers = {"Authorization": auth_api_key[0]} # Use the first auth key set
logger.debug(f"Proxy api({path}) requires auth")
logger.trace(f"Proxy api({path}) requires auth")
else:
headers = None

Expand Down Expand Up @@ -119,7 +119,7 @@ async def _parse_openai_docs(self, url: str) -> None:

Model = create_pydantic_model(schema_name, model_schema, components) # noqa
params.append(("model", Model))
logger.opt(colors=True).debug(f"Found proxy api({method}) <g>{url}{path}</g>")
logger.opt(colors=True).trace(f"Found proxy api({method}) <g>{url}{path}</g>")
func_name = body.get("operationId")
is_stream = path in settings.force_stream_apis
func = self._create_dynamic_method(
Expand Down Expand Up @@ -234,7 +234,8 @@ def _create_dynamic_method(

# Construct dynamic methods
async def dynamic_method(**kwargs: Any) -> AsyncGenerator[str, None] | dict[str, Any] | str:
logger.info(f"Calling proxy url: {url} with kwargs: {kwargs}")
log_info = shorten_str(str(kwargs), 512)
logger.info(f"Calling proxy url: {url} with kwargs: {log_info}")
validated = RequestModel(**kwargs) # Type Validation
query = {}
json_body = None
Expand Down
4 changes: 4 additions & 0 deletions src/framex/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -156,3 +156,7 @@ def safe_error_message(e: Exception) -> str:
if e.args:
return str(e.args[0])
return "Internal Server Error"


def shorten_str(data: str, max_len: int = 45) -> str:
return data if len(data) <= max_len else data[: max_len - 3] + "..."
4 changes: 2 additions & 2 deletions tests/driver/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,6 @@ def test_docs_accessible_with_valid_jwt(self):
"secret",
algorithm="HS256",
)

resp = client.get("/docs", cookies={"token": token}, follow_redirects=False)
client.cookies.set("token", token)
resp = client.get("/docs", follow_redirects=False)
assert resp.status_code == status.HTTP_200_OK
104 changes: 104 additions & 0 deletions tests/driver/test_ingress.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
from unittest.mock import Mock, patch

import pytest
from fastapi.routing import APIRoute
from starlette.routing import Route

from framex.driver.ingress import APIIngress

# ---------- helpers ----------


def make_route(path: str, methods: set[str]) -> APIRoute:
route = Mock(spec=APIRoute)
route.path = path
route.methods = methods
return route


# ---------- fixtures ----------


@pytest.fixture
def mock_app():
with patch("framex.driver.ingress.app") as app:
app.routes = []
app.add_api_route = Mock()
yield app


@pytest.fixture
def ingress():
return APIIngress.__new__(APIIngress)


# ---------- tests ----------


def test_add_first_route_success(ingress, mock_app):
endpoint = Mock()

ingress.add_api_route("/users", endpoint, methods=["GET"])

mock_app.add_api_route.assert_called_once_with("/users", endpoint, methods=["GET"])


@pytest.mark.parametrize(
("existing_path", "new_path"),
[
("/users/{id}", "/users/{id}"),
("/users/{id}", "/users/{user_id}"),
("/users/{uid}/posts/{pid}", "/users/{id}/posts/{post_id}"),
],
)
def test_duplicate_path_same_method_raises(ingress, mock_app, existing_path, new_path):
mock_app.routes = [make_route(existing_path, {"GET"})]

with pytest.raises(RuntimeError, match=r"Duplicate API route"):
ingress.add_api_route(new_path, Mock(), methods=["GET"])


def test_same_path_different_method_allowed(ingress, mock_app):
mock_app.routes = [make_route("/users/{id}", {"GET"})]

ingress.add_api_route("/users/{id}", Mock(), methods=["POST"])

mock_app.add_api_route.assert_called_once()


def test_overlapping_methods_raises(ingress, mock_app):
mock_app.routes = [make_route("/users", {"GET", "POST"})]

with pytest.raises(RuntimeError):
ingress.add_api_route("/users", Mock(), methods=["POST", "PUT"])


def test_case_insensitive_methods(ingress, mock_app):
mock_app.routes = [make_route("/users", {"GET"})]

with pytest.raises(RuntimeError):
ingress.add_api_route("/users", Mock(), methods=["get"])


def test_non_api_route_is_ignored(ingress, mock_app):
non_api_route = Mock(spec=Route)
non_api_route.path = "/users/{id}"
mock_app.routes = [non_api_route]

ingress.add_api_route("/users/{id}", Mock(), methods=["GET"])

mock_app.add_api_route.assert_called_once()


def test_kwargs_are_passed_through(ingress, mock_app):
ingress.add_api_route(
"/users",
Mock(),
methods=["GET"],
tags=["users"],
response_class=Mock(),
)

_, kwargs = mock_app.add_api_route.call_args
assert kwargs["tags"] == ["users"]
assert "response_class" in kwargs
Loading