Skip to content

Commit

Permalink
Merge pull request #26 from JuliaNLSolvers/mbaran/basis-improvements
Browse files Browse the repository at this point in the history
Moving Bases to Base
  • Loading branch information
mateuszbaran authored Mar 15, 2020
2 parents 54b6c86 + 526cfab commit b58ebf8
Show file tree
Hide file tree
Showing 14 changed files with 1,544 additions and 271 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "ManifoldsBase"
uuid = "3362f125-f0bb-47a3-aa74-596ffd7ef2fb"
authors = ["Seth Axen <[email protected]>", "Mateusz Baran <[email protected]>", "Ronny Bergmann <[email protected]>", "Antoine Levitt <[email protected]>"]
version = "0.5.2"
version = "0.6"

[deps]
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
Expand Down
415 changes: 270 additions & 145 deletions src/ArrayManifold.jl

Large diffs are not rendered by default.

104 changes: 56 additions & 48 deletions src/DecoratorManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ function _split_signature(sig::Expr)
argnames = argnames,
argtypes = argtypes,
kwargs_call = kwargs_call,
fname__parent = Symbol(string(fname) * "__parent"),
fname__transparent = Symbol(string(fname) * "__transparent"),
fname__intransparent = Symbol(string(fname) * "__intransparent"),
)
end

Expand Down Expand Up @@ -146,12 +149,11 @@ macro decorator_transparent_fallback(fallback_case, input_ex)
parts = _split_function(ex)
callargs = parts[:callargs]
where_exprs = parts[:where_exprs]
fname_fallback = Symbol(string(parts.fname) * "__" * string(fallback_case)[2:end])
return esc(
quote
function ($(parts[:fname]))(
$(callargs[1]),
::Val{$fallback_case},
$(callargs[2:end]...);
function ($(fname_fallback))(
$(callargs...);
$(parts[:kwargs_list]...),
) where {$(where_exprs...)}
($(parts[:body]))
Expand Down Expand Up @@ -213,6 +215,7 @@ macro decorator_transparent_function(fallback_case, input_ex)
argnames = parts[:argnames]
argtypes = parts[:argtypes]
kwargs_call = parts[:kwargs_call]
fname_fallback = Symbol(string(parts.fname) * "__" * string(fallback_case)[2:end])

return esc(
quote
Expand All @@ -221,29 +224,30 @@ macro decorator_transparent_function(fallback_case, input_ex)
$(callargs[2:end]...);
$(kwargs_list...),
) where {$(where_exprs...)}
return ($fname)(
$(argnames[1]),
ManifoldsBase._acts_transparently($fname, $(argnames...)),
$(argnames[2:end]...),
;
$(kwargs_call...),
)
transparency = ManifoldsBase._acts_transparently($fname, $(argnames...))
if transparency === Val(:parent)
return ($(parts.fname__parent))($(argnames...); $(kwargs_call...))
elseif transparency === Val(:transparent)
return ($(parts.fname__transparent))($(argnames...); $(kwargs_call...))
elseif transparency === Val(:intransparent)
return ($(parts.fname__intransparent))($(argnames...); $(kwargs_call...))
else
error("incorrect transparency: $transparency")
end
end
function ($fname)(
function ($(parts[:fname__transparent]))(
$(argnames[1])::AbstractDecoratorManifold,
::Val{:transparent},
$(callargs[2:end]...);
$(kwargs_list...),
) where {$(where_exprs...)}
return ($fname)(
decorated_manifold($(argnames[1])),
ManifoldsBase.decorated_manifold($(argnames[1])),
$(argnames[2:end]...);
$(kwargs_call...),
)
end
function ($fname)(
function ($(parts[:fname__intransparent]))(
$(argnames[1])::AbstractDecoratorManifold,
::Val{:intransparent},
$(callargs[2:end]...);
$(kwargs_list...),
) where {$(where_exprs...)}
Expand All @@ -270,9 +274,8 @@ macro decorator_transparent_function(fallback_case, input_ex)
". Maybe you missed to implement this function for a default?",
))
end
function ($fname)(
function ($(parts[:fname__parent]))(
$(argnames[1])::AbstractDecoratorManifold,
::Val{:parent},
$(callargs[2:end]...);
$(kwargs_list...),
) where {$(where_exprs...)}
Expand All @@ -283,9 +286,8 @@ macro decorator_transparent_function(fallback_case, input_ex)
$(kwargs_call...),
)
end
function ($fname)(
function ($fname_fallback)(
$(callargs[1]),
::Val{$fallback_case},
$(callargs[2:end]...);
$(kwargs_list...),
) where {$(where_exprs...)}
Expand Down Expand Up @@ -345,29 +347,29 @@ macro decorator_transparent_signature(ex)
return esc(
quote
function ($fname)($(callargs...); $(kwargs_list...)) where {$(where_exprs...)}
return ($fname)(
$(argnames[1]),
ManifoldsBase._acts_transparently($fname, $(argnames...)),
$(argnames[2:end]...);
$(kwargs_call...),
)
transparency = ManifoldsBase._acts_transparently($fname, $(argnames...))
if transparency === Val(:parent)
return ($(parts.fname__parent))($(argnames...); $(kwargs_call...))
elseif transparency === Val(:transparent)
return ($(parts.fname__transparent))($(argnames...); $(kwargs_call...))
elseif transparency === Val(:intransparent)
return ($(parts.fname__intransparent))($(argnames...); $(kwargs_call...))
else
error("incorrect transparency: $transparency")
end
end
function ($fname)(
$(callargs[1]),
::Val{:transparent},
$(callargs[2:end]...);
function ($(parts[:fname__transparent]))(
$(callargs...);
$(kwargs_list...),
) where {$(where_exprs...)}
return ($fname)(
decorated_manifold($(argnames[1])),
ManifoldsBase.decorated_manifold($(argnames[1])),
$(argnames[2:end]...);
$(kwargs_call...),
)
end
function ($fname)(
$(callargs[1]),
::Val{:intransparent},
$(callargs[2:end]...);
function ($(parts[:fname__intransparent]))(
$(callargs...);
$(kwargs_list...),
) where {$(where_exprs...)}
error_msg = ManifoldsBase.manifold_function_not_implemented_message(
Expand All @@ -377,10 +379,8 @@ macro decorator_transparent_signature(ex)
)
error(error_msg)
end
function ($fname)(
$(callargs[1]),
::Val{:parent},
$(callargs[2:end]...);
function ($(parts[:fname__parent]))(
$(callargs...);
$(kwargs_list...),
) where {$(where_exprs...)}
return invoke(
Expand Down Expand Up @@ -495,28 +495,32 @@ Return the manifold decorated by the decorator `M`. Defaults to `M.manifold`.
"""
decorated_manifold(M::Manifold) = M.manifold


@decorator_transparent_signature distance(M::AbstractDecoratorManifold, p, q)

@decorator_transparent_signature exp(M::AbstractDecoratorManifold, p, X)

@decorator_transparent_signature exp!(M::AbstractDecoratorManifold, q, p, X)

@decorator_transparent_signature hat(M::AbstractDecoratorManifold, p, Xⁱ)

@decorator_transparent_signature hat!(M::AbstractDecoratorManifold, X, p, Xⁱ)

@decorator_transparent_signature injectivity_radius(M::AbstractDecoratorManifold)
@decorator_transparent_signature injectivity_radius(M::AbstractDecoratorManifold, p)
@decorator_transparent_signature injectivity_radius(
M::AbstractDecoratorManifold,
m::AbstractRetractionMethod,
)
@decorator_transparent_signature injectivity_radius(
M::AbstractDecoratorManifold,
m::ExponentialRetraction,
)
@decorator_transparent_signature injectivity_radius(
M::AbstractDecoratorManifold,
p,
m::AbstractRetractionMethod,
)
@decorator_transparent_signature injectivity_radius(
M::AbstractDecoratorManifold,
p,
m::ExponentialRetraction,
)

@decorator_transparent_signature inner(M::AbstractDecoratorManifold, p, X, Y)

Expand Down Expand Up @@ -640,9 +644,13 @@ decorated_manifold(M::Manifold) = M.manifold
q,
m::AbstractVectorTransportMethod,
)

@decorator_transparent_signature vee!(M::AbstractDecoratorManifold, Xⁱ, p, X)

@decorator_transparent_signature vee(M::AbstractDecoratorManifold, p, X)
@decorator_transparent_signature vector_transport_to!(
M::AbstractDecoratorManifold,
Y,
p,
X,
q,
m::ProjectionTransport,
)

@decorator_transparent_signature zero_tangent_vector!(M::AbstractDecoratorManifold, X, p)
63 changes: 57 additions & 6 deletions src/DefaultManifold.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,76 @@ This manifold further illustrates how to type your manifold points and tangent v
that the interface does not require this, but it might be handy in debugging and educative
situations to verify correctness of involved variabes.
"""
struct DefaultManifold{T<:Tuple} <: Manifold where {T} end
DefaultManifold(n::Vararg{Int,N}) where {N} = DefaultManifold{Tuple{n...}}()
struct DefaultManifold{T<:Tuple, 𝔽} <: Manifold where {T, 𝔽} end
DefaultManifold(n::Vararg{Int,N}; field = ℝ) where {N} = DefaultManifold{Tuple{n...}, field}()

function check_manifold_point(M::DefaultManifold, p; kwargs...)
if size(p) != representation_size(M)
return DomainError(
size(p),
"The point $(p) does not lie on $M, since its size is not $(representation_size(M)).",
)
end
return nothing
end

function check_tangent_vector(
M::DefaultManifold,
p,
X;
check_base_point = true,
kwargs...,
)
if check_base_point
perr = check_manifold_point(M, p)
perr === nothing || return perr
end
if size(X) != representation_size(M)
return DomainError(
size(X),
"The vector $(X) is not a tangent to a point on $M since its size does not match $(representation_size(M)).",
)
end
return nothing
end

distance(::DefaultManifold, x, y) = norm(x - y)

exp!(::DefaultManifold, y, x, v) = (y .= x .+ v)

hat!(M::DefaultManifold, X, p, Xⁱ) = copyto!(X, reshape(Xⁱ, representation_size(M)))
function get_basis(M::DefaultManifold, p, B::DefaultOrthonormalBasis)
return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)])
end
function get_basis(M::DefaultManifold, p, B::DefaultOrthogonalBasis)
return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)])
end
function get_basis(M::DefaultManifold, p, B::DefaultBasis)
return CachedBasis(B, [_euclidean_basis_vector(p, i) for i in eachindex(p)])
end
function get_basis(M::DefaultManifold, p, B::DiagonalizingOrthonormalBasis)
vecs = get_vectors(M, p, get_basis(M, p, DefaultOrthonormalBasis()))
eigenvalues = zeros(real(eltype(p)), manifold_dimension(M))
return CachedBasis(B, DiagonalizingBasisData(B.frame_direction, eigenvalues, vecs))
end

function get_coordinates!(M::DefaultManifold, Y, p, X, B::DefaultOrthonormalBasis)
copyto!(Y, reshape(X, manifold_dimension(M)))
return Y
end

@generated manifold_dimension(::DefaultManifold{T}) where {T} = *(T.parameters...)
function get_vector!(M::DefaultManifold, Y, p, X, B::DefaultOrthonormalBasis)
copyto!(Y, reshape(X, representation_size(M)))
return Y
end

injectivity_radius(::DefaultManifold) = Inf

@inline inner(::DefaultManifold, x, v, w) = dot(v, w)

log!(::DefaultManifold, v, x, y) = (v .= y .- x)

@generated manifold_dimension(::DefaultManifold{T,𝔽}) where {T,𝔽} = *(T.parameters...)*real_dimension(𝔽)

norm(::DefaultManifold, x, v) = norm(v)

project_point!(::DefaultManifold, y, x) = copyto!(y, x)
Expand All @@ -49,6 +102,4 @@ function vector_transport_to!(::DefaultManifold, vto, x, v, y, ::ParallelTranspo
return copyto!(vto, v)
end

vee!(M::DefaultManifold, Xⁱ, p, X) = copyto!(Xⁱ, reshape(X, manifold_dimension(M)))

zero_tangent_vector!(::DefaultManifold, v, x) = fill!(v, 0)
Loading

2 comments on commit b58ebf8

@mateuszbaran
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/11004

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if Julia TagBot is installed, or can be done manually through the github interface, or via:

git tag -a v0.6.0 -m "<description of version>" b58ebf8732960085af4b2311221fd6965e641fa2
git push origin v0.6.0

Please sign in to comment.