Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "CTParser"
uuid = "32681960-a1b1-40db-9bff-a1ca817385d1"
version = "0.8.9-beta"
version = "0.8.10-beta"
authors = ["Jean-Baptiste Caillau <jean-baptiste.caillau@univ-cotedazur.fr>"]

[deps]
Expand Down
74 changes: 49 additions & 25 deletions src/initial_guess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ or a time grid, based on whether `arg` matches `time_name(ocp)`.

- `pref::Symbol`: backend module prefix (e.g. `:CTModels`).
- `ocp`: symbolic OCP variable passed from the macro.
- `arg::Symbol`: argument symbol used in the specification (e.g. `:t`, `:s`, `:T`).
- `arg`: argument used in the specification (e.g. `:t`, `:s`, or a literal array after alias expansion).
- `rhs`: right-hand side expression.
- `arg_in_rhs::Bool`: whether `arg` appears in `rhs` (computed at parse-time via `has`).

Expand All @@ -119,14 +119,25 @@ or a time grid, based on whether `arg` matches `time_name(ocp)`.

# Notes

When `arg_in_rhs` is `true`, the specification is definitely a time-dependent
function, so we validate that `arg == Symbol(time_name(ocp))` and throw an
error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional
When `arg` is not a Symbol (e.g., a literal array after alias expansion like
`[0.0, 0.5, 1.0]`), it is always treated as a time grid specification.

When `arg` is a Symbol and `arg_in_rhs` is `true`, the specification is definitely
a time-dependent function, so we validate that `arg == Symbol(time_name(ocp))` and
throw an error if not. When `arg_in_rhs` is `false`, we generate a runtime conditional
that checks whether `arg` matches the time name to decide between a constant
function or a time grid.
"""
function __gen_temporal_value(pref, ocp, arg, rhs, arg_in_rhs)
val_sym = __symgen(:init_val)

# Early return: if arg is not a Symbol (e.g., literal array after alias expansion),
# it's always a grid specification
if !(arg isa Symbol)
code = :($val_sym = ($arg, $rhs))
return val_sym, code
end

arg_quoted = QuoteNode(arg)

if arg_in_rhs
Expand Down Expand Up @@ -232,7 +243,15 @@ function __log_spec(key, spec)
rhs
end
rhs_str = sprint(Base.show_unquoted, rhs_clean)
return string(key, " = ", arg, " -> ", rhs_str)

# If arg is not a Symbol (e.g., literal array after alias expansion),
# format as grid specification
if arg isa Symbol
return string(key, " = ", arg, " -> ", rhs_str)
else
arg_str = sprint(Base.show_unquoted, arg)
return string(key, " = (", arg_str, ", ", rhs_str, ")")
end
else
return string(key, " = ???")
end
Expand All @@ -245,17 +264,21 @@ Internal helper that parses the body of an `@init` block.

The function walks through the expression `ex` and splits it into

- *alias statements*, which are left as ordinary Julia assignments and
executed verbatim inside the generated block;
- *alias statements* of the form `lhs = rhs`, which are stored in a dictionary
and substituted into subsequent statements at parse-time using `subs`;
- *initialisation specifications* of the form `lhs := rhs` or
`lhs(arg) := rhs`, which are converted into structured specification
tuples.
tuples after alias expansion.

For expressions of the form `lhs(arg) := rhs`, this function uses `has(rhs, arg)`
to determine whether `arg` appears in the right-hand side. This information
is stored in the specification tuple and used later to generate appropriate
runtime code that distinguishes time-dependent functions from time grids.

Alias substitution happens before each statement is matched, enabling
time-dependent aliases like `phi = 2π * t` and accumulated aliases like
`a = t; s = a`.

# Arguments

- `ex::Any`: expression or block coming from the body of `@init`.
Expand All @@ -264,16 +287,14 @@ runtime code that distinguishes time-dependent functions from time grids.

# Returns

- `alias_stmts::Vector{Expr}`: ordinary statements to execute before
building the initial guess.
- `keys::Vector{Symbol}`: names of the components being initialised
(e.g. `:q`, `:v`, `:u`, `:tf`).
- `specs::Vector{Tuple}`: specification tuples, either `(:constant, rhs)`
for constant values or `(:temporal, arg, rhs, arg_in_rhs)` for temporal
specifications where `arg_in_rhs` indicates whether `arg` appears in `rhs`.
"""
function _collect_init_specs(ex, lnum::Int, line_str::String)
alias_stmts = Expr[] # statements of the form a = ... or other Julia statements
aliases = OrderedCollections.OrderedDict{Union{Symbol,Expr}, Any}()
keys = Symbol[] # keys of the NamedTuple (q, v, x, u, tf, ...)
specs = Tuple[] # specification tuples

Expand All @@ -286,20 +307,27 @@ function _collect_init_specs(ex, lnum::Int, line_str::String)
for st in stmts
st isa LineNumberNode && continue

# Substitute all known aliases before matching
for a in Base.keys(aliases)
st = subs(st, a, aliases[a])
end

@match st begin
# Alias / ordinary Julia assignments left as-is
# Alias: store for future substitution
:($lhs = $rhs) => begin
push!(alias_stmts, st)
lhs isa Symbol || error("Unsupported alias left-hand side in @init: $lhs (only Symbol allowed)")
aliases[lhs] = rhs
end

# Forms q(arg) := rhs
# Use has(rhs, arg) to determine if arg appears in rhs
# After alias expansion, arg may be a Symbol or an Expr (literal grid)
:($lhs($arg) := $rhs) => begin
lhs isa Symbol || error("Unsupported left-hand side in @init: $lhs")
arg isa Symbol || error("Unsupported argument in @init: $arg must be a symbol")

# Check if arg appears in rhs using has() from utils.jl
arg_in_rhs = has(rhs, arg)
# Note: if arg is not a Symbol (e.g., after alias expansion to a literal array),
# has() will return false, which is correct for grid specifications
arg_in_rhs = (arg isa Symbol) && has(rhs, arg)

push!(keys, lhs)
push!(specs, (:temporal, arg, rhs, arg_in_rhs))
Expand All @@ -312,14 +340,14 @@ function _collect_init_specs(ex, lnum::Int, line_str::String)
push!(specs, (:constant, rhs))
end

# Fallback: any other line is treated as an ordinary Julia statement
# Fallback: strict mode - reject unrecognized statements
_ => begin
push!(alias_stmts, st)
error("Unrecognized statement in @init block: $st. Only alias assignments (a = expr) and specifications (lhs := rhs or lhs(arg) := rhs) are allowed.")
end
end
end

return alias_stmts, keys, specs
return keys, specs
end

"""
Expand Down Expand Up @@ -352,24 +380,20 @@ macro level.
initial guess when executed.
"""
function init_fun(ocp, e, lnum::Int, line_str::String)
alias_stmts, keys, specs = _collect_init_specs(e, lnum, line_str)
keys, specs = _collect_init_specs(e, lnum, line_str)
pref = init_prefix()

# If there is no init specification, delegate to build_initial_guess/validate_initial_guess
if isempty(keys)
body_stmts = Any[]
append!(body_stmts, alias_stmts)
build_call = :($pref.build_initial_guess($ocp, ()))
validate_call = :($pref.validate_initial_guess($ocp, $build_call))
push!(body_stmts, validate_call)
code_expr = Expr(:block, body_stmts...)
code_expr = validate_call
log_str = "()"
return log_str, code_expr
end

# Generate runtime code for each specification
body_stmts = Any[]
append!(body_stmts, alias_stmts)

val_syms = Symbol[]
for spec in specs
Expand Down
151 changes: 151 additions & 0 deletions test/test_initial_guess.jl
Original file line number Diff line number Diff line change
Expand Up @@ -692,4 +692,155 @@ function test_initial_guess() # debug
@test occursin("\"s\"", err_msg)
@test occursin(":s", err_msg)
end

@testset "time-dependent alias (phi = 2pi * t)" begin
ocp_circle = @def begin
t ∈ [0, 1], time
x ∈ R², state
u ∈ R², control
x(0) == [0, 0]
x(1) == [1, 1]
ẋ(t) == u(t)
∫(u(t)' * u(t)) → min
end

ig = @init ocp_circle begin
phi = 2π * t # Alias depending on time variable
u(t) := [cos(phi), sin(phi)]
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_circle, ig)

ufun = CTModels.control(ig)
u0 = ufun(0.0)
u1 = ufun(0.5)

@test u0[1] ≈ cos(0.0)
@test u0[2] ≈ sin(0.0)
@test u1[1] ≈ cos(π)
@test u1[2] ≈ sin(π) atol=1e-10
end

@testset "time variable substitution (s = t)" begin
ocp_circle = @def begin
t ∈ [0, 1], time
x ∈ R², state
u ∈ R², control
x(0) == [0, 0]
x(1) == [1, 1]
ẋ(t) == u(t)
∫(u(t)' * u(t)) → min
end

ig = @init ocp_circle begin
s = t # Alias for time variable
u(s) := [cos(s), sin(s)]
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_circle, ig)

ufun = CTModels.control(ig)
u0 = ufun(0.0)
u1 = ufun(0.5)

@test u0[1] ≈ cos(0.0)
@test u0[2] ≈ sin(0.0)
@test u1[1] ≈ cos(0.5)
@test u1[2] ≈ sin(0.5)
end

@testset "grid aliases (T, X, U as local variables)" begin
ig = @init ocp_fixed begin
T = [0.0, 0.5, 1.0]
X = [[-1.0, 0.0], [0.0, 0.5], [0.0, 0.0]]
U = [0.0, 0.0, 1.0]
x(T) := X
u(T) := U
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_fixed, ig)

xfun = CTModels.state(ig)
ufun = CTModels.control(ig)

x0 = xfun(0.0)
x1 = xfun(1.0)
u0 = ufun(0.0)
u1 = ufun(1.0)

@test x0[1] ≈ -1.0
@test x0[2] ≈ 0.0
@test x1[1] ≈ 0.0
@test x1[2] ≈ 0.0
@test u0 ≈ 0.0
@test u1 ≈ 1.0
end

@testset "accumulated aliases (a = t, s = a)" begin
ocp_circle = @def begin
t ∈ [0, 1], time
x ∈ R², state
u ∈ R², control
x(0) == [0, 0]
x(1) == [1, 1]
ẋ(t) == u(t)
∫(u(t)' * u(t)) → min
end

ig = @init ocp_circle begin
a = t
s = a
u(s) := [cos(s), sin(s)]
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_circle, ig)

ufun = CTModels.control(ig)
u0 = ufun(0.0)
u1 = ufun(0.5)

@test u0[1] ≈ cos(0.0)
@test u0[2] ≈ sin(0.0)
@test u1[1] ≈ cos(0.5)
@test u1[2] ≈ sin(0.5)
end

@testset "grid aliases with literal arrays" begin
ig = @init ocp_fixed begin
X = [[-1.0, 0.0], [0.0, 0.5], [0.0, 0.0]]
U = [0.0, 0.0, 1.0]
x([0.0, 0.5, 1.0]) := X
u([0.0, 0.5, 1.0]) := U
end

@test ig isa CTModels.AbstractInitialGuess
CTModels.validate_initial_guess(ocp_fixed, ig)

xfun = CTModels.state(ig)
ufun = CTModels.control(ig)

x0 = xfun(0.0)
x1 = xfun(1.0)
u0 = ufun(0.0)
u1 = ufun(1.0)

@test x0[1] ≈ -1.0
@test x0[2] ≈ 0.0
@test x1[1] ≈ 0.0
@test x1[2] ≈ 0.0
@test u0 ≈ 0.0
@test u1 ≈ 1.0
end

@testset "strict mode: unrecognized statement error" begin
@test_throws CTBase.ParsingError Base.redirect_stdout(Base.devnull) do
@init ocp_fixed begin
println("This should fail")
end
end
end
end
Loading