From 5d52c73234c9888da314df91039647b90741bbef Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 11 Mar 2020 18:42:17 +0100 Subject: [PATCH 01/27] mysterious vee/hat fixes --- src/DecoratorManifold.jl | 8 -------- src/ManifoldsBase.jl | 5 ----- 2 files changed, 13 deletions(-) diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl index da770bfa..752b8be4 100644 --- a/src/DecoratorManifold.jl +++ b/src/DecoratorManifold.jl @@ -502,10 +502,6 @@ decorated_manifold(M::Manifold) = M.manifold @decorator_transparent_signature exp!(M::AbstractDecoratorManifold, q, p, X) -@decorator_transparent_signature hat(M::AbstractDecoratorManifold, p, Xⁱ) - -@decorator_transparent_signature hat!(M::AbstractDecoratorManifold, X, p, Xⁱ) - @decorator_transparent_signature injectivity_radius(M::AbstractDecoratorManifold) @decorator_transparent_signature injectivity_radius(M::AbstractDecoratorManifold, p) @decorator_transparent_signature injectivity_radius( @@ -641,8 +637,4 @@ decorated_manifold(M::Manifold) = M.manifold m::AbstractVectorTransportMethod, ) -@decorator_transparent_signature vee!(M::AbstractDecoratorManifold, Xⁱ, p, X) - -@decorator_transparent_signature vee(M::AbstractDecoratorManifold, p, X) - @decorator_transparent_signature zero_tangent_vector!(M::AbstractDecoratorManifold, X, p) diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index 8df0903f..fdcfaac6 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -835,11 +835,6 @@ function vee!(M::Manifold, Xⁱ, p, X) error(manifold_function_not_implemented_message(M, vee!, Xⁱ, p, X)) end -function allocate_result(M::Manifold, f::typeof(vee), p, X) - T = allocate_result_type(M, f, (p, X)) - return allocate(p, T, manifold_dimension(M)) -end - """ zero_tangent_vector!(M::Manifold, X, p) From 7738d08decdf2ed495eb51990f57482f161c33bd Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 11 Mar 2020 19:33:26 +0100 Subject: [PATCH 02/27] moving numbers.jl from Manifolds --- src/ManifoldsBase.jl | 3 +++ src/numbers.jl | 54 ++++++++++++++++++++++++++++++++++++++++++++ test/numbers.jl | 24 ++++++++++++++++++++ test/runtests.jl | 1 + 4 files changed, 82 insertions(+) create mode 100644 src/numbers.jl create mode 100644 test/numbers.jl diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index fdcfaac6..5f396ad5 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -854,6 +854,8 @@ function zero_tangent_vector(M::Manifold, p) zero_tangent_vector!(M, X, p) return X end + +include("numbers.jl") include("DecoratorManifold.jl") include("ArrayManifold.jl") include("DefaultManifold.jl") @@ -900,6 +902,7 @@ export allocate, project_point!, project_tangent, project_tangent!, + real_dimension, representation_size, retract, retract!, diff --git a/src/numbers.jl b/src/numbers.jl new file mode 100644 index 00000000..3913c2a7 --- /dev/null +++ b/src/numbers.jl @@ -0,0 +1,54 @@ +""" + AbstractNumbers + +An abstract type to represent the number system on which a manifold is built. + +This provides concrete number types for dispatch. The two most common number types are +the fields [`RealNumbers`](@ref) (`ℝ` for short) and [`ComplexNumbers`](@ref) (`ℂ`). +""" +abstract type AbstractNumbers end + +""" + ℝ = RealNumbers() + +The field of real numbers. +""" +struct RealNumbers <: AbstractNumbers end + +""" + ℂ = ComplexNumbers() + +The field of complex numbers. +""" +struct ComplexNumbers <: AbstractNumbers end + +""" + ℍ = QuaternionNumbers() + +The division algebra of quaternions. +""" +struct QuaternionNumbers <: AbstractNumbers end + +const ℝ = RealNumbers() +const ℂ = ComplexNumbers() +const ℍ = QuaternionNumbers() + +Base.show(io::IO, ::RealNumbers) = print(io, "ℝ") +Base.show(io::IO, ::ComplexNumbers) = print(io, "ℂ") +Base.show(io::IO, ::QuaternionNumbers) = print(io, "ℍ") + +@doc raw""" + real_dimension(𝔽::AbstractNumbers) + +Return the real dimension $\dim_ℝ 𝔽$ of the [`AbstractNumbers`] system `𝔽`. +The real dimension is the dimension of a real vector space with which a number in `𝔽` can be +identified. +For example, [`ComplexNumbers`](@ref) have a real dimension of 2, and +[`QuaternionNumbers`](@ref) have a real dimension of 4. +""" +function real_dimension(𝔽::AbstractNumbers) + error("real_dimension not defined for number system $(𝔽)") +end +real_dimension(::RealNumbers) = 1 +real_dimension(::ComplexNumbers) = 2 +real_dimension(::QuaternionNumbers) = 4 diff --git a/test/numbers.jl b/test/numbers.jl new file mode 100644 index 00000000..ee22607c --- /dev/null +++ b/test/numbers.jl @@ -0,0 +1,24 @@ +using Test +using ManifoldsBase +using ManifoldsBase: AbstractNumbers, ℝ, ℂ, ℍ + +struct NotImplementedNumbers <: ManifoldsBase.AbstractNumbers end + +@testset "Number systems" begin + @test_throws ErrorException real_dimension(NotImplementedNumbers()) + + @test ℝ isa ManifoldsBase.RealNumbers + @test ManifoldsBase.RealNumbers() === ℝ + @test real_dimension(ℝ) == 1 + @test repr(ℝ) == "ℝ" + + @test ℂ isa ManifoldsBase.ComplexNumbers + @test ManifoldsBase.ComplexNumbers() === ℂ + @test real_dimension(ℂ) == 2 + @test repr(ℂ) == "ℂ" + + @test ℍ isa ManifoldsBase.QuaternionNumbers + @test ManifoldsBase.QuaternionNumbers() === ℍ + @test real_dimension(ℍ) == 4 + @test repr(ℍ) == "ℍ" +end diff --git a/test/runtests.jl b/test/runtests.jl index 63307a52..0607337e 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,6 +2,7 @@ using Test @testset "ManifoldsBase" begin include("allocation.jl") + include("numbers.jl") include("decorator_manifold.jl") include("empty_manifold.jl") include("default_manifold.jl") From 28bab5417c5b2fea4d3d516eec6ba43ba7e62146 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Wed, 11 Mar 2020 21:16:57 +0100 Subject: [PATCH 03/27] the main part of moving bases to ManifoldsBase --- src/ArrayManifold.jl | 86 +++++++ src/DefaultManifold.jl | 18 +- src/ManifoldsBase.jl | 66 ++--- src/bases.jl | 541 +++++++++++++++++++++++++++++++++++++++++ test/bases.jl | 205 ++++++++++++++++ test/runtests.jl | 4 + 6 files changed, 868 insertions(+), 52 deletions(-) create mode 100644 src/bases.jl create mode 100644 test/bases.jl diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index 1286aa24..603720b7 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -130,6 +130,92 @@ function exp!(M::ArrayManifold, y, x, v; kwargs...) return y end +function get_basis( + M::ArrayManifold, + p, + B::CachedBasis{<:AbstractOrthonormalBasis{ℝ},T,ℝ}, +) where {T<:AbstractVector} + bvectors = get_vectors(M, p, B) + N = length(bvectors) + M_dim = manifold_dimension(M) + if N != M_dim + + throw(ArgumentError("Incorrect number of basis vectors; expected: $M_dim, given: $N")) + end + for i = 1:N + Xi_norm = norm(M, p, bvectors[i]) + if !isapprox(Xi_norm, 1) + throw(ArgumentError("vector number $i is not normalized (norm = $Xi_norm)")) + end + for j = i+1:N + dot_val = real(inner(M, p, bvectors[i], bvectors[j])) + if !isapprox(dot_val, 0; atol = eps(eltype(p))) + throw(ArgumentError("vectors number $i and $j are not orthonormal (inner product = $dot_val)")) + end + end + end + return B +end + + +# the following is not nice, can we do better when using decorators and a specific last part? +function get_coordinates(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) + _get_coordinates(M, p, X, B, kwargs...) +end +function get_coordinates(M::ArrayManifold, p, X, B::CachedBasis; kwargs...) + _get_coordinates(M, p, X, B, kwargs...) +end +function get_coordinates(M::ArrayManifold, p, X, B::DefaultBasis; kwargs...) + _get_coordinates(M, p, X, B, kwargs...) +end +function get_coordinates(M::ArrayManifold, p, X, B::DefaultOrthogonalBasis; kwargs...) + _get_coordinates(M, p, X, B, kwargs...) +end +function get_coordinates(M::ArrayManifold, p, X, B::DefaultOrthonormalBasis; kwargs...) + _get_coordinates(M, p, X, B, kwargs...) +end + +function _get_coordinates(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) + return get_coordinates(M.manifold, p, X, B) +end +function get_coordinates!(M::ArrayManifold, Y, p, X, B::all_uncached_bases; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) + get_coordinates!(M, Y, p, X, B) + return Y +end + +function get_vector(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) + return _get_vector(M, p, X, B, kwargs...) +end +function get_vector(M::ArrayManifold, p, X, B::CachedBasis; kwargs...) + return _get_vector(M, p, X, B, kwargs...) +end +function get_vector(M::ArrayManifold, p, X, B::DefaultBasis; kwargs...) + return _get_vector(M, p, X, B, kwargs...) +end +function get_vector(M::ArrayManifold, p, X, B::DefaultOrthogonalBasis; kwargs...) + return _get_vector(M, p, X, B, kwargs...) +end +function get_vector(M::ArrayManifold, p, X, B::DefaultOrthonormalBasis; kwargs...) + return _get_vector(M, p, X, B, kwargs...) +end + +function _get_vector(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) + is_manifold_point(M, p, true; kwargs...) + size(X) == (manifold_dimension(M),) || error("Incorrect size of coefficient vector X") + Y = get_vector(M.manifold, p, X, B) + size(Y) == representation_size(M) || error("Incorrect size of tangent vector Y") + return Y +end +function get_vector!(M::ArrayManifold, Y, p, X, B::all_uncached_bases; kwargs...) + is_manifold_point(M, p, true; kwargs...) + size(X) == (manifold_dimension(M),) || error("Incorrect size of coefficient vector X") + get_vector!(M.manifold, Y, p, X, B) + size(Y) == representation_size(M) || error("Incorrect size of tangent vector Y") + return Y +end + injectivity_radius(M::ArrayManifold) = injectivity_radius(M.manifold) function injectivity_radius(M::ArrayManifold, method::AbstractRetractionMethod) return injectivity_radius(M.manifold, method) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 5be7ebd4..bbaf8b42 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -16,9 +16,19 @@ distance(::DefaultManifold, x, y) = norm(x - y) exp!(::DefaultManifold, y, x, v) = (y .= x .+ v) -hat!(M::DefaultManifold, X, p, Xⁱ) = copyto!(X, reshape(Xⁱ, representation_size(M))) +function get_basis(M::DefaultManifold, p, B::DefaultOrthonormalBasis) + return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) +end -@generated manifold_dimension(::DefaultManifold{T}) where {T} = *(T.parameters...) +function get_coordinates!(M::DefaultManifold, Y, p, X, B::DefaultOrthonormalBasis) + Y .= reshape(X, manifold_dimension(M)) + return Y +end + +function get_vector!(M::DefaultManifold, Y, p, X, B::DefaultOrthonormalBasis) + Y .= reshape(X, representation_size(M)) + return Y +end injectivity_radius(::DefaultManifold) = Inf @@ -26,6 +36,8 @@ injectivity_radius(::DefaultManifold) = Inf log!(::DefaultManifold, v, x, y) = (v .= y .- x) +@generated manifold_dimension(::DefaultManifold{T}) where {T} = *(T.parameters...) + norm(::DefaultManifold, x, v) = norm(v) project_point!(::DefaultManifold, y, x) = copyto!(y, x) @@ -49,6 +61,4 @@ function vector_transport_to!(::DefaultManifold, vto, x, v, y, ::ParallelTranspo return copyto!(vto, v) end -vee!(M::DefaultManifold, Xⁱ, p, X) = copyto!(Xⁱ, reshape(X, manifold_dimension(M))) - zero_tangent_vector!(::DefaultManifold, v, x) = fill!(v, 0) diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index 5f396ad5..fe259a2e 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -1,6 +1,6 @@ module ManifoldsBase -import Base: isapprox, exp, log, convert, copyto!, angle, eltype, similar, +, -, * +import Base: isapprox, exp, log, convert, copyto!, angle, eltype, similar, show, +, -, * import LinearAlgebra: dot, norm, det, cross, I, UniformScaling, Diagonal import Markdown: @doc_str @@ -313,29 +313,6 @@ geodesic(M::Manifold, p, X) = t -> exp(M, p, X, t) geodesic(M::Manifold, p, X, t::Real) = exp(M, p, X, t) geodesic(M::Manifold, p, X, T::AbstractVector) = map(t -> exp(M, p, X, t), T) -@doc raw""" - hat(M::Manifold, p, Xⁱ) - -Given a basis $e_i$ on the tangent space at a point `p` and tangent -component vector $X^i$, compute the equivalent vector representation -$X=X^i e_i$, where Einstein summation notation is used: - -````math -∧ : X^i ↦ X^i e_i -```` - -For array manifolds, this converts a vector representation of the tangent -vector to an array representation. The [`vee`](@ref) map is the `hat` map's -inverse. -""" -function hat(M::Manifold, p, Xⁱ) - X = allocate_result(M, hat, p, Xⁱ) - return hat!(M, X, p, Xⁱ) -end -function hat!(M::Manifold, X, p, Xⁱ) - error(manifold_function_not_implemented_message(M, hat!, X, p, Xⁱ)) -end - @doc doc""" injectivity_radius(M::Manifold, p) @@ -811,30 +788,6 @@ function vector_transport_to!( )) end -@doc raw""" - vee(M::Manifold, p, X) - -Given a basis $e_i$ on the tangent space at a point `p` and tangent -vector `X`, compute the vector components $X^i$, such that $X = X^i e_i$, where -Einstein summation notation is used: - -````math -\vee : X^i e_i ↦ X^i -```` - -For array manifolds, this converts an array representation of the tangent -vector to a vector representation. The [`hat`](@ref) map is the `vee` map's -inverse. -""" -function vee(M::Manifold, p, X) - Xⁱ = allocate_result(M, vee, p, X) - return vee!(M, Xⁱ, p, X) -end - -function vee!(M::Manifold, Xⁱ, p, X) - error(manifold_function_not_implemented_message(M, vee!, Xⁱ, p, X)) -end - """ zero_tangent_vector!(M::Manifold, X, p) @@ -857,6 +810,7 @@ end include("numbers.jl") include("DecoratorManifold.jl") +include("bases.jl") include("ArrayManifold.jl") include("DefaultManifold.jl") @@ -875,6 +829,15 @@ export AbstractInverseRetractionMethod, export ParallelTransport, ProjectionTransport +export + CachedBasis, + DefaultBasis, + DefaultOrthogonalBasis, + DefaultOrthonormalBasis, + DiagonalizingOrthonormalBasis, + DefaultOrthonormalBasis, + ProjectedOrthonormalBasis + export allocate, base_manifold, check_manifold_point, @@ -883,6 +846,12 @@ export allocate, exp, exp!, geodesic, + get_basis, + get_coordinates, + get_coordinates!, + get_vector, + get_vector!, + get_vectors, hat, hat!, shortest_geodesic, @@ -898,6 +867,7 @@ export allocate, manifold_dimension, norm, number_eltype, + number_system, project_point, project_point!, project_tangent, diff --git a/src/bases.jl b/src/bases.jl new file mode 100644 index 00000000..9b337314 --- /dev/null +++ b/src/bases.jl @@ -0,0 +1,541 @@ +""" + AbstractBasis{𝔽} + +Abstract type that represents a basis on a manifold or a subset of it. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as scalars. +""" +abstract type AbstractBasis{𝔽} end + +""" + DefaultBasis{𝔽} + +An arbitrary basis on a manifold. This will usually +be the fastest basis available for a manifold. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as scalars. +""" +struct DefaultBasis{𝔽} <: AbstractBasis{𝔽} end +DefaultBasis(𝔽::AbstractNumbers = ℝ) = DefaultBasis{𝔽}() + +""" + AbstractOrthogonalBasis{𝔽} + +Abstract type that represents an orthonormal basis on a manifold or a subset of it. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as scalars. +""" +abstract type AbstractOrthogonalBasis{𝔽} <: AbstractBasis{𝔽} end + +""" + DefaultOrthogonalBasis{𝔽} + +An arbitrary orthogonal basis on a manifold. This will usually +be the fastest orthogonal basis available for a manifold. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as scalars. +""" +struct DefaultOrthogonalBasis{𝔽} <: AbstractOrthogonalBasis{𝔽} end +DefaultOrthogonalBasis(𝔽::AbstractNumbers = ℝ) = DefaultOrthogonalBasis{𝔽}() + + +struct VeeOrthogonalBasis{𝔽} <: AbstractOrthogonalBasis{𝔽} end +VeeOrthogonalBasis(𝔽::AbstractNumbers = ℝ) = VeeOrthogonalBasis{𝔽}() + +""" + AbstractOrthonormalBasis{𝔽} + +Abstract type that represents an orthonormal basis on a manifold or a subset of it. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as scalars. +""" +abstract type AbstractOrthonormalBasis{𝔽} <: AbstractOrthogonalBasis{𝔽} end + +""" + DefaultOrthonormalBasis(𝔽::AbstractNumbers = ℝ) + +An arbitrary orthonormal basis on a manifold. This will usually +be the fastest orthonormal basis available for a manifold. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as +scalars. +""" +struct DefaultOrthonormalBasis{𝔽} <: AbstractOrthonormalBasis{𝔽} end + +DefaultOrthonormalBasis(𝔽::AbstractNumbers = ℝ) = DefaultOrthonormalBasis{𝔽}() + +""" + ProjectedOrthonormalBasis(method::Symbol, 𝔽::AbstractNumbers = ℝ) + +An orthonormal basis that comes from orthonormalization of basis vectors +of the ambient space projected onto the subspace representing the tangent space +at a given point. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as +scalars. + +Available methods: + - `:gram_schmidt` uses a modified Gram-Schmidt orthonormalization. + - `:svd` uses SVD decomposition to orthogonalize projected vectors. + The SVD-based method should be more numerically stable at the cost of + an additional assumption (local metric tensor at a point where the + basis is calculated has to be diagonal). +""" +struct ProjectedOrthonormalBasis{Method,𝔽} <: AbstractOrthonormalBasis{𝔽} end + +function ProjectedOrthonormalBasis(method::Symbol, 𝔽::AbstractNumbers = ℝ) + return ProjectedOrthonormalBasis{method,𝔽}() +end + +@doc raw""" + DiagonalizingOrthonormalBasis(frame_direction, 𝔽::AbstractNumbers = ℝ) + +An orthonormal basis `Ξ` as a vector of tangent vectors (of length determined by +[`manifold_dimension`](@ref)) in the tangent space that diagonalizes the curvature +tensor $R(u,v)w$ and where the direction `frame_direction` $v$ has curvature `0`. + +The type parameter `𝔽` denotes the [`AbstractNumbers`](@ref) that will be used as +scalars. +""" +struct DiagonalizingOrthonormalBasis{TV,𝔽} <: AbstractOrthonormalBasis{𝔽} + frame_direction::TV +end +function DiagonalizingOrthonormalBasis(X, 𝔽::AbstractNumbers = ℝ) + return DiagonalizingOrthonormalBasis{typeof(X),𝔽}(X) +end +struct DiagonalizingBasisData{D,V,ET} + frame_direction::D + eigenvalues::ET + vectors::V +end + +const DefaultOrDiagonalizingBasis = + Union{DefaultOrthonormalBasis,DiagonalizingOrthonormalBasis} + + +struct CachedBasis{B,V,𝔽} <: AbstractBasis{𝔽} where {BT<:AbstractBasis,V} + data::V +end +function CachedBasis(basis::B, data::V, 𝔽::AbstractNumbers = ℝ) where {V,B<:AbstractBasis} + return CachedBasis{B,V,𝔽}(data) +end +function CachedBasis(basis::CachedBasis) # avoid double encapsulation + return basis +end +function CachedBasis( + basis::DiagonalizingOrthonormalBasis, + eigenvalues::ET, + vectors::T, + 𝔽::AbstractNumbers = ℝ, +) where {ET<:AbstractVector,T<:AbstractVector} + data = DiagonalizingBasisData(basis.frame_direction, eigenvalues, vectors) + return CachedBasis(basis, data, 𝔽) +end + +# forward declarations +function get_coordinates end +function get_vector end + +const all_uncached_bases = Union{AbstractBasis, DefaultBasis, DefaultOrthogonalBasis, DefaultOrthonormalBasis} +const DISAMBIGUATION_BASIS_TYPES = [CachedBasis, DefaultOrthogonalBasis, DefaultOrthonormalBasis, DefaultOrDiagonalizingBasis, VeeOrthogonalBasis] + +function allocate_result(M::Manifold, f::typeof(get_coordinates), p, X) + T = allocate_result_type(M, f, (p, X)) + return allocate(p, T, manifold_dimension(M)) +end + +@inline function allocate_result_type( + M::Manifold, + f::Union{typeof(get_coordinates), typeof(get_vector)}, + args::Tuple, +) + apf = allocation_promotion_function(M, f, args) + return apf(invoke(allocate_result_type, Tuple{Manifold,Any,typeof(args)}, M, f, args)) +end + +""" + allocation_promotion_function(M::Manifold, f, args::Tuple) + +Determine the function that must be used to ensure that the allocated representation is of +the right type. This is needed for [`get_vector`](@ref) when a point on a complex manifold +is represented by a real-valued vectors with a real-coefficient basis, so that +a complex-valued vector representation is allocated. +""" +allocation_promotion_function(M::Manifold, f, args::Tuple) = identity + +function combine_allocation_promotion_functions(f::T, ::T) where {T} + return f +end +function combine_allocation_promotion_functions(::typeof(complex), ::typeof(identity)) + return complex +end +function combine_allocation_promotion_functions(::typeof(identity), ::typeof(complex)) + return complex +end + +function _euclidean_basis_vector(p, i) + X = zero(p) + X[i] = 1 + return X +end + +""" + get_basis(M::Manifold, p, B::AbstractBasis) -> CachedBasis + +Compute the basis vectors of the tangent space at a point on manifold `M` +represented by `p`. + +Returned object derives from [`AbstractBasis`](@ref) and may have a field `.vectors` +that stores tangent vectors or it may store them implicitly, in which case +the function [`get_vectors`](@ref) needs to be used to retrieve the basis vectors. + +See also: [`get_coordinates`](@ref), [`get_vector`](@ref) +""" +function get_basis(M::Manifold, p, B::AbstractBasis) + error("get_basis not implemented for manifold of type $(typeof(M)) a point of type $(typeof(p)) and basis of type $(typeof(B)).") +end + +function get_basis(M::Manifold, p, B::DefaultOrthonormalBasis) + dim = manifold_dimension(M) + return CachedBasis( + B, + [get_vector(M, p, [ifelse(i == j, 1, 0) for j = 1:dim], B) for i = 1:dim], + ) +end +function get_basis(M::Manifold, p, B::CachedBasis) + return B +end +function get_basis(M::Manifold, p, B::ProjectedOrthonormalBasis{:svd,ℝ}) + S = representation_size(M) + PS = prod(S) + dim = manifold_dimension(M) + # projection + # TODO: find a better way to obtain a basis of the ambient space + Xs = [ + convert(Vector, reshape(project_tangent(M, p, _euclidean_basis_vector(p, i)), PS)) + for i in eachindex(p) + ] + O = reduce(hcat, Xs) + # orthogonalization + # TODO: try using rank-revealing QR here + decomp = svd(O) + rotated = Diagonal(decomp.S) * decomp.Vt + vecs = [collect(reshape(rotated[i, :], S)) for i = 1:dim] + # normalization + for i = 1:dim + i_norm = norm(M, p, vecs[i]) + vecs[i] /= i_norm + end + return CachedBasis(B, vecs) +end +function get_basis(M::Manifold, p, B::ProjectedOrthonormalBasis{:gram_schmidt,ℝ}; kwargs...) + E = [_euclidean_basis_vector(p, i) for i in eachindex(p)] + N = length(E) + Ξ = empty(E) + dim = manifold_dimension(M) + N < dim && @warn "Input only has $(N) vectors, but manifold dimension is $(dim)." + K = 0 + @inbounds for n = 1:N + Ξₙ = project_tangent(M, p, E[n]) + for k = 1:K + Ξₙ .-= real(inner(M, p, Ξ[k], Ξₙ)) .* Ξ[k] + end + nrmΞₙ = norm(M, p, Ξₙ) + if nrmΞₙ == 0 + @warn "Input vector $(n) has length 0." + @goto skip + end + Ξₙ ./= nrmΞₙ + for k = 1:K + if !isapprox(real(inner(M, p, Ξ[k], Ξₙ)), 0; kwargs...) + @warn "Input vector $(n) is not linearly independent of output basis vector $(k)." + @goto skip + end + end + push!(Ξ, Ξₙ) + K += 1 + K * real_dimension(number_system(B)) == dim && return CachedBasis(B, Ξ, ℝ) + @label skip + end + @warn "get_basis with bases $(typeof(B)) only found $(K) orthonormal basis vectors, but manifold dimension is $(dim)." + return CachedBasis(B, Ξ) +end + +""" + get_coordinates(M::Manifold, p, X, B::AbstractBasis) + get_coordinates(M::Manifold, p, X, B::CachedBasis) + +Compute a one-dimensional vector of coefficients of the tangent vector `X` +at point denoted by `p` on manifold `M` in basis `B`. + +Depending on the basis, `p` may not directly represent a point on the manifold. +For example if a basis transported along a curve is used, `p` may be the coordinate +along the curve. If a [`CachedBasis`](@ref) is provided, their stored vectors are used, +otherwise the user has to provide a method to compute the coordinates. + +For the [`CachedBasis`](@ref) keep in mind that the reconstruction with [`get_vector`](@ref) +requires either a dual basis or the cached basis to be selfdual, for example orthonormal + +See also: [`get_vector`](@ref), [`get_basis`](@ref) +""" +function get_coordinates(M::Manifold, p, X, B::AbstractBasis) + Y = allocate_result(M, get_coordinates, p, X) + return get_coordinates!(M, Y, p, X, B) +end +@decorator_transparent_signature get_coordinates(M::AbstractDecoratorManifold, p, X, B::AbstractBasis) +function decorator_transparent_dispatch(::typeof(get_coordinates), ::Manifold, args...) + return Val(:parent) +end + +function get_coordinates!(M::Manifold, Y, p, X, B::AbstractBasis) + error("get_coordinates! not implemented for manifold of type $(typeof(M)) coordinates of type $(typeof(Y)), a point of type $(typeof(p)), tangent vector of type $(typeof(X)) and basis of type $(typeof(B)).") +end +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::AbstractBasis) +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis) +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis{BT,V,𝔽}) where {BT<:AbstractBasis{ℝ}, 𝔽, V} +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultBasis) +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::VeeOrthogonalBasis) +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthogonalBasis) +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthonormalBasis) +@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DiagonalizingOrthonormalBasis) +function decorator_transparent_dispatch(::typeof(get_coordinates!), ::Manifold, args...) + return Val(:transparent) +end + +function get_coordinates!(M::Manifold, Y, p, X, B::VeeOrthogonalBasis) + return get_coordinates!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) +end +function get_coordinates!(M::Manifold, Y, p, X, B::DefaultBasis) + return get_coordinates!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) +end +function get_coordinates!(M::Manifold, Y, p, X, B::DefaultOrthogonalBasis) + return get_coordinates!(M, Y, p, X, DefaultOrthonormalBasis(number_system(B))) +end +function get_coordinates!( + M::Manifold, + Y, + p, + X, + B::CachedBasis{BT}, +) where {BT<:AbstractBasis{ℝ}} + map!(vb -> real(inner(M, p, X, vb)), Y, get_vectors(M, p, B)) + return Y +end +function get_coordinates!(M::Manifold, Y, p, X, B::CachedBasis) + map!(vb -> inner(M, p, X, vb), Y, get_vectors(M, p, B)) + return Y +end + + +""" + get_vector(M::Manifold, p, X, B::AbstractBasis) + +Convert a one-dimensional vector of coefficients in a basis `B` of +the tangent space at `p` on manifold `M` to a tangent vector `X` at `p`. + +Depending on the basis, `p` may not directly represent a point on the manifold. +For example if a basis transported along a curve is used, `p` may be the coordinate +along the curve. + +For the [`CachedBasis`](@ref) keep in mind that the reconstruction from [`get_coordinates`](@ref) +requires either a dual basis or the cached basis to be selfdual, for example orthonormal + +See also: [`get_coordinates`](@ref), [`get_basis`](@ref) +""" +function get_vector(M::Manifold, p, X, B::AbstractBasis) + Y = allocate_result(M, get_vector, p, X) + return get_vector!(M, Y, p, X, B) +end +@decorator_transparent_signature get_vector(M::AbstractDecoratorManifold, p, X, B::AbstractBasis) +function decorator_transparent_dispatch(::typeof(get_vector), ::Manifold, args...) + return Val(:parent) +end + +function get_vector!(M::Manifold, Y, p, X, B::AbstractBasis) + error("get_vector! not implemented for manifold of type $(typeof(M)) vector of type $(typeof(Y)), a point of type $(typeof(p)), coordinates of type $(typeof(X)) and basis of type $(typeof(B)).") +end +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::AbstractBasis) +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis) +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis{BT,V,𝔽}) where {BT<:AbstractBasis{ℝ}, 𝔽, V} +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultBasis) +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::VeeOrthogonalBasis) +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthogonalBasis) +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthonormalBasis) +@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DiagonalizingOrthonormalBasis) +function decorator_transparent_dispatch(::typeof(get_vector!), ::Manifold, args...) + return Val(:transparent) +end + +_get_vector_cache_broadcast(::Any) = Val(true) + +function get_vector!(M::Manifold, Y, p, X, B::VeeOrthogonalBasis) + return get_vector!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) +end +function get_vector!(M::Manifold, Y, p, X, B::DefaultBasis) + return get_vector!(M, Y, p, X, DefaultOrthogonalBasis(number_system(B))) +end +function get_vector!(M::Manifold, Y, p, X, B::DefaultOrthogonalBasis) + return get_vector!(M, Y, p, X, DefaultOrthonormalBasis(number_system(B))) +end +function get_vector!(M::Manifold, Y, p, X, B::CachedBasis) + # quite convoluted but: + # 1) preserves the correct `eltype` + # 2) guarantees a reasonable array type `Y` + # (for example scalar * `SizedArray` is an `SArray`) + bvectors = get_vectors(M, p, B) + if _get_vector_cache_broadcast(bvectors[1]) === Val(false) + Xt = X[1] * bvectors[1] + copyto!(Y, Xt) + for i = 2:length(X) + Y += X[i] * bvectors[i] + end + return Y + else + Xt = X[1] .* bvectors[1] + copyto!(Y, Xt) + for i = 2:length(X) + Y .+= X[i] .* bvectors[i] + end + return Y + end +end + +""" + get_vectors(M::Manifold, p, B::AbstractBasis) + +Get the basis vectors of basis `B` of the tangent space at point `p`. +""" +function get_vectors(M::Manifold, p, B::AbstractBasis) + error("get_vectors not implemented for manifold of type $(typeof(M)) a point of type $(typeof(p)) and basis of type $(typeof(B)).") +end +function get_vectors( + M::Manifold, + p, + B::CachedBasis{<:AbstractBasis,<:AbstractArray}, +) + return B.data +end +function get_vectors( + M::Manifold, + p, + B::CachedBasis{<:AbstractBasis,<:DiagonalizingBasisData}, +) + return B.data.vectors +end + +#internal for directly cached basis i.e. those that are just arrays – used in show +_get_vectors(B::CachedBasis{<:AbstractBasis,<:AbstractArray}) = B.data +_get_vectors(B::CachedBasis{<:AbstractBasis,<:DiagonalizingBasisData}) = B.data.vectors + +@doc raw""" + hat(M::Manifold, p, Xⁱ) + +Given a basis $e_i$ on the tangent space at a point `p` and tangent +component vector $X^i$, compute the equivalent vector representation +$X=X^i e_i$, where Einstein summation notation is used: + +````math +∧ : X^i ↦ X^i e_i +```` + +For array manifolds, this converts a vector representation of the tangent +vector to an array representation. The [`vee`](@ref) map is the `hat` map's +inverse. +""" +hat(M::Manifold, p, X) = get_vector(M, p, X, VeeOrthogonalBasis()) +hat!(M::Manifold, Y, p, X) = get_vector!(M, Y, p, X, VeeOrthogonalBasis()) + +""" + number_system(::AbstractBasis) + +The number system used as scalars in the given basis. +""" +number_system(::AbstractBasis{𝔽}) where {𝔽} = 𝔽 + +function _show_basis_vector(io::IO, X; pre = "", head = "") + sX = sprint(show, "text/plain", X, context = io, sizehint = 0) + sX = replace(sX, '\n' => "\n$(pre)") + print(io, head, pre, sX) +end +function _show_basis_vector_range(io::IO, Ξ, range; pre = "", sym = "E") + for i in range + _show_basis_vector(io, Ξ[i]; pre = pre, head = "\n$(sym)$(i) =\n") + end + return nothing +end +function _show_basis_vector_range_noheader(io::IO, Ξ; max_vectors = 4, pre = "", sym = "E") + nv = length(Ξ) + if nv ≤ max_vectors + _show_basis_vector_range(io, Ξ, 1:nv; pre = " ", sym = " E") + else + halfn = div(max_vectors, 2) + _show_basis_vector_range(io, Ξ, 1:halfn; pre = " ", sym = " E") + print(io, "\n ⋮") + _show_basis_vector_range(io, Ξ, (nv-halfn+1):nv; pre = " ", sym = " E") + end +end + +function show(io::IO, ::DefaultOrthonormalBasis{𝔽}) where {𝔽} + print(io, "DefaultOrthonormalBasis($(𝔽))") +end +function show(io::IO, ::ProjectedOrthonormalBasis{method,𝔽}) where {method,𝔽} + print(io, "ProjectedOrthonormalBasis($(repr(method)), $(𝔽))") +end +function show(io::IO, mime::MIME"text/plain", onb::DiagonalizingOrthonormalBasis) + println( + io, + "DiagonalizingOrthonormalBasis with coordinates in $(number_system(onb)) and eigenvalue 0 in direction:", + ) + sk = sprint(show, "text/plain", onb.frame_direction, context = io, sizehint = 0) + sk = replace(sk, '\n' => "\n ") + print(io, sk) +end +function show( + io::IO, + mime::MIME"text/plain", + B::CachedBasis{T,D,𝔽}, +) where {T<:AbstractBasis,D,𝔽} + vectors = _get_vectors(B) + nv = length(vectors) + print( + io, + "$(T()) with coordinates in $(number_system(B)) and $(nv) basis vector$(nv == 1 ? "" : "s"):", + ) + _show_basis_vector_range_noheader(io, vectors; max_vectors = 4, pre = " ", sym = " E") +end +function show( + io::IO, + mime::MIME"text/plain", + B::CachedBasis{T,D,𝔽}, +) where {T<:DiagonalizingOrthonormalBasis,D<:DiagonalizingBasisData,𝔽} + vectors = _get_vectors(B) + nv = length(vectors) + sk = sprint(show, "text/plain", T(B.data.frame_direction), context = io, sizehint = 0) + sk = replace(sk, '\n' => "\n ") + print(io, sk) + println(io, "\nand $(nv) basis vector$(nv == 1 ? "" : "s").") + print(io, "Basis vectors:") + _show_basis_vector_range_noheader(io, vectors; max_vectors = 4, pre = " ", sym = " E") + println(io, "\nEigenvalues:") + sk = sprint(show, "text/plain", B.data.eigenvalues, context = io, sizehint = 0) + sk = replace(sk, '\n' => "\n ") + print(io, ' ', sk) +end + +@doc raw""" + vee(M::Manifold, p, X) + +Given a basis $e_i$ on the tangent space at a point `p` and tangent +vector `X`, compute the vector components $X^i$, such that $X = X^i e_i$, where +Einstein summation notation is used: + +````math +\vee : X^i e_i ↦ X^i +```` + +For array manifolds, this converts an array representation of the tangent +vector to a vector representation. The [`hat`](@ref) map is the `vee` map's +inverse. +""" +vee(M::Manifold, p, X) = get_coordinates(M, p, X, VeeOrthogonalBasis()) +vee!(M::Manifold, Y, p, X) = get_coordinates!(M, Y, p, X, VeeOrthogonalBasis()) diff --git a/test/bases.jl b/test/bases.jl new file mode 100644 index 00000000..a79f6b0e --- /dev/null +++ b/test/bases.jl @@ -0,0 +1,205 @@ +using LinearAlgebra + +struct ProjManifold <: Manifold end + +ManifoldsBase.inner(::ProjManifold, x, w, v) = dot(w, v) +ManifoldsBase.project_tangent!(S::ProjManifold, w, x, v) = (w .= v .- dot(x, v) .* x) +ManifoldsBase.representation_size(::ProjManifold) = (2,3) +ManifoldsBase.manifold_dimension(::ProjManifold) = 5 +ManifoldsBase.get_vector(::ProjManifold, x, v, ::DefaultOrthonormalBasis) = reverse(v) + +@testset "Projected and arbitrary orthonormal basis" begin + M = ProjManifold() + x = [sqrt(2)/2 0.0 0.0; + 0.0 sqrt(2)/2 0.0] + + pb = get_basis(M, x, ProjectedOrthonormalBasis(:svd)) + @test number_system(pb) == ℝ + @test get_basis(M, x, pb) == pb + N = manifold_dimension(M) + @test isa(pb, CachedBasis) + @test length(get_vectors(M, x, pb)) == N + # test orthonormality + for i in 1:N + @test norm(M, x, get_vectors(M, x, pb)[i]) ≈ 1 + for j in i+1:N + @test inner(M, x, get_vectors(M, x, pb)[i], get_vectors(M, x, pb)[j]) ≈ 0 atol = 1e-15 + end + end + # check projection idempotency + for i in 1:N + @test project_tangent(M, x, get_vectors(M, x, pb)[i]) ≈ get_vectors(M, x, pb)[i] + end + + aonb = get_basis(M, x, DefaultOrthonormalBasis()) + @test size(get_vectors(M, x, aonb)) == (5,) + @test get_vectors(M, x, aonb)[1] ≈ [0, 0, 0, 0, 1] +end + +struct NonManifold <: Manifold end +struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end + +@testset "ManifoldsBase.jl stuff" begin + + @testset "Errors" begin + m = NonManifold() + onb = DefaultOrthonormalBasis() + + @test_throws ErrorException get_basis(m, [0], onb) + @test_throws ErrorException get_basis(m, [0], NonBasis()) + @test_throws ErrorException get_coordinates(m, [0], [0], onb) + @test_throws ErrorException get_coordinates!(m, [0], [0], [0], onb) + @test_throws ErrorException get_vector(m, [0], [0], onb) + @test_throws ErrorException get_vector!(m, [0], [0], [0], onb) + @test_throws ErrorException get_vectors(m, [0], NonBasis()) + end + + M = ManifoldsBase.DefaultManifold(3) + pts = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + @testset "basis representation" begin + v1 = log(M, pts[1], pts[2]) + + vb = get_coordinates(M, pts[1], v1, DefaultOrthonormalBasis()) + @test isa(vb, AbstractVector) + vbi = get_vector(M, pts[1], vb, DefaultOrthonormalBasis()) + @test isapprox(M, pts[1], v1, vbi) + + b = get_basis(M, pts[1], DefaultOrthonormalBasis()) + @test isa(b, CachedBasis{DefaultOrthonormalBasis{ℝ},Array{Array{Float64,1},1},ℝ}) + N = manifold_dimension(M) + @test length(get_vectors(M, pts[1], b)) == N + # check orthonormality + for i in 1:N + @test norm(M, pts[1], get_vectors(M, pts[1], b)[i]) ≈ 1 + for j in i+1:N + @test inner( + M, + pts[1], + get_vectors(M, pts[1], b)[i], + get_vectors(M, pts[1], b)[j] + ) ≈ 0 + end + end + # check that the coefficients correspond to the basis + for i in 1:N + @test inner(M, pts[1], v1, get_vectors(M, pts[1], b)[i]) ≈ vb[i] + end + + @test get_coordinates(M, pts[1], v1, b) ≈ get_coordinates(M, pts[1], v1, DefaultOrthonormalBasis()) + @test get_vector(M, pts[1], vb, b) ≈ get_vector(M, pts[1], vb, DefaultOrthonormalBasis()) + + v1c = allocate(v1) + get_coordinates!(M, v1c, pts[1], v1, b) + @test v1c ≈ get_coordinates(M, pts[1], v1, b) + + v1cv = allocate(v1) + get_vector!(M, v1cv, pts[1], v1c, b) + @test isapprox(M, pts[1], v1, v1cv) + end + + @testset "ArrayManifold basis" begin + A = ArrayManifold(M) + aonb = DefaultOrthonormalBasis() + b = get_basis(A, pts[1], aonb) + @test_throws ErrorException get_vector(A, pts[1], [], aonb) + @test_throws ArgumentError get_basis(A, pts[1], CachedBasis(aonb,[pts[1]])) + @test_throws ArgumentError get_basis(A, pts[1], CachedBasis(aonb,[pts[1], pts[1], pts[1]])) + @test_throws ArgumentError get_basis(A, pts[1], CachedBasis(aonb,[2*pts[1], pts[1], pts[1]])) + end +end + +@testset "Basis show methods" begin + @test sprint(show, DefaultOrthonormalBasis()) == "DefaultOrthonormalBasis(ℝ)" + @test sprint(show, DefaultOrthonormalBasis(ℂ)) == "DefaultOrthonormalBasis(ℂ)" + @test sprint(show, ProjectedOrthonormalBasis(:svd)) == "ProjectedOrthonormalBasis(:svd, ℝ)" + @test sprint(show, ProjectedOrthonormalBasis(:gram_schmidt, ℂ)) == "ProjectedOrthonormalBasis(:gram_schmidt, ℂ)" + + @test sprint(show, "text/plain", DiagonalizingOrthonormalBasis(Float64[1, 2, 3])) == """ + DiagonalizingOrthonormalBasis with coordinates in ℝ and eigenvalue 0 in direction: + 3-element Array{Float64,1}: + 1.0 + 2.0 + 3.0""" + + M = ManifoldsBase.DefaultManifold(2, 3) + x = collect(reshape(1.0:6.0, (2, 3))) + pb = get_basis(M, x, DefaultOrthonormalBasis()) + @test sprint(show, "text/plain", pb) == """ + DefaultOrthonormalBasis(ℝ) with coordinates in ℝ and 6 basis vectors: + E1 = + 2×3 Array{Float64,2}: + 1.0 0.0 0.0 + 0.0 0.0 0.0 + E2 = + 2×3 Array{Float64,2}: + 0.0 0.0 0.0 + 1.0 0.0 0.0 + ⋮ + E5 = + 2×3 Array{Float64,2}: + 0.0 0.0 1.0 + 0.0 0.0 0.0 + E6 = + 2×3 Array{Float64,2}: + 0.0 0.0 0.0 + 0.0 0.0 1.0""" + b = DiagonalizingOrthonormalBasis(get_vectors(M, x, pb)[1]) + dpb = CachedBasis(b, Float64[1, 2, 3, 4, 5, 6], get_vectors(M, x, pb)) + @test sprint(show, "text/plain", dpb) == """ + DiagonalizingOrthonormalBasis with coordinates in ℝ and eigenvalue 0 in direction: + 2×3 Array{Float64,2}: + 1.0 0.0 0.0 + 0.0 0.0 0.0 + and 6 basis vectors. + Basis vectors: + E1 = + 2×3 Array{Float64,2}: + 1.0 0.0 0.0 + 0.0 0.0 0.0 + E2 = + 2×3 Array{Float64,2}: + 0.0 0.0 0.0 + 1.0 0.0 0.0 + ⋮ + E5 = + 2×3 Array{Float64,2}: + 0.0 0.0 1.0 + 0.0 0.0 0.0 + E6 = + 2×3 Array{Float64,2}: + 0.0 0.0 0.0 + 0.0 0.0 1.0 + Eigenvalues: + 6-element Array{Float64,1}: + 1.0 + 2.0 + 3.0 + 4.0 + 5.0 + 6.0""" + + M = ManifoldsBase.DefaultManifold(1, 1, 1) + x = reshape(Float64[1], (1, 1, 1)) + pb = get_basis(M, x, DefaultOrthonormalBasis()) + @test sprint(show, "text/plain", pb) == """ + DefaultOrthonormalBasis(ℝ) with coordinates in ℝ and 1 basis vector: + E1 = + 1×1×1 Array{Float64,3}: + [:, :, 1] = + 1.0""" + + dpb = CachedBasis(DiagonalizingOrthonormalBasis(get_vectors(M, x, pb)), Float64[1], get_vectors(M, x, pb)) + @test sprint(show, "text/plain", dpb) == """ + DiagonalizingOrthonormalBasis with coordinates in ℝ and eigenvalue 0 in direction: + 1-element Array{Array{Float64,3},1}: + [1.0] + and 1 basis vector. + Basis vectors: + E1 = + 1×1×1 Array{Float64,3}: + [:, :, 1] = + 1.0 + Eigenvalues: + 1-element Array{Float64,1}: + 1.0""" +end diff --git a/test/runtests.jl b/test/runtests.jl index 0607337e..3c96759c 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -1,8 +1,12 @@ using Test +using ManifoldsBase @testset "ManifoldsBase" begin + # TODO: decrease the number of ambiguities + @test length(Test.detect_ambiguities(ManifoldsBase)) <= 86 include("allocation.jl") include("numbers.jl") + include("bases.jl") include("decorator_manifold.jl") include("empty_manifold.jl") include("default_manifold.jl") From a0d92d7d1ed1d0a94fa636aa1eb048349fe2ec06 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 12 Mar 2020 12:55:11 +0100 Subject: [PATCH 04/27] version set to 0.6 --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 3054a973..ce3de0ea 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "ManifoldsBase" uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb" authors = ["Seth Axen ", "Mateusz Baran ", "Ronny Bergmann ", "Antoine Levitt "] -version = "0.5.2" +version = "0.6" [deps] LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" From aa4e1f2bd661b0295f7535d67d09289f2e9bb03e Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 12 Mar 2020 13:47:07 +0100 Subject: [PATCH 05/27] new function-based decorator dispatch (has less method ambiguities) --- src/DecoratorManifold.jl | 76 +++++++++++++++++++------------------- test/decorator_manifold.jl | 4 +- test/runtests.jl | 2 +- 3 files changed, 41 insertions(+), 41 deletions(-) diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl index 752b8be4..c4c544fb 100644 --- a/src/DecoratorManifold.jl +++ b/src/DecoratorManifold.jl @@ -59,6 +59,9 @@ function _split_signature(sig::Expr) argnames = argnames, argtypes = argtypes, kwargs_call = kwargs_call, + fname__parent = Symbol(string(fname) * "__parent"), + fname__transparent = Symbol(string(fname) * "__transparent"), + fname__intransparent = Symbol(string(fname) * "__intransparent"), ) end @@ -146,12 +149,11 @@ macro decorator_transparent_fallback(fallback_case, input_ex) parts = _split_function(ex) callargs = parts[:callargs] where_exprs = parts[:where_exprs] + fname_fallback = Symbol(string(parts.fname) * "__" * string(fallback_case)[2:end]) return esc( quote - function ($(parts[:fname]))( - $(callargs[1]), - ::Val{$fallback_case}, - $(callargs[2:end]...); + function ($(fname_fallback))( + $(callargs...); $(parts[:kwargs_list]...), ) where {$(where_exprs...)} ($(parts[:body])) @@ -213,6 +215,7 @@ macro decorator_transparent_function(fallback_case, input_ex) argnames = parts[:argnames] argtypes = parts[:argtypes] kwargs_call = parts[:kwargs_call] + fname_fallback = Symbol(string(parts.fname) * "__" * string(fallback_case)[2:end]) return esc( quote @@ -221,29 +224,30 @@ macro decorator_transparent_function(fallback_case, input_ex) $(callargs[2:end]...); $(kwargs_list...), ) where {$(where_exprs...)} - return ($fname)( - $(argnames[1]), - ManifoldsBase._acts_transparently($fname, $(argnames...)), - $(argnames[2:end]...), - ; - $(kwargs_call...), - ) + transparency = ManifoldsBase._acts_transparently($fname, $(argnames...)) + if transparency === Val(:parent) + return ($(parts.fname__parent))($(argnames...); $(kwargs_call...)) + elseif transparency === Val(:transparent) + return ($(parts.fname__transparent))($(argnames...); $(kwargs_call...)) + elseif transparency === Val(:intransparent) + return ($(parts.fname__intransparent))($(argnames...); $(kwargs_call...)) + else + error("incorrect transparency: $transparency") + end end - function ($fname)( + function ($(parts[:fname__transparent]))( $(argnames[1])::AbstractDecoratorManifold, - ::Val{:transparent}, $(callargs[2:end]...); $(kwargs_list...), ) where {$(where_exprs...)} return ($fname)( - decorated_manifold($(argnames[1])), + ManifoldsBase.decorated_manifold($(argnames[1])), $(argnames[2:end]...); $(kwargs_call...), ) end - function ($fname)( + function ($(parts[:fname__intransparent]))( $(argnames[1])::AbstractDecoratorManifold, - ::Val{:intransparent}, $(callargs[2:end]...); $(kwargs_list...), ) where {$(where_exprs...)} @@ -270,9 +274,8 @@ macro decorator_transparent_function(fallback_case, input_ex) ". Maybe you missed to implement this function for a default?", )) end - function ($fname)( + function ($(parts[:fname__parent]))( $(argnames[1])::AbstractDecoratorManifold, - ::Val{:parent}, $(callargs[2:end]...); $(kwargs_list...), ) where {$(where_exprs...)} @@ -283,9 +286,8 @@ macro decorator_transparent_function(fallback_case, input_ex) $(kwargs_call...), ) end - function ($fname)( + function ($fname_fallback)( $(callargs[1]), - ::Val{$fallback_case}, $(callargs[2:end]...); $(kwargs_list...), ) where {$(where_exprs...)} @@ -345,17 +347,19 @@ macro decorator_transparent_signature(ex) return esc( quote function ($fname)($(callargs...); $(kwargs_list...)) where {$(where_exprs...)} - return ($fname)( - $(argnames[1]), - ManifoldsBase._acts_transparently($fname, $(argnames...)), - $(argnames[2:end]...); - $(kwargs_call...), - ) + transparency = ManifoldsBase._acts_transparently($fname, $(argnames...)) + if transparency === Val(:parent) + return ($(parts.fname__parent))($(argnames...); $(kwargs_call...)) + elseif transparency === Val(:transparent) + return ($(parts.fname__transparent))($(argnames...); $(kwargs_call...)) + elseif transparency === Val(:intransparent) + return ($(parts.fname__intransparent))($(argnames...); $(kwargs_call...)) + else + error("incorrect transparency: $transparency") + end end - function ($fname)( - $(callargs[1]), - ::Val{:transparent}, - $(callargs[2:end]...); + function ($(parts[:fname__transparent]))( + $(callargs...); $(kwargs_list...), ) where {$(where_exprs...)} return ($fname)( @@ -364,10 +368,8 @@ macro decorator_transparent_signature(ex) $(kwargs_call...), ) end - function ($fname)( - $(callargs[1]), - ::Val{:intransparent}, - $(callargs[2:end]...); + function ($(parts[:fname__intransparent]))( + $(callargs...); $(kwargs_list...), ) where {$(where_exprs...)} error_msg = ManifoldsBase.manifold_function_not_implemented_message( @@ -377,10 +379,8 @@ macro decorator_transparent_signature(ex) ) error(error_msg) end - function ($fname)( - $(callargs[1]), - ::Val{:parent}, - $(callargs[2:end]...); + function ($(parts[:fname__parent]))( + $(callargs...); $(kwargs_list...), ) where {$(where_exprs...)} return invoke( diff --git a/test/decorator_manifold.jl b/test/decorator_manifold.jl index 9769aa34..8d004627 100644 --- a/test/decorator_manifold.jl +++ b/test/decorator_manifold.jl @@ -105,8 +105,8 @@ end p = [1.0, 0.0, 0.0] X = [2.0, 1.0, 3.0] - @test inner(A, p, X, X) ≈ inner(A, Val(:transparent), p, X, X) - @test_throws ErrorException inner(A, Val(:intransparent), p, X, X) + @test inner(A, p, X, X) ≈ ManifoldsBase.inner__transparent(A, p, X, X) + @test_throws ErrorException ManifoldsBase.inner__intransparent(A, p, X, X) TD = TestDecorator(M) diff --git a/test/runtests.jl b/test/runtests.jl index 3c96759c..514cff92 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using ManifoldsBase @testset "ManifoldsBase" begin # TODO: decrease the number of ambiguities - @test length(Test.detect_ambiguities(ManifoldsBase)) <= 86 + @test length(Test.detect_ambiguities(ManifoldsBase)) <= 35 include("allocation.jl") include("numbers.jl") include("bases.jl") From fe58ae7fdb31fc60e7676c023837c7c3e5818f81 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 12 Mar 2020 15:01:53 +0100 Subject: [PATCH 06/27] less strict type bounds in allocate --- src/ManifoldsBase.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index fe259a2e..5df9ba07 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -164,10 +164,10 @@ abstract type CoTVector end """ allocate(a) - allocate(a, dims::Int...) + allocate(a, dims::Integer...) allocate(a, dims::Tuple) allocate(a, T::Type) - allocate(a, T::Type, dims::Int...) + allocate(a, T::Type, dims::Integer...) allocate(a, T::Type, dims::Tuple) Allocate an object similar to `a`. It is similar to function `similar`, although @@ -179,10 +179,10 @@ allocation and is forwarded to the function `similar`. """ allocate(a, args...) allocate(a) = similar(a) -allocate(a, dims::Int...) = similar(a, dims...) +allocate(a, dims::Integer...) = similar(a, dims...) allocate(a, dims::Tuple) = similar(a, dims) allocate(a, T::Type) = similar(a, T) -allocate(a, T::Type, dims::Int...) = similar(a, T, dims...) +allocate(a, T::Type, dims::Integer...) = similar(a, T, dims...) allocate(a, T::Type, dims::Tuple) = similar(a, T, dims) allocate(a::AbstractArray{<:AbstractArray}) = map(allocate, a) allocate(a::AbstractArray{<:AbstractArray}, T::Type) = map(t -> allocate(t, T), a) From 027ca59142133e446d7c4234497ad684a2372f61 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 12 Mar 2020 19:00:39 +0100 Subject: [PATCH 07/27] improving coverage --- src/DefaultManifold.jl | 6 ++++++ src/bases.jl | 6 ++++++ test/allocation.jl | 6 ++++++ test/bases.jl | 16 +++++++++------- 4 files changed, 27 insertions(+), 7 deletions(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index bbaf8b42..e5633efb 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -19,6 +19,12 @@ exp!(::DefaultManifold, y, x, v) = (y .= x .+ v) function get_basis(M::DefaultManifold, p, B::DefaultOrthonormalBasis) return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) end +function get_basis(M::DefaultManifold, p, B::DefaultOrthogonalBasis) + return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) +end +function get_basis(M::DefaultManifold, p, B::DefaultBasis) + return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) +end function get_coordinates!(M::DefaultManifold, Y, p, X, B::DefaultOrthonormalBasis) Y .= reshape(X, manifold_dimension(M)) diff --git a/src/bases.jl b/src/bases.jl index 9b337314..e5c51b07 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -475,6 +475,12 @@ function _show_basis_vector_range_noheader(io::IO, Ξ; max_vectors = 4, pre = "" end end +function show(io::IO, ::DefaultBasis{𝔽}) where {𝔽} + print(io, "DefaultBasis($(𝔽))") +end +function show(io::IO, ::DefaultOrthogonalBasis{𝔽}) where {𝔽} + print(io, "DefaultOrthogonalBasis($(𝔽))") +end function show(io::IO, ::DefaultOrthonormalBasis{𝔽}) where {𝔽} print(io, "DefaultOrthonormalBasis($(𝔽))") end diff --git a/test/allocation.jl b/test/allocation.jl index 859d8f22..1b3ea49e 100644 --- a/test/allocation.jl +++ b/test/allocation.jl @@ -1,5 +1,6 @@ using ManifoldsBase using Test +using ManifoldsBase: combine_allocation_promotion_functions struct AllocManifold <: Manifold end @@ -34,6 +35,11 @@ end @test a4 isa Matrix{Float64} @test size(a4) == (2, 3) + @test combine_allocation_promotion_functions(identity, identity) === identity + @test combine_allocation_promotion_functions(identity, complex) === complex + @test combine_allocation_promotion_functions(complex, identity) === complex + @test combine_allocation_promotion_functions(complex, complex) === complex + @test number_eltype([2.0]) == Float64 @test number_eltype([[2.0], [3]]) == Float64 @test number_eltype([[2], [3.0]]) == Float64 diff --git a/test/bases.jl b/test/bases.jl index a79f6b0e..152d327d 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -56,16 +56,16 @@ struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end M = ManifoldsBase.DefaultManifold(3) pts = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] - @testset "basis representation" begin + @testset "basis representation" for BT in (DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis) v1 = log(M, pts[1], pts[2]) - vb = get_coordinates(M, pts[1], v1, DefaultOrthonormalBasis()) + vb = get_coordinates(M, pts[1], v1, BT()) @test isa(vb, AbstractVector) - vbi = get_vector(M, pts[1], vb, DefaultOrthonormalBasis()) + vbi = get_vector(M, pts[1], vb, BT()) @test isapprox(M, pts[1], v1, vbi) - b = get_basis(M, pts[1], DefaultOrthonormalBasis()) - @test isa(b, CachedBasis{DefaultOrthonormalBasis{ℝ},Array{Array{Float64,1},1},ℝ}) + b = get_basis(M, pts[1], BT()) + @test isa(b, CachedBasis{BT{ℝ},Array{Array{Float64,1},1},ℝ}) N = manifold_dimension(M) @test length(get_vectors(M, pts[1], b)) == N # check orthonormality @@ -85,8 +85,8 @@ struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end @test inner(M, pts[1], v1, get_vectors(M, pts[1], b)[i]) ≈ vb[i] end - @test get_coordinates(M, pts[1], v1, b) ≈ get_coordinates(M, pts[1], v1, DefaultOrthonormalBasis()) - @test get_vector(M, pts[1], vb, b) ≈ get_vector(M, pts[1], vb, DefaultOrthonormalBasis()) + @test get_coordinates(M, pts[1], v1, b) ≈ get_coordinates(M, pts[1], v1, BT()) + @test get_vector(M, pts[1], vb, b) ≈ get_vector(M, pts[1], vb, BT()) v1c = allocate(v1) get_coordinates!(M, v1c, pts[1], v1, b) @@ -109,6 +109,8 @@ struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end end @testset "Basis show methods" begin + @test sprint(show, DefaultBasis()) == "DefaultBasis(ℝ)" + @test sprint(show, DefaultOrthogonalBasis()) == "DefaultOrthogonalBasis(ℝ)" @test sprint(show, DefaultOrthonormalBasis()) == "DefaultOrthonormalBasis(ℝ)" @test sprint(show, DefaultOrthonormalBasis(ℂ)) == "DefaultOrthonormalBasis(ℂ)" @test sprint(show, ProjectedOrthonormalBasis(:svd)) == "ProjectedOrthonormalBasis(:svd, ℝ)" From 4ec11b8a02574f1d494e2c6c507ed4c60edd4abb Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 12 Mar 2020 19:31:28 +0100 Subject: [PATCH 08/27] reducing ambiguities --- src/ArrayManifold.jl | 60 ++++++++++++++++++-------------------------- src/bases.jl | 59 ++++++++++++++++++++++++++++++++----------- test/runtests.jl | 2 +- 3 files changed, 69 insertions(+), 52 deletions(-) diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index 603720b7..2d2cdd00 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -157,64 +157,52 @@ function get_basis( return B end - -# the following is not nice, can we do better when using decorators and a specific last part? function get_coordinates(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) - _get_coordinates(M, p, X, B, kwargs...) -end -function get_coordinates(M::ArrayManifold, p, X, B::CachedBasis; kwargs...) - _get_coordinates(M, p, X, B, kwargs...) -end -function get_coordinates(M::ArrayManifold, p, X, B::DefaultBasis; kwargs...) - _get_coordinates(M, p, X, B, kwargs...) -end -function get_coordinates(M::ArrayManifold, p, X, B::DefaultOrthogonalBasis; kwargs...) - _get_coordinates(M, p, X, B, kwargs...) -end -function get_coordinates(M::ArrayManifold, p, X, B::DefaultOrthonormalBasis; kwargs...) - _get_coordinates(M, p, X, B, kwargs...) -end - -function _get_coordinates(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) is_tangent_vector(M, p, X, true; kwargs...) return get_coordinates(M.manifold, p, X, B) end -function get_coordinates!(M::ArrayManifold, Y, p, X, B::all_uncached_bases; kwargs...) +for BT in DISAMBIGUATION_BASIS_TYPES + eval(quote + @invoke_maker 4 AbstractBasis get_coordinates(M::ArrayManifold, p, X, B::$BT; kwargs...) + end) +end + +function get_coordinates!(M::ArrayManifold, Y, p, X, B::AbstractBasis; kwargs...) is_tangent_vector(M, p, X, true; kwargs...) get_coordinates!(M, Y, p, X, B) return Y end - -function get_vector(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) - return _get_vector(M, p, X, B, kwargs...) -end -function get_vector(M::ArrayManifold, p, X, B::CachedBasis; kwargs...) - return _get_vector(M, p, X, B, kwargs...) -end -function get_vector(M::ArrayManifold, p, X, B::DefaultBasis; kwargs...) - return _get_vector(M, p, X, B, kwargs...) -end -function get_vector(M::ArrayManifold, p, X, B::DefaultOrthogonalBasis; kwargs...) - return _get_vector(M, p, X, B, kwargs...) -end -function get_vector(M::ArrayManifold, p, X, B::DefaultOrthonormalBasis; kwargs...) - return _get_vector(M, p, X, B, kwargs...) +for BT in DISAMBIGUATION_BASIS_TYPES + eval(quote + @invoke_maker 5 AbstractBasis get_coordinates!(M::ArrayManifold, Y, p, X, B::$BT; kwargs...) + end) end -function _get_vector(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) +function get_vector(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) is_manifold_point(M, p, true; kwargs...) size(X) == (manifold_dimension(M),) || error("Incorrect size of coefficient vector X") Y = get_vector(M.manifold, p, X, B) size(Y) == representation_size(M) || error("Incorrect size of tangent vector Y") return Y end -function get_vector!(M::ArrayManifold, Y, p, X, B::all_uncached_bases; kwargs...) +for BT in DISAMBIGUATION_BASIS_TYPES + eval(quote + @invoke_maker 4 AbstractBasis get_vector(M::ArrayManifold, p, X, B::$BT; kwargs...) + end) +end + +function get_vector!(M::ArrayManifold, Y, p, X, B::AbstractBasis; kwargs...) is_manifold_point(M, p, true; kwargs...) size(X) == (manifold_dimension(M),) || error("Incorrect size of coefficient vector X") get_vector!(M.manifold, Y, p, X, B) size(Y) == representation_size(M) || error("Incorrect size of tangent vector Y") return Y end +for BT in DISAMBIGUATION_BASIS_TYPES + eval(quote + @invoke_maker 5 AbstractBasis get_vector!(M::ArrayManifold, Y, p, X, B::$BT; kwargs...) + end) +end injectivity_radius(M::ArrayManifold) = injectivity_radius(M.manifold) function injectivity_radius(M::ArrayManifold, method::AbstractRetractionMethod) diff --git a/src/bases.jl b/src/bases.jl index e5c51b07..f8324259 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -137,7 +137,15 @@ function get_coordinates end function get_vector end const all_uncached_bases = Union{AbstractBasis, DefaultBasis, DefaultOrthogonalBasis, DefaultOrthonormalBasis} -const DISAMBIGUATION_BASIS_TYPES = [CachedBasis, DefaultOrthogonalBasis, DefaultOrthonormalBasis, DefaultOrDiagonalizingBasis, VeeOrthogonalBasis] +const DISAMBIGUATION_BASIS_TYPES = [ + CachedBasis, + CachedBasis{<:AbstractBasis{ℝ}}, + DefaultBasis, + DefaultOrthonormalBasis, + DefaultOrthogonalBasis, + DiagonalizingOrthonormalBasis, + VeeOrthogonalBasis, +] function allocate_result(M::Manifold, f::typeof(get_coordinates), p, X) T = allocate_result_type(M, f, (p, X)) @@ -291,13 +299,11 @@ function get_coordinates!(M::Manifold, Y, p, X, B::AbstractBasis) error("get_coordinates! not implemented for manifold of type $(typeof(M)) coordinates of type $(typeof(Y)), a point of type $(typeof(p)), tangent vector of type $(typeof(X)) and basis of type $(typeof(B)).") end @decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::AbstractBasis) -@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis) -@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis{BT,V,𝔽}) where {BT<:AbstractBasis{ℝ}, 𝔽, V} -@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultBasis) -@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::VeeOrthogonalBasis) -@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthogonalBasis) -@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthonormalBasis) -@decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::DiagonalizingOrthonormalBasis) +for BT in DISAMBIGUATION_BASIS_TYPES + eval(quote + @decorator_transparent_signature get_coordinates!(M::AbstractDecoratorManifold, Y, p, X, B::$BT) + end) +end function decorator_transparent_dispatch(::typeof(get_coordinates!), ::Manifold, args...) return Val(:transparent) end @@ -355,13 +361,11 @@ function get_vector!(M::Manifold, Y, p, X, B::AbstractBasis) error("get_vector! not implemented for manifold of type $(typeof(M)) vector of type $(typeof(Y)), a point of type $(typeof(p)), coordinates of type $(typeof(X)) and basis of type $(typeof(B)).") end @decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::AbstractBasis) -@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis) -@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::CachedBasis{BT,V,𝔽}) where {BT<:AbstractBasis{ℝ}, 𝔽, V} -@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultBasis) -@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::VeeOrthogonalBasis) -@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthogonalBasis) -@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DefaultOrthonormalBasis) -@decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::DiagonalizingOrthonormalBasis) +for BT in DISAMBIGUATION_BASIS_TYPES + eval(quote + @decorator_transparent_signature get_vector!(M::AbstractDecoratorManifold, Y, p, X, B::$BT) + end) +end function decorator_transparent_dispatch(::typeof(get_vector!), ::Manifold, args...) return Val(:transparent) end @@ -545,3 +549,28 @@ inverse. """ vee(M::Manifold, p, X) = get_coordinates(M, p, X, VeeOrthogonalBasis()) vee!(M::Manifold, Y, p, X) = get_coordinates!(M, Y, p, X, VeeOrthogonalBasis()) + +macro invoke_maker(argnum, type, sig) + parts = ManifoldsBase._split_signature(sig) + kwargs_list = parts[:kwargs_list] + callargs = parts[:callargs] + fname = parts[:fname] + where_exprs = parts[:where_exprs] + argnames = parts[:argnames] + argtypes = parts[:argtypes] + kwargs_call = parts[:kwargs_call] + + return esc(quote + function ($fname)( + $(callargs...); + $(kwargs_list...), + ) where {$(where_exprs...)} + return invoke( + $fname, + Tuple{$(argtypes[1:argnum-1]...),$type,$(argtypes[argnum+1:end]...)}, + $(argnames...); + $(kwargs_call...), + ) + end + end) +end diff --git a/test/runtests.jl b/test/runtests.jl index 514cff92..49c90a5a 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using ManifoldsBase @testset "ManifoldsBase" begin # TODO: decrease the number of ambiguities - @test length(Test.detect_ambiguities(ManifoldsBase)) <= 35 + @test length(Test.detect_ambiguities(ManifoldsBase)) <= 12 include("allocation.jl") include("numbers.jl") include("bases.jl") From c4a8c410e600216954c65538e9793e404f06848c Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Thu, 12 Mar 2020 22:15:08 +0100 Subject: [PATCH 09/27] trying to increase coverage further --- src/ArrayManifold.jl | 2 +- src/DefaultManifold.jl | 30 ++++++++++++++++++++++++++++++ test/allocation.jl | 3 ++- test/array_manifold.jl | 19 +++++++++++++++++++ test/bases.jl | 12 ++---------- 5 files changed, 54 insertions(+), 12 deletions(-) diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index 2d2cdd00..cf4337da 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -169,7 +169,7 @@ end function get_coordinates!(M::ArrayManifold, Y, p, X, B::AbstractBasis; kwargs...) is_tangent_vector(M, p, X, true; kwargs...) - get_coordinates!(M, Y, p, X, B) + get_coordinates!(M.manifold, Y, p, X, B) return Y end for BT in DISAMBIGUATION_BASIS_TYPES diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index e5633efb..39703527 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -12,6 +12,36 @@ situations to verify correctness of involved variabes. struct DefaultManifold{T<:Tuple} <: Manifold where {T} end DefaultManifold(n::Vararg{Int,N}) where {N} = DefaultManifold{Tuple{n...}}() +function check_manifold_point(M::DefaultManifold, p; kwargs...) + if size(p) != representation_size(M) + return DomainError( + size(p), + "The point $(p) does not lie on $M, since its size is not $(N+1).", + ) + end + return nothing +end + +function check_tangent_vector( + M::DefaultManifold, + p, + X; + check_base_point = true, + kwargs..., +) + if check_base_point + perr = check_manifold_point(M, p) + perr === nothing || return perr + end + if size(X) != representation_size(M) + return DomainError( + size(X), + "The vector $(X) is not a tangent to a point on $M since its size does not match $(N+1).", + ) + end + return nothing +end + distance(::DefaultManifold, x, y) = norm(x - y) exp!(::DefaultManifold, y, x, v) = (y .= x .+ v) diff --git a/test/allocation.jl b/test/allocation.jl index 1b3ea49e..1abd9a4f 100644 --- a/test/allocation.jl +++ b/test/allocation.jl @@ -1,6 +1,6 @@ using ManifoldsBase using Test -using ManifoldsBase: combine_allocation_promotion_functions +using ManifoldsBase: combine_allocation_promotion_functions, allocation_promotion_function struct AllocManifold <: Manifold end @@ -35,6 +35,7 @@ end @test a4 isa Matrix{Float64} @test size(a4) == (2, 3) + @test allocation_promotion_function(M, exp, (a, b)) === identity @test combine_allocation_promotion_functions(identity, identity) === identity @test combine_allocation_promotion_functions(identity, complex) === complex @test combine_allocation_promotion_functions(complex, identity) === complex diff --git a/test/array_manifold.jl b/test/array_manifold.jl index 4dd32f3f..5e7249a5 100644 --- a/test/array_manifold.jl +++ b/test/array_manifold.jl @@ -117,4 +117,23 @@ end @test injectivity_radius(A, CustomArrayManifoldRetraction()) == 10 @test injectivity_radius(A, x, CustomArrayManifoldRetraction()) == 11 end + + @testset "ArrayManifold basis" begin + for BT in (DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis) + cb = BT() + @test_broken b = get_basis(A, x, cb) + v = similar(x) + @test_throws ErrorException get_vector(A, x, [1.0], cb) + @test_throws ErrorException get_coordinates(A, x, [1.0], cb) + @test_throws ErrorException get_vector!(A, v, x, [], cb) + @test_throws ErrorException get_coordinates!(A, v, x, [], cb) + @test get_vector(A, x, [1, 2, 3], cb) ≈ get_vector(M, x, [1, 2, 3], cb) + @test get_coordinates(A, x, [1, 2, 3], cb) ≈ get_coordinates(M, x, [1, 2, 3], cb) + + + @test_throws ArgumentError get_basis(A, x, CachedBasis(cb, [x])) + @test_throws ArgumentError get_basis(A, x, CachedBasis(cb, [x, x, x])) + @test_throws ArgumentError get_basis(A, x, CachedBasis(cb, [2*x, x, x])) + end + end end diff --git a/test/bases.jl b/test/bases.jl index 152d327d..c97c3d1d 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -55,6 +55,7 @@ struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end end M = ManifoldsBase.DefaultManifold(3) + pts = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] @testset "basis representation" for BT in (DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis) v1 = log(M, pts[1], pts[2]) @@ -66,6 +67,7 @@ struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end b = get_basis(M, pts[1], BT()) @test isa(b, CachedBasis{BT{ℝ},Array{Array{Float64,1},1},ℝ}) + @test get_basis(M, pts[1], b) === b N = manifold_dimension(M) @test length(get_vectors(M, pts[1], b)) == N # check orthonormality @@ -96,16 +98,6 @@ struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end get_vector!(M, v1cv, pts[1], v1c, b) @test isapprox(M, pts[1], v1, v1cv) end - - @testset "ArrayManifold basis" begin - A = ArrayManifold(M) - aonb = DefaultOrthonormalBasis() - b = get_basis(A, pts[1], aonb) - @test_throws ErrorException get_vector(A, pts[1], [], aonb) - @test_throws ArgumentError get_basis(A, pts[1], CachedBasis(aonb,[pts[1]])) - @test_throws ArgumentError get_basis(A, pts[1], CachedBasis(aonb,[pts[1], pts[1], pts[1]])) - @test_throws ArgumentError get_basis(A, pts[1], CachedBasis(aonb,[2*pts[1], pts[1], pts[1]])) - end end @testset "Basis show methods" begin From 1a2b8456a1110aa1981b925c0572d8b7efb9736a Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Thu, 12 Mar 2020 23:37:07 +0100 Subject: [PATCH 10/27] fixes one of the test by TVector-decorating later. --- src/ArrayManifold.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index cf4337da..a9e1188d 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -251,9 +251,9 @@ end function log(M::ArrayManifold, x, y; kwargs...) is_manifold_point(M, x, true; kwargs...) is_manifold_point(M, y, true; kwargs...) - v = ArrayTVector(log(M.manifold, array_value(x), array_value(y))) + v = log(M.manifold, array_value(x), array_value(y)) is_tangent_vector(M, x, v, true; kwargs...) - return v + return ArrayTVector(v) end function log!(M::ArrayManifold, v, x, y; kwargs...) From 362786cc87d982eb586a4fe590f87afd5430a92e Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Fri, 13 Mar 2020 09:19:59 +0100 Subject: [PATCH 11/27] fixes exP7log by loosening the type of checks of ArrayManifolds, which were too strict to hit for nonencapsulated data. --- src/ArrayManifold.jl | 290 +++++++++++++++++++-------------------- src/DecoratorManifold.jl | 1 - 2 files changed, 145 insertions(+), 146 deletions(-) diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index a9e1188d..2beec340 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -46,88 +46,88 @@ struct ArrayCoTVector{V<:AbstractArray{<:Number}} <: TVector value::V end -(+)(v1::ArrayCoTVector, v2::ArrayCoTVector) = ArrayCoTVector(v1.value + v2.value) -(-)(v1::ArrayCoTVector, v2::ArrayCoTVector) = ArrayCoTVector(v1.value - v2.value) -(-)(v::ArrayCoTVector) = ArrayCoTVector(-v.value) -(*)(a::Number, v::ArrayCoTVector) = ArrayCoTVector(a * v.value) +(+)(X::ArrayCoTVector, Y::ArrayCoTVector) = ArrayCoTVector(X.value + Y.value) +(-)(X::ArrayCoTVector, Y::ArrayCoTVector) = ArrayCoTVector(X.value - Y.value) +(-)(X::ArrayCoTVector) = ArrayCoTVector(-X.value) +(*)(a::Number, X::ArrayCoTVector) = ArrayCoTVector(a * X.value) -(+)(v1::ArrayTVector, v2::ArrayTVector) = ArrayTVector(v1.value + v2.value) -(-)(v1::ArrayTVector, v2::ArrayTVector) = ArrayTVector(v1.value - v2.value) -(-)(v::ArrayTVector) = ArrayTVector(-v.value) -(*)(a::Number, v::ArrayTVector) = ArrayTVector(a * v.value) +(+)(X::ArrayTVector, Y::ArrayTVector) = ArrayTVector(X.value + Y.value) +(-)(X::ArrayTVector, Y::ArrayTVector) = ArrayTVector(X.value - Y.value) +(-)(X::ArrayTVector) = ArrayTVector(-X.value) +(*)(a::Number, X::ArrayTVector) = ArrayTVector(a * X.value) -allocate(x::ArrayMPoint) = ArrayMPoint(allocate(x.value)) -allocate(x::ArrayMPoint, ::Type{T}) where {T} = ArrayMPoint(allocate(x.value, T)) -allocate(x::ArrayTVector) = ArrayTVector(allocate(x.value)) -allocate(x::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(allocate(x.value, T)) +allocate(p::ArrayMPoint) = ArrayMPoint(allocate(p.value)) +allocate(p::ArrayMPoint, ::Type{T}) where {T} = ArrayMPoint(allocate(p.value, T)) +allocate(p::ArrayTVector) = ArrayTVector(allocate(p.value)) +allocate(p::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(allocate(p.value, T)) """ - array_value(x) + array_value(p) -Return the internal array value of a [`ArrayMPoint`](@ref), [`ArrayTVector`](@ref), or -[`ArrayCoTVector`](@ref) if the value `x` is encapsulated as such. Return `x` if it is +Return the internal array value of an [`ArrayMPoint`](@ref), [`ArrayTVector`](@ref), or +[`ArrayCoTVector`](@ref) if the value `p` is encapsulated as such. Return `p` if it is already an array. """ -array_value(x::AbstractArray) = x -array_value(x::ArrayMPoint) = x.value -array_value(v::ArrayTVector) = v.value -array_value(v::ArrayCoTVector) = v.value +array_value(p::AbstractArray) = p +array_value(p::ArrayMPoint) = p.value +array_value(X::ArrayTVector) = X.value +array_value(ξ::ArrayCoTVector) = ξ.value -function check_manifold_point(M::ArrayManifold, x::MPoint; kwargs...) - return check_manifold_point(M.manifold, array_value(x); kwargs...) +function check_manifold_point(M::ArrayManifold, p; kwargs...) + return check_manifold_point(M.manifold, array_value(p); kwargs...) end -function check_tangent_vector(M::ArrayManifold, x::MPoint, v::TVector; kwargs...) - return check_tangent_vector(M.manifold, array_value(x), array_value(v); kwargs...) +function check_tangent_vector(M::ArrayManifold, p, X; kwargs...) + return check_tangent_vector(M.manifold, array_value(p), array_value(X); kwargs...) end -convert(::Type{V}, v::ArrayCoTVector{V}) where {V<:AbstractArray{<:Number}} = v.value -function convert(::Type{ArrayCoTVector{V}}, v::V) where {V<:AbstractArray{<:Number}} - return ArrayCoTVector{V}(v) +convert(::Type{V}, X::ArrayCoTVector{V}) where {V<:AbstractArray{<:Number}} = X.value +function convert(::Type{ArrayCoTVector{V}}, X::V) where {V<:AbstractArray{<:Number}} + return ArrayCoTVector{V}(X) end convert(::Type{M}, m::ArrayManifold{M}) where {M<:Manifold} = m.manifold convert(::Type{ArrayManifold{M}}, m::M) where {M<:Manifold} = ArrayManifold(m) -convert(::Type{V}, x::ArrayMPoint{V}) where {V<:AbstractArray{<:Number}} = x.value +convert(::Type{V}, p::ArrayMPoint{V}) where {V<:AbstractArray{<:Number}} = p.value convert(::Type{ArrayMPoint{V}}, x::V) where {V<:AbstractArray{<:Number}} = ArrayMPoint{V}(x) -convert(::Type{V}, v::ArrayTVector{V}) where {V<:AbstractArray{<:Number}} = v.value -function convert(::Type{ArrayTVector{V}}, v::V) where {V<:AbstractArray{<:Number}} - return ArrayTVector{V}(v) +convert(::Type{V}, X::ArrayTVector{V}) where {V<:AbstractArray{<:Number}} = X.value +function convert(::Type{ArrayTVector{V}}, X::V) where {V<:AbstractArray{<:Number}} + return ArrayTVector{V}(X) end -function copyto!(x::ArrayMPoint, y::ArrayMPoint) - copyto!(x.value, y.value) - return x +function copyto!(p::ArrayMPoint, q::ArrayMPoint) + copyto!(p.value, q.value) + return p end -function copyto!(x::ArrayCoTVector, y::ArrayCoTVector) - copyto!(x.value, y.value) - return x +function copyto!(p::ArrayCoTVector, q::ArrayCoTVector) + copyto!(p.value, q.value) + return p end -function copyto!(x::ArrayTVector, y::ArrayTVector) - copyto!(x.value, y.value) - return x +function copyto!(p::ArrayTVector, q::ArrayTVector) + copyto!(p.value, q.value) + return p end -function distance(M::ArrayManifold, x, y; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_manifold_point(M, y, true; kwargs...) - return distance(M.manifold, array_value(x), array_value(y)) +function distance(M::ArrayManifold, p, q; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_manifold_point(M, q, true; kwargs...) + return distance(M.manifold, array_value(p), array_value(q)) end -function exp(M::ArrayManifold, x, v; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_tangent_vector(M, x, v, true; kwargs...) - y = ArrayMPoint(exp(M.manifold, array_value(x), array_value(v))) +function exp(M::ArrayManifold, p, X; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) + y = exp(M.manifold, array_value(p), array_value(X)) is_manifold_point(M, y, true; kwargs...) - return y + return ArrayMPoint(y) end -function exp!(M::ArrayManifold, y, x, v; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_tangent_vector(M, x, v, true; kwargs...) - exp!(M.manifold, array_value(y), array_value(x), array_value(v)) - is_manifold_point(M, y, true; kwargs...) - return y +function exp!(M::ArrayManifold, q, p, X; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) + exp!(M.manifold, array_value(q), array_value(p), array_value(X)) + is_manifold_point(M, q, true; kwargs...) + return q end function get_basis( @@ -208,161 +208,161 @@ injectivity_radius(M::ArrayManifold) = injectivity_radius(M.manifold) function injectivity_radius(M::ArrayManifold, method::AbstractRetractionMethod) return injectivity_radius(M.manifold, method) end -function injectivity_radius(M::ArrayManifold, x; kwargs...) - is_manifold_point(M, x, true; kwargs...) - return injectivity_radius(M.manifold, array_value(x)) +function injectivity_radius(M::ArrayManifold, p; kwargs...) + is_manifold_point(M, p, true; kwargs...) + return injectivity_radius(M.manifold, array_value(p)) end function injectivity_radius( M::ArrayManifold, - x, + p, method::AbstractRetractionMethod; kwargs..., ) - is_manifold_point(M, x, true; kwargs...) - return injectivity_radius(M.manifold, array_value(x), method) + is_manifold_point(M, p, true; kwargs...) + return injectivity_radius(M.manifold, array_value(p), method) end function injectivity_radius(M::ArrayManifold, method::ExponentialRetraction) return injectivity_radius(M.manifold, method) end -function injectivity_radius(M::ArrayManifold, x, method::ExponentialRetraction; kwargs...) - is_manifold_point(M, x, true; kwargs...) - return injectivity_radius(M.manifold, array_value(x), method) +function injectivity_radius(M::ArrayManifold, p, method::ExponentialRetraction; kwargs...) + is_manifold_point(M, p, true; kwargs...) + return injectivity_radius(M.manifold, array_value(p), method) end -function inner(M::ArrayManifold, x, v, w; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_tangent_vector(M, x, v, true; kwargs...) - is_tangent_vector(M, x, w, true; kwargs...) - return inner(M.manifold, array_value(x), array_value(v), array_value(w)) +function inner(M::ArrayManifold, p, X, Y; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) + is_tangent_vector(M, p, Y, true; kwargs...) + return inner(M.manifold, array_value(p), array_value(X), array_value(Y)) end -function isapprox(M::ArrayManifold, x, y; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_manifold_point(M, y, true; kwargs...) - return isapprox(M.manifold, array_value(x), array_value(y); kwargs...) +function isapprox(M::ArrayManifold, p, q; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_manifold_point(M, q, true; kwargs...) + return isapprox(M.manifold, array_value(p), array_value(q); kwargs...) end -function isapprox(M::ArrayManifold, x, v, w; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_tangent_vector(M, x, v, true; kwargs...) - is_tangent_vector(M, x, w, true; kwargs...) - return isapprox(M.manifold, array_value(x), array_value(v), array_value(w); kwargs...) +function isapprox(M::ArrayManifold, p, X, Y; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) + is_tangent_vector(M, p, Y, true; kwargs...) + return isapprox(M.manifold, array_value(p), array_value(X), array_value(Y); kwargs...) end -function log(M::ArrayManifold, x, y; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_manifold_point(M, y, true; kwargs...) - v = log(M.manifold, array_value(x), array_value(y)) - is_tangent_vector(M, x, v, true; kwargs...) - return ArrayTVector(v) +function log(M::ArrayManifold, p, q; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_manifold_point(M, q, true; kwargs...) + X = log(M.manifold, array_value(p), array_value(q)) + is_tangent_vector(M, p, X, true; kwargs...) + return ArrayTVector(X) end -function log!(M::ArrayManifold, v, x, y; kwargs...) - is_manifold_point(M, x, true; kwargs...) - is_manifold_point(M, y, true; kwargs...) - log!(M.manifold, array_value(v), array_value(x), array_value(y)) - is_tangent_vector(M, x, v, true; kwargs...) - return v +function log!(M::ArrayManifold, X, p, q; kwargs...) + is_manifold_point(M, p, true; kwargs...) + is_manifold_point(M, q, true; kwargs...) + log!(M.manifold, array_value(X), array_value(p), array_value(q)) + is_tangent_vector(M, p, X, true; kwargs...) + return X end number_eltype(::Type{ArrayMPoint{V}}) where {V} = number_eltype(V) -number_eltype(x::ArrayMPoint) = number_eltype(x.value) +number_eltype(p::ArrayMPoint) = number_eltype(p.value) number_eltype(::Type{ArrayCoTVector{V}}) where {V} = number_eltype(V) -number_eltype(x::ArrayCoTVector) = number_eltype(x.value) +number_eltype(p::ArrayCoTVector) = number_eltype(p.value) number_eltype(::Type{ArrayTVector{V}}) where {V} = number_eltype(V) -number_eltype(x::ArrayTVector) = number_eltype(x.value) +number_eltype(p::ArrayTVector) = number_eltype(p.value) -function project_tangent!(M::ArrayManifold, w, x, v; kwargs...) - is_manifold_point(M, x, true; kwargs...) - project_tangent!(M.manifold, w.value, array_value(x), array_value(v)) - is_tangent_vector(M, x, w, true; kwargs...) - return w +function project_tangent!(M::ArrayManifold, Y, p, X; kwargs...) + is_manifold_point(M, p, true; kwargs...) + project_tangent!(M.manifold, array_value(Y), array_value(p), array_value(X)) + is_tangent_vector(M, p, Y, true; kwargs...) + return Y end -similar(x::ArrayMPoint) = ArrayMPoint(similar(x.value)) -similar(x::ArrayMPoint, ::Type{T}) where {T} = ArrayMPoint(similar(x.value, T)) -similar(x::ArrayCoTVector) = ArrayCoTVector(similar(x.value)) -similar(x::ArrayCoTVector, ::Type{T}) where {T} = ArrayCoTVector(similar(x.value, T)) -similar(x::ArrayTVector) = ArrayTVector(similar(x.value)) -similar(x::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(similar(x.value, T)) +similar(p::ArrayMPoint) = ArrayMPoint(similar(p.value)) +similar(p::ArrayMPoint, ::Type{T}) where {T} = ArrayMPoint(similar(p.value, T)) +similar(p::ArrayCoTVector) = ArrayCoTVector(similar(p.value)) +similar(p::ArrayCoTVector, ::Type{T}) where {T} = ArrayCoTVector(similar(p.value, T)) +similar(p::ArrayTVector) = ArrayTVector(similar(p.value)) +similar(p::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(similar(p.value, T)) function vector_transport_along!( M::ArrayManifold, - vto, - x, - v, + Y, + p, + X, c, m::AbstractVectorTransportMethod; kwargs..., ) - is_tangent_vector(M, x, v, true; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) vector_transport_along!( M.manifold, - array_value(vto), - array_value(x), - array_value(v), + array_value(Y), + array_value(p), + array_value(X), c, m, ) - is_tangent_vector(M, c(1), vto, true; kwargs...) - return vto + is_tangent_vector(M, c(1), Y, true; kwargs...) + return Y end function vector_transport_to!( M::ArrayManifold, - vto, - x, - v, - y, + Y, + p, + X, + q, m::AbstractVectorTransportMethod; kwargs..., ) - is_manifold_point(M, y, true; kwargs...) - is_tangent_vector(M, x, v, true; kwargs...) + is_manifold_point(M, q, true; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) vector_transport_to!( M.manifold, - array_value(vto), - array_value(x), - array_value(v), - array_value(y), + array_value(Y), + array_value(p), + array_value(X), + array_value(q), m, ) - is_tangent_vector(M, y, vto, true; kwargs...) - return vto + is_tangent_vector(M, q, Y, true; kwargs...) + return Y end function vector_transport_to!( M::ArrayManifold, - vto, - x, - v, - y, + Y, + p, + X, + q, m::ProjectionTransport; kwargs..., ) - is_manifold_point(M, y, true; kwargs...) - is_tangent_vector(M, x, v, true; kwargs...) + is_manifold_point(M, q, true; kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) vector_transport_to!( M.manifold, - array_value(vto), - array_value(x), - array_value(v), - array_value(y), + array_value(Y), + array_value(p), + array_value(X), + array_value(q), m, ) - is_tangent_vector(M, y, vto, true; kwargs...) - return vto + is_tangent_vector(M, q, Y, true; kwargs...) + return Y end -function zero_tangent_vector(M::ArrayManifold, x; kwargs...) - is_manifold_point(M, x, true; kwargs...) - w = zero_tangent_vector(M.manifold, array_value(x)) - is_tangent_vector(M, x, w, true; kwargs...) +function zero_tangent_vector(M::ArrayManifold, p; kwargs...) + is_manifold_point(M, p, true; kwargs...) + w = zero_tangent_vector(M.manifold, array_value(p)) + is_tangent_vector(M, p, w, true; kwargs...) return w end -function zero_tangent_vector!(M::ArrayManifold, v, x; kwargs...) - is_manifold_point(M, x, true; kwargs...) - zero_tangent_vector!(M.manifold, array_value(v), array_value(x); kwargs...) - is_tangent_vector(M, x, v, true; kwargs...) - return v +function zero_tangent_vector!(M::ArrayManifold, X, p; kwargs...) + is_manifold_point(M, p, true; kwargs...) + zero_tangent_vector!(M.manifold, array_value(X), array_value(p); kwargs...) + is_tangent_vector(M, p, X, true; kwargs...) + return X end diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl index c4c544fb..a2cc249c 100644 --- a/src/DecoratorManifold.jl +++ b/src/DecoratorManifold.jl @@ -495,7 +495,6 @@ Return the manifold decorated by the decorator `M`. Defaults to `M.manifold`. """ decorated_manifold(M::Manifold) = M.manifold - @decorator_transparent_signature distance(M::AbstractDecoratorManifold, p, q) @decorator_transparent_signature exp(M::AbstractDecoratorManifold, p, X) From 4ded0f13fbe3b3b5fe192b19196b798b96b92cec Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Fri, 13 Mar 2020 10:21:52 +0100 Subject: [PATCH 12/27] Fixes all tests in ArrayManifold concerning bases. --- src/ArrayManifold.jl | 46 ++++++++++++++++++++---------------------- src/DefaultManifold.jl | 2 +- test/array_manifold.jl | 29 ++++++++++++++------------ 3 files changed, 39 insertions(+), 38 deletions(-) diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index 2beec340..c3de701a 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -77,10 +77,16 @@ array_value(ξ::ArrayCoTVector) = ξ.value function check_manifold_point(M::ArrayManifold, p; kwargs...) return check_manifold_point(M.manifold, array_value(p); kwargs...) end +function check_manifold_point(M::ArrayManifold, p::MPoint; kwargs...) + return check_manifold_point(M.manifold, array_value(p); kwargs...) +end function check_tangent_vector(M::ArrayManifold, p, X; kwargs...) return check_tangent_vector(M.manifold, array_value(p), array_value(X); kwargs...) end +function check_tangent_vector(M::ArrayManifold, p::MPoint, X::TVector; kwargs...) + return check_tangent_vector(M.manifold, array_value(p), array_value(X); kwargs...) +end convert(::Type{V}, X::ArrayCoTVector{V}) where {V<:AbstractArray{<:Number}} = X.value function convert(::Type{ArrayCoTVector{V}}, X::V) where {V<:AbstractArray{<:Number}} @@ -130,34 +136,26 @@ function exp!(M::ArrayManifold, q, p, X; kwargs...) return q end -function get_basis( - M::ArrayManifold, - p, - B::CachedBasis{<:AbstractOrthonormalBasis{ℝ},T,ℝ}, -) where {T<:AbstractVector} - bvectors = get_vectors(M, p, B) - N = length(bvectors) - M_dim = manifold_dimension(M) - if N != M_dim - - throw(ArgumentError("Incorrect number of basis vectors; expected: $M_dim, given: $N")) - end - for i = 1:N - Xi_norm = norm(M, p, bvectors[i]) - if !isapprox(Xi_norm, 1) - throw(ArgumentError("vector number $i is not normalized (norm = $Xi_norm)")) - end - for j = i+1:N - dot_val = real(inner(M, p, bvectors[i], bvectors[j])) - if !isapprox(dot_val, 0; atol = eps(eltype(p))) - throw(ArgumentError("vectors number $i and $j are not orthonormal (inner product = $dot_val)")) - end - end +function get_basis(M::ArrayManifold, p, B::AbstractBasis; kwargs...) + is_manifold_point(M, p, true; kwargs...) + Ξ = get_basis(M.manifold, array_value(p), B) + nV = length(get_vectors(M.manifold, array_value(p), Ξ)) + if nV != manifold_dimension(M.manifold) + return ErrorException( + "For a basis of the tangent space at $(p) of $(M.manifold), $(manifold_dimension(M)) vectors are required, but get_basis $(B) computed $(nV)" + ) end - return B + map(X -> is_tangent_vector(M, p, X, true; kwargs...), get_vectors(M.manifold, array_value(p), Ξ)) + return Ξ +end +for BT in DISAMBIGUATION_BASIS_TYPES + eval(quote + @invoke_maker 3 AbstractBasis get_basis(M::ArrayManifold, p, B::$BT; kwargs...) + end) end function get_coordinates(M::ArrayManifold, p, X, B::AbstractBasis; kwargs...) + is_manifold_point(M, p, true; kwargs...) is_tangent_vector(M, p, X, true; kwargs...) return get_coordinates(M.manifold, p, X, B) end diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 39703527..7332ddbe 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -36,7 +36,7 @@ function check_tangent_vector( if size(X) != representation_size(M) return DomainError( size(X), - "The vector $(X) is not a tangent to a point on $M since its size does not match $(N+1).", + "The vector $(X) is not a tangent to a point on $M since its size does not match $(representation_size(M)).", ) end return nothing diff --git a/test/array_manifold.jl b/test/array_manifold.jl index 5e7249a5..a8c18b45 100644 --- a/test/array_manifold.jl +++ b/test/array_manifold.jl @@ -119,21 +119,24 @@ end end @testset "ArrayManifold basis" begin + b = [Matrix(I,3,3)[:,i] for i=1:3] for BT in (DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis) - cb = BT() - @test_broken b = get_basis(A, x, cb) - v = similar(x) - @test_throws ErrorException get_vector(A, x, [1.0], cb) - @test_throws ErrorException get_coordinates(A, x, [1.0], cb) - @test_throws ErrorException get_vector!(A, v, x, [], cb) - @test_throws ErrorException get_coordinates!(A, v, x, [], cb) - @test get_vector(A, x, [1, 2, 3], cb) ≈ get_vector(M, x, [1, 2, 3], cb) - @test get_coordinates(A, x, [1, 2, 3], cb) ≈ get_coordinates(M, x, [1, 2, 3], cb) + @testset "Basis $(BT)" begin + cb = BT() + @test b == get_vectors(M, x, get_basis(A,x,cb)) + v = similar(x) + @test_throws ErrorException get_vector(A, x, [1.0], cb) + @test_throws DomainError get_coordinates(A, x, [1.0], cb) + @test_throws ErrorException get_vector!(A, v, x, [], cb) + @test_throws DomainError get_coordinates!(A, v, x, [], cb) + @test get_vector(A, x, [1, 2, 3], cb) ≈ get_vector(M, x, [1, 2, 3], cb) + @test get_coordinates(A, x, [1, 2, 3], cb) ≈ get_coordinates(M, x, [1, 2, 3], cb) - - @test_throws ArgumentError get_basis(A, x, CachedBasis(cb, [x])) - @test_throws ArgumentError get_basis(A, x, CachedBasis(cb, [x, x, x])) - @test_throws ArgumentError get_basis(A, x, CachedBasis(cb, [2*x, x, x])) + @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [x])) + @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [x, x, x])) + @test_throws ErrorException + get_basis(A, x, CachedBasis(cb, [2*x, x, x])) + end end end end From c0c8128629b582b7388a444f635c97f81a6a0c9d Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 13 Mar 2020 12:17:36 +0100 Subject: [PATCH 13/27] actually fixing basis-array manifold interactions --- src/ArrayManifold.jl | 67 +++++++++++++++++++++++++++++++++++----- src/DecoratorManifold.jl | 13 ++++++++ src/bases.jl | 4 +++ test/array_manifold.jl | 8 +++-- test/runtests.jl | 2 +- 5 files changed, 84 insertions(+), 10 deletions(-) diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index c3de701a..c677078d 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -139,18 +139,71 @@ end function get_basis(M::ArrayManifold, p, B::AbstractBasis; kwargs...) is_manifold_point(M, p, true; kwargs...) Ξ = get_basis(M.manifold, array_value(p), B) - nV = length(get_vectors(M.manifold, array_value(p), Ξ)) - if nV != manifold_dimension(M.manifold) - return ErrorException( - "For a basis of the tangent space at $(p) of $(M.manifold), $(manifold_dimension(M)) vectors are required, but get_basis $(B) computed $(nV)" - ) + bvectors = get_vectors(M, p, Ξ) + N = length(bvectors) + if N != manifold_dimension(M.manifold) + throw(ErrorException( + "For a basis of the tangent space at $(p) of $(M.manifold), $(manifold_dimension(M)) vectors are required, but get_basis $(B) computed $(N)" + )) + end + # check that the vectors are linearly independent\ + bv_rank = rank(reduce(hcat, bvectors)) + if N != bv_rank + throw(ErrorException( + "For a basis of the tangent space at $(p) of $(M.manifold), $(manifold_dimension(M)) linearly independent vectors are required, but get_basis $(B) computed $(bv_rank)" + )) + end + map(X -> is_tangent_vector(M, p, X, true; kwargs...), bvectors) + return Ξ +end +function get_basis( + M::ArrayManifold, + p, + B::Union{AbstractOrthogonalBasis,CachedBasis{<:AbstractOrthogonalBasis}}; + kwargs..., +) + is_manifold_point(M, p, true; kwargs...) + Ξ = invoke(get_basis, Tuple{ArrayManifold,Any,AbstractBasis}, M, p, B; kwargs...) + bvectors = get_vectors(M, p, Ξ) + N = length(bvectors) + for i = 1:N + for j = i+1:N + dot_val = real(inner(M, p, bvectors[i], bvectors[j])) + if !isapprox(dot_val, 0; atol = eps(eltype(p))) + throw(ArgumentError("vectors number $i and $j are not orthonormal (inner product = $dot_val)")) + end + end + end + return Ξ +end +function get_basis( + M::ArrayManifold, + p, + B::Union{AbstractOrthonormalBasis,CachedBasis{<:AbstractOrthonormalBasis}}; + kwargs..., +) + is_manifold_point(M, p, true; kwargs...) + Ξ = invoke(get_basis, Tuple{ArrayManifold,Any,AbstractOrthogonalBasis}, M, p, B; kwargs...) + bvectors = get_vectors(M, p, Ξ) + N = length(bvectors) + for i = 1:N + Xi_norm = norm(M, p, bvectors[i]) + if !isapprox(Xi_norm, 1) + throw(ArgumentError("vector number $i is not normalized (norm = $Xi_norm)")) + end end - map(X -> is_tangent_vector(M, p, X, true; kwargs...), get_vectors(M.manifold, array_value(p), Ξ)) return Ξ end for BT in DISAMBIGUATION_BASIS_TYPES + if BT <: Union{AbstractOrthonormalBasis,CachedBasis{<:AbstractOrthonormalBasis}} + CT = AbstractOrthonormalBasis + elseif BT <: Union{AbstractOrthogonalBasis,CachedBasis{<:AbstractOrthogonalBasis}} + CT = AbstractOrthogonalBasis + else + CT = AbstractBasis + end eval(quote - @invoke_maker 3 AbstractBasis get_basis(M::ArrayManifold, p, B::$BT; kwargs...) + @invoke_maker 3 $CT get_basis(M::ArrayManifold, p, B::$BT; kwargs...) end) end diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl index a2cc249c..7d53b521 100644 --- a/src/DecoratorManifold.jl +++ b/src/DecoratorManifold.jl @@ -481,6 +481,12 @@ end kwargs..., ) +@decorator_transparent_signature check_manifold_point( + M::AbstractDecoratorManifold, + p::MPoint; + kwargs..., +) + @decorator_transparent_signature check_tangent_vector( M::AbstractDecoratorManifold, p, @@ -488,6 +494,13 @@ end kwargs..., ) +@decorator_transparent_signature check_tangent_vector( + M::AbstractDecoratorManifold, + p::MPoint, + X::TVector; + kwargs..., +) + """ decorated_manifold(M::AbstractDecoratorManifold) diff --git a/src/bases.jl b/src/bases.jl index f8324259..127c133a 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -140,10 +140,14 @@ const all_uncached_bases = Union{AbstractBasis, DefaultBasis, DefaultOrthogonalB const DISAMBIGUATION_BASIS_TYPES = [ CachedBasis, CachedBasis{<:AbstractBasis{ℝ}}, + CachedBasis{<:AbstractOrthogonalBasis{ℝ}}, + CachedBasis{<:AbstractOrthonormalBasis{ℝ}}, DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis, DiagonalizingOrthonormalBasis, + ProjectedOrthonormalBasis{:svd,ℝ}, + ProjectedOrthonormalBasis{:gram_schmidt,ℝ}, VeeOrthogonalBasis, ] diff --git a/test/array_manifold.jl b/test/array_manifold.jl index a8c18b45..af12093b 100644 --- a/test/array_manifold.jl +++ b/test/array_manifold.jl @@ -134,8 +134,12 @@ end @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [x])) @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [x, x, x])) - @test_throws ErrorException - get_basis(A, x, CachedBasis(cb, [2*x, x, x])) + @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [2*x, x, x])) + if BT <: ManifoldsBase.AbstractOrthogonalBasis + @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [[1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) + elseif BT <: ManifoldsBase.AbstractOrthonormalBasis + @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [[2.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]])) + end end end end diff --git a/test/runtests.jl b/test/runtests.jl index 49c90a5a..0ab7ef33 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using ManifoldsBase @testset "ManifoldsBase" begin # TODO: decrease the number of ambiguities - @test length(Test.detect_ambiguities(ManifoldsBase)) <= 12 + @test length(Test.detect_ambiguities(ManifoldsBase)) <= 10 include("allocation.jl") include("numbers.jl") include("bases.jl") From 88cf3536c7f51478062faf5a66dfeb300191affe Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 13 Mar 2020 12:42:02 +0100 Subject: [PATCH 14/27] improving coverage --- src/DefaultManifold.jl | 2 +- test/array_manifold.jl | 7 ++++-- test/bases.jl | 48 ++++++++++++++++++++++++++++-------------- 3 files changed, 38 insertions(+), 19 deletions(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 7332ddbe..536f4d99 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -16,7 +16,7 @@ function check_manifold_point(M::DefaultManifold, p; kwargs...) if size(p) != representation_size(M) return DomainError( size(p), - "The point $(p) does not lie on $M, since its size is not $(N+1).", + "The point $(p) does not lie on $M, since its size is not $(representation_size(M)).", ) end return nothing diff --git a/test/array_manifold.jl b/test/array_manifold.jl index af12093b..05caa4f2 100644 --- a/test/array_manifold.jl +++ b/test/array_manifold.jl @@ -123,12 +123,15 @@ end for BT in (DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis) @testset "Basis $(BT)" begin cb = BT() - @test b == get_vectors(M, x, get_basis(A,x,cb)) + @test b == get_vectors(M, x, get_basis(A, x, cb)) v = similar(x) @test_throws ErrorException get_vector(A, x, [1.0], cb) - @test_throws DomainError get_coordinates(A, x, [1.0], cb) + @test_throws DomainError get_vector(A, [1.0], [1.0, 0.0, 0.0], cb) @test_throws ErrorException get_vector!(A, v, x, [], cb) + @test_throws DomainError get_vector!(A, v, [1.0], [1.0, 0.0, 0.0], cb) + @test_throws DomainError get_coordinates(A, x, [1.0], cb) @test_throws DomainError get_coordinates!(A, v, x, [], cb) + @test_throws DomainError get_coordinates!(A, v, [1.0], [1.0, 0.0, 0.0], cb) @test get_vector(A, x, [1, 2, 3], cb) ≈ get_vector(M, x, [1, 2, 3], cb) @test get_coordinates(A, x, [1, 2, 3], cb) ≈ get_coordinates(M, x, [1, 2, 3], cb) diff --git a/test/bases.jl b/test/bases.jl index c97c3d1d..df0ef779 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -1,4 +1,5 @@ using LinearAlgebra +using ManifoldsBase struct ProjManifold <: Manifold end @@ -8,29 +9,44 @@ ManifoldsBase.representation_size(::ProjManifold) = (2,3) ManifoldsBase.manifold_dimension(::ProjManifold) = 5 ManifoldsBase.get_vector(::ProjManifold, x, v, ::DefaultOrthonormalBasis) = reverse(v) +@testset "Dispatch" begin + @test ManifoldsBase.decorator_transparent_dispatch( + get_coordinates, + ManifoldsBase.DefaultManifold(3), + [0.0, 0.0, 0.0], + ) === Val(:parent) + @test ManifoldsBase.decorator_transparent_dispatch( + get_coordinates!, + ManifoldsBase.DefaultManifold(3), + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ) === Val(:transparent) +end + @testset "Projected and arbitrary orthonormal basis" begin M = ProjManifold() x = [sqrt(2)/2 0.0 0.0; 0.0 sqrt(2)/2 0.0] - pb = get_basis(M, x, ProjectedOrthonormalBasis(:svd)) - @test number_system(pb) == ℝ - @test get_basis(M, x, pb) == pb - N = manifold_dimension(M) - @test isa(pb, CachedBasis) - @test length(get_vectors(M, x, pb)) == N - # test orthonormality - for i in 1:N - @test norm(M, x, get_vectors(M, x, pb)[i]) ≈ 1 - for j in i+1:N - @test inner(M, x, get_vectors(M, x, pb)[i], get_vectors(M, x, pb)[j]) ≈ 0 atol = 1e-15 + for pb in (ProjectedOrthonormalBasis(:svd), ProjectedOrthonormalBasis(:gram_schmidt)) + pb = get_basis(M, x, pb) + @test number_system(pb) == ℝ + @test get_basis(M, x, pb) == pb + N = manifold_dimension(M) + @test isa(pb, CachedBasis) + @test length(get_vectors(M, x, pb)) == N + # test orthonormality + for i in 1:N + @test norm(M, x, get_vectors(M, x, pb)[i]) ≈ 1 + for j in i+1:N + @test inner(M, x, get_vectors(M, x, pb)[i], get_vectors(M, x, pb)[j]) ≈ 0 atol = 1e-15 + end + end + # check projection idempotency + for i in 1:N + @test project_tangent(M, x, get_vectors(M, x, pb)[i]) ≈ get_vectors(M, x, pb)[i] end end - # check projection idempotency - for i in 1:N - @test project_tangent(M, x, get_vectors(M, x, pb)[i]) ≈ get_vectors(M, x, pb)[i] - end - aonb = get_basis(M, x, DefaultOrthonormalBasis()) @test size(get_vectors(M, x, aonb)) == (5,) @test get_vectors(M, x, aonb)[1] ≈ [0, 0, 0, 0, 1] From 94fe008b1bd4c6f8883b72d3596da0c3fcb3e35d Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 13 Mar 2020 12:54:38 +0100 Subject: [PATCH 15/27] improving coverage again --- test/bases.jl | 12 ++++++++++++ test/decorator_manifold.jl | 6 ++++++ 2 files changed, 18 insertions(+) diff --git a/test/bases.jl b/test/bases.jl index df0ef779..189cf2ab 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -21,6 +21,17 @@ ManifoldsBase.get_vector(::ProjManifold, x, v, ::DefaultOrthonormalBasis) = reve [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ) === Val(:transparent) + @test ManifoldsBase.decorator_transparent_dispatch( + get_vector, + ManifoldsBase.DefaultManifold(3), + [0.0, 0.0, 0.0], + ) === Val(:parent) + @test ManifoldsBase.decorator_transparent_dispatch( + get_vector!, + ManifoldsBase.DefaultManifold(3), + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ) === Val(:transparent) end @testset "Projected and arbitrary orthonormal basis" begin @@ -34,6 +45,7 @@ end @test get_basis(M, x, pb) == pb N = manifold_dimension(M) @test isa(pb, CachedBasis) + @test CachedBasis(pb) === pb @test length(get_vectors(M, x, pb)) == N # test orthonormality for i in 1:N diff --git a/test/decorator_manifold.jl b/test/decorator_manifold.jl index 8d004627..2fda4db3 100644 --- a/test/decorator_manifold.jl +++ b/test/decorator_manifold.jl @@ -78,6 +78,11 @@ decorator_transparent_dispatch(::typeof(test10), M::TestDecorator3, args...) = V return 15*a end +@decorator_transparent_function function test12(M::ManifoldsBase.DefaultManifold, p) + return 12*p +end +ManifoldsBase._acts_transparently(test12, TestDecorator3, p) = Val(:foo) + @testset "Testing decorator manifold functions" begin M = ManifoldsBase.DefaultManifold(3) A = ArrayManifold(M) @@ -140,4 +145,5 @@ end @test test9(TestDecorator3(TD), p; a = 1000, b = 10000) == 11109 @test test10(TestDecorator3(TD), p; a = 11) == 110 @test test11(TestDecorator3(TD), p; a = 12) == 180 + @test_throws ErrorException test12(TestDecorator3(TD), p) end From 9b9a960e6392ec0a434deebf4bcb4c7bb9dcf334 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 13 Mar 2020 13:26:38 +0100 Subject: [PATCH 16/27] updating naming convention --- src/ArrayManifold.jl | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/src/ArrayManifold.jl b/src/ArrayManifold.jl index c677078d..3862793a 100644 --- a/src/ArrayManifold.jl +++ b/src/ArrayManifold.jl @@ -59,8 +59,8 @@ end allocate(p::ArrayMPoint) = ArrayMPoint(allocate(p.value)) allocate(p::ArrayMPoint, ::Type{T}) where {T} = ArrayMPoint(allocate(p.value, T)) -allocate(p::ArrayTVector) = ArrayTVector(allocate(p.value)) -allocate(p::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(allocate(p.value, T)) +allocate(X::ArrayTVector) = ArrayTVector(allocate(X.value)) +allocate(X::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(allocate(X.value, T)) """ array_value(p) @@ -109,9 +109,9 @@ function copyto!(p::ArrayCoTVector, q::ArrayCoTVector) copyto!(p.value, q.value) return p end -function copyto!(p::ArrayTVector, q::ArrayTVector) - copyto!(p.value, q.value) - return p +function copyto!(Y::ArrayTVector, X::ArrayTVector) + copyto!(Y.value, X.value) + return Y end function distance(M::ArrayManifold, p, q; kwargs...) @@ -320,7 +320,7 @@ number_eltype(p::ArrayMPoint) = number_eltype(p.value) number_eltype(::Type{ArrayCoTVector{V}}) where {V} = number_eltype(V) number_eltype(p::ArrayCoTVector) = number_eltype(p.value) number_eltype(::Type{ArrayTVector{V}}) where {V} = number_eltype(V) -number_eltype(p::ArrayTVector) = number_eltype(p.value) +number_eltype(X::ArrayTVector) = number_eltype(X.value) function project_tangent!(M::ArrayManifold, Y, p, X; kwargs...) is_manifold_point(M, p, true; kwargs...) @@ -333,8 +333,8 @@ similar(p::ArrayMPoint) = ArrayMPoint(similar(p.value)) similar(p::ArrayMPoint, ::Type{T}) where {T} = ArrayMPoint(similar(p.value, T)) similar(p::ArrayCoTVector) = ArrayCoTVector(similar(p.value)) similar(p::ArrayCoTVector, ::Type{T}) where {T} = ArrayCoTVector(similar(p.value, T)) -similar(p::ArrayTVector) = ArrayTVector(similar(p.value)) -similar(p::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(similar(p.value, T)) +similar(X::ArrayTVector) = ArrayTVector(similar(X.value)) +similar(X::ArrayTVector, ::Type{T}) where {T} = ArrayTVector(similar(X.value, T)) function vector_transport_along!( M::ArrayManifold, From 54e2510247f478aeebefe49c7aba6ce3d9ad52a3 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 13 Mar 2020 14:11:48 +0100 Subject: [PATCH 17/27] no more ambiguities --- src/DecoratorManifold.jl | 17 +++++++++++++++++ src/ManifoldsBase.jl | 6 ------ test/array_manifold.jl | 3 +++ test/decorator_manifold.jl | 2 ++ test/empty_manifold.jl | 12 ++++++------ test/runtests.jl | 2 +- 6 files changed, 29 insertions(+), 13 deletions(-) diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl index 7d53b521..0ac17a55 100644 --- a/src/DecoratorManifold.jl +++ b/src/DecoratorManifold.jl @@ -520,11 +520,20 @@ decorated_manifold(M::Manifold) = M.manifold M::AbstractDecoratorManifold, m::AbstractRetractionMethod, ) +@decorator_transparent_signature injectivity_radius( + M::AbstractDecoratorManifold, + m::ExponentialRetraction, +) @decorator_transparent_signature injectivity_radius( M::AbstractDecoratorManifold, p, m::AbstractRetractionMethod, ) +@decorator_transparent_signature injectivity_radius( + M::AbstractDecoratorManifold, + p, + m::ExponentialRetraction, +) @decorator_transparent_signature inner(M::AbstractDecoratorManifold, p, X, Y) @@ -648,5 +657,13 @@ decorated_manifold(M::Manifold) = M.manifold q, m::AbstractVectorTransportMethod, ) +@decorator_transparent_signature vector_transport_to!( + M::AbstractDecoratorManifold, + Y, + p, + X, + q, + m::ProjectionTransport, +) @decorator_transparent_signature zero_tangent_vector!(M::AbstractDecoratorManifold, X, p) diff --git a/src/ManifoldsBase.jl b/src/ManifoldsBase.jl index 5df9ba07..c6e0baac 100644 --- a/src/ManifoldsBase.jl +++ b/src/ManifoldsBase.jl @@ -242,9 +242,6 @@ By default, `check_manifold_point` returns `nothing`, i.e. if no checks are impl assumption is to be optimistic for a point not deriving from the [`MPoint`](@ref) type. """ check_manifold_point(M::Manifold, p; kwargs...) = nothing -function check_manifold_point(M::Manifold, p::MPoint; kwargs...) - error(manifold_function_not_implemented_message(M, check_manifold_point, p)) -end """ check_tangent_vector(M::Manifold, p, X; kwargs...) -> Union{Nothing,String} @@ -259,9 +256,6 @@ assumption is to be optimistic for tangent vectors not deriving from the [`TVect type. """ check_tangent_vector(M::Manifold, p, X; kwargs...) = nothing -function check_tangent_vector(M::Manifold, p::MPoint, X::TVector; kwargs...) - error(manifold_function_not_implemented_message(M, check_tangent_vector, p, X)) -end """ distance(M::Manifold, p, q) diff --git a/test/array_manifold.jl b/test/array_manifold.jl index 05caa4f2..4bd5ae1d 100644 --- a/test/array_manifold.jl +++ b/test/array_manifold.jl @@ -125,6 +125,7 @@ end cb = BT() @test b == get_vectors(M, x, get_basis(A, x, cb)) v = similar(x) + v2 = similar(x) @test_throws ErrorException get_vector(A, x, [1.0], cb) @test_throws DomainError get_vector(A, [1.0], [1.0, 0.0, 0.0], cb) @test_throws ErrorException get_vector!(A, v, x, [], cb) @@ -133,7 +134,9 @@ end @test_throws DomainError get_coordinates!(A, v, x, [], cb) @test_throws DomainError get_coordinates!(A, v, [1.0], [1.0, 0.0, 0.0], cb) @test get_vector(A, x, [1, 2, 3], cb) ≈ get_vector(M, x, [1, 2, 3], cb) + @test get_vector!(A, v2, x, [1, 2, 3], cb) ≈ get_vector!(M, v, x, [1, 2, 3], cb) @test get_coordinates(A, x, [1, 2, 3], cb) ≈ get_coordinates(M, x, [1, 2, 3], cb) + @test get_coordinates!(A, v2, x, [1, 2, 3], cb) ≈ get_coordinates!(M, v, x, [1, 2, 3], cb) @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [x])) @test_throws ErrorException get_basis(A, x, CachedBasis(cb, [x, x, x])) diff --git a/test/decorator_manifold.jl b/test/decorator_manifold.jl index 2fda4db3..4c27af28 100644 --- a/test/decorator_manifold.jl +++ b/test/decorator_manifold.jl @@ -118,6 +118,8 @@ ManifoldsBase._acts_transparently(test12, TestDecorator3, p) = Val(:foo) @test (@inferred ManifoldsBase.default_decorator_dispatch(M)) === Val(false) @test ManifoldsBase.is_default_decorator(M) === false + @test injectivity_radius(TD, ManifoldsBase.ExponentialRetraction()) == Inf + @test test1(TD, p) == 1 @test test1(TD, p; a = 1000) == 1001 @test test2(TD, p) == 102 diff --git a/test/empty_manifold.jl b/test/empty_manifold.jl index a91d993f..c89ac51a 100644 --- a/test/empty_manifold.jl +++ b/test/empty_manifold.jl @@ -118,15 +118,15 @@ struct NonCoTVector <: CoTVector end @test_throws ErrorException zero_tangent_vector(M, [0]) @test check_manifold_point(M, [0]) === nothing - @test_throws ErrorException check_manifold_point(M, p) + @test check_manifold_point(M, p) === nothing @test is_manifold_point(M, [0]) - @test check_manifold_point(M, [0]) == nothing + @test check_manifold_point(M, [0]) === nothing @test check_tangent_vector(M, [0], [0]) === nothing - @test_throws ErrorException check_tangent_vector(M, p, v) + @test check_tangent_vector(M, p, v) === nothing @test is_tangent_vector(M, [0], [0]) - @test check_tangent_vector(M, [0], [0]) == nothing + @test check_tangent_vector(M, [0], [0]) === nothing - @test_throws ErrorException hat!(M,[0],[0],[0]) - @test_throws ErrorException vee!(M,[0],[0],[0]) + @test_throws ErrorException hat!(M, [0], [0], [0]) + @test_throws ErrorException vee!(M, [0], [0], [0]) end diff --git a/test/runtests.jl b/test/runtests.jl index 0ab7ef33..11551836 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -3,7 +3,7 @@ using ManifoldsBase @testset "ManifoldsBase" begin # TODO: decrease the number of ambiguities - @test length(Test.detect_ambiguities(ManifoldsBase)) <= 10 + @test length(Test.detect_ambiguities(ManifoldsBase)) == 0 include("allocation.jl") include("numbers.jl") include("bases.jl") From 25c64639d4ad9d5ddb06f57babe5d5c5bf65189b Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Fri, 13 Mar 2020 16:00:58 +0100 Subject: [PATCH 18/27] removing two unnecessary methods --- src/DecoratorManifold.jl | 13 ------------- 1 file changed, 13 deletions(-) diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl index 0ac17a55..c502a6c6 100644 --- a/src/DecoratorManifold.jl +++ b/src/DecoratorManifold.jl @@ -481,12 +481,6 @@ end kwargs..., ) -@decorator_transparent_signature check_manifold_point( - M::AbstractDecoratorManifold, - p::MPoint; - kwargs..., -) - @decorator_transparent_signature check_tangent_vector( M::AbstractDecoratorManifold, p, @@ -494,13 +488,6 @@ end kwargs..., ) -@decorator_transparent_signature check_tangent_vector( - M::AbstractDecoratorManifold, - p::MPoint, - X::TVector; - kwargs..., -) - """ decorated_manifold(M::AbstractDecoratorManifold) From 2230e68d155daef029d529ccf21c18177ee8d02d Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Sat, 14 Mar 2020 08:51:07 +0100 Subject: [PATCH 19/27] test fallbacks and fix a bufg in the signature decorator. --- src/DecoratorManifold.jl | 2 +- test/decorator_manifold.jl | 24 ++++++++++++++++++++++++ 2 files changed, 25 insertions(+), 1 deletion(-) diff --git a/src/DecoratorManifold.jl b/src/DecoratorManifold.jl index c502a6c6..5e6d327f 100644 --- a/src/DecoratorManifold.jl +++ b/src/DecoratorManifold.jl @@ -363,7 +363,7 @@ macro decorator_transparent_signature(ex) $(kwargs_list...), ) where {$(where_exprs...)} return ($fname)( - decorated_manifold($(argnames[1])), + ManifoldsBase.decorated_manifold($(argnames[1])), $(argnames[2:end]...); $(kwargs_call...), ) diff --git a/test/decorator_manifold.jl b/test/decorator_manifold.jl index 4c27af28..6b48a681 100644 --- a/test/decorator_manifold.jl +++ b/test/decorator_manifold.jl @@ -83,6 +83,22 @@ end end ManifoldsBase._acts_transparently(test12, TestDecorator3, p) = Val(:foo) +@decorator_transparent_function :none function test13(M::TestDecorator3, p) + return 13.5*p +end +decorator_transparent_dispatch(::typeof(test13), M::TestDecorator, args...) = Val(:intransparent) +decorator_transparent_dispatch(::typeof(test13), M::TestDecorator2, args...) = Val(:transparent) +test13(::ManifoldsBase.DefaultManifold,a) = 13*a + +function test14(M::AbstractDecoratorManifold, p) + return 14.5*p +end +@decorator_transparent_signature test14(M::AbstractDecoratorManifold,p) +decorator_transparent_dispatch(::typeof(test14), M::TestDecorator3, args...) = Val(:none) +decorator_transparent_dispatch(::typeof(test14), M::TestDecorator, args...) = Val(:intransparent) +decorator_transparent_dispatch(::typeof(test14), M::TestDecorator2, args...) = Val(:transparent) +test14(::ManifoldsBase.DefaultManifold,a) = 14*a + @testset "Testing decorator manifold functions" begin M = ManifoldsBase.DefaultManifold(3) A = ArrayManifold(M) @@ -148,4 +164,12 @@ ManifoldsBase._acts_transparently(test12, TestDecorator3, p) = Val(:foo) @test test10(TestDecorator3(TD), p; a = 11) == 110 @test test11(TestDecorator3(TD), p; a = 12) == 180 @test_throws ErrorException test12(TestDecorator3(TD), p) + + @test_throws ErrorException test13(TestDecorator3(M),1) # :none nonexistent + @test_throws ErrorException test13(TestDecorator(M),1) # not implemented + @test test13(TestDecorator2(M),2) == 26 # from parent + + @test_throws ErrorException test14(TestDecorator3(M),1) # :none nonexistent + @test_throws ErrorException test14(TestDecorator(M),1) # not implemented + @test test14(TestDecorator2(M),2) == 28 # from parent end From 99fafe982d08c2abb6e58cee281685b615bc8649 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sat, 14 Mar 2020 14:25:32 +0100 Subject: [PATCH 20/27] increasing coverage and a bugfix --- src/DefaultManifold.jl | 9 ++- src/bases.jl | 2 +- test/bases.jl | 166 +++++++++++++++++++++++++++++++++-------- test/runtests.jl | 2 +- 4 files changed, 145 insertions(+), 34 deletions(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 536f4d99..1b0560d5 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -55,14 +55,19 @@ end function get_basis(M::DefaultManifold, p, B::DefaultBasis) return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)]) end +function get_basis(M::DefaultManifold, p, B::DiagonalizingOrthonormalBasis) + vecs = get_vectors(M, p, get_basis(M, p, DefaultOrthonormalBasis())) + eigenvalues = zeros(real(eltype(p)), manifold_dimension(M)) + return CachedBasis(B, DiagonalizingBasisData(B.frame_direction, eigenvalues, vecs)) +end function get_coordinates!(M::DefaultManifold, Y, p, X, B::DefaultOrthonormalBasis) - Y .= reshape(X, manifold_dimension(M)) + copyto!(Y, reshape(X, manifold_dimension(M))) return Y end function get_vector!(M::DefaultManifold, Y, p, X, B::DefaultOrthonormalBasis) - Y .= reshape(X, representation_size(M)) + copyto!(Y, reshape(X, representation_size(M))) return Y end diff --git a/src/bases.jl b/src/bases.jl index 127c133a..2c38f798 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -395,7 +395,7 @@ function get_vector!(M::Manifold, Y, p, X, B::CachedBasis) Xt = X[1] * bvectors[1] copyto!(Y, Xt) for i = 2:length(X) - Y += X[i] * bvectors[i] + copyto!(Y, Y + X[i] * bvectors[i]) end return Y else diff --git a/test/bases.jl b/test/bases.jl index 189cf2ab..cfa0fbdd 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -1,5 +1,8 @@ using LinearAlgebra using ManifoldsBase +using ManifoldsBase: DefaultManifold +using Test +import Base: +, -, *, copyto!, isapprox struct ProjManifold <: Manifold end @@ -12,23 +15,23 @@ ManifoldsBase.get_vector(::ProjManifold, x, v, ::DefaultOrthonormalBasis) = reve @testset "Dispatch" begin @test ManifoldsBase.decorator_transparent_dispatch( get_coordinates, - ManifoldsBase.DefaultManifold(3), + DefaultManifold(3), [0.0, 0.0, 0.0], ) === Val(:parent) @test ManifoldsBase.decorator_transparent_dispatch( get_coordinates!, - ManifoldsBase.DefaultManifold(3), + DefaultManifold(3), [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ) === Val(:transparent) @test ManifoldsBase.decorator_transparent_dispatch( get_vector, - ManifoldsBase.DefaultManifold(3), + DefaultManifold(3), [0.0, 0.0, 0.0], ) === Val(:parent) @test ManifoldsBase.decorator_transparent_dispatch( get_vector!, - ManifoldsBase.DefaultManifold(3), + DefaultManifold(3), [0.0, 0.0, 0.0], [0.0, 0.0, 0.0], ) === Val(:transparent) @@ -67,6 +70,89 @@ end struct NonManifold <: Manifold end struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end +struct NonBroadcastBasisThing{T} + v::T +end + ++(a::NonBroadcastBasisThing, b::NonBroadcastBasisThing) = NonBroadcastBasisThing(a.v + b.v) +*(α, a::NonBroadcastBasisThing) = NonBroadcastBasisThing(α * a.v) +-(a::NonBroadcastBasisThing, b::NonBroadcastBasisThing) = NonBroadcastBasisThing(a.v - b.v) + +isapprox(a::NonBroadcastBasisThing, b::NonBroadcastBasisThing) = isapprox(a.v, b.v) + +function ManifoldsBase.number_eltype(a::NonBroadcastBasisThing) + return typeof(reduce(+, one(number_eltype(eti)) for eti ∈ a.v)) +end + +import ManifoldsBase.allocate + +allocate(a::NonBroadcastBasisThing) = NonBroadcastBasisThing(allocate(a.v)) +function allocate(a::NonBroadcastBasisThing, ::Type{T}) where {T} + return NonBroadcastBasisThing(allocate(a.v, T)) +end +allocate(::NonBroadcastBasisThing, ::Type{T}, s::Integer) where {S,T} = Vector{T}(undef, s) + +function copyto!(a::NonBroadcastBasisThing, b::NonBroadcastBasisThing) + copyto!(a.v, b.v) + return a +end + +function ManifoldsBase.log!( + ::DefaultManifold, + v::NonBroadcastBasisThing, + x::NonBroadcastBasisThing, + y::NonBroadcastBasisThing, +) + return copyto!(v, y - x) +end + +function ManifoldsBase.exp!( + ::DefaultManifold, + y::NonBroadcastBasisThing, + x::NonBroadcastBasisThing, + v::NonBroadcastBasisThing, +) + return copyto!(y, x + v) +end + +function ManifoldsBase.get_basis(M::DefaultManifold, p::NonBroadcastBasisThing, B::DefaultOrthonormalBasis) + return CachedBasis(B, [NonBroadcastBasisThing(ManifoldsBase._euclidean_basis_vector(p.v, i)) for i in eachindex(p.v)]) +end +function ManifoldsBase.get_basis(M::DefaultManifold, p::NonBroadcastBasisThing, B::DefaultOrthogonalBasis) + return CachedBasis(B, [NonBroadcastBasisThing(ManifoldsBase._euclidean_basis_vector(p.v, i)) for i in eachindex(p.v)]) +end +function ManifoldsBase.get_basis(M::DefaultManifold, p::NonBroadcastBasisThing, B::DefaultBasis) + return CachedBasis(B, [NonBroadcastBasisThing(ManifoldsBase._euclidean_basis_vector(p.v, i)) for i in eachindex(p.v)]) +end + +function ManifoldsBase.get_coordinates!( + M::DefaultManifold, + Y, + p::NonBroadcastBasisThing, + X::NonBroadcastBasisThing, + B::DefaultOrthonormalBasis, +) + copyto!(Y, reshape(X.v, manifold_dimension(M))) + return Y +end + +function ManifoldsBase.get_vector!( + M::DefaultManifold, + Y::NonBroadcastBasisThing, + p::NonBroadcastBasisThing, + X, + B::DefaultOrthonormalBasis, +) + copyto!(Y.v, reshape(X, representation_size(M))) + return Y +end + +ManifoldsBase.inner(::DefaultManifold, x::NonBroadcastBasisThing, v::NonBroadcastBasisThing, w::NonBroadcastBasisThing) = dot(v.v, w.v) + +ManifoldsBase._get_vector_cache_broadcast(::NonBroadcastBasisThing) = Val(false) + +DiagonalizingBasisProxy() = DiagonalizingOrthonormalBasis([1.0, 0.0, 0.0]) + @testset "ManifoldsBase.jl stuff" begin @testset "Errors" begin @@ -82,43 +168,63 @@ struct NonBasis <: ManifoldsBase.AbstractBasis{ℝ} end @test_throws ErrorException get_vectors(m, [0], NonBasis()) end - M = ManifoldsBase.DefaultManifold(3) + M = DefaultManifold(3) - pts = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] - @testset "basis representation" for BT in (DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis) + _pts = [[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]] + @testset "basis representation" for BT in ( + DefaultBasis, + DefaultOrthonormalBasis, + DefaultOrthogonalBasis, + DiagonalizingBasisProxy, + ), pts in (_pts, map(NonBroadcastBasisThing, _pts)) + if BT == DiagonalizingBasisProxy && pts !== _pts + continue + end v1 = log(M, pts[1], pts[2]) - vb = get_coordinates(M, pts[1], v1, BT()) - @test isa(vb, AbstractVector) - vbi = get_vector(M, pts[1], vb, BT()) - @test isapprox(M, pts[1], v1, vbi) + if BT != DiagonalizingBasisProxy + vb = get_coordinates(M, pts[1], v1, BT()) + @test isa(vb, AbstractVector) + vbi = get_vector(M, pts[1], vb, BT()) + @test isapprox(M, pts[1], v1, vbi) + end b = get_basis(M, pts[1], BT()) - @test isa(b, CachedBasis{BT{ℝ},Array{Array{Float64,1},1},ℝ}) + if BT != DiagonalizingBasisProxy + if pts[1] isa Array + @test isa(b, CachedBasis{BT{ℝ},Vector{Vector{Float64}},ℝ}) + else + @test isa(b, CachedBasis{BT{ℝ},Vector{NonBroadcastBasisThing{Vector{Float64}}},ℝ}) + end + end @test get_basis(M, pts[1], b) === b N = manifold_dimension(M) @test length(get_vectors(M, pts[1], b)) == N # check orthonormality - for i in 1:N - @test norm(M, pts[1], get_vectors(M, pts[1], b)[i]) ≈ 1 - for j in i+1:N - @test inner( - M, - pts[1], - get_vectors(M, pts[1], b)[i], - get_vectors(M, pts[1], b)[j] - ) ≈ 0 + if BT isa DefaultOrthonormalBasis && pts[1] isa Vector + for i in 1:N + @test norm(M, pts[1], get_vectors(M, pts[1], b)[i]) ≈ 1 + for j in i+1:N + @test inner( + M, + pts[1], + get_vectors(M, pts[1], b)[i], + get_vectors(M, pts[1], b)[j] + ) ≈ 0 + end + end + # check that the coefficients correspond to the basis + for i in 1:N + @test inner(M, pts[1], v1, get_vectors(M, pts[1], b)[i]) ≈ vb[i] end - end - # check that the coefficients correspond to the basis - for i in 1:N - @test inner(M, pts[1], v1, get_vectors(M, pts[1], b)[i]) ≈ vb[i] end - @test get_coordinates(M, pts[1], v1, b) ≈ get_coordinates(M, pts[1], v1, BT()) - @test get_vector(M, pts[1], vb, b) ≈ get_vector(M, pts[1], vb, BT()) + if BT != DiagonalizingBasisProxy + @test get_coordinates(M, pts[1], v1, b) ≈ get_coordinates(M, pts[1], v1, BT()) + @test get_vector(M, pts[1], vb, b) ≈ get_vector(M, pts[1], vb, BT()) + end - v1c = allocate(v1) + v1c = Vector{Float64}(undef, 3) get_coordinates!(M, v1c, pts[1], v1, b) @test v1c ≈ get_coordinates(M, pts[1], v1, b) @@ -143,7 +249,7 @@ end 2.0 3.0""" - M = ManifoldsBase.DefaultManifold(2, 3) + M = DefaultManifold(2, 3) x = collect(reshape(1.0:6.0, (2, 3))) pb = get_basis(M, x, DefaultOrthonormalBasis()) @test sprint(show, "text/plain", pb) == """ @@ -200,7 +306,7 @@ end 5.0 6.0""" - M = ManifoldsBase.DefaultManifold(1, 1, 1) + M = DefaultManifold(1, 1, 1) x = reshape(Float64[1], (1, 1, 1)) pb = get_basis(M, x, DefaultOrthonormalBasis()) @test sprint(show, "text/plain", pb) == """ diff --git a/test/runtests.jl b/test/runtests.jl index 11551836..2ba241a9 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -2,7 +2,7 @@ using Test using ManifoldsBase @testset "ManifoldsBase" begin - # TODO: decrease the number of ambiguities + # This should remain at 0 @test length(Test.detect_ambiguities(ManifoldsBase)) == 0 include("allocation.jl") include("numbers.jl") From 8ebe5c57ac7c05dd7e35a3342bf602c1c57230aa Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Sun, 15 Mar 2020 15:03:53 +0100 Subject: [PATCH 21/27] extedns decoratormanifold and adds a test for complex bases. --- src/DefaultManifold.jl | 6 +++--- test/bases.jl | 11 +++++++++++ 2 files changed, 14 insertions(+), 3 deletions(-) diff --git a/src/DefaultManifold.jl b/src/DefaultManifold.jl index 1b0560d5..bc8b6727 100644 --- a/src/DefaultManifold.jl +++ b/src/DefaultManifold.jl @@ -9,8 +9,8 @@ This manifold further illustrates how to type your manifold points and tangent v that the interface does not require this, but it might be handy in debugging and educative situations to verify correctness of involved variabes. """ -struct DefaultManifold{T<:Tuple} <: Manifold where {T} end -DefaultManifold(n::Vararg{Int,N}) where {N} = DefaultManifold{Tuple{n...}}() +struct DefaultManifold{T<:Tuple, 𝔽} <: Manifold where {T, 𝔽} end +DefaultManifold(n::Vararg{Int,N}; field = ℝ) where {N} = DefaultManifold{Tuple{n...}, field}() function check_manifold_point(M::DefaultManifold, p; kwargs...) if size(p) != representation_size(M) @@ -77,7 +77,7 @@ injectivity_radius(::DefaultManifold) = Inf log!(::DefaultManifold, v, x, y) = (v .= y .- x) -@generated manifold_dimension(::DefaultManifold{T}) where {T} = *(T.parameters...) +@generated manifold_dimension(::DefaultManifold{T,𝔽}) where {T,𝔽} = *(T.parameters...)*real_dimension(𝔽) norm(::DefaultManifold, x, v) = norm(v) diff --git a/test/bases.jl b/test/bases.jl index cfa0fbdd..b34439de 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -234,6 +234,17 @@ DiagonalizingBasisProxy() = DiagonalizingOrthonormalBasis([1.0, 0.0, 0.0]) end end +@testset "Complex Cached Basis" begin + M = ManifoldsBase.DefaultManifold(3; field = ManifoldsBase.ℂ) + p = [1.0, 2.0im, 3.0] + X = [1.2, 2.2im, 2.3im] + b = [ [Matrix{Float64}(I,3,3)[:,i] for i=1:3]..., [Matrix{Float64}(I,3,3)[:,i]im for i=1:3]... ] + B = CachedBasis(DefaultOrthonormalBasis(),b,ManifoldsBase.ℂ) + a = get_coordinates(M,p,X,B) + Y = get_vector(M,p,a,B) + @test Y ≈ X +end + @testset "Basis show methods" begin @test sprint(show, DefaultBasis()) == "DefaultBasis(ℝ)" @test sprint(show, DefaultOrthogonalBasis()) == "DefaultOrthogonalBasis(ℝ)" From 444346d3ec946e57ba3e11d45ba8d768ee9c83f3 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Sun, 15 Mar 2020 15:32:29 +0100 Subject: [PATCH 22/27] adds a test for :parent dispatch. --- test/decorator_manifold.jl | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/test/decorator_manifold.jl b/test/decorator_manifold.jl index 6b48a681..c9a34b1c 100644 --- a/test/decorator_manifold.jl +++ b/test/decorator_manifold.jl @@ -21,6 +21,12 @@ struct TestDecorator3{M<:Manifold} <: AbstractTestDecorator manifold::M end +abstract type AbstractParentDecorator <: AbstractDecoratorManifold end + +struct ChildDecorator{M<:Manifold} <: AbstractParentDecorator + manifold::M +end + test1(M::Manifold, p; a = 0) = 101 + a test2(M::Manifold, p; a = 0) = 102 + a test3(M::Manifold, p; a = 0) = 103 + a @@ -99,6 +105,20 @@ decorator_transparent_dispatch(::typeof(test14), M::TestDecorator, args...) = Va decorator_transparent_dispatch(::typeof(test14), M::TestDecorator2, args...) = Val(:transparent) test14(::ManifoldsBase.DefaultManifold,a) = 14*a +test15(::ManifoldsBase.DefaultManifold,a) = 15.5*a +@decorator_transparent_function function test15(M::AbstractDecoratorManifold, p) + error("Not yet implemented") +end +test15(::AbstractParentDecorator,p) = 15*p +decorator_transparent_dispatch(::typeof(test15), M::ChildDecorator, args...) = Val(:parent) + +function test16(::AbstractParentDecorator, p) + return 16*p +end +test16(::ManifoldsBase.DefaultManifold, a) = 16.5*a +@decorator_transparent_signature test16(M::AbstractDecoratorManifold, p) +decorator_transparent_dispatch(::typeof(test16), M::ChildDecorator, args...) = Val(:parent) + @testset "Testing decorator manifold functions" begin M = ManifoldsBase.DefaultManifold(3) A = ArrayManifold(M) @@ -172,4 +192,7 @@ test14(::ManifoldsBase.DefaultManifold,a) = 14*a @test_throws ErrorException test14(TestDecorator3(M),1) # :none nonexistent @test_throws ErrorException test14(TestDecorator(M),1) # not implemented @test test14(TestDecorator2(M),2) == 28 # from parent + + @test test15(ChildDecorator(M),1) == 15 + @test test16(ChildDecorator(M),1) == 16 end From e4b421b275919371e2bc7ee788b87cfdd7f613ce Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Sun, 15 Mar 2020 16:18:38 +0100 Subject: [PATCH 23/27] Distinguishes between a real basis and a complex basis for complex manifolds (returning complex values for the first case and real for the second). --- src/bases.jl | 14 ++++++++++++-- test/bases.jl | 4 ++-- 2 files changed, 14 insertions(+), 4 deletions(-) diff --git a/src/bases.jl b/src/bases.jl index 2c38f798..d3cde4c8 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -151,6 +151,16 @@ const DISAMBIGUATION_BASIS_TYPES = [ VeeOrthogonalBasis, ] +function allocate_result(M::Manifold, f::typeof(get_coordinates), p, X, B) + T = allocate_result_type(M, f, (p, X)) + return allocate(p, T, manifold_dimension(M)) +end + +function allocate_result(M::Manifold, f::typeof(get_coordinates), p, X, B::CachedBasis) + T = allocate_result_type(M, f, (p, X)) + return allocate(p, T, length(get_vectors(M, p, B))) +end + function allocate_result(M::Manifold, f::typeof(get_coordinates), p, X) T = allocate_result_type(M, f, (p, X)) return allocate(p, T, manifold_dimension(M)) @@ -291,7 +301,7 @@ requires either a dual basis or the cached basis to be selfdual, for example ort See also: [`get_vector`](@ref), [`get_basis`](@ref) """ function get_coordinates(M::Manifold, p, X, B::AbstractBasis) - Y = allocate_result(M, get_coordinates, p, X) + Y = allocate_result(M, get_coordinates, p, X, B) return get_coordinates!(M, Y, p, X, B) end @decorator_transparent_signature get_coordinates(M::AbstractDecoratorManifold, p, X, B::AbstractBasis) @@ -332,7 +342,7 @@ function get_coordinates!( return Y end function get_coordinates!(M::Manifold, Y, p, X, B::CachedBasis) - map!(vb -> inner(M, p, X, vb), Y, get_vectors(M, p, B)) + map!(vb -> conj(inner(M, p, X, vb)), Y, get_vectors(M, p, B)) return Y end diff --git a/test/bases.jl b/test/bases.jl index b34439de..e8b17aff 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -238,8 +238,8 @@ end M = ManifoldsBase.DefaultManifold(3; field = ManifoldsBase.ℂ) p = [1.0, 2.0im, 3.0] X = [1.2, 2.2im, 2.3im] - b = [ [Matrix{Float64}(I,3,3)[:,i] for i=1:3]..., [Matrix{Float64}(I,3,3)[:,i]im for i=1:3]... ] - B = CachedBasis(DefaultOrthonormalBasis(),b,ManifoldsBase.ℂ) + b = [Matrix{Float64}(I,3,3)[:,i] for i=1:3] + B = CachedBasis(DefaultOrthonormalBasis{ManifoldsBase.ℂ}(),b,ManifoldsBase.ℂ) a = get_coordinates(M,p,X,B) Y = get_vector(M,p,a,B) @test Y ≈ X From 93d0a3d78155642b6869a02a9384e0646395260f Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sun, 15 Mar 2020 16:27:42 +0100 Subject: [PATCH 24/27] improved get_basis with projected Gram-Schmidt method --- src/bases.jl | 20 +++++++++++++++----- test/bases.jl | 18 ++++++++++++++++++ 2 files changed, 33 insertions(+), 5 deletions(-) diff --git a/src/bases.jl b/src/bases.jl index 2c38f798..32b20036 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -240,7 +240,14 @@ function get_basis(M::Manifold, p, B::ProjectedOrthonormalBasis{:svd,ℝ}) end return CachedBasis(B, vecs) end -function get_basis(M::Manifold, p, B::ProjectedOrthonormalBasis{:gram_schmidt,ℝ}; kwargs...) +function get_basis( + M::Manifold, + p, + B::ProjectedOrthonormalBasis{:gram_schmidt,ℝ}; + warn_linearly_dependent = false, + return_incomplete_set = false, + kwargs..., +) E = [_euclidean_basis_vector(p, i) for i in eachindex(p)] N = length(E) Ξ = empty(E) @@ -254,13 +261,13 @@ function get_basis(M::Manifold, p, B::ProjectedOrthonormalBasis{:gram_schmidt, end nrmΞₙ = norm(M, p, Ξₙ) if nrmΞₙ == 0 - @warn "Input vector $(n) has length 0." + warn_linearly_dependent && @warn "Input vector $(n) has length 0." @goto skip end Ξₙ ./= nrmΞₙ for k = 1:K if !isapprox(real(inner(M, p, Ξ[k], Ξₙ)), 0; kwargs...) - @warn "Input vector $(n) is not linearly independent of output basis vector $(k)." + warn_linearly_dependent && @warn "Input vector $(n) is not linearly independent of output basis vector $(k)." @goto skip end end @@ -269,8 +276,11 @@ function get_basis(M::Manifold, p, B::ProjectedOrthonormalBasis{:gram_schmidt, K * real_dimension(number_system(B)) == dim && return CachedBasis(B, Ξ, ℝ) @label skip end - @warn "get_basis with bases $(typeof(B)) only found $(K) orthonormal basis vectors, but manifold dimension is $(dim)." - return CachedBasis(B, Ξ) + if return_incomplete_set + return CachedBasis(B, Ξ, ℝ) + else + error("get_basis with bases $(typeof(B)) only found $(K) orthonormal basis vectors, but manifold dimension is $(dim).") + end end """ diff --git a/test/bases.jl b/test/bases.jl index b34439de..0b3169a9 100644 --- a/test/bases.jl +++ b/test/bases.jl @@ -37,6 +37,15 @@ ManifoldsBase.get_vector(::ProjManifold, x, v, ::DefaultOrthonormalBasis) = reve ) === Val(:transparent) end +struct ProjectionTestManifold <: Manifold end +ManifoldsBase.inner(::ProjectionTestManifold, ::Any, X, Y) = dot(X, Y) +function ManifoldsBase.project_tangent!(::ProjectionTestManifold, Y, p, X) + Y .= X .- dot(p, X) .* p + Y[end] = 0 + return Y +end +ManifoldsBase.manifold_dimension(::ProjectionTestManifold) = 100 + @testset "Projected and arbitrary orthonormal basis" begin M = ProjManifold() x = [sqrt(2)/2 0.0 0.0; @@ -65,6 +74,15 @@ end aonb = get_basis(M, x, DefaultOrthonormalBasis()) @test size(get_vectors(M, x, aonb)) == (5,) @test get_vectors(M, x, aonb)[1] ≈ [0, 0, 0, 0, 1] + + @testset "Gram-Schmidt special cases" begin + tm = ProjectionTestManifold() + bt = ProjectedOrthonormalBasis(:gram_schmidt) + p = [sqrt(2)/2, 0.0, sqrt(2)/2, 0.0, 0.0] + @test_throws ErrorException get_basis(tm, p, bt) + b = get_basis(tm, p, bt; return_incomplete_set = true, warn_linearly_dependent = true) + @test length(get_vectors(tm, p, b)) == 3 + end end struct NonManifold <: Manifold end From 6183651634076625c0a32a4bb95cd7bb37a20389 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Sun, 15 Mar 2020 16:27:57 +0100 Subject: [PATCH 25/27] removes an unnecessary case. --- src/bases.jl | 5 ----- 1 file changed, 5 deletions(-) diff --git a/src/bases.jl b/src/bases.jl index d3cde4c8..d7295a15 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -161,11 +161,6 @@ function allocate_result(M::Manifold, f::typeof(get_coordinates), p, X, B::Cache return allocate(p, T, length(get_vectors(M, p, B))) end -function allocate_result(M::Manifold, f::typeof(get_coordinates), p, X) - T = allocate_result_type(M, f, (p, X)) - return allocate(p, T, manifold_dimension(M)) -end - @inline function allocate_result_type( M::Manifold, f::Union{typeof(get_coordinates), typeof(get_vector)}, From 8a0f8fe8d492fc35b546611af6f661ad3998ae25 Mon Sep 17 00:00:00 2001 From: Ronny Bergmann Date: Sun, 15 Mar 2020 16:32:56 +0100 Subject: [PATCH 26/27] reduce code redundancy. --- src/bases.jl | 12 ++---------- 1 file changed, 2 insertions(+), 10 deletions(-) diff --git a/src/bases.jl b/src/bases.jl index e40e4bfc..e250fe4c 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -434,18 +434,10 @@ end function get_vectors( M::Manifold, p, - B::CachedBasis{<:AbstractBasis,<:AbstractArray}, + B::CachedBasis, ) - return B.data + return _get_vectors(B) end -function get_vectors( - M::Manifold, - p, - B::CachedBasis{<:AbstractBasis,<:DiagonalizingBasisData}, -) - return B.data.vectors -end - #internal for directly cached basis i.e. those that are just arrays – used in show _get_vectors(B::CachedBasis{<:AbstractBasis,<:AbstractArray}) = B.data _get_vectors(B::CachedBasis{<:AbstractBasis,<:DiagonalizingBasisData}) = B.data.vectors From f0f3faea893e756ec089989487dd46acb33b0052 Mon Sep 17 00:00:00 2001 From: Mateusz Baran Date: Sun, 15 Mar 2020 16:37:36 +0100 Subject: [PATCH 27/27] making the complex-coefficient get_coordinates! for CachedBasis signature a bit more narrow --- src/bases.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/bases.jl b/src/bases.jl index e40e4bfc..6cd3b1d0 100644 --- a/src/bases.jl +++ b/src/bases.jl @@ -142,6 +142,7 @@ const DISAMBIGUATION_BASIS_TYPES = [ CachedBasis{<:AbstractBasis{ℝ}}, CachedBasis{<:AbstractOrthogonalBasis{ℝ}}, CachedBasis{<:AbstractOrthonormalBasis{ℝ}}, + CachedBasis{<:AbstractBasis{ℂ}}, DefaultBasis, DefaultOrthonormalBasis, DefaultOrthogonalBasis, @@ -346,7 +347,7 @@ function get_coordinates!( map!(vb -> real(inner(M, p, X, vb)), Y, get_vectors(M, p, B)) return Y end -function get_coordinates!(M::Manifold, Y, p, X, B::CachedBasis) +function get_coordinates!(M::Manifold, Y, p, X, B::CachedBasis{<:AbstractBasis{ℂ}}) map!(vb -> conj(inner(M, p, X, vb)), Y, get_vectors(M, p, B)) return Y end