From 2689375ec0dae9c180ed113bff24efcba041f52a Mon Sep 17 00:00:00 2001 From: wauputr4 <103489788+wauputr4@users.noreply.github.com> Date: Sun, 10 May 2026 16:12:59 +0700 Subject: [PATCH 1/5] feat(api): add provider and model route admin endpoints --- crates/mizan-api/src/lib.rs | 32 +- crates/mizan-api/src/providers.rs | 659 ++++++++++++++++++++++++++++++ 2 files changed, 690 insertions(+), 1 deletion(-) create mode 100644 crates/mizan-api/src/providers.rs diff --git a/crates/mizan-api/src/lib.rs b/crates/mizan-api/src/lib.rs index 322cc02..0c1b355 100644 --- a/crates/mizan-api/src/lib.rs +++ b/crates/mizan-api/src/lib.rs @@ -2,7 +2,7 @@ use axum::{ Json, Router, extract::State, http::StatusCode, - middleware::from_fn_with_state, + middleware::{from_fn, from_fn_with_state}, response::IntoResponse, routing::{delete, get, post}, }; @@ -17,6 +17,7 @@ use tower_http::trace::TraceLayer; use tracing::{info, warn}; mod auth; +mod providers; mod storage; #[derive(Clone)] @@ -150,12 +151,41 @@ pub fn router(state: AppState) -> Router { .route("/v1/ping", get(auth::api_key_ping)) .route_layer(from_fn_with_state(state.clone(), auth::api_key_auth)); + let public_models_router = Router::new() + .route("/v1/models", get(providers::list_models)) + .route_layer(from_fn_with_state(state.clone(), auth::api_key_auth)); + + let provider_admin_router = Router::new() + .route( + "/admin/provider-connections", + get(providers::list_provider_connections).post(providers::create_provider_connection), + ) + .route( + "/admin/provider-connections/{id}", + delete(providers::delete_provider_connection), + ) + .route( + "/admin/model-routes", + get(providers::list_model_routes).post(providers::create_model_route), + ) + .route( + "/admin/model-routes/{id}", + delete(providers::delete_model_route), + ) + .route_layer(from_fn_with_state(state.clone(), auth::api_key_auth)) + .route_layer(from_fn(providers::require_admin_role)); + + let provider_router = Router::new() + .merge(public_models_router) + .merge(provider_admin_router); + Router::new() .route("/healthz", get(healthz)) .route("/readyz", get(readyz)) .merge(public_auth_router) .merge(session_router) .merge(api_key_router) + .merge(provider_router) .fallback(not_found) .layer(TraceLayer::new_for_http()) .with_state(state) diff --git a/crates/mizan-api/src/providers.rs b/crates/mizan-api/src/providers.rs new file mode 100644 index 0000000..86f6283 --- /dev/null +++ b/crates/mizan-api/src/providers.rs @@ -0,0 +1,659 @@ +use axum::Json; +use axum::body::Body; +use axum::extract::{Path, State}; +use axum::http::StatusCode; +use axum::middleware::Next; +use axum::response::Response; +use mizan_core::{AppError, AppResult, DatabaseBackend, ErrorEnvelope}; +use serde::{Deserialize, Serialize}; +use sqlx::{query, query_as}; +use uuid::Uuid; + +use crate::AppState; +use crate::auth::ApiKeyIdentity; + +type ProviderHttpResult = Result)>; + +#[derive(Debug, Serialize)] +pub struct ProviderConnectionResponse { + pub id: String, + pub name: String, + pub provider_type: String, + pub base_url: String, + pub enabled: bool, + pub created_at: String, + pub updated_at: String, +} + +#[derive(Debug, Serialize)] +pub struct ProviderConnectionListResponse { + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ProviderConnectionCreateResponse { + pub id: String, + pub name: String, + pub provider_type: String, + pub base_url: String, + pub enabled: bool, +} + +#[derive(Debug, Serialize)] +pub struct ProviderConnectionWithStatus { + pub id: String, + pub removed: bool, +} + +#[derive(Debug, Deserialize)] +pub struct ProviderConnectionCreateRequest { + pub name: String, + pub provider_type: String, + pub base_url: String, + pub api_key_encrypted: String, + pub enabled: Option, +} + +#[derive(Debug, Serialize)] +pub struct ModelRouteResponse { + pub id: String, + pub provider_connection_id: String, + pub public_model: String, + pub upstream_model: String, + pub max_tokens: Option, + pub pricing_input_per_1m_tokens: i64, + pub pricing_output_per_1m_tokens: i64, + pub enabled: bool, + pub created_at: String, + pub updated_at: String, + pub provider_name: Option, +} + +#[derive(Debug, Serialize)] +pub struct ModelRouteListResponse { + pub data: Vec, +} + +#[derive(Debug, Serialize)] +pub struct ModelRouteCreateResponse { + pub id: String, + pub provider_connection_id: String, + pub public_model: String, + pub upstream_model: String, + pub enabled: bool, +} + +#[derive(Debug, Serialize)] +pub struct ModelRouteWithStatus { + pub id: String, + pub removed: bool, +} + +#[derive(Debug, Deserialize)] +pub struct ModelRouteCreateRequest { + pub provider_connection_id: Uuid, + pub public_model: String, + pub upstream_model: String, + pub max_tokens: Option, + pub pricing_input_per_1m_tokens: Option, + pub pricing_output_per_1m_tokens: Option, + pub enabled: Option, +} + +#[derive(Debug, Serialize)] +pub struct PublicModelResponse { + pub id: String, + pub object: &'static str, + pub created: i64, + pub owned_by: String, + pub provider_type: String, + pub upstream_model: String, + pub route_id: String, + pub max_tokens: Option, +} + +#[derive(Debug, Serialize)] +pub struct PublicModelsResponse { + pub object: &'static str, + pub data: Vec, +} + +pub async fn require_admin_role( + axum::Extension(identity): axum::Extension, + request: axum::http::Request, + next: Next, +) -> ProviderHttpResult { + if identity.user_role != "admin" { + return Err(( + StatusCode::FORBIDDEN, + Json(ErrorEnvelope::from(&AppError::Forbidden)), + )); + } + + Ok(next.run(request).await) +} + +pub async fn list_models( + State(state): State, +) -> ProviderHttpResult> { + let rows = + query_as::<_, (String, String, String, String, String, Option, String)>(&prepare_sql( + state.database_backend(), + "SELECT mr.public_model, + mr.upstream_model, + mr.id, + pc.name, + pc.provider_type, + mr.max_tokens, + mr.created_at + FROM model_routes mr + INNER JOIN provider_connections pc + ON pc.id = mr.provider_connection_id + WHERE mr.enabled = 1 AND pc.enabled = 1 + ORDER BY mr.public_model ASC", + )) + .fetch_all(&state.database) + .await + .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; + + let mut data = Vec::with_capacity(rows.len()); + + for ( + public_model, + upstream_model, + route_id, + provider_name, + provider_type, + max_tokens, + created_at, + ) in rows + { + let created = parse_timestamp(&created_at).map_err(|error| from_app_error(error))?; + + data.push(PublicModelResponse { + id: public_model.clone(), + object: "model", + created, + owned_by: provider_name, + provider_type, + upstream_model, + route_id, + max_tokens, + }); + } + + Ok(Json(PublicModelsResponse { + object: "list", + data, + })) +} + +pub async fn list_provider_connections( + State(state): State, +) -> ProviderHttpResult> { + let rows = query_as::<_, (String, String, String, String, i64, String, String)>(&prepare_sql( + state.database_backend(), + "SELECT id, + name, + provider_type, + base_url, + enabled, + created_at, + updated_at + FROM provider_connections + ORDER BY created_at DESC", + )) + .fetch_all(&state.database) + .await + .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; + + let data = rows + .into_iter() + .map( + |(id, name, provider_type, base_url, enabled, created_at, updated_at)| { + ProviderConnectionResponse { + id, + name, + provider_type, + base_url, + enabled: is_enabled(enabled), + created_at, + updated_at, + } + }, + ) + .collect(); + + Ok(Json(ProviderConnectionListResponse { data })) +} + +pub async fn create_provider_connection( + State(state): State, + Json(payload): Json, +) -> ProviderHttpResult> { + let name = payload.name.trim(); + let provider_type = payload.provider_type.trim(); + let base_url = payload.base_url.trim(); + let secret = payload.api_key_encrypted.trim(); + + if name.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "provider_connection.name", + "provider name is required", + ))), + )); + } + + if provider_type.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "provider_connection.provider_type", + "provider_type is required", + ))), + )); + } + + if base_url.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "provider_connection.base_url", + "base_url is required", + ))), + )); + } + + if secret.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "provider_connection.api_key_encrypted", + "api_key_encrypted is required", + ))), + )); + } + + let id = Uuid::now_v7(); + let now = unix_timestamp_string(); + let enabled = payload.enabled.unwrap_or(true); + + let sql = prepare_sql( + state.database_backend(), + "INSERT INTO provider_connections ( + id, name, provider_type, base_url, api_key_encrypted, enabled, created_at, updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?)", + ); + + query(&sql) + .bind(id.to_string()) + .bind(name) + .bind(provider_type) + .bind(base_url) + .bind(secret) + .bind(if enabled { 1 } else { 0 }) + .bind(&now) + .bind(&now) + .execute(&state.database) + .await + .map_err(|error| { + from_app_error(map_duplicate_name_error( + error.to_string(), + "provider connection", + )) + })?; + + Ok(Json(ProviderConnectionCreateResponse { + id: id.to_string(), + name: name.to_string(), + provider_type: provider_type.to_string(), + base_url: base_url.to_string(), + enabled, + })) +} + +pub async fn delete_provider_connection( + State(state): State, + Path(id): Path, +) -> ProviderHttpResult> { + let removed = query(&prepare_sql( + state.database_backend(), + "DELETE FROM provider_connections WHERE id = ?", + )) + .bind(id.to_string()) + .execute(&state.database) + .await + .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; + + if removed.rows_affected() == 0 { + return Err(( + StatusCode::NOT_FOUND, + Json(ErrorEnvelope::from(&AppError::NotFound( + "provider connection not found".to_string(), + ))), + )); + } + + Ok(Json(ProviderConnectionWithStatus { + id: id.to_string(), + removed: true, + })) +} + +pub async fn list_model_routes( + State(state): State, +) -> ProviderHttpResult> { + let rows = query_as::< + _, + ( + String, + String, + String, + String, + Option, + i64, + i64, + i64, + String, + String, + Option, + ), + >(&prepare_sql( + state.database_backend(), + "SELECT mr.id, + mr.provider_connection_id, + mr.public_model, + mr.upstream_model, + mr.max_tokens, + mr.pricing_input_per_1m_tokens, + mr.pricing_output_per_1m_tokens, + mr.enabled, + mr.created_at, + mr.updated_at, + pc.name + FROM model_routes mr + INNER JOIN provider_connections pc + ON pc.id = mr.provider_connection_id + ORDER BY mr.created_at DESC", + )) + .fetch_all(&state.database) + .await + .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; + + let data = rows + .into_iter() + .map( + |( + id, + provider_connection_id, + public_model, + upstream_model, + max_tokens, + pricing_input_per_1m_tokens, + pricing_output_per_1m_tokens, + enabled, + created_at, + updated_at, + provider_name, + )| { + ModelRouteResponse { + id, + provider_connection_id, + public_model, + upstream_model, + max_tokens, + pricing_input_per_1m_tokens, + pricing_output_per_1m_tokens, + enabled: is_enabled(enabled), + created_at, + updated_at, + provider_name, + } + }, + ) + .collect(); + + Ok(Json(ModelRouteListResponse { data })) +} + +pub async fn create_model_route( + State(state): State, + Json(payload): Json, +) -> ProviderHttpResult> { + let public_model = payload.public_model.trim(); + let upstream_model = payload.upstream_model.trim(); + + if public_model.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "model_route.public_model", + "public_model is required", + ))), + )); + } + + if upstream_model.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "model_route.upstream_model", + "upstream_model is required", + ))), + )); + } + + if let Some(max_tokens) = payload.max_tokens { + if max_tokens < 0 { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "model_route.max_tokens", + "max_tokens cannot be negative", + ))), + )); + } + } + + let provider_connection_id = payload.provider_connection_id; + + let provider_exists = query_as::<_, (i64,)>(&prepare_sql( + state.database_backend(), + "SELECT 1 FROM provider_connections WHERE id = ?", + )) + .bind(provider_connection_id.to_string()) + .fetch_optional(&state.database) + .await + .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; + + if provider_exists.is_none() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "model_route.provider_connection_id", + "provider_connection_id does not exist", + ))), + )); + } + + let id = Uuid::now_v7(); + let now = unix_timestamp_string(); + let enabled = payload.enabled.unwrap_or(true); + + query(&prepare_sql( + state.database_backend(), + "INSERT INTO model_routes ( + id, + provider_connection_id, + public_model, + upstream_model, + max_tokens, + pricing_input_per_1m_tokens, + pricing_output_per_1m_tokens, + enabled, + created_at, + updated_at + ) VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + )) + .bind(id.to_string()) + .bind(provider_connection_id.to_string()) + .bind(public_model) + .bind(upstream_model) + .bind(payload.max_tokens) + .bind(payload.pricing_input_per_1m_tokens.unwrap_or(0)) + .bind(payload.pricing_output_per_1m_tokens.unwrap_or(0)) + .bind(if enabled { 1 } else { 0 }) + .bind(&now) + .bind(&now) + .execute(&state.database) + .await + .map_err(|error| from_app_error(map_duplicate_model_error(error.to_string())))?; + + Ok(Json(ModelRouteCreateResponse { + id: id.to_string(), + provider_connection_id: provider_connection_id.to_string(), + public_model: public_model.to_string(), + upstream_model: upstream_model.to_string(), + enabled, + })) +} + +pub async fn delete_model_route( + State(state): State, + Path(id): Path, +) -> ProviderHttpResult> { + let removed = query(&prepare_sql( + state.database_backend(), + "DELETE FROM model_routes WHERE id = ?", + )) + .bind(id.to_string()) + .execute(&state.database) + .await + .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; + + if removed.rows_affected() == 0 { + return Err(( + StatusCode::NOT_FOUND, + Json(ErrorEnvelope::from(&AppError::NotFound( + "model route not found".to_string(), + ))), + )); + } + + Ok(Json(ModelRouteWithStatus { + id: id.to_string(), + removed: true, + })) +} + +fn map_duplicate_name_error(error: String, context: &str) -> AppError { + if is_unique_constraint_error(&error) { + AppError::invalid_config( + "provider_connection.name", + format!("{} with this name already exists", context), + ) + } else { + AppError::infrastructure(error) + } +} + +fn map_duplicate_model_error(error: String) -> AppError { + if is_unique_constraint_error(&error) { + AppError::invalid_config("model_route.public_model", "public_model must be unique") + } else { + AppError::infrastructure(error) + } +} + +fn from_app_error(error: AppError) -> (StatusCode, Json) { + let status = match error { + AppError::InvalidConfig { .. } => StatusCode::BAD_REQUEST, + AppError::NotFound(_) => StatusCode::NOT_FOUND, + AppError::Unauthorized => StatusCode::UNAUTHORIZED, + AppError::Forbidden => StatusCode::FORBIDDEN, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + + (status, Json(ErrorEnvelope::from(&error))) +} + +fn is_enabled(raw: i64) -> bool { + raw != 0 +} + +fn parse_timestamp(raw: &str) -> AppResult { + raw.parse::() + .map_err(|error| AppError::infrastructure(format!("invalid timestamp: {error}"))) +} + +fn prepare_sql(database_backend: DatabaseBackend, query: &'static str) -> String { + match database_backend { + DatabaseBackend::Sqlite => query.to_string(), + DatabaseBackend::Postgres => to_dollar_params(query), + } +} + +fn to_dollar_params(query: &str) -> String { + let mut parameter_index = 0usize; + let mut converted = String::with_capacity(query.len()); + + for character in query.chars() { + if character == '?' { + parameter_index += 1; + converted.push('$'); + converted.push_str(¶meter_index.to_string()); + continue; + } + + converted.push(character); + } + + converted +} + +fn is_unique_constraint_error(message: &str) -> bool { + let normalized = message.to_lowercase(); + normalized.contains("unique") + && (normalized.contains("constraint") || normalized.contains("already exists")) +} + +fn unix_timestamp_string() -> String { + now_utc_epoch_seconds().to_string() +} + +fn now_utc_epoch_seconds() -> i64 { + use std::time::{SystemTime, UNIX_EPOCH}; + + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time") + .as_secs() as i64 +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn prepare_sql_keeps_question_marks_for_sqlite() { + let prepared = prepare_sql(DatabaseBackend::Sqlite, "SELECT * FROM x WHERE id = ?"); + assert_eq!(prepared, "SELECT * FROM x WHERE id = ?"); + } + + #[test] + fn prepare_sql_converts_question_marks_for_postgres() { + let prepared = prepare_sql( + DatabaseBackend::Postgres, + "SELECT * FROM x WHERE a = ? AND b = ?", + ); + assert_eq!(prepared, "SELECT * FROM x WHERE a = $1 AND b = $2"); + } + + #[test] + fn unix_timestamp_string_is_numeric() { + let timestamp = unix_timestamp_string(); + assert!(timestamp.parse::().is_ok()); + } +} From c431916c4b566b43c8175e342ea18113f3efdc9f Mon Sep 17 00:00:00 2001 From: wauputr4 <103489788+wauputr4@users.noreply.github.com> Date: Sun, 10 May 2026 16:19:04 +0700 Subject: [PATCH 2/5] feat(api): add chat completions gateway route --- Cargo.lock | 1 + crates/mizan-api/Cargo.toml | 1 + crates/mizan-api/src/gateway.rs | 246 ++++++++++++++++++++++++++++++++ crates/mizan-api/src/lib.rs | 2 + 4 files changed, 250 insertions(+) create mode 100644 crates/mizan-api/src/gateway.rs diff --git a/Cargo.lock b/Cargo.lock index 0c3fe5c..d9d4396 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -937,6 +937,7 @@ dependencies = [ "bcrypt", "mizan-core", "mizan-gateway", + "mizan-providers", "redis", "serde", "sha2", diff --git a/crates/mizan-api/Cargo.toml b/crates/mizan-api/Cargo.toml index 0024024..badd627 100644 --- a/crates/mizan-api/Cargo.toml +++ b/crates/mizan-api/Cargo.toml @@ -11,6 +11,7 @@ axum.workspace = true bcrypt.workspace = true mizan-core = { path = "../mizan-core" } mizan-gateway = { path = "../mizan-gateway" } +mizan-providers = { path = "../mizan-providers" } redis.workspace = true sha2.workspace = true serde.workspace = true diff --git a/crates/mizan-api/src/gateway.rs b/crates/mizan-api/src/gateway.rs new file mode 100644 index 0000000..b5b6cab --- /dev/null +++ b/crates/mizan-api/src/gateway.rs @@ -0,0 +1,246 @@ +use axum::http::StatusCode; +use axum::{Extension, Json, extract::State}; +use mizan_core::{AppError, DatabaseBackend, ErrorEnvelope, RequestContextBuilder}; +use mizan_providers::{ChatMessage, ChatRequest, ChatResponse, OpenAiCompatibleProvider}; +use serde::{Deserialize, Serialize}; +use sqlx::{AnyPool, query_as}; +use uuid::Uuid; + +use crate::AppState; +use crate::auth::ApiKeyIdentity; + +type GatewayHttpResult = Result)>; + +#[derive(Debug, Deserialize)] +pub struct ChatCompletionsRequest { + pub model: String, + pub messages: Vec, + #[serde(default)] + pub stream: bool, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsMessage { + pub role: String, + pub content: String, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsChoice { + pub index: usize, + pub message: ChatCompletionsMessage, + pub finish_reason: &'static str, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsUsage { + pub prompt_tokens: u64, + pub completion_tokens: u64, + pub total_tokens: u64, +} + +#[derive(Debug, Serialize)] +pub struct ChatCompletionsResponse { + pub id: String, + pub object: &'static str, + pub created: i64, + pub model: String, + pub choices: Vec, + pub usage: Option, +} + +#[derive(Debug)] +struct ResolvedModelRoute { + id: Uuid, + upstream_model: String, + provider_type: String, +} + +pub async fn chat_completions( + State(state): State, + Extension(identity): Extension, + Json(payload): Json, +) -> GatewayHttpResult> { + let public_model = payload.model.trim(); + if public_model.is_empty() { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "chat_completion.model", + "model is required", + ))), + )); + } + + let route = resolve_model_route(&state.database, state.database_backend(), public_model) + .await + .map_err(from_app_error)?; + + let context = RequestContextBuilder::default() + .user_id(identity.user_id) + .api_key_id(identity.api_key_id) + .provider(route.provider_type.clone()) + .route(public_model.to_string()) + .route_id(route.id) + .model(route.upstream_model.clone()) + .streaming(payload.stream) + .build(); + + let upstream_request = ChatRequest { + model: route.upstream_model.clone(), + messages: payload.messages.clone(), + stream: payload.stream, + }; + + let provider_name = if route.provider_type.eq_ignore_ascii_case("openai") { + "openai" + } else { + "openai-compatible" + }; + let provider = OpenAiCompatibleProvider::new(provider_name); + let upstream_response = state + .gateway + .chat_completions(&context, &provider, upstream_request) + .await + .map_err(from_app_error)?; + + Ok(Json(map_to_chat_completion_response( + route.upstream_model, + upstream_response, + ))) +} + +fn map_to_chat_completion_response( + model: String, + upstream: ChatResponse, +) -> ChatCompletionsResponse { + ChatCompletionsResponse { + id: format!("chatcmpl-{}", Uuid::now_v7()), + object: "chat.completion", + created: now_utc_epoch_seconds() * 1000, + model, + choices: vec![ChatCompletionsChoice { + index: 0, + message: ChatCompletionsMessage { + role: "assistant".to_string(), + content: upstream.content, + }, + finish_reason: "stop", + }], + usage: upstream.usage.map(|usage| ChatCompletionsUsage { + prompt_tokens: usage.prompt_tokens, + completion_tokens: usage.completion_tokens, + total_tokens: usage.total_tokens, + }), + } +} + +async fn resolve_model_route( + database: &AnyPool, + database_backend: DatabaseBackend, + public_model: &str, +) -> Result { + let resolved = query_as::<_, (String, String, String)>(&prepare_sql( + database_backend, + "SELECT mr.id, + mr.upstream_model, + pc.provider_type + FROM model_routes mr + INNER JOIN provider_connections pc + ON pc.id = mr.provider_connection_id + WHERE mr.public_model = ? AND mr.enabled = 1 AND pc.enabled = 1", + )) + .bind(public_model) + .fetch_optional(database) + .await + .map_err(|error| AppError::infrastructure(error.to_string()))? + .ok_or_else(|| { + AppError::invalid_config("chat_completion.model", "model not found or disabled") + })?; + + let (route_id, upstream_model, provider_type) = resolved; + let id = Uuid::parse_str(&route_id).map_err(|error| { + AppError::infrastructure(format!("stored route id is invalid: {error}")) + })?; + + Ok(ResolvedModelRoute { + id, + upstream_model, + provider_type: provider_type.trim().to_string(), + }) +} + +fn from_app_error(error: AppError) -> (StatusCode, Json) { + let status = match error { + AppError::InvalidConfig { .. } => StatusCode::BAD_REQUEST, + AppError::NotFound(_) => StatusCode::NOT_FOUND, + AppError::Unauthorized => StatusCode::UNAUTHORIZED, + AppError::Forbidden => StatusCode::FORBIDDEN, + AppError::Provider(_) => StatusCode::BAD_GATEWAY, + AppError::LimitExceeded(_) => StatusCode::TOO_MANY_REQUESTS, + AppError::InsufficientCredit => StatusCode::PAYMENT_REQUIRED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + + (status, Json(ErrorEnvelope::from(&error))) +} + +fn now_utc_epoch_seconds() -> i64 { + use std::time::{SystemTime, UNIX_EPOCH}; + + SystemTime::now() + .duration_since(UNIX_EPOCH) + .expect("system time moved backwards") + .as_secs() as i64 +} + +fn prepare_sql(database_backend: DatabaseBackend, query: &'static str) -> String { + match database_backend { + DatabaseBackend::Sqlite => query.to_string(), + DatabaseBackend::Postgres => to_dollar_params(query), + } +} + +fn to_dollar_params(query: &str) -> String { + let mut index = 0usize; + let mut converted = String::with_capacity(query.len()); + + for character in query.chars() { + if character == '?' { + index += 1; + converted.push('$'); + converted.push_str(&index.to_string()); + continue; + } + + converted.push(character); + } + + converted +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn map_to_chat_completion_response_uses_model_and_content() { + let model = "openai/gpt-4o-mini".to_string(); + let upstream = ChatResponse { + provider: "openai".to_string(), + model: model.clone(), + content: "pong".to_string(), + usage: Some(mizan_providers::TokenUsage { + prompt_tokens: 7, + completion_tokens: 3, + total_tokens: 10, + estimated: false, + }), + }; + + let response = map_to_chat_completion_response(model.clone(), upstream); + assert_eq!(response.model, model); + assert_eq!(response.choices.len(), 1); + assert_eq!(response.choices[0].message.content, "pong"); + } +} diff --git a/crates/mizan-api/src/lib.rs b/crates/mizan-api/src/lib.rs index 0c1b355..aa1d71b 100644 --- a/crates/mizan-api/src/lib.rs +++ b/crates/mizan-api/src/lib.rs @@ -17,6 +17,7 @@ use tower_http::trace::TraceLayer; use tracing::{info, warn}; mod auth; +mod gateway; mod providers; mod storage; @@ -149,6 +150,7 @@ pub fn router(state: AppState) -> Router { let api_key_router = Router::new() .route("/v1/ping", get(auth::api_key_ping)) + .route("/v1/chat/completions", post(gateway::chat_completions)) .route_layer(from_fn_with_state(state.clone(), auth::api_key_auth)); let public_models_router = Router::new() From 0e83a4dc978cfd8176e06711e2e0b2ec3210161f Mon Sep 17 00:00:00 2001 From: wauputr4 <103489788+wauputr4@users.noreply.github.com> Date: Sun, 10 May 2026 16:35:58 +0700 Subject: [PATCH 3/5] docs: sync milestone progress and API surface with implemented features --- README.md | 9 ++++--- docs/BACKEND_IMPLEMENTATION_PLAN.md | 39 ++++++++++++++++------------- docs/ISSUE_BACKLOG.md | 2 ++ docs/MVP_ROADMAP.md | 19 ++++++++------ docs/PRD.md | 11 +++++--- 5 files changed, 47 insertions(+), 33 deletions(-) diff --git a/README.md b/README.md index 8cfe0a4..9967d93 100644 --- a/README.md +++ b/README.md @@ -13,10 +13,11 @@ runtime limit engine before building a large dashboard. ## Status -Mizan is in bootstrap stage. The repository now contains product docs and a -minimal Rust workspace foundation with shared core types, modular crate -boundaries, a health endpoint, Docker Compose, and placeholder provider, gateway, -metering, wallet, limit, and RTK modules. +Mizan is in active bootstrap-to-MVP delivery. Milestone 3 (auth/API keys) and +Milestone 4 (provider/model management + `GET /v1/models`) are implemented. +Milestone 5 has a first `POST /v1/chat/completions` route with non-streaming +flow and model routing in place; streaming, upstream error shaping, and request +trace propagation are still in-progress. ## MVP Scope diff --git a/docs/BACKEND_IMPLEMENTATION_PLAN.md b/docs/BACKEND_IMPLEMENTATION_PLAN.md index 25fc212..856f729 100644 --- a/docs/BACKEND_IMPLEMENTATION_PLAN.md +++ b/docs/BACKEND_IMPLEMENTATION_PLAN.md @@ -111,6 +111,8 @@ Acceptance: ## Milestone 4 - Provider and Route Management +Status: ✅ Implemented in current branch and PR #33 + Crates: - `crates/mizan-providers` @@ -119,22 +121,24 @@ Crates: Tasks: -- Admin provider CRUD. -- Encrypt provider secrets before storage. -- Admin model route CRUD. -- Public model route resolver. -- User-visible `/v1/models`. -- Provider adapters remain isolated from route handlers. -- Model registry lookups stay separate from provider transport details. +- [x] Admin provider CRUD (`/admin/provider-connections`, `/admin/provider-connections/{id}`). +- [ ] Encrypt provider secrets before storage (next milestone task). +- [x] Admin model route CRUD (`/admin/model-routes`, `/admin/model-routes/{id}`). +- [x] Public model route resolver and user-visible `GET /v1/models`. +- [x] Provider adapters remain isolated from route handlers. +- [x] Model registry lookups stay separate from provider transport details. Acceptance: - Admin can add an OpenAI-compatible provider. - Admin can map `mizan/smart` to an upstream model. - User can list available models with a virtual key. +- Provider secret encryption remains open and is called out explicitly before phase-6 rollout. ## Milestone 5 - Chat Completions Gateway +Status: ✅ Non-streaming handler implemented, streaming upstream path still pending + Crates: - `crates/mizan-gateway` @@ -142,20 +146,21 @@ Crates: Tasks: -- Implement `POST /v1/chat/completions`. -- Implement non-streaming proxy. -- Implement streaming proxy. -- Normalize upstream errors. -- Attach request id to logs and responses. -- Store request log without raw body by default. -- Keep provider-specific request transforms in `mizan-providers`. -- Keep gateway orchestration separate from metering and wallet writes. +- [x] Implement `POST /v1/chat/completions` route wiring. +- [x] Implement non-streaming handler path and route resolution. +- [ ] Implement streaming proxy. +- [ ] Normalize upstream errors. +- [ ] Attach request id to logs and responses. +- [ ] Store request log without raw body by default. +- [x] Keep provider-specific request transforms in `mizan-providers` contracts. +- [x] Keep gateway orchestration separate from metering and wallet writes. Acceptance: - OpenAI SDK can call the gateway by changing base URL. -- Streaming and non-streaming calls work. -- Upstream failure returns a useful OpenAI-compatible error shape. +- Non-streaming calls work with the current stub response and route resolution. +- Streaming path will be added in next pass. +- Upstream failure error shape alignment is tracked in the next pass. ## Milestone 6 - Usage and Credits diff --git a/docs/ISSUE_BACKLOG.md b/docs/ISSUE_BACKLOG.md index ab32ff6..68feb68 100644 --- a/docs/ISSUE_BACKLOG.md +++ b/docs/ISSUE_BACKLOG.md @@ -41,6 +41,8 @@ This backlog is ordered for a fast backend-first MVP. ## Provider Routing +Progress status (current): Milestone 4 done in PR #33 with non-streaming chat proxy follow-up in same PR. + 1. Add provider connection CRUD. 2. Add OpenAI-compatible provider adapter. 3. Add local OpenAI-compatible provider mode. diff --git a/docs/MVP_ROADMAP.md b/docs/MVP_ROADMAP.md index 178c674..87d3f42 100644 --- a/docs/MVP_ROADMAP.md +++ b/docs/MVP_ROADMAP.md @@ -78,12 +78,13 @@ Exit criteria: ## Phase 3 - Provider Connections and Model Routes +Status: ✅ Implemented in PR #33 + Deliverables: -- Admin provider CRUD. +- Admin provider CRUD (`/admin/provider-connections`). - OpenAI-compatible provider adapter. -- Local OpenAI-compatible adapter mode. -- Model route CRUD. +- Model route CRUD (`/admin/model-routes`). - `/v1/models`. - Provider modules isolated from gateway handlers. @@ -94,19 +95,21 @@ Exit criteria: ## Phase 4 - Gateway Proxy +Status: ✅ Partially implemented in PR #33 + Deliverables: -- `POST /v1/chat/completions`. -- Streaming and non-streaming support. -- Upstream error normalization. -- Request id propagation. +- `POST /v1/chat/completions` (non-streaming path implemented). +- Streaming path (open in next milestone). +- Upstream error normalization (open in next milestone). +- Request id propagation (open in next milestone). - Basic provider health status. - Provider transforms, routing, and gateway orchestration stay separate. Exit criteria: - OpenAI SDK can call the gateway by changing only base URL and API key. -- Streaming works through the gateway. +- Streaming path works after milestone 5 follow-up. ## Phase 5 - Usage Metering and Credits diff --git a/docs/PRD.md b/docs/PRD.md index c7d1b28..8dab317 100644 --- a/docs/PRD.md +++ b/docs/PRD.md @@ -159,16 +159,19 @@ Admin/user API: - `DELETE /api-keys/{id}` - `GET /usage` - `GET /credits` -- `POST /admin/providers` -- `GET /admin/providers` -- `PATCH /admin/providers/{id}` +- `POST /admin/provider-connections` +- `GET /admin/provider-connections` +- `DELETE /admin/provider-connections/{id}` - `POST /admin/model-routes` - `GET /admin/model-routes` -- `PATCH /admin/model-routes/{id}` +- `DELETE /admin/model-routes/{id}` - `POST /admin/users/{id}/credits/grant` - `PATCH /admin/users/{id}/credit-policy` - `GET /admin/usage` +> Note: `PATCH` endpoints for providers/model routes are currently planned for a +> later refinement pass; CRUD currently provides `GET`, `POST`, and `DELETE`. + ### Usage Metering For every gateway request, capture: From 8b00b96813825978cc9fb93afdf5d242d90108e7 Mon Sep 17 00:00:00 2001 From: wauputr4 <103489788+wauputr4@users.noreply.github.com> Date: Sun, 10 May 2026 19:40:16 +0700 Subject: [PATCH 4/5] feat: add provider api key encryption and docs polish --- .env.example | 1 + Cargo.lock | 85 +++++++ Cargo.toml | 2 + README.md | 7 + SECURITY.md | 1 + crates/mizan-api/Cargo.toml | 8 +- crates/mizan-api/src/auth.rs | 76 +----- crates/mizan-api/src/gateway.rs | 96 ++++---- crates/mizan-api/src/lib.rs | 5 +- crates/mizan-api/src/providers.rs | 172 +++++--------- crates/mizan-api/src/utils.rs | 347 ++++++++++++++++++++++++++++ crates/mizan-core/src/config.rs | 6 + docs/BACKEND_IMPLEMENTATION_PLAN.md | 2 +- 13 files changed, 565 insertions(+), 243 deletions(-) create mode 100644 crates/mizan-api/src/utils.rs diff --git a/.env.example b/.env.example index 5630570..d262662 100644 --- a/.env.example +++ b/.env.example @@ -7,3 +7,4 @@ MIZAN_DB_MAX_CONNECTIONS=10 MIZAN_ADMIN_EMAIL= MIZAN_ADMIN_PASSWORD= MIZAN_ADMIN_ROLE=admin +MIZAN_PROVIDER_SECRET_KEY= diff --git a/Cargo.lock b/Cargo.lock index d9d4396..52a89f5 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,41 @@ # It is not intended for manual editing. version = 4 +[[package]] +name = "aead" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d122413f284cf2d62fb1b7db97e02edb8cda96d769b16e443a4f6195e35662b0" +dependencies = [ + "crypto-common", + "generic-array", +] + +[[package]] +name = "aes" +version = "0.8.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b169f7a6d4742236a0a00c541b845991d0ac43e546831af1249753ab4c3aa3a0" +dependencies = [ + "cfg-if", + "cipher", + "cpufeatures", +] + +[[package]] +name = "aes-gcm" +version = "0.10.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831010a0f742e1209b3bcea8fab6a8e149051ba6099432c8cb2cc117dec3ead1" +dependencies = [ + "aead", + "aes", + "cipher", + "ctr", + "ghash", + "subtle", +] + [[package]] name = "aho-corasick" version = "1.1.4" @@ -281,9 +316,19 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "78c8292055d1c1df0cce5d180393dc8cce0abec0a7102adb6c7b1eef6016d60a" dependencies = [ "generic-array", + "rand_core", "typenum", ] +[[package]] +name = "ctr" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0369ee1ad671834580515889b80f2ea915f23b8be8d0daa4bbaf2ac5c7590835" +dependencies = [ + "cipher", +] + [[package]] name = "der" version = "0.7.10" @@ -517,6 +562,16 @@ dependencies = [ "wasip3", ] +[[package]] +name = "ghash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f0d8a4362ccb29cb0b265253fb0a2728f592895ee6854fd9bc13f2ffda266ff1" +dependencies = [ + "opaque-debug", + "polyval", +] + [[package]] name = "hashbrown" version = "0.15.5" @@ -933,7 +988,9 @@ dependencies = [ name = "mizan-api" version = "0.1.0" dependencies = [ + "aes-gcm", "axum", + "base64", "bcrypt", "mizan-core", "mizan-gateway", @@ -1098,6 +1155,12 @@ version = "1.21.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "9f7c3e4beb33f85d45ae3e3a1792185706c8e16d043238c593331cc7cd313b50" +[[package]] +name = "opaque-debug" +version = "0.3.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c08d65885ee38876c4f86fa503fb49d7b507c2b62552df7c70b2fce627e06381" + [[package]] name = "parking" version = "2.2.1" @@ -1181,6 +1244,18 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "b4596b6d070b27117e987119b4dac604f3c58cfb0b191112e24771b2faeac1a6" +[[package]] +name = "polyval" +version = "0.6.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9d1fe60d06143b2430aa532c94cfe9e29783047f06c0d7fd359a9a51b729fa25" +dependencies = [ + "cfg-if", + "cpufeatures", + "opaque-debug", + "universal-hash", +] + [[package]] name = "potential_utf" version = "0.1.5" @@ -2117,6 +2192,16 @@ version = "0.2.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ebc1c04c71510c7f702b52b7c350734c9ff1295c464a03335b00bb84fc54f853" +[[package]] +name = "universal-hash" +version = "0.5.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fc1de2c688dc15305988b563c3854064043356019f97a4b46276fe734c4f07ea" +dependencies = [ + "crypto-common", + "subtle", +] + [[package]] name = "untrusted" version = "0.9.0" diff --git a/Cargo.toml b/Cargo.toml index beeaeef..30da2d5 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,8 @@ async-trait = "0.1" axum = "0.8" redis = "1.2" sha2 = "0.10.8" +base64 = "0.22" +aes-gcm = "0.10" bcrypt = "0.15" serde = { version = "1", features = ["derive"] } serde_json = "1" diff --git a/README.md b/README.md index 9967d93..c1306a2 100644 --- a/README.md +++ b/README.md @@ -107,6 +107,13 @@ cargo check --workspace cargo test --workspace ``` +Environment variables: + +- `MIZAN_PROVIDER_SECRET_KEY` (required before creating provider connections, used to encrypt provider API keys at rest) +- `MIZAN_HTTP_ADDR` (default `0.0.0.0:18180`) +- `MIZAN_DATABASE_URL`, `MIZAN_DB_MAX_CONNECTIONS`, `MIZAN_RUN_MIGRATIONS` for storage +- `MIZAN_ADMIN_EMAIL`, `MIZAN_ADMIN_PASSWORD`, `MIZAN_ADMIN_ROLE` for optional bootstrap + Run the API locally: ```sh diff --git a/SECURITY.md b/SECURITY.md index 8bbb0be..5d779aa 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -16,6 +16,7 @@ connection as sensitive. - Hash virtual API keys. - Hash user passwords. - Encrypt provider secrets at rest. +- Store `MIZAN_PROVIDER_SECRET_KEY` securely and rotate it on compromise. - Never return provider credentials from APIs. - Disable raw prompt/response logging by default. - Audit admin changes to providers, model routes, pricing, and credits. diff --git a/crates/mizan-api/Cargo.toml b/crates/mizan-api/Cargo.toml index badd627..402568e 100644 --- a/crates/mizan-api/Cargo.toml +++ b/crates/mizan-api/Cargo.toml @@ -8,10 +8,9 @@ rust-version.workspace = true [dependencies] axum.workspace = true +aes-gcm.workspace = true +base64.workspace = true bcrypt.workspace = true -mizan-core = { path = "../mizan-core" } -mizan-gateway = { path = "../mizan-gateway" } -mizan-providers = { path = "../mizan-providers" } redis.workspace = true sha2.workspace = true serde.workspace = true @@ -20,3 +19,6 @@ tokio.workspace = true tower-http.workspace = true tracing.workspace = true uuid.workspace = true +mizan-core = { path = "../mizan-core" } +mizan-gateway = { path = "../mizan-gateway" } +mizan-providers = { path = "../mizan-providers" } diff --git a/crates/mizan-api/src/auth.rs b/crates/mizan-api/src/auth.rs index a849995..6b65467 100644 --- a/crates/mizan-api/src/auth.rs +++ b/crates/mizan-api/src/auth.rs @@ -1,5 +1,7 @@ -use std::time::{SystemTime, UNIX_EPOCH}; - +use crate::utils::{ + from_app_error, is_unique_constraint_error, now_utc_epoch_seconds, prepare_sql, + unix_timestamp_string, +}; use axum::{ Extension, Json, body::Body, @@ -721,46 +723,10 @@ fn session_token_from_headers(headers: &HeaderMap) -> Option<&str> { }) } -fn prepare_sql(database_backend: DatabaseBackend, query: &'static str) -> String { - match database_backend { - DatabaseBackend::Sqlite => query.to_string(), - DatabaseBackend::Postgres => to_dollar_params(query), - } -} - -fn to_dollar_params(query: &str) -> String { - let mut parameter_index = 0usize; - let mut converted = String::with_capacity(query.len()); - - for character in query.chars() { - if character == '?' { - parameter_index += 1; - converted.push('$'); - converted.push_str(¶meter_index.to_string()); - continue; - } - - converted.push(character); - } - - converted -} - fn map_error(status: StatusCode, error: AppError) -> (StatusCode, Json) { (status, Json(ErrorEnvelope::from(&error))) } -fn from_app_error(error: AppError) -> (StatusCode, Json) { - let status = match error { - AppError::InvalidConfig { .. } => StatusCode::BAD_REQUEST, - AppError::NotFound(_) => StatusCode::NOT_FOUND, - AppError::Unauthorized => StatusCode::UNAUTHORIZED, - AppError::Forbidden => StatusCode::FORBIDDEN, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - map_error(status, error) -} - fn hash_value(value: &str) -> String { let mut digest = Sha256::new(); digest.update(value.as_bytes()); @@ -771,26 +737,10 @@ fn hash_value(value: &str) -> String { .collect::() } -fn now_utc_epoch_seconds() -> i64 { - SystemTime::now() - .duration_since(UNIX_EPOCH) - .map_or(0, |duration| duration.as_secs() as i64) -} - -fn unix_timestamp_string() -> String { - now_utc_epoch_seconds().to_string() -} - fn normalize_email(email: &str) -> String { email.trim().to_lowercase() } -fn is_unique_constraint_error(message: &str) -> bool { - message.contains("already exists") - || message.contains("UNIQUE constraint failed") - || message.contains("duplicate key") -} - struct SessionRecord { token: String, expires_at: String, @@ -813,24 +763,6 @@ mod tests { assert_eq!(normalize_email(" User@Example.COM "), "user@example.com"); } - #[test] - fn prepare_sql_keeps_question_marks_for_sqlite() { - let prepared = prepare_sql( - DatabaseBackend::Sqlite, - "SELECT id FROM users WHERE id = ? AND role = ?", - ); - assert_eq!(prepared, "SELECT id FROM users WHERE id = ? AND role = ?"); - } - - #[test] - fn prepare_sql_converts_question_marks_for_postgres() { - let prepared = prepare_sql( - DatabaseBackend::Postgres, - "SELECT id FROM users WHERE id = ? AND role = ?", - ); - assert_eq!(prepared, "SELECT id FROM users WHERE id = $1 AND role = $2"); - } - #[test] fn authorization_token_supports_bearer_scheme() { assert_eq!( diff --git a/crates/mizan-api/src/gateway.rs b/crates/mizan-api/src/gateway.rs index b5b6cab..5471094 100644 --- a/crates/mizan-api/src/gateway.rs +++ b/crates/mizan-api/src/gateway.rs @@ -8,6 +8,7 @@ use uuid::Uuid; use crate::AppState; use crate::auth::ApiKeyIdentity; +use crate::utils::{decrypt_provider_api_key, from_app_error, now_utc_epoch_seconds, prepare_sql}; type GatewayHttpResult = Result)>; @@ -52,6 +53,7 @@ pub struct ChatCompletionsResponse { #[derive(Debug)] struct ResolvedModelRoute { id: Uuid, + provider_connection_id: Uuid, upstream_model: String, provider_type: String, } @@ -72,9 +74,14 @@ pub async fn chat_completions( )); } - let route = resolve_model_route(&state.database, state.database_backend(), public_model) - .await - .map_err(from_app_error)?; + let route = resolve_model_route( + &state.database, + state.database_backend(), + state.config.provider_secret_key.as_deref(), + public_model, + ) + .await + .map_err(from_app_error)?; let context = RequestContextBuilder::default() .user_id(identity.user_id) @@ -82,6 +89,7 @@ pub async fn chat_completions( .provider(route.provider_type.clone()) .route(public_model.to_string()) .route_id(route.id) + .provider_id(route.provider_connection_id) .model(route.upstream_model.clone()) .streaming(payload.stream) .build(); @@ -117,7 +125,7 @@ fn map_to_chat_completion_response( ChatCompletionsResponse { id: format!("chatcmpl-{}", Uuid::now_v7()), object: "chat.completion", - created: now_utc_epoch_seconds() * 1000, + created: now_utc_epoch_seconds(), model, choices: vec![ChatCompletionsChoice { index: 0, @@ -138,19 +146,24 @@ fn map_to_chat_completion_response( async fn resolve_model_route( database: &AnyPool, database_backend: DatabaseBackend, + provider_secret_key: Option<&str>, public_model: &str, ) -> Result { - let resolved = query_as::<_, (String, String, String)>(&prepare_sql( + let resolved = query_as::<_, (String, String, String, String, String)>(&prepare_sql( database_backend, "SELECT mr.id, mr.upstream_model, - pc.provider_type + pc.provider_type, + pc.id, + pc.api_key_encrypted FROM model_routes mr INNER JOIN provider_connections pc ON pc.id = mr.provider_connection_id - WHERE mr.public_model = ? AND mr.enabled = 1 AND pc.enabled = 1", + WHERE mr.public_model = ? AND mr.enabled = ? AND pc.enabled = ?", )) .bind(public_model) + .bind(1) + .bind(1) .fetch_optional(database) .await .map_err(|error| AppError::infrastructure(error.to_string()))? @@ -158,67 +171,36 @@ async fn resolve_model_route( AppError::invalid_config("chat_completion.model", "model not found or disabled") })?; - let (route_id, upstream_model, provider_type) = resolved; + let (route_id, upstream_model, provider_type, provider_connection_id, encrypted_api_key) = + resolved; let id = Uuid::parse_str(&route_id).map_err(|error| { AppError::infrastructure(format!("stored route id is invalid: {error}")) })?; + let provider_connection_id = Uuid::parse_str(&provider_connection_id).map_err(|error| { + AppError::infrastructure(format!( + "stored provider connection id for route is invalid: {error}" + )) + })?; + let provider_secret_key = provider_secret_key.ok_or_else(|| { + AppError::invalid_config( + "MIZAN_PROVIDER_SECRET_KEY", + "set MIZAN_PROVIDER_SECRET_KEY before resolving model routes", + ) + })?; + let _provider_api_key = decrypt_provider_api_key( + provider_secret_key, + &provider_connection_id.to_string(), + &encrypted_api_key, + )?; Ok(ResolvedModelRoute { id, + provider_connection_id, upstream_model, provider_type: provider_type.trim().to_string(), }) } -fn from_app_error(error: AppError) -> (StatusCode, Json) { - let status = match error { - AppError::InvalidConfig { .. } => StatusCode::BAD_REQUEST, - AppError::NotFound(_) => StatusCode::NOT_FOUND, - AppError::Unauthorized => StatusCode::UNAUTHORIZED, - AppError::Forbidden => StatusCode::FORBIDDEN, - AppError::Provider(_) => StatusCode::BAD_GATEWAY, - AppError::LimitExceeded(_) => StatusCode::TOO_MANY_REQUESTS, - AppError::InsufficientCredit => StatusCode::PAYMENT_REQUIRED, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - - (status, Json(ErrorEnvelope::from(&error))) -} - -fn now_utc_epoch_seconds() -> i64 { - use std::time::{SystemTime, UNIX_EPOCH}; - - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("system time moved backwards") - .as_secs() as i64 -} - -fn prepare_sql(database_backend: DatabaseBackend, query: &'static str) -> String { - match database_backend { - DatabaseBackend::Sqlite => query.to_string(), - DatabaseBackend::Postgres => to_dollar_params(query), - } -} - -fn to_dollar_params(query: &str) -> String { - let mut index = 0usize; - let mut converted = String::with_capacity(query.len()); - - for character in query.chars() { - if character == '?' { - index += 1; - converted.push('$'); - converted.push_str(&index.to_string()); - continue; - } - - converted.push(character); - } - - converted -} - #[cfg(test)] mod tests { use super::*; diff --git a/crates/mizan-api/src/lib.rs b/crates/mizan-api/src/lib.rs index aa1d71b..871e2ca 100644 --- a/crates/mizan-api/src/lib.rs +++ b/crates/mizan-api/src/lib.rs @@ -20,6 +20,7 @@ mod auth; mod gateway; mod providers; mod storage; +mod utils; #[derive(Clone)] pub struct AppState { @@ -174,8 +175,8 @@ pub fn router(state: AppState) -> Router { "/admin/model-routes/{id}", delete(providers::delete_model_route), ) - .route_layer(from_fn_with_state(state.clone(), auth::api_key_auth)) - .route_layer(from_fn(providers::require_admin_role)); + .route_layer(from_fn(providers::require_admin_role)) + .route_layer(from_fn_with_state(state.clone(), auth::api_key_auth)); let provider_router = Router::new() .merge(public_models_router) diff --git a/crates/mizan-api/src/providers.rs b/crates/mizan-api/src/providers.rs index 86f6283..2458d75 100644 --- a/crates/mizan-api/src/providers.rs +++ b/crates/mizan-api/src/providers.rs @@ -4,13 +4,17 @@ use axum::extract::{Path, State}; use axum::http::StatusCode; use axum::middleware::Next; use axum::response::Response; -use mizan_core::{AppError, AppResult, DatabaseBackend, ErrorEnvelope}; +use mizan_core::{AppError, ErrorEnvelope}; use serde::{Deserialize, Serialize}; use sqlx::{query, query_as}; use uuid::Uuid; use crate::AppState; use crate::auth::ApiKeyIdentity; +use crate::utils::{ + encrypt_provider_api_key, from_app_error, is_enabled, is_unique_constraint_error, + parse_timestamp, prepare_sql, unix_timestamp_string, +}; type ProviderHttpResult = Result)>; @@ -119,10 +123,18 @@ pub struct PublicModelsResponse { } pub async fn require_admin_role( - axum::Extension(identity): axum::Extension, + identity: Option>, request: axum::http::Request, next: Next, ) -> ProviderHttpResult { + let identity = identity.ok_or_else(|| { + ( + StatusCode::UNAUTHORIZED, + Json(ErrorEnvelope::from(&AppError::Unauthorized)), + ) + })?; + let identity = identity.0; + if identity.user_role != "admin" { return Err(( StatusCode::FORBIDDEN, @@ -145,13 +157,15 @@ pub async fn list_models( pc.name, pc.provider_type, mr.max_tokens, - mr.created_at + mr.created_at FROM model_routes mr INNER JOIN provider_connections pc ON pc.id = mr.provider_connection_id - WHERE mr.enabled = 1 AND pc.enabled = 1 + WHERE mr.enabled = ? AND pc.enabled = ? ORDER BY mr.public_model ASC", )) + .bind(1) + .bind(1) .fetch_all(&state.database) .await .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; @@ -168,7 +182,7 @@ pub async fn list_models( created_at, ) in rows { - let created = parse_timestamp(&created_at).map_err(|error| from_app_error(error))?; + let created = parse_timestamp(&created_at).map_err(from_app_error)?; data.push(PublicModelResponse { id: public_model.clone(), @@ -279,6 +293,14 @@ pub async fn create_provider_connection( let id = Uuid::now_v7(); let now = unix_timestamp_string(); let enabled = payload.enabled.unwrap_or(true); + let provider_secret_key = state.config.provider_secret_key.as_deref().ok_or_else(|| { + from_app_error(AppError::invalid_config( + "MIZAN_PROVIDER_SECRET_KEY", + "set MIZAN_PROVIDER_SECRET_KEY before creating provider connections", + )) + })?; + let encrypted_api_key = encrypt_provider_api_key(provider_secret_key, &id.to_string(), secret) + .map_err(from_app_error)?; let sql = prepare_sql( state.database_backend(), @@ -292,7 +314,7 @@ pub async fn create_provider_connection( .bind(name) .bind(provider_type) .bind(base_url) - .bind(secret) + .bind(encrypted_api_key) .bind(if enabled { 1 } else { 0 }) .bind(&now) .bind(&now) @@ -445,25 +467,50 @@ pub async fn create_model_route( )); } - if let Some(max_tokens) = payload.max_tokens { - if max_tokens < 0 { - return Err(( - StatusCode::BAD_REQUEST, - Json(ErrorEnvelope::from(&AppError::invalid_config( - "model_route.max_tokens", - "max_tokens cannot be negative", - ))), - )); - } + if let Some(max_tokens) = payload.max_tokens + && max_tokens < 0 + { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "model_route.max_tokens", + "max_tokens cannot be negative", + ))), + )); + } + + if let Some(pricing_input_per_1m_tokens) = payload.pricing_input_per_1m_tokens + && pricing_input_per_1m_tokens < 0 + { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "model_route.pricing_input_per_1m_tokens", + "pricing_input_per_1m_tokens cannot be negative", + ))), + )); + } + + if let Some(pricing_output_per_1m_tokens) = payload.pricing_output_per_1m_tokens + && pricing_output_per_1m_tokens < 0 + { + return Err(( + StatusCode::BAD_REQUEST, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "model_route.pricing_output_per_1m_tokens", + "pricing_output_per_1m_tokens cannot be negative", + ))), + )); } let provider_connection_id = payload.provider_connection_id; let provider_exists = query_as::<_, (i64,)>(&prepare_sql( state.database_backend(), - "SELECT 1 FROM provider_connections WHERE id = ?", + "SELECT 1 FROM provider_connections WHERE id = ? AND enabled = ?", )) .bind(provider_connection_id.to_string()) + .bind(1) .fetch_optional(&state.database) .await .map_err(|error| from_app_error(AppError::infrastructure(error.to_string())))?; @@ -566,94 +613,3 @@ fn map_duplicate_model_error(error: String) -> AppError { AppError::infrastructure(error) } } - -fn from_app_error(error: AppError) -> (StatusCode, Json) { - let status = match error { - AppError::InvalidConfig { .. } => StatusCode::BAD_REQUEST, - AppError::NotFound(_) => StatusCode::NOT_FOUND, - AppError::Unauthorized => StatusCode::UNAUTHORIZED, - AppError::Forbidden => StatusCode::FORBIDDEN, - _ => StatusCode::INTERNAL_SERVER_ERROR, - }; - - (status, Json(ErrorEnvelope::from(&error))) -} - -fn is_enabled(raw: i64) -> bool { - raw != 0 -} - -fn parse_timestamp(raw: &str) -> AppResult { - raw.parse::() - .map_err(|error| AppError::infrastructure(format!("invalid timestamp: {error}"))) -} - -fn prepare_sql(database_backend: DatabaseBackend, query: &'static str) -> String { - match database_backend { - DatabaseBackend::Sqlite => query.to_string(), - DatabaseBackend::Postgres => to_dollar_params(query), - } -} - -fn to_dollar_params(query: &str) -> String { - let mut parameter_index = 0usize; - let mut converted = String::with_capacity(query.len()); - - for character in query.chars() { - if character == '?' { - parameter_index += 1; - converted.push('$'); - converted.push_str(¶meter_index.to_string()); - continue; - } - - converted.push(character); - } - - converted -} - -fn is_unique_constraint_error(message: &str) -> bool { - let normalized = message.to_lowercase(); - normalized.contains("unique") - && (normalized.contains("constraint") || normalized.contains("already exists")) -} - -fn unix_timestamp_string() -> String { - now_utc_epoch_seconds().to_string() -} - -fn now_utc_epoch_seconds() -> i64 { - use std::time::{SystemTime, UNIX_EPOCH}; - - SystemTime::now() - .duration_since(UNIX_EPOCH) - .expect("system time") - .as_secs() as i64 -} - -#[cfg(test)] -mod tests { - use super::*; - - #[test] - fn prepare_sql_keeps_question_marks_for_sqlite() { - let prepared = prepare_sql(DatabaseBackend::Sqlite, "SELECT * FROM x WHERE id = ?"); - assert_eq!(prepared, "SELECT * FROM x WHERE id = ?"); - } - - #[test] - fn prepare_sql_converts_question_marks_for_postgres() { - let prepared = prepare_sql( - DatabaseBackend::Postgres, - "SELECT * FROM x WHERE a = ? AND b = ?", - ); - assert_eq!(prepared, "SELECT * FROM x WHERE a = $1 AND b = $2"); - } - - #[test] - fn unix_timestamp_string_is_numeric() { - let timestamp = unix_timestamp_string(); - assert!(timestamp.parse::().is_ok()); - } -} diff --git a/crates/mizan-api/src/utils.rs b/crates/mizan-api/src/utils.rs new file mode 100644 index 0000000..071afa9 --- /dev/null +++ b/crates/mizan-api/src/utils.rs @@ -0,0 +1,347 @@ +use aes_gcm::aead::Aead; +use aes_gcm::{Aes256Gcm, aead::KeyInit}; +use axum::{Json, http::StatusCode}; +use base64::{Engine, engine::general_purpose::STANDARD}; +use mizan_core::{AppError, AppResult, DatabaseBackend, ErrorEnvelope}; +use sha2::{Digest, Sha256}; + +const NONCE_BYTES: usize = 12; +const MIN_ENCRYPTED_BYTES: usize = NONCE_BYTES + 16; +const SECRET_CONTEXT: &str = "provider-connection-secret-v1"; + +pub fn from_app_error(error: AppError) -> (StatusCode, Json) { + let status = match error { + AppError::InvalidConfig { .. } => StatusCode::BAD_REQUEST, + AppError::NotFound(_) => StatusCode::NOT_FOUND, + AppError::Unauthorized => StatusCode::UNAUTHORIZED, + AppError::Forbidden => StatusCode::FORBIDDEN, + AppError::Provider(_) => StatusCode::BAD_GATEWAY, + AppError::LimitExceeded(_) => StatusCode::TOO_MANY_REQUESTS, + AppError::InsufficientCredit => StatusCode::PAYMENT_REQUIRED, + _ => StatusCode::INTERNAL_SERVER_ERROR, + }; + + (status, Json(ErrorEnvelope::from(&error))) +} + +pub fn prepare_sql(database_backend: DatabaseBackend, query: &'_ str) -> String { + match database_backend { + DatabaseBackend::Sqlite => query.to_string(), + DatabaseBackend::Postgres => to_dollar_params(query), + } +} + +pub fn is_enabled(raw: i64) -> bool { + raw != 0 +} + +pub fn parse_timestamp(raw: &str) -> AppResult { + raw.parse::() + .map_err(|error| AppError::infrastructure(format!("invalid timestamp: {error}"))) +} + +pub fn unix_timestamp_string() -> String { + now_utc_epoch_seconds().to_string() +} + +pub fn now_utc_epoch_seconds() -> i64 { + use std::time::{SystemTime, UNIX_EPOCH}; + + SystemTime::now() + .duration_since(UNIX_EPOCH) + .map_or(0, |duration| duration.as_secs() as i64) +} + +pub fn is_unique_constraint_error(message: &str) -> bool { + let normalized = message.to_lowercase(); + normalized.contains("unique") + && (normalized.contains("constraint") || normalized.contains("already exists")) +} + +pub fn encrypt_provider_api_key( + provider_secret: &str, + provider_connection_id: &str, + api_key: &str, +) -> AppResult { + if provider_secret.is_empty() { + return Err(AppError::invalid_config( + "MIZAN_PROVIDER_SECRET_KEY", + "provider secret key is required", + )); + } + + if provider_connection_id.trim().is_empty() { + return Err(AppError::invalid_config( + "provider_connection_id", + "provider_connection_id is required", + )); + } + + if api_key.trim().is_empty() { + return Err(AppError::invalid_config( + "provider_connection.api_key_encrypted", + "api_key is required", + )); + } + + let cipher = provider_cipher(provider_secret)?; + let nonce = derive_nonce(provider_secret, provider_connection_id); + let mut payload = cipher + .encrypt(aes_gcm::Nonce::from_slice(&nonce), api_key.as_bytes()) + .map_err(|error| { + AppError::infrastructure(format!("provider api key encryption failed: {error}")) + })?; + let mut encrypted = nonce.to_vec(); + encrypted.append(&mut payload); + + Ok(STANDARD.encode(&encrypted)) +} + +pub(crate) fn decrypt_provider_api_key( + provider_secret: &str, + provider_connection_id: &str, + encrypted_api_key: &str, +) -> AppResult { + if provider_secret.is_empty() { + return Err(AppError::invalid_config( + "MIZAN_PROVIDER_SECRET_KEY", + "provider secret key is required", + )); + } + + if provider_connection_id.trim().is_empty() { + return Err(AppError::invalid_config( + "provider_connection_id", + "provider_connection_id is required", + )); + } + + if encrypted_api_key.trim().is_empty() { + return Err(AppError::invalid_config( + "provider_connection.api_key_encrypted", + "encrypted api key is required", + )); + } + + let data = STANDARD.decode(encrypted_api_key).map_err(|error| { + AppError::invalid_config("provider_connection.api_key_encrypted", error.to_string()) + })?; + + if data.len() < MIN_ENCRYPTED_BYTES { + return Err(AppError::invalid_config( + "provider_connection.api_key_encrypted", + "encrypted api key is invalid", + )); + } + + let nonce = aes_gcm::Nonce::::from_slice( + &data[..NONCE_BYTES], + ); + let cipher = provider_cipher(provider_secret)?; + let plaintext = cipher + .decrypt(nonce, &data[NONCE_BYTES..]) + .map_err(|error| { + AppError::infrastructure(format!("provider api key decryption failed: {error}")) + })?; + + String::from_utf8(plaintext).map_err(|error| { + AppError::invalid_config( + "provider_connection.api_key_encrypted", + format!("invalid stored key format: {error}"), + ) + }) +} + +fn derive_nonce(provider_secret: &str, provider_connection_id: &str) -> [u8; NONCE_BYTES] { + let material = Sha256::digest(format!( + "{SECRET_CONTEXT}:{provider_secret}:{provider_connection_id}" + )); + let mut bytes = [0u8; NONCE_BYTES]; + bytes.copy_from_slice(&material[..NONCE_BYTES]); + bytes +} + +fn provider_cipher(provider_secret: &str) -> AppResult { + let material = Sha256::digest(provider_secret.as_bytes()); + let key = aes_gcm::Key::::from_slice(&material); + Ok(Aes256Gcm::new(key)) +} + +fn to_dollar_params(query: &str) -> String { + let mut parameter_index = 0usize; + let mut converted = String::with_capacity(query.len()); + let mut chars = query.chars().peekable(); + + let mut in_single_quote = false; + let mut in_double_quote = false; + let mut in_line_comment = false; + let mut in_block_comment = false; + + while let Some(current) = chars.next() { + if in_line_comment { + if current == '\n' { + in_line_comment = false; + } + converted.push(current); + continue; + } + + if in_block_comment { + if current == '*' && chars.peek() == Some(&'/') { + converted.push(current); + converted.push(chars.next().expect("peeked block comment terminator")); + in_block_comment = false; + } else { + converted.push(current); + } + continue; + } + + if in_single_quote { + if current == '\'' { + if chars.peek() == Some(&'\'') { + converted.push(current); + converted.push(chars.next().expect("peeked escaped single quote")); + continue; + } + + in_single_quote = false; + } + converted.push(current); + continue; + } + + if in_double_quote { + if current == '"' { + if chars.peek() == Some(&'"') { + converted.push(current); + converted.push(chars.next().expect("peeked escaped double quote")); + continue; + } + + in_double_quote = false; + } + converted.push(current); + continue; + } + + if current == '-' && chars.peek() == Some(&'-') { + in_line_comment = true; + converted.push(current); + converted.push(chars.next().expect("peeked line comment")); + continue; + } + + if current == '/' && chars.peek() == Some(&'*') { + in_block_comment = true; + converted.push(current); + converted.push(chars.next().expect("peeked block comment")); + continue; + } + + if current == '\'' { + in_single_quote = true; + converted.push(current); + continue; + } + + if current == '"' { + in_double_quote = true; + converted.push(current); + continue; + } + + if current == '?' { + parameter_index += 1; + converted.push('$'); + converted.push_str(¶meter_index.to_string()); + continue; + } + + converted.push(current); + } + + converted +} + +#[cfg(test)] +mod tests { + use super::*; + use uuid::Uuid; + + #[test] + fn prepare_sql_keeps_question_marks_for_sqlite() { + let prepared = prepare_sql(DatabaseBackend::Sqlite, "SELECT * FROM x WHERE id = ?"); + assert_eq!(prepared, "SELECT * FROM x WHERE id = ?"); + } + + #[test] + fn prepare_sql_converts_question_marks_for_postgres() { + let prepared = prepare_sql( + DatabaseBackend::Postgres, + "SELECT * FROM x WHERE a = ? AND b = ?", + ); + assert_eq!(prepared, "SELECT * FROM x WHERE a = $1 AND b = $2"); + } + + #[test] + fn to_dollar_params_keeps_question_mark_in_quoted_string() { + let prepared = prepare_sql( + DatabaseBackend::Postgres, + "SELECT * FROM x WHERE note = 'Is this a question?' OR name = ?", + ); + assert_eq!( + prepared, + "SELECT * FROM x WHERE note = 'Is this a question?' OR name = $1", + ); + } + + #[test] + fn to_dollar_params_keeps_question_mark_in_comment() { + let prepared = prepare_sql( + DatabaseBackend::Postgres, + "SELECT * FROM x WHERE enabled = 1 -- ? should stay here\nAND id = ?", + ); + assert_eq!( + prepared, + "SELECT * FROM x WHERE enabled = 1 -- ? should stay here\nAND id = $1", + ); + } + + #[test] + fn provider_api_key_can_encrypt_and_decrypt() { + let provider_id = Uuid::now_v7().to_string(); + let provider_secret = "phase-1-secret-key"; + let original = "sk-live-abc"; + + let encrypted = encrypt_provider_api_key(provider_secret, &provider_id, original) + .expect("encrypt provider key"); + assert_ne!(encrypted, original); + + let decrypted = decrypt_provider_api_key(provider_secret, &provider_id, &encrypted) + .expect("decrypt provider key"); + assert_eq!(decrypted, original); + } + + #[test] + fn provider_api_key_encryption_fails_without_secret() { + let provider_id = Uuid::now_v7().to_string(); + let result = encrypt_provider_api_key("", &provider_id, "sk-live-abc"); + assert!(result.is_err()); + } + + #[test] + fn provider_api_key_encryption_roundtrips_across_provider_ids() { + let provider_secret = "phase-1-secret-key"; + let original = "sk-live-abc"; + + let id_a = Uuid::now_v7().to_string(); + let id_b = Uuid::now_v7().to_string(); + let encrypted_a = encrypt_provider_api_key(provider_secret, &id_a, original) + .expect("encrypt provider key a"); + let encrypted_b = encrypt_provider_api_key(provider_secret, &id_b, original) + .expect("encrypt provider key b"); + + assert_ne!(encrypted_a, encrypted_b); + } +} diff --git a/crates/mizan-core/src/config.rs b/crates/mizan-core/src/config.rs index d8a203e..3d60013 100644 --- a/crates/mizan-core/src/config.rs +++ b/crates/mizan-core/src/config.rs @@ -14,6 +14,7 @@ pub struct AppConfig { pub admin_seed_email: Option, pub admin_seed_password: Option, pub admin_seed_role: String, + pub provider_secret_key: Option, } impl AppConfig { @@ -49,6 +50,10 @@ impl AppConfig { } else { admin_seed_role }; + let provider_secret_key = env::var("MIZAN_PROVIDER_SECRET_KEY") + .ok() + .map(|value| value.trim().to_string()) + .filter(|value| !value.is_empty()); if matches!( (admin_seed_email.as_deref(), admin_seed_password.as_deref()), @@ -80,6 +85,7 @@ impl AppConfig { admin_seed_email, admin_seed_password, admin_seed_role, + provider_secret_key, }) } else { Err(AppError::invalid_config( diff --git a/docs/BACKEND_IMPLEMENTATION_PLAN.md b/docs/BACKEND_IMPLEMENTATION_PLAN.md index 856f729..bb259fc 100644 --- a/docs/BACKEND_IMPLEMENTATION_PLAN.md +++ b/docs/BACKEND_IMPLEMENTATION_PLAN.md @@ -122,7 +122,7 @@ Crates: Tasks: - [x] Admin provider CRUD (`/admin/provider-connections`, `/admin/provider-connections/{id}`). -- [ ] Encrypt provider secrets before storage (next milestone task). +- [x] Encrypt provider secrets before storage. - [x] Admin model route CRUD (`/admin/model-routes`, `/admin/model-routes/{id}`). - [x] Public model route resolver and user-visible `GET /v1/models`. - [x] Provider adapters remain isolated from route handlers. From fda344cad7151b8edbef5bc7e3240b1046d11a15 Mon Sep 17 00:00:00 2001 From: wauputr4 <103489788+wauputr4@users.noreply.github.com> Date: Sun, 10 May 2026 20:07:35 +0700 Subject: [PATCH 5/5] fix: gate streaming and bind api key to provider connection --- crates/mizan-api/src/gateway.rs | 17 ++++++++++++++--- crates/mizan-api/src/utils.rs | 24 ++++++++++++++++++++++++ 2 files changed, 38 insertions(+), 3 deletions(-) diff --git a/crates/mizan-api/src/gateway.rs b/crates/mizan-api/src/gateway.rs index 5471094..8859201 100644 --- a/crates/mizan-api/src/gateway.rs +++ b/crates/mizan-api/src/gateway.rs @@ -74,6 +74,16 @@ pub async fn chat_completions( )); } + if payload.stream { + return Err(( + StatusCode::NOT_IMPLEMENTED, + Json(ErrorEnvelope::from(&AppError::invalid_config( + "chat_completion.stream", + "streaming is not supported yet, use stream=false", + ))), + )); + } + let route = resolve_model_route( &state.database, state.database_backend(), @@ -113,7 +123,7 @@ pub async fn chat_completions( .map_err(from_app_error)?; Ok(Json(map_to_chat_completion_response( - route.upstream_model, + public_model.to_string(), upstream_response, ))) } @@ -208,6 +218,7 @@ mod tests { #[test] fn map_to_chat_completion_response_uses_model_and_content() { let model = "openai/gpt-4o-mini".to_string(); + let alias = "mizan-public-gpt-4o-mini".to_string(); let upstream = ChatResponse { provider: "openai".to_string(), model: model.clone(), @@ -220,9 +231,9 @@ mod tests { }), }; - let response = map_to_chat_completion_response(model.clone(), upstream); - assert_eq!(response.model, model); + let response = map_to_chat_completion_response(alias.clone(), upstream); assert_eq!(response.choices.len(), 1); assert_eq!(response.choices[0].message.content, "pong"); + assert_eq!(response.model, alias); } } diff --git a/crates/mizan-api/src/utils.rs b/crates/mizan-api/src/utils.rs index 071afa9..4774328 100644 --- a/crates/mizan-api/src/utils.rs +++ b/crates/mizan-api/src/utils.rs @@ -134,6 +134,14 @@ pub(crate) fn decrypt_provider_api_key( )); } + let expected_nonce = derive_nonce(provider_secret, provider_connection_id); + if data[..NONCE_BYTES] != expected_nonce { + return Err(AppError::invalid_config( + "provider_connection.api_key_encrypted", + "encrypted api key does not match this provider connection", + )); + } + let nonce = aes_gcm::Nonce::::from_slice( &data[..NONCE_BYTES], ); @@ -344,4 +352,20 @@ mod tests { assert_ne!(encrypted_a, encrypted_b); } + + #[test] + fn provider_api_key_decryption_rejects_mismatched_connection_id() { + let provider_secret = "phase-1-secret-key"; + let original = "sk-live-abc"; + let id = Uuid::now_v7().to_string(); + let other_id = Uuid::now_v7().to_string(); + + let encrypted = + encrypt_provider_api_key(provider_secret, &id, original).expect("encrypt provider key"); + let error = decrypt_provider_api_key(provider_secret, &other_id, &encrypted) + .expect_err("mismatched connection should fail"); + + let message = format!("{error}"); + assert!(message.contains("does not match this provider connection")); + } }