diff --git a/optd-core/src/bridge/from_cir.rs b/optd-core/src/bridge/from_cir.rs index 8b4c8338..391fe2d0 100644 --- a/optd-core/src/bridge/from_cir.rs +++ b/optd-core/src/bridge/from_cir.rs @@ -14,7 +14,7 @@ pub(crate) fn partial_logical_to_value(plan: &PartialLogicalPlan) -> Value { match plan { PartialLogicalPlan::UnMaterialized(group_id) => { // For unmaterialized logical operators, we create a `Value` with the group ID. - Value(Logical(UnMaterialized(hir::GroupId(group_id.0)))) + Value::new(Logical(UnMaterialized(hir::GroupId(group_id.0)))) } PartialLogicalPlan::Materialized(node) => { // For materialized logical operators, we create a `Value` with the operator data. @@ -24,7 +24,7 @@ pub(crate) fn partial_logical_to_value(plan: &PartialLogicalPlan) -> Value { children: convert_children_to_values(&node.children, partial_logical_to_value), }; - Value(Logical(Materialized(LogicalOp::logical(operator)))) + Value::new(Logical(Materialized(LogicalOp::logical(operator)))) } } } @@ -35,7 +35,7 @@ pub(crate) fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { PartialPhysicalPlan::UnMaterialized(goal) => { // For unmaterialized physical operators, we create a `Value` with the goal let hir_goal = cir_goal_to_hir(goal); - Value(Physical(UnMaterialized(hir_goal))) + Value::new(Physical(UnMaterialized(hir_goal))) } PartialPhysicalPlan::Materialized(node) => { // For materialized physical operators, we create a Value with the operator data @@ -45,7 +45,7 @@ pub(crate) fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { children: convert_children_to_values(&node.children, partial_physical_to_value), }; - Value(Physical(Materialized(PhysicalOp::physical(operator)))) + Value::new(Physical(Materialized(PhysicalOp::physical(operator)))) } } } @@ -54,9 +54,9 @@ pub(crate) fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { /// Converts a [`PartialPhysicalPlan`] with its cost into a [`Value`]. pub(crate) fn costed_physical_to_value(plan: PartialPhysicalPlan, cost: Cost) -> Value { let operator = partial_physical_to_value(&plan); - Value(Tuple(vec![ + Value::new(Tuple(vec![ partial_physical_to_value(&plan), - Value(Literal(Float64(cost.0))), + Value::new(Literal(Float64(cost.0))), ])) } @@ -65,7 +65,7 @@ pub(crate) fn costed_physical_to_value(plan: PartialPhysicalPlan, cost: Cost) -> pub(crate) fn logical_properties_to_value(properties: &LogicalProperties) -> Value { match &properties.0 { Some(data) => properties_data_to_value(data), - Option::None => Value(None), + Option::None => Value::new(None), } } @@ -73,7 +73,7 @@ pub(crate) fn logical_properties_to_value(properties: &LogicalProperties) -> Val pub(crate) fn physical_properties_to_value(properties: &PhysicalProperties) -> Value { match &properties.0 { Some(data) => properties_data_to_value(data), - Option::None => Value(None), + Option::None => Value::new(None), } } @@ -104,7 +104,7 @@ where .map(|child| match child { Child::Singleton(item) => converter(item), Child::VarLength(items) => { - Value(Array(items.iter().map(|item| converter(item)).collect())) + Value::new(Array(items.iter().map(|item| converter(item)).collect())) } }) .collect() @@ -118,32 +118,34 @@ fn convert_operator_data_to_values(data: &[OperatorData]) -> Vec { /// Converts an [`OperatorData`] into a [`Value`]. fn operator_data_to_value(data: &OperatorData) -> Value { match data { - OperatorData::Int64(i) => Value(Literal(Int64(*i))), - OperatorData::Float64(f) => Value(Literal(Float64(**f))), - OperatorData::String(s) => Value(Literal(String(s.clone()))), - OperatorData::Bool(b) => Value(Literal(Bool(*b))), - OperatorData::Struct(name, elements) => Value(Struct( + OperatorData::Int64(i) => Value::new(Literal(Int64(*i))), + OperatorData::Float64(f) => Value::new(Literal(Float64(**f))), + OperatorData::String(s) => Value::new(Literal(String(s.clone()))), + OperatorData::Bool(b) => Value::new(Literal(Bool(*b))), + OperatorData::Struct(name, elements) => Value::new(Struct( name.clone(), convert_operator_data_to_values(elements), )), - OperatorData::Array(elements) => Value(Array(convert_operator_data_to_values(elements))), + OperatorData::Array(elements) => { + Value::new(Array(convert_operator_data_to_values(elements))) + } } } /// Converts a [`PropertiesData`] into a [`Value`]. fn properties_data_to_value(data: &PropertiesData) -> Value { match data { - PropertiesData::Int64(i) => Value(Literal(Int64(*i))), - PropertiesData::Float64(f) => Value(Literal(Float64(**f))), - PropertiesData::String(s) => Value(Literal(String(s.clone()))), - PropertiesData::Bool(b) => Value(Literal(Bool(*b))), + PropertiesData::Int64(i) => Value::new(Literal(Int64(*i))), + PropertiesData::Float64(f) => Value::new(Literal(Float64(**f))), + PropertiesData::String(s) => Value::new(Literal(String(s.clone()))), + PropertiesData::Bool(b) => Value::new(Literal(Bool(*b))), PropertiesData::Struct(name, elements) => { let values = elements.iter().map(properties_data_to_value).collect(); - Value(Struct(name.clone(), values)) + Value::new(Struct(name.clone(), values)) } PropertiesData::Array(elements) => { let values = elements.iter().map(properties_data_to_value).collect(); - Value(Array(values)) + Value::new(Array(values)) } } } diff --git a/optd-core/src/bridge/into_cir.rs b/optd-core/src/bridge/into_cir.rs index 113b314c..f209b907 100644 --- a/optd-core/src/bridge/into_cir.rs +++ b/optd-core/src/bridge/into_cir.rs @@ -14,7 +14,7 @@ use std::sync::Arc; /// /// Panics if the [`Value`] is not a [`Logical`] variant. pub(crate) fn value_to_partial_logical(value: &Value) -> PartialLogicalPlan { - match &value.0 { + match &value.data { Logical(logical_op) => match logical_op { UnMaterialized(group_id) => { PartialLogicalPlan::UnMaterialized(hir_group_id_to_cir(group_id)) @@ -28,7 +28,7 @@ pub(crate) fn value_to_partial_logical(value: &Value) -> PartialLogicalPlan { ), }), }, - _ => panic!("Expected Logical CoreData variant, found: {:?}", value.0), + _ => panic!("Expected Logical CoreData variant, found: {:?}", value.data), } } @@ -38,7 +38,7 @@ pub(crate) fn value_to_partial_logical(value: &Value) -> PartialLogicalPlan { /// /// Panics if the [`Value`] is not a [`Physical`] variant. pub(crate) fn value_to_partial_physical(value: &Value) -> PartialPhysicalPlan { - match &value.0 { + match &value.data { Physical(physical_op) => match physical_op { UnMaterialized(hir_goal) => { PartialPhysicalPlan::UnMaterialized(hir_goal_to_cir(hir_goal)) @@ -52,7 +52,10 @@ pub(crate) fn value_to_partial_physical(value: &Value) -> PartialPhysicalPlan { ), }), }, - _ => panic!("Expected Physical CoreData variant, found: {:?}", value.0), + _ => panic!( + "Expected Physical CoreData variant, found: {:?}", + value.data + ), } } @@ -62,15 +65,15 @@ pub(crate) fn value_to_partial_physical(value: &Value) -> PartialPhysicalPlan { /// /// Panics if the [`Value`] is not a [`Literal`] variant with a [`Float64`] value. pub(crate) fn value_to_cost(value: &Value) -> Cost { - match &value.0 { + match &value.data { Literal(Float64(f)) => Cost(*f), - _ => panic!("Expected Float64 literal, found: {:?}", value.0), + _ => panic!("Expected Float64 literal, found: {:?}", value.data), } } /// Converts an HIR properties [`Value`] into a CIR [`LogicalProperties`]. pub(crate) fn value_to_logical_properties(properties_value: &Value) -> LogicalProperties { - match &properties_value.0 { + match &properties_value.data { None => LogicalProperties(Option::None), _ => LogicalProperties(Some(value_to_properties_data(properties_value))), } @@ -78,7 +81,7 @@ pub(crate) fn value_to_logical_properties(properties_value: &Value) -> LogicalPr /// Convert an HIR properties [`Value`] into a CIR [`PhysicalProperties`]. fn value_to_physical_properties(properties_value: &Value) -> PhysicalProperties { - match &properties_value.0 { + match &properties_value.data { None => PhysicalProperties(Option::None), _ => PhysicalProperties(Some(value_to_properties_data(properties_value))), } @@ -108,7 +111,7 @@ pub(crate) fn hir_goal_to_cir(hir_goal: &hir::Goal) -> Goal { /// Panics if the [`Value`] is not a [`Logical`] variant or if the [`Logical`] variant is not a /// [`Materialized`] variant. fn value_to_logical(value: &Value) -> LogicalPlan { - match &value.0 { + match &value.data { Logical(logical_op) => match logical_op { UnMaterialized(_) => { panic!("Cannot convert UnMaterialized LogicalOperator to LogicalPlan") @@ -119,7 +122,7 @@ fn value_to_logical(value: &Value) -> LogicalPlan { children: convert_values_to_children(&log_op.operator.children, value_to_logical), }), }, - _ => panic!("Expected Logical CoreData variant, found: {:?}", value.0), + _ => panic!("Expected Logical CoreData variant, found: {:?}", value.data), } } @@ -132,7 +135,7 @@ where { values .iter() - .map(|value| match &value.0 { + .map(|value| match &value.data { Array(elements) => VarLength( elements .iter() @@ -160,7 +163,7 @@ fn convert_values_to_properties_data(values: &[Value]) -> Vec { /// /// Panics if the [`Value`] cannot be converted to [`OperatorData`], such as a [`Unit`] literal. fn value_to_operator_data(value: &Value) -> OperatorData { - match &value.0 { + match &value.data { Literal(constant) => match constant { Int64(i) => OperatorData::Int64(*i), Float64(f) => OperatorData::Float64((*f).into()), @@ -172,7 +175,7 @@ fn value_to_operator_data(value: &Value) -> OperatorData { Struct(name, elements) => { OperatorData::Struct(name.clone(), convert_values_to_operator_data(elements)) } - _ => panic!("Cannot convert {:?} to OperatorData", value.0), + _ => panic!("Cannot convert {:?} to OperatorData", value.data), } } @@ -182,7 +185,7 @@ fn value_to_operator_data(value: &Value) -> OperatorData { /// /// Panics if the [`Value`] cannot be converted to [`PropertiesData`], such as a [`Unit`] literal. fn value_to_properties_data(value: &Value) -> PropertiesData { - match &value.0 { + match &value.data { Literal(constant) => match constant { Int64(i) => PropertiesData::Int64(*i), Float64(f) => PropertiesData::Float64((*f).into()), @@ -194,6 +197,6 @@ fn value_to_properties_data(value: &Value) -> PropertiesData { Struct(name, elements) => { PropertiesData::Struct(name.clone(), convert_values_to_properties_data(elements)) } - _ => panic!("Cannot convert {:?} to PropertyData content", value.0), + _ => panic!("Cannot convert {:?} to PropertyData content", value.data), } } diff --git a/optd-dsl/src/analyzer/context.rs b/optd-dsl/src/analyzer/context.rs index 08fba5ec..621bd69e 100644 --- a/optd-dsl/src/analyzer/context.rs +++ b/optd-dsl/src/analyzer/context.rs @@ -1,4 +1,4 @@ -use super::hir::Identifier; +use super::hir::{ExprMetadata, Identifier, NoMetadata}; use crate::analyzer::hir::Value; use std::{collections::HashMap, sync::Arc}; @@ -12,15 +12,15 @@ use std::{collections::HashMap, sync::Arc}; /// The current (innermost) scope is mutable, while all previous scopes /// are immutable and stored as Arc for efficient cloning. #[derive(Debug, Clone, Default)] -pub struct Context { +pub struct Context { /// Previous scopes (outer lexical scopes), stored as immutable Arc references - previous_scopes: Vec>>, + previous_scopes: Vec>>>, /// Current scope (innermost) that can be directly modified - current_scope: HashMap, + current_scope: HashMap>, } -impl Context { +impl Context { /// Creates a new context with the given initial bindings as the global scope. /// /// # Arguments @@ -30,7 +30,7 @@ impl Context { /// # Returns /// /// A new `Context` instance with one scope containing the initial bindings - pub fn new(initial_bindings: HashMap) -> Self { + pub fn new(initial_bindings: HashMap>) -> Self { Self { previous_scopes: Vec::new(), current_scope: initial_bindings, @@ -59,7 +59,7 @@ impl Context { /// # Returns /// /// Some reference to the value if found, None otherwise - pub fn lookup(&self, name: &str) -> Option<&Value> { + pub fn lookup(&self, name: &str) -> Option<&Value> { // First check the current scope if let Some(value) = self.current_scope.get(name) { return Some(value); @@ -84,7 +84,7 @@ impl Context { /// # Arguments /// /// * `other` - The context to merge from (consumed by this operation) - pub fn merge(&mut self, other: Context) { + pub fn merge(&mut self, other: Context) { // Move bindings from other's current scope into our current scope for (name, val) in other.current_scope { self.current_scope.insert(name, val); @@ -100,7 +100,7 @@ impl Context { /// /// * `name` - The name of the variable to bind /// * `val` - The value to bind to the variable - pub fn bind(&mut self, name: String, val: Value) { + pub fn bind(&mut self, name: String, val: Value) { self.current_scope.insert(name, val); } } diff --git a/optd-dsl/src/analyzer/hir.rs b/optd-dsl/src/analyzer/hir.rs index 9f9a0f6c..fa974e03 100644 --- a/optd-dsl/src/analyzer/hir.rs +++ b/optd-dsl/src/analyzer/hir.rs @@ -16,6 +16,7 @@ //! intermediate representations through the bridge modules. use super::context::Context; +use super::map::Map; use super::r#type::Type; use crate::utils::span::Span; use std::fmt::Debug; @@ -37,15 +38,36 @@ pub enum Literal { Unit, } +/// Metadata that can be attached to expression nodes +/// +/// This trait allows for different types of metadata to be attached to +/// expression nodes while maintaining a common interface for access. +pub trait ExprMetadata: Debug + Clone {} + +/// Empty metadata implementation for cases where no additional data is needed +#[derive(Debug, Clone, Default)] +pub struct NoMetadata; +impl ExprMetadata for NoMetadata {} + +/// Combined span and type information for an expression +#[derive(Debug, Clone)] +pub struct TypedSpan { + /// Source code location. + pub span: Span, + /// Inferred type. + pub ty: Type, +} +impl ExprMetadata for TypedSpan {} + /// Types of functions in the system #[derive(Debug, Clone)] -pub enum FunKind { - Closure(Vec, Arc), - RustUDF(fn(Vec) -> Value), +pub enum FunKind { + Closure(Vec, Arc>), + RustUDF(fn(Vec>) -> Value), } /// Group identifier in the optimizer -#[derive(Debug, Clone, PartialEq, Copy)] +#[derive(Debug, Clone, PartialEq, Copy, Eq, Hash)] pub struct GroupId(pub i64); /// Either materialized or unmaterialized data @@ -68,7 +90,7 @@ pub struct Goal { /// The logical group to implement pub group_id: GroupId, /// Required physical properties - pub properties: Box, + pub properties: Box>, } /// Unified operator node structure for all operator types @@ -124,13 +146,13 @@ impl LogicalOp { /// Represents an executable implementation of a logical operation with specific /// physical properties, either materialized as a concrete operator or as a physical goal. #[derive(Debug, Clone)] -pub struct PhysicalOp { +pub struct PhysicalOp { pub operator: Operator, pub goal: Option, - pub cost: Option>, + pub cost: Option>>, } -impl PhysicalOp { +impl PhysicalOp { /// Creates a new physical operator without goal or cost information /// /// Used for representing physical operators that are not yet part of the @@ -158,18 +180,31 @@ impl PhysicalOp { /// Creates a new physical operator with both goal and cost information /// /// Used for representing fully optimized physical operators with computed cost. - pub fn costed_physical(operator: Operator, goal: Goal, cost: Value) -> Self { + pub fn costed_physical(operator: Operator, goal: Goal, cost: Value) -> Self { Self { operator, goal: Some(goal), - cost: Some(cost.into()), + cost: Some(Box::new(cost)), } } } +/// Evaluated expression result +#[derive(Debug, Clone)] +pub struct Value { + pub data: CoreData, M>, +} + +impl Value { + /// Creates a new value from core data + pub fn new(data: CoreData, M>) -> Self { + Self { data } + } +} + /// Core data structures shared across the system #[derive(Debug, Clone)] -pub enum CoreData { +pub enum CoreData { /// Primitive literal values Literal(Literal), /// Ordered collection of values @@ -177,42 +212,21 @@ pub enum CoreData { /// Fixed collection of possibly heterogeneous values Tuple(Vec), /// Key-value associations - Map(Vec<(T, T)>), + Map(Map), /// Named structure with fields Struct(Identifier, Vec), /// Function or closure - Function(FunKind), + Function(FunKind), /// Error representation Fail(Box), /// Logical query operators Logical(Materializable, GroupId>), /// Physical query operators - Physical(Materializable, Goal>), + Physical(Materializable, Goal>), /// The None value None, } -/// Metadata that can be attached to expression nodes -/// -/// This trait allows for different types of metadata to be attached to -/// expression nodes while maintaining a common interface for access. -pub trait ExprMetadata: Debug + Clone {} - -/// Empty metadata implementation for cases where no additional data is needed -#[derive(Debug, Clone, Default)] -pub struct NoMetadata; -impl ExprMetadata for NoMetadata {} - -/// Combined span and type information for an expression -#[derive(Debug, Clone)] -pub struct TypedSpan { - /// Source code location. - pub span: Span, - /// Inferred type. - pub ty: Type, -} -impl ExprMetadata for TypedSpan {} - /// Expression nodes in the HIR with optional metadata /// /// The M type parameter allows attaching different kinds of metadata to expressions, @@ -235,9 +249,12 @@ impl Expr { } } +/// Type alias for map entries to reduce type complexity +pub type MapEntries = Vec<(Arc>, Arc>)>; + /// Expression node kinds without metadata #[derive(Debug, Clone)] -pub enum ExprKind { +pub enum ExprKind { /// Pattern matching expression PatternMatch(Arc>, Vec>), /// Conditional expression @@ -250,18 +267,16 @@ pub enum ExprKind { Unary(UnaryOp, Arc>), /// Function call Call(Arc>, Vec>>), + /// Map expression + Map(MapEntries), /// Variable reference Ref(Identifier), /// Core expression - CoreExpr(CoreData>>), + CoreExpr(CoreData>, M>), /// Core value - CoreVal(Value), + CoreVal(Value), } -/// Evaluated expression result -#[derive(Debug, Clone)] -pub struct Value(pub CoreData); - /// Pattern for matching #[derive(Debug, Clone)] pub enum Pattern { @@ -315,10 +330,6 @@ pub enum UnaryOp { /// Program representation after the analysis phase #[derive(Debug)] pub struct HIR { - pub context: Context, + pub context: Context, pub annotations: HashMap>, - pub expressions: Vec>, } - -/// Type alias for HIR with both type and source location information -pub type TypedSpannedHIR = HIR; diff --git a/optd-dsl/src/analyzer/map.rs b/optd-dsl/src/analyzer/map.rs new file mode 100644 index 00000000..d874d587 --- /dev/null +++ b/optd-dsl/src/analyzer/map.rs @@ -0,0 +1,510 @@ +//! Custom Map implementation for the DSL. +//! +//! This module provides a Map type that enforces specific key constraints at runtime +//! rather than compile time. It bridges between the flexible HIR Value type and +//! a more restricted internal MapKey type that implements Hash and Eq. +//! +//! The Map implementation supports a subset of Value types as keys: +//! - Literals (except Float64, which doesn't implement Hash/Eq) +//! - Nested structures (Option, Logical, Physical, Fail, Struct, Tuple) +//! if all their contents are also valid key types +//! +//! The implementation provides: +//! - Conversion from Value to MapKey (with runtime validation) +//! - Efficient key lookup (O(1) via HashMap) +//! - Basic map operations (get, concat) + +use super::hir::ExprMetadata; +use super::hir::{ + CoreData, GroupId, Literal, LogicalOp, Materializable, Operator, PhysicalOp, Value, +}; +use std::collections::HashMap; +use std::hash::Hash; + +/// Map key representation of a logical operator +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct OperatorMapKey { + pub tag: String, + pub data: Vec, + pub children: Vec, +} + +/// Map key representation of a logical operator +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct LogicalMapOpKey { + pub operator: OperatorMapKey, + pub group_id: Option, +} + +/// Map key representation of a physical operator +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct PhysicalMapOpKey { + pub operator: OperatorMapKey, + pub goal: Option, + pub cost: Option, +} + +/// Map key representation of a goal +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub struct GoalMapKey { + pub group_id: GroupId, + pub properties: MapKey, +} + +/// Map key representation of logical operators (materialized or unmaterialized) +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum LogicalMapKey { + Materialized(LogicalMapOpKey), + UnMaterialized(GroupId), +} + +/// Map key representation of physical operators (materialized or unmaterialized) +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum PhysicalMapKey { + Materialized(PhysicalMapOpKey), + UnMaterialized(GoalMapKey), +} + +/// Internal key type that implements Hash, Eq, PartialEq +/// These are the only types that can be used as keys in the Map +#[derive(Clone, PartialEq, Eq, Hash, Debug)] +pub enum MapKey { + /// Primitive literal values (except Float64) + Int64(i64), + String(String), + Bool(bool), + Unit, + + /// Nested structures + Tuple(Vec), + Struct(String, Vec), + + /// Query structures + Logical(Box), + Physical(Box), + + /// or representation + Fail(Box), + + /// None value + None, +} + +/// Custom Map implementation that enforces key type constraints at runtime +#[derive(Clone, Debug, Default)] +pub struct Map { + inner: HashMap, +} + +impl Map { + /// Creates a Map from a collection of key-value pairs + pub fn from_pairs(pairs: Vec<(Value, Value)>) -> Self { + pairs.into_iter().fold(Self::default(), |mut map, (k, v)| { + let map_key = value_to_map_key(&k); + map.inner.insert(map_key, v); + map + }) + } + + /// Gets a value by key, returning None (as a Value) if not found + pub fn get(&self, key: &Value) -> Value { + self.inner + .get(&value_to_map_key(key)) + .unwrap_or(&Value::new(CoreData::None)) + .clone() + } + + /// Combines two maps, with values from other overriding values from self when keys collide + pub fn concat(&mut self, other: Map) { + self.inner.extend(other.inner); + } +} + +// Key conversion functions + +/// Converts a Value to a MapKey, enforcing valid key types +/// This performs runtime validation that the key type is supported +/// and will return an error for invalid key types +fn value_to_map_key(value: &Value) -> MapKey { + match &value.data { + CoreData::Literal(lit) => match lit { + Literal::Int64(i) => MapKey::Int64(*i), + Literal::String(s) => MapKey::String(s.clone()), + Literal::Bool(b) => MapKey::Bool(*b), + Literal::Unit => MapKey::Unit, + Literal::Float64(_) => panic!("Invalid map key: Float64"), + }, + CoreData::Tuple(items) => { + let key_items = items.iter().map(value_to_map_key).collect(); + MapKey::Tuple(key_items) + } + CoreData::Struct(name, fields) => { + let key_fields = fields.iter().map(value_to_map_key).collect(); + MapKey::Struct(name.clone(), key_fields) + } + CoreData::Logical(materializable) => match materializable { + Materializable::UnMaterialized(group_id) => { + MapKey::Logical(Box::new(LogicalMapKey::UnMaterialized(*group_id))) + } + Materializable::Materialized(logical_op) => { + let map_op = value_to_logical_map_op(logical_op); + MapKey::Logical(Box::new(LogicalMapKey::Materialized(map_op))) + } + }, + CoreData::Physical(materializable) => match materializable { + Materializable::UnMaterialized(goal) => { + let properties = value_to_map_key(&goal.properties); + let map_goal = GoalMapKey { + group_id: goal.group_id, + properties, + }; + MapKey::Physical(Box::new(PhysicalMapKey::UnMaterialized(map_goal))) + } + Materializable::Materialized(physical_op) => { + let map_op = value_to_physical_map_op(physical_op); + MapKey::Physical(Box::new(PhysicalMapKey::Materialized(map_op))) + } + }, + CoreData::Fail(inner) => { + let inner_key = value_to_map_key(inner); + MapKey::Fail(Box::new(inner_key)) + } + CoreData::None => MapKey::None, + _ => panic!("Invalid map key: {:?}", value), + } +} + +/// Converts an Operator to a map key operator +fn value_to_operator_map_key( + operator: &Operator, + value_converter: &dyn Fn(&T) -> MapKey, +) -> OperatorMapKey { + let data = operator.data.iter().map(value_converter).collect(); + + let children = operator.children.iter().map(value_converter).collect(); + + OperatorMapKey { + tag: operator.tag.clone(), + data, + children, + } +} + +/// Converts a LogicalOp to a map key +fn value_to_logical_map_op(logical_op: &LogicalOp>) -> LogicalMapOpKey { + let operator = + value_to_operator_map_key(&logical_op.operator, &(|v: &Value| value_to_map_key(v))); + + LogicalMapOpKey { + operator, + group_id: logical_op.group_id, + } +} + +/// Converts a PhysicalOp to a map key +fn value_to_physical_map_op( + physical_op: &PhysicalOp, M>, +) -> PhysicalMapOpKey { + let operator = + value_to_operator_map_key(&physical_op.operator, &(|v: &Value| value_to_map_key(v))); + + let goal = physical_op.goal.as_ref().map(|g| GoalMapKey { + group_id: g.group_id, + properties: value_to_map_key(&g.properties), + }); + + let cost = physical_op.cost.as_ref().map(|c| value_to_map_key(c)); + + PhysicalMapOpKey { + operator, + goal, + cost, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::tests::{ + assert_values_equal, create_logical_operator, create_physical_operator, + }; + + // Helper to create Value literals + fn int_val(i: i64) -> Value { + Value::new(CoreData::Literal(Literal::Int64(i))) + } + + fn bool_val(b: bool) -> Value { + Value::new(CoreData::Literal(Literal::Bool(b))) + } + + fn string_val(s: &str) -> Value { + Value::new(CoreData::Literal(Literal::String(s.to_string()))) + } + + fn float_val(f: f64) -> Value { + Value::new(CoreData::Literal(Literal::Float64(f))) + } + + fn unit_val() -> Value { + Value::new(CoreData::Literal(Literal::Unit)) + } + + fn tuple_val(items: Vec) -> Value { + Value::new(CoreData::Tuple(items)) + } + + fn struct_val(name: &str, fields: Vec) -> Value { + Value::new(CoreData::Struct(name.to_string(), fields)) + } + + fn array_val(items: Vec) -> Value { + Value::new(CoreData::Array(items)) + } + + fn none_val() -> Value { + Value::new(CoreData::None) + } + + fn fail_val(inner: Value) -> Value { + Value::new(CoreData::Fail(Box::new(inner))) + } + + #[test] + fn test_simple_map_operations() { + let mut map = Map::default(); + let map_key1 = value_to_map_key(&int_val(1)); + let map_key2 = value_to_map_key(&int_val(2)); + + // Insert key-value pairs directly into inner HashMap + map.inner.insert(map_key1, string_val("one")); + map.inner.insert(map_key2, string_val("two")); + + // Check retrieval + assert_values_equal(&map.get(&int_val(1)), &string_val("one")); + assert_values_equal(&map.get(&int_val(2)), &string_val("two")); + assert_values_equal(&map.get(&int_val(3)), &none_val()); // Non-existent key + + // Check map size + assert_eq!(map.inner.len(), 2); + } + + #[test] + fn test_map_from_pairs() { + let pairs = vec![ + (int_val(1), string_val("one")), + (int_val(2), string_val("two")), + ]; + + let map = Map::from_pairs(pairs); + + assert_values_equal(&map.get(&int_val(1)), &string_val("one")); + assert_values_equal(&map.get(&int_val(2)), &string_val("two")); + assert_eq!(map.inner.len(), 2); + } + + #[test] + fn test_map_concat() { + let mut map1 = Map::default(); + map1.inner + .insert(value_to_map_key(&int_val(1)), string_val("one")); + map1.inner + .insert(value_to_map_key(&int_val(2)), string_val("two")); + + let mut map2 = Map::default(); + map2.inner + .insert(value_to_map_key(&int_val(2)), string_val("TWO")); // Overwrite key 2 + map2.inner + .insert(value_to_map_key(&int_val(3)), string_val("three")); + + map1.concat(map2); + + assert_values_equal(&map1.get(&int_val(1)), &string_val("one")); + assert_values_equal(&map1.get(&int_val(2)), &string_val("TWO")); // Overwritten + assert_values_equal(&map1.get(&int_val(3)), &string_val("three")); + assert_eq!(map1.inner.len(), 3); + } + + #[test] + fn test_various_key_types() { + let mut map = Map::default(); + + // Basic literals + map.inner + .insert(value_to_map_key(&int_val(1)), string_val("int")); + map.inner + .insert(value_to_map_key(&bool_val(true)), string_val("bool")); + map.inner + .insert(value_to_map_key(&string_val("key")), string_val("string")); + map.inner + .insert(value_to_map_key(&unit_val()), string_val("unit")); + + // Compound types + map.inner.insert( + value_to_map_key(&tuple_val(vec![int_val(1), string_val("test")])), + string_val("tuple"), + ); + map.inner.insert( + value_to_map_key(&struct_val("Person", vec![string_val("John"), int_val(30)])), + string_val("struct"), + ); + map.inner + .insert(value_to_map_key(&none_val()), string_val("none")); + map.inner.insert( + value_to_map_key(&fail_val(string_val("error"))), + string_val("fail"), + ); + + // Operator types + let logical_op = create_logical_operator("filter", vec![int_val(1)], vec![]); + let physical_op = create_physical_operator("scan", vec![string_val("table")], vec![]); + map.inner + .insert(value_to_map_key(&logical_op), string_val("logical")); + map.inner + .insert(value_to_map_key(&physical_op), string_val("physical")); + + // Verify all keys work + assert_values_equal(&map.get(&int_val(1)), &string_val("int")); + assert_values_equal(&map.get(&bool_val(true)), &string_val("bool")); + assert_values_equal(&map.get(&string_val("key")), &string_val("string")); + assert_values_equal(&map.get(&unit_val()), &string_val("unit")); + assert_values_equal( + &map.get(&tuple_val(vec![int_val(1), string_val("test")])), + &string_val("tuple"), + ); + assert_values_equal( + &map.get(&struct_val("Person", vec![string_val("John"), int_val(30)])), + &string_val("struct"), + ); + assert_values_equal(&map.get(&none_val()), &string_val("none")); + assert_values_equal( + &map.get(&fail_val(string_val("error"))), + &string_val("fail"), + ); + assert_values_equal(&map.get(&logical_op), &string_val("logical")); + assert_values_equal(&map.get(&physical_op), &string_val("physical")); + + assert_eq!(map.inner.len(), 10); + } + + #[test] + #[should_panic(expected = "Invalid map key: Float64")] + fn test_float_key_panics() { + value_to_map_key(&float_val(3.14)); + } + + #[test] + #[should_panic(expected = "Invalid map key: Float64")] + fn test_tuple_with_float_panics() { + let tuple_with_float = tuple_val(vec![int_val(1), float_val(2.5)]); + value_to_map_key(&tuple_with_float); + } + + #[test] + #[should_panic] + fn test_array_key_panics() { + let array_key = array_val(vec![int_val(1), int_val(2)]); + value_to_map_key(&array_key); + } + + #[test] + fn test_empty_map() { + let map = Map::default(); + assert_eq!(map.inner.len(), 0); + assert!(map.inner.is_empty()); + } + + #[test] + fn test_map_key_conversion() { + // Test conversion between Value and MapKey + let values = vec![ + int_val(42), + string_val("hello"), + bool_val(true), + unit_val(), + tuple_val(vec![int_val(1), bool_val(false)]), + struct_val("Test", vec![string_val("field"), int_val(123)]), + none_val(), + fail_val(int_val(404)), + ]; + + for value in values { + // Convert Value to MapKey + let map_key = value_to_map_key(&value); + + // Use the map_key in a map + let mut map = Map::default(); + map.inner.insert(map_key, string_val("value")); + + // Verify we can retrieve with the original value + assert_values_equal(&map.get(&value), &string_val("value")); + } + } + + #[test] + fn test_deep_nesting() { + // Create deeply nested structure to test recursive key conversion + let nested_value = struct_val( + "Root", + vec![ + tuple_val(vec![ + int_val(1), + struct_val( + "Inner", + vec![ + string_val("deep"), + tuple_val(vec![bool_val(true), unit_val()]), + ], + ), + ]), + fail_val(none_val()), + ], + ); + + // Should not panic + let key = value_to_map_key(&nested_value); + + // Use it as a key + let mut map = Map::default(); + map.inner.insert(key, string_val("complex")); + + // Verify we can retrieve it + assert_values_equal(&map.get(&nested_value), &string_val("complex")); + } + + #[test] + fn test_operator_key_conversion() { + // Test LogicalOp conversion + let logical_op = create_logical_operator( + "join", + vec![string_val("id"), int_val(10)], + vec![int_val(1), int_val(2)], + ); + + let logical_key = value_to_map_key(&logical_op); + assert!(matches!(logical_key, MapKey::Logical(_))); + + // Test PhysicalOp conversion + let physical_op = create_physical_operator( + "hash_join", + vec![string_val("id")], + vec![int_val(1), int_val(2)], + ); + + let physical_key = value_to_map_key(&physical_op); + assert!(matches!(physical_key, MapKey::Physical(_))); + } + + #[test] + fn test_key_equality() { + // Test that equal values create equal map keys + let key1 = value_to_map_key(&tuple_val(vec![int_val(1), string_val("test")])); + let key2 = value_to_map_key(&tuple_val(vec![int_val(1), string_val("test")])); + + assert_eq!(key1, key2); + + // Test that different values create different map keys + let key3 = value_to_map_key(&tuple_val(vec![int_val(2), string_val("test")])); + assert_ne!(key1, key3); + } +} diff --git a/optd-dsl/src/analyzer/mod.rs b/optd-dsl/src/analyzer/mod.rs index 5776abb3..3883e948 100644 --- a/optd-dsl/src/analyzer/mod.rs +++ b/optd-dsl/src/analyzer/mod.rs @@ -1,5 +1,6 @@ pub mod context; pub mod hir; +pub mod map; pub mod semantic_checker; pub mod r#type; pub mod type_checker; diff --git a/optd-dsl/src/analyzer/semantic_checker/error.rs b/optd-dsl/src/analyzer/semantic_checker/error.rs index 848e7c0e..3857972e 100644 --- a/optd-dsl/src/analyzer/semantic_checker/error.rs +++ b/optd-dsl/src/analyzer/semantic_checker/error.rs @@ -2,10 +2,33 @@ use ariadne::{Report, Source}; use crate::utils::{error::Diagnose, span::Span}; +/// Error types for semantic analysis #[derive(Debug)] -pub struct SemanticError {} +pub enum SemanticError { + /// Error for duplicate ADT names + DuplicateAdt { + /// Name of the duplicate ADT + name: String, + /// Span of the first declaration + first_span: Span, + /// Span of the duplicate declaration + duplicate_span: Span, + }, +} + +impl SemanticError { + /// Creates a new error for duplicate ADT names + pub fn new_duplicate_adt(name: String, first_span: Span, duplicate_span: Span) -> Box { + Self::DuplicateAdt { + name, + first_span, + duplicate_span, + } + .into() + } +} -impl Diagnose for SemanticError { +impl Diagnose for Box { fn report(&self) -> Report { todo!() } diff --git a/optd-dsl/src/analyzer/type.rs b/optd-dsl/src/analyzer/type.rs index cc396257..bcbdb363 100644 --- a/optd-dsl/src/analyzer/type.rs +++ b/optd-dsl/src/analyzer/type.rs @@ -1,8 +1,9 @@ +use super::semantic_checker::error::SemanticError; use crate::parser::ast::Adt; -use std::{ - collections::{HashMap, HashSet}, - ops::{Deref, DerefMut}, -}; +use crate::utils::error::CompileError; +use crate::utils::span::Span; +use Adt::*; +use std::collections::{HashMap, HashSet}; pub type Identifier = String; @@ -20,89 +21,24 @@ pub enum Type { // Special types Unit, - Universe, + Universe, // All types are subtypes of Universe + Never, // Inherits all types. + Unknown, - // Complex types + // User types + Adt(Identifier), + Generic(Identifier), + + // Composite types Array(Box), Closure(Box, Box), Tuple(Vec), Map(Box, Box), Optional(Box), - // User-defined types - Adt(Identifier), -} - -/// A typed value that carries both a value and its type information. -/// -/// This generic wrapper allows attaching type information to any value in the compiler pipeline. -/// It's particularly useful for tracking types through expressions, statements, and other AST nodes -/// during the type checking phase. -/// -/// # Type Parameters -/// -/// * `T` - The type of the wrapped value -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct Typed { - /// The wrapped value - pub value: Box, - - /// The type of the value - pub ty: Type, -} - -impl Deref for Typed { - /// The wrapped type T - type Target = T; - - /// Returns a reference to the wrapped value - fn deref(&self) -> &Self::Target { - &self.value - } -} - -/// Implements mutable dereferencing for Typed -impl DerefMut for Typed { - /// Returns a mutable reference to the wrapped value - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.value - } -} - -impl Typed { - /// Creates a new typed value with the given value and type. - /// - /// # Arguments - /// - /// * `value` - The value to wrap - /// * `ty` - The type to associate with the value - /// - /// # Returns - /// - /// A new `Typed` instance containing the value and its type - pub fn new(value: T, ty: Type) -> Self { - Self { - value: Box::new(value), - ty, - } - } - - /// Checks if this typed value's type is a subtype of another typed value's type. - /// - /// This is useful for type checking to ensure type compatibility between - /// expressions, assignments, function calls, etc. - /// - /// # Arguments - /// - /// * `other` - The other typed value to compare against - /// * `registry` - The type registry that defines subtyping relationships - /// - /// # Returns - /// - /// `true` if this type is a subtype of the other, `false` otherwise - pub fn subtype_of(&self, other: &Self, registry: &TypeRegistry) -> bool { - registry.is_subtype(&self.ty, &other.ty) - } + // For Logical & Physical: memo status + Stored(Box), + Costed(Box), } /// Manages the type hierarchy and subtyping relationships @@ -113,6 +49,7 @@ impl Typed { #[derive(Debug, Clone, Default)] pub struct TypeRegistry { subtypes: HashMap>, + adt_spans: HashMap, // Track spans for error reporting } impl TypeRegistry { @@ -125,26 +62,60 @@ impl TypeRegistry { /// # Arguments /// /// * `adt` - The ADT to register - pub fn register_adt(&mut self, adt: &Adt) { + /// + /// # Returns + /// + /// `Ok(())` if registration is successful, or a `CompileError` if a duplicate name is found + pub fn register_adt(&mut self, adt: &Adt) -> Result<(), CompileError> { match adt { - Adt::Product { name, .. } => { + Product { name, .. } => { let type_name = name.value.as_ref().clone(); + // Check for duplicate ADT names. + if let Some(existing_span) = self.adt_spans.get(&type_name) { + return Err(SemanticError::new_duplicate_adt( + type_name, + existing_span.clone(), + name.span.clone(), + ) + .into()); + } + + self.adt_spans.insert(type_name.clone(), name.span.clone()); self.subtypes.entry(type_name).or_default(); + Ok(()) } - Adt::Sum { name, variants } => { + Sum { name, variants } => { let enum_name = name.value.as_ref().clone(); + + // Check for duplicate ADT names. + if let Some(existing_span) = self.adt_spans.get(&enum_name) { + return Err(SemanticError::new_duplicate_adt( + enum_name, + existing_span.clone(), + name.clone().span, + ) + .into()); + } + + self.adt_spans.insert(enum_name.clone(), name.clone().span); self.subtypes.entry(enum_name.clone()).or_default(); + for variant in variants { let variant_adt = variant.value.as_ref(); - self.register_adt(variant_adt); + // Register each variant. + self.register_adt(variant_adt)?; + let variant_name = match variant_adt { - Adt::Product { name, .. } => name.value.as_ref(), - Adt::Sum { name, .. } => name.value.as_ref(), + Product { name, .. } => name.value.as_ref(), + Sum { name, .. } => name.value.as_ref(), }; + + // Add variant as a subtype of the enum. if let Some(children) = self.subtypes.get_mut(&enum_name) { children.insert(variant_name.clone()); } } + Ok(()) } } } @@ -171,7 +142,31 @@ impl TypeRegistry { match (child, parent) { // Universe is the top type - everything is a subtype of Universe (_, Type::Universe) => true, - // Check transitive inheritance + + // Never is the bottom type - it is a subtype of everything + (Type::Never, _) => true, + + // Stored and Costed type handling + (Type::Stored(child_inner), Type::Stored(parent_inner)) => { + self.is_subtype(child_inner, parent_inner) + } + (Type::Costed(child_inner), Type::Costed(parent_inner)) => { + self.is_subtype(child_inner, parent_inner) + } + (Type::Costed(child_inner), Type::Stored(parent_inner)) => { + // Costed(A) is a subtype of Stored(A) + self.is_subtype(child_inner, parent_inner) + } + (Type::Costed(child_inner), parent_inner) => { + // Costed(A) is a subtype of A + self.is_subtype(child_inner, parent_inner) + } + (Type::Stored(child_inner), parent_inner) => { + // Stored(A) is a subtype of A + self.is_subtype(child_inner, parent_inner) + } + + // Check transitive inheritance for ADTs (Type::Adt(child_name), Type::Adt(parent_name)) => { if child_name == parent_name { return true; @@ -186,10 +181,12 @@ impl TypeRegistry { }) }) } + // Array covariance: Array[T] <: Array[U] if T <: U (Type::Array(child_elem), Type::Array(parent_elem)) => { self.is_subtype(child_elem, parent_elem) } + // Tuple covariance: (T1, T2, ...) <: (U1, U2, ...) if T1 <: U1, T2 <: U2, ... (Type::Tuple(child_types), Type::Tuple(parent_types)) => { if child_types.len() != parent_types.len() { @@ -200,19 +197,23 @@ impl TypeRegistry { .zip(parent_types.iter()) .all(|(c, p)| self.is_subtype(c, p)) } + // Map covariance: Map[K1, V1] <: Map[K2, V2] if K1 <: K2 and V1 <: V2 (Type::Map(child_key, child_val), Type::Map(parent_key, parent_val)) => { self.is_subtype(child_key, parent_key) && self.is_subtype(child_val, parent_val) } + // Function contravariance on args, covariance on return type: // (T1 -> U1) <: (T2 -> U2) if T2 <: T1 and U1 <: U2 (Type::Closure(child_param, child_ret), Type::Closure(parent_param, parent_ret)) => { self.is_subtype(parent_param, child_param) && self.is_subtype(child_ret, parent_ret) } + // Optional type covariance: Optional[T] <: Optional[U] if T <: U (Type::Optional(child_ty), Type::Optional(parent_ty)) => { self.is_subtype(child_ty, parent_ty) } + _ => false, } } @@ -245,21 +246,96 @@ mod type_registry_tests { }) .collect(); - Adt::Product { + Product { name: spanned(name.to_string()), fields: spanned_fields, } } fn create_sum_adt(name: &str, variants: Vec) -> Adt { - let spanned_variants: Vec> = variants.into_iter().map(spanned).collect(); - - Adt::Sum { + Sum { name: spanned(name.to_string()), - variants: spanned_variants, + variants: variants.into_iter().map(spanned).collect(), } } + #[test] + fn test_stored_and_costed_types() { + let registry = TypeRegistry::default(); + + // Test Stored type as a subtype of the inner type + assert!(registry.is_subtype(&Type::Stored(Box::new(Type::Int64)), &Type::Int64)); + + // Test Costed type as a subtype of Stored type + assert!(registry.is_subtype( + &Type::Costed(Box::new(Type::Int64)), + &Type::Stored(Box::new(Type::Int64)) + )); + + // Test Costed type as a subtype of the inner type (transitivity) + assert!(registry.is_subtype(&Type::Costed(Box::new(Type::Int64)), &Type::Int64)); + + // Test Stored type covariance + let mut adts_registry = TypeRegistry::default(); + let animal = create_product_adt("Animal", vec![]); + let dog = create_product_adt("Dog", vec![]); + let animals_enum = create_sum_adt("Animals", vec![animal, dog]); + adts_registry.register_adt(&animals_enum).unwrap(); + + assert!(adts_registry.is_subtype( + &Type::Stored(Box::new(Type::Adt("Dog".to_string()))), + &Type::Stored(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test Costed type covariance + assert!(adts_registry.is_subtype( + &Type::Costed(Box::new(Type::Adt("Dog".to_string()))), + &Type::Costed(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test the inheritance relationship: Costed(Dog) <: Stored(Animals) + assert!(adts_registry.is_subtype( + &Type::Costed(Box::new(Type::Adt("Dog".to_string()))), + &Type::Stored(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test nested Stored/Costed types + assert!(adts_registry.is_subtype( + &Type::Stored(Box::new(Type::Costed(Box::new(Type::Adt( + "Dog".to_string() + ))))), + &Type::Stored(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test with Array of Stored/Costed types + assert!(adts_registry.is_subtype( + &Type::Array(Box::new(Type::Costed(Box::new(Type::Adt( + "Dog".to_string() + ))))), + &Type::Array(Box::new(Type::Stored(Box::new(Type::Adt( + "Animals".to_string() + ))))) + )); + } + + #[test] + fn test_duplicate_adt_detection() { + let mut registry = TypeRegistry::default(); + + // First registration should succeed + let car1 = create_product_adt("Car", vec![]); + assert!(registry.register_adt(&car1).is_ok()); + + // Second registration with the same name should fail + let car2 = Product { + name: spanned("Car".to_string()), + fields: vec![], + }; + + let result = registry.register_adt(&car2); + assert!(result.is_err()); + } + #[test] fn test_primitive_type_equality() { let registry = TypeRegistry::default(); @@ -316,7 +392,7 @@ mod type_registry_tests { let vehicle = create_product_adt("Vehicle", vec![]); let car = create_product_adt("Car", vec![]); let vehicles_enum = create_sum_adt("Vehicles", vec![vehicle, car]); - adts_registry.register_adt(&vehicles_enum); + adts_registry.register_adt(&vehicles_enum).unwrap(); assert!(adts_registry.is_subtype( &Type::Array(Box::new(Type::Adt("Car".to_string()))), @@ -421,7 +497,7 @@ mod type_registry_tests { let animal = create_product_adt("Animal", vec![]); let dog = create_product_adt("Dog", vec![]); let animals_enum = create_sum_adt("Animals", vec![animal, dog]); - adts_registry.register_adt(&animals_enum); + adts_registry.register_adt(&animals_enum).unwrap(); // (Animals -> Bool) <: (Dog -> Bool) because Dog <: Animals (contravariance) assert!(adts_registry.is_subtype( @@ -469,7 +545,7 @@ mod type_registry_tests { let shapes_enum = create_sum_adt("Shapes", vec![shape, circle, rectangle]); // Register the ADT - registry.register_adt(&shapes_enum); + registry.register_adt(&shapes_enum).unwrap(); // Test subtypes relationship assert!(registry.is_subtype( @@ -531,12 +607,43 @@ mod type_registry_tests { &Type::Universe )); + // Check that Universe is a supertype of Stored and Costed types + assert!(registry.is_subtype(&Type::Stored(Box::new(Type::Int64)), &Type::Universe)); + assert!(registry.is_subtype(&Type::Costed(Box::new(Type::Int64)), &Type::Universe)); + // But Universe is not a subtype of any other type assert!(!registry.is_subtype(&Type::Universe, &Type::Int64)); assert!(!registry.is_subtype(&Type::Universe, &Type::String)); assert!(!registry.is_subtype(&Type::Universe, &Type::Array(Box::new(Type::Int64)))); } + #[test] + fn test_never_as_bottom_type() { + let registry = TypeRegistry::default(); + + // Never is a subtype of all primitive types + assert!(registry.is_subtype(&Type::Never, &Type::Int64)); + assert!(registry.is_subtype(&Type::Never, &Type::String)); + assert!(registry.is_subtype(&Type::Never, &Type::Bool)); + assert!(registry.is_subtype(&Type::Never, &Type::Float64)); + assert!(registry.is_subtype(&Type::Never, &Type::Unit)); + assert!(registry.is_subtype(&Type::Never, &Type::Universe)); + + // Never is a subtype of complex types + assert!(registry.is_subtype(&Type::Never, &Type::Array(Box::new(Type::Int64)))); + assert!(registry.is_subtype(&Type::Never, &Type::Tuple(vec![Type::Int64, Type::Bool]))); + assert!(registry.is_subtype( + &Type::Never, + &Type::Closure(Box::new(Type::Int64), Box::new(Type::Bool)) + )); + + // But no type is a subtype of Never (except Never itself) + assert!(!registry.is_subtype(&Type::Int64, &Type::Never)); + assert!(!registry.is_subtype(&Type::Bool, &Type::Never)); + assert!(!registry.is_subtype(&Type::Universe, &Type::Never)); + assert!(!registry.is_subtype(&Type::Array(Box::new(Type::Int64)), &Type::Never)); + } + #[test] fn test_complex_nested_type_hierarchy() { let mut registry = TypeRegistry::default(); @@ -573,7 +680,7 @@ mod type_registry_tests { let vehicles_enum = create_sum_adt("Vehicles", vec![vehicle, cars_enum, truck]); // Register the ADT - registry.register_adt(&vehicles_enum); + registry.register_adt(&vehicles_enum).unwrap(); // Test direct subtyping relationships assert!(registry.is_subtype( @@ -623,97 +730,4 @@ mod type_registry_tests { &Type::Adt("Cars".to_string()) )); } - - #[test] - fn test_combined_complex_types() { - let mut registry = TypeRegistry::default(); - - // Create and register a type hierarchy - let animal = create_product_adt("Animal", vec![]); - let dog = create_product_adt("Dog", vec![]); - let cat = create_product_adt("Cat", vec![]); - let animals_enum = create_sum_adt("Animals", vec![animal, dog, cat]); - registry.register_adt(&animals_enum); - - // Test array of ADTs - assert!(registry.is_subtype( - &Type::Array(Box::new(Type::Adt("Dog".to_string()))), - &Type::Array(Box::new(Type::Adt("Animals".to_string()))) - )); - - // Test tuple of ADTs - assert!(registry.is_subtype( - &Type::Tuple(vec![ - Type::Adt("Dog".to_string()), - Type::Adt("Cat".to_string()) - ]), - &Type::Tuple(vec![ - Type::Adt("Animals".to_string()), - Type::Adt("Animals".to_string()) - ]) - )); - - // Test map with ADT keys and values - assert!(registry.is_subtype( - &Type::Map( - Box::new(Type::Adt("Dog".to_string())), - Box::new(Type::Adt("Cat".to_string())) - ), - &Type::Map( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Animals".to_string())) - ) - )); - - // Test closures with ADTs - assert!(registry.is_subtype( - &Type::Closure( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Dog".to_string())) - ), - &Type::Closure( - Box::new(Type::Adt("Dog".to_string())), - Box::new(Type::Adt("Animals".to_string())) - ) - )); - - // Test deeply nested types - assert!(registry.is_subtype( - &Type::Map( - Box::new(Type::String), - Box::new(Type::Array(Box::new(Type::Tuple(vec![ - Type::Adt("Dog".to_string()), - Type::Closure( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Cat".to_string())) - ) - ])))) - ), - &Type::Map( - Box::new(Type::String), - Box::new(Type::Array(Box::new(Type::Tuple(vec![ - Type::Adt("Animals".to_string()), - Type::Closure( - Box::new(Type::Adt("Dog".to_string())), - Box::new(Type::Adt("Animals".to_string())) - ) - ])))) - ) - )); - - // Any complex nested type is a subtype of Universe - assert!(registry.is_subtype( - &Type::Map( - Box::new(Type::String), - Box::new(Type::Array(Box::new(Type::Tuple(vec![ - Type::Adt("Dog".to_string()), - Type::Closure( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Cat".to_string())) - ) - ])))) - ), - &Type::Universe - )); - } } diff --git a/optd-dsl/src/analyzer/type_checker/error.rs b/optd-dsl/src/analyzer/type_checker/error.rs index 5a9774fe..16f2edd8 100644 --- a/optd-dsl/src/analyzer/type_checker/error.rs +++ b/optd-dsl/src/analyzer/type_checker/error.rs @@ -5,7 +5,7 @@ use crate::utils::{error::Diagnose, span::Span}; #[derive(Debug)] pub struct TypeError {} -impl Diagnose for TypeError { +impl Diagnose for Box { fn report(&self) -> Report { todo!() } diff --git a/optd-dsl/src/cli/basic.op b/optd-dsl/src/cli/basic.op index 09dab1d7..4f18393e 100644 --- a/optd-dsl/src/cli/basic.op +++ b/optd-dsl/src/cli/basic.op @@ -93,6 +93,8 @@ data JoinType = fn map_get(map: {K: V}, key: K): V = None +fn empty: V + [rust] fn (expr: Scalar) apply_children(f: Scalar -> Scalar) = None @@ -102,10 +104,10 @@ fn (pred: Predicate) remap(map: {I64 : I64}) = \ _ -> predicate.apply_children(child -> rewrite_column_refs(child, map)) [rule] -fn (expr: Logical) join_commute: Logical? = match expr +fn (expr: Logical*) join_commute: Logical? = match expr \ Join(left, right, Inner, cond) -> let - right_indices = 0.right.schema_len, + right_indices = 0..right.schema_len, left_indices = 0..left.schema_len, remapping = left_indices.map(i -> (i, i + right_len)) ++ right_indices.map(i -> (left_len + i, i)).to_map, diff --git a/optd-dsl/src/engine/eval/binary.rs b/optd-dsl/src/engine/eval/binary.rs index e3e4412b..c5ea2901 100644 --- a/optd-dsl/src/engine/eval/binary.rs +++ b/optd-dsl/src/engine/eval/binary.rs @@ -23,27 +23,29 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { use BinOp::*; use CoreData::*; - match (left.0, op, right.0) { + match (left.data, op, right.data) { // Handle operations between two literals. (Literal(l), op, Literal(r)) => match (l, op, r) { // Integer operations (arithmetic, comparison). - (Int64(l), Add | Sub | Mul | Div | Eq | Lt, Int64(r)) => Value(Literal(match op { - Add => Int64(l + r), // Integer addition - Sub => Int64(l - r), // Integer subtraction - Mul => Int64(l * r), // Integer multiplication - Div => Int64(l / r), // Integer division (panics on divide by zero) - Eq => Bool(l == r), // Integer equality comparison - Lt => Bool(l < r), // Integer less-than comparison - _ => unreachable!(), // This branch is unreachable due to pattern guard - })), + (Int64(l), Add | Sub | Mul | Div | Eq | Lt, Int64(r)) => { + Value::new(Literal(match op { + Add => Int64(l + r), // Integer addition + Sub => Int64(l - r), // Integer subtraction + Mul => Int64(l * r), // Integer multiplication + Div => Int64(l / r), // Integer division (panics on divide by zero) + Eq => Bool(l == r), // Integer equality comparison + Lt => Bool(l < r), // Integer less-than comparison + _ => unreachable!(), // This branch is unreachable due to pattern guard + })) + } // Integer range operation (creates an array of sequential integers). - (Int64(l), Range, Int64(r)) => { - Value(Array((l..=r).map(|n| Value(Literal(Int64(n)))).collect())) - } + (Int64(l), Range, Int64(r)) => Value::new(Array( + (l..=r).map(|n| Value::new(Literal(Int64(n)))).collect(), + )), // Float operations (arithmetic, comparison). - (Float64(l), op, Float64(r)) => Value(Literal(match op { + (Float64(l), op, Float64(r)) => Value::new(Literal(match op { Add => Float64(l + r), // Float addition Sub => Float64(l - r), // Float subtraction Mul => Float64(l * r), // Float multiplication @@ -53,7 +55,7 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { })), // Boolean operations (logical, comparison). - (Bool(l), op, Bool(r)) => Value(Literal(match op { + (Bool(l), op, Bool(r)) => Value::new(Literal(match op { And => Bool(l && r), // Logical AND Or => Bool(l || r), // Logical OR Eq => Bool(l == r), // Boolean equality comparison @@ -61,7 +63,7 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { })), // String operations (comparison, concatenation). - (String(l), op, String(r)) => Value(Literal(match op { + (String(l), op, String(r)) => Value::new(Literal(match op { Eq => Bool(l == r), // String equality comparison Concat => String(format!("{l}{r}")), // String concatenation _ => panic!("Invalid string operation"), // Other operations not supported @@ -75,14 +77,13 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { (Array(l), Concat, Array(r)) => { let mut result = l.clone(); result.extend(r.iter().cloned()); - Value(Array(result)) + Value::new(Array(result)) } // Map concatenation (joins two maps). - (Map(l), Concat, Map(r)) => { - let mut result = l.clone(); - result.extend(r.iter().cloned()); - Value(Map(result)) + (Map(mut l), Concat, Map(r)) => { + l.concat(r); + Value::new(Map(l)) } // Any other combination of value types or operations is not supported. @@ -96,55 +97,54 @@ mod tests { use BinOp::*; use CoreData::*; use Literal::*; - use std::collections::HashMap; use super::eval_binary_op; // Helper function to create integer Value fn int(i: i64) -> Value { - Value(Literal(Int64(i))) + Value::new(Literal(Int64(i))) } // Helper function to create float Value fn float(f: f64) -> Value { - Value(Literal(Float64(f))) + Value::new(Literal(Float64(f))) } // Helper function to create boolean Value fn boolean(b: bool) -> Value { - Value(Literal(Bool(b))) + Value::new(Literal(Bool(b))) } // Helper function to create string Value fn string(s: &str) -> Value { - Value(Literal(String(s.to_string()))) + Value::new(Literal(String(s.to_string()))) } #[test] fn test_integer_arithmetic() { // Addition - if let Literal(Int64(result)) = eval_binary_op(int(5), &Add, int(7)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(5), &Add, int(7)).data { assert_eq!(result, 12); } else { panic!("Expected Int64"); } // Subtraction - if let Literal(Int64(result)) = eval_binary_op(int(10), &Sub, int(3)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(10), &Sub, int(3)).data { assert_eq!(result, 7); } else { panic!("Expected Int64"); } // Multiplication - if let Literal(Int64(result)) = eval_binary_op(int(4), &Mul, int(5)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(4), &Mul, int(5)).data { assert_eq!(result, 20); } else { panic!("Expected Int64"); } // Division - if let Literal(Int64(result)) = eval_binary_op(int(20), &Div, int(4)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(20), &Div, int(4)).data { assert_eq!(result, 5); } else { panic!("Expected Int64"); @@ -154,28 +154,28 @@ mod tests { #[test] fn test_integer_comparison() { // Equality - true case - if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(5)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(5)).data { assert!(result); } else { panic!("Expected Bool"); } // Equality - false case - if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(7)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(7)).data { assert!(!result); } else { panic!("Expected Bool"); } // Less than - true case - if let Literal(Bool(result)) = eval_binary_op(int(5), &Lt, int(10)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(5), &Lt, int(10)).data { assert!(result); } else { panic!("Expected Bool"); } // Less than - false case - if let Literal(Bool(result)) = eval_binary_op(int(10), &Lt, int(5)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(10), &Lt, int(5)).data { assert!(!result); } else { panic!("Expected Bool"); @@ -185,12 +185,12 @@ mod tests { #[test] fn test_integer_range() { // Range operation - if let Array(result) = eval_binary_op(int(1), &Range, int(5)).0 { + if let Array(result) = eval_binary_op(int(1), &Range, int(5)).data { assert_eq!(result.len(), 5); // Check individual elements for (i, val) in result.iter().enumerate() { - if let Literal(Int64(n)) = val.0 { + if let Literal(Int64(n)) = val.data { assert_eq!(n, (i as i64) + 1); } else { panic!("Expected Int64 in array"); @@ -204,35 +204,35 @@ mod tests { #[test] fn test_float_operations() { // Addition - if let Literal(Float64(result)) = eval_binary_op(float(3.5), &Add, float(2.25)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(3.5), &Add, float(2.25)).data { assert_eq!(result, 5.75); } else { panic!("Expected Float64"); } // Subtraction - if let Literal(Float64(result)) = eval_binary_op(float(10.5), &Sub, float(3.25)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(10.5), &Sub, float(3.25)).data { assert_eq!(result, 7.25); } else { panic!("Expected Float64"); } // Multiplication - if let Literal(Float64(result)) = eval_binary_op(float(4.0), &Mul, float(2.5)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(4.0), &Mul, float(2.5)).data { assert_eq!(result, 10.0); } else { panic!("Expected Float64"); } // Division - if let Literal(Float64(result)) = eval_binary_op(float(10.0), &Div, float(2.5)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(10.0), &Div, float(2.5)).data { assert_eq!(result, 4.0); } else { panic!("Expected Float64"); } // Less than - if let Literal(Bool(result)) = eval_binary_op(float(3.5), &Lt, float(3.6)).0 { + if let Literal(Bool(result)) = eval_binary_op(float(3.5), &Lt, float(3.6)).data { assert!(result); } else { panic!("Expected Bool"); @@ -242,35 +242,35 @@ mod tests { #[test] fn test_boolean_operations() { // AND - true case - if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(true)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(true)).data { assert!(result); } else { panic!("Expected Bool"); } // AND - false case - if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(false)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(false)).data { assert!(!result); } else { panic!("Expected Bool"); } // OR - true case - if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(true)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(true)).data { assert!(result); } else { panic!("Expected Bool"); } // OR - false case - if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(false)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(false)).data { assert!(!result); } else { panic!("Expected Bool"); } // Equality - if let Literal(Bool(result)) = eval_binary_op(boolean(true), &Eq, boolean(true)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(true), &Eq, boolean(true)).data { assert!(result); } else { panic!("Expected Bool"); @@ -280,14 +280,14 @@ mod tests { #[test] fn test_string_operations() { // Equality - true case - if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("hello")).0 { + if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("hello")).data { assert!(result); } else { panic!("Expected Bool"); } // Equality - false case - if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("world")).0 { + if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("world")).data { assert!(!result); } else { panic!("Expected Bool"); @@ -295,7 +295,7 @@ mod tests { // Concatenation if let Literal(String(result)) = - eval_binary_op(string("hello "), &Concat, string("world")).0 + eval_binary_op(string("hello "), &Concat, string("world")).data { assert_eq!(result, "hello world"); } else { @@ -306,17 +306,17 @@ mod tests { #[test] fn test_array_concatenation() { // Create two arrays - let array1 = Value(Array(vec![int(1), int(2), int(3)])); - let array2 = Value(Array(vec![int(4), int(5)])); + let array1 = Value::new(Array(vec![int(1), int(2), int(3)])); + let array2 = Value::new(Array(vec![int(4), int(5)])); // Concatenate arrays - if let Array(result) = eval_binary_op(array1, &Concat, array2).0 { + if let Array(result) = eval_binary_op(array1, &Concat, array2).data { assert_eq!(result.len(), 5); // Check the elements let expected = [1, 2, 3, 4, 5]; for (i, val) in result.iter().enumerate() { - if let Literal(Int64(n)) = val.0 { + if let Literal(Int64(n)) = val.data { assert_eq!(n, expected[i] as i64); } else { panic!("Expected Int64 in array"); @@ -329,32 +329,51 @@ mod tests { #[test] fn test_map_concatenation() { - // Create two maps - let map1 = Value(Map(vec![(string("a"), int(1)), (string("b"), int(2))])); - - let map2 = Value(Map(vec![(string("c"), int(3)), (string("d"), int(4))])); + use crate::analyzer::map::Map; + + // Create two maps using Map::from_pairs + let map1 = Value::new(Map(Map::from_pairs(vec![ + (string("a"), int(1)), + (string("b"), int(2)), + ]))); + let map2 = Value::new(Map(Map::from_pairs(vec![ + (string("c"), int(3)), + (string("d"), int(4)), + ]))); // Concatenate maps - if let Map(result) = eval_binary_op(map1, &Concat, map2).0 { - assert_eq!(result.len(), 4); - - // Convert to a HashMap for easier testing - let map: HashMap = result - .iter() - .map(|(k, v)| { - if let (Literal(String(key)), Literal(Int64(value))) = (&k.0, &v.0) { - (key.clone(), *value) - } else { - panic!("Expected String keys and Int64 values"); - } - }) - .collect(); - - // Check elements - assert_eq!(map.get("a"), Some(&1)); - assert_eq!(map.get("b"), Some(&2)); - assert_eq!(map.get("c"), Some(&3)); - assert_eq!(map.get("d"), Some(&4)); + if let Map(result) = eval_binary_op(map1, &Concat, map2).data { + // Check each key-value pair is accessible + if let Literal(Int64(v)) = result.get(&string("a")).data { + assert_eq!(v, 1); + } else { + panic!("Expected Int64 for key 'a'"); + } + + if let Literal(Int64(v)) = result.get(&string("b")).data { + assert_eq!(v, 2); + } else { + panic!("Expected Int64 for key 'b'"); + } + + if let Literal(Int64(v)) = result.get(&string("c")).data { + assert_eq!(v, 3); + } else { + panic!("Expected Int64 for key 'c'"); + } + + if let Literal(Int64(v)) = result.get(&string("d")).data { + assert_eq!(v, 4); + } else { + panic!("Expected Int64 for key 'd'"); + } + + // Check a non-existent key returns None + if let None = result.get(&string("z")).data { + // This is the expected behavior + } else { + panic!("Expected None for non-existent key"); + } } else { panic!("Expected Map"); } diff --git a/optd-dsl/src/engine/eval/core.rs b/optd-dsl/src/engine/eval/core.rs index 30d6a651..ce7729fc 100644 --- a/optd-dsl/src/engine/eval/core.rs +++ b/optd-dsl/src/engine/eval/core.rs @@ -27,24 +27,27 @@ where match data { Literal(lit) => { // Directly continue with the literal value. - k(Value(Literal(lit))).await + k(Value::new(Literal(lit))).await } Array(items) => evaluate_collection(items, Array, engine, k).await, Tuple(items) => evaluate_collection(items, Tuple, engine, k).await, Struct(name, items) => { evaluate_collection(items, move |values| Struct(name, values), engine, k).await } - Map(items) => evaluate_map(items, engine, k).await, + Map(items) => { + // Directly continue with the map value. + k(Value::new(Map(items))).await + } Function(fun_type) => { // Directly continue with the function value. - k(Value(Function(fun_type))).await + k(Value::new(Function(fun_type))).await } Fail(msg) => evaluate_fail(*msg, engine, k).await, Logical(op) => evaluate_logical_operator(op, engine, k).await, Physical(op) => evaluate_physical_operator(op, engine, k).await, None => { // Directly continue with null value. - k(Value(None)).await + k(Value::new(None)).await } } } @@ -72,7 +75,7 @@ where engine, Arc::new(move |values| { Box::pin(capture!([constructor, k], async move { - let result = Value(constructor(values)); + let result = Value::new(constructor(values)); k(result).await })) }), @@ -80,49 +83,6 @@ where .await } -/// Evaluates a map expression. -/// -/// # Parameters -/// -/// * `items` - The key-value pairs to evaluate. -/// * `engine` - The evaluation engine. -/// * `k` - The continuation to receive evaluation results. -async fn evaluate_map( - items: Vec<(Arc, Arc)>, - engine: Engine, - k: Continuation>, -) -> EngineResponse -where - O: Send + 'static, -{ - // Extract keys and values. - let (keys, values): (Vec>, Vec>) = items.into_iter().unzip(); - - // First evaluate all key expressions. - evaluate_sequence( - keys, - engine.clone(), - Arc::new(move |keys_values| { - Box::pin(capture!([values, engine, k], async move { - // Then evaluate all value expressions. - evaluate_sequence( - values, - engine, - Arc::new(move |values_values| { - Box::pin(capture!([keys_values, k], async move { - // Create a map from keys and values. - let map_items = keys_values.into_iter().zip(values_values).collect(); - k(Value(Map(map_items))).await - })) - }), - ) - .await - })) - }), - ) - .await -} - /// Evaluates a fail expression. /// /// # Parameters @@ -142,10 +102,9 @@ where .evaluate( msg, Arc::new(move |value| { - Box::pin(capture!( - [k], - async move { k(Value(Fail(value.into()))).await } - )) + Box::pin(capture!([k], async move { + k(Value::new(Fail(Box::new(value)))).await + })) }), ) .await @@ -154,7 +113,7 @@ where #[cfg(test)] mod tests { use crate::engine::Continuation; - use crate::engine::test_utils::{ + use crate::utils::tests::{ TestHarness, evaluate_and_collect, evaluate_and_collect_with_custom_k, int, lit_expr, string, }; @@ -181,7 +140,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 42); } @@ -207,18 +166,18 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Array(elements) => { assert_eq!(elements.len(), 3); - match &elements[0].0 { + match &elements[0].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 1), _ => panic!("Expected integer literal"), } - match &elements[1].0 { + match &elements[1].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 2), _ => panic!("Expected integer literal"), } - match &elements[2].0 { + match &elements[2].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 3), _ => panic!("Expected integer literal"), } @@ -245,18 +204,18 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Tuple(elements) => { assert_eq!(elements.len(), 3); - match &elements[0].0 { + match &elements[0].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 42), _ => panic!("Expected integer literal"), } - match &elements[1].0 { + match &elements[1].data { CoreData::Literal(Literal::String(value)) => assert_eq!(value, "hello"), _ => panic!("Expected string literal"), } - match &elements[2].0 { + match &elements[2].data { CoreData::Literal(Literal::Bool(value)) => assert!(*value), _ => panic!("Expected boolean literal"), } @@ -282,15 +241,15 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Struct(name, fields) => { assert_eq!(name, "Point"); assert_eq!(fields.len(), 2); - match &fields[0].0 { + match &fields[0].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 10), _ => panic!("Expected integer literal"), } - match &fields[1].0 { + match &fields[1].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 20), _ => panic!("Expected integer literal"), } @@ -299,71 +258,6 @@ mod tests { } } - /// Test evaluation of map expressions - #[tokio::test] - async fn test_map_evaluation() { - let harness = TestHarness::new(); - let ctx = Context::default(); - let engine = Engine::new(ctx); - - // Create a map expression - let map_expr = Arc::new(Expr::new(CoreExpr(CoreData::Map(vec![ - (lit_expr(string("a")), lit_expr(int(1))), - (lit_expr(string("b")), lit_expr(int(2))), - (lit_expr(string("c")), lit_expr(int(3))), - ])))); - - let results = evaluate_and_collect(map_expr, engine, harness).await; - - // Check result - assert_eq!(results.len(), 1); - match &results[0].0 { - CoreData::Map(items) => { - assert_eq!(items.len(), 3); - - // Find key "a" and check value - let a_found = items.iter().any(|(k, v)| { - if let CoreData::Literal(Literal::String(key)) = &k.0 { - if key == "a" { - if let CoreData::Literal(Literal::Int64(val)) = &v.0 { - return *val == 1; - } - } - } - false - }); - assert!(a_found, "Key 'a' with value 1 not found"); - - // Find key "b" and check value - let b_found = items.iter().any(|(k, v)| { - if let CoreData::Literal(Literal::String(key)) = &k.0 { - if key == "b" { - if let CoreData::Literal(Literal::Int64(val)) = &v.0 { - return *val == 2; - } - } - } - false - }); - assert!(b_found, "Key 'b' with value 2 not found"); - - // Find key "c" and check value - let c_found = items.iter().any(|(k, v)| { - if let CoreData::Literal(Literal::String(key)) = &k.0 { - if key == "c" { - if let CoreData::Literal(Literal::Int64(val)) = &v.0 { - return *val == 3; - } - } - } - false - }); - assert!(c_found, "Key 'c' with value 3 not found"); - } - _ => panic!("Expected map"), - } - } - /// Test evaluation of function expressions #[tokio::test] async fn test_function_evaluation() { @@ -381,7 +275,7 @@ mod tests { // Check that we got a function value assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Function(_) => { // Successfully evaluated to a function } @@ -403,7 +297,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::None => { // Successfully evaluated to null } @@ -416,12 +310,12 @@ mod tests { async fn test_fail_evaluation() { let return_k: Continuation> = Arc::new(move |value| { Box::pin(async move { - match value { - Value(CoreData::Fail(boxed_value)) => match boxed_value.0 { + match value.data { + CoreData::Fail(boxed_value) => match boxed_value.data { CoreData::Literal(Literal::String(msg)) => Err(msg), _ => panic!("Expected string message in fail"), }, - value => Ok(value), + _ => Ok(value), } }) }); @@ -432,7 +326,9 @@ mod tests { // Create a fail expression with a message let fail_expr = Arc::new(Expr::new(CoreExpr(CoreData::Fail(Box::new(Arc::new( - Expr::new(CoreVal(Value(CoreData::Literal(string("error message"))))), + Expr::new(CoreVal(Value::new(CoreData::Literal(string( + "error message", + ))))), )))))); let results = diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index 4c9536e7..e2d61b6b 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -1,12 +1,15 @@ use super::{binary::eval_binary_op, unary::eval_unary_op}; -use crate::analyzer::hir::{BinOp, CoreData, Expr, FunKind, Identifier, Literal, UnaryOp, Value}; +use crate::analyzer::hir::{ + BinOp, CoreData, Expr, ExprKind, FunKind, Goal, GroupId, Identifier, Literal, LogicalOp, + Materializable, PhysicalOp, UnaryOp, Value, +}; +use crate::analyzer::map::Map; use crate::engine::{Continuation, EngineResponse}; use crate::{ capture, engine::{Engine, utils::evaluate_sequence}, }; -use CoreData::*; -use FunKind::*; +use ExprKind::*; use std::sync::Arc; /// Evaluates an if-then-else expression. @@ -31,15 +34,14 @@ pub(crate) async fn evaluate_if_then_else( where O: Send + 'static, { - // First evaluate the condition engine .clone() .evaluate( cond, Arc::new(move |value| { Box::pin(capture!([then_expr, else_expr, engine, k], async move { - match value.0 { - Literal(Literal::Bool(b)) => { + match value.data { + CoreData::Literal(Literal::Bool(b)) => { if b { engine.evaluate(then_expr, k).await } else { @@ -116,32 +118,6 @@ pub(crate) async fn evaluate_binary_expr( where O: Send + 'static, { - // Helper function to evaluate the right operand after the left is evaluated. - async fn evaluate_right( - left_val: Value, - right: Arc, - op: BinOp, - engine: Engine, - k: Continuation>, - ) -> EngineResponse - where - O: Send + 'static, - { - engine - .evaluate( - right, - Arc::new(move |right_val| { - Box::pin(capture!([left_val, op, k], async move { - // Apply the binary operation and pass result to continuation. - let result = eval_binary_op(left_val, &op, right_val); - k(result).await - })) - }), - ) - .await - } - - // First evaluate the left operand. engine .clone() .evaluate( @@ -155,6 +131,30 @@ where .await } +/// Helper function to evaluate the right operand after the left is evaluated. +async fn evaluate_right( + left_val: Value, + right: Arc, + op: BinOp, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + engine + .evaluate( + right, + Arc::new(move |right_val| { + Box::pin(capture!([left_val, op, k], async move { + let result = eval_binary_op(left_val, &op, right_val); + k(result).await + })) + }), + ) + .await +} + /// Evaluates a unary expression. /// /// Evaluates the operand, then applies the unary operation, passing the result to the continuation. @@ -174,13 +174,11 @@ pub(crate) async fn evaluate_unary_expr( where O: Send + 'static, { - // Evaluate the operand, then apply the unary operation engine .evaluate( expr, Arc::new(move |value| { Box::pin(capture!([op, k], async move { - // Apply the unary operation and pass result to continuation let result = eval_unary_op(&op, value); k(result).await })) @@ -189,19 +187,22 @@ where .await } -/// Evaluates a function call expression. +/// Evaluates a call expression. /// -/// First evaluates the function expression, then the arguments, and finally applies the function to +/// First evaluates the called expression, then the arguments, and finally applies the call to /// the arguments, passing results to the continuation. /// +/// Extended to support indexing into collections (Array, Tuple, Struct, Map, Logical, Physical) +/// when the called expression evaluates to one of these types and a single argument is provided. +/// /// # Parameters /// -/// * `fun` - The function expression to evaluate. +/// * `called` - The called expression to evaluate. /// * `args` - The argument expressions to evaluate. /// * `engine` - The evaluation engine. /// * `k` - The continuation to receive evaluation results. -pub(crate) async fn evaluate_function_call( - fun: Arc, +pub(crate) async fn evaluate_call( + called: Arc, args: Vec>, engine: Engine, k: Continuation>, @@ -213,20 +214,324 @@ where engine .clone() .evaluate( - fun, - Arc::new(move |fun_value| { + called, + Arc::new(move |called_value| { Box::pin(capture!([args, engine, k], async move { - match fun_value.0 { - // Handle closure (user-defined function). - Function(Closure(params, body)) => { + match called_value.data { + // Handle function calls. + CoreData::Function(FunKind::Closure(params, body)) => { evaluate_closure_call(params, body, args, engine, k).await } - // Handle Rust UDF (built-in function). - Function(RustUDF(udf)) => { + CoreData::Function(FunKind::RustUDF(udf)) => { evaluate_rust_udf_call(udf, args, engine, k).await } - // Value must be a function. - _ => panic!("Expected function value"), + + // Handle collection indexing. + CoreData::Array(_) | CoreData::Tuple(_) | CoreData::Struct(_, _) => { + evaluate_indexed_access(called_value, args, engine, k).await + } + CoreData::Map(_) => { + evaluate_map_lookup(called_value, args, engine, k).await + } + + // Handle operator field accesses. + CoreData::Logical(op) => { + evaluate_logical_operator_access(op, args, engine, k).await + } + CoreData::Physical(op) => { + evaluate_physical_operator_access(op, args, engine, k).await + } + + // Value must be a function or indexable collection/operator. + _ => panic!( + "Expected function or indexable value, got: {:?}", + called_value + ), + } + })) + }), + ) + .await +} + +/// Evaluates access to a logical operator. +/// +/// Handles both materialized and unmaterialized logical operators. +/// +/// # Parameters +/// +/// * `op` - The logical operator (materialized or unmaterialized). +/// * `args` - The argument expressions (should be a single index). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +fn evaluate_logical_operator_access( + op: Materializable, GroupId>, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> impl Future> + Send +where + O: Send + 'static, +{ + Box::pin(async move { + validate_single_index_arg(&args); + + match op { + // For unmaterialized logical operators, yield the group and continue when it's expanded. + Materializable::UnMaterialized(group_id) => EngineResponse::YieldGroup( + group_id, + Arc::new(move |expanded_value| { + Box::pin(capture!([args, engine, k], async move { + evaluate_call(Expr::new(CoreVal(expanded_value)).into(), args, engine, k) + .await + })) + }), + ), + // For materialized logical operators, access the data or children directly. + Materializable::Materialized(log_op) => { + evaluate_index_on_materialized_operator( + args[0].clone(), + log_op.operator.data, + log_op.operator.children, + engine, + k, + ) + .await + } + } + }) +} + +/// Evaluates access to a physical operator. +/// +/// Handles both materialized and unmaterialized physical operators. +/// +/// # Parameters +/// +/// * `op` - The physical operator (materialized or unmaterialized). +/// * `args` - The argument expressions (should be a single index). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +fn evaluate_physical_operator_access( + op: Materializable, Goal>, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> impl Future> + Send +where + O: Send + 'static, +{ + Box::pin(async move { + validate_single_index_arg(&args); + + match op { + // For unmaterialized physical operators, yield the goal and continue when it's expanded. + Materializable::UnMaterialized(goal) => EngineResponse::YieldGoal( + goal, + Arc::new(move |expanded_value| { + Box::pin(capture!([args, engine, k], async move { + evaluate_call(Expr::new(CoreVal(expanded_value)).into(), args, engine, k) + .await + })) + }), + ), + // For materialized physical operators, access the data or children directly. + Materializable::Materialized(phys_op) => { + evaluate_index_on_materialized_operator( + args[0].clone(), + phys_op.operator.data, + phys_op.operator.children, + engine, + k, + ) + .await + } + } + }) +} + +/// Validates that exactly one index argument is provided. +/// +/// # Parameters +/// +/// * `args` - The argument expressions to validate. +fn validate_single_index_arg(args: &[Arc]) { + if args.len() != 1 { + panic!("Operator access requires exactly one index argument"); + } +} + +/// Evaluates an index expression on a materialized operator. +/// +/// Treats the operator's data and children as a concatenated vector and accesses by index. +/// +/// # Parameters +/// +/// * `index_expr` - The index expression to evaluate. +/// * `data` - The operator's data fields. +/// * `children` - The operator's children. +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_index_on_materialized_operator( + index_expr: Arc, + data: Vec, + children: Vec, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + engine + .evaluate( + index_expr, + Arc::new(move |index_value| { + Box::pin(capture!([data, children, k], async move { + let index = extract_index(&index_value); + let result = access_operator_field(index, &data, &children); + + k(result).await + })) + }), + ) + .await +} + +/// Extracts an integer index from a value. +/// +/// # Parameters +/// +/// * `index_value` - The value containing the index. +/// +/// # Returns +/// +/// The extracted integer index. +fn extract_index(index_value: &Value) -> usize { + match &index_value.data { + CoreData::Literal(Literal::Int64(i)) => *i as usize, + _ => panic!("Index must be an integer, got: {:?}", index_value), + } +} + +/// Accesses a field in an operator by index. +/// +/// Treats data and children as a concatenated vector and accesses by index. +/// +/// # Parameters +/// +/// * `index` - The index to access. +/// * `data` - The operator's data fields. +/// * `children` - The operator's children. +/// +/// # Returns +/// +/// The value at the specified index. +fn access_operator_field(index: usize, data: &[Value], children: &[Value]) -> Value { + let data_len = data.len(); + let total_len = data_len + children.len(); + + if index >= total_len { + panic!("index out of bounds: {} >= {}", index, total_len); + } + + if index < data_len { + data[index].clone() + } else { + children[index - data_len].clone() + } +} + +/// Evaluates indexing into a collection (Array, Tuple, or Struct). +/// +/// # Parameters +/// +/// * `collection` - The collection value to index into. +/// * `args` - The argument expressions (should be a single index). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_indexed_access( + collection: Value, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + validate_single_index_arg(&args); + + engine + .evaluate( + args[0].clone(), + Arc::new(move |index_value| { + Box::pin(capture!([collection, k], async move { + let index = extract_index(&index_value); + + let result = match &collection.data { + CoreData::Array(items) => get_indexed_item(items, index), + CoreData::Tuple(items) => get_indexed_item(items, index), + CoreData::Struct(_, fields) => get_indexed_item(fields, index), + _ => panic!("Attempted to index a non-indexable value: {:?}", collection), + }; + + k(result).await + })) + }), + ) + .await +} + +/// Gets an item from a collection at the specified index. +/// +/// # Parameters +/// +/// * `items` - The collection items. +/// * `index` - The index to access. +/// +/// # Returns +/// +/// The value at the specified index. +fn get_indexed_item(items: &[Value], index: usize) -> Value { + if index < items.len() { + items[index].clone() + } else { + panic!("index out of bounds: {} >= {}", index, items.len()); + } +} + +/// Evaluates a map lookup. +/// +/// # Parameters +/// +/// * `map_value` - The map value to look up in. +/// * `args` - The argument expressions (should be a single key). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_map_lookup( + map_value: Value, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + validate_single_index_arg(&args); + + engine + .evaluate( + args[0].clone(), + Arc::new(move |key_value| { + Box::pin(capture!([map_value, k], async move { + match &map_value.data { + CoreData::Map(map) => { + let result = map.get(&key_value); + k(result).await + } + _ => panic!( + "Attempted to perform map lookup on non-map value: {:?}", + map_value + ), } })) }), @@ -261,7 +566,7 @@ where engine.clone(), Arc::new(move |arg_values| { Box::pin(capture!([params, body, engine, k], async move { - // Create a new context with parameters bound to arguments + // Create a new context with parameters bound to arguments. let mut new_ctx = engine.context.clone(); new_ctx.push_scope(); @@ -269,7 +574,7 @@ where new_ctx.bind(p.clone(), a); }); - // Evaluate the body in the new context + // Evaluate the body in the new context. engine.with_new_context(new_ctx).evaluate(body, k).await })) }), @@ -313,6 +618,49 @@ where .await } +/// Evaluates a map expression. +/// +/// # Parameters +/// +/// * `items` - The key-value pairs to evaluate. +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +pub(crate) async fn evaluate_map( + items: Vec<(Arc, Arc)>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + // Extract keys and values. + let (keys, values): (Vec>, Vec>) = items.into_iter().unzip(); + + // First evaluate all key expressions. + evaluate_sequence( + keys, + engine.clone(), + Arc::new(move |keys_values| { + Box::pin(capture!([values, engine, k], async move { + // Then evaluate all value expressions. + evaluate_sequence( + values, + engine, + Arc::new(move |values_values| { + Box::pin(capture!([keys_values, k], async move { + // Create a map from keys and values. + let map_items = keys_values.into_iter().zip(values_values).collect(); + k(Value::new(CoreData::Map(Map::from_pairs(map_items)))).await + })) + }), + ) + .await + })) + }), + ) + .await +} + /// Evaluates a reference to a variable. /// /// Looks up the variable in the context and passes its value to the continuation. @@ -330,28 +678,30 @@ pub(crate) async fn evaluate_reference( where O: Send + 'static, { - // Look up the variable in the context. let value = engine .context .lookup(&ident) .unwrap_or_else(|| panic!("Variable not found: {}", ident)) .clone(); - // Pass the value to the continuation. k(value).await } #[cfg(test)] mod tests { - use crate::analyzer::{ - context::Context, - hir::{BinOp, CoreData, Expr, ExprKind, FunKind, Literal, Value}, + use crate::analyzer::hir::{Goal, GroupId, Materializable}; + use crate::engine::Engine; + use crate::utils::tests::{ + array_val, assert_values_equal, create_logical_operator, create_physical_operator, + ref_expr, struct_val, }; - use crate::engine::{ - Engine, - test_utils::{ - TestHarness, array_val, boolean, evaluate_and_collect, int, lit_expr, lit_val, - ref_expr, string, + use crate::{ + analyzer::{ + context::Context, + hir::{BinOp, CoreData, Expr, ExprKind, FunKind, Literal, Value}, + }, + utils::tests::{ + TestHarness, boolean, evaluate_and_collect, int, lit_expr, lit_val, string, }, }; use ExprKind::*; @@ -408,21 +758,21 @@ mod tests { let complex_results = evaluate_and_collect(complex_condition, engine_with_x, harness).await; // Check results - match &true_results[0].0 { + match &true_results[0].data { CoreData::Literal(Literal::String(value)) => { assert_eq!(value, "yes"); // true condition should select "yes" } _ => panic!("Expected string value"), } - match &false_results[0].0 { + match &false_results[0].data { CoreData::Literal(Literal::String(value)) => { assert_eq!(value, "no"); // false condition should select "no" } _ => panic!("Expected string value"), } - match &complex_results[0].0 { + match &complex_results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 40); // 20 * 2 = 40 (since x > 10) } @@ -451,7 +801,7 @@ mod tests { let results = evaluate_and_collect(let_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 15); // 10 + 5 = 15 } @@ -486,7 +836,7 @@ mod tests { let results = evaluate_and_collect(nested_let_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 30); // 10 + (10 * 2) = 30 } @@ -501,7 +851,7 @@ mod tests { let mut ctx = Context::default(); // Define a function: fn(x, y) => x + y - let add_function = Value(CoreData::Function(FunKind::Closure( + let add_function = Value::new(CoreData::Function(FunKind::Closure( vec!["x".to_string(), "y".to_string()], Arc::new(Expr::new(Binary(ref_expr("x"), BinOp::Add, ref_expr("y")))), ))); @@ -518,7 +868,7 @@ mod tests { let results = evaluate_and_collect(call_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 30); // 10 + 20 = 30 } @@ -533,16 +883,16 @@ mod tests { let mut ctx = Context::default(); // Define a Rust UDF that calculates the sum of array elements - let sum_function = Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + let sum_function = Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Array(elements) => { let mut sum = 0; for elem in elements { - if let CoreData::Literal(Literal::Int64(value)) = &elem.0 { + if let CoreData::Literal(Literal::Int64(value)) = &elem.data { sum += value; } } - Value(CoreData::Literal(Literal::Int64(sum))) + Value::new(CoreData::Literal(Literal::Int64(sum))) } _ => panic!("Expected array argument"), } @@ -566,7 +916,7 @@ mod tests { let results = evaluate_and_collect(call_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 15); // 1 + 2 + 3 + 4 + 5 = 15 } @@ -574,6 +924,182 @@ mod tests { } } + /// Test to verify the Map implementation works correctly. + #[tokio::test] + async fn test_map_creation() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a map with key-value pairs: { "a": 1, "b": 2, "c": 3 } + let map_expr = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("a")), lit_expr(int(1))), + (lit_expr(string("b")), lit_expr(int(2))), + (lit_expr(string("c")), lit_expr(int(3))), + ]))); + + // Evaluate the map expression + let results = evaluate_and_collect(map_expr, engine.clone(), harness.clone()).await; + + // Check that we got a Map value + assert_eq!(results.len(), 1); + match &results[0].data { + CoreData::Map(map) => { + // Check that map has the correct key-value pairs + assert_values_equal(&map.get(&lit_val(string("a"))), &lit_val(int(1))); + assert_values_equal(&map.get(&lit_val(string("b"))), &lit_val(int(2))); + assert_values_equal(&map.get(&lit_val(string("c"))), &lit_val(int(3))); + + // Check that non-existent key returns None value + assert_values_equal(&map.get(&lit_val(string("d"))), &Value::new(CoreData::None)); + } + _ => panic!("Expected Map value"), + } + + // Test map with expressions that need evaluation as keys and values + // Map: { "x" + "y": 10 + 5, "a" + "b": 20 * 2 } + let complex_map_expr = Arc::new(Expr::new(Map(vec![ + ( + Arc::new(Expr::new(Binary( + lit_expr(string("x")), + BinOp::Concat, + lit_expr(string("y")), + ))), + Arc::new(Expr::new(Binary( + lit_expr(int(10)), + BinOp::Add, + lit_expr(int(5)), + ))), + ), + ( + Arc::new(Expr::new(Binary( + lit_expr(string("a")), + BinOp::Concat, + lit_expr(string("b")), + ))), + Arc::new(Expr::new(Binary( + lit_expr(int(20)), + BinOp::Mul, + lit_expr(int(2)), + ))), + ), + ]))); + + // Evaluate the complex map expression + let complex_results = evaluate_and_collect(complex_map_expr, engine, harness).await; + + // Check that we got a Map value with correctly evaluated keys and values + assert_eq!(complex_results.len(), 1); + match &complex_results[0].data { + CoreData::Map(map) => { + // Check that map has the correct key-value pairs after evaluation + assert_values_equal(&map.get(&lit_val(string("xy"))), &lit_val(int(15))); + assert_values_equal(&map.get(&lit_val(string("ab"))), &lit_val(int(40))); + } + _ => panic!("Expected Map value"), + } + } + + /// Test map operations with nested maps and lookup + #[tokio::test] + async fn test_map_nested_and_lookup() { + let harness = TestHarness::new(); + let mut ctx = Context::default(); + + // Add a map lookup function + ctx.bind( + "get".to_string(), + Value::new(CoreData::Function(FunKind::RustUDF(|args| { + if args.len() != 2 { + panic!("get function requires 2 arguments"); + } + + match &args[0].data { + CoreData::Map(map) => map.get(&args[1]), + _ => panic!("First argument must be a map"), + } + }))), + ); + + let engine = Engine::new(ctx); + + // Create a nested map: + // { + // "user": { + // "name": "Alice", + // "age": 30, + // "address": { + // "city": "San Francisco", + // "zip": 94105 + // } + // }, + // "settings": { + // "theme": "dark", + // "notifications": true + // } + // } + + // First, create the address map + let address_map = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("city")), lit_expr(string("San Francisco"))), + (lit_expr(string("zip")), lit_expr(int(94105))), + ]))); + + // Then, create the user map with the nested address map + let user_map = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("name")), lit_expr(string("Alice"))), + (lit_expr(string("age")), lit_expr(int(30))), + (lit_expr(string("address")), address_map), + ]))); + + // Create the settings map + let settings_map = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("theme")), lit_expr(string("dark"))), + (lit_expr(string("notifications")), lit_expr(boolean(true))), + ]))); + + // Finally, create the top-level map + let nested_map_expr = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("user")), user_map), + (lit_expr(string("settings")), settings_map), + ]))); + + // First, evaluate the nested map to bind it to a variable + let program = Arc::new(Expr::new(Let( + "data".to_string(), + nested_map_expr, + // Extract user.address.city using get function + Arc::new(Expr::new(Call( + ref_expr("get"), + vec![ + Arc::new(Expr::new(Call( + ref_expr("get"), + vec![ + Arc::new(Expr::new(Call( + ref_expr("get"), + vec![ref_expr("data"), lit_expr(string("user"))], + ))), + lit_expr(string("address")), + ], + ))), + lit_expr(string("city")), + ], + ))), + ))); + + // Evaluate the program + let results = evaluate_and_collect(program, engine, harness).await; + + // Check that we got the correct value from the nested lookup + assert_eq!(results.len(), 1); + match &results[0].data { + CoreData::Literal(Literal::String(value)) => { + assert_eq!(value, "San Francisco"); + } + _ => panic!("Expected string value"), + } + } + /// Test complex program with multiple expression types #[tokio::test] async fn test_complex_program() { @@ -581,7 +1107,7 @@ mod tests { let mut ctx = Context::default(); // Define a function to compute factorial: fn(n) => if n <= 1 then 1 else n * factorial(n-1) - let factorial_function = Value(CoreData::Function(FunKind::Closure( + let factorial_function = Value::new(CoreData::Function(FunKind::Closure( vec!["n".to_string()], Arc::new(Expr::new(IfThenElse( Arc::new(Expr::new(Binary( @@ -633,7 +1159,7 @@ mod tests { let results = evaluate_and_collect(program, engine, harness).await; // Check result: factorial(5) / 3 = 120 / 3 = 40 - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 40); } @@ -641,6 +1167,659 @@ mod tests { } } + /// Test array indexing with call syntax + #[tokio::test] + async fn test_array_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create an array [10, 20, 30, 40, 50] + let array_expr = Arc::new(Expr::new(CoreVal(array_val(vec![ + lit_val(int(10)), + lit_val(int(20)), + lit_val(int(30)), + lit_val(int(40)), + lit_val(int(50)), + ])))); + + // Access array[2] which should be 30 + let index_expr = Arc::new(Expr::new(Call(array_expr, vec![lit_expr(int(2))]))); + + let results = evaluate_and_collect(index_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(30)); + } + _ => panic!("Expected integer literal"), + } + } + + /// Test tuple indexing with call syntax + #[tokio::test] + async fn test_tuple_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a tuple (10, "hello", true) + let tuple_expr = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Tuple(vec![ + lit_val(int(10)), + lit_val(string("hello")), + lit_val(Literal::Bool(true)), + ]))))); + + // Access tuple[1] which should be "hello" + let index_expr = Arc::new(Expr::new(Call(tuple_expr, vec![lit_expr(int(1))]))); + + let results = evaluate_and_collect(index_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("hello".to_string())); + } + _ => panic!("Expected string literal"), + } + } + + /// Test struct field access with call syntax + #[tokio::test] + async fn test_struct_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a struct Point { x: 10, y: 20 } + let struct_expr = Arc::new(Expr::new(CoreVal(struct_val( + "Point", + vec![lit_val(int(10)), lit_val(int(20))], + )))); + + // Access struct[1] which should be 20 (the y field) + let index_expr = Arc::new(Expr::new(Call(struct_expr, vec![lit_expr(int(1))]))); + + let results = evaluate_and_collect(index_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(20)); + } + _ => panic!("Expected integer literal"), + } + } + + /// Test map lookup with call syntax + #[tokio::test] + async fn test_map_lookup() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a map with key-value pairs: { "a": 1, "b": 2, "c": 3 } + // Use a let expression to bind the map and do lookups directly + let test_expr = Arc::new(Expr::new(Let( + "map".to_string(), + Arc::new(Expr::new(Map(vec![ + (lit_expr(string("a")), lit_expr(int(1))), + (lit_expr(string("b")), lit_expr(int(2))), + (lit_expr(string("c")), lit_expr(int(3))), + ]))), + // Create a tuple of map["b"] and map["d"] to test both existing and missing keys + Arc::new(Expr::new(CoreExpr(CoreData::Tuple(vec![ + // map["b"] - should be 2 + Arc::new(Expr::new(Call( + ref_expr("map"), + vec![lit_expr(string("b"))], + ))), + // map["d"] - should be None + Arc::new(Expr::new(Call( + ref_expr("map"), + vec![lit_expr(string("d"))], + ))), + ])))), + ))); + + let results = evaluate_and_collect(test_expr, engine, harness).await; + + // Check result - should be a tuple (2, None) + assert_eq!(results.len(), 1); + match &results[0].data { + CoreData::Tuple(elements) => { + assert_eq!(elements.len(), 2); + + // Check first element: map["b"] should be 2 + match &elements[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(2)); + } + _ => panic!("Expected integer literal for existing key lookup"), + } + + // Check second element: map["d"] should be None + match &elements[1].data { + CoreData::None => {} + _ => panic!("Expected None for missing key lookup"), + } + } + _ => panic!("Expected tuple result"), + } + } + + /// Test complex expressions for both collection and index + #[tokio::test] + async fn test_complex_collection_and_index() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a let expression that binds an array and then accesses it + // let arr = [10, 20, 30, 40, 50] in + // let idx = 2 + 1 in + // arr[idx] // should be 40 + let complex_expr = Arc::new(Expr::new(Let( + "arr".to_string(), + Arc::new(Expr::new(CoreExpr(CoreData::Array(vec![ + lit_expr(int(10)), + lit_expr(int(20)), + lit_expr(int(30)), + lit_expr(int(40)), + lit_expr(int(50)), + ])))), + Arc::new(Expr::new(Let( + "idx".to_string(), + Arc::new(Expr::new(Binary( + lit_expr(int(2)), + BinOp::Add, + lit_expr(int(1)), + ))), + Arc::new(Expr::new(Call(ref_expr("arr"), vec![ref_expr("idx")]))), + ))), + ))); + + let results = evaluate_and_collect(complex_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(40)); + } + _ => panic!("Expected integer literal"), + } + } + + /// Test indexing into a logical operator + #[tokio::test] + async fn test_logical_operator_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a logical operator: LogicalJoin { joinType: "inner", condition: "x = y" } [TableScan("orders"), TableScan("lineitem")] + let join_op = create_logical_operator( + "LogicalJoin", + vec![lit_val(string("inner")), lit_val(string("x = y"))], + vec![ + create_logical_operator("TableScan", vec![lit_val(string("orders"))], vec![]), + create_logical_operator("TableScan", vec![lit_val(string("lineitem"))], vec![]), + ], + ); + + let logical_op_expr = Arc::new(Expr::new(CoreVal(join_op))); + + // Access join_type using indexing - should be "inner" + let join_type_expr = Arc::new(Expr::new(Call( + logical_op_expr.clone(), + vec![lit_expr(int(0))], + ))); + let join_type_results = + evaluate_and_collect(join_type_expr, engine.clone(), harness.clone()).await; + + // Access condition using indexing - should be "x = y" + let condition_expr = Arc::new(Expr::new(Call( + logical_op_expr.clone(), + vec![lit_expr(int(1))], + ))); + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + + // Access first child (orders table scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + logical_op_expr.clone(), + vec![lit_expr(int(2))], + ))); + let first_child_results = + evaluate_and_collect(first_child_expr, engine.clone(), harness.clone()).await; + + // Access second child (lineitem table scan) using indexing + let second_child_expr = Arc::new(Expr::new(Call(logical_op_expr, vec![lit_expr(int(3))]))); + let second_child_results = evaluate_and_collect(second_child_expr, engine, harness).await; + + // Check join_type result + assert_eq!(join_type_results.len(), 1); + match &join_type_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("inner".to_string())); + } + _ => panic!("Expected string literal for join type"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("x = y".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (orders table scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].data { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("orders".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for first child"), + } + + // Check second child result (lineitem table scan) + assert_eq!(second_child_results.len(), 1); + match &second_child_results[0].data { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("lineitem".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for second child"), + } + } + + /// Test indexing into a physical operator + #[tokio::test] + async fn test_physical_operator_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a physical operator: HashJoin { method: "hash", condition: "id = id" } [IndexScan("customers"), ParallelScan("orders")] + let join_op = create_physical_operator( + "HashJoin", + vec![lit_val(string("hash")), lit_val(string("id = id"))], + vec![ + create_physical_operator("IndexScan", vec![lit_val(string("customers"))], vec![]), + create_physical_operator("ParallelScan", vec![lit_val(string("orders"))], vec![]), + ], + ); + + let physical_op_expr = Arc::new(Expr::new(CoreVal(join_op))); + + // Access join method using indexing - should be "hash" + let method_expr = Arc::new(Expr::new(Call( + physical_op_expr.clone(), + vec![lit_expr(int(0))], + ))); + let method_results = + evaluate_and_collect(method_expr, engine.clone(), harness.clone()).await; + + // Access condition using indexing - should be "id = id" + let condition_expr = Arc::new(Expr::new(Call( + physical_op_expr.clone(), + vec![lit_expr(int(1))], + ))); + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + + // Access first child (customers index scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + physical_op_expr.clone(), + vec![lit_expr(int(2))], + ))); + let first_child_results = + evaluate_and_collect(first_child_expr, engine.clone(), harness.clone()).await; + + // Access second child (orders parallel scan) using indexing + let second_child_expr = Arc::new(Expr::new(Call(physical_op_expr, vec![lit_expr(int(3))]))); + let second_child_results = evaluate_and_collect(second_child_expr, engine, harness).await; + + // Check join method result + assert_eq!(method_results.len(), 1); + match &method_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("hash".to_string())); + } + _ => panic!("Expected string literal for join method"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("id = id".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (customers index scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].data { + CoreData::Physical(Materializable::Materialized(phys_op)) => { + assert_eq!(phys_op.operator.tag, "IndexScan"); + assert_eq!(phys_op.operator.data.len(), 1); + match &phys_op.operator.data[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected physical operator for first child"), + } + + // Check second child result (orders parallel scan) + assert_eq!(second_child_results.len(), 1); + match &second_child_results[0].data { + CoreData::Physical(Materializable::Materialized(phys_op)) => { + assert_eq!(phys_op.operator.tag, "ParallelScan"); + assert_eq!(phys_op.operator.data.len(), 1); + match &phys_op.operator.data[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("orders".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected physical operator for second child"), + } + } + + /// Test indexing into an unmaterialized logical operator + #[tokio::test] + async fn test_unmaterialized_logical_operator_indexing() { + let harness = TestHarness::new(); + let test_group_id = GroupId(1); + + // Register a logical operator in the test harness + let materialized_join = create_logical_operator( + "LogicalJoin", + vec![ + lit_val(string("inner")), + lit_val(string("customer.id = order.id")), + ], + vec![ + create_logical_operator("TableScan", vec![lit_val(string("customers"))], vec![]), + create_logical_operator("TableScan", vec![lit_val(string("orders"))], vec![]), + ], + ); + + harness.register_group(test_group_id, materialized_join); + + // Create an unmaterialized logical operator + let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Logical( + Materializable::UnMaterialized(test_group_id), + ))))); + + // Access join type using indexing - should materialize and return "inner" + let join_type_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(0))], + ))); + + // Access condition using indexing - should materialize and return "customer.id = order.id" + let condition_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(1))], + ))); + + // Access first child (customers table scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(2))], + ))); + + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Evaluate the expressions + let join_type_results = + evaluate_and_collect(join_type_expr, engine.clone(), harness.clone()).await; + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + let first_child_results = evaluate_and_collect(first_child_expr, engine, harness).await; + + // Check join type result + assert_eq!(join_type_results.len(), 1); + match &join_type_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("inner".to_string())); + } + _ => panic!("Expected string literal for join type"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customer.id = order.id".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (customers table scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].data { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for first child"), + } + } + + /// Test indexing into an unmaterialized physical operator + #[tokio::test] + async fn test_unmaterialized_physical_operator_indexing() { + let harness = TestHarness::new(); + + // Create a physical goal + let test_group_id = GroupId(2); + let properties = Box::new(Value::new(CoreData::Literal(string("sorted")))); + let test_goal = Goal { + group_id: test_group_id, + properties, + }; + + // Register a physical operator to be returned when the goal is expanded + let materialized_join = create_physical_operator( + "MergeJoin", + vec![ + lit_val(string("merge")), + lit_val(string("customer.id = order.id")), + ], + vec![ + create_physical_operator("SortedScan", vec![lit_val(string("customers"))], vec![]), + create_physical_operator("SortedScan", vec![lit_val(string("orders"))], vec![]), + ], + ); + + harness.register_goal(&test_goal, materialized_join); + + // Create an unmaterialized physical operator + let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Physical( + Materializable::UnMaterialized(test_goal), + ))))); + + // Access join method using indexing - should materialize and return "merge" + let method_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(0))], + ))); + + // Access condition using indexing - should materialize and return "customer.id = order.id" + let condition_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(1))], + ))); + + // Access first child (customers scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(2))], + ))); + + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Evaluate the expressions + let method_results = + evaluate_and_collect(method_expr, engine.clone(), harness.clone()).await; + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + let first_child_results = evaluate_and_collect(first_child_expr, engine, harness).await; + + // Check join method result + assert_eq!(method_results.len(), 1); + match &method_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("merge".to_string())); + } + _ => panic!("Expected string literal for join method"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].data { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customer.id = order.id".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (customers scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].data { + CoreData::Physical(Materializable::Materialized(phys_op)) => { + assert_eq!(phys_op.operator.tag, "SortedScan"); + assert_eq!(phys_op.operator.data.len(), 1); + match &phys_op.operator.data[0].data { + CoreData::Literal(lit) => { + assert_eq!(*lit, Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected physical operator for first child"), + } + } + + /// Test accessing multiple levels of nested operators through indexing + #[tokio::test] + async fn test_nested_operator_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a nested operator: + // Project [col1, col2] ( + // Filter ("age > 30") ( + // Join ("inner", "t1.id = t2.id") ( + // TableScan ("customers"), + // TableScan ("orders") + // ) + // ) + // ) + let nested_op = create_logical_operator( + "Project", + vec![array_val(vec![ + lit_val(string("col1")), + lit_val(string("col2")), + ])], + vec![create_logical_operator( + "Filter", + vec![lit_val(string("age > 30"))], + vec![create_logical_operator( + "Join", + vec![lit_val(string("inner")), lit_val(string("t1.id = t2.id"))], + vec![ + create_logical_operator( + "TableScan", + vec![lit_val(string("customers"))], + vec![], + ), + create_logical_operator( + "TableScan", + vec![lit_val(string("orders"))], + vec![], + ), + ], + )], + )], + ); + + let nested_op_expr = Arc::new(Expr::new(CoreVal(nested_op))); + + // First, access the Filter child of Project (index 1) + let filter_expr = Arc::new(Expr::new(Call(nested_op_expr, vec![lit_expr(int(1))]))); + let filter_results = + evaluate_and_collect(filter_expr, engine.clone(), harness.clone()).await; + + // Now, from the Filter, access its Join child (index 1) + let filter_value = filter_results[0].clone(); + let join_expr = Arc::new(Expr::new(Call( + Arc::new(Expr::new(CoreVal(filter_value))), + vec![lit_expr(int(1))], + ))); + let join_results = evaluate_and_collect(join_expr, engine.clone(), harness.clone()).await; + + // Finally, from the Join, access its first TableScan child (index 2) + let join_value = join_results[0].clone(); + let table_scan_expr = Arc::new(Expr::new(Call( + Arc::new(Expr::new(CoreVal(join_value))), + vec![lit_expr(int(2))], + ))); + let table_scan_results = evaluate_and_collect(table_scan_expr, engine, harness).await; + + // Verify the final result is the "customers" TableScan + assert_eq!(table_scan_results.len(), 1); + match &table_scan_results[0].data { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].data { + CoreData::Literal(lit) => { + assert_eq!(*lit, Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for table scan"), + } + } + /// Test variable reference in various contexts #[tokio::test] async fn test_variable_references() { @@ -663,14 +1842,14 @@ mod tests { let outer_results = evaluate_and_collect(outer_ref, engine.clone(), harness).await; // Check results - match &inner_results[0].0 { + match &inner_results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 200); } _ => panic!("Expected integer value"), } - match &outer_results[0].0 { + match &outer_results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 100); } diff --git a/optd-dsl/src/engine/eval/match.rs b/optd-dsl/src/engine/eval/match.rs index 2ad64b8a..884d3aa9 100644 --- a/optd-dsl/src/engine/eval/match.rs +++ b/optd-dsl/src/engine/eval/match.rs @@ -131,7 +131,7 @@ where O: Send + 'static, { Box::pin(async move { - match (pattern, &value.0) { + match (pattern, &value.data) { // Simple patterns. (Wildcard, _) => k((value, Some(ctx))).await, (Literal(pattern_lit), CoreData::Literal(value_lit)) => { @@ -263,7 +263,7 @@ where // Split array into head and tail. let head = arr[0].clone(); let tail_elements = arr[1..].to_vec(); - let tail = Value(CoreData::Array(tail_elements)); + let tail = Value::new(CoreData::Array(tail_elements)); // Create components to match sequentially. let patterns = vec![head_pattern, tail_pattern]; @@ -284,13 +284,13 @@ where let tail_value = results[1].0.clone(); // Extract tail elements. - let tail_elements = match &tail_value.0 { + let tail_elements = match &tail_value.data { CoreData::Array(elements) => elements.clone(), _ => panic!("Expected Array in tail result"), }; // Create new array with matched head + tail elements. - let new_array = Value(CoreData::Array( + let new_array = Value::new(CoreData::Array( std::iter::once(head_value).chain(tail_elements).collect(), )); @@ -339,7 +339,7 @@ where // Reconstruct struct with matched field values. let matched_values = results.iter().map(|(v, _)| v.clone()).collect(); - let new_struct = Value(CoreData::Struct(pat_name, matched_values)); + let new_struct = Value::new(CoreData::Struct(pat_name, matched_values)); if all_matched { // Combine contexts by folding over the results, starting with the base context. @@ -415,10 +415,10 @@ where // Create appropriate value type based on original_value. let new_value = if is_logical { let log_op = LogicalOp::logical(new_op); - Value(CoreData::Logical(Materialized(log_op))) + Value::new(CoreData::Logical(Materialized(log_op))) } else { let phys_op = PhysicalOp::physical(new_op); - Value(CoreData::Physical(Materialized(phys_op))) + Value::new(CoreData::Physical(Materialized(phys_op))) }; if all_matched { @@ -520,14 +520,11 @@ where #[cfg(test)] mod tests { - use crate::engine::{ - Engine, - test_utils::{ - array_decomp_pattern, array_val, bind_pattern, create_logical_operator, - create_physical_operator, evaluate_and_collect, int, lit_expr, lit_val, - literal_pattern, match_arm, operator_pattern, pattern_match_expr, ref_expr, string, - struct_pattern, struct_val, wildcard_pattern, - }, + use crate::engine::Engine; + use crate::utils::tests::{ + array_decomp_pattern, array_val, bind_pattern, create_logical_operator, + create_physical_operator, evaluate_and_collect, lit_val, literal_pattern, match_arm, + operator_pattern, ref_expr, struct_pattern, struct_val, }; use crate::{ analyzer::{ @@ -537,7 +534,7 @@ mod tests { Materializable, Operator, Value, }, }, - engine::test_utils::TestHarness, + utils::tests::{TestHarness, int, lit_expr, pattern_match_expr, string, wildcard_pattern}, }; use ExprKind::*; use Materializable::*; @@ -564,7 +561,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("matched".to_string())); } @@ -608,7 +605,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(52)); } @@ -660,10 +657,10 @@ mod tests { let mut ctx = Context::default(); ctx.bind( "length".to_string(), - Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Array(elements) => { - Value(CoreData::Literal(int(elements.len() as i64))) + Value::new(CoreData::Literal(int(elements.len() as i64))) } _ => panic!("Expected array"), } @@ -677,7 +674,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(5)); } @@ -725,7 +722,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(30)); } @@ -750,7 +747,7 @@ mod tests { ], }; - let logical_op_value = Value(CoreData::Logical(Materialized(LogicalOp::logical(op)))); + let logical_op_value = Value::new(CoreData::Logical(Materialized(LogicalOp::logical(op)))); let logical_op_expr = Arc::new(Expr::new(CoreVal(logical_op_value.clone()))); // Create a match expression: @@ -818,7 +815,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!( lit, @@ -852,7 +849,8 @@ mod tests { harness.register_group(test_group_id, materialized_join); // Create an unmaterialized logical operator - let unmaterialized_logical_op = Value(CoreData::Logical(UnMaterialized(test_group_id))); + let unmaterialized_logical_op = + Value::new(CoreData::Logical(UnMaterialized(test_group_id))); let unmaterialized_expr = Arc::new(Expr::new(CoreVal(unmaterialized_logical_op))); @@ -924,7 +922,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!( lit, @@ -944,7 +942,7 @@ mod tests { // Create a physical goal let test_group_id = GroupId(2); - let properties = Box::new(Value(CoreData::Literal(string("sorted")))); + let properties = Box::new(Value::new(CoreData::Literal(string("sorted")))); let test_goal = Goal { group_id: test_group_id, properties, @@ -963,7 +961,7 @@ mod tests { harness.register_goal(&test_goal, materialized_hash_join); // Create an unmaterialized physical operator with the goal - let unmaterialized_physical_op = Value(CoreData::Physical(UnMaterialized(test_goal))); + let unmaterialized_physical_op = Value::new(CoreData::Physical(UnMaterialized(test_goal))); let unmaterialized_expr = Arc::new(Expr::new(CoreVal(unmaterialized_physical_op))); @@ -994,12 +992,12 @@ mod tests { // Create formatted result string with binding values Arc::new(Expr::new(Let( "to_string".to_string(), - Arc::new(Expr::new(CoreVal(Value(CoreData::Function( - FunKind::RustUDF(|args| match &args[0].0 { + Arc::new(Expr::new(CoreVal(Value::new(CoreData::Function( + FunKind::RustUDF(|args| match &args[0].data { CoreData::Literal(lit) => { - Value(CoreData::Literal(string(&format!("{:?}", lit)))) + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) } - _ => Value(CoreData::Literal(string(""))), + _ => Value::new(CoreData::Literal(string(""))), }), ))))), Arc::new(Expr::new(Binary( @@ -1048,7 +1046,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { match lit { Literal::String(s) => { @@ -1077,10 +1075,10 @@ mod tests { // Add to_string function to convert numbers to strings ctx.bind( "to_string".to_string(), - Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Literal(Literal::Int64(i)) => { - Value(CoreData::Literal(string(&i.to_string()))) + Value::new(CoreData::Literal(string(&i.to_string()))) } _ => panic!("Expected integer literal"), } @@ -1193,7 +1191,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("Fallthrough arm: 30, 40".to_string())); } @@ -1210,13 +1208,13 @@ mod tests { // Add to_string function to convert complex values to strings ctx.bind( "to_string".to_string(), - Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Literal(lit) => { - Value(CoreData::Literal(string(&format!("{:?}", lit)))) + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) } - CoreData::Array(_) => Value(CoreData::Literal(string(""))), - _ => Value(CoreData::Literal(string(""))), + CoreData::Array(_) => Value::new(CoreData::Literal(string(""))), + _ => Value::new(CoreData::Literal(string(""))), } }))), ); @@ -1351,7 +1349,7 @@ mod tests { // Check result (this is a complex test of correct binding propagation through deeply nested patterns) assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert!(matches!(lit, Literal::String(_))); let result_str = match lit { @@ -1417,8 +1415,8 @@ mod tests { lit_val(string("left.id = right.id")), ], vec![ - Value(CoreData::Logical(UnMaterialized(group_id_1))), - Value(CoreData::Logical(UnMaterialized(group_id_2))), + Value::new(CoreData::Logical(UnMaterialized(group_id_1))), + Value::new(CoreData::Logical(UnMaterialized(group_id_2))), ], ); @@ -1488,7 +1486,7 @@ mod tests { ]; for result in &results { - match &result.0 { + match &result.data { CoreData::Literal(Literal::String(s)) => { assert!( expected_combinations @@ -1503,7 +1501,7 @@ mod tests { // Ensure we got each combination exactly once (no duplicates) let mut result_strings = Vec::new(); for result in &results { - if let CoreData::Literal(Literal::String(s)) = &result.0 { + if let CoreData::Literal(Literal::String(s)) = &result.data { result_strings.push(s.clone()); } } diff --git a/optd-dsl/src/engine/eval/operator.rs b/optd-dsl/src/engine/eval/operator.rs index f9168ac4..b6ed1449 100644 --- a/optd-dsl/src/engine/eval/operator.rs +++ b/optd-dsl/src/engine/eval/operator.rs @@ -27,7 +27,7 @@ where // For unmaterialized operators, directly call the continuation with the unmaterialized // value. UnMaterialized(group_id) => { - let result = Value(Logical(UnMaterialized(group_id))); + let result = Value::new(Logical(UnMaterialized(group_id))); k(result).await } // For materialized operators, evaluate all parts and construct the result. @@ -41,7 +41,7 @@ where Box::pin(capture!([k], async move { // Wrap the constructed operator in the logical operator structure let log_op = LogicalOp::logical(constructed_op); - let result = Value(Logical(Materialized(log_op))); + let result = Value::new(Logical(Materialized(log_op))); k(result).await })) }), @@ -69,7 +69,7 @@ where match op { // For unmaterialized operators, continue with the unmaterialized value. UnMaterialized(physical_goal) => { - let result = Value(Physical(UnMaterialized(physical_goal))); + let result = Value::new(Physical(UnMaterialized(physical_goal))); k(result).await } // For materialized operators, evaluate all parts and construct the result. @@ -82,7 +82,7 @@ where Arc::new(move |constructed_op| { Box::pin(capture!([k], async move { let phys_op = PhysicalOp::physical(constructed_op); - let result = Value(Physical(Materialized(phys_op))); + let result = Value::new(Physical(Materialized(phys_op))); k(result).await })) }), @@ -145,16 +145,16 @@ where #[cfg(test)] mod tests { - use crate::analyzer::{ - context::Context, - hir::{ - BinOp, CoreData, Expr, ExprKind, Goal, GroupId, Literal, LogicalOp, Materializable, - Operator, PhysicalOp, Value, + use crate::engine::Engine; + use crate::{ + analyzer::{ + context::Context, + hir::{ + BinOp, CoreData, Expr, ExprKind, Goal, GroupId, Literal, LogicalOp, Materializable, + Operator, PhysicalOp, Value, + }, }, - }; - use crate::engine::{ - Engine, - test_utils::{ + utils::tests::{ TestHarness, create_logical_operator, evaluate_and_collect, int, lit_expr, lit_val, string, }, @@ -196,20 +196,20 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { // Check tag assert_eq!(log_op.operator.tag, "Join"); // Check data - should have "inner" and 15 assert_eq!(log_op.operator.data.len(), 2); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } _ => panic!("Expected string literal"), } - match &log_op.operator.data[1].0 { + match &log_op.operator.data[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(15)); } @@ -218,13 +218,13 @@ mod tests { // Check children - should have "orders" and "lineitem" assert_eq!(log_op.operator.children.len(), 2); - match &log_op.operator.children[0].0 { + match &log_op.operator.children[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("orders".to_string())); } _ => panic!("Expected string literal"), } - match &log_op.operator.children[1].0 { + match &log_op.operator.children[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("lineitem".to_string())); } @@ -255,7 +255,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Logical(Materializable::UnMaterialized(id)) => { // Check that the group ID is preserved assert_eq!(*id, group_id); @@ -298,20 +298,20 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Physical(Materializable::Materialized(phys_op)) => { // Check tag assert_eq!(phys_op.operator.tag, "HashJoin"); // Check data - should have "inner" and 60 assert_eq!(phys_op.operator.data.len(), 2); - match &phys_op.operator.data[0].0 { + match &phys_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } _ => panic!("Expected string literal"), } - match &phys_op.operator.data[1].0 { + match &phys_op.operator.data[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(60)); } @@ -320,13 +320,13 @@ mod tests { // Check children - should have "IndexScan" and "FullScan" assert_eq!(phys_op.operator.children.len(), 2); - match &phys_op.operator.children[0].0 { + match &phys_op.operator.children[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("IndexScan".to_string())); } _ => panic!("Expected string literal"), } - match &phys_op.operator.children[1].0 { + match &phys_op.operator.children[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("FullScan".to_string())); } @@ -347,7 +347,7 @@ mod tests { // Create an unmaterialized physical operator with a goal let goal = Goal { group_id: GroupId(42), - properties: Box::new(Value(CoreData::Literal(Literal::String( + properties: Box::new(Value::new(CoreData::Literal(Literal::String( "sorted".to_string(), )))), }; @@ -362,13 +362,13 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Physical(Materializable::UnMaterialized(result_goal)) => { // Check that the goal is preserved assert_eq!(result_goal.group_id, goal.group_id); // Check properties - match &result_goal.properties.0 { + match &result_goal.properties.data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("sorted".to_string())); } @@ -415,14 +415,14 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { // Check tag assert_eq!(log_op.operator.tag, "Join"); // Check data assert_eq!(log_op.operator.data.len(), 1); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } @@ -433,11 +433,11 @@ mod tests { assert_eq!(log_op.operator.children.len(), 2); // Check first child - match &log_op.operator.children[0].0 { + match &log_op.operator.children[0].data { CoreData::Logical(Materializable::Materialized(child_log_op)) => { assert_eq!(child_log_op.operator.tag, "Scan"); assert_eq!(child_log_op.operator.data.len(), 1); - match &child_log_op.operator.data[0].0 { + match &child_log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("orders".to_string())); } @@ -448,11 +448,11 @@ mod tests { } // Check second child - match &log_op.operator.children[1].0 { + match &log_op.operator.children[1].data { CoreData::Logical(Materializable::Materialized(child_log_op)) => { assert_eq!(child_log_op.operator.tag, "Scan"); assert_eq!(child_log_op.operator.data.len(), 1); - match &child_log_op.operator.data[0].0 { + match &child_log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("lineitem".to_string())); } diff --git a/optd-dsl/src/engine/eval/unary.rs b/optd-dsl/src/engine/eval/unary.rs index 619231f4..c6456a44 100644 --- a/optd-dsl/src/engine/eval/unary.rs +++ b/optd-dsl/src/engine/eval/unary.rs @@ -26,16 +26,13 @@ use UnaryOp::*; /// # Panics /// Panics when the operation is not defined for the given operand type pub(crate) fn eval_unary_op(op: &UnaryOp, expr: Value) -> Value { - match (op, &expr.0) { + match (op, &expr.data) { // Numeric negation for integers - (Neg, Literal(Int64(x))) => Value(Literal(Int64(-x))), - + (Neg, Literal(Int64(x))) => Value::new(Literal(Int64(-x))), // Numeric negation for floating-point numbers - (Neg, Literal(Float64(x))) => Value(Literal(Float64(-x))), - + (Neg, Literal(Float64(x))) => Value::new(Literal(Float64(-x))), // Logical NOT for boolean values - (Not, Literal(Bool(x))) => Value(Literal(Bool(!x))), - + (Not, Literal(Bool(x))) => Value::new(Literal(Bool(!x))), // Any other combination is invalid _ => panic!("Invalid unary operation or type mismatch"), } @@ -47,37 +44,37 @@ mod tests { // Helper function to create integer Value fn int(i: i64) -> Value { - Value(Literal(Int64(i))) + Value::new(Literal(Int64(i))) } // Helper function to create float Value fn float(f: f64) -> Value { - Value(Literal(Float64(f))) + Value::new(Literal(Float64(f))) } // Helper function to create boolean Value fn boolean(b: bool) -> Value { - Value(Literal(Bool(b))) + Value::new(Literal(Bool(b))) } #[test] fn test_integer_negation() { // Negating a positive integer - if let Literal(Int64(result)) = eval_unary_op(&Neg, int(5)).0 { + if let Literal(Int64(result)) = eval_unary_op(&Neg, int(5)).data { assert_eq!(result, -5); } else { panic!("Expected Int64"); } // Negating a negative integer - if let Literal(Int64(result)) = eval_unary_op(&Neg, int(-7)).0 { + if let Literal(Int64(result)) = eval_unary_op(&Neg, int(-7)).data { assert_eq!(result, 7); } else { panic!("Expected Int64"); } // Negating zero - if let Literal(Int64(result)) = eval_unary_op(&Neg, int(0)).0 { + if let Literal(Int64(result)) = eval_unary_op(&Neg, int(0)).data { assert_eq!(result, 0); } else { panic!("Expected Int64"); @@ -89,21 +86,21 @@ mod tests { use std::f64::consts::PI; // Negating a positive float - if let Literal(Float64(result)) = eval_unary_op(&Neg, float(PI)).0 { + if let Literal(Float64(result)) = eval_unary_op(&Neg, float(PI)).data { assert_eq!(result, -PI); } else { panic!("Expected Float64"); } // Negating a negative float - if let Literal(Float64(result)) = eval_unary_op(&Neg, float(-2.5)).0 { + if let Literal(Float64(result)) = eval_unary_op(&Neg, float(-2.5)).data { assert_eq!(result, 2.5); } else { panic!("Expected Float64"); } // Negating zero - if let Literal(Float64(result)) = eval_unary_op(&Neg, float(0.0)).0 { + if let Literal(Float64(result)) = eval_unary_op(&Neg, float(0.0)).data { assert_eq!(result, -0.0); // Checking sign bit for -0.0 assert!(result.to_bits() & 0x8000_0000_0000_0000 != 0); @@ -115,14 +112,14 @@ mod tests { #[test] fn test_boolean_not() { // NOT true - if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(true)).0 { + if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(true)).data { assert!(!result); } else { panic!("Expected Bool"); } // NOT false - if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(false)).0 { + if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(false)).data { assert!(result); } else { panic!("Expected Bool"); diff --git a/optd-dsl/src/engine/mod.rs b/optd-dsl/src/engine/mod.rs index 72ac7a0f..7c5d1436 100644 --- a/optd-dsl/src/engine/mod.rs +++ b/optd-dsl/src/engine/mod.rs @@ -3,12 +3,12 @@ use crate::analyzer::{ hir::{Expr, ExprKind, Goal, GroupId, Value}, }; use ExprKind::*; -use eval::core::evaluate_core_expr; use eval::expr::{ - evaluate_binary_expr, evaluate_function_call, evaluate_if_then_else, evaluate_let_binding, + evaluate_binary_expr, evaluate_call, evaluate_if_then_else, evaluate_let_binding, evaluate_reference, evaluate_unary_expr, }; use eval::r#match::evaluate_pattern_match; +use eval::{core::evaluate_core_expr, expr::evaluate_map}; use std::sync::Arc; mod eval; @@ -16,9 +16,6 @@ mod eval; mod utils; pub use utils::*; -#[cfg(test)] -mod test_utils; - /// The engine response type, which can be either a return value with a converter callback /// or a yielded group/goal with a continuation for further processing. pub enum EngineResponse { @@ -98,7 +95,8 @@ impl Engine { evaluate_binary_expr(left.clone(), op.clone(), right.clone(), self, k).await } Unary(op, expr) => evaluate_unary_expr(op.clone(), expr.clone(), self, k).await, - Call(fun, args) => evaluate_function_call(fun.clone(), args.clone(), self, k).await, + Call(fun, args) => evaluate_call(fun.clone(), args.clone(), self, k).await, + Map(map) => evaluate_map(map.clone(), self, k).await, Ref(ident) => evaluate_reference(ident.clone(), self, k).await, CoreExpr(expr) => evaluate_core_expr(expr.clone(), self, k).await, CoreVal(val) => k(val.clone()).await, diff --git a/optd-dsl/src/lexer/error.rs b/optd-dsl/src/lexer/error.rs index 215ab035..aef7c80c 100644 --- a/optd-dsl/src/lexer/error.rs +++ b/optd-dsl/src/lexer/error.rs @@ -16,12 +16,12 @@ pub struct LexerError { impl LexerError { /// Creates a new lexer error from source code and a Chumsky error. - pub fn new(src_code: String, error: Simple) -> Self { - Self { src_code, error } + pub fn new(src_code: String, error: Simple) -> Box { + Self { src_code, error }.into() } } -impl Diagnose for LexerError { +impl Diagnose for Box { fn report(&self) -> Report { match self.error.reason() { SimpleReason::Custom(msg) => { diff --git a/optd-dsl/src/lexer/lex.rs b/optd-dsl/src/lexer/lex.rs index d577dc67..2b2e4048 100644 --- a/optd-dsl/src/lexer/lex.rs +++ b/optd-dsl/src/lexer/lex.rs @@ -140,6 +140,8 @@ fn lexer() -> impl Parser, Error = Simple> just(".").to(Token::Dot), just(":").to(Token::Colon), just("?").to(Token::Question), + just("$").to(Token::Dollar), + just("#").to(Token::HashTag), )); let comments = just("//") diff --git a/optd-dsl/src/lexer/tokens.rs b/optd-dsl/src/lexer/tokens.rs index 2d416a3c..77803604 100644 --- a/optd-dsl/src/lexer/tokens.rs +++ b/optd-dsl/src/lexer/tokens.rs @@ -64,6 +64,8 @@ pub enum Token { Colon, // : UnderScore, // _ Question, // ? + Dollar, // $ + HashTag, // # } pub const ALL_DELIMITERS: [(Token, Token); 3] = [ @@ -137,6 +139,8 @@ impl std::fmt::Display for Token { Token::Colon => write!(f, ":"), Token::UnderScore => write!(f, "_"), Token::Question => write!(f, "?"), + Token::Dollar => write!(f, "$"), + Token::HashTag => write!(f, "#"), } } } diff --git a/optd-dsl/src/parser/ast.rs b/optd-dsl/src/parser/ast.rs index 87faff1f..e3d76ef1 100644 --- a/optd-dsl/src/parser/ast.rs +++ b/optd-dsl/src/parser/ast.rs @@ -35,12 +35,17 @@ pub enum Type { Tuple(Vec>), /// Map/dictionary with key and value types Map(Spanned, Spanned), - /// Optional type (represents a value that might be absent) - Optional(Spanned), - // User defined types - /// Algebraic Data Type reference by name - Adt(Identifier), + // Type wrappers + /// For types ending with ? + Questioned(Spanned), + /// For types ending with * + Starred(Spanned), + /// For types ending with $ + Dollared(Spanned), + + // The type identified by a name (e.g., ADT, Logical, Physical, generics, etc.) + Identifier(Identifier), } /// Represents a field in a record or a parameter in a function @@ -96,7 +101,7 @@ pub enum Expr { Unary(UnaryOp, Spanned), // Function-related - /// Postfix operations (function call, member access) + /// Postfix operations (function call, member access, member call) Postfix(Spanned, PostfixOp), /// Anonymous function definition Closure(Vec>, Spanned), @@ -216,10 +221,12 @@ pub enum UnaryOp { /// Represents postfix operations (function call, member access, composition) #[derive(Debug, Clone)] pub enum PostfixOp { - /// Function or method call with arguments + /// Function or method call with arguments [?](..) Call(Vec>), - /// Member/field access - Member(Identifier), + /// Struct field access [?]#field + Field(Identifier), + /// Method access [?].method + Method(Identifier), } /// Represents a function definition diff --git a/optd-dsl/src/parser/error.rs b/optd-dsl/src/parser/error.rs index 0c708595..d0de0dd3 100644 --- a/optd-dsl/src/parser/error.rs +++ b/optd-dsl/src/parser/error.rs @@ -19,12 +19,12 @@ pub struct ParserError { impl ParserError { /// Creates a new parser error from source code and a Chumsky error. - pub fn new(src_code: String, error: Simple) -> Self { - Self { src_code, error } + pub fn new(src_code: String, error: Simple) -> Box { + Self { src_code, error }.into() } } -impl Diagnose for ParserError { +impl Diagnose for Box { fn report(&self) -> Report { match self.error.reason() { SimpleReason::Unclosed { span, delimiter } => { diff --git a/optd-dsl/src/parser/expr.rs b/optd-dsl/src/parser/expr.rs index 470486ae..28b5937b 100644 --- a/optd-dsl/src/parser/expr.rs +++ b/optd-dsl/src/parser/expr.rs @@ -198,9 +198,12 @@ pub fn expr_parser() -> impl Parser, Error = Simple name }) + .map(PostfixOp::Field), just(Token::Dot) .ignore_then(select! { Token::TermIdent(name) => name }) - .map(PostfixOp::Member), + .map(PostfixOp::Method), )) .map_with_span(|op, span| (op, span)) .repeated(), @@ -399,7 +402,7 @@ mod tests { assert_expr_eq(&a.value, &e.value); } } - (PostfixOp::Member(a_member), PostfixOp::Member(e_member)) => { + (PostfixOp::Field(a_member), PostfixOp::Field(e_member)) => { assert_eq!(a_member, e_member); } _ => panic!( @@ -695,7 +698,7 @@ mod tests { } // Complex map with expressions as keys and values - let (result, errors) = parse_expr("{x + 1: y * 2, \"key\": z.field}"); + let (result, errors) = parse_expr("{x + 1: y * 2, \"key\": z#field}"); assert!( result.is_some(), "Expected successful parse for complex map" @@ -720,13 +723,13 @@ mod tests { panic!("Expected binary multiplication in first value"); } - // Second entry: "key": z.field + // Second entry: "key": z#field assert_expr_eq( &entries[1].0.value, &Expr::Literal(Literal::String("key".to_string())), ); - if let Expr::Postfix(expr, PostfixOp::Member(member)) = &*entries[1].1.value { + if let Expr::Postfix(expr, PostfixOp::Field(member)) = &*entries[1].1.value { assert_expr_eq(&expr.value, &Expr::Ref("z".to_string())); assert_eq!(member, "field"); } else { @@ -824,7 +827,7 @@ mod tests { fn test_nested_expressions() { // Test deeply nested expressions with different expression types let (result, errors) = parse_expr( - "if ({[1 + 2, 3]: 5}.size > 0) then { let x = 42 in x } else fail(\"Empty map\")", + "if ({[1 + 2, 3]: 5}#size > 0) then { let x = 42 in x } else fail(\"Empty map\")", ); assert!( @@ -842,7 +845,7 @@ mod tests { // Test condition with member access and binary operation if let Expr::Binary(left, BinOp::Gt, right) = &*condition.value { // Test member access on map expression - if let Expr::Postfix(expr, PostfixOp::Member(member)) = &*left.value { + if let Expr::Postfix(expr, PostfixOp::Field(member)) = &*left.value { assert_eq!(member, "size"); // Check that expression is a Map @@ -881,7 +884,7 @@ mod tests { panic!("Expected Map expression"); } } else { - panic!("Expected member access in condition"); + panic!("Expected field access in condition"); } assert_expr_eq(&right.value, &Expr::Literal(Literal::Int64(0))); @@ -1268,11 +1271,11 @@ mod tests { assert_eq!(args2.len(), 1); if let Expr::Postfix(inner3, PostfixOp::Call(args3)) = &*inner2.value { assert_eq!(args3.len(), 1); - if let Expr::Postfix(obj, PostfixOp::Member(method)) = &*inner3.value { + if let Expr::Postfix(obj, PostfixOp::Method(method)) = &*inner3.value { assert_expr_eq(&obj.value, &Expr::Ref("obj".to_string())); assert_eq!(method, "method"); } else { - panic!("Expected member access at base"); + panic!("Expected method access at base"); } } else { panic!("Expected third call"); @@ -1286,15 +1289,15 @@ mod tests { } // Test function call followed by field access - let (result, errors) = parse_expr("func().field"); + let (result, errors) = parse_expr("func()#field"); assert!( result.is_some(), - "Expected successful parse for func().field" + "Expected successful parse for func()#field" ); - assert!(errors.is_empty(), "Expected no errors for func().field"); + assert!(errors.is_empty(), "Expected no errors for func()#field"); if let Some(expr) = result { - if let Expr::Postfix(inner, PostfixOp::Member(member)) = &*expr.value { + if let Expr::Postfix(inner, PostfixOp::Field(member)) = &*expr.value { assert_eq!(member, "field"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner.value { assert_eq!(args.len(), 0); @@ -1303,12 +1306,12 @@ mod tests { panic!("Expected call in function call"); } } else { - panic!("Expected member access"); + panic!("Expected field access"); } } - // Test complex chain of operations - let (result, errors) = parse_expr("obj.method().field.other_method(arg)"); + // Test complex chain of operations with mix of field and method + let (result, errors) = parse_expr("obj.method()#field.other_method(arg)"); assert!( result.is_some(), "Expected successful parse for complex chain" @@ -1320,47 +1323,47 @@ mod tests { assert_eq!(args.len(), 1); assert_expr_eq(&args[0].value, &Expr::Ref("arg".to_string())); - if let Expr::Postfix(inner2, PostfixOp::Member(member)) = &*inner1.value { + if let Expr::Postfix(inner2, PostfixOp::Method(member)) = &*inner1.value { assert_eq!(member, "other_method"); - if let Expr::Postfix(inner3, PostfixOp::Member(field)) = &*inner2.value { + if let Expr::Postfix(inner3, PostfixOp::Field(field)) = &*inner2.value { assert_eq!(field, "field"); if let Expr::Postfix(inner4, PostfixOp::Call(method_args)) = &*inner3.value { assert_eq!(method_args.len(), 0); - if let Expr::Postfix(obj, PostfixOp::Member(method)) = &*inner4.value { + if let Expr::Postfix(obj, PostfixOp::Method(method)) = &*inner4.value { assert_expr_eq(&obj.value, &Expr::Ref("obj".to_string())); assert_eq!(method, "method"); } else { - panic!("Expected member access for initial method"); + panic!("Expected method access for initial method"); } } else { panic!("Expected call for first method"); } } else { - panic!("Expected member access for field"); + panic!("Expected field access"); } } else { - panic!("Expected member access for final method"); + panic!("Expected method access for final method"); } } else { panic!("Expected call at top level"); } } - // Test compose operator + // Test method reference without call let (result, errors) = parse_expr("map(dat).filter"); assert!( result.is_some(), - "Expected successful parse for compose operator" + "Expected successful parse for method reference" ); - assert!(errors.is_empty(), "Expected no errors for compose operator"); + assert!(errors.is_empty(), "Expected no errors for method reference"); if let Some(expr) = result { - if let Expr::Postfix(inner, PostfixOp::Member(name)) = &*expr.value { + if let Expr::Postfix(inner, PostfixOp::Method(name)) = &*expr.value { assert_eq!(name, "filter"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner.value { @@ -1368,32 +1371,32 @@ mod tests { assert_eq!(args.len(), 1); assert_expr_eq(&args[0].value, &Expr::Ref("dat".to_string())); } else { - panic!("Expected function call before compose operator"); + panic!("Expected function call before method reference"); } } else { - panic!("Expected compose operator"); + panic!("Expected method reference"); } } - // Test chained compose operators + // Test chained method references let (result, errors) = parse_expr("transform(input).map.filter.reduce"); assert!( result.is_some(), - "Expected successful parse for chained compose operators" + "Expected successful parse for chained method references" ); assert!( errors.is_empty(), - "Expected no errors for chained compose operators" + "Expected no errors for chained method references" ); if let Some(expr) = result { - if let Expr::Postfix(inner1, PostfixOp::Member(name1)) = &*expr.value { + if let Expr::Postfix(inner1, PostfixOp::Method(name1)) = &*expr.value { assert_eq!(name1, "reduce"); - if let Expr::Postfix(inner2, PostfixOp::Member(name2)) = &*inner1.value { + if let Expr::Postfix(inner2, PostfixOp::Method(name2)) = &*inner1.value { assert_eq!(name2, "filter"); - if let Expr::Postfix(inner3, PostfixOp::Member(name3)) = &*inner2.value { + if let Expr::Postfix(inner3, PostfixOp::Method(name3)) = &*inner2.value { assert_eq!(name3, "map"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner3.value { @@ -1404,13 +1407,57 @@ mod tests { panic!("Expected function call at the beginning of chain"); } } else { - panic!("Expected first compose operation in chain"); + panic!("Expected first method reference in chain"); + } + } else { + panic!("Expected second method reference in chain"); + } + } else { + panic!("Expected third method reference in chain"); + } + } + + // Test mixed field access and method references + let (result, errors) = parse_expr("person#name.split(\"#\")#length"); + assert!( + result.is_some(), + "Expected successful parse for mixed field and method chain" + ); + assert!( + errors.is_empty(), + "Expected no errors for mixed field and method chain" + ); + + if let Some(expr) = result { + if let Expr::Postfix(inner1, PostfixOp::Field(field_name)) = &*expr.value { + assert_eq!(field_name, "length"); + + if let Expr::Postfix(inner2, PostfixOp::Call(args)) = &*inner1.value { + assert_eq!(args.len(), 1); + if let Expr::Literal(Literal::String(s)) = &*args[0].value { + assert_eq!(s, "#"); + } else { + panic!("Expected string literal argument"); + } + + if let Expr::Postfix(inner3, PostfixOp::Method(method_name)) = &*inner2.value { + assert_eq!(method_name, "split"); + + if let Expr::Postfix(inner4, PostfixOp::Field(field_name)) = &*inner3.value + { + assert_eq!(field_name, "name"); + assert_expr_eq(&inner4.value, &Expr::Ref("person".to_string())); + } else { + panic!("Expected field access at beginning of chain"); + } + } else { + panic!("Expected method reference in middle of chain"); } } else { - panic!("Expected second compose operation in chain"); + panic!("Expected method call in chain"); } } else { - panic!("Expected third compose operation in chain"); + panic!("Expected field access at end of chain"); } } } diff --git a/optd-dsl/src/parser/function.rs b/optd-dsl/src/parser/function.rs index aadc5d9c..5989904b 100644 --- a/optd-dsl/src/parser/function.rs +++ b/optd-dsl/src/parser/function.rs @@ -271,10 +271,12 @@ mod tests { let params = func.value.params.as_ref().unwrap(); assert_eq!(params.len(), 1); assert_eq!(*params[0].value.name.value, "x"); - assert!(matches!(*params[0].clone().value.ty.value, Type::Adt(name) if name == "T")); + assert!( + matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "T") + ); // Check return type - assert!(matches!(*func.value.return_type.value, Type::Adt(name) if name == "T")); + assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "T")); // Check body assert!(func.value.body.is_some()); @@ -305,19 +307,21 @@ mod tests { // Check first parameter (map: {K: V}) assert_eq!(*params[0].value.name.value, "map"); if let Type::Map(key_ty, val_ty) = &*params[0].value.ty.value { - assert!(matches!(*key_ty.clone().value, Type::Adt(name) if name == "K")); - assert!(matches!(*val_ty.clone().value, Type::Adt(name) if name == "V")); + assert!(matches!(*key_ty.clone().value, Type::Identifier(name) if name == "K")); + assert!(matches!(*val_ty.clone().value, Type::Identifier(name) if name == "V")); } else { panic!("Expected Map type for first parameter"); } // Check second parameter (key: K) assert_eq!(*params[1].value.name.value, "key"); - assert!(matches!(*params[1].clone().value.ty.value, Type::Adt(name) if name == "K")); + assert!( + matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "K") + ); // Check return type (V?) - if let Type::Optional(inner) = &*func.value.return_type.value { - assert!(matches!(*inner.clone().value, Type::Adt(name) if name == "V")); + if let Type::Questioned(inner) = &*func.value.return_type.value { + assert!(matches!(*inner.clone().value, Type::Identifier(name) if name == "V")); } else { panic!("Expected Optional return type"); } @@ -346,12 +350,16 @@ mod tests { let params = func.value.params.as_ref().unwrap(); assert_eq!(params.len(), 2); assert_eq!(*params[0].value.name.value, "a"); - assert!(matches!(*params[0].clone().value.ty.value, Type::Adt(name) if name == "A")); + assert!( + matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "A") + ); assert_eq!(*params[1].value.name.value, "b"); - assert!(matches!(*params[1].clone().value.ty.value, Type::Adt(name) if name == "B")); + assert!( + matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "B") + ); // Check return type - assert!(matches!(*func.value.return_type.value, Type::Adt(name) if name == "C")); + assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "C")); // Check body is None assert!(func.value.body.is_none()); @@ -408,7 +416,9 @@ mod tests { assert!(func.value.receiver.is_some()); if let Some(receiver) = &func.value.receiver { assert_eq!(*receiver.value.name.value, "self"); - assert!(matches!(&*receiver.value.ty.value, Type::Adt(name) if name == "Person")); + assert!( + matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person") + ); } // Check params. @@ -438,7 +448,9 @@ mod tests { assert!(func.value.receiver.is_some()); if let Some(receiver) = &func.value.receiver { assert_eq!(*receiver.value.name.value, "self"); - assert!(matches!(&*receiver.value.ty.value, Type::Adt(name) if name == "Person")); + assert!( + matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person") + ); } // Check params. diff --git a/optd-dsl/src/parser/module.rs b/optd-dsl/src/parser/module.rs index 6b076df9..9abcc0e8 100644 --- a/optd-dsl/src/parser/module.rs +++ b/optd-dsl/src/parser/module.rs @@ -1,12 +1,11 @@ -use crate::lexer::tokens::Token; -use crate::utils::error::CompileError; -use crate::utils::span::Span; -use chumsky::{Stream, prelude::*}; - use super::adt::adt_parser; use super::ast::{Item, Module}; use super::error::ParserError; use super::function::function_parser; +use crate::lexer::tokens::Token; +use crate::utils::error::CompileError; +use crate::utils::span::Span; +use chumsky::{Stream, prelude::*}; /// Parses a vector of tokens into a module AST. /// Uses Chumsky for parsing and Ariadne for error reporting. diff --git a/optd-dsl/src/parser/type.rs b/optd-dsl/src/parser/type.rs index cbdd282f..681334e9 100644 --- a/optd-dsl/src/parser/type.rs +++ b/optd-dsl/src/parser/type.rs @@ -1,3 +1,8 @@ +use super::{ast::Type, utils::delimited_parser}; +use crate::{ + lexer::tokens::Token, + utils::span::{Span, Spanned}, +}; use chumsky::{ Parser, error::Simple, @@ -5,13 +10,6 @@ use chumsky::{ select, }; -use crate::{ - lexer::tokens::Token, - utils::span::{Span, Spanned}, -}; - -use super::{ast::Type, utils::delimited_parser}; - /// Creates a parser for type expressions. /// /// This parser supports: @@ -22,18 +20,8 @@ use super::{ast::Type, utils::delimited_parser}; /// - Function types: T1 -> T2 /// - User-defined types: TypeName /// - Optional types: T? -/// -/// Syntax examples: -/// - I64 - Integer type -/// - String - String type -/// - \[I64\] - Array of integers -/// - {String : I64} - Map from strings to integers -/// - (I64, String) - Tuple with integer and string -/// - I64 -> String - Function from integer to string -/// - I64? - Optional integer -/// - \[String\]? - Optional array of strings -/// - I64 -> String? - Function returning optional string -/// - (I64 -> String)? - Optional function +/// - Starred types: T* +/// - Dollared types: T$ /// /// The parser follows standard precedence rules and supports /// arbitrary nesting of type expressions. @@ -89,8 +77,8 @@ pub fn type_parser() -> impl Parser, Error = Simple Type::Adt(name) }.map_with_span(Spanned::new); + let data_type = select! { Token::TypeIdent(name) => Type::Identifier(name) } + .map_with_span(Spanned::new); // Note: cannot apply delimiter recovery, as its recovery // would block further successful parses (e.g. tuples). @@ -123,18 +111,23 @@ pub fn type_parser() -> impl Parser, Error = Simple param_type, }); - // Process optional types + // Process optional, starred, and dollared types function_type .then( - just(Token::Question) + choice((just(Token::Question), just(Token::Mul), just(Token::Dollar))) .repeated() .at_least(0) .collect::>(), ) - .map_with_span(|(base_type, question_marks), span| { + .map_with_span(|(base_type, modifiers), span| { let mut result = base_type; - for _ in question_marks { - result = Spanned::new(Type::Optional(result), span.clone()); + for modifier in modifiers { + result = match modifier { + Token::Question => Spanned::new(Type::Questioned(result), span.clone()), + Token::Mul => Spanned::new(Type::Starred(result), span.clone()), + Token::Dollar => Spanned::new(Type::Dollared(result), span.clone()), + _ => unreachable!("Invalid type modifier"), + }; } result }) @@ -258,33 +251,33 @@ mod tests { // Basic optional types let result = parse_type("I64?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Int64) + Type::Questioned(inner) if matches!(*inner.value, Type::Int64) )); let result = parse_type("String?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::String) + Type::Questioned(inner) if matches!(*inner.value, Type::String) )); // Nested optional types let result = parse_type("I64??").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.clone().value, - Type::Optional(inner_inner) if matches!(*inner_inner.value, Type::Int64) + Type::Questioned(inner) if matches!(*inner.clone().value, + Type::Questioned(inner_inner) if matches!(*inner_inner.value, Type::Int64) ) )); // Complex types with optional let result = parse_type("[I64]?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.clone().value, + Type::Questioned(inner) if matches!(*inner.clone().value, Type::Array(arr_inner) if matches!(*arr_inner.value, Type::Int64) ) )); let result = parse_type("(I64, String)?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Tuple(_)) + Type::Questioned(inner) if matches!(*inner.value, Type::Tuple(_)) )); // Function return type is optional @@ -292,21 +285,156 @@ mod tests { assert!(matches!(*result.value, Type::Closure(param, ret) if matches!(*param.value, Type::Int64) - && matches!(*ret.value, Type::Optional(_)) + && matches!(*ret.value, Type::Questioned(_)) && matches!(*ret.value.clone(), - Type::Optional(inner) if matches!(*inner.value, Type::String)) + Type::Questioned(inner) if matches!(*inner.value, Type::String)) )); // Entire function type is optional let result = parse_type("(I64 -> String)?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Closure(_, _)) + Type::Questioned(inner) if matches!(*inner.value, Type::Closure(_, _)) )); // Optional map type let result = parse_type("{String : I64}?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Map(_, _)) + Type::Questioned(inner) if matches!(*inner.value, Type::Map(_, _)) + )); + } + + #[test] + fn test_starred_types() { + // Basic starred types + let result = parse_type("I64*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::Int64) + )); + + let result = parse_type("String*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::String) + )); + + // Nested starred types + let result = parse_type("I64**").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.clone().value, + Type::Starred(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + // Complex types with starred + let result = parse_type("[I64]*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.clone().value, + Type::Array(arr_inner) if matches!(*arr_inner.value, Type::Int64) + ) + )); + + let result = parse_type("(I64, String)*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::Tuple(_)) + )); + + // Function return type is starred + let result = parse_type("I64 -> String*").unwrap(); + assert!(matches!(*result.value, + Type::Closure(param, ret) + if matches!(*param.value, Type::Int64) + && matches!(*ret.value, Type::Starred(_)) + && matches!(*ret.value.clone(), + Type::Starred(inner) if matches!(*inner.value, Type::String)) + )); + + // Entire function type is starred + let result = parse_type("(I64 -> String)*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::Closure(_, _)) + )); + } + + #[test] + fn test_dollared_types() { + // Basic dollared types + let result = parse_type("I64$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::Int64) + )); + + let result = parse_type("String$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::String) + )); + + // Nested dollared types + let result = parse_type("I64$$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.clone().value, + Type::Dollared(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + // Complex types with dollared + let result = parse_type("[I64]$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.clone().value, + Type::Array(arr_inner) if matches!(*arr_inner.value, Type::Int64) + ) + )); + + let result = parse_type("(I64, String)$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::Tuple(_)) + )); + + // Function return type is dollared + let result = parse_type("I64 -> String$").unwrap(); + assert!(matches!(*result.value, + Type::Closure(param, ret) + if matches!(*param.value, Type::Int64) + && matches!(*ret.value, Type::Dollared(_)) + && matches!(*ret.value.clone(), + Type::Dollared(inner) if matches!(*inner.value, Type::String)) + )); + + // Entire function type is dollared + let result = parse_type("(I64 -> String)$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::Closure(_, _)) + )); + } + + #[test] + fn test_mixed_modifiers() { + // Test combination of modifiers + let result = parse_type("I64?*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.clone().value, + Type::Questioned(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + let result = parse_type("I64*$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.clone().value, + Type::Starred(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + let result = parse_type("I64$?").unwrap(); + assert!(matches!(*result.value, + Type::Questioned(inner) if matches!(*inner.clone().value, + Type::Dollared(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + // Complex mixed type + let result = parse_type("(I64 -> String?)*$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(outer) if matches!(*outer.clone().value, + Type::Starred(inner) if matches!(*inner.value, Type::Closure(_, _)) + ) )); } @@ -334,7 +462,7 @@ mod tests { assert!(matches!(*map_array.value, Type::Map(_, _))); if let Type::Map(map_key, map_val) = &*map_array.value { assert!(matches!(*map_key.value, Type::String)); - assert!(matches!(*map_val.value, Type::Adt(_))); + assert!(matches!(*map_val.value, Type::Identifier(_))); } } } @@ -344,14 +472,14 @@ mod tests { assert!(matches!(*ret_tuple.value, Type::Tuple(_))); if let Type::Tuple(elements) = &*ret_tuple.value { assert_eq!(elements.len(), 3); - assert!(matches!(*elements[0].value, Type::Adt(_))); - assert!(matches!(*elements[1].value, Type::Adt(_))); + assert!(matches!(*elements[0].value, Type::Identifier(_))); + assert!(matches!(*elements[1].value, Type::Identifier(_))); assert!(matches!(*elements[2].value, Type::Closure(_, _))); if let Type::Closure(bool_param, scalar_arr) = &*elements[2].value { assert!(matches!(*bool_param.value, Type::Bool)); assert!(matches!(*scalar_arr.value, Type::Array(_))); if let Type::Array(scalar) = &*scalar_arr.value { - assert!(matches!(*scalar.value, Type::Adt(_))); + assert!(matches!(*scalar.value, Type::Identifier(_))); } } } @@ -360,9 +488,8 @@ mod tests { } } - // Test an even more insane nested type with optionals - let optional_insane = - "{String? : {I64 : [(Logical -> {String : [((Bool?, [Scalar]?) -> Physical?)]?}?)]}}?"; - assert!(parse_type(optional_insane).is_ok()); + // Test an even more insane nested type with optionals, starred, and dollared + let complex_type = "{String*? : {I64$ : [(Logical -> {String : [((Bool?, [Scalar]$) -> Physical?*)]?}$)]}}*$?"; + assert!(parse_type(complex_type).is_ok()); } } diff --git a/optd-dsl/src/utils/error.rs b/optd-dsl/src/utils/error.rs index f9d2dce7..e1cf7fb0 100644 --- a/optd-dsl/src/utils/error.rs +++ b/optd-dsl/src/utils/error.rs @@ -44,14 +44,14 @@ pub trait Diagnose { #[derive(Debug)] pub enum CompileError { /// Errors occurring during the lexing/tokenization phase - LexerError(LexerError), + LexerError(Box), /// Errors occurring during the parsing/syntax analysis phase - ParserError(ParserError), + ParserError(Box), /// Errors occurring during the semantic analysis phase - SemanticError(SemanticError), + SemanticError(Box), /// Errors occurring during the type analysis phase - TypeError(TypeError), + TypeError(Box), } diff --git a/optd-dsl/src/utils/mod.rs b/optd-dsl/src/utils/mod.rs index 3ab35050..e93ebb83 100644 --- a/optd-dsl/src/utils/mod.rs +++ b/optd-dsl/src/utils/mod.rs @@ -1,2 +1,4 @@ pub mod error; pub mod span; +#[cfg(test)] +pub mod tests; diff --git a/optd-dsl/src/engine/test_utils.rs b/optd-dsl/src/utils/tests.rs similarity index 81% rename from optd-dsl/src/engine/test_utils.rs rename to optd-dsl/src/utils/tests.rs index 8ee3d7ad..c173bcc9 100644 --- a/optd-dsl/src/engine/test_utils.rs +++ b/optd-dsl/src/utils/tests.rs @@ -1,9 +1,8 @@ -use super::{Continuation, EngineResponse}; use crate::analyzer::hir::{ CoreData, Expr, ExprKind, Goal, GroupId, Literal, LogicalOp, MatchArm, Materializable, Operator, Pattern, PhysicalOp, Value, }; -use crate::engine::Engine; +use crate::engine::{Continuation, Engine, EngineResponse}; use Materializable::*; use std::collections::{HashMap, VecDeque}; use std::sync::{Arc, Mutex}; @@ -82,6 +81,35 @@ impl TestHarness { } } +// Helper to compare Values +pub fn assert_values_equal(v1: &Value, v2: &Value) { + match (&v1.data, &v2.data) { + (CoreData::Literal(l1), CoreData::Literal(l2)) => match (l1, l2) { + (Literal::Int64(i1), Literal::Int64(i2)) => assert_eq!(i1, i2), + (Literal::Float64(f1), Literal::Float64(f2)) => assert_eq!(f1, f2), + (Literal::String(s1), Literal::String(s2)) => assert_eq!(s1, s2), + (Literal::Bool(b1), Literal::Bool(b2)) => assert_eq!(b1, b2), + (Literal::Unit, Literal::Unit) => {} + _ => panic!("Literals don't match: {:?} vs {:?}", l1, l2), + }, + (CoreData::None, CoreData::None) => {} + (CoreData::Tuple(t1), CoreData::Tuple(t2)) => { + assert_eq!(t1.len(), t2.len()); + for (v1, v2) in t1.iter().zip(t2.iter()) { + assert_values_equal(v1, v2); + } + } + (CoreData::Struct(n1, f1), CoreData::Struct(n2, f2)) => { + assert_eq!(n1, n2); + assert_eq!(f1.len(), f2.len()); + for (v1, v2) in f1.iter().zip(f2.iter()) { + assert_values_equal(v1, v2); + } + } + _ => panic!("Values don't match: {:?} vs {:?}", v1.data, v2.data), + } +} + /// Helper to create a literal expression. pub fn lit_expr(literal: Literal) -> Arc { Arc::new(Expr::new(ExprKind::CoreExpr(CoreData::Literal(literal)))) @@ -89,7 +117,7 @@ pub fn lit_expr(literal: Literal) -> Arc { /// Helper to create a literal value. pub fn lit_val(literal: Literal) -> Value { - Value(CoreData::Literal(literal)) + Value::new(CoreData::Literal(literal)) } /// Helper to create an integer literal. @@ -119,12 +147,12 @@ pub fn match_arm(pattern: Pattern, expr: Arc) -> MatchArm { /// Helper to create an array value. pub fn array_val(items: Vec) -> Value { - Value(CoreData::Array(items)) + Value::new(CoreData::Array(items)) } /// Helper to create a struct value. pub fn struct_val(name: &str, fields: Vec) -> Value { - Value(CoreData::Struct(name.to_string(), fields)) + Value::new(CoreData::Struct(name.to_string(), fields)) } /// Helper to create a pattern matching expression. @@ -174,7 +202,7 @@ pub fn create_logical_operator(tag: &str, data: Vec, children: Vec children, }; - Value(CoreData::Logical(Materialized(LogicalOp::logical(op)))) + Value::new(CoreData::Logical(Materialized(LogicalOp::logical(op)))) } /// Helper to create a simple physical operator value. @@ -185,7 +213,7 @@ pub fn create_physical_operator(tag: &str, data: Vec, children: Vec