Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .env.run-docker
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export LOG_DIAGNOSE=true
export LOG_STANDARD_LOGGER__root=INFO
export LOG_STANDARD_LOGGER__sqlalchemy.engine=INFO
export LOG_STANDARD_LOGGER__sqlalchemy.pool=INFO
export LOG_STANDARD_LOGGER__uvicorn.access=WARNING

export AWS_ACCESS_KEY_ID=entitycore
export AWS_SECRET_ACCESS_KEY=entitycore
Expand Down
1 change: 1 addition & 0 deletions .env.run-local
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@ export LOG_DIAGNOSE=true
export LOG_STANDARD_LOGGER__root=INFO
export LOG_STANDARD_LOGGER__sqlalchemy.engine=INFO
export LOG_STANDARD_LOGGER__sqlalchemy.pool=INFO
export LOG_STANDARD_LOGGER__uvicorn.access=WARNING

export UVICORN_HOST=127.0.0.1
export UVICORN_PORT=8000
Expand Down
2 changes: 2 additions & 0 deletions .env.test-docker
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
# for testing the application with Docker

export LOG_STANDARD_LOGGER__root=WARNING

export DB_HOST=db-test
export DB_PORT=5432
export DB_USER=test
Expand Down
1 change: 1 addition & 0 deletions .env.test-local
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

export APP_VERSION=${APP_VERSION:-1}
export COMMIT_SHA=${COMMIT_SHA:-deadbeef}
export LOG_STANDARD_LOGGER__root=WARNING

export DB_HOST=127.0.0.1
export DB_PORT=5434
Expand Down
25 changes: 2 additions & 23 deletions app/application.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import asyncio
import os
import time
from collections.abc import AsyncIterator
from contextlib import asynccontextmanager
from http import HTTPStatus
Expand All @@ -20,10 +19,9 @@
from app.dependencies.common import forbid_extra_query_params
from app.errors import ApiError, ApiErrorCode
from app.logger import L
from app.middleware import RequestContextMiddleware
from app.routers import router
from app.schemas.api import ErrorResponse
from app.schemas.types import HeaderKey
from app.utils.uuid import create_uuid


@asynccontextmanager
Expand Down Expand Up @@ -141,26 +139,7 @@ async def http_exception_handler(request: Request, exception: StarletteHTTPExcep
allow_methods=["*"],
allow_headers=["*"],
)


@app.middleware("http")
async def add_request_id_header(request: Request, call_next):
"""Generate a unique request-id and add it to the response headers."""
request_id = str(create_uuid())
request.state.request_id = request_id
response = await call_next(request)
response.headers[HeaderKey.request_id] = request_id
return response


@app.middleware("http")
async def add_process_time_header(request: Request, call_next):
"""Calculate the process time and add it to the response headers."""
start_time = time.perf_counter()
response = await call_next(request)
process_time = time.perf_counter() - start_time
response.headers[HeaderKey.process_time] = f"{process_time:.3f}"
return response
app.add_middleware(RequestContextMiddleware)


app.include_router(
Expand Down
2 changes: 1 addition & 1 deletion app/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ class Settings(BaseSettings):
LOG_DIAGNOSE: bool = False
LOG_ENQUEUE: bool = False
LOG_CATCH: bool = True
LOG_STANDARD_LOGGER: dict[str, str] = {"root": "INFO"}
LOG_STANDARD_LOGGER: dict[str, str] = {"root": "INFO", "uvicorn.access": "WARNING"}

KEYCLOAK_URL: str = "https://staging.cell-a.openbraininstitute.org/auth/realms/SBO"
AUTH_CACHE_MAXSIZE: int = 128 # items
Expand Down
12 changes: 12 additions & 0 deletions app/context.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
from contextvars import ContextVar
from typing import TypedDict


class RequestContext(TypedDict, total=False):
"""Request context dictionary."""

request_id: str # Unique identifier for the current request
user_id: str # Keycloak identifier of the user making the request


request_context_provider: ContextVar[RequestContext] = ContextVar("request_context")
15 changes: 14 additions & 1 deletion app/dependencies/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from starlette.requests import Request

from app.config import settings
from app.context import request_context_provider
from app.errors import ApiError, ApiErrorCode, AuthErrorReason
from app.logger import L
from app.schemas.auth import (
Expand Down Expand Up @@ -168,6 +169,12 @@ def _check_user_info(
return user_context


def _enrich_request_context(user_context: UserContext) -> None:
"""Store user_id in contextvar for logging purposes."""
ctx = request_context_provider.get()
ctx["user_id"] = str(user_context.profile.subject)


def user_verified(
project_context: Annotated[OptionalProjectContext, Header()],
token: Annotated[HTTPAuthorizationCredentials | None, Depends(AuthHeader)],
Expand All @@ -181,7 +188,7 @@ def user_verified(
"""
if settings.APP_DISABLE_AUTH:
L.warning("Authentication is disabled: admin role granted, vlab and proj not verified")
return UserContext(
user_context = UserContext(
profile=UserProfile(
subject=UUID(int=0),
name="Admin User",
Expand All @@ -192,6 +199,9 @@ def user_verified(
virtual_lab_id=project_context.virtual_lab_id,
project_id=project_context.project_id,
)
# enrich request context even when the authentication is disabled
_enrich_request_context(user_context)
return user_context

if not token:
raise ApiError(
Expand All @@ -213,6 +223,9 @@ def user_verified(
http_client=request.state.http_client,
)

# enrich request context before potentially raising an exception
_enrich_request_context(user_context)

if not user_context.is_authorized:
match user_context.auth_error_reason:
case AuthErrorReason.NOT_AUTHORIZED_USER | AuthErrorReason.NOT_AUTHORIZED_PROJECT:
Expand Down
35 changes: 0 additions & 35 deletions app/dependencies/logger.py

This file was deleted.

12 changes: 12 additions & 0 deletions app/logger.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from loguru import logger

from app.config import settings
from app.context import request_context_provider

L = logger

Expand Down Expand Up @@ -84,6 +85,16 @@ def str_formatter(record: "loguru.Record") -> str:

def configure_logging() -> int:
"""Configure logging."""

def patcher(record: "loguru.Record") -> None:
"""Add request context (request_id, user_id) to all log records.

This function is automatically applied to every log message across all modules,
enriching them with contextual information from the current request.
"""
ctx = request_context_provider.get({})
record["extra"].update(ctx)

L.remove()
handler_id = L.add(
sink=sys.stderr,
Expand All @@ -94,6 +105,7 @@ def configure_logging() -> int:
enqueue=settings.LOG_ENQUEUE,
catch=settings.LOG_CATCH,
)
L.configure(patcher=patcher)
L.enable("app")
logging.basicConfig(handlers=[InterceptHandler()], level=logging.NOTSET, force=True)
for logger_name, logger_level in settings.LOG_STANDARD_LOGGER.items():
Expand Down
61 changes: 61 additions & 0 deletions app/middleware.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Request context middleware."""

import time
from collections.abc import Awaitable, Callable

from starlette.middleware.base import BaseHTTPMiddleware
from starlette.requests import Request
from starlette.responses import Response

from app.context import RequestContext, request_context_provider
from app.logger import L
from app.schemas.types import HeaderKey
from app.utils.uuid import create_uuid

RequestResponseEndpoint = Callable[[Request], Awaitable[Response]]


class RequestContextMiddleware(BaseHTTPMiddleware):
"""Middleware to initialize request context and log access."""

async def dispatch( # noqa: PLR6301
self,
request: Request,
call_next: RequestResponseEndpoint,
) -> Response:
"""Set request context and log access."""
start_time = time.perf_counter()
request_id = str(create_uuid())
ctx = RequestContext(request_id=request_id)
request_context_provider.set(ctx)

try:
response = await call_next(request)
except Exception:
process_time = time.perf_counter() - start_time
L.error(
"{} {}",
request.method,
str(request.url),
client=request.client.host if request.client else "-",
status_code=500,
process_time=f"{process_time:.3f}",
forwarded_for=request.headers.get(HeaderKey.forwarded_for, ""),
)
raise

process_time = time.perf_counter() - start_time
response.headers[HeaderKey.process_time] = f"{process_time:.3f}"
response.headers[HeaderKey.request_id] = request_id

L.info(
"{} {}",
request.method,
str(request.url),
client=request.client.host if request.client else "-",
status_code=response.status_code,
process_time=f"{process_time:.3f}",
forwarded_for=request.headers.get(HeaderKey.forwarded_for, ""),
)

return response
3 changes: 0 additions & 3 deletions app/routers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from fastapi import APIRouter, Depends

from app.dependencies.auth import user_verified, user_with_service_admin_role
from app.dependencies.logger import logger_context
from app.routers import (
admin,
analysis_notebook_environment,
Expand Down Expand Up @@ -85,7 +84,6 @@
admin.router,
dependencies=[
Depends(user_with_service_admin_role),
Depends(logger_context),
],
)

Expand Down Expand Up @@ -167,6 +165,5 @@
r,
dependencies=[
Depends(user_verified),
Depends(logger_context),
],
)
6 changes: 6 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
from app.db.session import DatabaseSessionManager, configure_database_session_manager
from app.db.types import CellMorphologyGenerationType, EntityType, StorageType
from app.dependencies import auth
from app.logger import configure_logging
from app.schemas.auth import UserContext, UserProfile, UserProjectGroup
from app.schemas.external_url import ExternalUrlCreate

Expand Down Expand Up @@ -100,6 +101,11 @@ def _setup_env_variables():
os.environ["AWS_SESSION_TOKEN"] = "testing" # noqa: S105


@pytest.fixture(scope="session", autouse=True)
def _configure_logging():
configure_logging()


@pytest.fixture(scope="session")
def s3():
"""Return a mocked S3 client."""
Expand Down
40 changes: 0 additions & 40 deletions tests/dependencies/test_logger.py

This file was deleted.

12 changes: 12 additions & 0 deletions tests/test_auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from fastapi.security import HTTPAuthorizationCredentials

from app.config import settings
from app.context import request_context_provider
from app.dependencies import auth as test_module
from app.errors import ApiError, ApiErrorCode, AuthErrorReason
from app.schemas.auth import UserContext, UserContextWithProjectId, UserProfile, UserProjectGroup
Expand All @@ -33,6 +34,17 @@ def _clear_cache():
test_module._check_user_info.cache_clear()


@pytest.fixture(autouse=True)
def _set_request_context():
"""Set the request context variable, needed in user_verified.

At runtime, this is set by RequestContextMiddleware.
"""
token = request_context_provider.set({})
yield
request_context_provider.reset(token)


@pytest.fixture
def http_client():
with httpx.Client() as client:
Expand Down
Loading