Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make ArrayPartition default instead of ProductRepr #612

Merged
merged 9 commits into from
May 30, 2023
6 changes: 3 additions & 3 deletions src/groups/product_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end

function identity_element(G::ProductGroup)
M = G.manifold
return ProductRepr(map(identity_element, M.manifolds))
return ArrayPartition(map(identity_element, M.manifolds))
end
function identity_element!(G::ProductGroup, p)
pes = submanifold_components(G, p)
Expand Down Expand Up @@ -194,7 +194,7 @@ end

function translate_diff(G::ProductGroup, p, q, X, conv::ActionDirection)
M = G.manifold
return ProductRepr(
return ArrayPartition(
map(
translate_diff,
M.manifolds,
Expand Down Expand Up @@ -282,7 +282,7 @@ end

function exp_lie(G::ProductGroup, X)
M = G.manifold
return ProductRepr(map(exp_lie, M.manifolds, submanifold_components(G, X))...)
return ArrayPartition(map(exp_lie, M.manifolds, submanifold_components(G, X))...)
end

function exp_lie!(G::ProductGroup, q, X)
Expand Down
6 changes: 3 additions & 3 deletions src/groups/semidirect_product_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -51,14 +51,14 @@ function allocate_result(G::SemidirectProductGroup, ::typeof(identity_element))
N, H = M.manifolds
np = allocate_result(N, identity_element)
hp = allocate_result(H, identity_element)
return ProductRepr(np, hp)
return ArrayPartition(np, hp)
end

"""
identity_element(G::SemidirectProductGroup)

Get the identity element of [`SemidirectProductGroup`](@ref) `G`. Uses [`ProductRepr`](@ref)
to represent the point.
Get the identity element of [`SemidirectProductGroup`](@ref) `G`. Uses `ArrayPartition`
from `RecursiveArrayTools.jl` to represent the point.
"""
identity_element(G::SemidirectProductGroup)

Expand Down
13 changes: 9 additions & 4 deletions src/groups/special_euclidean.jl
Original file line number Diff line number Diff line change
Expand Up @@ -129,8 +129,7 @@ matrix part of `p`, `r` is the translation part of `fX` and `ω` is the rotation
``×`` is the cross product and ``⋅`` is the matrix product.
"""
function adjoint_action(::SpecialEuclidean{3}, p, fX::TFVector{<:Any,VeeOrthogonalBasis{ℝ}})
t = p.parts[1]
R = p.parts[2]
t, R = submanifold_components(p)
r = fX.data[SA[1, 2, 3]]
ω = fX.data[SA[4, 5, 6]]
Rω = R * ω
Expand Down Expand Up @@ -553,6 +552,7 @@ end

"""
lie_bracket(G::SpecialEuclidean, X::ProductRepr, Y::ProductRepr)
lie_bracket(G::SpecialEuclidean, X::ArrayPartition, Y::ArrayPartition)
lie_bracket(G::SpecialEuclidean, X::AbstractMatrix, Y::AbstractMatrix)

Calculate the Lie bracket between elements `X` and `Y` of the special Euclidean Lie
Expand All @@ -565,6 +565,11 @@ function lie_bracket(G::SpecialEuclidean, X::ProductRepr, Y::ProductRepr)
nY, hY = submanifold_components(G, Y)
return ProductRepr(hX * nY - hY * nX, lie_bracket(G.manifold.manifolds[2], hX, hY))
end
function lie_bracket(G::SpecialEuclidean, X::ArrayPartition, Y::ArrayPartition)
nX, hX = submanifold_components(G, X)
nY, hY = submanifold_components(G, Y)
return ArrayPartition(hX * nY - hY * nX, lie_bracket(G.manifold.manifolds[2], hX, hY))
end
function lie_bracket(::SpecialEuclidean, X::AbstractMatrix, Y::AbstractMatrix)
return X * Y - Y * X
end
Expand Down Expand Up @@ -668,7 +673,7 @@ This is performed by extracting the rotation and translation part as in [`affine
function project(M::SpecialEuclideanInGeneralLinear, p)
G = M.manifold
np, hp = submanifold_components(G, p)
return ProductRepr(np, hp)
return ArrayPartition(np, hp)
end
"""
project(M::SpecialEuclideanInGeneralLinear, p, X)
Expand All @@ -681,7 +686,7 @@ function project(M::SpecialEuclideanInGeneralLinear, p, X)
G = M.manifold
np, hp = submanifold_components(G, p)
nX, hX = submanifold_components(G, X)
return ProductRepr(hp * nX, hX)
return ArrayPartition(hp * nX, hX)
end

function project!(M::SpecialEuclideanInGeneralLinear, q, p)
Expand Down
2 changes: 1 addition & 1 deletion src/manifolds/ProductManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -118,7 +118,7 @@ function InverseProductRetraction(inverse_retractions::AbstractInverseRetraction
end

@inline function allocate_result(M::ProductManifold, f)
return ProductRepr(map(N -> allocate_result(N, f), M.manifolds))
return ArrayPartition(map(N -> allocate_result(N, f), M.manifolds))
end

function allocation_promotion_function(M::ProductManifold, f, args::Tuple)
Expand Down
12 changes: 6 additions & 6 deletions src/manifolds/VectorBundle.jl
Original file line number Diff line number Diff line change
Expand Up @@ -306,7 +306,7 @@ base_manifold(B::VectorSpaceAtPoint) = base_manifold(B.fiber)
base_manifold(B::VectorBundle) = base_manifold(B.manifold)

"""
bundle_projection(B::VectorBundle, x::ProductRepr)
bundle_projection(B::VectorBundle, p::ArrayPartition)

Projection of point `p` from the bundle `M` to the base manifold.
Returns the point on the base manifold `B.manifold` at which the vector part
Expand Down Expand Up @@ -466,7 +466,7 @@ end
function get_vector(M::VectorBundle, p, X, B::AbstractBasis)
n = manifold_dimension(M.manifold)
xp1 = submanifold_component(p, Val(1))
return ProductRepr(
return ArrayPartition(
get_vector(M.manifold, xp1, X[1:n], B),
get_vector(M.fiber, xp1, X[(n + 1):end], B),
)
Expand All @@ -488,7 +488,7 @@ function get_vector(
) where {𝔽}
n = manifold_dimension(M.manifold)
xp1 = submanifold_component(p, Val(1))
return ProductRepr(
return ArrayPartition(
get_vector(M.manifold, xp1, X[1:n], B.data.base_basis),
get_vector(M.fiber, xp1, X[(n + 1):end], B.data.vec_basis),
)
Expand Down Expand Up @@ -1070,7 +1070,7 @@ end
return allocate_result(B.manifold, f, x...)
end
@inline function allocate_result(M::VectorBundle, f::TF) where {TF}
return ProductRepr(allocate_result(M.manifold, f), allocate_result(M.fiber, f))
return ArrayPartition(allocate_result(M.manifold, f), allocate_result(M.fiber, f))
end

"""
Expand Down Expand Up @@ -1123,7 +1123,7 @@ function _vector_transport_direction(
px, pVx = submanifold_components(M.manifold, p)
VXM, VXF = submanifold_components(M.manifold, X)
dx, dVx = submanifold_components(M.manifold, d)
return ProductRepr(
return ArrayPartition(
vector_transport_direction(M.manifold, px, VXM, dx, m.method_point),
vector_transport_direction(M.manifold, px, VXF, dx, m.method_vector),
)
Expand Down Expand Up @@ -1174,7 +1174,7 @@ function _vector_transport_to(
px, pVx = submanifold_components(M.manifold, p)
VXM, VXF = submanifold_components(M.manifold, X)
qx, qVx = submanifold_components(M.manifold, q)
return ProductRepr(
return ArrayPartition(
vector_transport_to(M.manifold, px, VXM, qx, m.method_point),
vector_transport_to(M.manifold, px, VXF, qx, m.method_vector),
)
Expand Down
25 changes: 21 additions & 4 deletions src/product_representations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,25 @@ created as

where `[1.0, 0.0, 0.0]` is the part corresponding to the sphere factor
and `[2.0, 3.0]` is the part corresponding to the euclidean manifold.


!!! warning

`ProductRepr` is deprecated and will be removed in a future release.
Please use `ArrayPartition` instead.
"""
struct ProductRepr{TM<:Tuple}
parts::TM
end

ProductRepr(points...) = ProductRepr{typeof(points)}(points)
function ProductRepr(points...)
Base.depwarn(
"`ProductRepr` will be deprecated in a future release. " *
"Please use `ArrayPartition` instead of `ProductRepr`.",
:ProductRepr,
)
return ProductRepr{typeof(points)}(points)
end

Base.:(==)(x::ProductRepr, y::ProductRepr) = x.parts == y.parts

Expand All @@ -89,13 +102,17 @@ allocate(x::ProductRepr) = ProductRepr(map(allocate, submanifold_components(x)).
function allocate(x::ProductRepr, ::Type{T}) where {T}
return ProductRepr(map(t -> allocate(t, T), submanifold_components(x))...)
end
allocate(p::ProductRepr, ::Type{T}, s::Size{S}) where {S,T} = Vector{T}(undef, S)
allocate(p::ProductRepr, ::Type{T}, s::Integer) where {T} = Vector{T}(undef, s)
allocate(::ProductRepr, ::Type{T}, s::Size{S}) where {S,T} = Vector{T}(undef, S)
allocate(::ProductRepr, ::Type{T}, s::Integer) where {T} = Vector{T}(undef, s)
allocate(a::AbstractArray{<:ProductRepr}) = map(allocate, a)

Base.copy(x::ProductRepr) = ProductRepr(map(copy, x.parts))

function Base.copyto!(x::ProductRepr, y::ProductRepr)
function Base.copyto!(x::ProductRepr, y::Union{ProductRepr,ArrayPartition})
map(copyto!, submanifold_components(x), submanifold_components(y))
return x
end
function Base.copyto!(x::ArrayPartition, y::ProductRepr)
map(copyto!, submanifold_components(x), submanifold_components(y))
return x
end
Expand Down
60 changes: 31 additions & 29 deletions test/groups/metric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -144,35 +144,37 @@ end
([-2.0, 1.0, 0.5], hat(Rn, p, [-1.0, -0.5, 1.1])),
]

pts = [ProductRepr(tp...) for tp in tuple_pts]
X_pts = [ProductRepr(tX...) for tX in tuple_X]

g1, g2 = pts[1:2]
t1, R1 = g1.parts
t2, R2 = g2.parts
g1g2 = ProductRepr(R1 * t2 + t1, R1 * R2)
@test isapprox(G, compose(G, g1, g2), g1g2)

test_group(
G,
pts,
X_pts,
X_pts;
test_diff=true,
test_lie_bracket=true,
test_adjoint_action=true,
diff_convs=[(), (LeftAction(),), (RightAction(),)],
)
test_manifold(
G,
pts;
#basis_types_vecs=basis_types,
basis_types_to_from=basis_types,
is_mutating=true,
#test_inplace=true,
test_vee_hat=false,
exp_log_atol_multiplier=50,
)
for prod_type in [ProductRepr, ArrayPartition]
pts = [prod_type(tp...) for tp in tuple_pts]
X_pts = [prod_type(tX...) for tX in tuple_X]

g1, g2 = pts[1:2]
t1, R1 = submanifold_components(g1)
t2, R2 = submanifold_components(g2)
g1g2 = prod_type(R1 * t2 + t1, R1 * R2)
@test isapprox(G, compose(G, g1, g2), g1g2)

test_group(
G,
pts,
X_pts,
X_pts;
test_diff=true,
test_lie_bracket=true,
test_adjoint_action=true,
diff_convs=[(), (LeftAction(),), (RightAction(),)],
)
test_manifold(
G,
pts;
#basis_types_vecs=basis_types,
basis_types_to_from=basis_types,
is_mutating=true,
#test_inplace=true,
test_vee_hat=false,
exp_log_atol_multiplier=50,
)
end
end
end
end
72 changes: 37 additions & 35 deletions test/groups/semidirect_product_group.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,40 +18,42 @@ include("group_utils.jl")
ts2 = Vector{Float64}.([1:2, 2:3, 3:4]) .* 10
tuple_pts = [zip(ts1, ts2)...]

pts = [ProductRepr(tp...) for tp in tuple_pts]

@testset "setindex! and getindex" begin
p1 = pts[1]
p2 = allocate(p1)
@test p1[G, 1] === p1[M, 1]
p2[G, 1] = p1[M, 1]
@test p2[G, 1] == p1[M, 1]
for prod_type in [ProductRepr, ArrayPartition]
pts = [prod_type(tp...) for tp in tuple_pts]

@testset "setindex! and getindex" begin
p1 = pts[1]
p2 = allocate(p1)
@test p1[G, 1] === p1[M, 1]
p2[G, 1] = p1[M, 1]
@test p2[G, 1] == p1[M, 1]
end

X = log(G, pts[1], pts[1])
Y = zero_vector(G, pts[1])
Z = Manifolds.allocate_result(G, zero_vector, pts[1])
Z = zero_vector!(M, Z, pts[1])
@test norm(G, pts[1], X) ≈ 0
@test norm(G, pts[1], Y) ≈ 0
@test norm(G, pts[1], Z) ≈ 0

e = Identity(G)
@test inv(G, e) === e

@test compose(G, e, pts[1]) == pts[1]
@test compose(G, pts[1], e) == pts[1]
@test compose(G, e, e) === e

# test in-place composition
o1 = copy(pts[1])
compose!(G, o1, o1, pts[2])
@test isapprox(G, o1, compose(G, pts[1], pts[2]))

eA = identity_element(G)
@test isapprox(G, eA, e)
@test isapprox(G, e, eA)
W = log(G, eA, pts[1])
Z = log(G, eA, pts[1])
@test isapprox(G, e, W, Z)
end

X = log(G, pts[1], pts[1])
Y = zero_vector(G, pts[1])
Z = Manifolds.allocate_result(G, zero_vector, pts[1])
Z = zero_vector!(M, Z, pts[1])
@test norm(G, pts[1], X) ≈ 0
@test norm(G, pts[1], Y) ≈ 0
@test norm(G, pts[1], Z) ≈ 0

e = Identity(G)
@test inv(G, e) === e

@test compose(G, e, pts[1]) == pts[1]
@test compose(G, pts[1], e) == pts[1]
@test compose(G, e, e) === e

# test in-place composition
o1 = copy(pts[1])
compose!(G, o1, o1, pts[2])
@test isapprox(G, o1, compose(G, pts[1], pts[2]))

eA = identity_element(G)
@test isapprox(G, eA, e)
@test isapprox(G, e, eA)
W = log(G, eA, pts[1])
Z = log(G, eA, pts[1])
@test isapprox(G, e, W, Z)
end
Loading