From adf63b33743557fcc5f17fe3623354e93597f895 Mon Sep 17 00:00:00 2001 From: Christian Schmitt Date: Tue, 3 Jun 2025 07:22:18 +0200 Subject: [PATCH 1/2] # Improve CORS configuration to only allow trusted origins feat: improve security by enabling CORS with trusted origins git diff -U0 --stat code/webserver/WebServer.py config/config_webserver.yaml # Conflicts: # code/config/config_webserver.yaml # code/webserver/WebServer.py --- code/config/config_webserver.yaml | 3 ++- code/webserver/WebServer.py | 20 ++++++++++++++------ 2 files changed, 16 insertions(+), 7 deletions(-) diff --git a/code/config/config_webserver.yaml b/code/config/config_webserver.yaml index 74075e0ca..94c6303b9 100644 --- a/code/config/config_webserver.yaml +++ b/code/config/config_webserver.yaml @@ -5,13 +5,14 @@ static_directory: ../../ # if development, various config params can be overridden with query params # obviously, this cannot be allowed in production. # in testing mode, exceptions are raised instead of being caught for better error visibility -mode: development # or production or testing. +mode: development # or production or testing. # Additional optional configurations server: host: 0.0.0.0 enable_cors: true + cors_trusted_origins: "*" # Comma-separated list of trusted origins. Set "*" for wildcard. https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Origin max_connections: 100 timeout: 30 # seconds diff --git a/code/webserver/WebServer.py b/code/webserver/WebServer.py index 224cb2962..55b4cbf54 100644 --- a/code/webserver/WebServer.py +++ b/code/webserver/WebServer.py @@ -106,8 +106,16 @@ async def send_response(status_code, response_headers, end_response=False): writer.write(status_line.encode('utf-8')) # Add CORS headers if enabled - if CONFIG.server.enable_cors and 'Origin' in headers: - response_headers['Access-Control-Allow-Origin'] = '*' + if CONFIG.server.enable_cors and 'origin' in headers: + origin = headers.get('origin', '') + + # If trusted origins list is empty or the origin is in the trusted origins list + if CONFIG.server.cors_trusted_origins: + if origin in CONFIG.server.cors_trusted_origins: + response_headers['Access-Control-Allow-Origin'] = origin + if '*' in CONFIG.server.cors_trusted_origins: + response_headers['Access-Control-Allow-Origin'] = '*' + response_headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS' response_headers['Access-Control-Allow-Headers'] = 'Content-Type' @@ -427,19 +435,19 @@ def get_port(): if __name__ == "__main__": import argparse - + # Parse command line arguments parser = argparse.ArgumentParser(description="NLWeb Server") - parser.add_argument('--mode', choices=['development', 'production', 'testing'], + parser.add_argument('--mode', choices=['development', 'production', 'testing'], help='Override the application mode from config') parser.add_argument('command', nargs='?', help='Optional command (e.g., https)') args = parser.parse_args() - + # Override mode if specified if args.mode: CONFIG.set_mode(args.mode) print(f"Mode overridden to: {args.mode}") - + try: port = get_port() From c55e858a5657ab433ff6bb6f7b5bb70210b2cc35 Mon Sep 17 00:00:00 2001 From: Christian Schmitt Date: Tue, 10 Jun 2025 06:39:44 +0200 Subject: [PATCH 2/2] fix: enhance CORS configuration to only allow trusted origins fix fix: enhance CORS configuration to only allow trusted origins fix: enhance CORS configuration to only allow trusted origins --- code/config/config.py | 2 ++ code/config/config_webserver.yaml | 4 +++- code/webserver/WebServer.py | 26 +++++++++++++++----------- 3 files changed, 20 insertions(+), 12 deletions(-) diff --git a/code/config/config.py b/code/config/config.py index c7cdf5b15..d4f189a09 100644 --- a/code/config/config.py +++ b/code/config/config.py @@ -64,6 +64,7 @@ class StaticConfig: class ServerConfig: host: str = "localhost" enable_cors: bool = True + cors_trusted_origins: List[str] = field(default_factory=list) max_connections: int = 100 timeout: int = 30 ssl: Optional[SSLConfig] = None @@ -305,6 +306,7 @@ def load_webserver_config(self, path: str = "config_webserver.yaml"): self.server = ServerConfig( host=self._get_config_value(server_data.get("host"), "localhost"), enable_cors=self._get_config_value(server_data.get("enable_cors"), True), + cors_trusted_origins=self._get_config_value(server_data.get("cors_trusted_origins"), ['*']), max_connections=self._get_config_value(server_data.get("max_connections"), 100), timeout=self._get_config_value(server_data.get("timeout"), 30), ssl=ssl_config, diff --git a/code/config/config_webserver.yaml b/code/config/config_webserver.yaml index 94c6303b9..2db1edb5c 100644 --- a/code/config/config_webserver.yaml +++ b/code/config/config_webserver.yaml @@ -12,7 +12,9 @@ mode: development # or production or testing. server: host: 0.0.0.0 enable_cors: true - cors_trusted_origins: "*" # Comma-separated list of trusted origins. Set "*" for wildcard. https://developer.mozilla.org/en-US/docs/Web/HTTP/Reference/Headers/Access-Control-Allow-Origin + # List of trusted origins. Set "*" for wildcard. + cors_trusted_origins: + - '*' max_connections: 100 timeout: 30 # seconds diff --git a/code/webserver/WebServer.py b/code/webserver/WebServer.py index 7bfd763b6..add9b3b24 100644 --- a/code/webserver/WebServer.py +++ b/code/webserver/WebServer.py @@ -40,10 +40,10 @@ async def handle_client(reader, writer, fulfill_request): if not request_line: connection_alive = False return - + # Debug logging to see what we're receiving logger.debug(f"[{request_id}] Raw request bytes: {request_line[:100]}") - + try: request_line = request_line.decode('utf-8', errors='replace').rstrip('\r\n') except Exception as decode_error: @@ -71,13 +71,13 @@ async def handle_client(reader, writer, fulfill_request): header_line = await reader.readline() if not header_line or header_line == b'\r\n': break - + try: hdr = header_line.decode('utf-8', errors='replace').rstrip('\r\n') except Exception as decode_error: logger.error(f"[{request_id}] Failed to decode header: {decode_error}, raw bytes: {header_line[:100]}") continue - + if ":" not in hdr: continue name, value = hdr.split(":", 1) @@ -129,15 +129,19 @@ async def send_response(status_code, response_headers, end_response=False): # Add CORS headers if enabled if CONFIG.server.enable_cors and 'origin' in headers: - origin = headers.get('origin', '') - # If trusted origins list is empty or the origin is in the trusted origins list if CONFIG.server.cors_trusted_origins: + # If the origin header matches one of the defined origins in server.cors_trusted_origins + origin = headers.get('origin', '') if origin in CONFIG.server.cors_trusted_origins: response_headers['Access-Control-Allow-Origin'] = origin + # If the wildcard is set we use the wildcard anyways if '*' in CONFIG.server.cors_trusted_origins: response_headers['Access-Control-Allow-Origin'] = '*' + response_headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS' + response_headers['Access-Control-Allow-Headers'] = 'Content-Type' + # Send headers for header_name, header_value in response_headers.items(): header_line = f"{header_name}: {header_value}\r\n" @@ -376,16 +380,16 @@ async def fulfill_request(method, path, headers, query_params, body, send_respon try: # Create a retriever client retriever = get_vector_db_client(query_params=query_params) - + # Get the list of sites sites = await retriever.get_sites() - + # Prepare the response with message-type response_data = { "message-type": "sites", "sites": sites } - + if streaming: # Set proper headers for server-sent events (SSE) response_headers = { @@ -394,10 +398,10 @@ async def fulfill_request(method, path, headers, query_params, body, send_respon 'Connection': 'keep-alive', 'X-Accel-Buffering': 'no' # Disable proxy buffering } - + # Send SSE headers await send_response(200, response_headers) - + # Send the sites data as an SSE event await send_chunk(f"data: {json.dumps(response_data)}\n\n", end_response=True) else: