From 4f3e534a30523d7a42a459388f2e8d9be3cad3dc Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 17:24:31 -0400 Subject: [PATCH 01/11] Add $ and *, and # --- .../src/analyzer/semantic_checker/error.rs | 24 +- optd-dsl/src/analyzer/type.rs | 365 +++++++++--------- optd-dsl/src/cli/basic.op | 6 +- optd-dsl/src/lexer/lex.rs | 2 + optd-dsl/src/lexer/tokens.rs | 4 + optd-dsl/src/parser/ast.rs | 25 +- optd-dsl/src/parser/expr.rs | 123 ++++-- optd-dsl/src/parser/function.rs | 24 +- optd-dsl/src/parser/type.rs | 215 ++++++++--- 9 files changed, 494 insertions(+), 294 deletions(-) diff --git a/optd-dsl/src/analyzer/semantic_checker/error.rs b/optd-dsl/src/analyzer/semantic_checker/error.rs index 848e7c0e..143edc31 100644 --- a/optd-dsl/src/analyzer/semantic_checker/error.rs +++ b/optd-dsl/src/analyzer/semantic_checker/error.rs @@ -2,8 +2,30 @@ use ariadne::{Report, Source}; use crate::utils::{error::Diagnose, span::Span}; +/// Error types for semantic analysis #[derive(Debug)] -pub struct SemanticError {} +pub enum SemanticError { + /// Error for duplicate ADT names + DuplicateAdt { + /// Name of the duplicate ADT + name: String, + /// Span of the first declaration + first_span: Span, + /// Span of the duplicate declaration + duplicate_span: Span, + }, +} + +impl SemanticError { + /// Creates a new error for duplicate ADT names + pub fn new_duplicate_adt(name: String, first_span: Span, duplicate_span: Span) -> Self { + Self::DuplicateAdt { + name, + first_span, + duplicate_span, + } + } +} impl Diagnose for SemanticError { fn report(&self) -> Report { diff --git a/optd-dsl/src/analyzer/type.rs b/optd-dsl/src/analyzer/type.rs index cc396257..88d490af 100644 --- a/optd-dsl/src/analyzer/type.rs +++ b/optd-dsl/src/analyzer/type.rs @@ -1,8 +1,9 @@ +use super::semantic_checker::error::SemanticError; use crate::parser::ast::Adt; -use std::{ - collections::{HashMap, HashSet}, - ops::{Deref, DerefMut}, -}; +use crate::utils::error::CompileError; +use crate::utils::span::Span; +use Adt::*; +use std::collections::{HashMap, HashSet}; pub type Identifier = String; @@ -21,88 +22,22 @@ pub enum Type { // Special types Unit, Universe, + Unknown, - // Complex types + // User types + Adt(Identifier), + Generic(Identifier), + + // Composite types Array(Box), Closure(Box, Box), Tuple(Vec), Map(Box, Box), Optional(Box), - // User-defined types - Adt(Identifier), -} - -/// A typed value that carries both a value and its type information. -/// -/// This generic wrapper allows attaching type information to any value in the compiler pipeline. -/// It's particularly useful for tracking types through expressions, statements, and other AST nodes -/// during the type checking phase. -/// -/// # Type Parameters -/// -/// * `T` - The type of the wrapped value -#[derive(Clone, PartialEq, Eq, Debug)] -pub struct Typed { - /// The wrapped value - pub value: Box, - - /// The type of the value - pub ty: Type, -} - -impl Deref for Typed { - /// The wrapped type T - type Target = T; - - /// Returns a reference to the wrapped value - fn deref(&self) -> &Self::Target { - &self.value - } -} - -/// Implements mutable dereferencing for Typed -impl DerefMut for Typed { - /// Returns a mutable reference to the wrapped value - fn deref_mut(&mut self) -> &mut Self::Target { - &mut self.value - } -} - -impl Typed { - /// Creates a new typed value with the given value and type. - /// - /// # Arguments - /// - /// * `value` - The value to wrap - /// * `ty` - The type to associate with the value - /// - /// # Returns - /// - /// A new `Typed` instance containing the value and its type - pub fn new(value: T, ty: Type) -> Self { - Self { - value: Box::new(value), - ty, - } - } - - /// Checks if this typed value's type is a subtype of another typed value's type. - /// - /// This is useful for type checking to ensure type compatibility between - /// expressions, assignments, function calls, etc. - /// - /// # Arguments - /// - /// * `other` - The other typed value to compare against - /// * `registry` - The type registry that defines subtyping relationships - /// - /// # Returns - /// - /// `true` if this type is a subtype of the other, `false` otherwise - pub fn subtype_of(&self, other: &Self, registry: &TypeRegistry) -> bool { - registry.is_subtype(&self.ty, &other.ty) - } + // For Logical & Physical: memo status + Stored(Box), + Costed(Box), } /// Manages the type hierarchy and subtyping relationships @@ -113,6 +48,7 @@ impl Typed { #[derive(Debug, Clone, Default)] pub struct TypeRegistry { subtypes: HashMap>, + adt_spans: HashMap, // Track spans for error reporting } impl TypeRegistry { @@ -125,26 +61,60 @@ impl TypeRegistry { /// # Arguments /// /// * `adt` - The ADT to register - pub fn register_adt(&mut self, adt: &Adt) { + /// + /// # Returns + /// + /// `Ok(())` if registration is successful, or a `CompileError` if a duplicate name is found + pub fn register_adt(&mut self, adt: &Adt) -> Result<(), CompileError> { match adt { - Adt::Product { name, .. } => { + Product { name, .. } => { let type_name = name.value.as_ref().clone(); + // Check for duplicate ADT names. + if let Some(existing_span) = self.adt_spans.get(&type_name) { + return Err(SemanticError::new_duplicate_adt( + type_name, + existing_span.clone(), + name.span.clone(), + ) + .into()); + } + + self.adt_spans.insert(type_name.clone(), name.span.clone()); self.subtypes.entry(type_name).or_default(); + Ok(()) } - Adt::Sum { name, variants } => { + Sum { name, variants } => { let enum_name = name.value.as_ref().clone(); + + // Check for duplicate ADT names. + if let Some(existing_span) = self.adt_spans.get(&enum_name) { + return Err(SemanticError::new_duplicate_adt( + enum_name, + existing_span.clone(), + name.clone().span, + ) + .into()); + } + + self.adt_spans.insert(enum_name.clone(), name.clone().span); self.subtypes.entry(enum_name.clone()).or_default(); + for variant in variants { let variant_adt = variant.value.as_ref(); - self.register_adt(variant_adt); + // Register each variant. + self.register_adt(variant_adt)?; + let variant_name = match variant_adt { - Adt::Product { name, .. } => name.value.as_ref(), - Adt::Sum { name, .. } => name.value.as_ref(), + Product { name, .. } => name.value.as_ref(), + Sum { name, .. } => name.value.as_ref(), }; + + // Add variant as a subtype of the enum. if let Some(children) = self.subtypes.get_mut(&enum_name) { children.insert(variant_name.clone()); } } + Ok(()) } } } @@ -171,7 +141,28 @@ impl TypeRegistry { match (child, parent) { // Universe is the top type - everything is a subtype of Universe (_, Type::Universe) => true, - // Check transitive inheritance + + // Stored and Costed type handling + (Type::Stored(child_inner), Type::Stored(parent_inner)) => { + self.is_subtype(child_inner, parent_inner) + } + (Type::Costed(child_inner), Type::Costed(parent_inner)) => { + self.is_subtype(child_inner, parent_inner) + } + (Type::Costed(child_inner), Type::Stored(parent_inner)) => { + // Costed(A) is a subtype of Stored(A) + self.is_subtype(child_inner, parent_inner) + } + (Type::Costed(child_inner), parent_inner) => { + // Costed(A) is a subtype of A + self.is_subtype(child_inner, parent_inner) + } + (Type::Stored(child_inner), parent_inner) => { + // Stored(A) is a subtype of A + self.is_subtype(child_inner, parent_inner) + } + + // Check transitive inheritance for ADTs (Type::Adt(child_name), Type::Adt(parent_name)) => { if child_name == parent_name { return true; @@ -186,10 +177,12 @@ impl TypeRegistry { }) }) } + // Array covariance: Array[T] <: Array[U] if T <: U (Type::Array(child_elem), Type::Array(parent_elem)) => { self.is_subtype(child_elem, parent_elem) } + // Tuple covariance: (T1, T2, ...) <: (U1, U2, ...) if T1 <: U1, T2 <: U2, ... (Type::Tuple(child_types), Type::Tuple(parent_types)) => { if child_types.len() != parent_types.len() { @@ -200,19 +193,23 @@ impl TypeRegistry { .zip(parent_types.iter()) .all(|(c, p)| self.is_subtype(c, p)) } + // Map covariance: Map[K1, V1] <: Map[K2, V2] if K1 <: K2 and V1 <: V2 (Type::Map(child_key, child_val), Type::Map(parent_key, parent_val)) => { self.is_subtype(child_key, parent_key) && self.is_subtype(child_val, parent_val) } + // Function contravariance on args, covariance on return type: // (T1 -> U1) <: (T2 -> U2) if T2 <: T1 and U1 <: U2 (Type::Closure(child_param, child_ret), Type::Closure(parent_param, parent_ret)) => { self.is_subtype(parent_param, child_param) && self.is_subtype(child_ret, parent_ret) } + // Optional type covariance: Optional[T] <: Optional[U] if T <: U (Type::Optional(child_ty), Type::Optional(parent_ty)) => { self.is_subtype(child_ty, parent_ty) } + _ => false, } } @@ -245,18 +242,99 @@ mod type_registry_tests { }) .collect(); - Adt::Product { + Product { name: spanned(name.to_string()), fields: spanned_fields, } } fn create_sum_adt(name: &str, variants: Vec) -> Adt { - let spanned_variants: Vec> = variants.into_iter().map(spanned).collect(); - - Adt::Sum { + Sum { name: spanned(name.to_string()), - variants: spanned_variants, + variants: variants.into_iter().map(spanned).collect(), + } + } + + #[test] + fn test_stored_and_costed_types() { + let registry = TypeRegistry::default(); + + // Test Stored type as a subtype of the inner type + assert!(registry.is_subtype(&Type::Stored(Box::new(Type::Int64)), &Type::Int64)); + + // Test Costed type as a subtype of Stored type + assert!(registry.is_subtype( + &Type::Costed(Box::new(Type::Int64)), + &Type::Stored(Box::new(Type::Int64)) + )); + + // Test Costed type as a subtype of the inner type (transitivity) + assert!(registry.is_subtype(&Type::Costed(Box::new(Type::Int64)), &Type::Int64)); + + // Test Stored type covariance + let mut adts_registry = TypeRegistry::default(); + let animal = create_product_adt("Animal", vec![]); + let dog = create_product_adt("Dog", vec![]); + let animals_enum = create_sum_adt("Animals", vec![animal, dog]); + adts_registry.register_adt(&animals_enum).unwrap(); + + assert!(adts_registry.is_subtype( + &Type::Stored(Box::new(Type::Adt("Dog".to_string()))), + &Type::Stored(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test Costed type covariance + assert!(adts_registry.is_subtype( + &Type::Costed(Box::new(Type::Adt("Dog".to_string()))), + &Type::Costed(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test the inheritance relationship: Costed(Dog) <: Stored(Animals) + assert!(adts_registry.is_subtype( + &Type::Costed(Box::new(Type::Adt("Dog".to_string()))), + &Type::Stored(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test nested Stored/Costed types + assert!(adts_registry.is_subtype( + &Type::Stored(Box::new(Type::Costed(Box::new(Type::Adt( + "Dog".to_string() + ))))), + &Type::Stored(Box::new(Type::Adt("Animals".to_string()))) + )); + + // Test with Array of Stored/Costed types + assert!(adts_registry.is_subtype( + &Type::Array(Box::new(Type::Costed(Box::new(Type::Adt( + "Dog".to_string() + ))))), + &Type::Array(Box::new(Type::Stored(Box::new(Type::Adt( + "Animals".to_string() + ))))) + )); + } + + #[test] + fn test_duplicate_adt_detection() { + let mut registry = TypeRegistry::default(); + + // First registration should succeed + let car1 = create_product_adt("Car", vec![]); + assert!(registry.register_adt(&car1).is_ok()); + + // Second registration with the same name should fail + let car2 = Product { + name: spanned("Car".to_string()), + fields: vec![], + }; + + let result = registry.register_adt(&car2); + assert!(result.is_err()); + + if let Err(CompileError::SemanticError(SemanticError::DuplicateAdt { name, .. })) = result { + assert_eq!(name, "Car"); + } else { + panic!("Expected DuplicateAdt error"); } } @@ -316,7 +394,7 @@ mod type_registry_tests { let vehicle = create_product_adt("Vehicle", vec![]); let car = create_product_adt("Car", vec![]); let vehicles_enum = create_sum_adt("Vehicles", vec![vehicle, car]); - adts_registry.register_adt(&vehicles_enum); + adts_registry.register_adt(&vehicles_enum).unwrap(); assert!(adts_registry.is_subtype( &Type::Array(Box::new(Type::Adt("Car".to_string()))), @@ -421,7 +499,7 @@ mod type_registry_tests { let animal = create_product_adt("Animal", vec![]); let dog = create_product_adt("Dog", vec![]); let animals_enum = create_sum_adt("Animals", vec![animal, dog]); - adts_registry.register_adt(&animals_enum); + adts_registry.register_adt(&animals_enum).unwrap(); // (Animals -> Bool) <: (Dog -> Bool) because Dog <: Animals (contravariance) assert!(adts_registry.is_subtype( @@ -469,7 +547,7 @@ mod type_registry_tests { let shapes_enum = create_sum_adt("Shapes", vec![shape, circle, rectangle]); // Register the ADT - registry.register_adt(&shapes_enum); + registry.register_adt(&shapes_enum).unwrap(); // Test subtypes relationship assert!(registry.is_subtype( @@ -531,6 +609,10 @@ mod type_registry_tests { &Type::Universe )); + // Check that Universe is a supertype of Stored and Costed types + assert!(registry.is_subtype(&Type::Stored(Box::new(Type::Int64)), &Type::Universe)); + assert!(registry.is_subtype(&Type::Costed(Box::new(Type::Int64)), &Type::Universe)); + // But Universe is not a subtype of any other type assert!(!registry.is_subtype(&Type::Universe, &Type::Int64)); assert!(!registry.is_subtype(&Type::Universe, &Type::String)); @@ -573,7 +655,7 @@ mod type_registry_tests { let vehicles_enum = create_sum_adt("Vehicles", vec![vehicle, cars_enum, truck]); // Register the ADT - registry.register_adt(&vehicles_enum); + registry.register_adt(&vehicles_enum).unwrap(); // Test direct subtyping relationships assert!(registry.is_subtype( @@ -623,97 +705,4 @@ mod type_registry_tests { &Type::Adt("Cars".to_string()) )); } - - #[test] - fn test_combined_complex_types() { - let mut registry = TypeRegistry::default(); - - // Create and register a type hierarchy - let animal = create_product_adt("Animal", vec![]); - let dog = create_product_adt("Dog", vec![]); - let cat = create_product_adt("Cat", vec![]); - let animals_enum = create_sum_adt("Animals", vec![animal, dog, cat]); - registry.register_adt(&animals_enum); - - // Test array of ADTs - assert!(registry.is_subtype( - &Type::Array(Box::new(Type::Adt("Dog".to_string()))), - &Type::Array(Box::new(Type::Adt("Animals".to_string()))) - )); - - // Test tuple of ADTs - assert!(registry.is_subtype( - &Type::Tuple(vec![ - Type::Adt("Dog".to_string()), - Type::Adt("Cat".to_string()) - ]), - &Type::Tuple(vec![ - Type::Adt("Animals".to_string()), - Type::Adt("Animals".to_string()) - ]) - )); - - // Test map with ADT keys and values - assert!(registry.is_subtype( - &Type::Map( - Box::new(Type::Adt("Dog".to_string())), - Box::new(Type::Adt("Cat".to_string())) - ), - &Type::Map( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Animals".to_string())) - ) - )); - - // Test closures with ADTs - assert!(registry.is_subtype( - &Type::Closure( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Dog".to_string())) - ), - &Type::Closure( - Box::new(Type::Adt("Dog".to_string())), - Box::new(Type::Adt("Animals".to_string())) - ) - )); - - // Test deeply nested types - assert!(registry.is_subtype( - &Type::Map( - Box::new(Type::String), - Box::new(Type::Array(Box::new(Type::Tuple(vec![ - Type::Adt("Dog".to_string()), - Type::Closure( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Cat".to_string())) - ) - ])))) - ), - &Type::Map( - Box::new(Type::String), - Box::new(Type::Array(Box::new(Type::Tuple(vec![ - Type::Adt("Animals".to_string()), - Type::Closure( - Box::new(Type::Adt("Dog".to_string())), - Box::new(Type::Adt("Animals".to_string())) - ) - ])))) - ) - )); - - // Any complex nested type is a subtype of Universe - assert!(registry.is_subtype( - &Type::Map( - Box::new(Type::String), - Box::new(Type::Array(Box::new(Type::Tuple(vec![ - Type::Adt("Dog".to_string()), - Type::Closure( - Box::new(Type::Adt("Animals".to_string())), - Box::new(Type::Adt("Cat".to_string())) - ) - ])))) - ), - &Type::Universe - )); - } } diff --git a/optd-dsl/src/cli/basic.op b/optd-dsl/src/cli/basic.op index 09dab1d7..4f18393e 100644 --- a/optd-dsl/src/cli/basic.op +++ b/optd-dsl/src/cli/basic.op @@ -93,6 +93,8 @@ data JoinType = fn map_get(map: {K: V}, key: K): V = None +fn empty: V + [rust] fn (expr: Scalar) apply_children(f: Scalar -> Scalar) = None @@ -102,10 +104,10 @@ fn (pred: Predicate) remap(map: {I64 : I64}) = \ _ -> predicate.apply_children(child -> rewrite_column_refs(child, map)) [rule] -fn (expr: Logical) join_commute: Logical? = match expr +fn (expr: Logical*) join_commute: Logical? = match expr \ Join(left, right, Inner, cond) -> let - right_indices = 0.right.schema_len, + right_indices = 0..right.schema_len, left_indices = 0..left.schema_len, remapping = left_indices.map(i -> (i, i + right_len)) ++ right_indices.map(i -> (left_len + i, i)).to_map, diff --git a/optd-dsl/src/lexer/lex.rs b/optd-dsl/src/lexer/lex.rs index d577dc67..2b2e4048 100644 --- a/optd-dsl/src/lexer/lex.rs +++ b/optd-dsl/src/lexer/lex.rs @@ -140,6 +140,8 @@ fn lexer() -> impl Parser, Error = Simple> just(".").to(Token::Dot), just(":").to(Token::Colon), just("?").to(Token::Question), + just("$").to(Token::Dollar), + just("#").to(Token::HashTag), )); let comments = just("//") diff --git a/optd-dsl/src/lexer/tokens.rs b/optd-dsl/src/lexer/tokens.rs index 2d416a3c..77803604 100644 --- a/optd-dsl/src/lexer/tokens.rs +++ b/optd-dsl/src/lexer/tokens.rs @@ -64,6 +64,8 @@ pub enum Token { Colon, // : UnderScore, // _ Question, // ? + Dollar, // $ + HashTag, // # } pub const ALL_DELIMITERS: [(Token, Token); 3] = [ @@ -137,6 +139,8 @@ impl std::fmt::Display for Token { Token::Colon => write!(f, ":"), Token::UnderScore => write!(f, "_"), Token::Question => write!(f, "?"), + Token::Dollar => write!(f, "$"), + Token::HashTag => write!(f, "#"), } } } diff --git a/optd-dsl/src/parser/ast.rs b/optd-dsl/src/parser/ast.rs index 87faff1f..e3d76ef1 100644 --- a/optd-dsl/src/parser/ast.rs +++ b/optd-dsl/src/parser/ast.rs @@ -35,12 +35,17 @@ pub enum Type { Tuple(Vec>), /// Map/dictionary with key and value types Map(Spanned, Spanned), - /// Optional type (represents a value that might be absent) - Optional(Spanned), - // User defined types - /// Algebraic Data Type reference by name - Adt(Identifier), + // Type wrappers + /// For types ending with ? + Questioned(Spanned), + /// For types ending with * + Starred(Spanned), + /// For types ending with $ + Dollared(Spanned), + + // The type identified by a name (e.g., ADT, Logical, Physical, generics, etc.) + Identifier(Identifier), } /// Represents a field in a record or a parameter in a function @@ -96,7 +101,7 @@ pub enum Expr { Unary(UnaryOp, Spanned), // Function-related - /// Postfix operations (function call, member access) + /// Postfix operations (function call, member access, member call) Postfix(Spanned, PostfixOp), /// Anonymous function definition Closure(Vec>, Spanned), @@ -216,10 +221,12 @@ pub enum UnaryOp { /// Represents postfix operations (function call, member access, composition) #[derive(Debug, Clone)] pub enum PostfixOp { - /// Function or method call with arguments + /// Function or method call with arguments [?](..) Call(Vec>), - /// Member/field access - Member(Identifier), + /// Struct field access [?]#field + Field(Identifier), + /// Method access [?].method + Method(Identifier), } /// Represents a function definition diff --git a/optd-dsl/src/parser/expr.rs b/optd-dsl/src/parser/expr.rs index 470486ae..28b5937b 100644 --- a/optd-dsl/src/parser/expr.rs +++ b/optd-dsl/src/parser/expr.rs @@ -198,9 +198,12 @@ pub fn expr_parser() -> impl Parser, Error = Simple name }) + .map(PostfixOp::Field), just(Token::Dot) .ignore_then(select! { Token::TermIdent(name) => name }) - .map(PostfixOp::Member), + .map(PostfixOp::Method), )) .map_with_span(|op, span| (op, span)) .repeated(), @@ -399,7 +402,7 @@ mod tests { assert_expr_eq(&a.value, &e.value); } } - (PostfixOp::Member(a_member), PostfixOp::Member(e_member)) => { + (PostfixOp::Field(a_member), PostfixOp::Field(e_member)) => { assert_eq!(a_member, e_member); } _ => panic!( @@ -695,7 +698,7 @@ mod tests { } // Complex map with expressions as keys and values - let (result, errors) = parse_expr("{x + 1: y * 2, \"key\": z.field}"); + let (result, errors) = parse_expr("{x + 1: y * 2, \"key\": z#field}"); assert!( result.is_some(), "Expected successful parse for complex map" @@ -720,13 +723,13 @@ mod tests { panic!("Expected binary multiplication in first value"); } - // Second entry: "key": z.field + // Second entry: "key": z#field assert_expr_eq( &entries[1].0.value, &Expr::Literal(Literal::String("key".to_string())), ); - if let Expr::Postfix(expr, PostfixOp::Member(member)) = &*entries[1].1.value { + if let Expr::Postfix(expr, PostfixOp::Field(member)) = &*entries[1].1.value { assert_expr_eq(&expr.value, &Expr::Ref("z".to_string())); assert_eq!(member, "field"); } else { @@ -824,7 +827,7 @@ mod tests { fn test_nested_expressions() { // Test deeply nested expressions with different expression types let (result, errors) = parse_expr( - "if ({[1 + 2, 3]: 5}.size > 0) then { let x = 42 in x } else fail(\"Empty map\")", + "if ({[1 + 2, 3]: 5}#size > 0) then { let x = 42 in x } else fail(\"Empty map\")", ); assert!( @@ -842,7 +845,7 @@ mod tests { // Test condition with member access and binary operation if let Expr::Binary(left, BinOp::Gt, right) = &*condition.value { // Test member access on map expression - if let Expr::Postfix(expr, PostfixOp::Member(member)) = &*left.value { + if let Expr::Postfix(expr, PostfixOp::Field(member)) = &*left.value { assert_eq!(member, "size"); // Check that expression is a Map @@ -881,7 +884,7 @@ mod tests { panic!("Expected Map expression"); } } else { - panic!("Expected member access in condition"); + panic!("Expected field access in condition"); } assert_expr_eq(&right.value, &Expr::Literal(Literal::Int64(0))); @@ -1268,11 +1271,11 @@ mod tests { assert_eq!(args2.len(), 1); if let Expr::Postfix(inner3, PostfixOp::Call(args3)) = &*inner2.value { assert_eq!(args3.len(), 1); - if let Expr::Postfix(obj, PostfixOp::Member(method)) = &*inner3.value { + if let Expr::Postfix(obj, PostfixOp::Method(method)) = &*inner3.value { assert_expr_eq(&obj.value, &Expr::Ref("obj".to_string())); assert_eq!(method, "method"); } else { - panic!("Expected member access at base"); + panic!("Expected method access at base"); } } else { panic!("Expected third call"); @@ -1286,15 +1289,15 @@ mod tests { } // Test function call followed by field access - let (result, errors) = parse_expr("func().field"); + let (result, errors) = parse_expr("func()#field"); assert!( result.is_some(), - "Expected successful parse for func().field" + "Expected successful parse for func()#field" ); - assert!(errors.is_empty(), "Expected no errors for func().field"); + assert!(errors.is_empty(), "Expected no errors for func()#field"); if let Some(expr) = result { - if let Expr::Postfix(inner, PostfixOp::Member(member)) = &*expr.value { + if let Expr::Postfix(inner, PostfixOp::Field(member)) = &*expr.value { assert_eq!(member, "field"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner.value { assert_eq!(args.len(), 0); @@ -1303,12 +1306,12 @@ mod tests { panic!("Expected call in function call"); } } else { - panic!("Expected member access"); + panic!("Expected field access"); } } - // Test complex chain of operations - let (result, errors) = parse_expr("obj.method().field.other_method(arg)"); + // Test complex chain of operations with mix of field and method + let (result, errors) = parse_expr("obj.method()#field.other_method(arg)"); assert!( result.is_some(), "Expected successful parse for complex chain" @@ -1320,47 +1323,47 @@ mod tests { assert_eq!(args.len(), 1); assert_expr_eq(&args[0].value, &Expr::Ref("arg".to_string())); - if let Expr::Postfix(inner2, PostfixOp::Member(member)) = &*inner1.value { + if let Expr::Postfix(inner2, PostfixOp::Method(member)) = &*inner1.value { assert_eq!(member, "other_method"); - if let Expr::Postfix(inner3, PostfixOp::Member(field)) = &*inner2.value { + if let Expr::Postfix(inner3, PostfixOp::Field(field)) = &*inner2.value { assert_eq!(field, "field"); if let Expr::Postfix(inner4, PostfixOp::Call(method_args)) = &*inner3.value { assert_eq!(method_args.len(), 0); - if let Expr::Postfix(obj, PostfixOp::Member(method)) = &*inner4.value { + if let Expr::Postfix(obj, PostfixOp::Method(method)) = &*inner4.value { assert_expr_eq(&obj.value, &Expr::Ref("obj".to_string())); assert_eq!(method, "method"); } else { - panic!("Expected member access for initial method"); + panic!("Expected method access for initial method"); } } else { panic!("Expected call for first method"); } } else { - panic!("Expected member access for field"); + panic!("Expected field access"); } } else { - panic!("Expected member access for final method"); + panic!("Expected method access for final method"); } } else { panic!("Expected call at top level"); } } - // Test compose operator + // Test method reference without call let (result, errors) = parse_expr("map(dat).filter"); assert!( result.is_some(), - "Expected successful parse for compose operator" + "Expected successful parse for method reference" ); - assert!(errors.is_empty(), "Expected no errors for compose operator"); + assert!(errors.is_empty(), "Expected no errors for method reference"); if let Some(expr) = result { - if let Expr::Postfix(inner, PostfixOp::Member(name)) = &*expr.value { + if let Expr::Postfix(inner, PostfixOp::Method(name)) = &*expr.value { assert_eq!(name, "filter"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner.value { @@ -1368,32 +1371,32 @@ mod tests { assert_eq!(args.len(), 1); assert_expr_eq(&args[0].value, &Expr::Ref("dat".to_string())); } else { - panic!("Expected function call before compose operator"); + panic!("Expected function call before method reference"); } } else { - panic!("Expected compose operator"); + panic!("Expected method reference"); } } - // Test chained compose operators + // Test chained method references let (result, errors) = parse_expr("transform(input).map.filter.reduce"); assert!( result.is_some(), - "Expected successful parse for chained compose operators" + "Expected successful parse for chained method references" ); assert!( errors.is_empty(), - "Expected no errors for chained compose operators" + "Expected no errors for chained method references" ); if let Some(expr) = result { - if let Expr::Postfix(inner1, PostfixOp::Member(name1)) = &*expr.value { + if let Expr::Postfix(inner1, PostfixOp::Method(name1)) = &*expr.value { assert_eq!(name1, "reduce"); - if let Expr::Postfix(inner2, PostfixOp::Member(name2)) = &*inner1.value { + if let Expr::Postfix(inner2, PostfixOp::Method(name2)) = &*inner1.value { assert_eq!(name2, "filter"); - if let Expr::Postfix(inner3, PostfixOp::Member(name3)) = &*inner2.value { + if let Expr::Postfix(inner3, PostfixOp::Method(name3)) = &*inner2.value { assert_eq!(name3, "map"); if let Expr::Postfix(func, PostfixOp::Call(args)) = &*inner3.value { @@ -1404,13 +1407,57 @@ mod tests { panic!("Expected function call at the beginning of chain"); } } else { - panic!("Expected first compose operation in chain"); + panic!("Expected first method reference in chain"); + } + } else { + panic!("Expected second method reference in chain"); + } + } else { + panic!("Expected third method reference in chain"); + } + } + + // Test mixed field access and method references + let (result, errors) = parse_expr("person#name.split(\"#\")#length"); + assert!( + result.is_some(), + "Expected successful parse for mixed field and method chain" + ); + assert!( + errors.is_empty(), + "Expected no errors for mixed field and method chain" + ); + + if let Some(expr) = result { + if let Expr::Postfix(inner1, PostfixOp::Field(field_name)) = &*expr.value { + assert_eq!(field_name, "length"); + + if let Expr::Postfix(inner2, PostfixOp::Call(args)) = &*inner1.value { + assert_eq!(args.len(), 1); + if let Expr::Literal(Literal::String(s)) = &*args[0].value { + assert_eq!(s, "#"); + } else { + panic!("Expected string literal argument"); + } + + if let Expr::Postfix(inner3, PostfixOp::Method(method_name)) = &*inner2.value { + assert_eq!(method_name, "split"); + + if let Expr::Postfix(inner4, PostfixOp::Field(field_name)) = &*inner3.value + { + assert_eq!(field_name, "name"); + assert_expr_eq(&inner4.value, &Expr::Ref("person".to_string())); + } else { + panic!("Expected field access at beginning of chain"); + } + } else { + panic!("Expected method reference in middle of chain"); } } else { - panic!("Expected second compose operation in chain"); + panic!("Expected method call in chain"); } } else { - panic!("Expected third compose operation in chain"); + panic!("Expected field access at end of chain"); } } } diff --git a/optd-dsl/src/parser/function.rs b/optd-dsl/src/parser/function.rs index aadc5d9c..8d48a07c 100644 --- a/optd-dsl/src/parser/function.rs +++ b/optd-dsl/src/parser/function.rs @@ -271,10 +271,10 @@ mod tests { let params = func.value.params.as_ref().unwrap(); assert_eq!(params.len(), 1); assert_eq!(*params[0].value.name.value, "x"); - assert!(matches!(*params[0].clone().value.ty.value, Type::Adt(name) if name == "T")); + assert!(matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "T")); // Check return type - assert!(matches!(*func.value.return_type.value, Type::Adt(name) if name == "T")); + assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "T")); // Check body assert!(func.value.body.is_some()); @@ -305,19 +305,19 @@ mod tests { // Check first parameter (map: {K: V}) assert_eq!(*params[0].value.name.value, "map"); if let Type::Map(key_ty, val_ty) = &*params[0].value.ty.value { - assert!(matches!(*key_ty.clone().value, Type::Adt(name) if name == "K")); - assert!(matches!(*val_ty.clone().value, Type::Adt(name) if name == "V")); + assert!(matches!(*key_ty.clone().value, Type::Identifier(name) if name == "K")); + assert!(matches!(*val_ty.clone().value, Type::Identifier(name) if name == "V")); } else { panic!("Expected Map type for first parameter"); } // Check second parameter (key: K) assert_eq!(*params[1].value.name.value, "key"); - assert!(matches!(*params[1].clone().value.ty.value, Type::Adt(name) if name == "K")); + assert!(matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "K")); // Check return type (V?) - if let Type::Optional(inner) = &*func.value.return_type.value { - assert!(matches!(*inner.clone().value, Type::Adt(name) if name == "V")); + if let Type::Questioned(inner) = &*func.value.return_type.value { + assert!(matches!(*inner.clone().value, Type::Identifier(name) if name == "V")); } else { panic!("Expected Optional return type"); } @@ -346,12 +346,12 @@ mod tests { let params = func.value.params.as_ref().unwrap(); assert_eq!(params.len(), 2); assert_eq!(*params[0].value.name.value, "a"); - assert!(matches!(*params[0].clone().value.ty.value, Type::Adt(name) if name == "A")); + assert!(matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "A")); assert_eq!(*params[1].value.name.value, "b"); - assert!(matches!(*params[1].clone().value.ty.value, Type::Adt(name) if name == "B")); + assert!(matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "B")); // Check return type - assert!(matches!(*func.value.return_type.value, Type::Adt(name) if name == "C")); + assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "C")); // Check body is None assert!(func.value.body.is_none()); @@ -408,7 +408,7 @@ mod tests { assert!(func.value.receiver.is_some()); if let Some(receiver) = &func.value.receiver { assert_eq!(*receiver.value.name.value, "self"); - assert!(matches!(&*receiver.value.ty.value, Type::Adt(name) if name == "Person")); + assert!(matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person")); } // Check params. @@ -438,7 +438,7 @@ mod tests { assert!(func.value.receiver.is_some()); if let Some(receiver) = &func.value.receiver { assert_eq!(*receiver.value.name.value, "self"); - assert!(matches!(&*receiver.value.ty.value, Type::Adt(name) if name == "Person")); + assert!(matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person")); } // Check params. diff --git a/optd-dsl/src/parser/type.rs b/optd-dsl/src/parser/type.rs index cbdd282f..681334e9 100644 --- a/optd-dsl/src/parser/type.rs +++ b/optd-dsl/src/parser/type.rs @@ -1,3 +1,8 @@ +use super::{ast::Type, utils::delimited_parser}; +use crate::{ + lexer::tokens::Token, + utils::span::{Span, Spanned}, +}; use chumsky::{ Parser, error::Simple, @@ -5,13 +10,6 @@ use chumsky::{ select, }; -use crate::{ - lexer::tokens::Token, - utils::span::{Span, Spanned}, -}; - -use super::{ast::Type, utils::delimited_parser}; - /// Creates a parser for type expressions. /// /// This parser supports: @@ -22,18 +20,8 @@ use super::{ast::Type, utils::delimited_parser}; /// - Function types: T1 -> T2 /// - User-defined types: TypeName /// - Optional types: T? -/// -/// Syntax examples: -/// - I64 - Integer type -/// - String - String type -/// - \[I64\] - Array of integers -/// - {String : I64} - Map from strings to integers -/// - (I64, String) - Tuple with integer and string -/// - I64 -> String - Function from integer to string -/// - I64? - Optional integer -/// - \[String\]? - Optional array of strings -/// - I64 -> String? - Function returning optional string -/// - (I64 -> String)? - Optional function +/// - Starred types: T* +/// - Dollared types: T$ /// /// The parser follows standard precedence rules and supports /// arbitrary nesting of type expressions. @@ -89,8 +77,8 @@ pub fn type_parser() -> impl Parser, Error = Simple Type::Adt(name) }.map_with_span(Spanned::new); + let data_type = select! { Token::TypeIdent(name) => Type::Identifier(name) } + .map_with_span(Spanned::new); // Note: cannot apply delimiter recovery, as its recovery // would block further successful parses (e.g. tuples). @@ -123,18 +111,23 @@ pub fn type_parser() -> impl Parser, Error = Simple param_type, }); - // Process optional types + // Process optional, starred, and dollared types function_type .then( - just(Token::Question) + choice((just(Token::Question), just(Token::Mul), just(Token::Dollar))) .repeated() .at_least(0) .collect::>(), ) - .map_with_span(|(base_type, question_marks), span| { + .map_with_span(|(base_type, modifiers), span| { let mut result = base_type; - for _ in question_marks { - result = Spanned::new(Type::Optional(result), span.clone()); + for modifier in modifiers { + result = match modifier { + Token::Question => Spanned::new(Type::Questioned(result), span.clone()), + Token::Mul => Spanned::new(Type::Starred(result), span.clone()), + Token::Dollar => Spanned::new(Type::Dollared(result), span.clone()), + _ => unreachable!("Invalid type modifier"), + }; } result }) @@ -258,33 +251,33 @@ mod tests { // Basic optional types let result = parse_type("I64?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Int64) + Type::Questioned(inner) if matches!(*inner.value, Type::Int64) )); let result = parse_type("String?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::String) + Type::Questioned(inner) if matches!(*inner.value, Type::String) )); // Nested optional types let result = parse_type("I64??").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.clone().value, - Type::Optional(inner_inner) if matches!(*inner_inner.value, Type::Int64) + Type::Questioned(inner) if matches!(*inner.clone().value, + Type::Questioned(inner_inner) if matches!(*inner_inner.value, Type::Int64) ) )); // Complex types with optional let result = parse_type("[I64]?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.clone().value, + Type::Questioned(inner) if matches!(*inner.clone().value, Type::Array(arr_inner) if matches!(*arr_inner.value, Type::Int64) ) )); let result = parse_type("(I64, String)?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Tuple(_)) + Type::Questioned(inner) if matches!(*inner.value, Type::Tuple(_)) )); // Function return type is optional @@ -292,21 +285,156 @@ mod tests { assert!(matches!(*result.value, Type::Closure(param, ret) if matches!(*param.value, Type::Int64) - && matches!(*ret.value, Type::Optional(_)) + && matches!(*ret.value, Type::Questioned(_)) && matches!(*ret.value.clone(), - Type::Optional(inner) if matches!(*inner.value, Type::String)) + Type::Questioned(inner) if matches!(*inner.value, Type::String)) )); // Entire function type is optional let result = parse_type("(I64 -> String)?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Closure(_, _)) + Type::Questioned(inner) if matches!(*inner.value, Type::Closure(_, _)) )); // Optional map type let result = parse_type("{String : I64}?").unwrap(); assert!(matches!(*result.value, - Type::Optional(inner) if matches!(*inner.value, Type::Map(_, _)) + Type::Questioned(inner) if matches!(*inner.value, Type::Map(_, _)) + )); + } + + #[test] + fn test_starred_types() { + // Basic starred types + let result = parse_type("I64*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::Int64) + )); + + let result = parse_type("String*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::String) + )); + + // Nested starred types + let result = parse_type("I64**").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.clone().value, + Type::Starred(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + // Complex types with starred + let result = parse_type("[I64]*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.clone().value, + Type::Array(arr_inner) if matches!(*arr_inner.value, Type::Int64) + ) + )); + + let result = parse_type("(I64, String)*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::Tuple(_)) + )); + + // Function return type is starred + let result = parse_type("I64 -> String*").unwrap(); + assert!(matches!(*result.value, + Type::Closure(param, ret) + if matches!(*param.value, Type::Int64) + && matches!(*ret.value, Type::Starred(_)) + && matches!(*ret.value.clone(), + Type::Starred(inner) if matches!(*inner.value, Type::String)) + )); + + // Entire function type is starred + let result = parse_type("(I64 -> String)*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.value, Type::Closure(_, _)) + )); + } + + #[test] + fn test_dollared_types() { + // Basic dollared types + let result = parse_type("I64$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::Int64) + )); + + let result = parse_type("String$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::String) + )); + + // Nested dollared types + let result = parse_type("I64$$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.clone().value, + Type::Dollared(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + // Complex types with dollared + let result = parse_type("[I64]$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.clone().value, + Type::Array(arr_inner) if matches!(*arr_inner.value, Type::Int64) + ) + )); + + let result = parse_type("(I64, String)$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::Tuple(_)) + )); + + // Function return type is dollared + let result = parse_type("I64 -> String$").unwrap(); + assert!(matches!(*result.value, + Type::Closure(param, ret) + if matches!(*param.value, Type::Int64) + && matches!(*ret.value, Type::Dollared(_)) + && matches!(*ret.value.clone(), + Type::Dollared(inner) if matches!(*inner.value, Type::String)) + )); + + // Entire function type is dollared + let result = parse_type("(I64 -> String)$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.value, Type::Closure(_, _)) + )); + } + + #[test] + fn test_mixed_modifiers() { + // Test combination of modifiers + let result = parse_type("I64?*").unwrap(); + assert!(matches!(*result.value, + Type::Starred(inner) if matches!(*inner.clone().value, + Type::Questioned(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + let result = parse_type("I64*$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(inner) if matches!(*inner.clone().value, + Type::Starred(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + let result = parse_type("I64$?").unwrap(); + assert!(matches!(*result.value, + Type::Questioned(inner) if matches!(*inner.clone().value, + Type::Dollared(inner_inner) if matches!(*inner_inner.value, Type::Int64) + ) + )); + + // Complex mixed type + let result = parse_type("(I64 -> String?)*$").unwrap(); + assert!(matches!(*result.value, + Type::Dollared(outer) if matches!(*outer.clone().value, + Type::Starred(inner) if matches!(*inner.value, Type::Closure(_, _)) + ) )); } @@ -334,7 +462,7 @@ mod tests { assert!(matches!(*map_array.value, Type::Map(_, _))); if let Type::Map(map_key, map_val) = &*map_array.value { assert!(matches!(*map_key.value, Type::String)); - assert!(matches!(*map_val.value, Type::Adt(_))); + assert!(matches!(*map_val.value, Type::Identifier(_))); } } } @@ -344,14 +472,14 @@ mod tests { assert!(matches!(*ret_tuple.value, Type::Tuple(_))); if let Type::Tuple(elements) = &*ret_tuple.value { assert_eq!(elements.len(), 3); - assert!(matches!(*elements[0].value, Type::Adt(_))); - assert!(matches!(*elements[1].value, Type::Adt(_))); + assert!(matches!(*elements[0].value, Type::Identifier(_))); + assert!(matches!(*elements[1].value, Type::Identifier(_))); assert!(matches!(*elements[2].value, Type::Closure(_, _))); if let Type::Closure(bool_param, scalar_arr) = &*elements[2].value { assert!(matches!(*bool_param.value, Type::Bool)); assert!(matches!(*scalar_arr.value, Type::Array(_))); if let Type::Array(scalar) = &*scalar_arr.value { - assert!(matches!(*scalar.value, Type::Adt(_))); + assert!(matches!(*scalar.value, Type::Identifier(_))); } } } @@ -360,9 +488,8 @@ mod tests { } } - // Test an even more insane nested type with optionals - let optional_insane = - "{String? : {I64 : [(Logical -> {String : [((Bool?, [Scalar]?) -> Physical?)]?}?)]}}?"; - assert!(parse_type(optional_insane).is_ok()); + // Test an even more insane nested type with optionals, starred, and dollared + let complex_type = "{String*? : {I64$ : [(Logical -> {String : [((Bool?, [Scalar]$) -> Physical?*)]?}$)]}}*$?"; + assert!(parse_type(complex_type).is_ok()); } } From eb8f5a7372f93a23323ccf967b722db38db70d1e Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 18:48:46 -0400 Subject: [PATCH 02/11] Add map in HIR --- optd-dsl/src/analyzer/hir.rs | 2 +- optd-dsl/src/analyzer/mod.rs | 1 + optd-dsl/src/engine/eval/binary.rs | 2 +- optd-dsl/src/engine/eval/core.rs | 2 +- optd-dsl/src/engine/eval/expr.rs | 18 +- optd-dsl/src/engine/eval/match.rs | 15 +- optd-dsl/src/engine/eval/operator.rs | 18 +- optd-dsl/src/engine/mod.rs | 3 - optd-dsl/src/engine/test_utils.rs | 243 --------------------------- optd-dsl/src/utils/mod.rs | 2 + 10 files changed, 30 insertions(+), 276 deletions(-) delete mode 100644 optd-dsl/src/engine/test_utils.rs diff --git a/optd-dsl/src/analyzer/hir.rs b/optd-dsl/src/analyzer/hir.rs index 9f9a0f6c..76b8020d 100644 --- a/optd-dsl/src/analyzer/hir.rs +++ b/optd-dsl/src/analyzer/hir.rs @@ -45,7 +45,7 @@ pub enum FunKind { } /// Group identifier in the optimizer -#[derive(Debug, Clone, PartialEq, Copy)] +#[derive(Debug, Clone, PartialEq, Copy, Eq, Hash)] pub struct GroupId(pub i64); /// Either materialized or unmaterialized data diff --git a/optd-dsl/src/analyzer/mod.rs b/optd-dsl/src/analyzer/mod.rs index 5776abb3..3883e948 100644 --- a/optd-dsl/src/analyzer/mod.rs +++ b/optd-dsl/src/analyzer/mod.rs @@ -1,5 +1,6 @@ pub mod context; pub mod hir; +pub mod map; pub mod semantic_checker; pub mod r#type; pub mod type_checker; diff --git a/optd-dsl/src/engine/eval/binary.rs b/optd-dsl/src/engine/eval/binary.rs index e3e4412b..106b66c3 100644 --- a/optd-dsl/src/engine/eval/binary.rs +++ b/optd-dsl/src/engine/eval/binary.rs @@ -81,7 +81,7 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { // Map concatenation (joins two maps). (Map(l), Concat, Map(r)) => { let mut result = l.clone(); - result.extend(r.iter().cloned()); + result.extend(r.into_iter()); Value(Map(result)) } diff --git a/optd-dsl/src/engine/eval/core.rs b/optd-dsl/src/engine/eval/core.rs index 30d6a651..78ce54e9 100644 --- a/optd-dsl/src/engine/eval/core.rs +++ b/optd-dsl/src/engine/eval/core.rs @@ -154,7 +154,7 @@ where #[cfg(test)] mod tests { use crate::engine::Continuation; - use crate::engine::test_utils::{ + use crate::utils::tests::{ TestHarness, evaluate_and_collect, evaluate_and_collect_with_custom_k, int, lit_expr, string, }; diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index 4c9536e7..7f892ad0 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -343,15 +343,15 @@ where #[cfg(test)] mod tests { - use crate::analyzer::{ - context::Context, - hir::{BinOp, CoreData, Expr, ExprKind, FunKind, Literal, Value}, - }; - use crate::engine::{ - Engine, - test_utils::{ - TestHarness, array_val, boolean, evaluate_and_collect, int, lit_expr, lit_val, - ref_expr, string, + use crate::engine::Engine; + use crate::utils::tests::{array_val, ref_expr}; + use crate::{ + analyzer::{ + context::Context, + hir::{BinOp, CoreData, Expr, ExprKind, FunKind, Literal, Value}, + }, + utils::tests::{ + TestHarness, boolean, evaluate_and_collect, int, lit_expr, lit_val, string, }, }; use ExprKind::*; diff --git a/optd-dsl/src/engine/eval/match.rs b/optd-dsl/src/engine/eval/match.rs index 2ad64b8a..90d53a64 100644 --- a/optd-dsl/src/engine/eval/match.rs +++ b/optd-dsl/src/engine/eval/match.rs @@ -520,14 +520,11 @@ where #[cfg(test)] mod tests { - use crate::engine::{ - Engine, - test_utils::{ - array_decomp_pattern, array_val, bind_pattern, create_logical_operator, - create_physical_operator, evaluate_and_collect, int, lit_expr, lit_val, - literal_pattern, match_arm, operator_pattern, pattern_match_expr, ref_expr, string, - struct_pattern, struct_val, wildcard_pattern, - }, + use crate::engine::Engine; + use crate::utils::tests::{ + array_decomp_pattern, array_val, bind_pattern, create_logical_operator, + create_physical_operator, evaluate_and_collect, lit_val, literal_pattern, match_arm, + operator_pattern, ref_expr, struct_pattern, struct_val, }; use crate::{ analyzer::{ @@ -537,7 +534,7 @@ mod tests { Materializable, Operator, Value, }, }, - engine::test_utils::TestHarness, + utils::tests::{TestHarness, int, lit_expr, pattern_match_expr, string, wildcard_pattern}, }; use ExprKind::*; use Materializable::*; diff --git a/optd-dsl/src/engine/eval/operator.rs b/optd-dsl/src/engine/eval/operator.rs index f9168ac4..2064a62d 100644 --- a/optd-dsl/src/engine/eval/operator.rs +++ b/optd-dsl/src/engine/eval/operator.rs @@ -145,16 +145,16 @@ where #[cfg(test)] mod tests { - use crate::analyzer::{ - context::Context, - hir::{ - BinOp, CoreData, Expr, ExprKind, Goal, GroupId, Literal, LogicalOp, Materializable, - Operator, PhysicalOp, Value, + use crate::engine::Engine; + use crate::{ + analyzer::{ + context::Context, + hir::{ + BinOp, CoreData, Expr, ExprKind, Goal, GroupId, Literal, LogicalOp, Materializable, + Operator, PhysicalOp, Value, + }, }, - }; - use crate::engine::{ - Engine, - test_utils::{ + utils::tests::{ TestHarness, create_logical_operator, evaluate_and_collect, int, lit_expr, lit_val, string, }, diff --git a/optd-dsl/src/engine/mod.rs b/optd-dsl/src/engine/mod.rs index 72ac7a0f..196d4f29 100644 --- a/optd-dsl/src/engine/mod.rs +++ b/optd-dsl/src/engine/mod.rs @@ -16,9 +16,6 @@ mod eval; mod utils; pub use utils::*; -#[cfg(test)] -mod test_utils; - /// The engine response type, which can be either a return value with a converter callback /// or a yielded group/goal with a continuation for further processing. pub enum EngineResponse { diff --git a/optd-dsl/src/engine/test_utils.rs b/optd-dsl/src/engine/test_utils.rs deleted file mode 100644 index 8ee3d7ad..00000000 --- a/optd-dsl/src/engine/test_utils.rs +++ /dev/null @@ -1,243 +0,0 @@ -use super::{Continuation, EngineResponse}; -use crate::analyzer::hir::{ - CoreData, Expr, ExprKind, Goal, GroupId, Literal, LogicalOp, MatchArm, Materializable, - Operator, Pattern, PhysicalOp, Value, -}; -use crate::engine::Engine; -use Materializable::*; -use std::collections::{HashMap, VecDeque}; -use std::sync::{Arc, Mutex}; - -/// A test harness for the evaluation engine. -#[derive(Clone)] -pub struct TestHarness { - /// Maps group IDs to their materialized values. - group_mappings: Arc>>>, - - /// Maps goals to their implementations. - goal_mappings: Arc>>>, -} - -impl TestHarness { - /// Creates a new test harness. - pub fn new() -> Self { - Self { - group_mappings: Arc::new(Mutex::new(HashMap::new())), - goal_mappings: Arc::new(Mutex::new(HashMap::new())), - } - } - - /// Registers a logical operator value to be returned when a specific group is requested. - pub fn register_group(&self, group_id: GroupId, value: Value) { - let key = format!("{:?}", group_id); - let mut mappings = self.group_mappings.lock().unwrap(); - mappings.entry(key).or_default().push(value); - } - - /// Registers a physical operator value to be returned when a specific goal is requested. - pub fn register_goal(&self, goal: &Goal, value: Value) { - let key = format!("{:?}:{:?}", goal.group_id, goal.properties); - let mut mappings = self.goal_mappings.lock().unwrap(); - mappings.entry(key).or_default().push(value); - } - - /// Forks the evaluation at a specific group ID and collect the responses. - async fn fork_at_group( - &self, - group_id: GroupId, - k: Continuation>, - queue: &mut VecDeque>, - ) where - T: Send + 'static, - { - let key = format!("{:?}", group_id); - let values = { - let mappings = self.group_mappings.lock().unwrap(); - mappings.get(&key).cloned().unwrap_or_default() - }; - - for value in values { - queue.push_back(k(value).await); - } - } - - /// Forks the evaluation at a specific goal and collect the responses. - async fn fork_at_goal( - &self, - goal: &Goal, - k: Continuation>, - queue: &mut VecDeque>, - ) where - T: Send + 'static, - { - let key = format!("{:?}:{:?}", goal.group_id, goal.properties); - let values = { - let mappings = self.goal_mappings.lock().unwrap(); - mappings.get(&key).cloned().unwrap_or_default() - }; - - for value in values { - queue.push_back(k(value).await); - } - } -} - -/// Helper to create a literal expression. -pub fn lit_expr(literal: Literal) -> Arc { - Arc::new(Expr::new(ExprKind::CoreExpr(CoreData::Literal(literal)))) -} - -/// Helper to create a literal value. -pub fn lit_val(literal: Literal) -> Value { - Value(CoreData::Literal(literal)) -} - -/// Helper to create an integer literal. -pub fn int(i: i64) -> Literal { - Literal::Int64(i) -} - -/// Helper to create a string literal. -pub fn string(s: &str) -> Literal { - Literal::String(s.to_string()) -} - -/// Helper to create a boolean literal. -pub fn boolean(b: bool) -> Literal { - Literal::Bool(b) -} - -/// Helper to create a reference expression. -pub fn ref_expr(name: &str) -> Arc { - Arc::new(Expr::new(ExprKind::Ref(name.to_string()))) -} - -/// Helper to create a pattern match arm. -pub fn match_arm(pattern: Pattern, expr: Arc) -> MatchArm { - MatchArm { pattern, expr } -} - -/// Helper to create an array value. -pub fn array_val(items: Vec) -> Value { - Value(CoreData::Array(items)) -} - -/// Helper to create a struct value. -pub fn struct_val(name: &str, fields: Vec) -> Value { - Value(CoreData::Struct(name.to_string(), fields)) -} - -/// Helper to create a pattern matching expression. -pub fn pattern_match_expr(expr: Arc, arms: Vec) -> Arc { - Arc::new(Expr::new(ExprKind::PatternMatch(expr, arms))) -} - -/// Helper to create a bind pattern. -pub fn bind_pattern(name: &str, inner: Pattern) -> Pattern { - Pattern::Bind(name.to_string(), Box::new(inner)) -} - -/// Helper to create a wildcard pattern. -pub fn wildcard_pattern() -> Pattern { - Pattern::Wildcard -} - -/// Helper to create a literal pattern. -pub fn literal_pattern(lit: Literal) -> Pattern { - Pattern::Literal(lit) -} - -/// Helper to create a struct pattern. -pub fn struct_pattern(name: &str, fields: Vec) -> Pattern { - Pattern::Struct(name.to_string(), fields) -} - -/// Helper to create an array decomposition pattern. -pub fn array_decomp_pattern(head: Pattern, tail: Pattern) -> Pattern { - Pattern::ArrayDecomp(Box::new(head), Box::new(tail)) -} - -/// Helper to create an operator pattern. -pub fn operator_pattern(tag: &str, data: Vec, children: Vec) -> Pattern { - Pattern::Operator(Operator { - tag: tag.to_string(), - data, - children, - }) -} - -/// Helper to create a simple logical operator value. -pub fn create_logical_operator(tag: &str, data: Vec, children: Vec) -> Value { - let op = Operator { - tag: tag.to_string(), - data, - children, - }; - - Value(CoreData::Logical(Materialized(LogicalOp::logical(op)))) -} - -/// Helper to create a simple physical operator value. -pub fn create_physical_operator(tag: &str, data: Vec, children: Vec) -> Value { - let op = Operator { - tag: tag.to_string(), - data, - children, - }; - - Value(CoreData::Physical(Materialized(PhysicalOp::physical(op)))) -} - -/// Runs a test by evaluating the expression and collecting all results with a custom continuation. -pub async fn evaluate_and_collect_with_custom_k( - expr: Arc, - engine: Engine, - harness: TestHarness, - return_k: Continuation, -) -> Vec -where - T: Send + 'static, -{ - let mut results = Vec::new(); - let mut queue = VecDeque::new(); - let response = engine - .evaluate( - expr, - Arc::new(move |value| { - let return_k = return_k.clone(); - Box::pin(async move { EngineResponse::Return(value, return_k) }) - }), - ) - .await; - - queue.push_back(response); - - while let Some(response) = queue.pop_front() { - match response { - EngineResponse::Return(value, return_k) => { - results.push(return_k(value).await); - } - EngineResponse::YieldGroup(group_id, continue_k) => { - harness - .fork_at_group(group_id, continue_k, &mut queue) - .await; - } - EngineResponse::YieldGoal(goal, continue_k) => { - harness.fork_at_goal(&goal, continue_k, &mut queue).await; - } - } - } - - results -} - -/// Runs a test by evaluating the expression and collecting all results. -pub async fn evaluate_and_collect( - expr: Arc, - engine: Engine, - harness: TestHarness, -) -> Vec { - let return_k: Continuation = Arc::new(|value| Box::pin(async move { value })); - - evaluate_and_collect_with_custom_k(expr, engine, harness, return_k).await -} diff --git a/optd-dsl/src/utils/mod.rs b/optd-dsl/src/utils/mod.rs index 3ab35050..e93ebb83 100644 --- a/optd-dsl/src/utils/mod.rs +++ b/optd-dsl/src/utils/mod.rs @@ -1,2 +1,4 @@ pub mod error; pub mod span; +#[cfg(test)] +pub mod tests; From a39c31c09c6915704d2654cd87932af5ee095865 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 18:49:31 -0400 Subject: [PATCH 03/11] Add missing file --- optd-dsl/src/analyzer/map.rs | 502 +++++++++++++++++++++++++++++++++++ optd-dsl/src/utils/tests.rs | 271 +++++++++++++++++++ 2 files changed, 773 insertions(+) create mode 100644 optd-dsl/src/analyzer/map.rs create mode 100644 optd-dsl/src/utils/tests.rs diff --git a/optd-dsl/src/analyzer/map.rs b/optd-dsl/src/analyzer/map.rs new file mode 100644 index 00000000..e7d5eb82 --- /dev/null +++ b/optd-dsl/src/analyzer/map.rs @@ -0,0 +1,502 @@ +//! Custom Map implementation for the DSL. +//! +//! This module provides a Map type that enforces specific key constraints at runtime +//! rather than compile time. It bridges between the flexible HIR Value type and +//! a more restricted internal MapKey type that implements Hash and Eq. +//! +//! The Map implementation supports a subset of Value types as keys: +//! - Literals (except Float64, which doesn't implement Hash/Eq) +//! - Nested structures (Option, Logical, Physical, Fail, Struct, Tuple) +//! if all their contents are also valid key types +//! +//! The implementation provides: +//! - Conversion from Value to MapKey (with runtime validation) +//! - Efficient key lookup (O(1) via HashMap) +//! - Basic map operations (get, concat) + +use super::hir::{ + CoreData, Goal, GroupId, Literal, LogicalOp, Materializable, Operator, PhysicalOp, Value, +}; +use CoreData::*; +use std::collections::HashMap; +use std::hash::Hash; + +/// Map key representation of a logical operator +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct OperatorMapKey { + pub tag: String, + pub data: Vec, + pub children: Vec, +} + +/// Map key representation of a logical operator +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct LogicalMapOpKey { + pub operator: OperatorMapKey, + pub group_id: Option, +} + +/// Map key representation of a physical operator +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct PhysicalMapOpKey { + pub operator: OperatorMapKey, + pub goal: Option, + pub cost: Option, +} + +/// Map key representation of a goal +#[derive(Clone, PartialEq, Eq, Hash)] +pub struct GoalMapKey { + pub group_id: GroupId, + pub properties: MapKey, +} + +/// Map key representation of logical operators (materialized or unmaterialized) +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum LogicalMapKey { + Materialized(LogicalMapOpKey), + UnMaterialized(GroupId), +} + +/// Map key representation of physical operators (materialized or unmaterialized) +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum PhysicalMapKey { + Materialized(PhysicalMapOpKey), + UnMaterialized(GoalMapKey), +} + +/// Internal key type that implements Hash, Eq, PartialEq +/// These are the only types that can be used as keys in the Map +#[derive(Clone, PartialEq, Eq, Hash)] +pub enum MapKey { + /// Primitive literal values (except Float64) + Int64(i64), + String(String), + Bool(bool), + Unit, + + /// Nested structures + Tuple(Vec), + Struct(String, Vec), + + /// Query structures + Logical(Box), + Physical(Box), + + /// or representation + Fail(Box), + + /// None value + None, +} + +/// Custom Map implementation that enforces key type constraints at runtime +#[derive(Clone)] +pub struct Map { + inner: HashMap, +} + +impl Map { + /// Creates a new empty Map + pub fn new() -> Self { + Self { + inner: HashMap::new(), + } + } + + /// Creates a Map from a collection of key-value pairs + pub fn from_pairs(pairs: Vec<(Value, Value)>) -> Self { + pairs.into_iter().fold(Self::new(), |mut map, (k, v)| { + map.insert(k, v); + map + }) + } + + /// Gets a value by key, returning None (as a Value) if not found + pub fn get(&self, key: &Value) -> Value { + self.inner + .get(&value_to_map_key(key)) + .unwrap_or(&Value(None)) + .clone() + } + + /// Combines two maps, with values from other overriding values from self when keys collide + pub fn concat(&self, other: &Map) -> Self { + let mut result = self.clone(); + result.inner.extend(other.inner.clone()); + result + } + + /// Checks if the map is empty + pub fn is_empty(&self) -> bool { + self.inner.is_empty() + } + + /// Inserts a key-value pair into the map + /// Panics if the key type is not supported (e.g., Float64 or Array) + pub fn insert(&mut self, key: Value, value: Value) { + let map_key = value_to_map_key(&key); + self.inner.insert(map_key, value); + } + + /// Converts the map into a Value representation + pub fn to_value(&self) -> Value { + let pairs = self + .inner + .iter() + .map(|(k, v)| (map_key_to_value(k), v.clone())) + .collect(); + Value(Map(pairs)) + } +} + +// Key conversion functions + +/// Converts a Value to a MapKey, enforcing valid key types +/// This performs runtime validation that the key type is supported +/// and will return an error for invalid key types +fn value_to_map_key(value: &Value) -> MapKey { + match &value.0 { + Literal(lit) => match lit { + Literal::Int64(i) => MapKey::Int64(*i), + Literal::String(s) => MapKey::String(s.clone()), + Literal::Bool(b) => MapKey::Bool(*b), + Literal::Unit => MapKey::Unit, + Literal::Float64(_) => panic!("Invalid map key: Float64"), + }, + Tuple(items) => { + let key_items = items.iter().map(value_to_map_key).collect(); + MapKey::Tuple(key_items) + } + Struct(name, fields) => { + let key_fields = fields.iter().map(value_to_map_key).collect(); + MapKey::Struct(name.clone(), key_fields) + } + Logical(materializable) => match materializable { + Materializable::UnMaterialized(group_id) => { + MapKey::Logical(Box::new(LogicalMapKey::UnMaterialized(*group_id))) + } + Materializable::Materialized(logical_op) => { + let map_op = value_to_logical_map_op(logical_op); + MapKey::Logical(Box::new(LogicalMapKey::Materialized(map_op))) + } + }, + Physical(materializable) => match materializable { + Materializable::UnMaterialized(goal) => { + let properties = value_to_map_key(&goal.properties); + let map_goal = GoalMapKey { + group_id: goal.group_id, + properties, + }; + MapKey::Physical(Box::new(PhysicalMapKey::UnMaterialized(map_goal))) + } + Materializable::Materialized(physical_op) => { + let map_op = value_to_physical_map_op(physical_op); + MapKey::Physical(Box::new(PhysicalMapKey::Materialized(map_op))) + } + }, + Fail(inner) => { + let inner_key = value_to_map_key(inner); + MapKey::Fail(Box::new(inner_key)) + } + None => MapKey::None, + _ => panic!("Invalid map key: {:?}", value), + } +} + +/// Converts an Operator to a map key operator +fn value_to_operator_map_key( + operator: &Operator, + value_converter: &dyn Fn(&T) -> MapKey, +) -> OperatorMapKey { + let data = operator.data.iter().map(value_converter).collect(); + + let children = operator.children.iter().map(value_converter).collect(); + + OperatorMapKey { + tag: operator.tag.clone(), + data, + children, + } +} + +/// Converts a LogicalOp to a map key +fn value_to_logical_map_op(logical_op: &LogicalOp) -> LogicalMapOpKey { + let operator = value_to_operator_map_key(&logical_op.operator, &value_to_map_key); + + LogicalMapOpKey { + operator, + group_id: logical_op.group_id, + } +} + +/// Converts a PhysicalOp to a map key +fn value_to_physical_map_op(physical_op: &PhysicalOp) -> PhysicalMapOpKey { + let operator = value_to_operator_map_key(&physical_op.operator, &value_to_map_key); + + let goal = physical_op.goal.as_ref().map(|g| GoalMapKey { + group_id: g.group_id, + properties: value_to_map_key(&g.properties), + }); + + let cost = physical_op.cost.as_ref().map(|c| value_to_map_key(c)); + + PhysicalMapOpKey { + operator, + goal, + cost, + } +} + +/// Converts a MapKey back to a Value +fn map_key_to_value(key: &MapKey) -> Value { + match key { + MapKey::Int64(i) => Value(Literal(Literal::Int64(*i))), + MapKey::String(s) => Value(Literal(Literal::String(s.clone()))), + MapKey::Bool(b) => Value(Literal(Literal::Bool(*b))), + MapKey::Unit => Value(Literal(Literal::Unit)), + MapKey::Tuple(items) => { + let values = items.iter().map(map_key_to_value).collect(); + Value(Tuple(values)) + } + MapKey::Struct(name, fields) => { + let values = fields.iter().map(map_key_to_value).collect(); + Value(Struct(name.clone(), values)) + } + MapKey::Logical(logical_key) => match &**logical_key { + LogicalMapKey::Materialized(op) => { + let operator_value = logical_map_op_to_value(op); + Value(Logical(Materializable::Materialized(operator_value))) + } + LogicalMapKey::UnMaterialized(id) => { + Value(Logical(Materializable::UnMaterialized(*id))) + } + }, + MapKey::Physical(physical_key) => match &**physical_key { + PhysicalMapKey::Materialized(op) => { + let operator_value = physical_map_op_to_value(op); + Value(Physical(Materializable::Materialized(operator_value))) + } + PhysicalMapKey::UnMaterialized(g) => { + let goal = Goal { + group_id: g.group_id, + properties: Box::new(map_key_to_value(&g.properties)), + }; + Value(Physical(Materializable::UnMaterialized(goal))) + } + }, + MapKey::Fail(inner) => { + let inner_value = map_key_to_value(inner); + Value(Fail(Box::new(inner_value))) + } + MapKey::None => Value(None), + } +} + +/// Converts an operator map key back to a Value operator +fn operator_map_key_to_value( + operator: &OperatorMapKey, + key_converter: &dyn Fn(&MapKey) -> T, +) -> Operator { + let data = operator.data.iter().map(key_converter).collect(); + let children = operator.children.iter().map(key_converter).collect(); + + Operator { + tag: operator.tag.clone(), + data, + children, + } +} + +/// Converts a logical map op key back to a Value logical op +fn logical_map_op_to_value(logical_op: &LogicalMapOpKey) -> LogicalOp { + let operator = operator_map_key_to_value(&logical_op.operator, &map_key_to_value); + + LogicalOp { + operator, + group_id: logical_op.group_id, + } +} + +/// Converts a physical map op key back to a Value physical op +fn physical_map_op_to_value(physical_op: &PhysicalMapOpKey) -> PhysicalOp { + let operator = operator_map_key_to_value(&physical_op.operator, &map_key_to_value); + + let goal = physical_op.goal.as_ref().map(|g| Goal { + group_id: g.group_id, + properties: Box::new(map_key_to_value(&g.properties)), + }); + + let cost = physical_op + .cost + .as_ref() + .map(|c| Box::new(map_key_to_value(c))); + + PhysicalOp { + operator, + goal, + cost, + } +} + +#[cfg(test)] +mod tests { + use super::*; + use crate::utils::tests::{ + array_val, assert_values_equal, boolean, int, lit_val, string, struct_val, + }; + + // Helper to create Value literals + fn int_val(i: i64) -> Value { + lit_val(int(i)) + } + + fn bool_val(b: bool) -> Value { + lit_val(boolean(b)) + } + + fn string_val(s: &str) -> Value { + lit_val(string(s)) + } + + fn float_val(f: f64) -> Value { + Value(Literal(Literal::Float64(f))) + } + + fn tuple_val(items: Vec) -> Value { + Value(Tuple(items)) + } + + #[test] + fn test_simple_map_operations() { + let mut map = Map::new(); + + // Insert key-value pairs + map.insert(int_val(1), string_val("one")); + map.insert(int_val(2), string_val("two")); + + // Check retrieval + assert_values_equal(&map.get(&int_val(1)), &string_val("one")); + assert_values_equal(&map.get(&int_val(2)), &string_val("two")); + assert_values_equal(&map.get(&int_val(3)), &Value(None)); // Non-existent key + } + + #[test] + fn test_map_from_pairs() { + let pairs = vec![ + (int_val(1), string_val("one")), + (int_val(2), string_val("two")), + ]; + + let map = Map::from_pairs(pairs); + + assert_values_equal(&map.get(&int_val(1)), &string_val("one")); + assert_values_equal(&map.get(&int_val(2)), &string_val("two")); + } + + #[test] + fn test_map_concat() { + let mut map1 = Map::new(); + map1.insert(int_val(1), string_val("one")); + map1.insert(int_val(2), string_val("two")); + + let mut map2 = Map::new(); + map2.insert(int_val(2), string_val("TWO")); // Overwrite key 2 + map2.insert(int_val(3), string_val("three")); + + let combined = map1.concat(&map2); + + assert_values_equal(&combined.get(&int_val(1)), &string_val("one")); + assert_values_equal(&combined.get(&int_val(2)), &string_val("TWO")); // Overwritten + assert_values_equal(&combined.get(&int_val(3)), &string_val("three")); + } + + #[test] + fn test_complex_keys() { + let mut map = Map::new(); + + // Tuple key + let tuple_key = tuple_val(vec![int_val(1), string_val("test")]); + map.insert(tuple_key.clone(), string_val("tuple value")); + + // Struct key + let struct_key = struct_val("Person", vec![string_val("John"), int_val(30)]); + map.insert(struct_key.clone(), string_val("struct value")); + + // Retrieve values + assert_values_equal(&map.get(&tuple_key), &string_val("tuple value")); + assert_values_equal(&map.get(&struct_key), &string_val("struct value")); + } + + #[test] + #[should_panic(expected = "Invalid map key")] + fn test_float_key_panics() { + let mut map = Map::new(); + map.insert(float_val(3.14), string_val("pi")); + } + + #[test] + #[should_panic(expected = "Invalid map key")] + fn test_tuple_with_float_panics() { + let mut map = Map::new(); + let tuple_with_float = tuple_val(vec![int_val(1), float_val(2.5)]); + map.insert(tuple_with_float, string_val("invalid")); + } + + #[test] + #[should_panic(expected = "Invalid map key")] + fn test_array_key_panics() { + let mut map = Map::new(); + let array_key = array_val(vec![int_val(1), int_val(2)]); + map.insert(array_key, string_val("invalid")); + } + + #[test] + fn test_get_with_invalid_key() { + let mut map = Map::new(); + map.insert(int_val(1), string_val("one")); + + // Getting with a float key should return None + assert_values_equal(&map.get(&float_val(1.0)), &Value(None)); + + // Getting with an array key should return None + let array_key = array_val(vec![int_val(1)]); + assert_values_equal(&map.get(&array_key), &Value(None)); + } + + #[test] + fn test_to_value() { + let mut map = Map::new(); + map.insert(int_val(1), string_val("one")); + map.insert(bool_val(true), int_val(42)); + + let value = map.to_value(); + + if let Map(pairs) = value.0 { + assert_eq!(pairs.len(), 2); + + // Check that the pairs contain our expected key-value pairs + let mut found_int_key = false; + let mut found_bool_key = false; + + for (k, v) in pairs { + match k.0 { + Literal(Literal::Int64(1)) => { + found_int_key = true; + assert_values_equal(&v, &string_val("one")); + } + Literal(Literal::Bool(true)) => { + found_bool_key = true; + assert_values_equal(&v, &int_val(42)); + } + _ => panic!("Unexpected key in map value"), + } + } + + assert!(found_int_key, "Integer key not found in map value"); + assert!(found_bool_key, "Boolean key not found in map value"); + } else { + panic!("to_value() did not return a Map CoreData"); + } + } +} diff --git a/optd-dsl/src/utils/tests.rs b/optd-dsl/src/utils/tests.rs new file mode 100644 index 00000000..7e020d90 --- /dev/null +++ b/optd-dsl/src/utils/tests.rs @@ -0,0 +1,271 @@ +use crate::analyzer::hir::{ + CoreData, Expr, ExprKind, Goal, GroupId, Literal, LogicalOp, MatchArm, Materializable, + Operator, Pattern, PhysicalOp, Value, +}; +use crate::engine::{Continuation, Engine, EngineResponse}; +use Materializable::*; +use std::collections::{HashMap, VecDeque}; +use std::sync::{Arc, Mutex}; + +/// A test harness for the evaluation engine. +#[derive(Clone)] +pub struct TestHarness { + /// Maps group IDs to their materialized values. + group_mappings: Arc>>>, + + /// Maps goals to their implementations. + goal_mappings: Arc>>>, +} + +impl TestHarness { + /// Creates a new test harness. + pub fn new() -> Self { + Self { + group_mappings: Arc::new(Mutex::new(HashMap::new())), + goal_mappings: Arc::new(Mutex::new(HashMap::new())), + } + } + + /// Registers a logical operator value to be returned when a specific group is requested. + pub fn register_group(&self, group_id: GroupId, value: Value) { + let key = format!("{:?}", group_id); + let mut mappings = self.group_mappings.lock().unwrap(); + mappings.entry(key).or_default().push(value); + } + + /// Registers a physical operator value to be returned when a specific goal is requested. + pub fn register_goal(&self, goal: &Goal, value: Value) { + let key = format!("{:?}:{:?}", goal.group_id, goal.properties); + let mut mappings = self.goal_mappings.lock().unwrap(); + mappings.entry(key).or_default().push(value); + } + + /// Forks the evaluation at a specific group ID and collect the responses. + async fn fork_at_group( + &self, + group_id: GroupId, + k: Continuation>, + queue: &mut VecDeque>, + ) where + T: Send + 'static, + { + let key = format!("{:?}", group_id); + let values = { + let mappings = self.group_mappings.lock().unwrap(); + mappings.get(&key).cloned().unwrap_or_default() + }; + + for value in values { + queue.push_back(k(value).await); + } + } + + /// Forks the evaluation at a specific goal and collect the responses. + async fn fork_at_goal( + &self, + goal: &Goal, + k: Continuation>, + queue: &mut VecDeque>, + ) where + T: Send + 'static, + { + let key = format!("{:?}:{:?}", goal.group_id, goal.properties); + let values = { + let mappings = self.goal_mappings.lock().unwrap(); + mappings.get(&key).cloned().unwrap_or_default() + }; + + for value in values { + queue.push_back(k(value).await); + } + } +} + +// Helper to compare Values +pub fn assert_values_equal(v1: &Value, v2: &Value) { + match (&v1.0, &v2.0) { + (CoreData::Literal(l1), CoreData::Literal(l2)) => match (l1, l2) { + (Literal::Int64(i1), Literal::Int64(i2)) => assert_eq!(i1, i2), + (Literal::Float64(f1), Literal::Float64(f2)) => assert_eq!(f1, f2), + (Literal::String(s1), Literal::String(s2)) => assert_eq!(s1, s2), + (Literal::Bool(b1), Literal::Bool(b2)) => assert_eq!(b1, b2), + (Literal::Unit, Literal::Unit) => {} + _ => panic!("Literals don't match: {:?} vs {:?}", l1, l2), + }, + (CoreData::None, CoreData::None) => {} + (CoreData::Tuple(t1), CoreData::Tuple(t2)) => { + assert_eq!(t1.len(), t2.len()); + for (v1, v2) in t1.iter().zip(t2.iter()) { + assert_values_equal(v1, v2); + } + } + (CoreData::Struct(n1, f1), CoreData::Struct(n2, f2)) => { + assert_eq!(n1, n2); + assert_eq!(f1.len(), f2.len()); + for (v1, v2) in f1.iter().zip(f2.iter()) { + assert_values_equal(v1, v2); + } + } + _ => panic!("Values don't match: {:?} vs {:?}", v1.0, v2.0), + } +} + +/// Helper to create a literal expression. +pub fn lit_expr(literal: Literal) -> Arc { + Arc::new(Expr::new(ExprKind::CoreExpr(CoreData::Literal(literal)))) +} + +/// Helper to create a literal value. +pub fn lit_val(literal: Literal) -> Value { + Value(CoreData::Literal(literal)) +} + +/// Helper to create an integer literal. +pub fn int(i: i64) -> Literal { + Literal::Int64(i) +} + +/// Helper to create a string literal. +pub fn string(s: &str) -> Literal { + Literal::String(s.to_string()) +} + +/// Helper to create a boolean literal. +pub fn boolean(b: bool) -> Literal { + Literal::Bool(b) +} + +/// Helper to create a reference expression. +pub fn ref_expr(name: &str) -> Arc { + Arc::new(Expr::new(ExprKind::Ref(name.to_string()))) +} + +/// Helper to create a pattern match arm. +pub fn match_arm(pattern: Pattern, expr: Arc) -> MatchArm { + MatchArm { pattern, expr } +} + +/// Helper to create an array value. +pub fn array_val(items: Vec) -> Value { + Value(CoreData::Array(items)) +} + +/// Helper to create a struct value. +pub fn struct_val(name: &str, fields: Vec) -> Value { + Value(CoreData::Struct(name.to_string(), fields)) +} + +/// Helper to create a pattern matching expression. +pub fn pattern_match_expr(expr: Arc, arms: Vec) -> Arc { + Arc::new(Expr::new(ExprKind::PatternMatch(expr, arms))) +} + +/// Helper to create a bind pattern. +pub fn bind_pattern(name: &str, inner: Pattern) -> Pattern { + Pattern::Bind(name.to_string(), Box::new(inner)) +} + +/// Helper to create a wildcard pattern. +pub fn wildcard_pattern() -> Pattern { + Pattern::Wildcard +} + +/// Helper to create a literal pattern. +pub fn literal_pattern(lit: Literal) -> Pattern { + Pattern::Literal(lit) +} + +/// Helper to create a struct pattern. +pub fn struct_pattern(name: &str, fields: Vec) -> Pattern { + Pattern::Struct(name.to_string(), fields) +} + +/// Helper to create an array decomposition pattern. +pub fn array_decomp_pattern(head: Pattern, tail: Pattern) -> Pattern { + Pattern::ArrayDecomp(Box::new(head), Box::new(tail)) +} + +/// Helper to create an operator pattern. +pub fn operator_pattern(tag: &str, data: Vec, children: Vec) -> Pattern { + Pattern::Operator(Operator { + tag: tag.to_string(), + data, + children, + }) +} + +/// Helper to create a simple logical operator value. +pub fn create_logical_operator(tag: &str, data: Vec, children: Vec) -> Value { + let op = Operator { + tag: tag.to_string(), + data, + children, + }; + + Value(CoreData::Logical(Materialized(LogicalOp::logical(op)))) +} + +/// Helper to create a simple physical operator value. +pub fn create_physical_operator(tag: &str, data: Vec, children: Vec) -> Value { + let op = Operator { + tag: tag.to_string(), + data, + children, + }; + + Value(CoreData::Physical(Materialized(PhysicalOp::physical(op)))) +} + +/// Runs a test by evaluating the expression and collecting all results with a custom continuation. +pub async fn evaluate_and_collect_with_custom_k( + expr: Arc, + engine: Engine, + harness: TestHarness, + return_k: Continuation, +) -> Vec +where + T: Send + 'static, +{ + let mut results = Vec::new(); + let mut queue = VecDeque::new(); + let response = engine + .evaluate( + expr, + Arc::new(move |value| { + let return_k = return_k.clone(); + Box::pin(async move { EngineResponse::Return(value, return_k) }) + }), + ) + .await; + + queue.push_back(response); + + while let Some(response) = queue.pop_front() { + match response { + EngineResponse::Return(value, return_k) => { + results.push(return_k(value).await); + } + EngineResponse::YieldGroup(group_id, continue_k) => { + harness + .fork_at_group(group_id, continue_k, &mut queue) + .await; + } + EngineResponse::YieldGoal(goal, continue_k) => { + harness.fork_at_goal(&goal, continue_k, &mut queue).await; + } + } + } + + results +} + +/// Runs a test by evaluating the expression and collecting all results. +pub async fn evaluate_and_collect( + expr: Arc, + engine: Engine, + harness: TestHarness, +) -> Vec { + let return_k: Continuation = Arc::new(|value| Box::pin(async move { value })); + + evaluate_and_collect_with_custom_k(expr, engine, harness, return_k).await +} From 6186e75442e8a8ffa047f99548215b320fd3d54b Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 19:00:40 -0400 Subject: [PATCH 04/11] Finish implementation of Map --- optd-dsl/src/analyzer/map.rs | 404 ++++++++++++++++++----------------- 1 file changed, 209 insertions(+), 195 deletions(-) diff --git a/optd-dsl/src/analyzer/map.rs b/optd-dsl/src/analyzer/map.rs index e7d5eb82..f6d4e069 100644 --- a/optd-dsl/src/analyzer/map.rs +++ b/optd-dsl/src/analyzer/map.rs @@ -15,14 +15,14 @@ //! - Basic map operations (get, concat) use super::hir::{ - CoreData, Goal, GroupId, Literal, LogicalOp, Materializable, Operator, PhysicalOp, Value, + CoreData, GroupId, Literal, LogicalOp, Materializable, Operator, PhysicalOp, Value, }; use CoreData::*; use std::collections::HashMap; use std::hash::Hash; /// Map key representation of a logical operator -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct OperatorMapKey { pub tag: String, pub data: Vec, @@ -30,14 +30,14 @@ pub struct OperatorMapKey { } /// Map key representation of a logical operator -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct LogicalMapOpKey { pub operator: OperatorMapKey, pub group_id: Option, } /// Map key representation of a physical operator -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct PhysicalMapOpKey { pub operator: OperatorMapKey, pub goal: Option, @@ -45,21 +45,21 @@ pub struct PhysicalMapOpKey { } /// Map key representation of a goal -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub struct GoalMapKey { pub group_id: GroupId, pub properties: MapKey, } /// Map key representation of logical operators (materialized or unmaterialized) -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum LogicalMapKey { Materialized(LogicalMapOpKey), UnMaterialized(GroupId), } /// Map key representation of physical operators (materialized or unmaterialized) -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum PhysicalMapKey { Materialized(PhysicalMapOpKey), UnMaterialized(GoalMapKey), @@ -67,7 +67,7 @@ pub enum PhysicalMapKey { /// Internal key type that implements Hash, Eq, PartialEq /// These are the only types that can be used as keys in the Map -#[derive(Clone, PartialEq, Eq, Hash)] +#[derive(Clone, PartialEq, Eq, Hash, Debug)] pub enum MapKey { /// Primitive literal values (except Float64) Int64(i64), @@ -107,7 +107,8 @@ impl Map { /// Creates a Map from a collection of key-value pairs pub fn from_pairs(pairs: Vec<(Value, Value)>) -> Self { pairs.into_iter().fold(Self::new(), |mut map, (k, v)| { - map.insert(k, v); + let map_key = value_to_map_key(&k); + map.inner.insert(map_key, v); map }) } @@ -126,28 +127,6 @@ impl Map { result.inner.extend(other.inner.clone()); result } - - /// Checks if the map is empty - pub fn is_empty(&self) -> bool { - self.inner.is_empty() - } - - /// Inserts a key-value pair into the map - /// Panics if the key type is not supported (e.g., Float64 or Array) - pub fn insert(&mut self, key: Value, value: Value) { - let map_key = value_to_map_key(&key); - self.inner.insert(map_key, value); - } - - /// Converts the map into a Value representation - pub fn to_value(&self) -> Value { - let pairs = self - .inner - .iter() - .map(|(k, v)| (map_key_to_value(k), v.clone())) - .collect(); - Value(Map(pairs)) - } } // Key conversion functions @@ -248,137 +227,72 @@ fn value_to_physical_map_op(physical_op: &PhysicalOp) -> PhysicalMapOpKey } } -/// Converts a MapKey back to a Value -fn map_key_to_value(key: &MapKey) -> Value { - match key { - MapKey::Int64(i) => Value(Literal(Literal::Int64(*i))), - MapKey::String(s) => Value(Literal(Literal::String(s.clone()))), - MapKey::Bool(b) => Value(Literal(Literal::Bool(*b))), - MapKey::Unit => Value(Literal(Literal::Unit)), - MapKey::Tuple(items) => { - let values = items.iter().map(map_key_to_value).collect(); - Value(Tuple(values)) - } - MapKey::Struct(name, fields) => { - let values = fields.iter().map(map_key_to_value).collect(); - Value(Struct(name.clone(), values)) - } - MapKey::Logical(logical_key) => match &**logical_key { - LogicalMapKey::Materialized(op) => { - let operator_value = logical_map_op_to_value(op); - Value(Logical(Materializable::Materialized(operator_value))) - } - LogicalMapKey::UnMaterialized(id) => { - Value(Logical(Materializable::UnMaterialized(*id))) - } - }, - MapKey::Physical(physical_key) => match &**physical_key { - PhysicalMapKey::Materialized(op) => { - let operator_value = physical_map_op_to_value(op); - Value(Physical(Materializable::Materialized(operator_value))) - } - PhysicalMapKey::UnMaterialized(g) => { - let goal = Goal { - group_id: g.group_id, - properties: Box::new(map_key_to_value(&g.properties)), - }; - Value(Physical(Materializable::UnMaterialized(goal))) - } - }, - MapKey::Fail(inner) => { - let inner_value = map_key_to_value(inner); - Value(Fail(Box::new(inner_value))) - } - MapKey::None => Value(None), - } -} - -/// Converts an operator map key back to a Value operator -fn operator_map_key_to_value( - operator: &OperatorMapKey, - key_converter: &dyn Fn(&MapKey) -> T, -) -> Operator { - let data = operator.data.iter().map(key_converter).collect(); - let children = operator.children.iter().map(key_converter).collect(); - - Operator { - tag: operator.tag.clone(), - data, - children, - } -} - -/// Converts a logical map op key back to a Value logical op -fn logical_map_op_to_value(logical_op: &LogicalMapOpKey) -> LogicalOp { - let operator = operator_map_key_to_value(&logical_op.operator, &map_key_to_value); - - LogicalOp { - operator, - group_id: logical_op.group_id, - } -} - -/// Converts a physical map op key back to a Value physical op -fn physical_map_op_to_value(physical_op: &PhysicalMapOpKey) -> PhysicalOp { - let operator = operator_map_key_to_value(&physical_op.operator, &map_key_to_value); - - let goal = physical_op.goal.as_ref().map(|g| Goal { - group_id: g.group_id, - properties: Box::new(map_key_to_value(&g.properties)), - }); - - let cost = physical_op - .cost - .as_ref() - .map(|c| Box::new(map_key_to_value(c))); - - PhysicalOp { - operator, - goal, - cost, - } -} - #[cfg(test)] mod tests { - use super::*; use crate::utils::tests::{ - array_val, assert_values_equal, boolean, int, lit_val, string, struct_val, + assert_values_equal, create_logical_operator, create_physical_operator, }; + use super::*; + // Helper to create Value literals fn int_val(i: i64) -> Value { - lit_val(int(i)) + Value(Literal(Literal::Int64(i))) } fn bool_val(b: bool) -> Value { - lit_val(boolean(b)) + Value(Literal(Literal::Bool(b))) } fn string_val(s: &str) -> Value { - lit_val(string(s)) + Value(Literal(Literal::String(s.to_string()))) } fn float_val(f: f64) -> Value { Value(Literal(Literal::Float64(f))) } + fn unit_val() -> Value { + Value(Literal(Literal::Unit)) + } + fn tuple_val(items: Vec) -> Value { Value(Tuple(items)) } + fn struct_val(name: &str, fields: Vec) -> Value { + Value(Struct(name.to_string(), fields)) + } + + fn array_val(items: Vec) -> Value { + Value(Array(items)) + } + + fn none_val() -> Value { + Value(None) + } + + fn fail_val(inner: Value) -> Value { + Value(Fail(Box::new(inner))) + } + #[test] fn test_simple_map_operations() { let mut map = Map::new(); + let map_key1 = value_to_map_key(&int_val(1)); + let map_key2 = value_to_map_key(&int_val(2)); - // Insert key-value pairs - map.insert(int_val(1), string_val("one")); - map.insert(int_val(2), string_val("two")); + // Insert key-value pairs directly into inner HashMap + map.inner.insert(map_key1, string_val("one")); + map.inner.insert(map_key2, string_val("two")); // Check retrieval assert_values_equal(&map.get(&int_val(1)), &string_val("one")); assert_values_equal(&map.get(&int_val(2)), &string_val("two")); - assert_values_equal(&map.get(&int_val(3)), &Value(None)); // Non-existent key + assert_values_equal(&map.get(&int_val(3)), &none_val()); // Non-existent key + + // Check map size + assert_eq!(map.inner.len(), 2); } #[test] @@ -392,111 +306,211 @@ mod tests { assert_values_equal(&map.get(&int_val(1)), &string_val("one")); assert_values_equal(&map.get(&int_val(2)), &string_val("two")); + assert_eq!(map.inner.len(), 2); } #[test] fn test_map_concat() { let mut map1 = Map::new(); - map1.insert(int_val(1), string_val("one")); - map1.insert(int_val(2), string_val("two")); + map1.inner + .insert(value_to_map_key(&int_val(1)), string_val("one")); + map1.inner + .insert(value_to_map_key(&int_val(2)), string_val("two")); let mut map2 = Map::new(); - map2.insert(int_val(2), string_val("TWO")); // Overwrite key 2 - map2.insert(int_val(3), string_val("three")); + map2.inner + .insert(value_to_map_key(&int_val(2)), string_val("TWO")); // Overwrite key 2 + map2.inner + .insert(value_to_map_key(&int_val(3)), string_val("three")); let combined = map1.concat(&map2); assert_values_equal(&combined.get(&int_val(1)), &string_val("one")); assert_values_equal(&combined.get(&int_val(2)), &string_val("TWO")); // Overwritten assert_values_equal(&combined.get(&int_val(3)), &string_val("three")); + assert_eq!(combined.inner.len(), 3); } #[test] - fn test_complex_keys() { + fn test_various_key_types() { let mut map = Map::new(); - // Tuple key - let tuple_key = tuple_val(vec![int_val(1), string_val("test")]); - map.insert(tuple_key.clone(), string_val("tuple value")); - - // Struct key - let struct_key = struct_val("Person", vec![string_val("John"), int_val(30)]); - map.insert(struct_key.clone(), string_val("struct value")); - - // Retrieve values - assert_values_equal(&map.get(&tuple_key), &string_val("tuple value")); - assert_values_equal(&map.get(&struct_key), &string_val("struct value")); + // Basic literals + map.inner + .insert(value_to_map_key(&int_val(1)), string_val("int")); + map.inner + .insert(value_to_map_key(&bool_val(true)), string_val("bool")); + map.inner + .insert(value_to_map_key(&string_val("key")), string_val("string")); + map.inner + .insert(value_to_map_key(&unit_val()), string_val("unit")); + + // Compound types + map.inner.insert( + value_to_map_key(&tuple_val(vec![int_val(1), string_val("test")])), + string_val("tuple"), + ); + map.inner.insert( + value_to_map_key(&struct_val("Person", vec![string_val("John"), int_val(30)])), + string_val("struct"), + ); + map.inner + .insert(value_to_map_key(&none_val()), string_val("none")); + map.inner.insert( + value_to_map_key(&fail_val(string_val("error"))), + string_val("fail"), + ); + + // Operator types + let logical_op = create_logical_operator("filter", vec![int_val(1)], vec![]); + let physical_op = create_physical_operator("scan", vec![string_val("table")], vec![]); + map.inner + .insert(value_to_map_key(&logical_op), string_val("logical")); + map.inner + .insert(value_to_map_key(&physical_op), string_val("physical")); + + // Verify all keys work + assert_values_equal(&map.get(&int_val(1)), &string_val("int")); + assert_values_equal(&map.get(&bool_val(true)), &string_val("bool")); + assert_values_equal(&map.get(&string_val("key")), &string_val("string")); + assert_values_equal(&map.get(&unit_val()), &string_val("unit")); + assert_values_equal( + &map.get(&tuple_val(vec![int_val(1), string_val("test")])), + &string_val("tuple"), + ); + assert_values_equal( + &map.get(&struct_val("Person", vec![string_val("John"), int_val(30)])), + &string_val("struct"), + ); + assert_values_equal(&map.get(&none_val()), &string_val("none")); + assert_values_equal( + &map.get(&fail_val(string_val("error"))), + &string_val("fail"), + ); + assert_values_equal(&map.get(&logical_op), &string_val("logical")); + assert_values_equal(&map.get(&physical_op), &string_val("physical")); + + assert_eq!(map.inner.len(), 10); } #[test] - #[should_panic(expected = "Invalid map key")] + #[should_panic(expected = "Invalid map key: Float64")] fn test_float_key_panics() { - let mut map = Map::new(); - map.insert(float_val(3.14), string_val("pi")); + value_to_map_key(&float_val(3.14)); } #[test] - #[should_panic(expected = "Invalid map key")] + #[should_panic(expected = "Invalid map key: Float64")] fn test_tuple_with_float_panics() { - let mut map = Map::new(); let tuple_with_float = tuple_val(vec![int_val(1), float_val(2.5)]); - map.insert(tuple_with_float, string_val("invalid")); + value_to_map_key(&tuple_with_float); } #[test] - #[should_panic(expected = "Invalid map key")] + #[should_panic] fn test_array_key_panics() { - let mut map = Map::new(); let array_key = array_val(vec![int_val(1), int_val(2)]); - map.insert(array_key, string_val("invalid")); + value_to_map_key(&array_key); } #[test] - fn test_get_with_invalid_key() { - let mut map = Map::new(); - map.insert(int_val(1), string_val("one")); + fn test_empty_map() { + let map = Map::new(); + assert_eq!(map.inner.len(), 0); + assert!(map.inner.is_empty()); + } + + #[test] + fn test_map_key_conversion() { + // Test conversion between Value and MapKey + let values = vec![ + int_val(42), + string_val("hello"), + bool_val(true), + unit_val(), + tuple_val(vec![int_val(1), bool_val(false)]), + struct_val("Test", vec![string_val("field"), int_val(123)]), + none_val(), + fail_val(int_val(404)), + ]; + + for value in values { + // Convert Value to MapKey + let map_key = value_to_map_key(&value); - // Getting with a float key should return None - assert_values_equal(&map.get(&float_val(1.0)), &Value(None)); + // Use the map_key in a map + let mut map = Map::new(); + map.inner.insert(map_key, string_val("value")); - // Getting with an array key should return None - let array_key = array_val(vec![int_val(1)]); - assert_values_equal(&map.get(&array_key), &Value(None)); + // Verify we can retrieve with the original value + assert_values_equal(&map.get(&value), &string_val("value")); + } } #[test] - fn test_to_value() { + fn test_deep_nesting() { + // Create deeply nested structure to test recursive key conversion + let nested_value = struct_val( + "Root", + vec![ + tuple_val(vec![ + int_val(1), + struct_val( + "Inner", + vec![ + string_val("deep"), + tuple_val(vec![bool_val(true), unit_val()]), + ], + ), + ]), + fail_val(none_val()), + ], + ); + + // Should not panic + let key = value_to_map_key(&nested_value); + + // Use it as a key let mut map = Map::new(); - map.insert(int_val(1), string_val("one")); - map.insert(bool_val(true), int_val(42)); - - let value = map.to_value(); - - if let Map(pairs) = value.0 { - assert_eq!(pairs.len(), 2); - - // Check that the pairs contain our expected key-value pairs - let mut found_int_key = false; - let mut found_bool_key = false; - - for (k, v) in pairs { - match k.0 { - Literal(Literal::Int64(1)) => { - found_int_key = true; - assert_values_equal(&v, &string_val("one")); - } - Literal(Literal::Bool(true)) => { - found_bool_key = true; - assert_values_equal(&v, &int_val(42)); - } - _ => panic!("Unexpected key in map value"), - } - } + map.inner.insert(key, string_val("complex")); - assert!(found_int_key, "Integer key not found in map value"); - assert!(found_bool_key, "Boolean key not found in map value"); - } else { - panic!("to_value() did not return a Map CoreData"); - } + // Verify we can retrieve it + assert_values_equal(&map.get(&nested_value), &string_val("complex")); + } + + #[test] + fn test_operator_key_conversion() { + // Test LogicalOp conversion + let logical_op = create_logical_operator( + "join", + vec![string_val("id"), int_val(10)], + vec![int_val(1), int_val(2)], + ); + + let logical_key = value_to_map_key(&logical_op); + assert!(matches!(logical_key, MapKey::Logical(_))); + + // Test PhysicalOp conversion + let physical_op = create_physical_operator( + "hash_join", + vec![string_val("id")], + vec![int_val(1), int_val(2)], + ); + + let physical_key = value_to_map_key(&physical_op); + assert!(matches!(physical_key, MapKey::Physical(_))); + } + + #[test] + fn test_key_equality() { + // Test that equal values create equal map keys + let key1 = value_to_map_key(&tuple_val(vec![int_val(1), string_val("test")])); + let key2 = value_to_map_key(&tuple_val(vec![int_val(1), string_val("test")])); + + assert_eq!(key1, key2); + + // Test that different values create different map keys + let key3 = value_to_map_key(&tuple_val(vec![int_val(2), string_val("test")])); + assert_ne!(key1, key3); } } From 33a790f766dc140cf82b1ddef3eff7c00e7b647b Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 20:01:44 -0400 Subject: [PATCH 05/11] Integrate hashable maps --- optd-dsl/src/analyzer/hir.rs | 8 +- optd-dsl/src/analyzer/map.rs | 41 ++-- .../src/analyzer/semantic_checker/error.rs | 5 +- optd-dsl/src/analyzer/type.rs | 6 - optd-dsl/src/analyzer/type_checker/error.rs | 2 +- optd-dsl/src/engine/eval/binary.rs | 73 +++--- optd-dsl/src/engine/eval/core.rs | 113 +-------- optd-dsl/src/engine/eval/expr.rs | 229 +++++++++++++++++- optd-dsl/src/engine/mod.rs | 3 +- optd-dsl/src/lexer/error.rs | 6 +- optd-dsl/src/parser/error.rs | 6 +- optd-dsl/src/parser/function.rs | 24 +- optd-dsl/src/parser/module.rs | 9 +- optd-dsl/src/utils/error.rs | 8 +- 14 files changed, 334 insertions(+), 199 deletions(-) diff --git a/optd-dsl/src/analyzer/hir.rs b/optd-dsl/src/analyzer/hir.rs index 76b8020d..97967766 100644 --- a/optd-dsl/src/analyzer/hir.rs +++ b/optd-dsl/src/analyzer/hir.rs @@ -16,6 +16,7 @@ //! intermediate representations through the bridge modules. use super::context::Context; +use super::map::Map; use super::r#type::Type; use crate::utils::span::Span; use std::fmt::Debug; @@ -177,7 +178,7 @@ pub enum CoreData { /// Fixed collection of possibly heterogeneous values Tuple(Vec), /// Key-value associations - Map(Vec<(T, T)>), + Map(Map), /// Named structure with fields Struct(Identifier, Vec), /// Function or closure @@ -235,6 +236,9 @@ impl Expr { } } +/// Type alias for map entries to reduce type complexity +pub type MapEntries = Vec<(Arc>, Arc>)>; + /// Expression node kinds without metadata #[derive(Debug, Clone)] pub enum ExprKind { @@ -250,6 +254,8 @@ pub enum ExprKind { Unary(UnaryOp, Arc>), /// Function call Call(Arc>, Vec>>), + /// Map expression + Map(MapEntries), /// Variable reference Ref(Identifier), /// Core expression diff --git a/optd-dsl/src/analyzer/map.rs b/optd-dsl/src/analyzer/map.rs index f6d4e069..3094c97f 100644 --- a/optd-dsl/src/analyzer/map.rs +++ b/optd-dsl/src/analyzer/map.rs @@ -91,22 +91,15 @@ pub enum MapKey { } /// Custom Map implementation that enforces key type constraints at runtime -#[derive(Clone)] +#[derive(Clone, Debug, Default)] pub struct Map { inner: HashMap, } impl Map { - /// Creates a new empty Map - pub fn new() -> Self { - Self { - inner: HashMap::new(), - } - } - /// Creates a Map from a collection of key-value pairs pub fn from_pairs(pairs: Vec<(Value, Value)>) -> Self { - pairs.into_iter().fold(Self::new(), |mut map, (k, v)| { + pairs.into_iter().fold(Self::default(), |mut map, (k, v)| { let map_key = value_to_map_key(&k); map.inner.insert(map_key, v); map @@ -122,10 +115,8 @@ impl Map { } /// Combines two maps, with values from other overriding values from self when keys collide - pub fn concat(&self, other: &Map) -> Self { - let mut result = self.clone(); - result.inner.extend(other.inner.clone()); - result + pub fn concat(&mut self, other: Map) { + self.inner.extend(other.inner); } } @@ -278,7 +269,7 @@ mod tests { #[test] fn test_simple_map_operations() { - let mut map = Map::new(); + let mut map = Map::default(); let map_key1 = value_to_map_key(&int_val(1)); let map_key2 = value_to_map_key(&int_val(2)); @@ -311,29 +302,29 @@ mod tests { #[test] fn test_map_concat() { - let mut map1 = Map::new(); + let mut map1 = Map::default(); map1.inner .insert(value_to_map_key(&int_val(1)), string_val("one")); map1.inner .insert(value_to_map_key(&int_val(2)), string_val("two")); - let mut map2 = Map::new(); + let mut map2 = Map::default(); map2.inner .insert(value_to_map_key(&int_val(2)), string_val("TWO")); // Overwrite key 2 map2.inner .insert(value_to_map_key(&int_val(3)), string_val("three")); - let combined = map1.concat(&map2); + map1.concat(map2); - assert_values_equal(&combined.get(&int_val(1)), &string_val("one")); - assert_values_equal(&combined.get(&int_val(2)), &string_val("TWO")); // Overwritten - assert_values_equal(&combined.get(&int_val(3)), &string_val("three")); - assert_eq!(combined.inner.len(), 3); + assert_values_equal(&map1.get(&int_val(1)), &string_val("one")); + assert_values_equal(&map1.get(&int_val(2)), &string_val("TWO")); // Overwritten + assert_values_equal(&map1.get(&int_val(3)), &string_val("three")); + assert_eq!(map1.inner.len(), 3); } #[test] fn test_various_key_types() { - let mut map = Map::new(); + let mut map = Map::default(); // Basic literals map.inner @@ -415,7 +406,7 @@ mod tests { #[test] fn test_empty_map() { - let map = Map::new(); + let map = Map::default(); assert_eq!(map.inner.len(), 0); assert!(map.inner.is_empty()); } @@ -439,7 +430,7 @@ mod tests { let map_key = value_to_map_key(&value); // Use the map_key in a map - let mut map = Map::new(); + let mut map = Map::default(); map.inner.insert(map_key, string_val("value")); // Verify we can retrieve with the original value @@ -471,7 +462,7 @@ mod tests { let key = value_to_map_key(&nested_value); // Use it as a key - let mut map = Map::new(); + let mut map = Map::default(); map.inner.insert(key, string_val("complex")); // Verify we can retrieve it diff --git a/optd-dsl/src/analyzer/semantic_checker/error.rs b/optd-dsl/src/analyzer/semantic_checker/error.rs index 143edc31..3857972e 100644 --- a/optd-dsl/src/analyzer/semantic_checker/error.rs +++ b/optd-dsl/src/analyzer/semantic_checker/error.rs @@ -18,16 +18,17 @@ pub enum SemanticError { impl SemanticError { /// Creates a new error for duplicate ADT names - pub fn new_duplicate_adt(name: String, first_span: Span, duplicate_span: Span) -> Self { + pub fn new_duplicate_adt(name: String, first_span: Span, duplicate_span: Span) -> Box { Self::DuplicateAdt { name, first_span, duplicate_span, } + .into() } } -impl Diagnose for SemanticError { +impl Diagnose for Box { fn report(&self) -> Report { todo!() } diff --git a/optd-dsl/src/analyzer/type.rs b/optd-dsl/src/analyzer/type.rs index 88d490af..77b907de 100644 --- a/optd-dsl/src/analyzer/type.rs +++ b/optd-dsl/src/analyzer/type.rs @@ -330,12 +330,6 @@ mod type_registry_tests { let result = registry.register_adt(&car2); assert!(result.is_err()); - - if let Err(CompileError::SemanticError(SemanticError::DuplicateAdt { name, .. })) = result { - assert_eq!(name, "Car"); - } else { - panic!("Expected DuplicateAdt error"); - } } #[test] diff --git a/optd-dsl/src/analyzer/type_checker/error.rs b/optd-dsl/src/analyzer/type_checker/error.rs index 5a9774fe..16f2edd8 100644 --- a/optd-dsl/src/analyzer/type_checker/error.rs +++ b/optd-dsl/src/analyzer/type_checker/error.rs @@ -5,7 +5,7 @@ use crate::utils::{error::Diagnose, span::Span}; #[derive(Debug)] pub struct TypeError {} -impl Diagnose for TypeError { +impl Diagnose for Box { fn report(&self) -> Report { todo!() } diff --git a/optd-dsl/src/engine/eval/binary.rs b/optd-dsl/src/engine/eval/binary.rs index 106b66c3..2949779a 100644 --- a/optd-dsl/src/engine/eval/binary.rs +++ b/optd-dsl/src/engine/eval/binary.rs @@ -79,10 +79,9 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { } // Map concatenation (joins two maps). - (Map(l), Concat, Map(r)) => { - let mut result = l.clone(); - result.extend(r.into_iter()); - Value(Map(result)) + (Map(mut l), Concat, Map(r)) => { + l.concat(r); + Value(Map(l)) } // Any other combination of value types or operations is not supported. @@ -96,7 +95,6 @@ mod tests { use BinOp::*; use CoreData::*; use Literal::*; - use std::collections::HashMap; use super::eval_binary_op; @@ -329,32 +327,51 @@ mod tests { #[test] fn test_map_concatenation() { - // Create two maps - let map1 = Value(Map(vec![(string("a"), int(1)), (string("b"), int(2))])); - - let map2 = Value(Map(vec![(string("c"), int(3)), (string("d"), int(4))])); + use crate::analyzer::map::Map; + + // Create two maps using Map::from_pairs + let map1 = Value(Map(Map::from_pairs(vec![ + (string("a"), int(1)), + (string("b"), int(2)), + ]))); + let map2 = Value(Map(Map::from_pairs(vec![ + (string("c"), int(3)), + (string("d"), int(4)), + ]))); // Concatenate maps if let Map(result) = eval_binary_op(map1, &Concat, map2).0 { - assert_eq!(result.len(), 4); - - // Convert to a HashMap for easier testing - let map: HashMap = result - .iter() - .map(|(k, v)| { - if let (Literal(String(key)), Literal(Int64(value))) = (&k.0, &v.0) { - (key.clone(), *value) - } else { - panic!("Expected String keys and Int64 values"); - } - }) - .collect(); - - // Check elements - assert_eq!(map.get("a"), Some(&1)); - assert_eq!(map.get("b"), Some(&2)); - assert_eq!(map.get("c"), Some(&3)); - assert_eq!(map.get("d"), Some(&4)); + // Check each key-value pair is accessible + if let Literal(Int64(v)) = result.get(&string("a")).0 { + assert_eq!(v, 1); + } else { + panic!("Expected Int64 for key 'a'"); + } + + if let Literal(Int64(v)) = result.get(&string("b")).0 { + assert_eq!(v, 2); + } else { + panic!("Expected Int64 for key 'b'"); + } + + if let Literal(Int64(v)) = result.get(&string("c")).0 { + assert_eq!(v, 3); + } else { + panic!("Expected Int64 for key 'c'"); + } + + if let Literal(Int64(v)) = result.get(&string("d")).0 { + assert_eq!(v, 4); + } else { + panic!("Expected Int64 for key 'd'"); + } + + // Check a non-existent key returns None + if let None = result.get(&string("z")).0 { + // This is the expected behavior + } else { + panic!("Expected None for non-existent key"); + } } else { panic!("Expected Map"); } diff --git a/optd-dsl/src/engine/eval/core.rs b/optd-dsl/src/engine/eval/core.rs index 78ce54e9..fbbebed5 100644 --- a/optd-dsl/src/engine/eval/core.rs +++ b/optd-dsl/src/engine/eval/core.rs @@ -34,7 +34,10 @@ where Struct(name, items) => { evaluate_collection(items, move |values| Struct(name, values), engine, k).await } - Map(items) => evaluate_map(items, engine, k).await, + Map(items) => { + // Directly continue with the map value. + k(Value(Map(items))).await + } Function(fun_type) => { // Directly continue with the function value. k(Value(Function(fun_type))).await @@ -80,49 +83,6 @@ where .await } -/// Evaluates a map expression. -/// -/// # Parameters -/// -/// * `items` - The key-value pairs to evaluate. -/// * `engine` - The evaluation engine. -/// * `k` - The continuation to receive evaluation results. -async fn evaluate_map( - items: Vec<(Arc, Arc)>, - engine: Engine, - k: Continuation>, -) -> EngineResponse -where - O: Send + 'static, -{ - // Extract keys and values. - let (keys, values): (Vec>, Vec>) = items.into_iter().unzip(); - - // First evaluate all key expressions. - evaluate_sequence( - keys, - engine.clone(), - Arc::new(move |keys_values| { - Box::pin(capture!([values, engine, k], async move { - // Then evaluate all value expressions. - evaluate_sequence( - values, - engine, - Arc::new(move |values_values| { - Box::pin(capture!([keys_values, k], async move { - // Create a map from keys and values. - let map_items = keys_values.into_iter().zip(values_values).collect(); - k(Value(Map(map_items))).await - })) - }), - ) - .await - })) - }), - ) - .await -} - /// Evaluates a fail expression. /// /// # Parameters @@ -299,71 +259,6 @@ mod tests { } } - /// Test evaluation of map expressions - #[tokio::test] - async fn test_map_evaluation() { - let harness = TestHarness::new(); - let ctx = Context::default(); - let engine = Engine::new(ctx); - - // Create a map expression - let map_expr = Arc::new(Expr::new(CoreExpr(CoreData::Map(vec![ - (lit_expr(string("a")), lit_expr(int(1))), - (lit_expr(string("b")), lit_expr(int(2))), - (lit_expr(string("c")), lit_expr(int(3))), - ])))); - - let results = evaluate_and_collect(map_expr, engine, harness).await; - - // Check result - assert_eq!(results.len(), 1); - match &results[0].0 { - CoreData::Map(items) => { - assert_eq!(items.len(), 3); - - // Find key "a" and check value - let a_found = items.iter().any(|(k, v)| { - if let CoreData::Literal(Literal::String(key)) = &k.0 { - if key == "a" { - if let CoreData::Literal(Literal::Int64(val)) = &v.0 { - return *val == 1; - } - } - } - false - }); - assert!(a_found, "Key 'a' with value 1 not found"); - - // Find key "b" and check value - let b_found = items.iter().any(|(k, v)| { - if let CoreData::Literal(Literal::String(key)) = &k.0 { - if key == "b" { - if let CoreData::Literal(Literal::Int64(val)) = &v.0 { - return *val == 2; - } - } - } - false - }); - assert!(b_found, "Key 'b' with value 2 not found"); - - // Find key "c" and check value - let c_found = items.iter().any(|(k, v)| { - if let CoreData::Literal(Literal::String(key)) = &k.0 { - if key == "c" { - if let CoreData::Literal(Literal::Int64(val)) = &v.0 { - return *val == 3; - } - } - } - false - }); - assert!(c_found, "Key 'c' with value 3 not found"); - } - _ => panic!("Expected map"), - } - } - /// Test evaluation of function expressions #[tokio::test] async fn test_function_evaluation() { diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index 7f892ad0..8cbf5bca 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -1,11 +1,11 @@ use super::{binary::eval_binary_op, unary::eval_unary_op}; use crate::analyzer::hir::{BinOp, CoreData, Expr, FunKind, Identifier, Literal, UnaryOp, Value}; +use crate::analyzer::map::Map; use crate::engine::{Continuation, EngineResponse}; use crate::{ capture, engine::{Engine, utils::evaluate_sequence}, }; -use CoreData::*; use FunKind::*; use std::sync::Arc; @@ -39,7 +39,7 @@ where Arc::new(move |value| { Box::pin(capture!([then_expr, else_expr, engine, k], async move { match value.0 { - Literal(Literal::Bool(b)) => { + CoreData::Literal(Literal::Bool(b)) => { if b { engine.evaluate(then_expr, k).await } else { @@ -218,11 +218,11 @@ where Box::pin(capture!([args, engine, k], async move { match fun_value.0 { // Handle closure (user-defined function). - Function(Closure(params, body)) => { + CoreData::Function(Closure(params, body)) => { evaluate_closure_call(params, body, args, engine, k).await } // Handle Rust UDF (built-in function). - Function(RustUDF(udf)) => { + CoreData::Function(RustUDF(udf)) => { evaluate_rust_udf_call(udf, args, engine, k).await } // Value must be a function. @@ -313,6 +313,49 @@ where .await } +/// Evaluates a map expression. +/// +/// # Parameters +/// +/// * `items` - The key-value pairs to evaluate. +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +pub(crate) async fn evaluate_map( + items: Vec<(Arc, Arc)>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + // Extract keys and values. + let (keys, values): (Vec>, Vec>) = items.into_iter().unzip(); + + // First evaluate all key expressions. + evaluate_sequence( + keys, + engine.clone(), + Arc::new(move |keys_values| { + Box::pin(capture!([values, engine, k], async move { + // Then evaluate all value expressions. + evaluate_sequence( + values, + engine, + Arc::new(move |values_values| { + Box::pin(capture!([keys_values, k], async move { + // Create a map from keys and values. + let map_items = keys_values.into_iter().zip(values_values).collect(); + k(Value(CoreData::Map(Map::from_pairs(map_items)))).await + })) + }), + ) + .await + })) + }), + ) + .await +} + /// Evaluates a reference to a variable. /// /// Looks up the variable in the context and passes its value to the continuation. @@ -344,7 +387,7 @@ where #[cfg(test)] mod tests { use crate::engine::Engine; - use crate::utils::tests::{array_val, ref_expr}; + use crate::utils::tests::{array_val, assert_values_equal, ref_expr}; use crate::{ analyzer::{ context::Context, @@ -574,6 +617,182 @@ mod tests { } } + /// Test to verify the Map implementation works correctly. + #[tokio::test] + async fn test_map_creation() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a map with key-value pairs: { "a": 1, "b": 2, "c": 3 } + let map_expr = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("a")), lit_expr(int(1))), + (lit_expr(string("b")), lit_expr(int(2))), + (lit_expr(string("c")), lit_expr(int(3))), + ]))); + + // Evaluate the map expression + let results = evaluate_and_collect(map_expr, engine.clone(), harness.clone()).await; + + // Check that we got a Map value + assert_eq!(results.len(), 1); + match &results[0].0 { + CoreData::Map(map) => { + // Check that map has the correct key-value pairs + assert_values_equal(&map.get(&lit_val(string("a"))), &lit_val(int(1))); + assert_values_equal(&map.get(&lit_val(string("b"))), &lit_val(int(2))); + assert_values_equal(&map.get(&lit_val(string("c"))), &lit_val(int(3))); + + // Check that non-existent key returns None value + assert_values_equal(&map.get(&lit_val(string("d"))), &Value(CoreData::None)); + } + _ => panic!("Expected Map value"), + } + + // Test map with expressions that need evaluation as keys and values + // Map: { "x" + "y": 10 + 5, "a" + "b": 20 * 2 } + let complex_map_expr = Arc::new(Expr::new(Map(vec![ + ( + Arc::new(Expr::new(Binary( + lit_expr(string("x")), + BinOp::Concat, + lit_expr(string("y")), + ))), + Arc::new(Expr::new(Binary( + lit_expr(int(10)), + BinOp::Add, + lit_expr(int(5)), + ))), + ), + ( + Arc::new(Expr::new(Binary( + lit_expr(string("a")), + BinOp::Concat, + lit_expr(string("b")), + ))), + Arc::new(Expr::new(Binary( + lit_expr(int(20)), + BinOp::Mul, + lit_expr(int(2)), + ))), + ), + ]))); + + // Evaluate the complex map expression + let complex_results = evaluate_and_collect(complex_map_expr, engine, harness).await; + + // Check that we got a Map value with correctly evaluated keys and values + assert_eq!(complex_results.len(), 1); + match &complex_results[0].0 { + CoreData::Map(map) => { + // Check that map has the correct key-value pairs after evaluation + assert_values_equal(&map.get(&lit_val(string("xy"))), &lit_val(int(15))); + assert_values_equal(&map.get(&lit_val(string("ab"))), &lit_val(int(40))); + } + _ => panic!("Expected Map value"), + } + } + + /// Test map operations with nested maps and lookup + #[tokio::test] + async fn test_map_nested_and_lookup() { + let harness = TestHarness::new(); + let mut ctx = Context::default(); + + // Add a map lookup function + ctx.bind( + "get".to_string(), + Value(CoreData::Function(FunKind::RustUDF(|args| { + if args.len() != 2 { + panic!("get function requires 2 arguments"); + } + + match &args[0].0 { + CoreData::Map(map) => map.get(&args[1]), + _ => panic!("First argument must be a map"), + } + }))), + ); + + let engine = Engine::new(ctx); + + // Create a nested map: + // { + // "user": { + // "name": "Alice", + // "age": 30, + // "address": { + // "city": "San Francisco", + // "zip": 94105 + // } + // }, + // "settings": { + // "theme": "dark", + // "notifications": true + // } + // } + + // First, create the address map + let address_map = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("city")), lit_expr(string("San Francisco"))), + (lit_expr(string("zip")), lit_expr(int(94105))), + ]))); + + // Then, create the user map with the nested address map + let user_map = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("name")), lit_expr(string("Alice"))), + (lit_expr(string("age")), lit_expr(int(30))), + (lit_expr(string("address")), address_map), + ]))); + + // Create the settings map + let settings_map = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("theme")), lit_expr(string("dark"))), + (lit_expr(string("notifications")), lit_expr(boolean(true))), + ]))); + + // Finally, create the top-level map + let nested_map_expr = Arc::new(Expr::new(Map(vec![ + (lit_expr(string("user")), user_map), + (lit_expr(string("settings")), settings_map), + ]))); + + // First, evaluate the nested map to bind it to a variable + let program = Arc::new(Expr::new(Let( + "data".to_string(), + nested_map_expr, + // Extract user.address.city using get function + Arc::new(Expr::new(Call( + ref_expr("get"), + vec![ + Arc::new(Expr::new(Call( + ref_expr("get"), + vec![ + Arc::new(Expr::new(Call( + ref_expr("get"), + vec![ref_expr("data"), lit_expr(string("user"))], + ))), + lit_expr(string("address")), + ], + ))), + lit_expr(string("city")), + ], + ))), + ))); + + // Evaluate the program + let results = evaluate_and_collect(program, engine, harness).await; + + // Check that we got the correct value from the nested lookup + assert_eq!(results.len(), 1); + match &results[0].0 { + CoreData::Literal(Literal::String(value)) => { + assert_eq!(value, "San Francisco"); + } + _ => panic!("Expected string value"), + } + } + /// Test complex program with multiple expression types #[tokio::test] async fn test_complex_program() { diff --git a/optd-dsl/src/engine/mod.rs b/optd-dsl/src/engine/mod.rs index 196d4f29..e8a80fe9 100644 --- a/optd-dsl/src/engine/mod.rs +++ b/optd-dsl/src/engine/mod.rs @@ -3,12 +3,12 @@ use crate::analyzer::{ hir::{Expr, ExprKind, Goal, GroupId, Value}, }; use ExprKind::*; -use eval::core::evaluate_core_expr; use eval::expr::{ evaluate_binary_expr, evaluate_function_call, evaluate_if_then_else, evaluate_let_binding, evaluate_reference, evaluate_unary_expr, }; use eval::r#match::evaluate_pattern_match; +use eval::{core::evaluate_core_expr, expr::evaluate_map}; use std::sync::Arc; mod eval; @@ -96,6 +96,7 @@ impl Engine { } Unary(op, expr) => evaluate_unary_expr(op.clone(), expr.clone(), self, k).await, Call(fun, args) => evaluate_function_call(fun.clone(), args.clone(), self, k).await, + Map(map) => evaluate_map(map.clone(), self, k).await, Ref(ident) => evaluate_reference(ident.clone(), self, k).await, CoreExpr(expr) => evaluate_core_expr(expr.clone(), self, k).await, CoreVal(val) => k(val.clone()).await, diff --git a/optd-dsl/src/lexer/error.rs b/optd-dsl/src/lexer/error.rs index 215ab035..aef7c80c 100644 --- a/optd-dsl/src/lexer/error.rs +++ b/optd-dsl/src/lexer/error.rs @@ -16,12 +16,12 @@ pub struct LexerError { impl LexerError { /// Creates a new lexer error from source code and a Chumsky error. - pub fn new(src_code: String, error: Simple) -> Self { - Self { src_code, error } + pub fn new(src_code: String, error: Simple) -> Box { + Self { src_code, error }.into() } } -impl Diagnose for LexerError { +impl Diagnose for Box { fn report(&self) -> Report { match self.error.reason() { SimpleReason::Custom(msg) => { diff --git a/optd-dsl/src/parser/error.rs b/optd-dsl/src/parser/error.rs index 0c708595..d0de0dd3 100644 --- a/optd-dsl/src/parser/error.rs +++ b/optd-dsl/src/parser/error.rs @@ -19,12 +19,12 @@ pub struct ParserError { impl ParserError { /// Creates a new parser error from source code and a Chumsky error. - pub fn new(src_code: String, error: Simple) -> Self { - Self { src_code, error } + pub fn new(src_code: String, error: Simple) -> Box { + Self { src_code, error }.into() } } -impl Diagnose for ParserError { +impl Diagnose for Box { fn report(&self) -> Report { match self.error.reason() { SimpleReason::Unclosed { span, delimiter } => { diff --git a/optd-dsl/src/parser/function.rs b/optd-dsl/src/parser/function.rs index 8d48a07c..5989904b 100644 --- a/optd-dsl/src/parser/function.rs +++ b/optd-dsl/src/parser/function.rs @@ -271,7 +271,9 @@ mod tests { let params = func.value.params.as_ref().unwrap(); assert_eq!(params.len(), 1); assert_eq!(*params[0].value.name.value, "x"); - assert!(matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "T")); + assert!( + matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "T") + ); // Check return type assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "T")); @@ -313,7 +315,9 @@ mod tests { // Check second parameter (key: K) assert_eq!(*params[1].value.name.value, "key"); - assert!(matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "K")); + assert!( + matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "K") + ); // Check return type (V?) if let Type::Questioned(inner) = &*func.value.return_type.value { @@ -346,9 +350,13 @@ mod tests { let params = func.value.params.as_ref().unwrap(); assert_eq!(params.len(), 2); assert_eq!(*params[0].value.name.value, "a"); - assert!(matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "A")); + assert!( + matches!(*params[0].clone().value.ty.value, Type::Identifier(name) if name == "A") + ); assert_eq!(*params[1].value.name.value, "b"); - assert!(matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "B")); + assert!( + matches!(*params[1].clone().value.ty.value, Type::Identifier(name) if name == "B") + ); // Check return type assert!(matches!(*func.value.return_type.value, Type::Identifier(name) if name == "C")); @@ -408,7 +416,9 @@ mod tests { assert!(func.value.receiver.is_some()); if let Some(receiver) = &func.value.receiver { assert_eq!(*receiver.value.name.value, "self"); - assert!(matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person")); + assert!( + matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person") + ); } // Check params. @@ -438,7 +448,9 @@ mod tests { assert!(func.value.receiver.is_some()); if let Some(receiver) = &func.value.receiver { assert_eq!(*receiver.value.name.value, "self"); - assert!(matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person")); + assert!( + matches!(&*receiver.value.ty.value, Type::Identifier(name) if name == "Person") + ); } // Check params. diff --git a/optd-dsl/src/parser/module.rs b/optd-dsl/src/parser/module.rs index 6b076df9..9abcc0e8 100644 --- a/optd-dsl/src/parser/module.rs +++ b/optd-dsl/src/parser/module.rs @@ -1,12 +1,11 @@ -use crate::lexer::tokens::Token; -use crate::utils::error::CompileError; -use crate::utils::span::Span; -use chumsky::{Stream, prelude::*}; - use super::adt::adt_parser; use super::ast::{Item, Module}; use super::error::ParserError; use super::function::function_parser; +use crate::lexer::tokens::Token; +use crate::utils::error::CompileError; +use crate::utils::span::Span; +use chumsky::{Stream, prelude::*}; /// Parses a vector of tokens into a module AST. /// Uses Chumsky for parsing and Ariadne for error reporting. diff --git a/optd-dsl/src/utils/error.rs b/optd-dsl/src/utils/error.rs index f9d2dce7..e1cf7fb0 100644 --- a/optd-dsl/src/utils/error.rs +++ b/optd-dsl/src/utils/error.rs @@ -44,14 +44,14 @@ pub trait Diagnose { #[derive(Debug)] pub enum CompileError { /// Errors occurring during the lexing/tokenization phase - LexerError(LexerError), + LexerError(Box), /// Errors occurring during the parsing/syntax analysis phase - ParserError(ParserError), + ParserError(Box), /// Errors occurring during the semantic analysis phase - SemanticError(SemanticError), + SemanticError(Box), /// Errors occurring during the type analysis phase - TypeError(TypeError), + TypeError(Box), } From 55f09c4b7fd1c88fba7c0b2e5a6681a6e871255c Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 20:37:22 -0400 Subject: [PATCH 06/11] Enrich call evaluate --- optd-dsl/src/engine/eval/expr.rs | 337 +++++++++++++++++++++++++++++-- optd-dsl/src/engine/mod.rs | 4 +- 2 files changed, 323 insertions(+), 18 deletions(-) diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index 8cbf5bca..affc34fa 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -6,7 +6,6 @@ use crate::{ capture, engine::{Engine, utils::evaluate_sequence}, }; -use FunKind::*; use std::sync::Arc; /// Evaluates an if-then-else expression. @@ -174,13 +173,13 @@ pub(crate) async fn evaluate_unary_expr( where O: Send + 'static, { - // Evaluate the operand, then apply the unary operation + // Evaluate the operand, then apply the unary operation. engine .evaluate( expr, Arc::new(move |value| { Box::pin(capture!([op, k], async move { - // Apply the unary operation and pass result to continuation + // Apply the unary operation and pass result to continuation. let result = eval_unary_op(&op, value); k(result).await })) @@ -189,19 +188,22 @@ where .await } -/// Evaluates a function call expression. +/// Evaluates a call expression. /// -/// First evaluates the function expression, then the arguments, and finally applies the function to +/// First evaluates the called expression, then the arguments, and finally applies the call to /// the arguments, passing results to the continuation. /// +/// Extended to support indexing into collections (Array, Tuple, Struct, Map) when the called +/// expression evaluates to one of these types and a single argument is provided. +/// /// # Parameters /// -/// * `fun` - The function expression to evaluate. +/// * `called` - The called expression to evaluate. /// * `args` - The argument expressions to evaluate. /// * `engine` - The evaluation engine. /// * `k` - The continuation to receive evaluation results. -pub(crate) async fn evaluate_function_call( - fun: Arc, +pub(crate) async fn evaluate_call( + called: Arc, args: Vec>, engine: Engine, k: Continuation>, @@ -213,20 +215,135 @@ where engine .clone() .evaluate( - fun, - Arc::new(move |fun_value| { + called, + Arc::new(move |called_value| { Box::pin(capture!([args, engine, k], async move { - match fun_value.0 { + match called_value.0 { // Handle closure (user-defined function). - CoreData::Function(Closure(params, body)) => { + CoreData::Function(FunKind::Closure(params, body)) => { evaluate_closure_call(params, body, args, engine, k).await } // Handle Rust UDF (built-in function). - CoreData::Function(RustUDF(udf)) => { + CoreData::Function(FunKind::RustUDF(udf)) => { evaluate_rust_udf_call(udf, args, engine, k).await } - // Value must be a function. - _ => panic!("Expected function value"), + // Handle indexing into collections (Array, Tuple, Struct). + CoreData::Array(_) | CoreData::Tuple(_) | CoreData::Struct(_, _) => { + evaluate_indexed_access(called_value, args, engine, k).await + } + // Handle Map lookup. + CoreData::Map(_) => { + evaluate_map_lookup(called_value, args, engine, k).await + } + // Value must be a function or indexable collection. + _ => panic!( + "Expected function or indexable collection, got: {:?}", + called_value + ), + } + })) + }), + ) + .await +} + +/// Evaluates indexing into a collection (Array, Tuple, or Struct). +/// +/// # Parameters +/// +/// * `collection` - The collection value to index into. +/// * `args` - The argument expressions (should be a single index). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_indexed_access( + collection: Value, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + // Check that there's exactly one argument. + if args.len() != 1 { + panic!("Indexed access requires exactly one index argument"); + } + + // Evaluate the index expression. + engine + .evaluate( + args[0].clone(), + Arc::new(move |index_value| { + Box::pin(capture!([collection, k], async move { + // Extract the index as an integer. + let index = match &index_value.0 { + CoreData::Literal(Literal::Int64(i)) => *i as usize, + _ => panic!("Index must be an integer, got: {:?}", index_value), + }; + + // Index into the collection based on its type. + let result = match &collection.0 { + CoreData::Array(items) => get_indexed_item(items, index), + CoreData::Tuple(items) => get_indexed_item(items, index), + CoreData::Struct(_, fields) => get_indexed_item(fields, index), + _ => panic!("Attempted to index a non-indexable value: {:?}", collection), + }; + + fn get_indexed_item(items: &[Value], index: usize) -> Value { + if index < items.len() { + items[index].clone() + } else { + panic!("index out of bounds: {} >= {}", index, items.len()); + } + } + + // Pass the indexed value to the continuation. + k(result).await + })) + }), + ) + .await +} + +/// Evaluates a map lookup. +/// +/// # Parameters +/// +/// * `map_value` - The map value to look up in. +/// * `args` - The argument expressions (should be a single key). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_map_lookup( + map_value: Value, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + // Check that there's exactly one argument + if args.len() != 1 { + panic!("Map lookup requires exactly one key argument"); + } + + // Evaluate the key expression + engine + .evaluate( + args[0].clone(), + Arc::new(move |key_value| { + Box::pin(capture!([map_value, k], async move { + // Extract the map + match &map_value.0 { + CoreData::Map(map) => { + // Look up the key in the map, returning None if not found + let result = map.get(&key_value); + k(result).await + } + _ => panic!( + "Attempted to perform map lookup on non-map value: {:?}", + map_value + ), } })) }), @@ -387,7 +504,7 @@ where #[cfg(test)] mod tests { use crate::engine::Engine; - use crate::utils::tests::{array_val, assert_values_equal, ref_expr}; + use crate::utils::tests::{array_val, assert_values_equal, ref_expr, struct_val}; use crate::{ analyzer::{ context::Context, @@ -860,6 +977,194 @@ mod tests { } } + /// Test array indexing with call syntax + #[tokio::test] + async fn test_array_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create an array [10, 20, 30, 40, 50] + let array_expr = Arc::new(Expr::new(CoreVal(array_val(vec![ + lit_val(int(10)), + lit_val(int(20)), + lit_val(int(30)), + lit_val(int(40)), + lit_val(int(50)), + ])))); + + // Access array[2] which should be 30 + let index_expr = Arc::new(Expr::new(Call(array_expr, vec![lit_expr(int(2))]))); + + let results = evaluate_and_collect(index_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(30)); + } + _ => panic!("Expected integer literal"), + } + } + + /// Test tuple indexing with call syntax + #[tokio::test] + async fn test_tuple_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a tuple (10, "hello", true) + let tuple_expr = Arc::new(Expr::new(CoreVal(Value(CoreData::Tuple(vec![ + lit_val(int(10)), + lit_val(string("hello")), + lit_val(Literal::Bool(true)), + ]))))); + + // Access tuple[1] which should be "hello" + let index_expr = Arc::new(Expr::new(Call(tuple_expr, vec![lit_expr(int(1))]))); + + let results = evaluate_and_collect(index_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("hello".to_string())); + } + _ => panic!("Expected string literal"), + } + } + + /// Test struct field access with call syntax + #[tokio::test] + async fn test_struct_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a struct Point { x: 10, y: 20 } + let struct_expr = Arc::new(Expr::new(CoreVal(struct_val( + "Point", + vec![lit_val(int(10)), lit_val(int(20))], + )))); + + // Access struct[1] which should be 20 (the y field) + let index_expr = Arc::new(Expr::new(Call(struct_expr, vec![lit_expr(int(1))]))); + + let results = evaluate_and_collect(index_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(20)); + } + _ => panic!("Expected integer literal"), + } + } + + /// Test map lookup with call syntax + #[tokio::test] + async fn test_map_lookup() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a map with key-value pairs: { "a": 1, "b": 2, "c": 3 } + // Use a let expression to bind the map and do lookups directly + let test_expr = Arc::new(Expr::new(Let( + "map".to_string(), + Arc::new(Expr::new(Map(vec![ + (lit_expr(string("a")), lit_expr(int(1))), + (lit_expr(string("b")), lit_expr(int(2))), + (lit_expr(string("c")), lit_expr(int(3))), + ]))), + // Create a tuple of map["b"] and map["d"] to test both existing and missing keys + Arc::new(Expr::new(CoreExpr(CoreData::Tuple(vec![ + // map["b"] - should be 2 + Arc::new(Expr::new(Call( + ref_expr("map"), + vec![lit_expr(string("b"))], + ))), + // map["d"] - should be None + Arc::new(Expr::new(Call( + ref_expr("map"), + vec![lit_expr(string("d"))], + ))), + ])))), + ))); + + let results = evaluate_and_collect(test_expr, engine, harness).await; + + // Check result - should be a tuple (2, None) + assert_eq!(results.len(), 1); + match &results[0].0 { + CoreData::Tuple(elements) => { + assert_eq!(elements.len(), 2); + + // Check first element: map["b"] should be 2 + match &elements[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(2)); + } + _ => panic!("Expected integer literal for existing key lookup"), + } + + // Check second element: map["d"] should be None + match &elements[1].0 { + CoreData::None => {} + _ => panic!("Expected None for missing key lookup"), + } + } + _ => panic!("Expected tuple result"), + } + } + + /// Test complex expressions for both collection and index + #[tokio::test] + async fn test_complex_collection_and_index() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a let expression that binds an array and then accesses it + // let arr = [10, 20, 30, 40, 50] in + // let idx = 2 + 1 in + // arr[idx] // should be 40 + let complex_expr = Arc::new(Expr::new(Let( + "arr".to_string(), + Arc::new(Expr::new(CoreExpr(CoreData::Array(vec![ + lit_expr(int(10)), + lit_expr(int(20)), + lit_expr(int(30)), + lit_expr(int(40)), + lit_expr(int(50)), + ])))), + Arc::new(Expr::new(Let( + "idx".to_string(), + Arc::new(Expr::new(Binary( + lit_expr(int(2)), + BinOp::Add, + lit_expr(int(1)), + ))), + Arc::new(Expr::new(Call(ref_expr("arr"), vec![ref_expr("idx")]))), + ))), + ))); + + let results = evaluate_and_collect(complex_expr, engine, harness).await; + + // Check result + assert_eq!(results.len(), 1); + match &results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::Int64(40)); + } + _ => panic!("Expected integer literal"), + } + } + /// Test variable reference in various contexts #[tokio::test] async fn test_variable_references() { diff --git a/optd-dsl/src/engine/mod.rs b/optd-dsl/src/engine/mod.rs index e8a80fe9..7c5d1436 100644 --- a/optd-dsl/src/engine/mod.rs +++ b/optd-dsl/src/engine/mod.rs @@ -4,7 +4,7 @@ use crate::analyzer::{ }; use ExprKind::*; use eval::expr::{ - evaluate_binary_expr, evaluate_function_call, evaluate_if_then_else, evaluate_let_binding, + evaluate_binary_expr, evaluate_call, evaluate_if_then_else, evaluate_let_binding, evaluate_reference, evaluate_unary_expr, }; use eval::r#match::evaluate_pattern_match; @@ -95,7 +95,7 @@ impl Engine { evaluate_binary_expr(left.clone(), op.clone(), right.clone(), self, k).await } Unary(op, expr) => evaluate_unary_expr(op.clone(), expr.clone(), self, k).await, - Call(fun, args) => evaluate_function_call(fun.clone(), args.clone(), self, k).await, + Call(fun, args) => evaluate_call(fun.clone(), args.clone(), self, k).await, Map(map) => evaluate_map(map.clone(), self, k).await, Ref(ident) => evaluate_reference(ident.clone(), self, k).await, CoreExpr(expr) => evaluate_core_expr(expr.clone(), self, k).await, From 168ed2fc6047935165541c4fc68da70d2e7e3a6f Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 20:59:15 -0400 Subject: [PATCH 07/11] Add nothing type --- optd-dsl/src/analyzer/type.rs | 33 ++++++++++++++++++++++++++++++++- 1 file changed, 32 insertions(+), 1 deletion(-) diff --git a/optd-dsl/src/analyzer/type.rs b/optd-dsl/src/analyzer/type.rs index 77b907de..9d0f8372 100644 --- a/optd-dsl/src/analyzer/type.rs +++ b/optd-dsl/src/analyzer/type.rs @@ -21,7 +21,8 @@ pub enum Type { // Special types Unit, - Universe, + Universe, // All types are subtypes of Universe + Nothing, // Inherits all types. Unknown, // User types @@ -142,6 +143,9 @@ impl TypeRegistry { // Universe is the top type - everything is a subtype of Universe (_, Type::Universe) => true, + // Nothing is the bottom type - it is a subtype of everything + (Type::Nothing, _) => true, + // Stored and Costed type handling (Type::Stored(child_inner), Type::Stored(parent_inner)) => { self.is_subtype(child_inner, parent_inner) @@ -613,6 +617,33 @@ mod type_registry_tests { assert!(!registry.is_subtype(&Type::Universe, &Type::Array(Box::new(Type::Int64)))); } + #[test] + fn test_nothing_as_bottom_type() { + let registry = TypeRegistry::default(); + + // Nothing is a subtype of all primitive types + assert!(registry.is_subtype(&Type::Nothing, &Type::Int64)); + assert!(registry.is_subtype(&Type::Nothing, &Type::String)); + assert!(registry.is_subtype(&Type::Nothing, &Type::Bool)); + assert!(registry.is_subtype(&Type::Nothing, &Type::Float64)); + assert!(registry.is_subtype(&Type::Nothing, &Type::Unit)); + assert!(registry.is_subtype(&Type::Nothing, &Type::Universe)); + + // Nothing is a subtype of complex types + assert!(registry.is_subtype(&Type::Nothing, &Type::Array(Box::new(Type::Int64)))); + assert!(registry.is_subtype(&Type::Nothing, &Type::Tuple(vec![Type::Int64, Type::Bool]))); + assert!(registry.is_subtype( + &Type::Nothing, + &Type::Closure(Box::new(Type::Int64), Box::new(Type::Bool)) + )); + + // But no type is a subtype of Nothing (except Nothing itself) + assert!(!registry.is_subtype(&Type::Int64, &Type::Nothing)); + assert!(!registry.is_subtype(&Type::Bool, &Type::Nothing)); + assert!(!registry.is_subtype(&Type::Universe, &Type::Nothing)); + assert!(!registry.is_subtype(&Type::Array(Box::new(Type::Int64)), &Type::Nothing)); + } + #[test] fn test_complex_nested_type_hierarchy() { let mut registry = TypeRegistry::default(); From 73e76d070018819aadc89b95a537043c74c4fb27 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Wed, 2 Apr 2025 21:16:07 -0400 Subject: [PATCH 08/11] Almost done with engine operator field accesses --- optd-dsl/src/engine/eval/expr.rs | 369 ++++++++++++++++++++++++++----- 1 file changed, 317 insertions(+), 52 deletions(-) diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index affc34fa..e233309b 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -1,11 +1,15 @@ use super::{binary::eval_binary_op, unary::eval_unary_op}; -use crate::analyzer::hir::{BinOp, CoreData, Expr, FunKind, Identifier, Literal, UnaryOp, Value}; +use crate::analyzer::hir::{ + BinOp, CoreData, Expr, ExprKind, FunKind, Goal, GroupId, Identifier, Literal, LogicalOp, + Materializable, PhysicalOp, UnaryOp, Value, +}; use crate::analyzer::map::Map; use crate::engine::{Continuation, EngineResponse}; use crate::{ capture, engine::{Engine, utils::evaluate_sequence}, }; +use ExprKind::*; use std::sync::Arc; /// Evaluates an if-then-else expression. @@ -115,31 +119,6 @@ pub(crate) async fn evaluate_binary_expr( where O: Send + 'static, { - // Helper function to evaluate the right operand after the left is evaluated. - async fn evaluate_right( - left_val: Value, - right: Arc, - op: BinOp, - engine: Engine, - k: Continuation>, - ) -> EngineResponse - where - O: Send + 'static, - { - engine - .evaluate( - right, - Arc::new(move |right_val| { - Box::pin(capture!([left_val, op, k], async move { - // Apply the binary operation and pass result to continuation. - let result = eval_binary_op(left_val, &op, right_val); - k(result).await - })) - }), - ) - .await - } - // First evaluate the left operand. engine .clone() @@ -154,6 +133,31 @@ where .await } +/// Helper function to evaluate the right operand after the left is evaluated. +async fn evaluate_right( + left_val: Value, + right: Arc, + op: BinOp, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + engine + .evaluate( + right, + Arc::new(move |right_val| { + Box::pin(capture!([left_val, op, k], async move { + // Apply the binary operation and pass result to continuation. + let result = eval_binary_op(left_val, &op, right_val); + k(result).await + })) + }), + ) + .await +} + /// Evaluates a unary expression. /// /// Evaluates the operand, then applies the unary operation, passing the result to the continuation. @@ -193,8 +197,8 @@ where /// First evaluates the called expression, then the arguments, and finally applies the call to /// the arguments, passing results to the continuation. /// -/// Extended to support indexing into collections (Array, Tuple, Struct, Map) when the called -/// expression evaluates to one of these types and a single argument is provided. +/// Extended to support indexing into collections (Array, Tuple, Struct, Map, Logical, Physical) +/// when the called expression evaluates to one of these types and a single argument is provided. /// /// # Parameters /// @@ -219,25 +223,33 @@ where Arc::new(move |called_value| { Box::pin(capture!([args, engine, k], async move { match called_value.0 { - // Handle closure (user-defined function). + // Handle function calls CoreData::Function(FunKind::Closure(params, body)) => { evaluate_closure_call(params, body, args, engine, k).await } - // Handle Rust UDF (built-in function). CoreData::Function(FunKind::RustUDF(udf)) => { evaluate_rust_udf_call(udf, args, engine, k).await } - // Handle indexing into collections (Array, Tuple, Struct). + + // Handle collection indexing CoreData::Array(_) | CoreData::Tuple(_) | CoreData::Struct(_, _) => { evaluate_indexed_access(called_value, args, engine, k).await } - // Handle Map lookup. CoreData::Map(_) => { evaluate_map_lookup(called_value, args, engine, k).await } - // Value must be a function or indexable collection. + + // Handle operator field accesses + CoreData::Logical(op) => { + evaluate_logical_operator_access(op, args, engine, k).await + } + CoreData::Physical(op) => { + evaluate_physical_operator_access(op, args, engine, k).await + } + + // Value must be a function or indexable collection/operator _ => panic!( - "Expected function or indexable collection, got: {:?}", + "Expected function or indexable value, got: {:?}", called_value ), } @@ -247,6 +259,256 @@ where .await } +/// Evaluates access to a logical operator. +/// +/// Handles both materialized and unmaterialized logical operators. +/// +/// # Parameters +/// +/// * `op` - The logical operator (materialized or unmaterialized). +/// * `args` - The argument expressions (should be a single index). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_logical_operator_access( + op: Materializable, GroupId>, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + validate_single_index_arg(&args); + + match op { + // For unmaterialized logical operators, yield the group and continue when it's expanded. + Materializable::UnMaterialized(group_id) => { + yield_group_and_continue(group_id, args, engine, k).await + } + // For materialized logical operators, access the data or children directly. + Materializable::Materialized(log_op) => { + evaluate_index_on_materialized_operator( + args[0].clone(), + log_op.operator.data, + log_op.operator.children, + engine, + k, + ) + .await + } + } +} + +/// Evaluates access to a physical operator. +/// +/// Handles both materialized and unmaterialized physical operators. +/// +/// # Parameters +/// +/// * `op` - The physical operator (materialized or unmaterialized). +/// * `args` - The argument expressions (should be a single index). +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_physical_operator_access( + op: Materializable, Goal>, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + validate_single_index_arg(&args); + + match op { + // For unmaterialized physical operators, yield the goal and continue when it's expanded + Materializable::UnMaterialized(goal) => { + yield_goal_and_continue(goal, args, engine, k).await + } + // For materialized physical operators, access the data or children directly + Materializable::Materialized(phys_op) => { + evaluate_index_on_materialized_operator( + args[0].clone(), + phys_op.operator.data, + phys_op.operator.children, + engine, + k, + ) + .await + } + } +} + +/// Validates that exactly one index argument is provided. +/// +/// # Parameters +/// +/// * `args` - The argument expressions to validate. +fn validate_single_index_arg(args: &[Arc]) { + if args.len() != 1 { + panic!("Operator access requires exactly one index argument"); + } +} + +/// Yields a group ID to be materialized and continues evaluation once expanded. +/// +/// # Parameters +/// +/// * `group_id` - The group ID to yield. +/// * `args` - The argument expressions for indexed access. +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +fn yield_group_and_continue( + group_id: GroupId, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> impl Future> + Send +where + O: Send + 'static, +{ + Box::pin(async move { + EngineResponse::YieldGroup( + group_id, + Arc::new(move |expanded_value| { + Box::pin(capture!([args, engine, k], async move { + // Once the group is expanded, perform the indexed access. + evaluate_call( + Arc::new(Expr::new(CoreVal(expanded_value))), + args, + engine, + k, + ) + .await + })) + }), + ) + }) +} + +/// Yields a goal to be materialized and continues evaluation once expanded. +/// +/// # Parameters +/// +/// * `goal` - The goal to yield. +/// * `args` - The argument expressions for indexed access. +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +fn yield_goal_and_continue( + goal: Goal, + args: Vec>, + engine: Engine, + k: Continuation>, +) -> impl Future> + Send +where + O: Send + 'static, +{ + Box::pin(async move { + EngineResponse::YieldGoal( + goal, + Arc::new(move |expanded_value| { + Box::pin(capture!([args, engine, k], async move { + // Once the goal is expanded, perform the indexed access + evaluate_call( + Arc::new(Expr::new(CoreVal(expanded_value))), + args, + engine, + k, + ) + .await + })) + }), + ) + }) +} + +/// Evaluates an index expression on a materialized operator. +/// +/// Treats the operator's data and children as a concatenated vector and accesses by index. +/// +/// # Parameters +/// +/// * `index_expr` - The index expression to evaluate. +/// * `data` - The operator's data fields. +/// * `children` - The operator's children. +/// * `engine` - The evaluation engine. +/// * `k` - The continuation to receive evaluation results. +async fn evaluate_index_on_materialized_operator( + index_expr: Arc, + data: Vec, + children: Vec, + engine: Engine, + k: Continuation>, +) -> EngineResponse +where + O: Send + 'static, +{ + // Evaluate the index expression + engine + .evaluate( + index_expr, + Arc::new(move |index_value| { + Box::pin(capture!([data, children, k], async move { + // Extract the index as an integer + let index = extract_index(&index_value); + + // Access the concatenated vector of data and children + let result = access_operator_field(index, &data, &children); + + // Pass the indexed value to the continuation + k(result).await + })) + }), + ) + .await +} + +/// Extracts an integer index from a value. +/// +/// # Parameters +/// +/// * `index_value` - The value containing the index. +/// +/// # Returns +/// +/// The extracted integer index. +fn extract_index(index_value: &Value) -> usize { + match &index_value.0 { + CoreData::Literal(Literal::Int64(i)) => *i as usize, + _ => panic!("Index must be an integer, got: {:?}", index_value), + } +} + +/// Accesses a field in an operator by index. +/// +/// Treats data and children as a concatenated vector and accesses by index. +/// +/// # Parameters +/// +/// * `index` - The index to access. +/// * `data` - The operator's data fields. +/// * `children` - The operator's children. +/// +/// # Returns +/// +/// The value at the specified index. +fn access_operator_field(index: usize, data: &[Value], children: &[Value]) -> Value { + let data_len = data.len(); + let total_len = data_len + children.len(); + + if index >= total_len { + panic!("index out of bounds: {} >= {}", index, total_len); + } + + if index < data_len { + // Access data + data[index].clone() + } else { + // Access children + children[index - data_len].clone() + } +} + /// Evaluates indexing into a collection (Array, Tuple, or Struct). /// /// # Parameters @@ -265,9 +527,7 @@ where O: Send + 'static, { // Check that there's exactly one argument. - if args.len() != 1 { - panic!("Indexed access requires exactly one index argument"); - } + validate_single_index_arg(&args); // Evaluate the index expression. engine @@ -276,10 +536,7 @@ where Arc::new(move |index_value| { Box::pin(capture!([collection, k], async move { // Extract the index as an integer. - let index = match &index_value.0 { - CoreData::Literal(Literal::Int64(i)) => *i as usize, - _ => panic!("Index must be an integer, got: {:?}", index_value), - }; + let index = extract_index(&index_value); // Index into the collection based on its type. let result = match &collection.0 { @@ -289,14 +546,6 @@ where _ => panic!("Attempted to index a non-indexable value: {:?}", collection), }; - fn get_indexed_item(items: &[Value], index: usize) -> Value { - if index < items.len() { - items[index].clone() - } else { - panic!("index out of bounds: {} >= {}", index, items.len()); - } - } - // Pass the indexed value to the continuation. k(result).await })) @@ -305,6 +554,24 @@ where .await } +/// Gets an item from a collection at the specified index. +/// +/// # Parameters +/// +/// * `items` - The collection items. +/// * `index` - The index to access. +/// +/// # Returns +/// +/// The value at the specified index. +fn get_indexed_item(items: &[Value], index: usize) -> Value { + if index < items.len() { + items[index].clone() + } else { + panic!("index out of bounds: {} >= {}", index, items.len()); + } +} + /// Evaluates a map lookup. /// /// # Parameters @@ -323,9 +590,7 @@ where O: Send + 'static, { // Check that there's exactly one argument - if args.len() != 1 { - panic!("Map lookup requires exactly one key argument"); - } + validate_single_index_arg(&args); // Evaluate the key expression engine From fc5d8e2c471197debf5cf0e244b11eae3bceafc4 Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Thu, 3 Apr 2025 08:15:50 -0400 Subject: [PATCH 09/11] Add logical & physical operator testing --- optd-dsl/src/engine/eval/expr.rs | 674 ++++++++++++++++++++++++------- 1 file changed, 532 insertions(+), 142 deletions(-) diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index e233309b..8dc52ca5 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -34,7 +34,6 @@ pub(crate) async fn evaluate_if_then_else( where O: Send + 'static, { - // First evaluate the condition engine .clone() .evaluate( @@ -119,7 +118,6 @@ pub(crate) async fn evaluate_binary_expr( where O: Send + 'static, { - // First evaluate the left operand. engine .clone() .evaluate( @@ -149,7 +147,6 @@ where right, Arc::new(move |right_val| { Box::pin(capture!([left_val, op, k], async move { - // Apply the binary operation and pass result to continuation. let result = eval_binary_op(left_val, &op, right_val); k(result).await })) @@ -177,13 +174,11 @@ pub(crate) async fn evaluate_unary_expr( where O: Send + 'static, { - // Evaluate the operand, then apply the unary operation. engine .evaluate( expr, Arc::new(move |value| { Box::pin(capture!([op, k], async move { - // Apply the unary operation and pass result to continuation. let result = eval_unary_op(&op, value); k(result).await })) @@ -223,7 +218,7 @@ where Arc::new(move |called_value| { Box::pin(capture!([args, engine, k], async move { match called_value.0 { - // Handle function calls + // Handle function calls. CoreData::Function(FunKind::Closure(params, body)) => { evaluate_closure_call(params, body, args, engine, k).await } @@ -231,7 +226,7 @@ where evaluate_rust_udf_call(udf, args, engine, k).await } - // Handle collection indexing + // Handle collection indexing. CoreData::Array(_) | CoreData::Tuple(_) | CoreData::Struct(_, _) => { evaluate_indexed_access(called_value, args, engine, k).await } @@ -239,7 +234,7 @@ where evaluate_map_lookup(called_value, args, engine, k).await } - // Handle operator field accesses + // Handle operator field accesses. CoreData::Logical(op) => { evaluate_logical_operator_access(op, args, engine, k).await } @@ -247,7 +242,7 @@ where evaluate_physical_operator_access(op, args, engine, k).await } - // Value must be a function or indexable collection/operator + // Value must be a function or indexable collection/operator. _ => panic!( "Expected function or indexable value, got: {:?}", called_value @@ -269,34 +264,42 @@ where /// * `args` - The argument expressions (should be a single index). /// * `engine` - The evaluation engine. /// * `k` - The continuation to receive evaluation results. -async fn evaluate_logical_operator_access( +fn evaluate_logical_operator_access( op: Materializable, GroupId>, args: Vec>, engine: Engine, k: Continuation>, -) -> EngineResponse +) -> impl Future> + Send where O: Send + 'static, { - validate_single_index_arg(&args); - - match op { - // For unmaterialized logical operators, yield the group and continue when it's expanded. - Materializable::UnMaterialized(group_id) => { - yield_group_and_continue(group_id, args, engine, k).await - } - // For materialized logical operators, access the data or children directly. - Materializable::Materialized(log_op) => { - evaluate_index_on_materialized_operator( - args[0].clone(), - log_op.operator.data, - log_op.operator.children, - engine, - k, - ) - .await + Box::pin(async move { + validate_single_index_arg(&args); + + match op { + // For unmaterialized logical operators, yield the group and continue when it's expanded. + Materializable::UnMaterialized(group_id) => EngineResponse::YieldGroup( + group_id, + Arc::new(move |expanded_value| { + Box::pin(capture!([args, engine, k], async move { + evaluate_call(Expr::new(CoreVal(expanded_value)).into(), args, engine, k) + .await + })) + }), + ), + // For materialized logical operators, access the data or children directly. + Materializable::Materialized(log_op) => { + evaluate_index_on_materialized_operator( + args[0].clone(), + log_op.operator.data, + log_op.operator.children, + engine, + k, + ) + .await + } } - } + }) } /// Evaluates access to a physical operator. @@ -309,34 +312,42 @@ where /// * `args` - The argument expressions (should be a single index). /// * `engine` - The evaluation engine. /// * `k` - The continuation to receive evaluation results. -async fn evaluate_physical_operator_access( +fn evaluate_physical_operator_access( op: Materializable, Goal>, args: Vec>, engine: Engine, k: Continuation>, -) -> EngineResponse +) -> impl Future> + Send where O: Send + 'static, { - validate_single_index_arg(&args); - - match op { - // For unmaterialized physical operators, yield the goal and continue when it's expanded - Materializable::UnMaterialized(goal) => { - yield_goal_and_continue(goal, args, engine, k).await - } - // For materialized physical operators, access the data or children directly - Materializable::Materialized(phys_op) => { - evaluate_index_on_materialized_operator( - args[0].clone(), - phys_op.operator.data, - phys_op.operator.children, - engine, - k, - ) - .await + Box::pin(async move { + validate_single_index_arg(&args); + + match op { + // For unmaterialized physical operators, yield the goal and continue when it's expanded. + Materializable::UnMaterialized(goal) => EngineResponse::YieldGoal( + goal, + Arc::new(move |expanded_value| { + Box::pin(capture!([args, engine, k], async move { + evaluate_call(Expr::new(CoreVal(expanded_value)).into(), args, engine, k) + .await + })) + }), + ), + // For materialized physical operators, access the data or children directly. + Materializable::Materialized(phys_op) => { + evaluate_index_on_materialized_operator( + args[0].clone(), + phys_op.operator.data, + phys_op.operator.children, + engine, + k, + ) + .await + } } - } + }) } /// Validates that exactly one index argument is provided. @@ -350,78 +361,6 @@ fn validate_single_index_arg(args: &[Arc]) { } } -/// Yields a group ID to be materialized and continues evaluation once expanded. -/// -/// # Parameters -/// -/// * `group_id` - The group ID to yield. -/// * `args` - The argument expressions for indexed access. -/// * `engine` - The evaluation engine. -/// * `k` - The continuation to receive evaluation results. -fn yield_group_and_continue( - group_id: GroupId, - args: Vec>, - engine: Engine, - k: Continuation>, -) -> impl Future> + Send -where - O: Send + 'static, -{ - Box::pin(async move { - EngineResponse::YieldGroup( - group_id, - Arc::new(move |expanded_value| { - Box::pin(capture!([args, engine, k], async move { - // Once the group is expanded, perform the indexed access. - evaluate_call( - Arc::new(Expr::new(CoreVal(expanded_value))), - args, - engine, - k, - ) - .await - })) - }), - ) - }) -} - -/// Yields a goal to be materialized and continues evaluation once expanded. -/// -/// # Parameters -/// -/// * `goal` - The goal to yield. -/// * `args` - The argument expressions for indexed access. -/// * `engine` - The evaluation engine. -/// * `k` - The continuation to receive evaluation results. -fn yield_goal_and_continue( - goal: Goal, - args: Vec>, - engine: Engine, - k: Continuation>, -) -> impl Future> + Send -where - O: Send + 'static, -{ - Box::pin(async move { - EngineResponse::YieldGoal( - goal, - Arc::new(move |expanded_value| { - Box::pin(capture!([args, engine, k], async move { - // Once the goal is expanded, perform the indexed access - evaluate_call( - Arc::new(Expr::new(CoreVal(expanded_value))), - args, - engine, - k, - ) - .await - })) - }), - ) - }) -} - /// Evaluates an index expression on a materialized operator. /// /// Treats the operator's data and children as a concatenated vector and accesses by index. @@ -443,19 +382,14 @@ async fn evaluate_index_on_materialized_operator( where O: Send + 'static, { - // Evaluate the index expression engine .evaluate( index_expr, Arc::new(move |index_value| { Box::pin(capture!([data, children, k], async move { - // Extract the index as an integer let index = extract_index(&index_value); - - // Access the concatenated vector of data and children let result = access_operator_field(index, &data, &children); - // Pass the indexed value to the continuation k(result).await })) }), @@ -501,10 +435,8 @@ fn access_operator_field(index: usize, data: &[Value], children: &[Value]) -> Va } if index < data_len { - // Access data data[index].clone() } else { - // Access children children[index - data_len].clone() } } @@ -526,19 +458,15 @@ async fn evaluate_indexed_access( where O: Send + 'static, { - // Check that there's exactly one argument. validate_single_index_arg(&args); - // Evaluate the index expression. engine .evaluate( args[0].clone(), Arc::new(move |index_value| { Box::pin(capture!([collection, k], async move { - // Extract the index as an integer. let index = extract_index(&index_value); - // Index into the collection based on its type. let result = match &collection.0 { CoreData::Array(items) => get_indexed_item(items, index), CoreData::Tuple(items) => get_indexed_item(items, index), @@ -546,7 +474,6 @@ where _ => panic!("Attempted to index a non-indexable value: {:?}", collection), }; - // Pass the indexed value to the continuation. k(result).await })) }), @@ -589,19 +516,15 @@ async fn evaluate_map_lookup( where O: Send + 'static, { - // Check that there's exactly one argument validate_single_index_arg(&args); - // Evaluate the key expression engine .evaluate( args[0].clone(), Arc::new(move |key_value| { Box::pin(capture!([map_value, k], async move { - // Extract the map match &map_value.0 { CoreData::Map(map) => { - // Look up the key in the map, returning None if not found let result = map.get(&key_value); k(result).await } @@ -643,7 +566,7 @@ where engine.clone(), Arc::new(move |arg_values| { Box::pin(capture!([params, body, engine, k], async move { - // Create a new context with parameters bound to arguments + // Create a new context with parameters bound to arguments. let mut new_ctx = engine.context.clone(); new_ctx.push_scope(); @@ -651,7 +574,7 @@ where new_ctx.bind(p.clone(), a); }); - // Evaluate the body in the new context + // Evaluate the body in the new context. engine.with_new_context(new_ctx).evaluate(body, k).await })) }), @@ -755,21 +678,23 @@ pub(crate) async fn evaluate_reference( where O: Send + 'static, { - // Look up the variable in the context. let value = engine .context .lookup(&ident) .unwrap_or_else(|| panic!("Variable not found: {}", ident)) .clone(); - // Pass the value to the continuation. k(value).await } #[cfg(test)] mod tests { + use crate::analyzer::hir::{Goal, GroupId, Materializable}; use crate::engine::Engine; - use crate::utils::tests::{array_val, assert_values_equal, ref_expr, struct_val}; + use crate::utils::tests::{ + array_val, assert_values_equal, create_logical_operator, create_physical_operator, + ref_expr, struct_val, + }; use crate::{ analyzer::{ context::Context, @@ -1430,6 +1355,471 @@ mod tests { } } + /// Test indexing into a logical operator + #[tokio::test] + async fn test_logical_operator_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a logical operator: LogicalJoin { joinType: "inner", condition: "x = y" } [TableScan("orders"), TableScan("lineitem")] + let join_op = create_logical_operator( + "LogicalJoin", + vec![lit_val(string("inner")), lit_val(string("x = y"))], + vec![ + create_logical_operator("TableScan", vec![lit_val(string("orders"))], vec![]), + create_logical_operator("TableScan", vec![lit_val(string("lineitem"))], vec![]), + ], + ); + + let logical_op_expr = Arc::new(Expr::new(CoreVal(join_op))); + + // Access join_type using indexing - should be "inner" + let join_type_expr = Arc::new(Expr::new(Call( + logical_op_expr.clone(), + vec![lit_expr(int(0))], + ))); + let join_type_results = + evaluate_and_collect(join_type_expr, engine.clone(), harness.clone()).await; + + // Access condition using indexing - should be "x = y" + let condition_expr = Arc::new(Expr::new(Call( + logical_op_expr.clone(), + vec![lit_expr(int(1))], + ))); + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + + // Access first child (orders table scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + logical_op_expr.clone(), + vec![lit_expr(int(2))], + ))); + let first_child_results = + evaluate_and_collect(first_child_expr, engine.clone(), harness.clone()).await; + + // Access second child (lineitem table scan) using indexing + let second_child_expr = Arc::new(Expr::new(Call(logical_op_expr, vec![lit_expr(int(3))]))); + let second_child_results = evaluate_and_collect(second_child_expr, engine, harness).await; + + // Check join_type result + assert_eq!(join_type_results.len(), 1); + match &join_type_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("inner".to_string())); + } + _ => panic!("Expected string literal for join type"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("x = y".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (orders table scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].0 { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("orders".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for first child"), + } + + // Check second child result (lineitem table scan) + assert_eq!(second_child_results.len(), 1); + match &second_child_results[0].0 { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("lineitem".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for second child"), + } + } + + /// Test indexing into a physical operator + #[tokio::test] + async fn test_physical_operator_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a physical operator: HashJoin { method: "hash", condition: "id = id" } [IndexScan("customers"), ParallelScan("orders")] + let join_op = create_physical_operator( + "HashJoin", + vec![lit_val(string("hash")), lit_val(string("id = id"))], + vec![ + create_physical_operator("IndexScan", vec![lit_val(string("customers"))], vec![]), + create_physical_operator("ParallelScan", vec![lit_val(string("orders"))], vec![]), + ], + ); + + let physical_op_expr = Arc::new(Expr::new(CoreVal(join_op))); + + // Access join method using indexing - should be "hash" + let method_expr = Arc::new(Expr::new(Call( + physical_op_expr.clone(), + vec![lit_expr(int(0))], + ))); + let method_results = + evaluate_and_collect(method_expr, engine.clone(), harness.clone()).await; + + // Access condition using indexing - should be "id = id" + let condition_expr = Arc::new(Expr::new(Call( + physical_op_expr.clone(), + vec![lit_expr(int(1))], + ))); + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + + // Access first child (customers index scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + physical_op_expr.clone(), + vec![lit_expr(int(2))], + ))); + let first_child_results = + evaluate_and_collect(first_child_expr, engine.clone(), harness.clone()).await; + + // Access second child (orders parallel scan) using indexing + let second_child_expr = Arc::new(Expr::new(Call(physical_op_expr, vec![lit_expr(int(3))]))); + let second_child_results = evaluate_and_collect(second_child_expr, engine, harness).await; + + // Check join method result + assert_eq!(method_results.len(), 1); + match &method_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("hash".to_string())); + } + _ => panic!("Expected string literal for join method"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("id = id".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (customers index scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].0 { + CoreData::Physical(Materializable::Materialized(phys_op)) => { + assert_eq!(phys_op.operator.tag, "IndexScan"); + assert_eq!(phys_op.operator.data.len(), 1); + match &phys_op.operator.data[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected physical operator for first child"), + } + + // Check second child result (orders parallel scan) + assert_eq!(second_child_results.len(), 1); + match &second_child_results[0].0 { + CoreData::Physical(Materializable::Materialized(phys_op)) => { + assert_eq!(phys_op.operator.tag, "ParallelScan"); + assert_eq!(phys_op.operator.data.len(), 1); + match &phys_op.operator.data[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("orders".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected physical operator for second child"), + } + } + + /// Test indexing into an unmaterialized logical operator + #[tokio::test] + async fn test_unmaterialized_logical_operator_indexing() { + let harness = TestHarness::new(); + let test_group_id = GroupId(1); + + // Register a logical operator in the test harness + let materialized_join = create_logical_operator( + "LogicalJoin", + vec![ + lit_val(string("inner")), + lit_val(string("customer.id = order.id")), + ], + vec![ + create_logical_operator("TableScan", vec![lit_val(string("customers"))], vec![]), + create_logical_operator("TableScan", vec![lit_val(string("orders"))], vec![]), + ], + ); + + harness.register_group(test_group_id, materialized_join); + + // Create an unmaterialized logical operator + let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value(CoreData::Logical( + Materializable::UnMaterialized(test_group_id), + ))))); + + // Access join type using indexing - should materialize and return "inner" + let join_type_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(0))], + ))); + + // Access condition using indexing - should materialize and return "customer.id = order.id" + let condition_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(1))], + ))); + + // Access first child (customers table scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(2))], + ))); + + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Evaluate the expressions + let join_type_results = + evaluate_and_collect(join_type_expr, engine.clone(), harness.clone()).await; + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + let first_child_results = evaluate_and_collect(first_child_expr, engine, harness).await; + + // Check join type result + assert_eq!(join_type_results.len(), 1); + match &join_type_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("inner".to_string())); + } + _ => panic!("Expected string literal for join type"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customer.id = order.id".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (customers table scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].0 { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for first child"), + } + } + + /// Test indexing into an unmaterialized physical operator + #[tokio::test] + async fn test_unmaterialized_physical_operator_indexing() { + let harness = TestHarness::new(); + + // Create a physical goal + let test_group_id = GroupId(2); + let properties = Box::new(Value(CoreData::Literal(string("sorted")))); + let test_goal = Goal { + group_id: test_group_id, + properties, + }; + + // Register a physical operator to be returned when the goal is expanded + let materialized_join = create_physical_operator( + "MergeJoin", + vec![ + lit_val(string("merge")), + lit_val(string("customer.id = order.id")), + ], + vec![ + create_physical_operator("SortedScan", vec![lit_val(string("customers"))], vec![]), + create_physical_operator("SortedScan", vec![lit_val(string("orders"))], vec![]), + ], + ); + + harness.register_goal(&test_goal, materialized_join); + + // Create an unmaterialized physical operator + let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value(CoreData::Physical( + Materializable::UnMaterialized(test_goal), + ))))); + + // Access join method using indexing - should materialize and return "merge" + let method_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(0))], + ))); + + // Access condition using indexing - should materialize and return "customer.id = order.id" + let condition_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(1))], + ))); + + // Access first child (customers scan) using indexing + let first_child_expr = Arc::new(Expr::new(Call( + unmaterialized_expr.clone(), + vec![lit_expr(int(2))], + ))); + + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Evaluate the expressions + let method_results = + evaluate_and_collect(method_expr, engine.clone(), harness.clone()).await; + let condition_results = + evaluate_and_collect(condition_expr, engine.clone(), harness.clone()).await; + let first_child_results = evaluate_and_collect(first_child_expr, engine, harness).await; + + // Check join method result + assert_eq!(method_results.len(), 1); + match &method_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("merge".to_string())); + } + _ => panic!("Expected string literal for join method"), + } + + // Check condition result + assert_eq!(condition_results.len(), 1); + match &condition_results[0].0 { + CoreData::Literal(lit) => { + assert_eq!(lit, &Literal::String("customer.id = order.id".to_string())); + } + _ => panic!("Expected string literal for condition"), + } + + // Check first child result (customers scan) + assert_eq!(first_child_results.len(), 1); + match &first_child_results[0].0 { + CoreData::Physical(Materializable::Materialized(phys_op)) => { + assert_eq!(phys_op.operator.tag, "SortedScan"); + assert_eq!(phys_op.operator.data.len(), 1); + match &phys_op.operator.data[0].0 { + CoreData::Literal(lit) => { + assert_eq!(*lit, Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected physical operator for first child"), + } + } + + /// Test accessing multiple levels of nested operators through indexing + #[tokio::test] + async fn test_nested_operator_indexing() { + let harness = TestHarness::new(); + let ctx = Context::default(); + let engine = Engine::new(ctx); + + // Create a nested operator: + // Project [col1, col2] ( + // Filter ("age > 30") ( + // Join ("inner", "t1.id = t2.id") ( + // TableScan ("customers"), + // TableScan ("orders") + // ) + // ) + // ) + let nested_op = create_logical_operator( + "Project", + vec![array_val(vec![ + lit_val(string("col1")), + lit_val(string("col2")), + ])], + vec![create_logical_operator( + "Filter", + vec![lit_val(string("age > 30"))], + vec![create_logical_operator( + "Join", + vec![lit_val(string("inner")), lit_val(string("t1.id = t2.id"))], + vec![ + create_logical_operator( + "TableScan", + vec![lit_val(string("customers"))], + vec![], + ), + create_logical_operator( + "TableScan", + vec![lit_val(string("orders"))], + vec![], + ), + ], + )], + )], + ); + + let nested_op_expr = Arc::new(Expr::new(CoreVal(nested_op))); + + // First, access the Filter child of Project (index 1) + let filter_expr = Arc::new(Expr::new(Call(nested_op_expr, vec![lit_expr(int(1))]))); + let filter_results = + evaluate_and_collect(filter_expr, engine.clone(), harness.clone()).await; + + // Now, from the Filter, access its Join child (index 1) + let filter_value = filter_results[0].clone(); + let join_expr = Arc::new(Expr::new(Call( + Arc::new(Expr::new(CoreVal(filter_value))), + vec![lit_expr(int(1))], + ))); + let join_results = evaluate_and_collect(join_expr, engine.clone(), harness.clone()).await; + + // Finally, from the Join, access its first TableScan child (index 2) + let join_value = join_results[0].clone(); + let table_scan_expr = Arc::new(Expr::new(Call( + Arc::new(Expr::new(CoreVal(join_value))), + vec![lit_expr(int(2))], + ))); + let table_scan_results = evaluate_and_collect(table_scan_expr, engine, harness).await; + + // Verify the final result is the "customers" TableScan + assert_eq!(table_scan_results.len(), 1); + match &table_scan_results[0].0 { + CoreData::Logical(Materializable::Materialized(log_op)) => { + assert_eq!(log_op.operator.tag, "TableScan"); + assert_eq!(log_op.operator.data.len(), 1); + match &log_op.operator.data[0].0 { + CoreData::Literal(lit) => { + assert_eq!(*lit, Literal::String("customers".to_string())); + } + _ => panic!("Expected string literal for table name"), + } + } + _ => panic!("Expected logical operator for table scan"), + } + } + /// Test variable reference in various contexts #[tokio::test] async fn test_variable_references() { From 9f1d93701b3158cd715141a8ce0df9cfbb88d6ee Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Thu, 3 Apr 2025 09:50:49 -0400 Subject: [PATCH 10/11] Rename nothing type into never --- optd-dsl/src/analyzer/type.rs | 42 +++++++++++++++++------------------ 1 file changed, 21 insertions(+), 21 deletions(-) diff --git a/optd-dsl/src/analyzer/type.rs b/optd-dsl/src/analyzer/type.rs index 9d0f8372..bcbdb363 100644 --- a/optd-dsl/src/analyzer/type.rs +++ b/optd-dsl/src/analyzer/type.rs @@ -22,7 +22,7 @@ pub enum Type { // Special types Unit, Universe, // All types are subtypes of Universe - Nothing, // Inherits all types. + Never, // Inherits all types. Unknown, // User types @@ -143,8 +143,8 @@ impl TypeRegistry { // Universe is the top type - everything is a subtype of Universe (_, Type::Universe) => true, - // Nothing is the bottom type - it is a subtype of everything - (Type::Nothing, _) => true, + // Never is the bottom type - it is a subtype of everything + (Type::Never, _) => true, // Stored and Costed type handling (Type::Stored(child_inner), Type::Stored(parent_inner)) => { @@ -618,30 +618,30 @@ mod type_registry_tests { } #[test] - fn test_nothing_as_bottom_type() { + fn test_never_as_bottom_type() { let registry = TypeRegistry::default(); - // Nothing is a subtype of all primitive types - assert!(registry.is_subtype(&Type::Nothing, &Type::Int64)); - assert!(registry.is_subtype(&Type::Nothing, &Type::String)); - assert!(registry.is_subtype(&Type::Nothing, &Type::Bool)); - assert!(registry.is_subtype(&Type::Nothing, &Type::Float64)); - assert!(registry.is_subtype(&Type::Nothing, &Type::Unit)); - assert!(registry.is_subtype(&Type::Nothing, &Type::Universe)); - - // Nothing is a subtype of complex types - assert!(registry.is_subtype(&Type::Nothing, &Type::Array(Box::new(Type::Int64)))); - assert!(registry.is_subtype(&Type::Nothing, &Type::Tuple(vec![Type::Int64, Type::Bool]))); + // Never is a subtype of all primitive types + assert!(registry.is_subtype(&Type::Never, &Type::Int64)); + assert!(registry.is_subtype(&Type::Never, &Type::String)); + assert!(registry.is_subtype(&Type::Never, &Type::Bool)); + assert!(registry.is_subtype(&Type::Never, &Type::Float64)); + assert!(registry.is_subtype(&Type::Never, &Type::Unit)); + assert!(registry.is_subtype(&Type::Never, &Type::Universe)); + + // Never is a subtype of complex types + assert!(registry.is_subtype(&Type::Never, &Type::Array(Box::new(Type::Int64)))); + assert!(registry.is_subtype(&Type::Never, &Type::Tuple(vec![Type::Int64, Type::Bool]))); assert!(registry.is_subtype( - &Type::Nothing, + &Type::Never, &Type::Closure(Box::new(Type::Int64), Box::new(Type::Bool)) )); - // But no type is a subtype of Nothing (except Nothing itself) - assert!(!registry.is_subtype(&Type::Int64, &Type::Nothing)); - assert!(!registry.is_subtype(&Type::Bool, &Type::Nothing)); - assert!(!registry.is_subtype(&Type::Universe, &Type::Nothing)); - assert!(!registry.is_subtype(&Type::Array(Box::new(Type::Int64)), &Type::Nothing)); + // But no type is a subtype of Never (except Never itself) + assert!(!registry.is_subtype(&Type::Int64, &Type::Never)); + assert!(!registry.is_subtype(&Type::Bool, &Type::Never)); + assert!(!registry.is_subtype(&Type::Universe, &Type::Never)); + assert!(!registry.is_subtype(&Type::Array(Box::new(Type::Int64)), &Type::Never)); } #[test] From e620ffb1a33b794de8246cd9e49e7555d50b60cc Mon Sep 17 00:00:00 2001 From: Alexis Schlomer Date: Thu, 3 Apr 2025 12:19:14 -0400 Subject: [PATCH 11/11] Add metadata to Value --- optd-core/src/bridge/from_cir.rs | 44 +++++----- optd-core/src/bridge/into_cir.rs | 33 ++++---- optd-dsl/src/analyzer/context.rs | 18 ++-- optd-dsl/src/analyzer/hir.rs | 95 +++++++++++---------- optd-dsl/src/analyzer/map.rs | 59 +++++++------ optd-dsl/src/engine/eval/binary.rs | 116 ++++++++++++------------- optd-dsl/src/engine/eval/core.rs | 55 ++++++------ optd-dsl/src/engine/eval/expr.rs | 122 +++++++++++++-------------- optd-dsl/src/engine/eval/match.rs | 79 ++++++++--------- optd-dsl/src/engine/eval/operator.rs | 48 +++++------ optd-dsl/src/engine/eval/unary.rs | 33 ++++---- optd-dsl/src/utils/tests.rs | 14 +-- 12 files changed, 365 insertions(+), 351 deletions(-) diff --git a/optd-core/src/bridge/from_cir.rs b/optd-core/src/bridge/from_cir.rs index 8b4c8338..391fe2d0 100644 --- a/optd-core/src/bridge/from_cir.rs +++ b/optd-core/src/bridge/from_cir.rs @@ -14,7 +14,7 @@ pub(crate) fn partial_logical_to_value(plan: &PartialLogicalPlan) -> Value { match plan { PartialLogicalPlan::UnMaterialized(group_id) => { // For unmaterialized logical operators, we create a `Value` with the group ID. - Value(Logical(UnMaterialized(hir::GroupId(group_id.0)))) + Value::new(Logical(UnMaterialized(hir::GroupId(group_id.0)))) } PartialLogicalPlan::Materialized(node) => { // For materialized logical operators, we create a `Value` with the operator data. @@ -24,7 +24,7 @@ pub(crate) fn partial_logical_to_value(plan: &PartialLogicalPlan) -> Value { children: convert_children_to_values(&node.children, partial_logical_to_value), }; - Value(Logical(Materialized(LogicalOp::logical(operator)))) + Value::new(Logical(Materialized(LogicalOp::logical(operator)))) } } } @@ -35,7 +35,7 @@ pub(crate) fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { PartialPhysicalPlan::UnMaterialized(goal) => { // For unmaterialized physical operators, we create a `Value` with the goal let hir_goal = cir_goal_to_hir(goal); - Value(Physical(UnMaterialized(hir_goal))) + Value::new(Physical(UnMaterialized(hir_goal))) } PartialPhysicalPlan::Materialized(node) => { // For materialized physical operators, we create a Value with the operator data @@ -45,7 +45,7 @@ pub(crate) fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { children: convert_children_to_values(&node.children, partial_physical_to_value), }; - Value(Physical(Materialized(PhysicalOp::physical(operator)))) + Value::new(Physical(Materialized(PhysicalOp::physical(operator)))) } } } @@ -54,9 +54,9 @@ pub(crate) fn partial_physical_to_value(plan: &PartialPhysicalPlan) -> Value { /// Converts a [`PartialPhysicalPlan`] with its cost into a [`Value`]. pub(crate) fn costed_physical_to_value(plan: PartialPhysicalPlan, cost: Cost) -> Value { let operator = partial_physical_to_value(&plan); - Value(Tuple(vec![ + Value::new(Tuple(vec![ partial_physical_to_value(&plan), - Value(Literal(Float64(cost.0))), + Value::new(Literal(Float64(cost.0))), ])) } @@ -65,7 +65,7 @@ pub(crate) fn costed_physical_to_value(plan: PartialPhysicalPlan, cost: Cost) -> pub(crate) fn logical_properties_to_value(properties: &LogicalProperties) -> Value { match &properties.0 { Some(data) => properties_data_to_value(data), - Option::None => Value(None), + Option::None => Value::new(None), } } @@ -73,7 +73,7 @@ pub(crate) fn logical_properties_to_value(properties: &LogicalProperties) -> Val pub(crate) fn physical_properties_to_value(properties: &PhysicalProperties) -> Value { match &properties.0 { Some(data) => properties_data_to_value(data), - Option::None => Value(None), + Option::None => Value::new(None), } } @@ -104,7 +104,7 @@ where .map(|child| match child { Child::Singleton(item) => converter(item), Child::VarLength(items) => { - Value(Array(items.iter().map(|item| converter(item)).collect())) + Value::new(Array(items.iter().map(|item| converter(item)).collect())) } }) .collect() @@ -118,32 +118,34 @@ fn convert_operator_data_to_values(data: &[OperatorData]) -> Vec { /// Converts an [`OperatorData`] into a [`Value`]. fn operator_data_to_value(data: &OperatorData) -> Value { match data { - OperatorData::Int64(i) => Value(Literal(Int64(*i))), - OperatorData::Float64(f) => Value(Literal(Float64(**f))), - OperatorData::String(s) => Value(Literal(String(s.clone()))), - OperatorData::Bool(b) => Value(Literal(Bool(*b))), - OperatorData::Struct(name, elements) => Value(Struct( + OperatorData::Int64(i) => Value::new(Literal(Int64(*i))), + OperatorData::Float64(f) => Value::new(Literal(Float64(**f))), + OperatorData::String(s) => Value::new(Literal(String(s.clone()))), + OperatorData::Bool(b) => Value::new(Literal(Bool(*b))), + OperatorData::Struct(name, elements) => Value::new(Struct( name.clone(), convert_operator_data_to_values(elements), )), - OperatorData::Array(elements) => Value(Array(convert_operator_data_to_values(elements))), + OperatorData::Array(elements) => { + Value::new(Array(convert_operator_data_to_values(elements))) + } } } /// Converts a [`PropertiesData`] into a [`Value`]. fn properties_data_to_value(data: &PropertiesData) -> Value { match data { - PropertiesData::Int64(i) => Value(Literal(Int64(*i))), - PropertiesData::Float64(f) => Value(Literal(Float64(**f))), - PropertiesData::String(s) => Value(Literal(String(s.clone()))), - PropertiesData::Bool(b) => Value(Literal(Bool(*b))), + PropertiesData::Int64(i) => Value::new(Literal(Int64(*i))), + PropertiesData::Float64(f) => Value::new(Literal(Float64(**f))), + PropertiesData::String(s) => Value::new(Literal(String(s.clone()))), + PropertiesData::Bool(b) => Value::new(Literal(Bool(*b))), PropertiesData::Struct(name, elements) => { let values = elements.iter().map(properties_data_to_value).collect(); - Value(Struct(name.clone(), values)) + Value::new(Struct(name.clone(), values)) } PropertiesData::Array(elements) => { let values = elements.iter().map(properties_data_to_value).collect(); - Value(Array(values)) + Value::new(Array(values)) } } } diff --git a/optd-core/src/bridge/into_cir.rs b/optd-core/src/bridge/into_cir.rs index 113b314c..f209b907 100644 --- a/optd-core/src/bridge/into_cir.rs +++ b/optd-core/src/bridge/into_cir.rs @@ -14,7 +14,7 @@ use std::sync::Arc; /// /// Panics if the [`Value`] is not a [`Logical`] variant. pub(crate) fn value_to_partial_logical(value: &Value) -> PartialLogicalPlan { - match &value.0 { + match &value.data { Logical(logical_op) => match logical_op { UnMaterialized(group_id) => { PartialLogicalPlan::UnMaterialized(hir_group_id_to_cir(group_id)) @@ -28,7 +28,7 @@ pub(crate) fn value_to_partial_logical(value: &Value) -> PartialLogicalPlan { ), }), }, - _ => panic!("Expected Logical CoreData variant, found: {:?}", value.0), + _ => panic!("Expected Logical CoreData variant, found: {:?}", value.data), } } @@ -38,7 +38,7 @@ pub(crate) fn value_to_partial_logical(value: &Value) -> PartialLogicalPlan { /// /// Panics if the [`Value`] is not a [`Physical`] variant. pub(crate) fn value_to_partial_physical(value: &Value) -> PartialPhysicalPlan { - match &value.0 { + match &value.data { Physical(physical_op) => match physical_op { UnMaterialized(hir_goal) => { PartialPhysicalPlan::UnMaterialized(hir_goal_to_cir(hir_goal)) @@ -52,7 +52,10 @@ pub(crate) fn value_to_partial_physical(value: &Value) -> PartialPhysicalPlan { ), }), }, - _ => panic!("Expected Physical CoreData variant, found: {:?}", value.0), + _ => panic!( + "Expected Physical CoreData variant, found: {:?}", + value.data + ), } } @@ -62,15 +65,15 @@ pub(crate) fn value_to_partial_physical(value: &Value) -> PartialPhysicalPlan { /// /// Panics if the [`Value`] is not a [`Literal`] variant with a [`Float64`] value. pub(crate) fn value_to_cost(value: &Value) -> Cost { - match &value.0 { + match &value.data { Literal(Float64(f)) => Cost(*f), - _ => panic!("Expected Float64 literal, found: {:?}", value.0), + _ => panic!("Expected Float64 literal, found: {:?}", value.data), } } /// Converts an HIR properties [`Value`] into a CIR [`LogicalProperties`]. pub(crate) fn value_to_logical_properties(properties_value: &Value) -> LogicalProperties { - match &properties_value.0 { + match &properties_value.data { None => LogicalProperties(Option::None), _ => LogicalProperties(Some(value_to_properties_data(properties_value))), } @@ -78,7 +81,7 @@ pub(crate) fn value_to_logical_properties(properties_value: &Value) -> LogicalPr /// Convert an HIR properties [`Value`] into a CIR [`PhysicalProperties`]. fn value_to_physical_properties(properties_value: &Value) -> PhysicalProperties { - match &properties_value.0 { + match &properties_value.data { None => PhysicalProperties(Option::None), _ => PhysicalProperties(Some(value_to_properties_data(properties_value))), } @@ -108,7 +111,7 @@ pub(crate) fn hir_goal_to_cir(hir_goal: &hir::Goal) -> Goal { /// Panics if the [`Value`] is not a [`Logical`] variant or if the [`Logical`] variant is not a /// [`Materialized`] variant. fn value_to_logical(value: &Value) -> LogicalPlan { - match &value.0 { + match &value.data { Logical(logical_op) => match logical_op { UnMaterialized(_) => { panic!("Cannot convert UnMaterialized LogicalOperator to LogicalPlan") @@ -119,7 +122,7 @@ fn value_to_logical(value: &Value) -> LogicalPlan { children: convert_values_to_children(&log_op.operator.children, value_to_logical), }), }, - _ => panic!("Expected Logical CoreData variant, found: {:?}", value.0), + _ => panic!("Expected Logical CoreData variant, found: {:?}", value.data), } } @@ -132,7 +135,7 @@ where { values .iter() - .map(|value| match &value.0 { + .map(|value| match &value.data { Array(elements) => VarLength( elements .iter() @@ -160,7 +163,7 @@ fn convert_values_to_properties_data(values: &[Value]) -> Vec { /// /// Panics if the [`Value`] cannot be converted to [`OperatorData`], such as a [`Unit`] literal. fn value_to_operator_data(value: &Value) -> OperatorData { - match &value.0 { + match &value.data { Literal(constant) => match constant { Int64(i) => OperatorData::Int64(*i), Float64(f) => OperatorData::Float64((*f).into()), @@ -172,7 +175,7 @@ fn value_to_operator_data(value: &Value) -> OperatorData { Struct(name, elements) => { OperatorData::Struct(name.clone(), convert_values_to_operator_data(elements)) } - _ => panic!("Cannot convert {:?} to OperatorData", value.0), + _ => panic!("Cannot convert {:?} to OperatorData", value.data), } } @@ -182,7 +185,7 @@ fn value_to_operator_data(value: &Value) -> OperatorData { /// /// Panics if the [`Value`] cannot be converted to [`PropertiesData`], such as a [`Unit`] literal. fn value_to_properties_data(value: &Value) -> PropertiesData { - match &value.0 { + match &value.data { Literal(constant) => match constant { Int64(i) => PropertiesData::Int64(*i), Float64(f) => PropertiesData::Float64((*f).into()), @@ -194,6 +197,6 @@ fn value_to_properties_data(value: &Value) -> PropertiesData { Struct(name, elements) => { PropertiesData::Struct(name.clone(), convert_values_to_properties_data(elements)) } - _ => panic!("Cannot convert {:?} to PropertyData content", value.0), + _ => panic!("Cannot convert {:?} to PropertyData content", value.data), } } diff --git a/optd-dsl/src/analyzer/context.rs b/optd-dsl/src/analyzer/context.rs index 08fba5ec..621bd69e 100644 --- a/optd-dsl/src/analyzer/context.rs +++ b/optd-dsl/src/analyzer/context.rs @@ -1,4 +1,4 @@ -use super::hir::Identifier; +use super::hir::{ExprMetadata, Identifier, NoMetadata}; use crate::analyzer::hir::Value; use std::{collections::HashMap, sync::Arc}; @@ -12,15 +12,15 @@ use std::{collections::HashMap, sync::Arc}; /// The current (innermost) scope is mutable, while all previous scopes /// are immutable and stored as Arc for efficient cloning. #[derive(Debug, Clone, Default)] -pub struct Context { +pub struct Context { /// Previous scopes (outer lexical scopes), stored as immutable Arc references - previous_scopes: Vec>>, + previous_scopes: Vec>>>, /// Current scope (innermost) that can be directly modified - current_scope: HashMap, + current_scope: HashMap>, } -impl Context { +impl Context { /// Creates a new context with the given initial bindings as the global scope. /// /// # Arguments @@ -30,7 +30,7 @@ impl Context { /// # Returns /// /// A new `Context` instance with one scope containing the initial bindings - pub fn new(initial_bindings: HashMap) -> Self { + pub fn new(initial_bindings: HashMap>) -> Self { Self { previous_scopes: Vec::new(), current_scope: initial_bindings, @@ -59,7 +59,7 @@ impl Context { /// # Returns /// /// Some reference to the value if found, None otherwise - pub fn lookup(&self, name: &str) -> Option<&Value> { + pub fn lookup(&self, name: &str) -> Option<&Value> { // First check the current scope if let Some(value) = self.current_scope.get(name) { return Some(value); @@ -84,7 +84,7 @@ impl Context { /// # Arguments /// /// * `other` - The context to merge from (consumed by this operation) - pub fn merge(&mut self, other: Context) { + pub fn merge(&mut self, other: Context) { // Move bindings from other's current scope into our current scope for (name, val) in other.current_scope { self.current_scope.insert(name, val); @@ -100,7 +100,7 @@ impl Context { /// /// * `name` - The name of the variable to bind /// * `val` - The value to bind to the variable - pub fn bind(&mut self, name: String, val: Value) { + pub fn bind(&mut self, name: String, val: Value) { self.current_scope.insert(name, val); } } diff --git a/optd-dsl/src/analyzer/hir.rs b/optd-dsl/src/analyzer/hir.rs index 97967766..fa974e03 100644 --- a/optd-dsl/src/analyzer/hir.rs +++ b/optd-dsl/src/analyzer/hir.rs @@ -38,11 +38,32 @@ pub enum Literal { Unit, } +/// Metadata that can be attached to expression nodes +/// +/// This trait allows for different types of metadata to be attached to +/// expression nodes while maintaining a common interface for access. +pub trait ExprMetadata: Debug + Clone {} + +/// Empty metadata implementation for cases where no additional data is needed +#[derive(Debug, Clone, Default)] +pub struct NoMetadata; +impl ExprMetadata for NoMetadata {} + +/// Combined span and type information for an expression +#[derive(Debug, Clone)] +pub struct TypedSpan { + /// Source code location. + pub span: Span, + /// Inferred type. + pub ty: Type, +} +impl ExprMetadata for TypedSpan {} + /// Types of functions in the system #[derive(Debug, Clone)] -pub enum FunKind { - Closure(Vec, Arc), - RustUDF(fn(Vec) -> Value), +pub enum FunKind { + Closure(Vec, Arc>), + RustUDF(fn(Vec>) -> Value), } /// Group identifier in the optimizer @@ -69,7 +90,7 @@ pub struct Goal { /// The logical group to implement pub group_id: GroupId, /// Required physical properties - pub properties: Box, + pub properties: Box>, } /// Unified operator node structure for all operator types @@ -125,13 +146,13 @@ impl LogicalOp { /// Represents an executable implementation of a logical operation with specific /// physical properties, either materialized as a concrete operator or as a physical goal. #[derive(Debug, Clone)] -pub struct PhysicalOp { +pub struct PhysicalOp { pub operator: Operator, pub goal: Option, - pub cost: Option>, + pub cost: Option>>, } -impl PhysicalOp { +impl PhysicalOp { /// Creates a new physical operator without goal or cost information /// /// Used for representing physical operators that are not yet part of the @@ -159,18 +180,31 @@ impl PhysicalOp { /// Creates a new physical operator with both goal and cost information /// /// Used for representing fully optimized physical operators with computed cost. - pub fn costed_physical(operator: Operator, goal: Goal, cost: Value) -> Self { + pub fn costed_physical(operator: Operator, goal: Goal, cost: Value) -> Self { Self { operator, goal: Some(goal), - cost: Some(cost.into()), + cost: Some(Box::new(cost)), } } } +/// Evaluated expression result +#[derive(Debug, Clone)] +pub struct Value { + pub data: CoreData, M>, +} + +impl Value { + /// Creates a new value from core data + pub fn new(data: CoreData, M>) -> Self { + Self { data } + } +} + /// Core data structures shared across the system #[derive(Debug, Clone)] -pub enum CoreData { +pub enum CoreData { /// Primitive literal values Literal(Literal), /// Ordered collection of values @@ -182,38 +216,17 @@ pub enum CoreData { /// Named structure with fields Struct(Identifier, Vec), /// Function or closure - Function(FunKind), + Function(FunKind), /// Error representation Fail(Box), /// Logical query operators Logical(Materializable, GroupId>), /// Physical query operators - Physical(Materializable, Goal>), + Physical(Materializable, Goal>), /// The None value None, } -/// Metadata that can be attached to expression nodes -/// -/// This trait allows for different types of metadata to be attached to -/// expression nodes while maintaining a common interface for access. -pub trait ExprMetadata: Debug + Clone {} - -/// Empty metadata implementation for cases where no additional data is needed -#[derive(Debug, Clone, Default)] -pub struct NoMetadata; -impl ExprMetadata for NoMetadata {} - -/// Combined span and type information for an expression -#[derive(Debug, Clone)] -pub struct TypedSpan { - /// Source code location. - pub span: Span, - /// Inferred type. - pub ty: Type, -} -impl ExprMetadata for TypedSpan {} - /// Expression nodes in the HIR with optional metadata /// /// The M type parameter allows attaching different kinds of metadata to expressions, @@ -241,7 +254,7 @@ pub type MapEntries = Vec<(Arc>, Arc>)>; /// Expression node kinds without metadata #[derive(Debug, Clone)] -pub enum ExprKind { +pub enum ExprKind { /// Pattern matching expression PatternMatch(Arc>, Vec>), /// Conditional expression @@ -259,15 +272,11 @@ pub enum ExprKind { /// Variable reference Ref(Identifier), /// Core expression - CoreExpr(CoreData>>), + CoreExpr(CoreData>, M>), /// Core value - CoreVal(Value), + CoreVal(Value), } -/// Evaluated expression result -#[derive(Debug, Clone)] -pub struct Value(pub CoreData); - /// Pattern for matching #[derive(Debug, Clone)] pub enum Pattern { @@ -321,10 +330,6 @@ pub enum UnaryOp { /// Program representation after the analysis phase #[derive(Debug)] pub struct HIR { - pub context: Context, + pub context: Context, pub annotations: HashMap>, - pub expressions: Vec>, } - -/// Type alias for HIR with both type and source location information -pub type TypedSpannedHIR = HIR; diff --git a/optd-dsl/src/analyzer/map.rs b/optd-dsl/src/analyzer/map.rs index 3094c97f..d874d587 100644 --- a/optd-dsl/src/analyzer/map.rs +++ b/optd-dsl/src/analyzer/map.rs @@ -14,10 +14,10 @@ //! - Efficient key lookup (O(1) via HashMap) //! - Basic map operations (get, concat) +use super::hir::ExprMetadata; use super::hir::{ CoreData, GroupId, Literal, LogicalOp, Materializable, Operator, PhysicalOp, Value, }; -use CoreData::*; use std::collections::HashMap; use std::hash::Hash; @@ -107,10 +107,10 @@ impl Map { } /// Gets a value by key, returning None (as a Value) if not found - pub fn get(&self, key: &Value) -> Value { + pub fn get(&self, key: &Value) -> Value { self.inner .get(&value_to_map_key(key)) - .unwrap_or(&Value(None)) + .unwrap_or(&Value::new(CoreData::None)) .clone() } @@ -125,24 +125,24 @@ impl Map { /// Converts a Value to a MapKey, enforcing valid key types /// This performs runtime validation that the key type is supported /// and will return an error for invalid key types -fn value_to_map_key(value: &Value) -> MapKey { - match &value.0 { - Literal(lit) => match lit { +fn value_to_map_key(value: &Value) -> MapKey { + match &value.data { + CoreData::Literal(lit) => match lit { Literal::Int64(i) => MapKey::Int64(*i), Literal::String(s) => MapKey::String(s.clone()), Literal::Bool(b) => MapKey::Bool(*b), Literal::Unit => MapKey::Unit, Literal::Float64(_) => panic!("Invalid map key: Float64"), }, - Tuple(items) => { + CoreData::Tuple(items) => { let key_items = items.iter().map(value_to_map_key).collect(); MapKey::Tuple(key_items) } - Struct(name, fields) => { + CoreData::Struct(name, fields) => { let key_fields = fields.iter().map(value_to_map_key).collect(); MapKey::Struct(name.clone(), key_fields) } - Logical(materializable) => match materializable { + CoreData::Logical(materializable) => match materializable { Materializable::UnMaterialized(group_id) => { MapKey::Logical(Box::new(LogicalMapKey::UnMaterialized(*group_id))) } @@ -151,7 +151,7 @@ fn value_to_map_key(value: &Value) -> MapKey { MapKey::Logical(Box::new(LogicalMapKey::Materialized(map_op))) } }, - Physical(materializable) => match materializable { + CoreData::Physical(materializable) => match materializable { Materializable::UnMaterialized(goal) => { let properties = value_to_map_key(&goal.properties); let map_goal = GoalMapKey { @@ -165,11 +165,11 @@ fn value_to_map_key(value: &Value) -> MapKey { MapKey::Physical(Box::new(PhysicalMapKey::Materialized(map_op))) } }, - Fail(inner) => { + CoreData::Fail(inner) => { let inner_key = value_to_map_key(inner); MapKey::Fail(Box::new(inner_key)) } - None => MapKey::None, + CoreData::None => MapKey::None, _ => panic!("Invalid map key: {:?}", value), } } @@ -191,8 +191,9 @@ fn value_to_operator_map_key( } /// Converts a LogicalOp to a map key -fn value_to_logical_map_op(logical_op: &LogicalOp) -> LogicalMapOpKey { - let operator = value_to_operator_map_key(&logical_op.operator, &value_to_map_key); +fn value_to_logical_map_op(logical_op: &LogicalOp>) -> LogicalMapOpKey { + let operator = + value_to_operator_map_key(&logical_op.operator, &(|v: &Value| value_to_map_key(v))); LogicalMapOpKey { operator, @@ -201,8 +202,11 @@ fn value_to_logical_map_op(logical_op: &LogicalOp) -> LogicalMapOpKey { } /// Converts a PhysicalOp to a map key -fn value_to_physical_map_op(physical_op: &PhysicalOp) -> PhysicalMapOpKey { - let operator = value_to_operator_map_key(&physical_op.operator, &value_to_map_key); +fn value_to_physical_map_op( + physical_op: &PhysicalOp, M>, +) -> PhysicalMapOpKey { + let operator = + value_to_operator_map_key(&physical_op.operator, &(|v: &Value| value_to_map_key(v))); let goal = physical_op.goal.as_ref().map(|g| GoalMapKey { group_id: g.group_id, @@ -220,51 +224,50 @@ fn value_to_physical_map_op(physical_op: &PhysicalOp) -> PhysicalMapOpKey #[cfg(test)] mod tests { + use super::*; use crate::utils::tests::{ assert_values_equal, create_logical_operator, create_physical_operator, }; - use super::*; - // Helper to create Value literals fn int_val(i: i64) -> Value { - Value(Literal(Literal::Int64(i))) + Value::new(CoreData::Literal(Literal::Int64(i))) } fn bool_val(b: bool) -> Value { - Value(Literal(Literal::Bool(b))) + Value::new(CoreData::Literal(Literal::Bool(b))) } fn string_val(s: &str) -> Value { - Value(Literal(Literal::String(s.to_string()))) + Value::new(CoreData::Literal(Literal::String(s.to_string()))) } fn float_val(f: f64) -> Value { - Value(Literal(Literal::Float64(f))) + Value::new(CoreData::Literal(Literal::Float64(f))) } fn unit_val() -> Value { - Value(Literal(Literal::Unit)) + Value::new(CoreData::Literal(Literal::Unit)) } fn tuple_val(items: Vec) -> Value { - Value(Tuple(items)) + Value::new(CoreData::Tuple(items)) } fn struct_val(name: &str, fields: Vec) -> Value { - Value(Struct(name.to_string(), fields)) + Value::new(CoreData::Struct(name.to_string(), fields)) } fn array_val(items: Vec) -> Value { - Value(Array(items)) + Value::new(CoreData::Array(items)) } fn none_val() -> Value { - Value(None) + Value::new(CoreData::None) } fn fail_val(inner: Value) -> Value { - Value(Fail(Box::new(inner))) + Value::new(CoreData::Fail(Box::new(inner))) } #[test] diff --git a/optd-dsl/src/engine/eval/binary.rs b/optd-dsl/src/engine/eval/binary.rs index 2949779a..c5ea2901 100644 --- a/optd-dsl/src/engine/eval/binary.rs +++ b/optd-dsl/src/engine/eval/binary.rs @@ -23,27 +23,29 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { use BinOp::*; use CoreData::*; - match (left.0, op, right.0) { + match (left.data, op, right.data) { // Handle operations between two literals. (Literal(l), op, Literal(r)) => match (l, op, r) { // Integer operations (arithmetic, comparison). - (Int64(l), Add | Sub | Mul | Div | Eq | Lt, Int64(r)) => Value(Literal(match op { - Add => Int64(l + r), // Integer addition - Sub => Int64(l - r), // Integer subtraction - Mul => Int64(l * r), // Integer multiplication - Div => Int64(l / r), // Integer division (panics on divide by zero) - Eq => Bool(l == r), // Integer equality comparison - Lt => Bool(l < r), // Integer less-than comparison - _ => unreachable!(), // This branch is unreachable due to pattern guard - })), + (Int64(l), Add | Sub | Mul | Div | Eq | Lt, Int64(r)) => { + Value::new(Literal(match op { + Add => Int64(l + r), // Integer addition + Sub => Int64(l - r), // Integer subtraction + Mul => Int64(l * r), // Integer multiplication + Div => Int64(l / r), // Integer division (panics on divide by zero) + Eq => Bool(l == r), // Integer equality comparison + Lt => Bool(l < r), // Integer less-than comparison + _ => unreachable!(), // This branch is unreachable due to pattern guard + })) + } // Integer range operation (creates an array of sequential integers). - (Int64(l), Range, Int64(r)) => { - Value(Array((l..=r).map(|n| Value(Literal(Int64(n)))).collect())) - } + (Int64(l), Range, Int64(r)) => Value::new(Array( + (l..=r).map(|n| Value::new(Literal(Int64(n)))).collect(), + )), // Float operations (arithmetic, comparison). - (Float64(l), op, Float64(r)) => Value(Literal(match op { + (Float64(l), op, Float64(r)) => Value::new(Literal(match op { Add => Float64(l + r), // Float addition Sub => Float64(l - r), // Float subtraction Mul => Float64(l * r), // Float multiplication @@ -53,7 +55,7 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { })), // Boolean operations (logical, comparison). - (Bool(l), op, Bool(r)) => Value(Literal(match op { + (Bool(l), op, Bool(r)) => Value::new(Literal(match op { And => Bool(l && r), // Logical AND Or => Bool(l || r), // Logical OR Eq => Bool(l == r), // Boolean equality comparison @@ -61,7 +63,7 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { })), // String operations (comparison, concatenation). - (String(l), op, String(r)) => Value(Literal(match op { + (String(l), op, String(r)) => Value::new(Literal(match op { Eq => Bool(l == r), // String equality comparison Concat => String(format!("{l}{r}")), // String concatenation _ => panic!("Invalid string operation"), // Other operations not supported @@ -75,13 +77,13 @@ pub(crate) fn eval_binary_op(left: Value, op: &BinOp, right: Value) -> Value { (Array(l), Concat, Array(r)) => { let mut result = l.clone(); result.extend(r.iter().cloned()); - Value(Array(result)) + Value::new(Array(result)) } // Map concatenation (joins two maps). (Map(mut l), Concat, Map(r)) => { l.concat(r); - Value(Map(l)) + Value::new(Map(l)) } // Any other combination of value types or operations is not supported. @@ -100,49 +102,49 @@ mod tests { // Helper function to create integer Value fn int(i: i64) -> Value { - Value(Literal(Int64(i))) + Value::new(Literal(Int64(i))) } // Helper function to create float Value fn float(f: f64) -> Value { - Value(Literal(Float64(f))) + Value::new(Literal(Float64(f))) } // Helper function to create boolean Value fn boolean(b: bool) -> Value { - Value(Literal(Bool(b))) + Value::new(Literal(Bool(b))) } // Helper function to create string Value fn string(s: &str) -> Value { - Value(Literal(String(s.to_string()))) + Value::new(Literal(String(s.to_string()))) } #[test] fn test_integer_arithmetic() { // Addition - if let Literal(Int64(result)) = eval_binary_op(int(5), &Add, int(7)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(5), &Add, int(7)).data { assert_eq!(result, 12); } else { panic!("Expected Int64"); } // Subtraction - if let Literal(Int64(result)) = eval_binary_op(int(10), &Sub, int(3)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(10), &Sub, int(3)).data { assert_eq!(result, 7); } else { panic!("Expected Int64"); } // Multiplication - if let Literal(Int64(result)) = eval_binary_op(int(4), &Mul, int(5)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(4), &Mul, int(5)).data { assert_eq!(result, 20); } else { panic!("Expected Int64"); } // Division - if let Literal(Int64(result)) = eval_binary_op(int(20), &Div, int(4)).0 { + if let Literal(Int64(result)) = eval_binary_op(int(20), &Div, int(4)).data { assert_eq!(result, 5); } else { panic!("Expected Int64"); @@ -152,28 +154,28 @@ mod tests { #[test] fn test_integer_comparison() { // Equality - true case - if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(5)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(5)).data { assert!(result); } else { panic!("Expected Bool"); } // Equality - false case - if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(7)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(5), &Eq, int(7)).data { assert!(!result); } else { panic!("Expected Bool"); } // Less than - true case - if let Literal(Bool(result)) = eval_binary_op(int(5), &Lt, int(10)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(5), &Lt, int(10)).data { assert!(result); } else { panic!("Expected Bool"); } // Less than - false case - if let Literal(Bool(result)) = eval_binary_op(int(10), &Lt, int(5)).0 { + if let Literal(Bool(result)) = eval_binary_op(int(10), &Lt, int(5)).data { assert!(!result); } else { panic!("Expected Bool"); @@ -183,12 +185,12 @@ mod tests { #[test] fn test_integer_range() { // Range operation - if let Array(result) = eval_binary_op(int(1), &Range, int(5)).0 { + if let Array(result) = eval_binary_op(int(1), &Range, int(5)).data { assert_eq!(result.len(), 5); // Check individual elements for (i, val) in result.iter().enumerate() { - if let Literal(Int64(n)) = val.0 { + if let Literal(Int64(n)) = val.data { assert_eq!(n, (i as i64) + 1); } else { panic!("Expected Int64 in array"); @@ -202,35 +204,35 @@ mod tests { #[test] fn test_float_operations() { // Addition - if let Literal(Float64(result)) = eval_binary_op(float(3.5), &Add, float(2.25)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(3.5), &Add, float(2.25)).data { assert_eq!(result, 5.75); } else { panic!("Expected Float64"); } // Subtraction - if let Literal(Float64(result)) = eval_binary_op(float(10.5), &Sub, float(3.25)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(10.5), &Sub, float(3.25)).data { assert_eq!(result, 7.25); } else { panic!("Expected Float64"); } // Multiplication - if let Literal(Float64(result)) = eval_binary_op(float(4.0), &Mul, float(2.5)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(4.0), &Mul, float(2.5)).data { assert_eq!(result, 10.0); } else { panic!("Expected Float64"); } // Division - if let Literal(Float64(result)) = eval_binary_op(float(10.0), &Div, float(2.5)).0 { + if let Literal(Float64(result)) = eval_binary_op(float(10.0), &Div, float(2.5)).data { assert_eq!(result, 4.0); } else { panic!("Expected Float64"); } // Less than - if let Literal(Bool(result)) = eval_binary_op(float(3.5), &Lt, float(3.6)).0 { + if let Literal(Bool(result)) = eval_binary_op(float(3.5), &Lt, float(3.6)).data { assert!(result); } else { panic!("Expected Bool"); @@ -240,35 +242,35 @@ mod tests { #[test] fn test_boolean_operations() { // AND - true case - if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(true)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(true)).data { assert!(result); } else { panic!("Expected Bool"); } // AND - false case - if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(false)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(true), &And, boolean(false)).data { assert!(!result); } else { panic!("Expected Bool"); } // OR - true case - if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(true)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(true)).data { assert!(result); } else { panic!("Expected Bool"); } // OR - false case - if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(false)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(false), &Or, boolean(false)).data { assert!(!result); } else { panic!("Expected Bool"); } // Equality - if let Literal(Bool(result)) = eval_binary_op(boolean(true), &Eq, boolean(true)).0 { + if let Literal(Bool(result)) = eval_binary_op(boolean(true), &Eq, boolean(true)).data { assert!(result); } else { panic!("Expected Bool"); @@ -278,14 +280,14 @@ mod tests { #[test] fn test_string_operations() { // Equality - true case - if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("hello")).0 { + if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("hello")).data { assert!(result); } else { panic!("Expected Bool"); } // Equality - false case - if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("world")).0 { + if let Literal(Bool(result)) = eval_binary_op(string("hello"), &Eq, string("world")).data { assert!(!result); } else { panic!("Expected Bool"); @@ -293,7 +295,7 @@ mod tests { // Concatenation if let Literal(String(result)) = - eval_binary_op(string("hello "), &Concat, string("world")).0 + eval_binary_op(string("hello "), &Concat, string("world")).data { assert_eq!(result, "hello world"); } else { @@ -304,17 +306,17 @@ mod tests { #[test] fn test_array_concatenation() { // Create two arrays - let array1 = Value(Array(vec![int(1), int(2), int(3)])); - let array2 = Value(Array(vec![int(4), int(5)])); + let array1 = Value::new(Array(vec![int(1), int(2), int(3)])); + let array2 = Value::new(Array(vec![int(4), int(5)])); // Concatenate arrays - if let Array(result) = eval_binary_op(array1, &Concat, array2).0 { + if let Array(result) = eval_binary_op(array1, &Concat, array2).data { assert_eq!(result.len(), 5); // Check the elements let expected = [1, 2, 3, 4, 5]; for (i, val) in result.iter().enumerate() { - if let Literal(Int64(n)) = val.0 { + if let Literal(Int64(n)) = val.data { assert_eq!(n, expected[i] as i64); } else { panic!("Expected Int64 in array"); @@ -330,44 +332,44 @@ mod tests { use crate::analyzer::map::Map; // Create two maps using Map::from_pairs - let map1 = Value(Map(Map::from_pairs(vec![ + let map1 = Value::new(Map(Map::from_pairs(vec![ (string("a"), int(1)), (string("b"), int(2)), ]))); - let map2 = Value(Map(Map::from_pairs(vec![ + let map2 = Value::new(Map(Map::from_pairs(vec![ (string("c"), int(3)), (string("d"), int(4)), ]))); // Concatenate maps - if let Map(result) = eval_binary_op(map1, &Concat, map2).0 { + if let Map(result) = eval_binary_op(map1, &Concat, map2).data { // Check each key-value pair is accessible - if let Literal(Int64(v)) = result.get(&string("a")).0 { + if let Literal(Int64(v)) = result.get(&string("a")).data { assert_eq!(v, 1); } else { panic!("Expected Int64 for key 'a'"); } - if let Literal(Int64(v)) = result.get(&string("b")).0 { + if let Literal(Int64(v)) = result.get(&string("b")).data { assert_eq!(v, 2); } else { panic!("Expected Int64 for key 'b'"); } - if let Literal(Int64(v)) = result.get(&string("c")).0 { + if let Literal(Int64(v)) = result.get(&string("c")).data { assert_eq!(v, 3); } else { panic!("Expected Int64 for key 'c'"); } - if let Literal(Int64(v)) = result.get(&string("d")).0 { + if let Literal(Int64(v)) = result.get(&string("d")).data { assert_eq!(v, 4); } else { panic!("Expected Int64 for key 'd'"); } // Check a non-existent key returns None - if let None = result.get(&string("z")).0 { + if let None = result.get(&string("z")).data { // This is the expected behavior } else { panic!("Expected None for non-existent key"); diff --git a/optd-dsl/src/engine/eval/core.rs b/optd-dsl/src/engine/eval/core.rs index fbbebed5..ce7729fc 100644 --- a/optd-dsl/src/engine/eval/core.rs +++ b/optd-dsl/src/engine/eval/core.rs @@ -27,7 +27,7 @@ where match data { Literal(lit) => { // Directly continue with the literal value. - k(Value(Literal(lit))).await + k(Value::new(Literal(lit))).await } Array(items) => evaluate_collection(items, Array, engine, k).await, Tuple(items) => evaluate_collection(items, Tuple, engine, k).await, @@ -36,18 +36,18 @@ where } Map(items) => { // Directly continue with the map value. - k(Value(Map(items))).await + k(Value::new(Map(items))).await } Function(fun_type) => { // Directly continue with the function value. - k(Value(Function(fun_type))).await + k(Value::new(Function(fun_type))).await } Fail(msg) => evaluate_fail(*msg, engine, k).await, Logical(op) => evaluate_logical_operator(op, engine, k).await, Physical(op) => evaluate_physical_operator(op, engine, k).await, None => { // Directly continue with null value. - k(Value(None)).await + k(Value::new(None)).await } } } @@ -75,7 +75,7 @@ where engine, Arc::new(move |values| { Box::pin(capture!([constructor, k], async move { - let result = Value(constructor(values)); + let result = Value::new(constructor(values)); k(result).await })) }), @@ -102,10 +102,9 @@ where .evaluate( msg, Arc::new(move |value| { - Box::pin(capture!( - [k], - async move { k(Value(Fail(value.into()))).await } - )) + Box::pin(capture!([k], async move { + k(Value::new(Fail(Box::new(value)))).await + })) }), ) .await @@ -141,7 +140,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 42); } @@ -167,18 +166,18 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Array(elements) => { assert_eq!(elements.len(), 3); - match &elements[0].0 { + match &elements[0].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 1), _ => panic!("Expected integer literal"), } - match &elements[1].0 { + match &elements[1].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 2), _ => panic!("Expected integer literal"), } - match &elements[2].0 { + match &elements[2].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 3), _ => panic!("Expected integer literal"), } @@ -205,18 +204,18 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Tuple(elements) => { assert_eq!(elements.len(), 3); - match &elements[0].0 { + match &elements[0].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 42), _ => panic!("Expected integer literal"), } - match &elements[1].0 { + match &elements[1].data { CoreData::Literal(Literal::String(value)) => assert_eq!(value, "hello"), _ => panic!("Expected string literal"), } - match &elements[2].0 { + match &elements[2].data { CoreData::Literal(Literal::Bool(value)) => assert!(*value), _ => panic!("Expected boolean literal"), } @@ -242,15 +241,15 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Struct(name, fields) => { assert_eq!(name, "Point"); assert_eq!(fields.len(), 2); - match &fields[0].0 { + match &fields[0].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 10), _ => panic!("Expected integer literal"), } - match &fields[1].0 { + match &fields[1].data { CoreData::Literal(Literal::Int64(value)) => assert_eq!(*value, 20), _ => panic!("Expected integer literal"), } @@ -276,7 +275,7 @@ mod tests { // Check that we got a function value assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Function(_) => { // Successfully evaluated to a function } @@ -298,7 +297,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::None => { // Successfully evaluated to null } @@ -311,12 +310,12 @@ mod tests { async fn test_fail_evaluation() { let return_k: Continuation> = Arc::new(move |value| { Box::pin(async move { - match value { - Value(CoreData::Fail(boxed_value)) => match boxed_value.0 { + match value.data { + CoreData::Fail(boxed_value) => match boxed_value.data { CoreData::Literal(Literal::String(msg)) => Err(msg), _ => panic!("Expected string message in fail"), }, - value => Ok(value), + _ => Ok(value), } }) }); @@ -327,7 +326,9 @@ mod tests { // Create a fail expression with a message let fail_expr = Arc::new(Expr::new(CoreExpr(CoreData::Fail(Box::new(Arc::new( - Expr::new(CoreVal(Value(CoreData::Literal(string("error message"))))), + Expr::new(CoreVal(Value::new(CoreData::Literal(string( + "error message", + ))))), )))))); let results = diff --git a/optd-dsl/src/engine/eval/expr.rs b/optd-dsl/src/engine/eval/expr.rs index 8dc52ca5..e2d61b6b 100644 --- a/optd-dsl/src/engine/eval/expr.rs +++ b/optd-dsl/src/engine/eval/expr.rs @@ -40,7 +40,7 @@ where cond, Arc::new(move |value| { Box::pin(capture!([then_expr, else_expr, engine, k], async move { - match value.0 { + match value.data { CoreData::Literal(Literal::Bool(b)) => { if b { engine.evaluate(then_expr, k).await @@ -217,7 +217,7 @@ where called, Arc::new(move |called_value| { Box::pin(capture!([args, engine, k], async move { - match called_value.0 { + match called_value.data { // Handle function calls. CoreData::Function(FunKind::Closure(params, body)) => { evaluate_closure_call(params, body, args, engine, k).await @@ -407,7 +407,7 @@ where /// /// The extracted integer index. fn extract_index(index_value: &Value) -> usize { - match &index_value.0 { + match &index_value.data { CoreData::Literal(Literal::Int64(i)) => *i as usize, _ => panic!("Index must be an integer, got: {:?}", index_value), } @@ -467,7 +467,7 @@ where Box::pin(capture!([collection, k], async move { let index = extract_index(&index_value); - let result = match &collection.0 { + let result = match &collection.data { CoreData::Array(items) => get_indexed_item(items, index), CoreData::Tuple(items) => get_indexed_item(items, index), CoreData::Struct(_, fields) => get_indexed_item(fields, index), @@ -523,7 +523,7 @@ where args[0].clone(), Arc::new(move |key_value| { Box::pin(capture!([map_value, k], async move { - match &map_value.0 { + match &map_value.data { CoreData::Map(map) => { let result = map.get(&key_value); k(result).await @@ -650,7 +650,7 @@ where Box::pin(capture!([keys_values, k], async move { // Create a map from keys and values. let map_items = keys_values.into_iter().zip(values_values).collect(); - k(Value(CoreData::Map(Map::from_pairs(map_items)))).await + k(Value::new(CoreData::Map(Map::from_pairs(map_items)))).await })) }), ) @@ -758,21 +758,21 @@ mod tests { let complex_results = evaluate_and_collect(complex_condition, engine_with_x, harness).await; // Check results - match &true_results[0].0 { + match &true_results[0].data { CoreData::Literal(Literal::String(value)) => { assert_eq!(value, "yes"); // true condition should select "yes" } _ => panic!("Expected string value"), } - match &false_results[0].0 { + match &false_results[0].data { CoreData::Literal(Literal::String(value)) => { assert_eq!(value, "no"); // false condition should select "no" } _ => panic!("Expected string value"), } - match &complex_results[0].0 { + match &complex_results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 40); // 20 * 2 = 40 (since x > 10) } @@ -801,7 +801,7 @@ mod tests { let results = evaluate_and_collect(let_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 15); // 10 + 5 = 15 } @@ -836,7 +836,7 @@ mod tests { let results = evaluate_and_collect(nested_let_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 30); // 10 + (10 * 2) = 30 } @@ -851,7 +851,7 @@ mod tests { let mut ctx = Context::default(); // Define a function: fn(x, y) => x + y - let add_function = Value(CoreData::Function(FunKind::Closure( + let add_function = Value::new(CoreData::Function(FunKind::Closure( vec!["x".to_string(), "y".to_string()], Arc::new(Expr::new(Binary(ref_expr("x"), BinOp::Add, ref_expr("y")))), ))); @@ -868,7 +868,7 @@ mod tests { let results = evaluate_and_collect(call_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 30); // 10 + 20 = 30 } @@ -883,16 +883,16 @@ mod tests { let mut ctx = Context::default(); // Define a Rust UDF that calculates the sum of array elements - let sum_function = Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + let sum_function = Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Array(elements) => { let mut sum = 0; for elem in elements { - if let CoreData::Literal(Literal::Int64(value)) = &elem.0 { + if let CoreData::Literal(Literal::Int64(value)) = &elem.data { sum += value; } } - Value(CoreData::Literal(Literal::Int64(sum))) + Value::new(CoreData::Literal(Literal::Int64(sum))) } _ => panic!("Expected array argument"), } @@ -916,7 +916,7 @@ mod tests { let results = evaluate_and_collect(call_expr, engine, harness).await; // Check result - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 15); // 1 + 2 + 3 + 4 + 5 = 15 } @@ -943,7 +943,7 @@ mod tests { // Check that we got a Map value assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Map(map) => { // Check that map has the correct key-value pairs assert_values_equal(&map.get(&lit_val(string("a"))), &lit_val(int(1))); @@ -951,7 +951,7 @@ mod tests { assert_values_equal(&map.get(&lit_val(string("c"))), &lit_val(int(3))); // Check that non-existent key returns None value - assert_values_equal(&map.get(&lit_val(string("d"))), &Value(CoreData::None)); + assert_values_equal(&map.get(&lit_val(string("d"))), &Value::new(CoreData::None)); } _ => panic!("Expected Map value"), } @@ -990,7 +990,7 @@ mod tests { // Check that we got a Map value with correctly evaluated keys and values assert_eq!(complex_results.len(), 1); - match &complex_results[0].0 { + match &complex_results[0].data { CoreData::Map(map) => { // Check that map has the correct key-value pairs after evaluation assert_values_equal(&map.get(&lit_val(string("xy"))), &lit_val(int(15))); @@ -1009,12 +1009,12 @@ mod tests { // Add a map lookup function ctx.bind( "get".to_string(), - Value(CoreData::Function(FunKind::RustUDF(|args| { + Value::new(CoreData::Function(FunKind::RustUDF(|args| { if args.len() != 2 { panic!("get function requires 2 arguments"); } - match &args[0].0 { + match &args[0].data { CoreData::Map(map) => map.get(&args[1]), _ => panic!("First argument must be a map"), } @@ -1092,7 +1092,7 @@ mod tests { // Check that we got the correct value from the nested lookup assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::String(value)) => { assert_eq!(value, "San Francisco"); } @@ -1107,7 +1107,7 @@ mod tests { let mut ctx = Context::default(); // Define a function to compute factorial: fn(n) => if n <= 1 then 1 else n * factorial(n-1) - let factorial_function = Value(CoreData::Function(FunKind::Closure( + let factorial_function = Value::new(CoreData::Function(FunKind::Closure( vec!["n".to_string()], Arc::new(Expr::new(IfThenElse( Arc::new(Expr::new(Binary( @@ -1159,7 +1159,7 @@ mod tests { let results = evaluate_and_collect(program, engine, harness).await; // Check result: factorial(5) / 3 = 120 / 3 = 40 - match &results[0].0 { + match &results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 40); } @@ -1190,7 +1190,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(30)); } @@ -1206,7 +1206,7 @@ mod tests { let engine = Engine::new(ctx); // Create a tuple (10, "hello", true) - let tuple_expr = Arc::new(Expr::new(CoreVal(Value(CoreData::Tuple(vec![ + let tuple_expr = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Tuple(vec![ lit_val(int(10)), lit_val(string("hello")), lit_val(Literal::Bool(true)), @@ -1219,7 +1219,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("hello".to_string())); } @@ -1247,7 +1247,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(20)); } @@ -1290,12 +1290,12 @@ mod tests { // Check result - should be a tuple (2, None) assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Tuple(elements) => { assert_eq!(elements.len(), 2); // Check first element: map["b"] should be 2 - match &elements[0].0 { + match &elements[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(2)); } @@ -1303,7 +1303,7 @@ mod tests { } // Check second element: map["d"] should be None - match &elements[1].0 { + match &elements[1].data { CoreData::None => {} _ => panic!("Expected None for missing key lookup"), } @@ -1347,7 +1347,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(40)); } @@ -1404,7 +1404,7 @@ mod tests { // Check join_type result assert_eq!(join_type_results.len(), 1); - match &join_type_results[0].0 { + match &join_type_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } @@ -1413,7 +1413,7 @@ mod tests { // Check condition result assert_eq!(condition_results.len(), 1); - match &condition_results[0].0 { + match &condition_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("x = y".to_string())); } @@ -1422,11 +1422,11 @@ mod tests { // Check first child result (orders table scan) assert_eq!(first_child_results.len(), 1); - match &first_child_results[0].0 { + match &first_child_results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { assert_eq!(log_op.operator.tag, "TableScan"); assert_eq!(log_op.operator.data.len(), 1); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("orders".to_string())); } @@ -1438,11 +1438,11 @@ mod tests { // Check second child result (lineitem table scan) assert_eq!(second_child_results.len(), 1); - match &second_child_results[0].0 { + match &second_child_results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { assert_eq!(log_op.operator.tag, "TableScan"); assert_eq!(log_op.operator.data.len(), 1); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("lineitem".to_string())); } @@ -1502,7 +1502,7 @@ mod tests { // Check join method result assert_eq!(method_results.len(), 1); - match &method_results[0].0 { + match &method_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("hash".to_string())); } @@ -1511,7 +1511,7 @@ mod tests { // Check condition result assert_eq!(condition_results.len(), 1); - match &condition_results[0].0 { + match &condition_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("id = id".to_string())); } @@ -1520,11 +1520,11 @@ mod tests { // Check first child result (customers index scan) assert_eq!(first_child_results.len(), 1); - match &first_child_results[0].0 { + match &first_child_results[0].data { CoreData::Physical(Materializable::Materialized(phys_op)) => { assert_eq!(phys_op.operator.tag, "IndexScan"); assert_eq!(phys_op.operator.data.len(), 1); - match &phys_op.operator.data[0].0 { + match &phys_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("customers".to_string())); } @@ -1536,11 +1536,11 @@ mod tests { // Check second child result (orders parallel scan) assert_eq!(second_child_results.len(), 1); - match &second_child_results[0].0 { + match &second_child_results[0].data { CoreData::Physical(Materializable::Materialized(phys_op)) => { assert_eq!(phys_op.operator.tag, "ParallelScan"); assert_eq!(phys_op.operator.data.len(), 1); - match &phys_op.operator.data[0].0 { + match &phys_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("orders".to_string())); } @@ -1573,7 +1573,7 @@ mod tests { harness.register_group(test_group_id, materialized_join); // Create an unmaterialized logical operator - let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value(CoreData::Logical( + let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Logical( Materializable::UnMaterialized(test_group_id), ))))); @@ -1607,7 +1607,7 @@ mod tests { // Check join type result assert_eq!(join_type_results.len(), 1); - match &join_type_results[0].0 { + match &join_type_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } @@ -1616,7 +1616,7 @@ mod tests { // Check condition result assert_eq!(condition_results.len(), 1); - match &condition_results[0].0 { + match &condition_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("customer.id = order.id".to_string())); } @@ -1625,11 +1625,11 @@ mod tests { // Check first child result (customers table scan) assert_eq!(first_child_results.len(), 1); - match &first_child_results[0].0 { + match &first_child_results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { assert_eq!(log_op.operator.tag, "TableScan"); assert_eq!(log_op.operator.data.len(), 1); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("customers".to_string())); } @@ -1647,7 +1647,7 @@ mod tests { // Create a physical goal let test_group_id = GroupId(2); - let properties = Box::new(Value(CoreData::Literal(string("sorted")))); + let properties = Box::new(Value::new(CoreData::Literal(string("sorted")))); let test_goal = Goal { group_id: test_group_id, properties, @@ -1669,7 +1669,7 @@ mod tests { harness.register_goal(&test_goal, materialized_join); // Create an unmaterialized physical operator - let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value(CoreData::Physical( + let unmaterialized_expr = Arc::new(Expr::new(CoreVal(Value::new(CoreData::Physical( Materializable::UnMaterialized(test_goal), ))))); @@ -1703,7 +1703,7 @@ mod tests { // Check join method result assert_eq!(method_results.len(), 1); - match &method_results[0].0 { + match &method_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("merge".to_string())); } @@ -1712,7 +1712,7 @@ mod tests { // Check condition result assert_eq!(condition_results.len(), 1); - match &condition_results[0].0 { + match &condition_results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("customer.id = order.id".to_string())); } @@ -1721,11 +1721,11 @@ mod tests { // Check first child result (customers scan) assert_eq!(first_child_results.len(), 1); - match &first_child_results[0].0 { + match &first_child_results[0].data { CoreData::Physical(Materializable::Materialized(phys_op)) => { assert_eq!(phys_op.operator.tag, "SortedScan"); assert_eq!(phys_op.operator.data.len(), 1); - match &phys_op.operator.data[0].0 { + match &phys_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(*lit, Literal::String("customers".to_string())); } @@ -1805,11 +1805,11 @@ mod tests { // Verify the final result is the "customers" TableScan assert_eq!(table_scan_results.len(), 1); - match &table_scan_results[0].0 { + match &table_scan_results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { assert_eq!(log_op.operator.tag, "TableScan"); assert_eq!(log_op.operator.data.len(), 1); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(*lit, Literal::String("customers".to_string())); } @@ -1842,14 +1842,14 @@ mod tests { let outer_results = evaluate_and_collect(outer_ref, engine.clone(), harness).await; // Check results - match &inner_results[0].0 { + match &inner_results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 200); } _ => panic!("Expected integer value"), } - match &outer_results[0].0 { + match &outer_results[0].data { CoreData::Literal(Literal::Int64(value)) => { assert_eq!(*value, 100); } diff --git a/optd-dsl/src/engine/eval/match.rs b/optd-dsl/src/engine/eval/match.rs index 90d53a64..884d3aa9 100644 --- a/optd-dsl/src/engine/eval/match.rs +++ b/optd-dsl/src/engine/eval/match.rs @@ -131,7 +131,7 @@ where O: Send + 'static, { Box::pin(async move { - match (pattern, &value.0) { + match (pattern, &value.data) { // Simple patterns. (Wildcard, _) => k((value, Some(ctx))).await, (Literal(pattern_lit), CoreData::Literal(value_lit)) => { @@ -263,7 +263,7 @@ where // Split array into head and tail. let head = arr[0].clone(); let tail_elements = arr[1..].to_vec(); - let tail = Value(CoreData::Array(tail_elements)); + let tail = Value::new(CoreData::Array(tail_elements)); // Create components to match sequentially. let patterns = vec![head_pattern, tail_pattern]; @@ -284,13 +284,13 @@ where let tail_value = results[1].0.clone(); // Extract tail elements. - let tail_elements = match &tail_value.0 { + let tail_elements = match &tail_value.data { CoreData::Array(elements) => elements.clone(), _ => panic!("Expected Array in tail result"), }; // Create new array with matched head + tail elements. - let new_array = Value(CoreData::Array( + let new_array = Value::new(CoreData::Array( std::iter::once(head_value).chain(tail_elements).collect(), )); @@ -339,7 +339,7 @@ where // Reconstruct struct with matched field values. let matched_values = results.iter().map(|(v, _)| v.clone()).collect(); - let new_struct = Value(CoreData::Struct(pat_name, matched_values)); + let new_struct = Value::new(CoreData::Struct(pat_name, matched_values)); if all_matched { // Combine contexts by folding over the results, starting with the base context. @@ -415,10 +415,10 @@ where // Create appropriate value type based on original_value. let new_value = if is_logical { let log_op = LogicalOp::logical(new_op); - Value(CoreData::Logical(Materialized(log_op))) + Value::new(CoreData::Logical(Materialized(log_op))) } else { let phys_op = PhysicalOp::physical(new_op); - Value(CoreData::Physical(Materialized(phys_op))) + Value::new(CoreData::Physical(Materialized(phys_op))) }; if all_matched { @@ -561,7 +561,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("matched".to_string())); } @@ -605,7 +605,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(52)); } @@ -657,10 +657,10 @@ mod tests { let mut ctx = Context::default(); ctx.bind( "length".to_string(), - Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Array(elements) => { - Value(CoreData::Literal(int(elements.len() as i64))) + Value::new(CoreData::Literal(int(elements.len() as i64))) } _ => panic!("Expected array"), } @@ -674,7 +674,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(5)); } @@ -722,7 +722,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(30)); } @@ -747,7 +747,7 @@ mod tests { ], }; - let logical_op_value = Value(CoreData::Logical(Materialized(LogicalOp::logical(op)))); + let logical_op_value = Value::new(CoreData::Logical(Materialized(LogicalOp::logical(op)))); let logical_op_expr = Arc::new(Expr::new(CoreVal(logical_op_value.clone()))); // Create a match expression: @@ -815,7 +815,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!( lit, @@ -849,7 +849,8 @@ mod tests { harness.register_group(test_group_id, materialized_join); // Create an unmaterialized logical operator - let unmaterialized_logical_op = Value(CoreData::Logical(UnMaterialized(test_group_id))); + let unmaterialized_logical_op = + Value::new(CoreData::Logical(UnMaterialized(test_group_id))); let unmaterialized_expr = Arc::new(Expr::new(CoreVal(unmaterialized_logical_op))); @@ -921,7 +922,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!( lit, @@ -941,7 +942,7 @@ mod tests { // Create a physical goal let test_group_id = GroupId(2); - let properties = Box::new(Value(CoreData::Literal(string("sorted")))); + let properties = Box::new(Value::new(CoreData::Literal(string("sorted")))); let test_goal = Goal { group_id: test_group_id, properties, @@ -960,7 +961,7 @@ mod tests { harness.register_goal(&test_goal, materialized_hash_join); // Create an unmaterialized physical operator with the goal - let unmaterialized_physical_op = Value(CoreData::Physical(UnMaterialized(test_goal))); + let unmaterialized_physical_op = Value::new(CoreData::Physical(UnMaterialized(test_goal))); let unmaterialized_expr = Arc::new(Expr::new(CoreVal(unmaterialized_physical_op))); @@ -991,12 +992,12 @@ mod tests { // Create formatted result string with binding values Arc::new(Expr::new(Let( "to_string".to_string(), - Arc::new(Expr::new(CoreVal(Value(CoreData::Function( - FunKind::RustUDF(|args| match &args[0].0 { + Arc::new(Expr::new(CoreVal(Value::new(CoreData::Function( + FunKind::RustUDF(|args| match &args[0].data { CoreData::Literal(lit) => { - Value(CoreData::Literal(string(&format!("{:?}", lit)))) + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) } - _ => Value(CoreData::Literal(string(""))), + _ => Value::new(CoreData::Literal(string(""))), }), ))))), Arc::new(Expr::new(Binary( @@ -1045,7 +1046,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { match lit { Literal::String(s) => { @@ -1074,10 +1075,10 @@ mod tests { // Add to_string function to convert numbers to strings ctx.bind( "to_string".to_string(), - Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Literal(Literal::Int64(i)) => { - Value(CoreData::Literal(string(&i.to_string()))) + Value::new(CoreData::Literal(string(&i.to_string()))) } _ => panic!("Expected integer literal"), } @@ -1190,7 +1191,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("Fallthrough arm: 30, 40".to_string())); } @@ -1207,13 +1208,13 @@ mod tests { // Add to_string function to convert complex values to strings ctx.bind( "to_string".to_string(), - Value(CoreData::Function(FunKind::RustUDF(|args| { - match &args[0].0 { + Value::new(CoreData::Function(FunKind::RustUDF(|args| { + match &args[0].data { CoreData::Literal(lit) => { - Value(CoreData::Literal(string(&format!("{:?}", lit)))) + Value::new(CoreData::Literal(string(&format!("{:?}", lit)))) } - CoreData::Array(_) => Value(CoreData::Literal(string(""))), - _ => Value(CoreData::Literal(string(""))), + CoreData::Array(_) => Value::new(CoreData::Literal(string(""))), + _ => Value::new(CoreData::Literal(string(""))), } }))), ); @@ -1348,7 +1349,7 @@ mod tests { // Check result (this is a complex test of correct binding propagation through deeply nested patterns) assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Literal(lit) => { assert!(matches!(lit, Literal::String(_))); let result_str = match lit { @@ -1414,8 +1415,8 @@ mod tests { lit_val(string("left.id = right.id")), ], vec![ - Value(CoreData::Logical(UnMaterialized(group_id_1))), - Value(CoreData::Logical(UnMaterialized(group_id_2))), + Value::new(CoreData::Logical(UnMaterialized(group_id_1))), + Value::new(CoreData::Logical(UnMaterialized(group_id_2))), ], ); @@ -1485,7 +1486,7 @@ mod tests { ]; for result in &results { - match &result.0 { + match &result.data { CoreData::Literal(Literal::String(s)) => { assert!( expected_combinations @@ -1500,7 +1501,7 @@ mod tests { // Ensure we got each combination exactly once (no duplicates) let mut result_strings = Vec::new(); for result in &results { - if let CoreData::Literal(Literal::String(s)) = &result.0 { + if let CoreData::Literal(Literal::String(s)) = &result.data { result_strings.push(s.clone()); } } diff --git a/optd-dsl/src/engine/eval/operator.rs b/optd-dsl/src/engine/eval/operator.rs index 2064a62d..b6ed1449 100644 --- a/optd-dsl/src/engine/eval/operator.rs +++ b/optd-dsl/src/engine/eval/operator.rs @@ -27,7 +27,7 @@ where // For unmaterialized operators, directly call the continuation with the unmaterialized // value. UnMaterialized(group_id) => { - let result = Value(Logical(UnMaterialized(group_id))); + let result = Value::new(Logical(UnMaterialized(group_id))); k(result).await } // For materialized operators, evaluate all parts and construct the result. @@ -41,7 +41,7 @@ where Box::pin(capture!([k], async move { // Wrap the constructed operator in the logical operator structure let log_op = LogicalOp::logical(constructed_op); - let result = Value(Logical(Materialized(log_op))); + let result = Value::new(Logical(Materialized(log_op))); k(result).await })) }), @@ -69,7 +69,7 @@ where match op { // For unmaterialized operators, continue with the unmaterialized value. UnMaterialized(physical_goal) => { - let result = Value(Physical(UnMaterialized(physical_goal))); + let result = Value::new(Physical(UnMaterialized(physical_goal))); k(result).await } // For materialized operators, evaluate all parts and construct the result. @@ -82,7 +82,7 @@ where Arc::new(move |constructed_op| { Box::pin(capture!([k], async move { let phys_op = PhysicalOp::physical(constructed_op); - let result = Value(Physical(Materialized(phys_op))); + let result = Value::new(Physical(Materialized(phys_op))); k(result).await })) }), @@ -196,20 +196,20 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { // Check tag assert_eq!(log_op.operator.tag, "Join"); // Check data - should have "inner" and 15 assert_eq!(log_op.operator.data.len(), 2); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } _ => panic!("Expected string literal"), } - match &log_op.operator.data[1].0 { + match &log_op.operator.data[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(15)); } @@ -218,13 +218,13 @@ mod tests { // Check children - should have "orders" and "lineitem" assert_eq!(log_op.operator.children.len(), 2); - match &log_op.operator.children[0].0 { + match &log_op.operator.children[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("orders".to_string())); } _ => panic!("Expected string literal"), } - match &log_op.operator.children[1].0 { + match &log_op.operator.children[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("lineitem".to_string())); } @@ -255,7 +255,7 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Logical(Materializable::UnMaterialized(id)) => { // Check that the group ID is preserved assert_eq!(*id, group_id); @@ -298,20 +298,20 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Physical(Materializable::Materialized(phys_op)) => { // Check tag assert_eq!(phys_op.operator.tag, "HashJoin"); // Check data - should have "inner" and 60 assert_eq!(phys_op.operator.data.len(), 2); - match &phys_op.operator.data[0].0 { + match &phys_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } _ => panic!("Expected string literal"), } - match &phys_op.operator.data[1].0 { + match &phys_op.operator.data[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::Int64(60)); } @@ -320,13 +320,13 @@ mod tests { // Check children - should have "IndexScan" and "FullScan" assert_eq!(phys_op.operator.children.len(), 2); - match &phys_op.operator.children[0].0 { + match &phys_op.operator.children[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("IndexScan".to_string())); } _ => panic!("Expected string literal"), } - match &phys_op.operator.children[1].0 { + match &phys_op.operator.children[1].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("FullScan".to_string())); } @@ -347,7 +347,7 @@ mod tests { // Create an unmaterialized physical operator with a goal let goal = Goal { group_id: GroupId(42), - properties: Box::new(Value(CoreData::Literal(Literal::String( + properties: Box::new(Value::new(CoreData::Literal(Literal::String( "sorted".to_string(), )))), }; @@ -362,13 +362,13 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Physical(Materializable::UnMaterialized(result_goal)) => { // Check that the goal is preserved assert_eq!(result_goal.group_id, goal.group_id); // Check properties - match &result_goal.properties.0 { + match &result_goal.properties.data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("sorted".to_string())); } @@ -415,14 +415,14 @@ mod tests { // Check result assert_eq!(results.len(), 1); - match &results[0].0 { + match &results[0].data { CoreData::Logical(Materializable::Materialized(log_op)) => { // Check tag assert_eq!(log_op.operator.tag, "Join"); // Check data assert_eq!(log_op.operator.data.len(), 1); - match &log_op.operator.data[0].0 { + match &log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("inner".to_string())); } @@ -433,11 +433,11 @@ mod tests { assert_eq!(log_op.operator.children.len(), 2); // Check first child - match &log_op.operator.children[0].0 { + match &log_op.operator.children[0].data { CoreData::Logical(Materializable::Materialized(child_log_op)) => { assert_eq!(child_log_op.operator.tag, "Scan"); assert_eq!(child_log_op.operator.data.len(), 1); - match &child_log_op.operator.data[0].0 { + match &child_log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("orders".to_string())); } @@ -448,11 +448,11 @@ mod tests { } // Check second child - match &log_op.operator.children[1].0 { + match &log_op.operator.children[1].data { CoreData::Logical(Materializable::Materialized(child_log_op)) => { assert_eq!(child_log_op.operator.tag, "Scan"); assert_eq!(child_log_op.operator.data.len(), 1); - match &child_log_op.operator.data[0].0 { + match &child_log_op.operator.data[0].data { CoreData::Literal(lit) => { assert_eq!(lit, &Literal::String("lineitem".to_string())); } diff --git a/optd-dsl/src/engine/eval/unary.rs b/optd-dsl/src/engine/eval/unary.rs index 619231f4..c6456a44 100644 --- a/optd-dsl/src/engine/eval/unary.rs +++ b/optd-dsl/src/engine/eval/unary.rs @@ -26,16 +26,13 @@ use UnaryOp::*; /// # Panics /// Panics when the operation is not defined for the given operand type pub(crate) fn eval_unary_op(op: &UnaryOp, expr: Value) -> Value { - match (op, &expr.0) { + match (op, &expr.data) { // Numeric negation for integers - (Neg, Literal(Int64(x))) => Value(Literal(Int64(-x))), - + (Neg, Literal(Int64(x))) => Value::new(Literal(Int64(-x))), // Numeric negation for floating-point numbers - (Neg, Literal(Float64(x))) => Value(Literal(Float64(-x))), - + (Neg, Literal(Float64(x))) => Value::new(Literal(Float64(-x))), // Logical NOT for boolean values - (Not, Literal(Bool(x))) => Value(Literal(Bool(!x))), - + (Not, Literal(Bool(x))) => Value::new(Literal(Bool(!x))), // Any other combination is invalid _ => panic!("Invalid unary operation or type mismatch"), } @@ -47,37 +44,37 @@ mod tests { // Helper function to create integer Value fn int(i: i64) -> Value { - Value(Literal(Int64(i))) + Value::new(Literal(Int64(i))) } // Helper function to create float Value fn float(f: f64) -> Value { - Value(Literal(Float64(f))) + Value::new(Literal(Float64(f))) } // Helper function to create boolean Value fn boolean(b: bool) -> Value { - Value(Literal(Bool(b))) + Value::new(Literal(Bool(b))) } #[test] fn test_integer_negation() { // Negating a positive integer - if let Literal(Int64(result)) = eval_unary_op(&Neg, int(5)).0 { + if let Literal(Int64(result)) = eval_unary_op(&Neg, int(5)).data { assert_eq!(result, -5); } else { panic!("Expected Int64"); } // Negating a negative integer - if let Literal(Int64(result)) = eval_unary_op(&Neg, int(-7)).0 { + if let Literal(Int64(result)) = eval_unary_op(&Neg, int(-7)).data { assert_eq!(result, 7); } else { panic!("Expected Int64"); } // Negating zero - if let Literal(Int64(result)) = eval_unary_op(&Neg, int(0)).0 { + if let Literal(Int64(result)) = eval_unary_op(&Neg, int(0)).data { assert_eq!(result, 0); } else { panic!("Expected Int64"); @@ -89,21 +86,21 @@ mod tests { use std::f64::consts::PI; // Negating a positive float - if let Literal(Float64(result)) = eval_unary_op(&Neg, float(PI)).0 { + if let Literal(Float64(result)) = eval_unary_op(&Neg, float(PI)).data { assert_eq!(result, -PI); } else { panic!("Expected Float64"); } // Negating a negative float - if let Literal(Float64(result)) = eval_unary_op(&Neg, float(-2.5)).0 { + if let Literal(Float64(result)) = eval_unary_op(&Neg, float(-2.5)).data { assert_eq!(result, 2.5); } else { panic!("Expected Float64"); } // Negating zero - if let Literal(Float64(result)) = eval_unary_op(&Neg, float(0.0)).0 { + if let Literal(Float64(result)) = eval_unary_op(&Neg, float(0.0)).data { assert_eq!(result, -0.0); // Checking sign bit for -0.0 assert!(result.to_bits() & 0x8000_0000_0000_0000 != 0); @@ -115,14 +112,14 @@ mod tests { #[test] fn test_boolean_not() { // NOT true - if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(true)).0 { + if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(true)).data { assert!(!result); } else { panic!("Expected Bool"); } // NOT false - if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(false)).0 { + if let Literal(Bool(result)) = eval_unary_op(&Not, boolean(false)).data { assert!(result); } else { panic!("Expected Bool"); diff --git a/optd-dsl/src/utils/tests.rs b/optd-dsl/src/utils/tests.rs index 7e020d90..c173bcc9 100644 --- a/optd-dsl/src/utils/tests.rs +++ b/optd-dsl/src/utils/tests.rs @@ -83,7 +83,7 @@ impl TestHarness { // Helper to compare Values pub fn assert_values_equal(v1: &Value, v2: &Value) { - match (&v1.0, &v2.0) { + match (&v1.data, &v2.data) { (CoreData::Literal(l1), CoreData::Literal(l2)) => match (l1, l2) { (Literal::Int64(i1), Literal::Int64(i2)) => assert_eq!(i1, i2), (Literal::Float64(f1), Literal::Float64(f2)) => assert_eq!(f1, f2), @@ -106,7 +106,7 @@ pub fn assert_values_equal(v1: &Value, v2: &Value) { assert_values_equal(v1, v2); } } - _ => panic!("Values don't match: {:?} vs {:?}", v1.0, v2.0), + _ => panic!("Values don't match: {:?} vs {:?}", v1.data, v2.data), } } @@ -117,7 +117,7 @@ pub fn lit_expr(literal: Literal) -> Arc { /// Helper to create a literal value. pub fn lit_val(literal: Literal) -> Value { - Value(CoreData::Literal(literal)) + Value::new(CoreData::Literal(literal)) } /// Helper to create an integer literal. @@ -147,12 +147,12 @@ pub fn match_arm(pattern: Pattern, expr: Arc) -> MatchArm { /// Helper to create an array value. pub fn array_val(items: Vec) -> Value { - Value(CoreData::Array(items)) + Value::new(CoreData::Array(items)) } /// Helper to create a struct value. pub fn struct_val(name: &str, fields: Vec) -> Value { - Value(CoreData::Struct(name.to_string(), fields)) + Value::new(CoreData::Struct(name.to_string(), fields)) } /// Helper to create a pattern matching expression. @@ -202,7 +202,7 @@ pub fn create_logical_operator(tag: &str, data: Vec, children: Vec children, }; - Value(CoreData::Logical(Materialized(LogicalOp::logical(op)))) + Value::new(CoreData::Logical(Materialized(LogicalOp::logical(op)))) } /// Helper to create a simple physical operator value. @@ -213,7 +213,7 @@ pub fn create_physical_operator(tag: &str, data: Vec, children: Vec