@@ -17,6 +17,7 @@ use std::net::SocketAddr;
1717use std:: sync:: Arc ;
1818use storage:: create_storage;
1919use tower_http:: cors:: { Any , CorsLayer } ;
20+ use tracing:: warn;
2021use tracing_subscriber:: { layer:: SubscriberExt , util:: SubscriberInitExt } ;
2122
2223use crate :: auth:: decode_key;
@@ -94,12 +95,7 @@ async fn main() -> anyhow::Result<()> {
9495 state. clone ( ) ,
9596 add_public_key_to_state,
9697 ) )
97- . layer (
98- CorsLayer :: new ( )
99- . allow_origin ( Any )
100- . allow_methods ( Any )
101- . allow_headers ( Any ) ,
102- )
98+ . layer ( build_cors_layer ( & config) )
10399 . with_state ( state) ;
104100
105101 // Start server
@@ -130,3 +126,43 @@ async fn add_public_key_to_state(
130126
131127 next. run ( request) . await
132128}
129+
130+ /// Build CORS layer based on configuration
131+ /// If CORS_ALLOWED_ORIGINS is set, use those specific origins
132+ /// Otherwise, allow all origins (for development)
133+ fn build_cors_layer ( config : & Config ) -> CorsLayer {
134+ if let Some ( ref allowed_origins) = config. cors_allowed_origins {
135+ // Parse comma-separated list of origins
136+ let origins: Vec < & str > = allowed_origins. split ( ',' ) . map ( |s| s. trim ( ) ) . collect ( ) ;
137+
138+ if origins. is_empty ( ) || ( origins. len ( ) == 1 && origins[ 0 ] == "*" ) {
139+ // Allow all origins
140+ tracing:: warn!( "CORS configured to allow all origins - this should not be used in production!" ) ;
141+ CorsLayer :: new ( )
142+ . allow_origin ( Any )
143+ . allow_methods ( Any )
144+ . allow_headers ( Any )
145+ } else {
146+ // Allow specific origins
147+ tracing:: info!( "CORS configured to allow specific origins: {:?}" , origins) ;
148+ let mut cors = CorsLayer :: new ( ) ;
149+
150+ for origin in origins {
151+ cors = cors. allow_origin ( origin. parse :: < axum:: http:: HeaderValue > ( ) . unwrap_or_else ( |_| {
152+ tracing:: warn!( "Invalid origin '{}', skipping" , origin) ;
153+ axum:: http:: HeaderValue :: from_static ( "" )
154+ } ) ) ;
155+ }
156+
157+ cors. allow_methods ( Any )
158+ . allow_headers ( Any )
159+ }
160+ } else {
161+ // Default: allow all origins (development mode)
162+ tracing:: warn!( "CORS_ALLOWED_ORIGINS not set, allowing all origins - configure this for production!" ) ;
163+ CorsLayer :: new ( )
164+ . allow_origin ( Any )
165+ . allow_methods ( Any )
166+ . allow_headers ( Any )
167+ }
168+ }
0 commit comments