diff --git a/Cargo.lock b/Cargo.lock index 68d5569ca..1a6a8f7b7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2045,6 +2045,7 @@ dependencies = [ "tracing-subscriber", "tracing-tree", "ulid", + "vrl", "xxhash-rust", ] diff --git a/Cargo.toml b/Cargo.toml index f126f24ae..85b88ee75 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -58,3 +58,4 @@ reqwest = "0.12.23" retry-policies = "0.4.0" reqwest-retry = "0.7.0" reqwest-middleware = "0.4.2" +vrl = { version = "0.27.0", features = ["compiler", "parser", "value", "diagnostic", "stdlib", "core"] } diff --git a/bin/router/Cargo.toml b/bin/router/Cargo.toml index 8512dc578..39c1c34ac 100644 --- a/bin/router/Cargo.toml +++ b/bin/router/Cargo.toml @@ -43,6 +43,7 @@ jsonwebtoken = { workspace = true } retry-policies = { workspace = true} reqwest-retry = { workspace = true } reqwest-middleware = { workspace = true } +vrl = { workspace = true } mimalloc = { version = "0.1.48", features = ["v3"] } moka = { version = "0.12.10", features = ["future"] } diff --git a/bin/router/src/pipeline/error.rs b/bin/router/src/pipeline/error.rs index 49a6ee8a0..049a77c0d 100644 --- a/bin/router/src/pipeline/error.rs +++ b/bin/router/src/pipeline/error.rs @@ -17,7 +17,10 @@ use serde::{Deserialize, Serialize}; use crate::{ jwt::errors::JwtForwardingError, - pipeline::header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, + pipeline::{ + header::{RequestAccepts, APPLICATION_GRAPHQL_RESPONSE_JSON_STR}, + progressive_override::LabelEvaluationError, + }, }; #[derive(Debug)] @@ -79,6 +82,8 @@ pub enum PipelineErrorVariant { PlanExecutionError(PlanExecutionError), #[error("Failed to produce a plan: {0}")] PlannerError(Arc), + #[error(transparent)] + LabelEvaluationError(LabelEvaluationError), // HTTP Security-related errors #[error("Required CSRF header(s) not present")] @@ -95,6 +100,7 @@ impl PipelineErrorVariant { Self::UnsupportedHttpMethod(_) => "METHOD_NOT_ALLOWED", Self::PlannerError(_) => "QUERY_PLAN_BUILD_FAILED", Self::PlanExecutionError(_) => "QUERY_PLAN_EXECUTION_FAILED", + Self::LabelEvaluationError(_) => "OVERRIDE_LABEL_EVALUATION_FAILED", Self::FailedToParseOperation(_) => "GRAPHQL_PARSE_FAILED", Self::ValidationErrors(_) => "GRAPHQL_VALIDATION_FAILED", Self::VariablesCoercionError(_) => "BAD_USER_INPUT", @@ -122,6 +128,7 @@ impl PipelineErrorVariant { match (self, prefer_ok) { (Self::PlannerError(_), _) => StatusCode::INTERNAL_SERVER_ERROR, (Self::PlanExecutionError(_), _) => StatusCode::INTERNAL_SERVER_ERROR, + (Self::LabelEvaluationError(_), _) => StatusCode::INTERNAL_SERVER_ERROR, (Self::JwtForwardingError(_), _) => StatusCode::INTERNAL_SERVER_ERROR, (Self::UnsupportedHttpMethod(_), _) => StatusCode::METHOD_NOT_ALLOWED, (Self::InvalidHeaderValue(_), _) => StatusCode::BAD_REQUEST, diff --git a/bin/router/src/pipeline/mod.rs b/bin/router/src/pipeline/mod.rs index 459e85870..89e3e96f9 100644 --- a/bin/router/src/pipeline/mod.rs +++ b/bin/router/src/pipeline/mod.rs @@ -1,7 +1,11 @@ use std::{borrow::Cow, sync::Arc}; -use hive_router_plan_executor::execution::plan::PlanExecutionOutput; -use hive_router_query_planner::utils::cancellation::CancellationToken; +use hive_router_plan_executor::execution::plan::{ + ClientRequestDetails, OperationDetails, PlanExecutionOutput, +}; +use hive_router_query_planner::{ + state::supergraph_state::OperationKind, utils::cancellation::CancellationToken, +}; use http::{header::CONTENT_TYPE, HeaderValue, Method}; use ntex::{ util::Bytes, @@ -12,7 +16,7 @@ use crate::{ pipeline::{ coerce_variables::coerce_request_variables, csrf_prevention::perform_csrf_prevention, - error::PipelineError, + error::{PipelineError, PipelineErrorFromAcceptHeader, PipelineErrorVariant}, execution::execute_plan, execution_request::get_execution_request, header::{ @@ -111,7 +115,6 @@ pub async fn execute_pipeline( validate_operation_with_cache(req, supergraph, schema_state, shared_state, &parser_payload) .await?; - let progressive_override_ctx = request_override_context()?; let normalize_payload = normalize_request_with_cache( req, supergraph, @@ -120,13 +123,35 @@ pub async fn execute_pipeline( &parser_payload, ) .await?; - let query = Cow::Owned(execution_request.query.clone()); + let query: Cow<'_, str> = Cow::Owned(execution_request.query.clone()); let variable_payload = coerce_request_variables(req, supergraph, execution_request, &normalize_payload)?; let query_plan_cancellation_token = CancellationToken::with_timeout(shared_state.router_config.query_planner.timeout); + let progressive_override_ctx = + request_override_context(&shared_state.override_labels_evaluator, || { + ClientRequestDetails { + method: req.method().clone(), + url: req.uri().clone(), + headers: req.headers(), + operation: OperationDetails { + name: normalize_payload.operation_for_plan.name.clone(), + kind: match normalize_payload.operation_for_plan.operation_kind { + Some(OperationKind::Query) => "query", + Some(OperationKind::Mutation) => "mutation", + Some(OperationKind::Subscription) => "subscription", + None => "query", + }, + query: query.clone(), + }, + } + }) + .map_err(|error| { + req.new_pipeline_error(PipelineErrorVariant::LabelEvaluationError(error)) + })?; + let query_plan_payload = plan_operation_with_cache( req, supergraph, diff --git a/bin/router/src/pipeline/progressive_override.rs b/bin/router/src/pipeline/progressive_override.rs index 713d942e5..776828f50 100644 --- a/bin/router/src/pipeline/progressive_override.rs +++ b/bin/router/src/pipeline/progressive_override.rs @@ -1,12 +1,44 @@ -use std::collections::{BTreeMap, HashSet}; +use std::collections::{BTreeMap, HashMap, HashSet}; +use hive_router_config::override_labels::{LabelOverrideValue, OverrideLabelsConfig}; +use hive_router_plan_executor::execution::plan::ClientRequestDetails; use hive_router_query_planner::{ graph::{PlannerOverrideContext, PERCENTAGE_SCALE_FACTOR}, state::supergraph_state::SupergraphState, }; use rand::Rng; +use vrl::{ + compiler::{compile as vrl_compile, Program as VrlProgram, TargetValue as VrlTargetValue}, + core::Value as VrlValue, + prelude::{ + state::RuntimeState as VrlState, Context as VrlContext, ExpressionError, + TimeZone as VrlTimeZone, + }, + stdlib::all as vrl_build_functions, + value::Secrets as VrlSecrets, +}; + +#[derive(thiserror::Error, Debug)] +#[error("Failed to compile override label expression for label '{label}': {error}")] +pub struct OverrideLabelsCompileError { + pub label: String, + pub error: String, +} -use super::error::PipelineError; +#[derive(thiserror::Error, Debug)] +pub enum LabelEvaluationError { + #[error( + "Failed to resolve VRL expression for override label '{label}'. Runtime error: {source}" + )] + ExpressionResolutionFailure { + label: String, + source: ExpressionError, + }, + #[error( + "VRL expression for override label '{label}' did not evaluate to a boolean. Got: {got}" + )] + ExpressionWrongType { label: String, got: String }, +} /// Contains the request-specific context for progressive overrides. /// This is stored in the request extensions @@ -19,9 +51,14 @@ pub struct RequestOverrideContext { } #[inline] -pub fn request_override_context() -> Result { - // No active flags by default - until we implement it - let active_flags = HashSet::new(); +pub fn request_override_context<'req, F>( + override_labels_evaluator: &OverrideLabelsEvaluator, + get_client_request: F, +) -> Result +where + F: FnOnce() -> ClientRequestDetails<'req>, +{ + let active_flags = override_labels_evaluator.evaluate(get_client_request)?; // Generate the random percentage value for this request. // Percentage is 0 - 100_000_000_000 (100*PERCENTAGE_SCALE_FACTOR) @@ -77,3 +114,105 @@ impl StableOverrideContext { } } } + +/// Evaluator for override labels based on configuration. +/// This struct compiles and evaluates the override label expressions. +/// It's intended to be used as a shared state in the router. +pub struct OverrideLabelsEvaluator { + static_enabled_labels: HashSet, + expressions: HashMap, +} + +impl OverrideLabelsEvaluator { + pub(crate) fn from_config( + override_labels_config: &OverrideLabelsConfig, + ) -> Result { + let mut static_enabled_labels = HashSet::new(); + let mut expressions = HashMap::new(); + let vrl_functions = vrl_build_functions(); + + for (label, value) in override_labels_config.iter() { + match value { + LabelOverrideValue::Boolean(true) => { + static_enabled_labels.insert(label.clone()); + } + LabelOverrideValue::Expression { expression } => { + let compilation_result = + vrl_compile(expression, &vrl_functions).map_err(|diagnostics| { + OverrideLabelsCompileError { + label: label.clone(), + error: diagnostics + .errors() + .into_iter() + .map(|d| d.code.to_string() + ": " + &d.message) + .collect::>() + .join(", "), + } + })?; + expressions.insert(label.clone(), compilation_result.program); + } + _ => {} // Skip false booleans + } + } + + Ok(Self { + static_enabled_labels, + expressions, + }) + } + + pub(crate) fn evaluate<'req, F>( + &self, + get_client_request: F, + ) -> Result, LabelEvaluationError> + where + F: FnOnce() -> ClientRequestDetails<'req>, + { + let mut active_flags = self.static_enabled_labels.clone(); + + if self.expressions.is_empty() { + return Ok(active_flags); + } + + let client_request = get_client_request(); + let mut target = VrlTargetValue { + value: VrlValue::Object(BTreeMap::from([( + "request".into(), + (&client_request).into(), + )])), + metadata: VrlValue::Object(BTreeMap::new()), + secrets: VrlSecrets::default(), + }; + + let mut state = VrlState::default(); + let timezone = VrlTimeZone::default(); + let mut ctx = VrlContext::new(&mut target, &mut state, &timezone); + + for (label, expression) in &self.expressions { + match expression.resolve(&mut ctx) { + Ok(evaluated_value) => match evaluated_value { + VrlValue::Boolean(true) => { + active_flags.insert(label.clone()); + } + VrlValue::Boolean(false) => { + // Do nothing for false + } + invalid_value => { + return Err(LabelEvaluationError::ExpressionWrongType { + label: label.clone(), + got: format!("{:?}", invalid_value), + }); + } + }, + Err(err) => { + return Err(LabelEvaluationError::ExpressionResolutionFailure { + label: label.clone(), + source: err, + }); + } + } + } + + Ok(active_flags) + } +} diff --git a/bin/router/src/shared_state.rs b/bin/router/src/shared_state.rs index 97c9885f3..f36bda6cd 100644 --- a/bin/router/src/shared_state.rs +++ b/bin/router/src/shared_state.rs @@ -8,12 +8,14 @@ use std::sync::Arc; use crate::jwt::JwtAuthRuntime; use crate::pipeline::cors::{CORSConfigError, Cors}; +use crate::pipeline::progressive_override::{OverrideLabelsCompileError, OverrideLabelsEvaluator}; pub struct RouterSharedState { pub validation_plan: ValidationPlan, pub parse_cache: Cache>>, pub router_config: Arc, pub headers_plan: HeaderRulesPlan, + pub override_labels_evaluator: OverrideLabelsEvaluator, pub cors_runtime: Option, pub jwt_auth_runtime: Option, } @@ -29,6 +31,10 @@ impl RouterSharedState { parse_cache: moka::future::Cache::new(1000), cors_runtime: Cors::from_config(&router_config.cors).map_err(Box::new)?, router_config: router_config.clone(), + override_labels_evaluator: OverrideLabelsEvaluator::from_config( + &router_config.override_labels, + ) + .map_err(Box::new)?, jwt_auth_runtime, }) } @@ -37,7 +43,9 @@ impl RouterSharedState { #[derive(thiserror::Error, Debug)] pub enum SharedStateError { #[error("invalid headers config: {0}")] - HeaderRuleCompileError(#[from] Box), + HeaderRuleCompile(#[from] Box), #[error("invalid regex in CORS config: {0}")] - CORSConfigError(#[from] Box), + CORSConfig(#[from] Box), + #[error("invalid override labels config: {0}")] + OverrideLabelsCompile(#[from] Box), } diff --git a/docs/README.md b/docs/README.md index 3b2fd2716..9e6a73b12 100644 --- a/docs/README.md +++ b/docs/README.md @@ -11,6 +11,7 @@ |[**http**](#http)|`object`|Configuration for the HTTP server/listener.
Default: `{"host":"0.0.0.0","port":4000}`
|| |[**jwt**](#jwt)|`object`, `null`|Configuration for JWT authentication plugin.
|yes| |[**log**](#log)|`object`|The router logger configuration.
Default: `{"filter":null,"format":"json","level":"info"}`
|| +|[**override\_labels**](#override_labels)|`object`|Configuration for overriding labels.
|| |[**override\_subgraph\_urls**](#override_subgraph_urls)|`object`|Configuration for overriding subgraph URLs.
Default: `{}`
|| |[**query\_planner**](#query_planner)|`object`|Query planning configuration.
Default: `{"allow_expose":false,"timeout":"10s"}`
|| |[**supergraph**](#supergraph)|`object`|Configuration for the Federation supergraph source. By default, the router will use a local file-based supergraph source (`./supergraph.graphql`).
|| @@ -63,6 +64,7 @@ log: filter: null format: json level: info +override_labels: {} override_subgraph_urls: accounts: url: https://accounts.example.com/graphql @@ -1536,6 +1538,18 @@ level: info ``` + +## override\_labels: object + +Configuration for overriding labels. + + +**Additional Properties** + +|Name|Type|Description|Required| +|----|----|-----------|--------| +|**Additional Properties**||Defines the value for a label override.

It can be a simple boolean,
or an object containing the expression that evaluates to a boolean.
|| + ## override\_subgraph\_urls: object diff --git a/lib/executor/Cargo.toml b/lib/executor/Cargo.toml index 995e766e5..1f32a2ef7 100644 --- a/lib/executor/Cargo.toml +++ b/lib/executor/Cargo.toml @@ -29,11 +29,11 @@ thiserror = { workspace = true } xxhash-rust = { workspace = true } tokio = { workspace = true, features = ["sync"] } dashmap = { workspace = true } +vrl = { workspace = true } + ahash = "0.8.12" regex-automata = "0.4.10" -vrl = { version = "0.27.0", features = ["compiler", "parser", "value", "diagnostic", "stdlib", "core"] } strum = { version = "0.27.2", features = ["derive"] } - ntex-http = "0.1.15" hyper-tls = { version = "0.6.0", features = ["vendored"] } hyper-util = { version = "0.1.16", features = [ diff --git a/lib/router-config/src/lib.rs b/lib/router-config/src/lib.rs index 5f8b78fde..a9cde36b5 100644 --- a/lib/router-config/src/lib.rs +++ b/lib/router-config/src/lib.rs @@ -6,6 +6,7 @@ pub mod headers; pub mod http_server; pub mod jwt_auth; pub mod log; +pub mod override_labels; pub mod override_subgraph_urls; pub mod primitives; pub mod query_planner; @@ -16,6 +17,7 @@ use config::{Config, File, FileFormat, FileSourceFile}; use envconfig::Envconfig; use schemars::JsonSchema; use serde::{Deserialize, Serialize}; +use std::collections::HashMap; use std::path::PathBuf; use crate::{ @@ -23,6 +25,7 @@ use crate::{ graphiql::GraphiQLConfig, http_server::HttpServerConfig, log::LoggingConfig, + override_labels::OverrideLabelsConfig, primitives::file_path::with_start_path, query_planner::QueryPlannerConfig, supergraph::SupergraphSource, @@ -81,6 +84,10 @@ pub struct HiveRouterConfig { /// Configuration for overriding subgraph URLs. #[serde(default)] pub override_subgraph_urls: override_subgraph_urls::OverrideSubgraphUrlsConfig, + + /// Configuration for overriding labels. + #[serde(default, skip_serializing_if = "HashMap::is_empty")] + pub override_labels: OverrideLabelsConfig, } #[derive(Debug, thiserror::Error)] diff --git a/lib/router-config/src/override_labels.rs b/lib/router-config/src/override_labels.rs new file mode 100644 index 000000000..b3dc01a75 --- /dev/null +++ b/lib/router-config/src/override_labels.rs @@ -0,0 +1,28 @@ +use schemars::JsonSchema; +use serde::{Deserialize, Serialize}; +use std::collections::HashMap; + +/// A map of label names to their override configuration. +pub type OverrideLabelsConfig = HashMap; + +/// Defines the value for a label override. +/// +/// It can be a simple boolean, +/// or an object containing the expression that evaluates to a boolean. +#[derive(Debug, Clone, Deserialize, Serialize, JsonSchema)] +#[serde(untagged)] +pub enum LabelOverrideValue { + /// A static boolean value to enable or disable the label. + Boolean(bool), + /// A dynamic value computed by an expression. + Expression { + /// An expression that must evaluate to a boolean. If true, the label will be applied. + expression: String, + }, +} + +impl LabelOverrideValue { + pub fn is_bool_and_true(&self) -> bool { + matches!(self, LabelOverrideValue::Boolean(true)) + } +}