@@ -35,14 +35,17 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
3535
3636 tracing:: debug!( "Routing request for hostname: {host}, path: {path}" ) ;
3737
38+ // Parse query parameters
39+ let query_params = parse_query_params ( path) ;
40+
3841 // Check if this is a WebSocket upgrade request
3942 let is_websocket = headers
4043 . get ( "upgrade" )
4144 . and_then ( |v| v. to_str ( ) . ok ( ) )
4245 . map ( |v| v. eq_ignore_ascii_case ( "websocket" ) )
4346 . unwrap_or ( false ) ;
4447
45- // Extract target from WebSocket protocol or HTTP header
48+ // Extract target from WebSocket protocol, HTTP header, or query param
4649 let target = if is_websocket {
4750 // For WebSocket, parse the sec-websocket-protocol header
4851 headers
@@ -55,15 +58,21 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
5558 . map ( |p| p. trim ( ) )
5659 . find_map ( |p| p. strip_prefix ( WS_PROTOCOL_TARGET ) )
5760 } )
61+ // Fallback to query parameter if protocol not provided
62+ . or_else ( || query_params. get ( "x_rivet_target" ) . map ( |s| s. as_str ( ) ) )
5863 } else {
59- // For HTTP, use the x-rivet-target header
60- headers. get ( X_RIVET_TARGET ) . and_then ( |x| x. to_str ( ) . ok ( ) )
64+ // For HTTP, use the x-rivet-target header, fallback to query param
65+ headers
66+ . get ( X_RIVET_TARGET )
67+ . and_then ( |x| x. to_str ( ) . ok ( ) )
68+ . or_else ( || query_params. get ( "x_rivet_target" ) . map ( |s| s. as_str ( ) ) )
6169 } ;
6270
6371 // Read target
6472 if let Some ( target) = target {
6573 if let Some ( routing_output) =
66- runner:: route_request ( & ctx, target, host, path, headers) . await ?
74+ runner:: route_request ( & ctx, target, host, path, headers, & query_params)
75+ . await ?
6776 {
6877 return Ok ( routing_output) ;
6978 }
@@ -76,6 +85,7 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
7685 path,
7786 headers,
7887 is_websocket,
88+ & query_params,
7989 )
8090 . await ?
8191 {
@@ -109,3 +119,19 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
109119 } ,
110120 )
111121}
122+
123+ /// Parse query parameters from a path string
124+ fn parse_query_params ( path : & str ) -> std:: collections:: HashMap < String , String > {
125+ let mut params = std:: collections:: HashMap :: new ( ) ;
126+
127+ if let Some ( query_start) = path. find ( '?' ) {
128+ // Strip fragment if present
129+ let query = & path[ query_start + 1 ..] . split ( '#' ) . next ( ) . unwrap_or ( "" ) ;
130+ // Use url::form_urlencoded to properly decode query parameters
131+ for ( key, value) in url:: form_urlencoded:: parse ( query. as_bytes ( ) ) {
132+ params. insert ( key. into_owned ( ) , value. into_owned ( ) ) ;
133+ }
134+ }
135+
136+ params
137+ }
0 commit comments