diff --git a/.env.run-docker b/.env.run-docker index ee3f1c22..9e77d8d7 100644 --- a/.env.run-docker +++ b/.env.run-docker @@ -8,6 +8,7 @@ export LOG_DIAGNOSE=true export LOG_STANDARD_LOGGER__root=INFO export LOG_STANDARD_LOGGER__sqlalchemy.engine=INFO export LOG_STANDARD_LOGGER__sqlalchemy.pool=INFO +export LOG_STANDARD_LOGGER__uvicorn.access=WARNING export AWS_ACCESS_KEY_ID=entitycore export AWS_SECRET_ACCESS_KEY=entitycore diff --git a/.env.run-local b/.env.run-local index 384087d5..f0847e33 100644 --- a/.env.run-local +++ b/.env.run-local @@ -8,6 +8,7 @@ export LOG_DIAGNOSE=true export LOG_STANDARD_LOGGER__root=INFO export LOG_STANDARD_LOGGER__sqlalchemy.engine=INFO export LOG_STANDARD_LOGGER__sqlalchemy.pool=INFO +export LOG_STANDARD_LOGGER__uvicorn.access=WARNING export UVICORN_HOST=127.0.0.1 export UVICORN_PORT=8000 diff --git a/.env.test-docker b/.env.test-docker index 2f309234..6cf894b2 100644 --- a/.env.test-docker +++ b/.env.test-docker @@ -1,5 +1,7 @@ # for testing the application with Docker +export LOG_STANDARD_LOGGER__root=WARNING + export DB_HOST=db-test export DB_PORT=5432 export DB_USER=test diff --git a/.env.test-local b/.env.test-local index 82dbf71c..b7cbfcf2 100644 --- a/.env.test-local +++ b/.env.test-local @@ -2,6 +2,7 @@ export APP_VERSION=${APP_VERSION:-1} export COMMIT_SHA=${COMMIT_SHA:-deadbeef} +export LOG_STANDARD_LOGGER__root=WARNING export DB_HOST=127.0.0.1 export DB_PORT=5434 diff --git a/app/application.py b/app/application.py index 63f249dc..b197bc21 100644 --- a/app/application.py +++ b/app/application.py @@ -1,6 +1,5 @@ import asyncio import os -import time from collections.abc import AsyncIterator from contextlib import asynccontextmanager from http import HTTPStatus @@ -20,10 +19,9 @@ from app.dependencies.common import forbid_extra_query_params from app.errors import ApiError, ApiErrorCode from app.logger import L +from app.middleware import RequestContextMiddleware from app.routers import router from app.schemas.api import ErrorResponse -from app.schemas.types import HeaderKey -from app.utils.uuid import create_uuid @asynccontextmanager @@ -141,26 +139,7 @@ async def http_exception_handler(request: Request, exception: StarletteHTTPExcep allow_methods=["*"], allow_headers=["*"], ) - - -@app.middleware("http") -async def add_request_id_header(request: Request, call_next): - """Generate a unique request-id and add it to the response headers.""" - request_id = str(create_uuid()) - request.state.request_id = request_id - response = await call_next(request) - response.headers[HeaderKey.request_id] = request_id - return response - - -@app.middleware("http") -async def add_process_time_header(request: Request, call_next): - """Calculate the process time and add it to the response headers.""" - start_time = time.perf_counter() - response = await call_next(request) - process_time = time.perf_counter() - start_time - response.headers[HeaderKey.process_time] = f"{process_time:.3f}" - return response +app.add_middleware(RequestContextMiddleware) app.include_router( diff --git a/app/config.py b/app/config.py index 590450b5..e2b6cd4e 100644 --- a/app/config.py +++ b/app/config.py @@ -39,7 +39,7 @@ class Settings(BaseSettings): LOG_DIAGNOSE: bool = False LOG_ENQUEUE: bool = False LOG_CATCH: bool = True - LOG_STANDARD_LOGGER: dict[str, str] = {"root": "INFO"} + LOG_STANDARD_LOGGER: dict[str, str] = {"root": "INFO", "uvicorn.access": "WARNING"} KEYCLOAK_URL: str = "https://staging.cell-a.openbraininstitute.org/auth/realms/SBO" AUTH_CACHE_MAXSIZE: int = 128 # items diff --git a/app/context.py b/app/context.py new file mode 100644 index 00000000..5b9b0ea7 --- /dev/null +++ b/app/context.py @@ -0,0 +1,12 @@ +from contextvars import ContextVar +from typing import TypedDict + + +class RequestContext(TypedDict, total=False): + """Request context dictionary.""" + + request_id: str # Unique identifier for the current request + user_id: str # Keycloak identifier of the user making the request + + +request_context_provider: ContextVar[RequestContext] = ContextVar("request_context") diff --git a/app/dependencies/auth.py b/app/dependencies/auth.py index 2594041a..e02a863b 100644 --- a/app/dependencies/auth.py +++ b/app/dependencies/auth.py @@ -12,6 +12,7 @@ from starlette.requests import Request from app.config import settings +from app.context import request_context_provider from app.errors import ApiError, ApiErrorCode, AuthErrorReason from app.logger import L from app.schemas.auth import ( @@ -168,6 +169,12 @@ def _check_user_info( return user_context +def _enrich_request_context(user_context: UserContext) -> None: + """Store user_id in contextvar for logging purposes.""" + ctx = request_context_provider.get() + ctx["user_id"] = str(user_context.profile.subject) + + def user_verified( project_context: Annotated[OptionalProjectContext, Header()], token: Annotated[HTTPAuthorizationCredentials | None, Depends(AuthHeader)], @@ -181,7 +188,7 @@ def user_verified( """ if settings.APP_DISABLE_AUTH: L.warning("Authentication is disabled: admin role granted, vlab and proj not verified") - return UserContext( + user_context = UserContext( profile=UserProfile( subject=UUID(int=0), name="Admin User", @@ -192,6 +199,9 @@ def user_verified( virtual_lab_id=project_context.virtual_lab_id, project_id=project_context.project_id, ) + # enrich request context even when the authentication is disabled + _enrich_request_context(user_context) + return user_context if not token: raise ApiError( @@ -213,6 +223,9 @@ def user_verified( http_client=request.state.http_client, ) + # enrich request context before potentially raising an exception + _enrich_request_context(user_context) + if not user_context.is_authorized: match user_context.auth_error_reason: case AuthErrorReason.NOT_AUTHORIZED_USER | AuthErrorReason.NOT_AUTHORIZED_PROJECT: diff --git a/app/dependencies/logger.py b/app/dependencies/logger.py deleted file mode 100644 index c5bb67b5..00000000 --- a/app/dependencies/logger.py +++ /dev/null @@ -1,35 +0,0 @@ -from collections.abc import AsyncIterator - -from starlette.requests import Request - -from app.dependencies.auth import UserContextDep -from app.logger import L -from app.schemas.types import HeaderKey - - -async def logger_context(request: Request, user_context: UserContextDep) -> AsyncIterator[None]: - """Add context information to each log message in authenticated endpoints. - - It shold be used only in authenticated endpoints, since it depends on `user_context`. - - These additional keys are added to the extra dict: - - - sub_id: the subject_id from Keycloak - - request_id: id that can be used to correlate multiple logs in the same request - - forwarded_for: the originating IP address of the client, from the X-Forwarded-For HTTP header - - Must be async because FastAPI wraps sync generator dependencies with - contextmanager_in_threadpool, which runs __enter__ and __exit__ - in different contexts, invalidating ContextVar tokens. - See: https://github.com/fastapi/fastapi/blob/c441583/fastapi/concurrency.py#L28-L41 - """ - sub_id = str(user_context.profile.subject) - request_id = request.state.request_id - forwarded_for = request.headers.get(HeaderKey.forwarded_for) - - with L.contextualize( - sub_id=sub_id, - request_id=request_id, - forwarded_for=forwarded_for, - ): - yield diff --git a/app/logger.py b/app/logger.py index 0015ad96..f833fb70 100644 --- a/app/logger.py +++ b/app/logger.py @@ -12,6 +12,7 @@ from loguru import logger from app.config import settings +from app.context import request_context_provider L = logger @@ -84,6 +85,16 @@ def str_formatter(record: "loguru.Record") -> str: def configure_logging() -> int: """Configure logging.""" + + def patcher(record: "loguru.Record") -> None: + """Add request context (request_id, user_id) to all log records. + + This function is automatically applied to every log message across all modules, + enriching them with contextual information from the current request. + """ + ctx = request_context_provider.get({}) + record["extra"].update(ctx) + L.remove() handler_id = L.add( sink=sys.stderr, @@ -94,6 +105,7 @@ def configure_logging() -> int: enqueue=settings.LOG_ENQUEUE, catch=settings.LOG_CATCH, ) + L.configure(patcher=patcher) L.enable("app") logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET, force=True) for logger_name, logger_level in settings.LOG_STANDARD_LOGGER.items(): diff --git a/app/middleware.py b/app/middleware.py new file mode 100644 index 00000000..8983f474 --- /dev/null +++ b/app/middleware.py @@ -0,0 +1,61 @@ +"""Request context middleware.""" + +import time +from collections.abc import Awaitable, Callable + +from starlette.middleware.base import BaseHTTPMiddleware +from starlette.requests import Request +from starlette.responses import Response + +from app.context import RequestContext, request_context_provider +from app.logger import L +from app.schemas.types import HeaderKey +from app.utils.uuid import create_uuid + +RequestResponseEndpoint = Callable[[Request], Awaitable[Response]] + + +class RequestContextMiddleware(BaseHTTPMiddleware): + """Middleware to initialize request context and log access.""" + + async def dispatch( # noqa: PLR6301 + self, + request: Request, + call_next: RequestResponseEndpoint, + ) -> Response: + """Set request context and log access.""" + start_time = time.perf_counter() + request_id = str(create_uuid()) + ctx = RequestContext(request_id=request_id) + request_context_provider.set(ctx) + + try: + response = await call_next(request) + except Exception: + process_time = time.perf_counter() - start_time + L.error( + "{} {}", + request.method, + str(request.url), + client=request.client.host if request.client else "-", + status_code=500, + process_time=f"{process_time:.3f}", + forwarded_for=request.headers.get(HeaderKey.forwarded_for, ""), + ) + raise + + process_time = time.perf_counter() - start_time + response.headers[HeaderKey.process_time] = f"{process_time:.3f}" + response.headers[HeaderKey.request_id] = request_id + + L.info( + "{} {}", + request.method, + str(request.url), + client=request.client.host if request.client else "-", + status_code=response.status_code, + process_time=f"{process_time:.3f}", + forwarded_for=request.headers.get(HeaderKey.forwarded_for, ""), + ) + + return response diff --git a/app/routers/__init__.py b/app/routers/__init__.py index 6ddaf5f2..a62c2b03 100644 --- a/app/routers/__init__.py +++ b/app/routers/__init__.py @@ -3,7 +3,6 @@ from fastapi import APIRouter, Depends from app.dependencies.auth import user_verified, user_with_service_admin_role -from app.dependencies.logger import logger_context from app.routers import ( admin, analysis_notebook_environment, @@ -85,7 +84,6 @@ admin.router, dependencies=[ Depends(user_with_service_admin_role), - Depends(logger_context), ], ) @@ -167,6 +165,5 @@ r, dependencies=[ Depends(user_verified), - Depends(logger_context), ], ) diff --git a/tests/conftest.py b/tests/conftest.py index cf3437da..b7b7eb4f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -52,6 +52,7 @@ from app.db.session import DatabaseSessionManager, configure_database_session_manager from app.db.types import CellMorphologyGenerationType, EntityType, StorageType from app.dependencies import auth +from app.logger import configure_logging from app.schemas.auth import UserContext, UserProfile, UserProjectGroup from app.schemas.external_url import ExternalUrlCreate @@ -100,6 +101,11 @@ def _setup_env_variables(): os.environ["AWS_SESSION_TOKEN"] = "testing" # noqa: S105 +@pytest.fixture(scope="session", autouse=True) +def _configure_logging(): + configure_logging() + + @pytest.fixture(scope="session") def s3(): """Return a mocked S3 client.""" diff --git a/tests/dependencies/test_logger.py b/tests/dependencies/test_logger.py deleted file mode 100644 index 6691efe1..00000000 --- a/tests/dependencies/test_logger.py +++ /dev/null @@ -1,40 +0,0 @@ -from unittest.mock import ANY - -from fastapi import Depends - -from app.dependencies.logger import logger_context -from app.logger import L - -from tests.utils import ADMIN_SUB_ID - - -def test_logger_context(client_admin): - """Test that logger_context adds sub_id, request_id, and forwarded_for to logs.""" - logs = [] - - def capture_sink(message): - logs.append(message.record) - - @client_admin.app.get("/test-logger", dependencies=[Depends(logger_context)]) - def test_endpoint(): - L.info("test message") - return {"ok": True} - - handler_id = L.add(capture_sink, level="INFO") - - try: - headers = {"x-forwarded-for": "127.1.2.3"} - result = client_admin.get("/test-logger", headers=headers) - - assert result.status_code == 200 - assert len(logs) == 1 - - record = logs[0] - assert record["message"] == "test message" - assert record["extra"] == { - "sub_id": ADMIN_SUB_ID, - "request_id": ANY, - "forwarded_for": "127.1.2.3", - } - finally: - L.remove(handler_id) diff --git a/tests/test_auth.py b/tests/test_auth.py index 65b1b70b..1475138c 100644 --- a/tests/test_auth.py +++ b/tests/test_auth.py @@ -8,6 +8,7 @@ from fastapi.security import HTTPAuthorizationCredentials from app.config import settings +from app.context import request_context_provider from app.dependencies import auth as test_module from app.errors import ApiError, ApiErrorCode, AuthErrorReason from app.schemas.auth import UserContext, UserContextWithProjectId, UserProfile, UserProjectGroup @@ -33,6 +34,17 @@ def _clear_cache(): test_module._check_user_info.cache_clear() +@pytest.fixture(autouse=True) +def _set_request_context(): + """Set the request context variable, needed in user_verified. + + At runtime, this is set by RequestContextMiddleware. + """ + token = request_context_provider.set({}) + yield + request_context_provider.reset(token) + + @pytest.fixture def http_client(): with httpx.Client() as client: diff --git a/tests/test_middleware.py b/tests/test_middleware.py new file mode 100644 index 00000000..a1e6db52 --- /dev/null +++ b/tests/test_middleware.py @@ -0,0 +1,173 @@ +from unittest.mock import ANY + +import pytest +from fastapi import Depends + +from app.application import app +from app.dependencies.auth import user_verified +from app.logger import L + +from tests.utils import ADMIN_SUB_ID + + +@pytest.fixture +def logs(): + """Fixture to capture logs.""" + logs = [] + + def capture_sink(message): + logs.append(message.record) + + handler_id = L.add(capture_sink, level="INFO") + yield logs + L.remove(handler_id) + + +@pytest.fixture(scope="session") +def test_authenticated_endpoint() -> str: + """Fixture to add an authenticated test endpoint to the app. + + Use the session scope because the endpoint isn't removed until the end of the tests. + """ + path = "/test-authenticated-endpoint" + + @app.get(path, dependencies=[Depends(user_verified)]) + def test_endpoint(): + L.info("test message") + return {"ok": True} + + return path + + +@pytest.fixture(scope="session") +def test_public_endpoint() -> str: + """Fixture to add a public test endpoint to the app. + + Use the session scope because the endpoint isn't removed until the end of the tests. + """ + path = "/test-public-endpoint" + + @app.get(path) + def test_endpoint(): + L.info("test message") + return {"ok": True} + + return path + + +@pytest.fixture(scope="session") +def test_error_endpoint() -> str: + """Fixture to add a test endpoint that raises an unhandled error. + + Use the session scope because the endpoint isn't removed until the end of the tests. + """ + path = "/test-error-endpoint" + + @app.get(path) + def test_endpoint(): + L.info("test message") + msg = "test error" + raise RuntimeError(msg) + + return path + + +def _filter_logs(logs, keys=("message", "extra")): + return [{k: rec[k] for k in keys} for rec in logs] + + +def test_authenticated_request_context(logs, client_admin, test_authenticated_endpoint): + """Test that the request context middleware adds user_id and request_id to logs.""" + client = client_admin + endpoint = test_authenticated_endpoint + headers = {"x-forwarded-for": "127.1.2.3"} + result = client.get(test_authenticated_endpoint, headers=headers) + + assert result.status_code == 200 + + expected = [ + { + "message": "test message", + "extra": { + "request_id": ANY, + "user_id": ADMIN_SUB_ID, + "serialized": ANY, + }, + }, + { + "message": f"GET http://testserver{endpoint}", + "extra": { + "request_id": ANY, + "user_id": ADMIN_SUB_ID, + "forwarded_for": "127.1.2.3", + "process_time": ANY, + "status_code": 200, + "client": "testclient", + "serialized": ANY, + }, + }, + ] + assert _filter_logs(logs) == expected + + +def test_public_request_context(logs, client_no_auth, test_public_endpoint): + """Test that the request context middleware adds request_id to logs.""" + client = client_no_auth + endpoint = test_public_endpoint + headers = {"x-forwarded-for": "127.1.2.3"} + result = client.get(endpoint, headers=headers) + + assert result.status_code == 200 + + expected = [ + { + "message": "test message", + "extra": { + "request_id": ANY, + "serialized": ANY, + }, + }, + { + "message": f"GET http://testserver{endpoint}", + "extra": { + "request_id": ANY, + "forwarded_for": "127.1.2.3", + "process_time": ANY, + "status_code": 200, + "client": "testclient", + "serialized": ANY, + }, + }, + ] + assert _filter_logs(logs) == expected + + +def test_error_request_context(logs, client_no_auth, test_error_endpoint): + """Test that the request context middleware handles errors.""" + client = client_no_auth + endpoint = test_error_endpoint + headers = {"x-forwarded-for": "127.1.2.3"} + with pytest.raises(RuntimeError, match="test error"): + client.get(endpoint, headers=headers) + + expected = [ + { + "message": "test message", + "extra": { + "request_id": ANY, + "serialized": ANY, + }, + }, + { + "message": f"GET http://testserver{endpoint}", + "extra": { + "request_id": ANY, + "forwarded_for": "127.1.2.3", + "process_time": ANY, + "status_code": 500, + "client": "testclient", + "serialized": ANY, + }, + }, + ] + assert _filter_logs(logs) == expected