diff --git a/code/config/config.py b/code/config/config.py index d4779830b..974e4da0a 100644 --- a/code/config/config.py +++ b/code/config/config.py @@ -65,6 +65,7 @@ class StaticConfig: class ServerConfig: host: str = "localhost" enable_cors: bool = True + cors_trusted_origins: List[str] = field(default_factory=list) max_connections: int = 100 timeout: int = 30 ssl: Optional[SSLConfig] = None @@ -315,6 +316,7 @@ def load_webserver_config(self, path: str = "config_webserver.yaml"): self.server = ServerConfig( host=self._get_config_value(server_data.get("host"), "localhost"), enable_cors=self._get_config_value(server_data.get("enable_cors"), True), + cors_trusted_origins=self._get_config_value(server_data.get("cors_trusted_origins"), ['*']), max_connections=self._get_config_value(server_data.get("max_connections"), 100), timeout=self._get_config_value(server_data.get("timeout"), 30), ssl=ssl_config, diff --git a/code/config/config_webserver.yaml b/code/config/config_webserver.yaml index 74075e0ca..2db1edb5c 100644 --- a/code/config/config_webserver.yaml +++ b/code/config/config_webserver.yaml @@ -5,13 +5,16 @@ static_directory: ../../ # if development, various config params can be overridden with query params # obviously, this cannot be allowed in production. # in testing mode, exceptions are raised instead of being caught for better error visibility -mode: development # or production or testing. +mode: development # or production or testing. # Additional optional configurations server: host: 0.0.0.0 enable_cors: true + # List of trusted origins. Set "*" for wildcard. + cors_trusted_origins: + - '*' max_connections: 100 timeout: 30 # seconds diff --git a/code/webserver/WebServer.py b/code/webserver/WebServer.py index 29a280a02..1cd3b5b71 100644 --- a/code/webserver/WebServer.py +++ b/code/webserver/WebServer.py @@ -129,7 +129,16 @@ async def send_response(status_code, response_headers, end_response=False): # Add CORS headers if enabled if CONFIG.server.enable_cors and 'origin' in headers: - response_headers['Access-Control-Allow-Origin'] = '*' + + if CONFIG.server.cors_trusted_origins: + # If the origin header matches one of the defined origins in server.cors_trusted_origins + origin = headers.get('origin', '') + if origin in CONFIG.server.cors_trusted_origins: + response_headers['Access-Control-Allow-Origin'] = origin + # If the wildcard is set we use the wildcard anyways + if '*' in CONFIG.server.cors_trusted_origins: + response_headers['Access-Control-Allow-Origin'] = '*' + response_headers['Access-Control-Allow-Methods'] = 'GET, POST, OPTIONS' response_headers['Access-Control-Allow-Headers'] = 'Content-Type'