diff --git a/Project.toml b/Project.toml index 7097dba..e4cceb4 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "CTParser" uuid = "32681960-a1b1-40db-9bff-a1ca817385d1" -version = "0.8.9-beta" +version = "0.8.10-beta" authors = ["Jean-Baptiste Caillau "] [deps] diff --git a/src/initial_guess.jl b/src/initial_guess.jl index cb5647c..de71c13 100644 --- a/src/initial_guess.jl +++ b/src/initial_guess.jl @@ -108,7 +108,7 @@ or a time grid, based on whether `arg` matches `time_name(ocp)`. - `pref::Symbol`: backend module prefix (e.g. `:CTModels`). - `ocp`: symbolic OCP variable passed from the macro. -- `arg::Symbol`: argument symbol used in the specification (e.g. `:t`, `:s`, `:T`). +- `arg`: argument used in the specification (e.g. `:t`, `:s`, or a literal array after alias expansion). - `rhs`: right-hand side expression. - `arg_in_rhs::Bool`: whether `arg` appears in `rhs` (computed at parse-time via `has`). @@ -119,14 +119,25 @@ or a time grid, based on whether `arg` matches `time_name(ocp)`. # Notes -When `arg_in_rhs` is `true`, the specification is definitely a time-dependent -function, so we validate that `arg == Symbol(time_name(ocp))` and throw an -error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional +When `arg` is not a Symbol (e.g., a literal array after alias expansion like +`[0.0, 0.5, 1.0]`), it is always treated as a time grid specification. + +When `arg` is a Symbol and `arg_in_rhs` is `true`, the specification is definitely +a time-dependent function, so we validate that `arg == Symbol(time_name(ocp))` and +throw an error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional that checks whether `arg` matches the time name to decide between a constant function or a time grid. """ function __gen_temporal_value(pref, ocp, arg, rhs, arg_in_rhs) val_sym = __symgen(:init_val) + + # Early return: if arg is not a Symbol (e.g., literal array after alias expansion), + # it's always a grid specification + if !(arg isa Symbol) + code = :($val_sym = ($arg, $rhs)) + return val_sym, code + end + arg_quoted = QuoteNode(arg) if arg_in_rhs @@ -232,7 +243,15 @@ function __log_spec(key, spec) rhs end rhs_str = sprint(Base.show_unquoted, rhs_clean) - return string(key, " = ", arg, " -> ", rhs_str) + + # If arg is not a Symbol (e.g., literal array after alias expansion), + # format as grid specification + if arg isa Symbol + return string(key, " = ", arg, " -> ", rhs_str) + else + arg_str = sprint(Base.show_unquoted, arg) + return string(key, " = (", arg_str, ", ", rhs_str, ")") + end else return string(key, " = ???") end @@ -245,17 +264,21 @@ Internal helper that parses the body of an `@init` block. The function walks through the expression `ex` and splits it into -- *alias statements*, which are left as ordinary Julia assignments and - executed verbatim inside the generated block; +- *alias statements* of the form `lhs = rhs`, which are stored in a dictionary + and substituted into subsequent statements at parse-time using `subs`; - *initialisation specifications* of the form `lhs := rhs` or `lhs(arg) := rhs`, which are converted into structured specification - tuples. + tuples after alias expansion. For expressions of the form `lhs(arg) := rhs`, this function uses `has(rhs, arg)` to determine whether `arg` appears in the right-hand side. This information is stored in the specification tuple and used later to generate appropriate runtime code that distinguishes time-dependent functions from time grids. +Alias substitution happens before each statement is matched, enabling +time-dependent aliases like `phi = 2π * t` and accumulated aliases like +`a = t; s = a`. + # Arguments - `ex::Any`: expression or block coming from the body of `@init`. @@ -264,8 +287,6 @@ runtime code that distinguishes time-dependent functions from time grids. # Returns -- `alias_stmts::Vector{Expr}`: ordinary statements to execute before - building the initial guess. - `keys::Vector{Symbol}`: names of the components being initialised (e.g. `:q`, `:v`, `:u`, `:tf`). - `specs::Vector{Tuple}`: specification tuples, either `(:constant, rhs)` @@ -273,7 +294,7 @@ runtime code that distinguishes time-dependent functions from time grids. specifications where `arg_in_rhs` indicates whether `arg` appears in `rhs`. """ function _collect_init_specs(ex, lnum::Int, line_str::String) - alias_stmts = Expr[] # statements of the form a = ... or other Julia statements + aliases = OrderedCollections.OrderedDict{Union{Symbol,Expr}, Any}() keys = Symbol[] # keys of the NamedTuple (q, v, x, u, tf, ...) specs = Tuple[] # specification tuples @@ -286,20 +307,27 @@ function _collect_init_specs(ex, lnum::Int, line_str::String) for st in stmts st isa LineNumberNode && continue + # Substitute all known aliases before matching + for a in Base.keys(aliases) + st = subs(st, a, aliases[a]) + end + @match st begin - # Alias / ordinary Julia assignments left as-is + # Alias: store for future substitution :($lhs = $rhs) => begin - push!(alias_stmts, st) + lhs isa Symbol || error("Unsupported alias left-hand side in @init: $lhs (only Symbol allowed)") + aliases[lhs] = rhs end # Forms q(arg) := rhs - # Use has(rhs, arg) to determine if arg appears in rhs + # After alias expansion, arg may be a Symbol or an Expr (literal grid) :($lhs($arg) := $rhs) => begin lhs isa Symbol || error("Unsupported left-hand side in @init: $lhs") - arg isa Symbol || error("Unsupported argument in @init: $arg must be a symbol") # Check if arg appears in rhs using has() from utils.jl - arg_in_rhs = has(rhs, arg) + # Note: if arg is not a Symbol (e.g., after alias expansion to a literal array), + # has() will return false, which is correct for grid specifications + arg_in_rhs = (arg isa Symbol) && has(rhs, arg) push!(keys, lhs) push!(specs, (:temporal, arg, rhs, arg_in_rhs)) @@ -312,14 +340,14 @@ function _collect_init_specs(ex, lnum::Int, line_str::String) push!(specs, (:constant, rhs)) end - # Fallback: any other line is treated as an ordinary Julia statement + # Fallback: strict mode - reject unrecognized statements _ => begin - push!(alias_stmts, st) + error("Unrecognized statement in @init block: $st. Only alias assignments (a = expr) and specifications (lhs := rhs or lhs(arg) := rhs) are allowed.") end end end - return alias_stmts, keys, specs + return keys, specs end """ @@ -352,24 +380,20 @@ macro level. initial guess when executed. """ function init_fun(ocp, e, lnum::Int, line_str::String) - alias_stmts, keys, specs = _collect_init_specs(e, lnum, line_str) + keys, specs = _collect_init_specs(e, lnum, line_str) pref = init_prefix() # If there is no init specification, delegate to build_initial_guess/validate_initial_guess if isempty(keys) - body_stmts = Any[] - append!(body_stmts, alias_stmts) build_call = :($pref.build_initial_guess($ocp, ())) validate_call = :($pref.validate_initial_guess($ocp, $build_call)) - push!(body_stmts, validate_call) - code_expr = Expr(:block, body_stmts...) + code_expr = validate_call log_str = "()" return log_str, code_expr end # Generate runtime code for each specification body_stmts = Any[] - append!(body_stmts, alias_stmts) val_syms = Symbol[] for spec in specs diff --git a/test/test_initial_guess.jl b/test/test_initial_guess.jl index 2425fe8..42e142b 100644 --- a/test/test_initial_guess.jl +++ b/test/test_initial_guess.jl @@ -692,4 +692,155 @@ function test_initial_guess() # debug @test occursin("\"s\"", err_msg) @test occursin(":s", err_msg) end + + @testset "time-dependent alias (phi = 2pi * t)" begin + ocp_circle = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R², control + x(0) == [0, 0] + x(1) == [1, 1] + ẋ(t) == u(t) + ∫(u(t)' * u(t)) → min + end + + ig = @init ocp_circle begin + phi = 2π * t # Alias depending on time variable + u(t) := [cos(phi), sin(phi)] + end + + @test ig isa CTModels.AbstractInitialGuess + CTModels.validate_initial_guess(ocp_circle, ig) + + ufun = CTModels.control(ig) + u0 = ufun(0.0) + u1 = ufun(0.5) + + @test u0[1] ≈ cos(0.0) + @test u0[2] ≈ sin(0.0) + @test u1[1] ≈ cos(π) + @test u1[2] ≈ sin(π) atol=1e-10 + end + + @testset "time variable substitution (s = t)" begin + ocp_circle = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R², control + x(0) == [0, 0] + x(1) == [1, 1] + ẋ(t) == u(t) + ∫(u(t)' * u(t)) → min + end + + ig = @init ocp_circle begin + s = t # Alias for time variable + u(s) := [cos(s), sin(s)] + end + + @test ig isa CTModels.AbstractInitialGuess + CTModels.validate_initial_guess(ocp_circle, ig) + + ufun = CTModels.control(ig) + u0 = ufun(0.0) + u1 = ufun(0.5) + + @test u0[1] ≈ cos(0.0) + @test u0[2] ≈ sin(0.0) + @test u1[1] ≈ cos(0.5) + @test u1[2] ≈ sin(0.5) + end + + @testset "grid aliases (T, X, U as local variables)" begin + ig = @init ocp_fixed begin + T = [0.0, 0.5, 1.0] + X = [[-1.0, 0.0], [0.0, 0.5], [0.0, 0.0]] + U = [0.0, 0.0, 1.0] + x(T) := X + u(T) := U + end + + @test ig isa CTModels.AbstractInitialGuess + CTModels.validate_initial_guess(ocp_fixed, ig) + + xfun = CTModels.state(ig) + ufun = CTModels.control(ig) + + x0 = xfun(0.0) + x1 = xfun(1.0) + u0 = ufun(0.0) + u1 = ufun(1.0) + + @test x0[1] ≈ -1.0 + @test x0[2] ≈ 0.0 + @test x1[1] ≈ 0.0 + @test x1[2] ≈ 0.0 + @test u0 ≈ 0.0 + @test u1 ≈ 1.0 + end + + @testset "accumulated aliases (a = t, s = a)" begin + ocp_circle = @def begin + t ∈ [0, 1], time + x ∈ R², state + u ∈ R², control + x(0) == [0, 0] + x(1) == [1, 1] + ẋ(t) == u(t) + ∫(u(t)' * u(t)) → min + end + + ig = @init ocp_circle begin + a = t + s = a + u(s) := [cos(s), sin(s)] + end + + @test ig isa CTModels.AbstractInitialGuess + CTModels.validate_initial_guess(ocp_circle, ig) + + ufun = CTModels.control(ig) + u0 = ufun(0.0) + u1 = ufun(0.5) + + @test u0[1] ≈ cos(0.0) + @test u0[2] ≈ sin(0.0) + @test u1[1] ≈ cos(0.5) + @test u1[2] ≈ sin(0.5) + end + + @testset "grid aliases with literal arrays" begin + ig = @init ocp_fixed begin + X = [[-1.0, 0.0], [0.0, 0.5], [0.0, 0.0]] + U = [0.0, 0.0, 1.0] + x([0.0, 0.5, 1.0]) := X + u([0.0, 0.5, 1.0]) := U + end + + @test ig isa CTModels.AbstractInitialGuess + CTModels.validate_initial_guess(ocp_fixed, ig) + + xfun = CTModels.state(ig) + ufun = CTModels.control(ig) + + x0 = xfun(0.0) + x1 = xfun(1.0) + u0 = ufun(0.0) + u1 = ufun(1.0) + + @test x0[1] ≈ -1.0 + @test x0[2] ≈ 0.0 + @test x1[1] ≈ 0.0 + @test x1[2] ≈ 0.0 + @test u0 ≈ 0.0 + @test u1 ≈ 1.0 + end + + @testset "strict mode: unrecognized statement error" begin + @test_throws CTBase.ParsingError Base.redirect_stdout(Base.devnull) do + @init ocp_fixed begin + println("This should fail") + end + end + end end