Skip to content

Commit 61ba55a

Browse files
authored
feat(guard): support routing via x_rivet_* query params (#3238)
1 parent df52371 commit 61ba55a

File tree

3 files changed

+50
-11
lines changed

3 files changed

+50
-11
lines changed

packages/core/guard/server/src/routing/mod.rs

Lines changed: 30 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -35,14 +35,17 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
3535

3636
tracing::debug!("Routing request for hostname: {host}, path: {path}");
3737

38+
// Parse query parameters
39+
let query_params = parse_query_params(path);
40+
3841
// Check if this is a WebSocket upgrade request
3942
let is_websocket = headers
4043
.get("upgrade")
4144
.and_then(|v| v.to_str().ok())
4245
.map(|v| v.eq_ignore_ascii_case("websocket"))
4346
.unwrap_or(false);
4447

45-
// Extract target from WebSocket protocol or HTTP header
48+
// Extract target from WebSocket protocol, HTTP header, or query param
4649
let target = if is_websocket {
4750
// For WebSocket, parse the sec-websocket-protocol header
4851
headers
@@ -55,15 +58,21 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
5558
.map(|p| p.trim())
5659
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TARGET))
5760
})
61+
// Fallback to query parameter if protocol not provided
62+
.or_else(|| query_params.get("x_rivet_target").map(|s| s.as_str()))
5863
} else {
59-
// For HTTP, use the x-rivet-target header
60-
headers.get(X_RIVET_TARGET).and_then(|x| x.to_str().ok())
64+
// For HTTP, use the x-rivet-target header, fallback to query param
65+
headers
66+
.get(X_RIVET_TARGET)
67+
.and_then(|x| x.to_str().ok())
68+
.or_else(|| query_params.get("x_rivet_target").map(|s| s.as_str()))
6169
};
6270

6371
// Read target
6472
if let Some(target) = target {
6573
if let Some(routing_output) =
66-
runner::route_request(&ctx, target, host, path, headers).await?
74+
runner::route_request(&ctx, target, host, path, headers, &query_params)
75+
.await?
6776
{
6877
return Ok(routing_output);
6978
}
@@ -76,6 +85,7 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
7685
path,
7786
headers,
7887
is_websocket,
88+
&query_params,
7989
)
8090
.await?
8191
{
@@ -109,3 +119,19 @@ pub fn create_routing_function(ctx: StandaloneCtx, shared_state: SharedState) ->
109119
},
110120
)
111121
}
122+
123+
/// Parse query parameters from a path string
124+
fn parse_query_params(path: &str) -> std::collections::HashMap<String, String> {
125+
let mut params = std::collections::HashMap::new();
126+
127+
if let Some(query_start) = path.find('?') {
128+
// Strip fragment if present
129+
let query = &path[query_start + 1..].split('#').next().unwrap_or("");
130+
// Use url::form_urlencoded to properly decode query parameters
131+
for (key, value) in url::form_urlencoded::parse(query.as_bytes()) {
132+
params.insert(key.into_owned(), value.into_owned());
133+
}
134+
}
135+
136+
params
137+
}

packages/core/guard/server/src/routing/pegboard_gateway.rs

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,13 +22,14 @@ pub async fn route_request(
2222
path: &str,
2323
headers: &hyper::HeaderMap,
2424
is_websocket: bool,
25+
query_params: &std::collections::HashMap<String, String>,
2526
) -> Result<Option<RoutingOutput>> {
2627
// Check target
2728
if target != "actor" {
2829
return Ok(None);
2930
}
3031

31-
// Extract actor ID from WebSocket protocol or HTTP header
32+
// Extract actor ID from WebSocket protocol, HTTP header, or query param
3233
let actor_id_str = if is_websocket {
3334
// For WebSocket, parse the sec-websocket-protocol header
3435
headers
@@ -41,22 +42,26 @@ pub async fn route_request(
4142
.map(|p| p.trim())
4243
.find_map(|p| p.strip_prefix(WS_PROTOCOL_ACTOR))
4344
})
45+
// Fallback to query parameter if protocol not provided
46+
.or_else(|| query_params.get("x_rivet_actor").map(|s| s.as_str()))
4447
.ok_or_else(|| {
4548
crate::errors::MissingHeader {
46-
header: "`rivet_actor.*` protocol in sec-websocket-protocol".to_string(),
49+
header: "`rivet_actor.*` protocol in sec-websocket-protocol or x_rivet_actor query parameter".to_string(),
4750
}
4851
.build()
4952
})?
5053
} else {
51-
// For HTTP, use the x-rivet-actor header
54+
// For HTTP, use the x-rivet-actor header, fallback to query param
5255
headers
5356
.get(X_RIVET_ACTOR)
5457
.map(|x| x.to_str())
5558
.transpose()
5659
.context("invalid x-rivet-actor header")?
60+
// Fallback to query parameter if header not provided
61+
.or_else(|| query_params.get("x_rivet_actor").map(|s| s.as_str()))
5762
.ok_or_else(|| {
5863
crate::errors::MissingHeader {
59-
header: X_RIVET_ACTOR.to_string(),
64+
header: format!("{} header or x_rivet_actor query parameter", X_RIVET_ACTOR),
6065
}
6166
.build()
6267
})?

packages/core/guard/server/src/routing/runner.rs

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ pub async fn route_request(
1414
host: &str,
1515
path: &str,
1616
headers: &hyper::HeaderMap,
17+
query_params: &std::collections::HashMap<String, String>,
1718
) -> Result<Option<RoutingOutput>> {
1819
if target != "runner" {
1920
return Ok(None);
@@ -57,7 +58,7 @@ pub async fn route_request(
5758

5859
// Check auth (if enabled)
5960
if let Some(auth) = &ctx.config().auth {
60-
// Extract token
61+
// Extract token from protocol, header, or query param
6162
let token = if is_websocket {
6263
headers
6364
.get(SEC_WEBSOCKET_PROTOCOL)
@@ -68,19 +69,26 @@ pub async fn route_request(
6869
.map(|p| p.trim())
6970
.find_map(|p| p.strip_prefix(WS_PROTOCOL_TOKEN))
7071
})
72+
// Fallback to query parameter if protocol not provided
73+
.or_else(|| query_params.get("x_rivet_token").map(|s| s.as_str()))
7174
.ok_or_else(|| {
7275
crate::errors::MissingHeader {
73-
header: "`rivet_token.*` protocol in sec-websocket-protocol".to_string(),
76+
header: "`rivet_token.*` protocol in sec-websocket-protocol or x_rivet_token query parameter".to_string(),
7477
}
7578
.build()
7679
})?
7780
} else {
7881
headers
7982
.get(X_RIVET_TOKEN)
8083
.and_then(|x| x.to_str().ok())
84+
// Fallback to query parameter if header not provided
85+
.or_else(|| query_params.get("x_rivet_token").map(|s| s.as_str()))
8186
.ok_or_else(|| {
8287
crate::errors::MissingHeader {
83-
header: X_RIVET_TOKEN.to_string(),
88+
header: format!(
89+
"{} header or x_rivet_token query parameter",
90+
X_RIVET_TOKEN
91+
),
8492
}
8593
.build()
8694
})?

0 commit comments

Comments
 (0)