Skip to content

Commit 36c0374

Browse files
committed
cleanup oidc tests
Refactor OIDC login simulation by extracting query parameter logic into a separate function and renaming the callback function for clarity. This improves code readability and maintainability.
1 parent 01102d7 commit 36c0374

File tree

1 file changed

+48
-96
lines changed

1 file changed

+48
-96
lines changed

tests/oidc/mod.rs

Lines changed: 48 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,16 @@ use actix_web::{
44
body::MessageBody,
55
cookie::Cookie,
66
dev::ServiceResponse,
7-
http::header,
7+
http::{header, StatusCode},
8+
test,
89
web::{self, Data},
910
App, HttpResponse, HttpServer, Responder,
1011
};
1112
use base64::Engine;
13+
use openidconnect::url::Url;
1214
use serde::{Deserialize, Serialize};
1315
use serde_json::json;
14-
use sqlpage::AppState;
16+
use sqlpage::{webserver::http::create_app, AppState};
1517
use std::collections::HashMap;
1618
use std::sync::{Arc, Mutex};
1719
use tokio::sync::oneshot;
@@ -273,6 +275,14 @@ where
273275
resp
274276
}
275277

278+
fn get_query_param(url: &Url, name: &str) -> String {
279+
url.query_pairs()
280+
.find(|(k, _)| k == name)
281+
.unwrap()
282+
.1
283+
.to_string()
284+
}
285+
276286
async fn setup_oidc_test_state(
277287
provider_mutator: impl FnOnce(&mut ProviderState),
278288
) -> (Data<AppState>, FakeOidcProvider) {
@@ -311,63 +321,61 @@ async fn setup_oidc_test_state(
311321
(app_data, provider)
312322
}
313323

314-
async fn simulate_oidc_login<S, B>(
324+
async fn perform_oidc_callback<S, B>(
315325
app: &S,
316326
provider: &FakeOidcProvider,
317327
protected_path: &str,
318-
) -> ServiceResponse<B>
328+
state_override: Option<String>,
329+
) -> (ServiceResponse<B>, Vec<Cookie<'static>>)
319330
where
320331
S: Service<Request, Response = ServiceResponse<B>, Error = actix_web::Error>,
321332
B: MessageBody,
322333
{
323-
use actix_web::{http::StatusCode, test};
324-
use openidconnect::url::Url;
325-
326334
let mut cookies: Vec<Cookie<'static>> = Vec::new();
327335

328-
// 1. Request protected page
329336
let req = test::TestRequest::get().uri(protected_path);
330337
let resp = make_request_with_session(app, req, &mut cookies).await;
331338
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
332339

333-
// 2. Extract info from redirect to OIDC provider
334340
let location = resp.headers().get("location").unwrap().to_str().unwrap();
335341
let auth_url = Url::parse(location).unwrap();
336342

337-
let state = auth_url
338-
.query_pairs()
339-
.find(|(k, _)| k == "state")
340-
.unwrap()
341-
.1
342-
.to_string();
343-
let nonce = auth_url
344-
.query_pairs()
345-
.find(|(k, _)| k == "nonce")
346-
.unwrap()
347-
.1
348-
.to_string();
349-
let redirect_uri = auth_url
350-
.query_pairs()
351-
.find(|(k, _)| k == "redirect_uri")
352-
.unwrap()
353-
.1
354-
.to_string();
343+
let state = get_query_param(&auth_url, "state");
344+
let nonce = get_query_param(&auth_url, "nonce");
345+
let redirect_uri = get_query_param(&auth_url, "redirect_uri");
355346

356-
// 3. Simulate user login at provider
357347
provider.store_auth_code("test_auth_code".to_string(), nonce);
358348

359-
// 4. Request the callback URL
360-
let callback_url = format!("{}?code=test_auth_code&state={}", redirect_uri, state);
349+
let callback_state = state_override.unwrap_or(state);
350+
let callback_url = format!(
351+
"{}?code=test_auth_code&state={}",
352+
redirect_uri, callback_state
353+
);
361354
let parsed_callback_url = Url::parse(&callback_url).unwrap();
362355
let callback_req = test::TestRequest::get().uri(&format!(
363356
"{}?{}",
364357
parsed_callback_url.path(),
365358
parsed_callback_url.query().unwrap_or_default()
366359
));
367360
let callback_resp = make_request_with_session(app, callback_req, &mut cookies).await;
361+
(callback_resp, cookies)
362+
}
363+
364+
async fn simulate_oidc_login<S, B>(
365+
app: &S,
366+
provider: &FakeOidcProvider,
367+
protected_path: &str,
368+
) -> ServiceResponse<B>
369+
where
370+
S: Service<Request, Response = ServiceResponse<B>, Error = actix_web::Error>,
371+
B: MessageBody,
372+
{
373+
use actix_web::{http::StatusCode, test};
374+
375+
let (callback_resp, mut cookies) =
376+
perform_oidc_callback(app, provider, protected_path, None).await;
368377
assert_eq!(callback_resp.status(), StatusCode::SEE_OTHER);
369378

370-
// 5. Follow the final redirect back to the protected page
371379
let final_location = callback_resp
372380
.headers()
373381
.get("location")
@@ -404,9 +412,6 @@ async fn test_fake_provider_discovery() {
404412

405413
#[actix_web::test]
406414
async fn test_oidc_happy_path() {
407-
use actix_web::http::StatusCode;
408-
use sqlpage::webserver::http::create_app;
409-
410415
let (app_data, provider) = setup_oidc_test_state(|_| {}).await;
411416
let app = actix_web::test::init_service(create_app(app_data.clone())).await;
412417
let final_resp = simulate_oidc_login(&app, &provider, "/").await;
@@ -417,61 +422,12 @@ async fn assert_oidc_login_fails(
417422
provider_mutator: impl FnOnce(&mut ProviderState),
418423
state_override: Option<String>,
419424
) {
420-
use actix_web::{http::StatusCode, test};
421-
use openidconnect::url::Url;
422-
use sqlpage::webserver::http::create_app;
423-
424425
let (app_data, provider) = setup_oidc_test_state(provider_mutator).await;
425426
let app = actix_web::test::init_service(create_app(app_data.clone())).await;
426427

427-
let mut cookies: Vec<Cookie<'static>> = Vec::new();
428-
429-
// 1. Request protected page
430-
let req = test::TestRequest::get().uri("/");
431-
let resp = make_request_with_session(&app, req, &mut cookies).await;
432-
assert_eq!(resp.status(), StatusCode::SEE_OTHER);
433-
434-
// 2. Extract info
435-
let location = resp.headers().get("location").unwrap().to_str().unwrap();
436-
let auth_url = Url::parse(location).unwrap();
437-
let state = auth_url
438-
.query_pairs()
439-
.find(|(k, _)| k == "state")
440-
.unwrap()
441-
.1
442-
.to_string();
443-
let nonce = auth_url
444-
.query_pairs()
445-
.find(|(k, _)| k == "nonce")
446-
.unwrap()
447-
.1
448-
.to_string();
449-
let redirect_uri = auth_url
450-
.query_pairs()
451-
.find(|(k, _)| k == "redirect_uri")
452-
.unwrap()
453-
.1
454-
.to_string();
455-
456-
provider.store_auth_code("test_auth_code".to_string(), nonce);
457-
458-
// 3. Request callback URL (with potential state override)
459-
let callback_state = state_override.unwrap_or(state);
460-
let callback_url = format!(
461-
"{}?code=test_auth_code&state={}",
462-
redirect_uri, callback_state
463-
);
464-
465-
let parsed_callback_url = Url::parse(&callback_url).unwrap();
466-
let callback_req = test::TestRequest::get().uri(&format!(
467-
"{}?{}",
468-
parsed_callback_url.path(),
469-
parsed_callback_url.query().unwrap_or_default()
470-
));
428+
let (callback_resp, cookies) =
429+
perform_oidc_callback(&app, &provider, "/", state_override).await;
471430

472-
let callback_resp = make_request_with_session(&app, callback_req, &mut cookies).await;
473-
474-
// 4. Assert failure
475431
assert_eq!(callback_resp.status(), StatusCode::SEE_OTHER);
476432
let location = callback_resp
477433
.headers()
@@ -507,16 +463,6 @@ async fn assert_oidc_callback_fails_with_bad_jwt(
507463
.await;
508464
}
509465

510-
async fn assert_oidc_callback_fails_with_bad_signature() {
511-
assert_oidc_login_fails(
512-
|state| {
513-
state.jwt_customizer = Some(Box::new(|claims, _| make_jwt(&claims, "wrong_secret")));
514-
},
515-
None,
516-
)
517-
.await;
518-
}
519-
520466
#[actix_web::test]
521467
async fn test_oidc_csrf_state_mismatch_is_rejected() {
522468
assert_oidc_login_fails(|_| {}, Some("wrong_state".to_string())).await;
@@ -532,7 +478,13 @@ async fn test_oidc_nonce_mismatch_is_rejected() {
532478

533479
#[actix_web::test]
534480
async fn test_oidc_bad_signature_is_rejected() {
535-
assert_oidc_callback_fails_with_bad_signature().await;
481+
assert_oidc_login_fails(
482+
|state| {
483+
state.jwt_customizer = Some(Box::new(|claims, _| make_jwt(&claims, "wrong_secret")));
484+
},
485+
None,
486+
)
487+
.await;
536488
}
537489

538490
#[actix_web::test]

0 commit comments

Comments
 (0)