diff --git a/backend/app/config.py b/backend/app/config.py index 06146d15..49d5f403 100644 --- a/backend/app/config.py +++ b/backend/app/config.py @@ -41,6 +41,17 @@ def _bool_env(name: str, default: bool) -> bool: return raw_value.strip().lower() in {"1", "true", "yes", "on"} +def _required_env(name: str) -> str: + """Get a required environment variable. Raise error if not set.""" + value = os.getenv(name) + if not value or not value.strip(): + raise ValueError( + f"Required environment variable '{name}' is not set. " + f"Please set it before starting the application." + ) + return value + + class Settings: """Application settings loaded from environment variables.""" @@ -50,6 +61,7 @@ class Settings: max_request_bytes: int = _int_env("MAX_REQUEST_BYTES", 1048576) rate_limit_requests: int = _int_env("RATE_LIMIT_REQUESTS", 120) rate_limit_window_seconds: int = _int_env("RATE_LIMIT_WINDOW_SECONDS", 60) + trust_proxy_headers: bool = _bool_env("TRUST_PROXY_HEADERS", False) cache_enabled: bool = _bool_env("CACHE_ENABLED", True) cache_ttl_seconds: int = _int_env("CACHE_TTL_SECONDS", 300) cache_max_entries: int = _int_env("CACHE_MAX_ENTRIES", 100) @@ -59,7 +71,7 @@ class Settings: enable_docs: bool = _bool_env("ENABLE_DOCS", False) public_root_info: bool = _bool_env("PUBLIC_ROOT_INFO", False) database_url: str = os.getenv("DATABASE_URL", "sqlite:///./assistant.db") - jwt_secret: str = os.getenv("JWT_SECRET", "change-this-in-production-min-32-bytes") + jwt_secret: str = _required_env("JWT_SECRET") jwt_algorithm: str = os.getenv("JWT_ALGORITHM", "HS256") access_token_minutes: int = _int_env("ACCESS_TOKEN_MINUTES", 720) llm_enabled: bool = _bool_env("LLM_ENABLED", False) diff --git a/backend/app/middleware.py b/backend/app/middleware.py index 9bcfbd09..5c9e682a 100644 --- a/backend/app/middleware.py +++ b/backend/app/middleware.py @@ -16,9 +16,15 @@ def get_client_key(request: Request) -> str: - xff = request.headers.get("x-forwarded-for", "").split(",")[0].strip() - if xff: - return xff + """Extract client IP for rate limiting. + + Only uses X-Forwarded-For if TRUST_PROXY_HEADERS is enabled. + Falls back to direct connection IP if proxy headers are not trusted. + """ + if settings.trust_proxy_headers: + xff = request.headers.get("x-forwarded-for", "").split(",")[-1].strip() + if xff and xff != "unknown": + return xff if request.client and request.client.host: return request.client.host return "unknown" diff --git a/backend/app/routers/analyze.py b/backend/app/routers/analyze.py index ede1fcd5..18c5200b 100644 --- a/backend/app/routers/analyze.py +++ b/backend/app/routers/analyze.py @@ -266,6 +266,7 @@ async def analyze_zip(request: Request, file: UploadFile = File(...)): results: list[dict] = [] skipped_files: list[str] = [] total_size = 0 + MAX_PER_FILE_BYTES = 2 * 1024 * 1024 # 2MB per file with archive: members = [ @@ -307,14 +308,22 @@ async def analyze_zip(request: Request, file: UploadFile = File(...)): ) continue - if total_size + info.file_size > MAX_ZIP_TOTAL_BYTES: + raw = archive.read(info) + decompressed_size = len(raw) + + if decompressed_size > MAX_PER_FILE_BYTES: + raise HTTPException( + status_code=400, + detail=f"File '{safe_name}' exceeds 2MB limit after decompression", + ) + + if total_size + decompressed_size > MAX_ZIP_TOTAL_BYTES: raise HTTPException( status_code=400, detail="ZIP source files exceed the 5MB total limit", ) - raw = archive.read(info) - total_size += len(raw) + total_size += decompressed_size try: code = raw.decode("utf-8") diff --git a/backend/app/routers/history.py b/backend/app/routers/history.py index 92b2a4fe..4cad7d8e 100644 --- a/backend/app/routers/history.py +++ b/backend/app/routers/history.py @@ -3,9 +3,11 @@ """ from __future__ import annotations -from fastapi import APIRouter, HTTPException, Query +from fastapi import APIRouter, Depends, HTTPException, Query from pydantic import BaseModel, Field +from ..security import get_current_user +from ..models import User from ..services import database router = APIRouter() @@ -29,8 +31,12 @@ class HistoryEntry(BaseModel): @router.post("/", response_model=dict, status_code=201) -async def save_history(body: HistorySaveRequest): +async def save_history( + body: HistorySaveRequest, + current_user: User = Depends(get_current_user), +): entry_id = await database.save_entry( + user_id=current_user.id, code=body.code, language=body.language, score=body.score, @@ -43,21 +49,34 @@ async def save_history(body: HistorySaveRequest): async def get_history( limit: int = Query(20, ge=1, le=100), offset: int = Query(0, ge=0), + current_user: User = Depends(get_current_user), ): - return await database.get_entries(limit=limit, offset=offset) + return await database.get_entries( + user_id=current_user.id, + limit=limit, + offset=offset, + ) @router.get("/search", response_model=list[HistoryEntry]) async def search_history( q: str = Query(..., min_length=1), limit: int = Query(20, ge=1, le=100), + current_user: User = Depends(get_current_user), ): - return await database.search_entries(q=q, limit=limit) + return await database.search_entries( + user_id=current_user.id, + q=q, + limit=limit, + ) @router.delete("/{entry_id}", response_model=dict) -async def delete_history(entry_id: int): - deleted = await database.delete_entry(entry_id) +async def delete_history( + entry_id: int, + current_user: User = Depends(get_current_user), +): + deleted = await database.delete_entry(entry_id, user_id=current_user.id) if not deleted: raise HTTPException(status_code=404, detail="History entry not found.") return {"id": entry_id, "status": "deleted"}