diff --git a/Cargo.lock b/Cargo.lock index 4de1b34..50839e9 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -102,6 +102,23 @@ dependencies = [ "windows-sys 0.61.2", ] +[[package]] +name = "antlr4rust" +version = "0.5.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "093d520274bfff7278d776f7ea12981a0a0a6f96db90964658e0f38fc6e9a6a6" +dependencies = [ + "better_any", + "bit-set", + "byteorder", + "lazy_static", + "murmur3", + "once_cell", + "parking_lot", + "typed-arena", + "uuid", +] + [[package]] name = "anyhow" version = "1.0.102" @@ -315,6 +332,12 @@ version = "1.8.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "2af50177e190e07a26ab74f8b1efbfe2ef87da2116221318cb1c2e82baf7de06" +[[package]] +name = "better_any" +version = "0.2.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4372b9543397a4b86050cc5e7ee36953edf4bac9518e8a774c2da694977fb6e4" + [[package]] name = "bit-set" version = "0.8.0" @@ -527,6 +550,22 @@ dependencies = [ "shlex", ] +[[package]] +name = "cel" +version = "0.13.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "47a40f338a8c3505921000b609279775792c07cc21f97a3011578c0c5e1738ae" +dependencies = [ + "antlr4rust", + "chrono", + "lazy_static", + "nom 7.1.3", + "pastey", + "regex", + "serde", + "thiserror 1.0.69", +] + [[package]] name = "cfg-if" version = "1.0.4" @@ -2236,6 +2275,19 @@ dependencies = [ "wasm-bindgen", ] +[[package]] +name = "jsonpath-rust" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "633a7320c4bb672863a3782e89b9094ad70285e097ff6832cddd0ec615beadfa" +dependencies = [ + "pest", + "pest_derive", + "regex", + "serde_json", + "thiserror 2.0.18", +] + [[package]] name = "jsonschema" version = "0.42.2" @@ -2473,6 +2525,15 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "murmur3" +version = "0.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "a198f9589efc03f544388dfc4a19fe8af4323662b62f598b8dcfdac62c14771c" +dependencies = [ + "byteorder", +] + [[package]] name = "nkeys" version = "0.4.5" @@ -2801,6 +2862,12 @@ dependencies = [ "syn 2.0.117", ] +[[package]] +name = "pastey" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "c5a797f0e07bdf071d15742978fc3128ec6c22891c31a3a931513263904c982a" + [[package]] name = "pem-rfc7468" version = "0.7.0" @@ -3708,6 +3775,7 @@ dependencies = [ "arc-swap", "async-nats", "async-trait", + "cel", "chrono", "criterion", "csv", @@ -3715,6 +3783,7 @@ dependencies = [ "futures", "jaq-interpret", "jaq-parse", + "jsonpath-rust", "notify", "opentelemetry-proto", "parking_lot", @@ -5011,6 +5080,12 @@ dependencies = [ "tokio", ] +[[package]] +name = "typed-arena" +version = "2.0.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6af6ae20167a9ece4bcb41af5b80f8a1f1df981f6391189ce00fd257af04126a" + [[package]] name = "typenum" version = "1.20.0" diff --git a/crates/rsigma-cli/src/commands/resolve.rs b/crates/rsigma-cli/src/commands/resolve.rs index 4434dec..0539ac9 100644 --- a/crates/rsigma-cli/src/commands/resolve.rs +++ b/crates/rsigma-cli/src/commands/resolve.rs @@ -7,7 +7,12 @@ use rsigma_eval::parse_pipeline_file; use rsigma_runtime::DefaultSourceResolver; use rsigma_runtime::sources::SourceResolver; -pub fn cmd_resolve(pipeline_paths: Vec, source_filter: Option, pretty: bool) { +pub fn cmd_resolve( + pipeline_paths: Vec, + source_filter: Option, + pretty: bool, + dry_run: bool, +) { let rt = tokio::runtime::Builder::new_current_thread() .enable_all() .build() @@ -16,10 +21,15 @@ pub fn cmd_resolve(pipeline_paths: Vec, source_filter: Option, std::process::exit(crate::exit_code::CONFIG_ERROR); }); - rt.block_on(async { resolve_async(pipeline_paths, source_filter, pretty).await }); + rt.block_on(async { resolve_async(pipeline_paths, source_filter, pretty, dry_run).await }); } -async fn resolve_async(pipeline_paths: Vec, source_filter: Option, pretty: bool) { +async fn resolve_async( + pipeline_paths: Vec, + source_filter: Option, + pretty: bool, + dry_run: bool, +) { let mut all_sources = Vec::new(); for path in &pipeline_paths { @@ -58,6 +68,35 @@ async fn resolve_async(pipeline_paths: Vec, source_filter: Option = all_sources + .iter() + .map(|(pipeline_name, source)| { + serde_json::json!({ + "pipeline": pipeline_name, + "source_id": &source.id, + "source_type": format!("{:?}", source.source_type).split('{').next().unwrap_or("unknown").trim(), + "required": source.required, + "refresh": format!("{:?}", source.refresh), + }) + }) + .collect(); + + let output = if items.len() == 1 { + items.into_iter().next().unwrap() + } else { + serde_json::Value::Array(items) + }; + + let json_str = if pretty { + serde_json::to_string_pretty(&output).unwrap() + } else { + serde_json::to_string(&output).unwrap() + }; + println!("{json_str}"); + return; + } + let resolver = Arc::new(DefaultSourceResolver::new()); let mut results = Vec::new(); let mut had_error = false; diff --git a/crates/rsigma-cli/src/commands/validate.rs b/crates/rsigma-cli/src/commands/validate.rs index cb6b214..a94e5b8 100644 --- a/crates/rsigma-cli/src/commands/validate.rs +++ b/crates/rsigma-cli/src/commands/validate.rs @@ -4,8 +4,65 @@ use std::process; use rsigma_eval::Engine; use rsigma_parser::parse_sigma_directory; -pub(crate) fn cmd_validate(path: PathBuf, verbose: bool, pipeline_paths: Vec) { - let pipelines = crate::load_pipelines(&pipeline_paths); +pub(crate) fn cmd_validate( + path: PathBuf, + verbose: bool, + pipeline_paths: Vec, + resolve_sources: bool, +) { + let mut pipelines = crate::load_pipelines(&pipeline_paths); + + if resolve_sources { + let has_dynamic = pipelines.iter().any(|p| p.is_dynamic()); + if has_dynamic { + let rt = tokio::runtime::Builder::new_current_thread() + .enable_all() + .build() + .unwrap_or_else(|e| { + eprintln!("Failed to create async runtime for source resolution: {e}"); + process::exit(crate::exit_code::CONFIG_ERROR); + }); + + let resolver = rsigma_runtime::DefaultSourceResolver::new(); + let mut resolved_pipelines = Vec::with_capacity(pipelines.len()); + let mut source_errors: Vec = Vec::new(); + + for pipeline in &pipelines { + if pipeline.is_dynamic() { + match rt.block_on(rsigma_runtime::sources::resolve_all( + &resolver, + &pipeline.sources, + )) { + Ok(resolved_data) => { + let expanded = + rsigma_runtime::sources::template::TemplateExpander::expand( + pipeline, + &resolved_data, + ); + resolved_pipelines.push(expanded); + } + Err(e) => { + source_errors.push(format!("pipeline '{}': {e}", pipeline.name)); + resolved_pipelines.push(pipeline.clone()); + } + } + } else { + resolved_pipelines.push(pipeline.clone()); + } + } + + if !source_errors.is_empty() { + eprintln!("Source resolution errors:"); + for err in &source_errors { + eprintln!(" - {err}"); + } + process::exit(crate::exit_code::CONFIG_ERROR); + } + + pipelines = resolved_pipelines; + println!(" Sources resolved: OK"); + } + } match parse_sigma_directory(&path) { Ok(collection) => { diff --git a/crates/rsigma-cli/src/daemon/instrumented_resolver.rs b/crates/rsigma-cli/src/daemon/instrumented_resolver.rs index c11fcb3..fa1d8e3 100644 --- a/crates/rsigma-cli/src/daemon/instrumented_resolver.rs +++ b/crates/rsigma-cli/src/daemon/instrumented_resolver.rs @@ -22,6 +22,11 @@ impl InstrumentedResolver { metrics, } } + + /// Access the underlying cache for invalidation operations. + pub fn cache(&self) -> &rsigma_runtime::sources::cache::SourceCache { + self.inner.cache() + } } #[async_trait::async_trait] @@ -44,6 +49,15 @@ impl SourceResolver for InstrumentedResolver { if value.from_cache { self.metrics.source_cache_hits.inc(); } + self.metrics + .source_last_resolved + .with_label_values(&[source.id.as_str()]) + .set( + std::time::SystemTime::now() + .duration_since(std::time::UNIX_EPOCH) + .unwrap_or_default() + .as_secs_f64(), + ); } Err(e) => { let error_kind = match &e.kind { diff --git a/crates/rsigma-cli/src/daemon/metrics.rs b/crates/rsigma-cli/src/daemon/metrics.rs index 4511c62..f5dc363 100644 --- a/crates/rsigma-cli/src/daemon/metrics.rs +++ b/crates/rsigma-cli/src/daemon/metrics.rs @@ -1,5 +1,5 @@ use prometheus::{ - Gauge, Histogram, HistogramOpts, IntCounter, IntCounterVec, IntGauge, Opts, Registry, + Gauge, GaugeVec, Histogram, HistogramOpts, IntCounter, IntCounterVec, IntGauge, Opts, Registry, TextEncoder, }; use rsigma_runtime::MetricsHook; @@ -30,6 +30,7 @@ pub struct Metrics { pub source_resolve_errors: IntCounterVec, pub source_resolve_latency: Histogram, pub source_cache_hits: IntCounter, + pub source_last_resolved: GaugeVec, #[cfg(feature = "daemon-otlp")] pub otlp_requests: IntCounterVec, #[cfg(feature = "daemon-otlp")] @@ -235,6 +236,14 @@ impl Metrics { "Times cached source data was served on resolution failure", )) .unwrap(); + let source_last_resolved = GaugeVec::new( + Opts::new( + "rsigma_source_last_resolved_timestamp", + "Unix timestamp of last successful resolution per source", + ), + &["source_id"], + ) + .unwrap(); registry .register(Box::new(source_resolves_total.clone())) @@ -248,6 +257,9 @@ impl Metrics { registry .register(Box::new(source_cache_hits.clone())) .unwrap(); + registry + .register(Box::new(source_last_resolved.clone())) + .unwrap(); #[cfg(feature = "daemon-otlp")] let otlp_requests = IntCounterVec::new( @@ -305,6 +317,7 @@ impl Metrics { source_resolve_errors, source_resolve_latency, source_cache_hits, + source_last_resolved, #[cfg(feature = "daemon-otlp")] otlp_requests, #[cfg(feature = "daemon-otlp")] diff --git a/crates/rsigma-cli/src/daemon/reload.rs b/crates/rsigma-cli/src/daemon/reload.rs index d5942a8..33a924d 100644 --- a/crates/rsigma-cli/src/daemon/reload.rs +++ b/crates/rsigma-cli/src/daemon/reload.rs @@ -60,9 +60,12 @@ pub fn spawn_file_watcher( Some(watcher) } -/// Set up a SIGHUP handler that sends reload signals. +/// Set up a SIGHUP handler that sends reload signals and source re-resolution triggers. #[cfg(unix)] -pub async fn sighup_listener(reload_tx: mpsc::Sender<()>) { +pub async fn sighup_listener( + reload_tx: mpsc::Sender<()>, + sources_trigger_tx: Option>, +) { use tokio::signal::unix::{SignalKind, signal}; let mut sig = match signal(SignalKind::hangup()) { @@ -75,13 +78,18 @@ pub async fn sighup_listener(reload_tx: mpsc::Sender<()>) { loop { sig.recv().await; - tracing::info!("SIGHUP received, triggering reload"); + tracing::info!("SIGHUP received, triggering reload and source re-resolution"); let _ = reload_tx.try_send(()); + if let Some(tx) = &sources_trigger_tx { + let _ = tx.try_send(rsigma_runtime::sources::refresh::RefreshTrigger::All); + } } } #[cfg(not(unix))] -pub async fn sighup_listener(_reload_tx: mpsc::Sender<()>) { - // No-op on non-Unix platforms +pub async fn sighup_listener( + _reload_tx: mpsc::Sender<()>, + _sources_trigger_tx: Option>, +) { std::future::pending::<()>().await; } diff --git a/crates/rsigma-cli/src/daemon/server.rs b/crates/rsigma-cli/src/daemon/server.rs index a4d7dd5..de80ba5 100644 --- a/crates/rsigma-cli/src/daemon/server.rs +++ b/crates/rsigma-cli/src/daemon/server.rs @@ -7,7 +7,7 @@ use std::time::Instant; use axum::extract::State; use axum::http::StatusCode; use axum::response::{IntoResponse, Response}; -use axum::routing::{get, post}; +use axum::routing::{delete, get, post}; use axum::{Json, Router}; use rsigma_eval::{CorrelationConfig, Pipeline, ProcessResult}; use rsigma_runtime::{ @@ -55,6 +55,8 @@ struct AppState { event_tx: Option>, /// Channel for on-demand source resolution triggers. sources_trigger_tx: Option>, + /// The instrumented source resolver (provides cache access for invalidation API). + source_resolver: Option>, /// Channel for OTLP event ingestion. Always set when daemon-otlp is compiled in. #[cfg(feature = "daemon-otlp")] otlp_event_tx: mpsc::Sender, @@ -86,6 +88,7 @@ pub struct DaemonConfig { pub state_restore_mode: StateRestoreMode, pub drain_timeout: u64, pub input_format: InputFormat, + pub allow_remote_include: bool, } pub async fn run_daemon(config: DaemonConfig) { @@ -109,6 +112,7 @@ pub async fn run_daemon(config: DaemonConfig) { config.include_event, ); engine.set_pipeline_paths(config.pipeline_paths.clone()); + engine.set_allow_remote_include(config.allow_remote_include); // Set up dynamic source resolver if any pipeline has dynamic sources let has_dynamic = config.pipelines.iter().any(|p| p.is_dynamic()); @@ -116,10 +120,15 @@ pub async fn run_daemon(config: DaemonConfig) { mpsc::Sender, > = None; + let mut source_resolver_val: Option> = + None; + if has_dynamic { - let resolver: Arc = Arc::new( - super::instrumented_resolver::InstrumentedResolver::new(metrics.clone()), - ); + let instrumented = Arc::new(super::instrumented_resolver::InstrumentedResolver::new( + metrics.clone(), + )); + source_resolver_val = Some(instrumented.clone()); + let resolver: Arc = instrumented; engine.set_source_resolver(resolver.clone()); // Resolve dynamic sources at startup (blocks on required sources) @@ -139,7 +148,48 @@ pub async fn run_daemon(config: DaemonConfig) { if !all_sources.is_empty() { let scheduler = rsigma_runtime::sources::refresh::RefreshScheduler::new(); sources_trigger_tx_val = Some(scheduler.trigger_sender()); + + // Spawn NATS control subject listener for remote re-resolution triggers + #[cfg(feature = "daemon-nats")] + { + let nats_url = config.nats_config.url.clone(); + let trigger_tx = scheduler.trigger_sender(); + tokio::spawn(async move { + let subject = rsigma_runtime::sources::refresh::NATS_CONTROL_SUBJECT; + if let Err(e) = rsigma_runtime::sources::refresh::nats_control_loop( + &nats_url, subject, trigger_tx, + ) + .await + { + tracing::warn!( + error = %e, + "NATS control subject listener failed" + ); + } + }); + } + + // Collect optional source IDs for background retry + let optional_source_ids: Vec = all_sources + .iter() + .filter(|s| !s.required) + .map(|s| s.id.clone()) + .collect(); + + let bg_trigger_tx = scheduler.trigger_sender(); scheduler.run(all_sources, resolver); + + // Spawn background retry for optional sources that may have failed at startup + if !optional_source_ids.is_empty() { + tokio::spawn(async move { + tokio::time::sleep(std::time::Duration::from_secs(5)).await; + for id in optional_source_ids { + let _ = bg_trigger_tx + .send(rsigma_runtime::sources::refresh::RefreshTrigger::Single(id)) + .await; + } + }); + } } } @@ -253,7 +303,8 @@ pub async fn run_daemon(config: DaemonConfig) { reload_tx: reload_tx.clone(), start_time, event_tx: http_event_tx, - sources_trigger_tx: sources_trigger_tx_val, + sources_trigger_tx: sources_trigger_tx_val.clone(), + source_resolver: source_resolver_val, #[cfg(feature = "daemon-otlp")] otlp_event_tx, }; @@ -271,6 +322,10 @@ pub async fn run_daemon(config: DaemonConfig) { .route( "/api/v1/sources/resolve/{source_id}", post(resolve_source_by_id), + ) + .route( + "/api/v1/sources/cache/{source_id}", + delete(invalidate_source_cache), ); #[cfg(feature = "daemon-otlp")] @@ -302,10 +357,11 @@ pub async fn run_daemon(config: DaemonConfig) { #[cfg(not(feature = "daemon-otlp"))] tracing::info!(addr = %actual_addr, "API server listening"); - // Spawn SIGHUP listener - let sighup_tx = reload_tx.clone(); + // Spawn SIGHUP listener (triggers both rule reload and source re-resolution) + let sighup_reload_tx = reload_tx.clone(); + let sighup_sources_tx = sources_trigger_tx_val.clone(); tokio::spawn(async move { - reload::sighup_listener(sighup_tx).await; + reload::sighup_listener(sighup_reload_tx, sighup_sources_tx).await; }); // Spawn reload handler — uses LogProcessor::reload_rules for atomic hot-reload @@ -993,10 +1049,64 @@ struct StatusResponse { detection_matches: u64, correlation_matches: u64, uptime_seconds: f64, + #[serde(skip_serializing_if = "Option::is_none")] + dynamic_sources: Option, +} + +#[derive(Serialize)] +struct DynamicSourcesSummary { + total: usize, + resolves_total: u64, + errors_total: u64, + cache_hits: u64, } async fn status(State(state): State) -> impl IntoResponse { let stats = state.processor.stats(); + + let dynamic_sources = state.source_resolver.as_ref().map(|_| { + use prometheus::core::Collector; + let resolves: u64 = state + .metrics + .source_resolves_total + .collect() + .first() + .map(|mf| { + mf.get_metric() + .iter() + .map(|m| m.get_counter().get_value() as u64) + .sum() + }) + .unwrap_or(0); + let errors: u64 = state + .metrics + .source_resolve_errors + .collect() + .first() + .map(|mf| { + mf.get_metric() + .iter() + .map(|m| m.get_counter().get_value() as u64) + .sum() + }) + .unwrap_or(0); + let cache_hits = state.metrics.source_cache_hits.get(); + let total = state + .metrics + .source_last_resolved + .collect() + .first() + .map(|mf| mf.get_metric().len()) + .unwrap_or(0); + + DynamicSourcesSummary { + total, + resolves_total: resolves, + errors_total: errors, + cache_hits, + } + }); + let resp = StatusResponse { status: if state.health.is_ready() { "running".to_string() @@ -1010,6 +1120,7 @@ async fn status(State(state): State) -> impl IntoResponse { detection_matches: state.metrics.detection_matches.get(), correlation_matches: state.metrics.correlation_matches.get(), uptime_seconds: state.start_time.elapsed().as_secs_f64(), + dynamic_sources, }; Json(resp) } @@ -1095,6 +1206,24 @@ async fn resolve_source_by_id( } } +async fn invalidate_source_cache( + State(state): State, + axum::extract::Path(source_id): axum::extract::Path, +) -> impl IntoResponse { + let Some(resolver) = &state.source_resolver else { + return ( + StatusCode::NOT_FOUND, + Json(serde_json::json!({ "error": "no dynamic sources configured" })), + ); + }; + + resolver.cache().invalidate(&source_id); + ( + StatusCode::OK, + Json(serde_json::json!({ "status": "invalidated", "source_id": source_id })), + ) +} + /// Accept events via HTTP POST for processing. /// Each non-empty line in the request body is parsed using the configured /// `--input-format` and forwarded to the engine. diff --git a/crates/rsigma-cli/src/main.rs b/crates/rsigma-cli/src/main.rs index 09594d6..3d84b77 100644 --- a/crates/rsigma-cli/src/main.rs +++ b/crates/rsigma-cli/src/main.rs @@ -50,6 +50,11 @@ enum Commands { /// Processing pipeline(s) to apply. Accepts builtin names (ecs_windows, sysmon) or YAML file paths #[arg(short = 'p', long = "pipeline")] pipelines: Vec, + + /// Also resolve dynamic pipeline sources during validation. + /// Sources must be reachable (file/command/HTTP) for validation to pass. + #[arg(long = "resolve-sources")] + resolve_sources: bool, }, /// Parse a condition expression and print the AST @@ -309,6 +314,12 @@ enum Commands { #[cfg(feature = "daemon-nats")] #[arg(long = "consumer-group", env = "RSIGMA_CONSUMER_GROUP")] consumer_group: Option, + + /// Allow include directives to reference remote (HTTP/NATS) sources. + /// By default, includes are restricted to local sources (file/command) + /// for security. Use this flag to opt in to remote include resolution. + #[arg(long = "allow-remote-include")] + allow_remote_include: bool, }, /// Evaluate events against Sigma rules @@ -468,6 +479,10 @@ enum Commands { /// Pretty-print JSON output #[arg(long)] pretty: bool, + + /// Show what would be resolved without performing resolution + #[arg(long = "dry-run")] + dry_run: bool, }, /// List all fields referenced by Sigma rules @@ -551,6 +566,7 @@ fn main() { timestamp_fallback, #[cfg(feature = "daemon-nats")] consumer_group, + allow_remote_include, } => { #[cfg(feature = "daemon-nats")] let nats_auth = NatsAuthArgs { @@ -620,6 +636,7 @@ fn main() { replay_policy, #[cfg(feature = "daemon-nats")] consumer_group, + allow_remote_include, ) } Commands::Parse { path, pretty } => commands::cmd_parse(path, pretty), @@ -627,7 +644,8 @@ fn main() { path, verbose, pipelines, - } => commands::cmd_validate(path, verbose, pipelines), + resolve_sources, + } => commands::cmd_validate(path, verbose, pipelines, resolve_sources), Commands::Lint { path, schema, @@ -730,7 +748,8 @@ fn main() { pipelines, source, pretty, - } => commands::cmd_resolve(pipelines, source, pretty), + dry_run, + } => commands::cmd_resolve(pipelines, source, pretty, dry_run), } } @@ -781,6 +800,7 @@ fn cmd_daemon( #[cfg(feature = "daemon-nats")] nats_auth: NatsAuthArgs, #[cfg(feature = "daemon-nats")] replay_policy: rsigma_runtime::ReplayPolicy, #[cfg(feature = "daemon-nats")] consumer_group: Option, + allow_remote_include: bool, ) { // Set up structured logging tracing_subscriber::fmt() @@ -854,6 +874,7 @@ fn cmd_daemon( #[cfg(feature = "daemon-nats")] consumer_group, state_restore_mode, + allow_remote_include, }; let rt = tokio::runtime::Builder::new_multi_thread() diff --git a/crates/rsigma-eval/src/pipeline/parsing.rs b/crates/rsigma-eval/src/pipeline/parsing.rs index 4d95d7b..c592ae9 100644 --- a/crates/rsigma-eval/src/pipeline/parsing.rs +++ b/crates/rsigma-eval/src/pipeline/parsing.rs @@ -13,7 +13,8 @@ use super::conditions::{ }; use super::finalizers::Finalizer; use super::sources::{ - DataFormat, DynamicSource, ErrorPolicy, RefLocation, RefreshPolicy, SourceRef, SourceType, + DataFormat, DynamicSource, ErrorPolicy, ExtractExpr, RefLocation, RefreshPolicy, SourceRef, + SourceType, }; use super::transformations::Transformation; use super::{Pipeline, TransformationItem}; @@ -873,10 +874,7 @@ fn parse_dynamic_source(value: &serde_yaml::Value) -> Result { })?; let format = parse_data_format(obj.get(ykey("format"))); - let extract = obj - .get(ykey("extract")) - .and_then(|v| v.as_str()) - .map(String::from); + let extract = parse_extract_expr(obj.get(ykey("extract")), &id)?; let source_type = match type_str { "http" => { @@ -927,6 +925,7 @@ fn parse_dynamic_source(value: &serde_yaml::Value) -> Result { SourceType::File { path: PathBuf::from(path), format, + extract, } } "nats" => { @@ -978,6 +977,56 @@ fn parse_dynamic_source(value: &serde_yaml::Value) -> Result { }) } +/// Parse an `extract` field which can be either: +/// - A plain string (always treated as jq): `extract: ".emails[]"` +/// - A structured mapping: `extract: { expr: "$.emails[*]", type: jsonpath }` +fn parse_extract_expr( + value: Option<&serde_yaml::Value>, + source_id: &str, +) -> Result> { + let Some(val) = value else { + return Ok(None); + }; + + if let Some(s) = val.as_str() { + return Ok(Some(ExtractExpr::Jq(s.to_string()))); + } + + if let Some(map) = val.as_mapping() { + let expr = map + .get(ykey("expr")) + .and_then(|v| v.as_str()) + .ok_or_else(|| { + EvalError::InvalidModifiers(format!( + "source '{source_id}': extract object must have an 'expr' field" + )) + })? + .to_string(); + + let extract_type = map + .get(ykey("type")) + .and_then(|v| v.as_str()) + .unwrap_or("jq"); + + let extract_expr = match extract_type { + "jq" => ExtractExpr::Jq(expr), + "jsonpath" => ExtractExpr::JsonPath(expr), + "cel" => ExtractExpr::Cel(expr), + other => { + return Err(EvalError::InvalidModifiers(format!( + "source '{source_id}': unknown extract type '{other}' (expected: jq, jsonpath, cel)" + ))); + } + }; + + return Ok(Some(extract_expr)); + } + + Err(EvalError::InvalidModifiers(format!( + "source '{source_id}': 'extract' must be a string or mapping" + ))) +} + fn parse_data_format(value: Option<&serde_yaml::Value>) -> DataFormat { match value.and_then(|v| v.as_str()) { Some("json") => DataFormat::Json, diff --git a/crates/rsigma-eval/src/pipeline/sources.rs b/crates/rsigma-eval/src/pipeline/sources.rs index 351a5b5..397ce26 100644 --- a/crates/rsigma-eval/src/pipeline/sources.rs +++ b/crates/rsigma-eval/src/pipeline/sources.rs @@ -43,25 +43,44 @@ pub enum SourceType { method: Option, headers: HashMap, format: DataFormat, - extract: Option, + extract: Option, }, /// Run a local command and capture its stdout. Command { command: Vec, format: DataFormat, - extract: Option, + extract: Option, }, /// Read data from a local file. - File { path: PathBuf, format: DataFormat }, + File { + path: PathBuf, + format: DataFormat, + extract: Option, + }, /// Subscribe to a NATS subject for push-based updates. Nats { url: String, subject: String, format: DataFormat, - extract: Option, + extract: Option, }, } +/// An extraction expression applied to source data after parsing. +/// +/// Supports two syntax forms in YAML: +/// - Plain string: always jq (the common case): `extract: ".emails[]"` +/// - Structured object: explicit language: `extract: { expr: "$.emails[*]", type: jsonpath }` +#[derive(Debug, Clone, PartialEq)] +pub enum ExtractExpr { + /// A jq expression (default). Evaluated via jaq. + Jq(String), + /// A JSONPath expression. Evaluated via serde_json_path. + JsonPath(String), + /// A CEL (Common Expression Language) expression. Evaluated via cel-interpreter. + Cel(String), +} + /// How often a source should be refreshed. #[derive(Debug, Clone, PartialEq)] pub enum RefreshPolicy { diff --git a/crates/rsigma-eval/src/pipeline/tests.rs b/crates/rsigma-eval/src/pipeline/tests.rs index 545fa12..a5561c3 100644 --- a/crates/rsigma-eval/src/pipeline/tests.rs +++ b/crates/rsigma-eval/src/pipeline/tests.rs @@ -1052,7 +1052,10 @@ transformations: } => { assert_eq!(url, "https://api.internal/v1/admin-emails"); assert_eq!(*format, sources::DataFormat::Json); - assert_eq!(extract.as_deref(), Some(".emails[]")); + assert_eq!( + *extract, + Some(sources::ExtractExpr::Jq(".emails[]".to_string())) + ); } other => panic!("expected Http, got {other:?}"), } @@ -1117,9 +1120,14 @@ transformations: assert_eq!(src.refresh, sources::RefreshPolicy::Watch); match &src.source_type { - sources::SourceType::File { path, format } => { + sources::SourceType::File { + path, + format, + extract, + } => { assert_eq!(path, std::path::Path::new("/etc/rsigma/watchlist.json")); assert_eq!(*format, sources::DataFormat::Json); + assert_eq!(*extract, None); } other => panic!("expected File, got {other:?}"), } @@ -1155,6 +1163,115 @@ transformations: } } +#[test] +fn test_parse_extract_structured_jsonpath() { + let yaml = r#" +name: JSONPath Extract Pipeline +sources: + - id: config + type: http + url: https://api.internal/v1/config + format: json + extract: + expr: "$.settings[*]" + type: jsonpath +transformations: + - type: value_placeholders +"#; + let pipeline = parse_pipeline(yaml).unwrap(); + let src = &pipeline.sources[0]; + match &src.source_type { + sources::SourceType::Http { extract, .. } => { + assert_eq!( + *extract, + Some(sources::ExtractExpr::JsonPath("$.settings[*]".to_string())) + ); + } + other => panic!("expected Http, got {other:?}"), + } +} + +#[test] +fn test_parse_extract_structured_cel() { + let yaml = r#" +name: CEL Extract Pipeline +sources: + - id: emails + type: file + path: /etc/rsigma/emails.json + format: json + extract: + expr: "data.emails.filter(e, e.endsWith('@corp.com'))" + type: cel +transformations: + - type: value_placeholders +"#; + let pipeline = parse_pipeline(yaml).unwrap(); + let src = &pipeline.sources[0]; + match &src.source_type { + sources::SourceType::File { extract, .. } => { + assert_eq!( + *extract, + Some(sources::ExtractExpr::Cel( + "data.emails.filter(e, e.endsWith('@corp.com'))".to_string() + )) + ); + } + other => panic!("expected File, got {other:?}"), + } +} + +#[test] +fn test_parse_extract_structured_jq_explicit() { + let yaml = r#" +name: Explicit JQ Extract Pipeline +sources: + - id: data + type: http + url: https://api.internal/v1/data + format: json + extract: + expr: ".items[] | select(.active)" + type: jq +transformations: + - type: value_placeholders +"#; + let pipeline = parse_pipeline(yaml).unwrap(); + let src = &pipeline.sources[0]; + match &src.source_type { + sources::SourceType::Http { extract, .. } => { + assert_eq!( + *extract, + Some(sources::ExtractExpr::Jq( + ".items[] | select(.active)".to_string() + )) + ); + } + other => panic!("expected Http, got {other:?}"), + } +} + +#[test] +fn test_parse_extract_unknown_type_errors() { + let yaml = r#" +name: Bad Extract Pipeline +sources: + - id: data + type: http + url: https://api.internal/v1/data + format: json + extract: + expr: "something" + type: xpath +transformations: + - type: value_placeholders +"#; + let result = parse_pipeline(yaml); + assert!(result.is_err()); + let err = result.unwrap_err().to_string(); + assert!(err.contains("xpath"), "error should mention 'xpath': {err}"); +} + #[test] fn test_parse_on_demand_refresh() { let yaml = r#" diff --git a/crates/rsigma-runtime/Cargo.toml b/crates/rsigma-runtime/Cargo.toml index 017077a..33157f7 100644 --- a/crates/rsigma-runtime/Cargo.toml +++ b/crates/rsigma-runtime/Cargo.toml @@ -34,6 +34,8 @@ regex = "1" csv = "1" jaq-interpret = "1.5.0" jaq-parse = "1.0.3" +jsonpath-rust = "1" +cel = "0.13" reqwest = { version = "0.12", default-features = false, features = ["rustls-tls", "json"] } rusqlite = { version = "0.39", features = ["bundled"] } notify = "8.2" diff --git a/crates/rsigma-runtime/src/engine.rs b/crates/rsigma-runtime/src/engine.rs index ae1d938..d887716 100644 --- a/crates/rsigma-runtime/src/engine.rs +++ b/crates/rsigma-runtime/src/engine.rs @@ -153,8 +153,10 @@ impl RuntimeEngine { let pipelines = std::mem::take(&mut self.pipelines); let resolver = self.source_resolver.clone().unwrap(); let allow_remote = self.allow_remote_include; - let resolved = handle.block_on(async { - resolve_pipelines_async(&resolver, &pipelines, allow_remote).await + let resolved = tokio::task::block_in_place(|| { + handle.block_on(async { + resolve_pipelines_async(&resolver, &pipelines, allow_remote).await + }) }); match resolved { Ok(p) => self.pipelines = p, diff --git a/crates/rsigma-runtime/src/sources/cache.rs b/crates/rsigma-runtime/src/sources/cache.rs index 25ed521..3b62ec0 100644 --- a/crates/rsigma-runtime/src/sources/cache.rs +++ b/crates/rsigma-runtime/src/sources/cache.rs @@ -1,26 +1,47 @@ //! Source resolution cache with in-memory and optional SQLite persistence. //! //! Stores last-known-good values so that `on_error: use_cached` can serve -//! stale data when a source fetch fails. +//! stale data when a source fetch fails. Supports optional TTL-based expiration. use std::collections::HashMap; use std::path::Path; use std::sync::Mutex; +use std::time::{Duration, Instant}; + +/// A cached entry with its stored timestamp. +#[derive(Clone)] +struct CacheEntry { + value: serde_json::Value, + stored_at: Instant, +} /// Thread-safe cache for resolved source data. /// -/// Provides an in-memory layer with optional SQLite-backed disk persistence. +/// Provides an in-memory layer with optional SQLite-backed disk persistence +/// and optional TTL-based expiration. pub struct SourceCache { - entries: Mutex>, + entries: Mutex>, db: Option>, + ttl: Option, } impl SourceCache { - /// Create a new in-memory-only cache. + /// Create a new in-memory-only cache (no TTL). pub fn new() -> Self { Self { entries: Mutex::new(HashMap::new()), db: None, + ttl: None, + } + } + + /// Create a new in-memory-only cache with a TTL. + /// Entries older than `ttl` are considered expired and will not be returned. + pub fn with_ttl(ttl: Duration) -> Self { + Self { + entries: Mutex::new(HashMap::new()), + db: None, + ttl: Some(ttl), } } @@ -29,6 +50,11 @@ impl SourceCache { /// The table is created if it does not exist. Existing cached values /// are loaded into memory on construction. pub fn with_sqlite(path: &Path) -> Result { + Self::with_sqlite_and_ttl(path, None) + } + + /// Create a SQLite-backed cache with an optional TTL. + pub fn with_sqlite_and_ttl(path: &Path, ttl: Option) -> Result { let conn = rusqlite::Connection::open(path) .map_err(|e| format!("failed to open source cache DB: {e}"))?; @@ -41,7 +67,6 @@ impl SourceCache { ) .map_err(|e| format!("failed to create source_cache table: {e}"))?; - // Load existing entries into memory let entries = { let mut map = HashMap::new(); let mut stmt = conn @@ -57,7 +82,13 @@ impl SourceCache { for (id, val_str) in rows.flatten() { if let Ok(val) = serde_json::from_str(&val_str) { - map.insert(id, val); + map.insert( + id, + CacheEntry { + value: val, + stored_at: Instant::now(), + }, + ); } } map @@ -66,6 +97,7 @@ impl SourceCache { Ok(Self { entries: Mutex::new(entries), db: Some(Mutex::new(conn)), + ttl, }) } @@ -73,7 +105,13 @@ impl SourceCache { pub fn store(&self, source_id: &str, value: &serde_json::Value) { { let mut entries = self.entries.lock().unwrap(); - entries.insert(source_id.to_string(), value.clone()); + entries.insert( + source_id.to_string(), + CacheEntry { + value: value.clone(), + stored_at: Instant::now(), + }, + ); } if let Some(db) = &self.db { @@ -87,9 +125,18 @@ impl SourceCache { } /// Retrieve a cached value for a source. + /// Returns `None` if no entry exists or if the entry has expired (when TTL is set). pub fn get(&self, source_id: &str) -> Option { let entries = self.entries.lock().unwrap(); - entries.get(source_id).cloned() + let entry = entries.get(source_id)?; + + if let Some(ttl) = self.ttl + && entry.stored_at.elapsed() > ttl + { + return None; + } + + Some(entry.value.clone()) } /// Remove a cached entry (memory + disk). @@ -121,7 +168,43 @@ impl SourceCache { } } - /// Returns the number of cached entries. + /// Remove all expired entries from the cache (memory + disk). + /// Only meaningful when a TTL is configured. + pub fn evict_expired(&self) { + let Some(ttl) = self.ttl else { return }; + + let expired_ids: Vec = { + let entries = self.entries.lock().unwrap(); + entries + .iter() + .filter(|(_, entry)| entry.stored_at.elapsed() > ttl) + .map(|(id, _)| id.clone()) + .collect() + }; + + if expired_ids.is_empty() { + return; + } + + { + let mut entries = self.entries.lock().unwrap(); + for id in &expired_ids { + entries.remove(id); + } + } + + if let Some(db) = &self.db { + let conn = db.lock().unwrap(); + for id in &expired_ids { + let _ = conn.execute( + "DELETE FROM source_cache WHERE source_id = ?1", + rusqlite::params![id], + ); + } + } + } + + /// Returns the number of cached entries (including potentially expired ones). pub fn len(&self) -> usize { let entries = self.entries.lock().unwrap(); entries.len() @@ -131,6 +214,11 @@ impl SourceCache { pub fn is_empty(&self) -> bool { self.len() == 0 } + + /// Returns the configured TTL, if any. + pub fn ttl(&self) -> Option { + self.ttl + } } impl Default for SourceCache { diff --git a/crates/rsigma-runtime/src/sources/command.rs b/crates/rsigma-runtime/src/sources/command.rs index bff35c2..1b04111 100644 --- a/crates/rsigma-runtime/src/sources/command.rs +++ b/crates/rsigma-runtime/src/sources/command.rs @@ -2,7 +2,7 @@ use std::time::Instant; -use rsigma_eval::pipeline::sources::DataFormat; +use rsigma_eval::pipeline::sources::{DataFormat, ExtractExpr}; use super::extract::apply_extract; use super::file::parse_data; @@ -12,7 +12,7 @@ use super::{ResolvedValue, SourceError, SourceErrorKind}; pub async fn resolve_command( command: &[String], format: DataFormat, - extract_expr: Option<&str>, + extract_expr: Option<&ExtractExpr>, ) -> Result { if command.is_empty() { return Err(SourceError { diff --git a/crates/rsigma-runtime/src/sources/extract.rs b/crates/rsigma-runtime/src/sources/extract.rs index 95f4e08..d0ea612 100644 --- a/crates/rsigma-runtime/src/sources/extract.rs +++ b/crates/rsigma-runtime/src/sources/extract.rs @@ -4,19 +4,22 @@ //! - Plain string: always jq (the common case) //! - Structured object `{ expr, type }`: explicit language selection //! -//! Supported types: `jq` (default), `jsonpath`, `cel`. +//! Supported types: `jq` (via jaq), `jsonpath` (via serde_json_path), `cel` (via cel-interpreter). + +use rsigma_eval::pipeline::sources::ExtractExpr; use super::{SourceError, SourceErrorKind}; -/// Apply an extract expression to parsed source data. -/// -/// The expression is always treated as jq in Phase 2a. JSONPath and CEL -/// support will be added in later sub-phases. +/// Apply a typed extract expression to parsed source data. pub fn apply_extract( data: &serde_json::Value, - expr: &str, + expr: &ExtractExpr, ) -> Result { - apply_jq(data, expr) + match expr { + ExtractExpr::Jq(e) => apply_jq(data, e), + ExtractExpr::JsonPath(e) => apply_jsonpath(data, e), + ExtractExpr::Cel(e) => apply_cel(data, e), + } } /// Apply a jq expression using jaq. @@ -56,6 +59,109 @@ fn apply_jq(data: &serde_json::Value, expr: &str) -> Result Result { + use jsonpath_rust::JsonPath; + + let results = data.query(expr).map_err(|e| SourceError { + source_id: String::new(), + kind: SourceErrorKind::Extract(format!("invalid JSONPath expression: {e}")), + })?; + + match results.len() { + 0 => Ok(serde_json::Value::Null), + 1 => Ok(results[0].clone()), + _ => { + let arr: Vec = results.into_iter().cloned().collect(); + Ok(serde_json::Value::Array(arr)) + } + } +} + +/// Apply a CEL expression using the `cel` crate (cel-rust). +/// +/// The resolved source data is bound as the CEL variable `data`. +fn apply_cel(data: &serde_json::Value, expr: &str) -> Result { + use cel::{Context, Program}; + + let program = Program::compile(expr).map_err(|e| SourceError { + source_id: String::new(), + kind: SourceErrorKind::Extract(format!("invalid CEL expression: {e}")), + })?; + + let mut context = Context::default(); + let cel_value = json_to_cel(data); + let _ = context.add_variable("data", cel_value); + + let result = program.execute(&context).map_err(|e| SourceError { + source_id: String::new(), + kind: SourceErrorKind::Extract(format!("CEL execution error: {e}")), + })?; + + Ok(cel_to_json(&result)) +} + +/// Convert a serde_json::Value to a cel::Value. +fn json_to_cel(json: &serde_json::Value) -> cel::Value { + match json { + serde_json::Value::Null => cel::Value::Null, + serde_json::Value::Bool(b) => (*b).into(), + serde_json::Value::Number(n) => { + if let Some(i) = n.as_i64() { + i.into() + } else if let Some(u) = n.as_u64() { + u.into() + } else if let Some(f) = n.as_f64() { + f.into() + } else { + cel::Value::Null + } + } + serde_json::Value::String(s) => s.as_str().into(), + serde_json::Value::Array(arr) => { + let items: Vec = arr.iter().map(json_to_cel).collect(); + items.into() + } + serde_json::Value::Object(map) => { + let cel_map: std::collections::HashMap = map + .iter() + .map(|(k, v)| (k.as_str().into(), json_to_cel(v))) + .collect(); + cel_map.into() + } + } +} + +/// Convert a cel::Value back to serde_json::Value. +fn cel_to_json(val: &cel::Value) -> serde_json::Value { + match val { + cel::Value::Null => serde_json::Value::Null, + cel::Value::Bool(b) => serde_json::Value::Bool(*b), + cel::Value::Int(i) => serde_json::json!(i), + cel::Value::UInt(u) => serde_json::json!(u), + cel::Value::Float(f) => serde_json::json!(f), + cel::Value::String(s) => serde_json::Value::String(s.to_string()), + cel::Value::List(list) => { + let arr: Vec = list.iter().map(cel_to_json).collect(); + serde_json::Value::Array(arr) + } + cel::Value::Map(map) => { + let mut obj = serde_json::Map::new(); + for (k, v) in map.map.iter() { + let key = match k { + cel::objects::Key::String(s) => s.to_string(), + cel::objects::Key::Int(i) => i.to_string(), + cel::objects::Key::Uint(u) => u.to_string(), + cel::objects::Key::Bool(b) => b.to_string(), + }; + obj.insert(key, cel_to_json(v)); + } + serde_json::Value::Object(obj) + } + _ => serde_json::Value::String(format!("{val:?}")), + } +} + /// Convert a jaq `Val` to a `serde_json::Value`. fn val_to_json(val: &jaq_interpret::Val) -> serde_json::Value { match val { diff --git a/crates/rsigma-runtime/src/sources/file.rs b/crates/rsigma-runtime/src/sources/file.rs index df9e421..5c410bb 100644 --- a/crates/rsigma-runtime/src/sources/file.rs +++ b/crates/rsigma-runtime/src/sources/file.rs @@ -3,7 +3,7 @@ use std::path::Path; use std::time::Instant; -use rsigma_eval::pipeline::sources::DataFormat; +use rsigma_eval::pipeline::sources::{DataFormat, ExtractExpr}; use super::extract::apply_extract; use super::{ResolvedValue, SourceError, SourceErrorKind}; @@ -12,7 +12,7 @@ use super::{ResolvedValue, SourceError, SourceErrorKind}; pub async fn resolve_file( path: &Path, format: DataFormat, - extract_expr: Option<&str>, + extract_expr: Option<&ExtractExpr>, ) -> Result { let contents = tokio::fs::read_to_string(path) .await diff --git a/crates/rsigma-runtime/src/sources/http.rs b/crates/rsigma-runtime/src/sources/http.rs index 7d3368b..dd92c2c 100644 --- a/crates/rsigma-runtime/src/sources/http.rs +++ b/crates/rsigma-runtime/src/sources/http.rs @@ -3,7 +3,7 @@ use std::collections::HashMap; use std::time::{Duration, Instant}; -use rsigma_eval::pipeline::sources::DataFormat; +use rsigma_eval::pipeline::sources::{DataFormat, ExtractExpr}; use super::extract::apply_extract; use super::file::parse_data; @@ -15,7 +15,7 @@ pub async fn resolve_http( method: Option<&str>, headers: &HashMap, format: DataFormat, - extract_expr: Option<&str>, + extract_expr: Option<&ExtractExpr>, timeout: Option, ) -> Result { let client = reqwest::Client::builder() diff --git a/crates/rsigma-runtime/src/sources/include.rs b/crates/rsigma-runtime/src/sources/include.rs index 7b57bfe..fec1b50 100644 --- a/crates/rsigma-runtime/src/sources/include.rs +++ b/crates/rsigma-runtime/src/sources/include.rs @@ -10,6 +10,9 @@ use rsigma_eval::pipeline::sources::SourceType; use rsigma_eval::pipeline::transformations::Transformation; use rsigma_eval::{Pipeline, TransformationItem}; +/// Maximum include nesting depth (prevents cycles). +const MAX_INCLUDE_DEPTH: usize = 1; + /// Expand all `Include` transformations in a pipeline. /// /// For each `Include { template }`, the template references a source ID. @@ -18,11 +21,29 @@ use rsigma_eval::{Pipeline, TransformationItem}; /// /// Security: if `allow_remote_include` is false, includes referencing HTTP or NATS /// sources produce an error. +/// +/// Recursive includes are not allowed (max depth 1). If an included fragment +/// itself contains `Include` directives, expansion fails with an error. pub fn expand_includes( pipeline: &mut Pipeline, resolved: &HashMap, allow_remote_include: bool, ) -> Result<(), String> { + expand_includes_with_depth(pipeline, resolved, allow_remote_include, 0) +} + +fn expand_includes_with_depth( + pipeline: &mut Pipeline, + resolved: &HashMap, + allow_remote_include: bool, + depth: usize, +) -> Result<(), String> { + if depth > MAX_INCLUDE_DEPTH { + return Err( + "recursive includes are not allowed (max depth 1); included content cannot itself contain include directives".to_string() + ); + } + let mut expanded_transformations = Vec::new(); let mut had_include = false; @@ -48,6 +69,17 @@ pub fn expand_includes( if let Some(data) = resolved.get(&source_id) { let items = parse_transformation_array(data)?; + + // Check for nested includes (depth enforcement) + for parsed_item in &items { + if matches!(parsed_item.transformation, Transformation::Include { .. }) { + return Err(format!( + "included content from source '{}' contains nested include directives; recursive includes are not allowed (max depth 1)", + source_id + )); + } + } + expanded_transformations.extend(items); } else { return Err(format!( @@ -118,4 +150,42 @@ mod tests { fn extract_source_id_plain_string() { assert_eq!(extract_source_id("my_source"), "my_source"); } + + #[test] + fn nested_include_rejected() { + let mut pipeline = Pipeline { + name: "test".to_string(), + priority: 0, + vars: HashMap::new(), + transformations: vec![TransformationItem { + id: None, + transformation: Transformation::Include { + template: "${source.transforms}".to_string(), + }, + rule_conditions: vec![], + rule_cond_expr: None, + detection_item_conditions: vec![], + field_name_conditions: vec![], + field_name_cond_not: false, + }], + finalizers: vec![], + sources: vec![], + source_refs: vec![], + }; + + // The resolved source data contains an include directive itself + let nested_yaml = serde_json::json!([ + {"type": "include", "include": "${source.other}"} + ]); + let mut resolved = HashMap::new(); + resolved.insert("transforms".to_string(), nested_yaml); + + let result = expand_includes(&mut pipeline, &resolved, true); + assert!(result.is_err()); + let err = result.unwrap_err(); + assert!( + err.contains("nested include") || err.contains("recursive"), + "error should mention nesting: {err}" + ); + } } diff --git a/crates/rsigma-runtime/src/sources/mod.rs b/crates/rsigma-runtime/src/sources/mod.rs index 808752c..819ce4e 100644 --- a/crates/rsigma-runtime/src/sources/mod.rs +++ b/crates/rsigma-runtime/src/sources/mod.rs @@ -118,25 +118,29 @@ impl Default for DefaultSourceResolver { impl SourceResolver for DefaultSourceResolver { async fn resolve(&self, source: &DynamicSource) -> Result { let result = match &source.source_type { - SourceType::File { path, format } => file::resolve_file(path, *format, None).await, + SourceType::File { + path, + format, + extract, + } => file::resolve_file(path, *format, extract.as_ref()).await, SourceType::Command { command, format, - extract: extract_expr, - } => command::resolve_command(command, *format, extract_expr.as_deref()).await, + extract, + } => command::resolve_command(command, *format, extract.as_ref()).await, SourceType::Http { url, method, headers, format, - extract: extract_expr, + extract, } => { http::resolve_http( url, method.as_deref(), headers, *format, - extract_expr.as_deref(), + extract.as_ref(), source.timeout, ) .await @@ -146,8 +150,8 @@ impl SourceResolver for DefaultSourceResolver { url, subject, format, - extract: extract_expr, - } => nats::resolve_nats_initial(url, subject, *format, extract_expr.as_deref()).await, + extract, + } => nats::resolve_nats_initial(url, subject, *format, extract.as_ref()).await, #[cfg(not(feature = "nats"))] SourceType::Nats { .. } => { return Err(SourceError { @@ -214,14 +218,29 @@ impl SourceResolver for DefaultSourceResolver { pub async fn resolve_all( resolver: &dyn SourceResolver, sources: &[DynamicSource], +) -> Result, SourceError> { + resolve_all_with_state(resolver, sources, None).await +} + +/// Like [`resolve_all`] but also updates a [`PipelineState`] with source resolution status. +pub async fn resolve_all_with_state( + resolver: &dyn SourceResolver, + sources: &[DynamicSource], + mut state: Option<&mut rsigma_eval::pipeline::state::PipelineState>, ) -> Result, SourceError> { let mut resolved = std::collections::HashMap::new(); for source in sources { match resolver.resolve(source).await { Ok(value) => { resolved.insert(source.id.clone(), value.data); + if let Some(s) = state.as_deref_mut() { + s.mark_source_resolved(&source.id); + } } Err(e) => { + if let Some(s) = state.as_deref_mut() { + s.mark_source_failed(&source.id); + } if source.required { return Err(e); } diff --git a/crates/rsigma-runtime/src/sources/nats.rs b/crates/rsigma-runtime/src/sources/nats.rs index 7ad0889..a81478f 100644 --- a/crates/rsigma-runtime/src/sources/nats.rs +++ b/crates/rsigma-runtime/src/sources/nats.rs @@ -2,7 +2,7 @@ use std::time::Instant; -use rsigma_eval::pipeline::sources::DataFormat; +use rsigma_eval::pipeline::sources::{DataFormat, ExtractExpr}; use super::extract::apply_extract; use super::file::parse_data; @@ -19,7 +19,7 @@ pub async fn resolve_nats_initial( url: &str, subject: &str, format: DataFormat, - extract_expr: Option<&str>, + extract_expr: Option<&ExtractExpr>, ) -> Result { use futures::StreamExt; @@ -65,7 +65,7 @@ pub async fn resolve_nats_initial( pub fn parse_nats_message( payload: &[u8], format: DataFormat, - extract_expr: Option<&str>, + extract_expr: Option<&ExtractExpr>, ) -> Result { let raw = std::str::from_utf8(payload).map_err(|e| SourceError { source_id: String::new(), diff --git a/crates/rsigma-runtime/src/sources/refresh.rs b/crates/rsigma-runtime/src/sources/refresh.rs index 98dea24..7814650 100644 --- a/crates/rsigma-runtime/src/sources/refresh.rs +++ b/crates/rsigma-runtime/src/sources/refresh.rs @@ -148,7 +148,7 @@ impl RefreshScheduler { let extract_expr = extract_expr.clone(); tokio::spawn(async move { if let Err(e) = - nats_push_loop(&url, &subject, format, extract_expr.as_deref(), &id, &tx) + nats_push_loop(&url, &subject, format, extract_expr.as_ref(), &id, &tx) .await { tracing::error!( @@ -226,7 +226,7 @@ async fn nats_push_loop( url: &str, subject: &str, format: rsigma_eval::pipeline::sources::DataFormat, - extract_expr: Option<&str>, + extract_expr: Option<&rsigma_eval::pipeline::sources::ExtractExpr>, source_id: &str, trigger_tx: &mpsc::Sender, ) -> Result<(), String> { @@ -271,6 +271,56 @@ async fn nats_push_loop( Ok(()) } +/// The default NATS control subject for triggering source re-resolution. +pub const NATS_CONTROL_SUBJECT: &str = "rsigma.control.resolve"; + +/// Subscribe to the NATS control subject and forward re-resolution triggers. +/// +/// Messages with an empty payload trigger re-resolution of all sources. +/// Messages with a non-empty payload are treated as a source ID to re-resolve. +#[cfg(feature = "nats")] +pub async fn nats_control_loop( + url: &str, + subject: &str, + trigger_tx: mpsc::Sender, +) -> Result<(), String> { + use futures::StreamExt; + + let client = async_nats::connect(url) + .await + .map_err(|e| format!("NATS control connect failed: {e}"))?; + + let mut subscriber = client + .subscribe(subject.to_string()) + .await + .map_err(|e| format!("NATS control subscribe failed: {e}"))?; + + tracing::info!( + subject = %subject, + "NATS control subscription active for source re-resolution" + ); + + while let Some(msg) = subscriber.next().await { + let payload = String::from_utf8_lossy(&msg.payload); + let payload = payload.trim(); + + let trigger = if payload.is_empty() { + tracing::debug!("NATS control: triggering all sources"); + RefreshTrigger::All + } else { + tracing::debug!(source_id = %payload, "NATS control: triggering single source"); + RefreshTrigger::Single(payload.to_string()) + }; + + if trigger_tx.send(trigger).await.is_err() { + tracing::debug!("NATS control loop: trigger channel closed, exiting"); + break; + } + } + + Ok(()) +} + /// Watch a file for changes and send refresh triggers. async fn file_watch_loop( path: &std::path::Path, diff --git a/crates/rsigma-runtime/tests/sources_integration.rs b/crates/rsigma-runtime/tests/sources_integration.rs index c5f7da8..8d88f0e 100644 --- a/crates/rsigma-runtime/tests/sources_integration.rs +++ b/crates/rsigma-runtime/tests/sources_integration.rs @@ -4,7 +4,7 @@ use std::collections::HashMap; use rsigma_eval::Pipeline; use rsigma_eval::pipeline::sources::{ - DataFormat, DynamicSource, ErrorPolicy, RefreshPolicy, SourceType, + DataFormat, DynamicSource, ErrorPolicy, ExtractExpr, RefreshPolicy, SourceType, }; use rsigma_runtime::sources::cache::SourceCache; use rsigma_runtime::sources::file::resolve_file; @@ -27,18 +27,105 @@ async fn file_source_json() { } #[tokio::test] -async fn file_source_json_with_extract() { +async fn file_source_json_with_extract_jq() { let dir = tempfile::tempdir().unwrap(); let path = dir.path().join("data.json"); std::fs::write(&path, r#"{"emails": ["a@b.com", "c@d.com"]}"#).unwrap(); - let result = resolve_file(&path, DataFormat::Json, Some(".emails[]")) + let extract = ExtractExpr::Jq(".emails[]".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)) .await .unwrap(); let expected = serde_json::json!(["a@b.com", "c@d.com"]); assert_eq!(result.data, expected); } +#[tokio::test] +async fn file_source_json_with_extract_jsonpath() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("data.json"); + std::fs::write(&path, r#"{"emails": ["a@b.com", "c@d.com"]}"#).unwrap(); + + let extract = ExtractExpr::JsonPath("$.emails[*]".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)) + .await + .unwrap(); + let expected = serde_json::json!(["a@b.com", "c@d.com"]); + assert_eq!(result.data, expected); +} + +#[tokio::test] +async fn file_source_json_with_extract_jsonpath_single() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("data.json"); + std::fs::write(&path, r#"{"name": "rsigma", "version": 9}"#).unwrap(); + + let extract = ExtractExpr::JsonPath("$.name".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)) + .await + .unwrap(); + assert_eq!(result.data, serde_json::json!("rsigma")); +} + +#[tokio::test] +async fn file_source_json_with_extract_cel() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("data.json"); + std::fs::write(&path, r#"{"emails": ["a@b.com", "c@d.com"], "count": 2}"#).unwrap(); + + let extract = ExtractExpr::Cel("data.count".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)) + .await + .unwrap(); + assert_eq!(result.data, serde_json::json!(2)); +} + +#[tokio::test] +async fn file_source_json_with_extract_cel_list() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("data.json"); + std::fs::write(&path, r#"{"items": [1, 2, 3, 4, 5]}"#).unwrap(); + + let extract = ExtractExpr::Cel("data.items.filter(x, x > 3)".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)) + .await + .unwrap(); + assert_eq!(result.data, serde_json::json!([4, 5])); +} + +#[tokio::test] +async fn extract_invalid_jq_returns_error() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("data.json"); + std::fs::write(&path, r#"{"x": 1}"#).unwrap(); + + let extract = ExtractExpr::Jq("invalid[[[".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn extract_invalid_jsonpath_returns_error() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("data.json"); + std::fs::write(&path, r#"{"x": 1}"#).unwrap(); + + let extract = ExtractExpr::JsonPath("$[invalid".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)).await; + assert!(result.is_err()); +} + +#[tokio::test] +async fn extract_invalid_cel_returns_error() { + let dir = tempfile::tempdir().unwrap(); + let path = dir.path().join("data.json"); + std::fs::write(&path, r#"{"x": 1}"#).unwrap(); + + let extract = ExtractExpr::Cel("invalid(((syntax".to_string()); + let result = resolve_file(&path, DataFormat::Json, Some(&extract)).await; + assert!(result.is_err()); +} + #[tokio::test] async fn file_source_lines() { let dir = tempfile::tempdir().unwrap(); @@ -142,8 +229,9 @@ async fn command_source_with_extract() { format!("type {}", path.to_str().unwrap()), ]; + let extract = ExtractExpr::Jq(".items[]".to_string()); let result = - rsigma_runtime::sources::command::resolve_command(&cmd, DataFormat::Json, Some(".items[]")) + rsigma_runtime::sources::command::resolve_command(&cmd, DataFormat::Json, Some(&extract)) .await .unwrap(); @@ -216,6 +304,7 @@ async fn resolver_file_source_end_to_end() { source_type: SourceType::File { path: path.clone(), format: DataFormat::Json, + extract: None, }, refresh: RefreshPolicy::Once, timeout: None, @@ -246,6 +335,7 @@ async fn resolver_use_cached_on_failure() { source_type: SourceType::File { path: "/nonexistent/file.json".into(), format: DataFormat::Json, + extract: None, }, refresh: RefreshPolicy::Once, timeout: None, @@ -270,6 +360,7 @@ async fn resolver_use_default_on_failure() { source_type: SourceType::File { path: "/nonexistent/file.json".into(), format: DataFormat::Json, + extract: None, }, refresh: RefreshPolicy::Once, timeout: None, @@ -291,6 +382,7 @@ async fn resolver_fail_policy_returns_error() { source_type: SourceType::File { path: "/nonexistent/file.json".into(), format: DataFormat::Json, + extract: None, }, refresh: RefreshPolicy::Once, timeout: None, @@ -341,6 +433,7 @@ async fn end_to_end_dynamic_pipeline_resolution() { source_type: SourceType::File { path: emails_path, format: DataFormat::Json, + extract: None, }, refresh: RefreshPolicy::Once, timeout: None, @@ -353,6 +446,7 @@ async fn end_to_end_dynamic_pipeline_resolution() { source_type: SourceType::File { path: config_path, format: DataFormat::Json, + extract: None, }, refresh: RefreshPolicy::Once, timeout: None, @@ -473,3 +567,46 @@ fn cache_sqlite_invalidate_persists() { assert!(cache.is_empty()); } } + +#[test] +fn cache_ttl_expiration() { + use std::thread; + use std::time::Duration; + + let cache = SourceCache::with_ttl(Duration::from_millis(50)); + cache.store("src1", &serde_json::json!("fresh")); + + // Immediately accessible + assert_eq!(cache.get("src1").unwrap(), serde_json::json!("fresh")); + + // Wait for TTL to expire + thread::sleep(Duration::from_millis(60)); + assert!(cache.get("src1").is_none()); +} + +#[test] +fn cache_ttl_evict_expired() { + use std::thread; + use std::time::Duration; + + let cache = SourceCache::with_ttl(Duration::from_millis(50)); + cache.store("src1", &serde_json::json!("a")); + cache.store("src2", &serde_json::json!("b")); + + thread::sleep(Duration::from_millis(60)); + + // Entries still in map (len counts all, including expired) + assert_eq!(cache.len(), 2); + + // Evict removes expired entries + cache.evict_expired(); + assert!(cache.is_empty()); +} + +#[test] +fn cache_no_ttl_never_expires() { + let cache = SourceCache::new(); + cache.store("src1", &serde_json::json!("persistent")); + assert_eq!(cache.ttl(), None); + assert_eq!(cache.get("src1").unwrap(), serde_json::json!("persistent")); +}