From daa8340a77f9c811ddab990817a3fa97dd0a68c7 Mon Sep 17 00:00:00 2001 From: Tor Erlend Fjelde Date: Mon, 29 Mar 2021 11:53:21 +0200 Subject: [PATCH] Fix and simplification of Logit (#172) * Update logit.jl Simplified Logit implementation a bit and fixed constructor * added back _logit thanks to @devmotion * removed specialized _logit to real arguments * version bump * Update src/bijectors/logit.jl Co-authored-by: David Widmann --- Project.toml | 2 +- src/bijectors/logit.jl | 22 ++++++++-------------- 2 files changed, 9 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 76911a9f..f6646751 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "Bijectors" uuid = "76274a88-744f-5084-9051-94815aaf08c4" -version = "0.8.15" +version = "0.8.16" [deps] ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197" diff --git a/src/bijectors/logit.jl b/src/bijectors/logit.jl index 945a1b19..ff1cf554 100644 --- a/src/bijectors/logit.jl +++ b/src/bijectors/logit.jl @@ -7,9 +7,11 @@ struct Logit{N, T<:Real} <: Bijector{N} a::T b::T end -function Logit(a, b) +Logit(a::Real, b::Real) = Logit{0}(a, b) +Logit(a::AbstractArray{<:Real, N}, b::AbstractArray{<:Real, N}) where {N} = Logit{N}(a, b) +function Logit{N}(a, b) where {N} T = promote_type(typeof(a), typeof(b)) - Logit{0, T}(a, b) + Logit{N, T}(a, b) end # fields are numerical parameters @@ -25,22 +27,14 @@ up1(b::Logit{N, T}) where {N, T} = Logit{N + 1, T}(b.a, b.b) # For equality of Logit with Float64 fields to one with Duals Base.:(==)(b1::Logit, b2::Logit) = b1.a == b2.a && b1.b == b2.b -(b::Logit{0})(x::Real) = _logit(x, b.a, b.b) -(b::Logit{0})(x) = _logit.(x, b.a, b.b) -(b::Logit{1})(x::AbstractVector) = _logit.(x, b.a, b.b) -(b::Logit{1})(x::AbstractMatrix) = _logit.(x, b.a, b.b) -(b::Logit{2})(x::AbstractMatrix) = _logit.(x, b.a, b.b) -(b::Logit{2})(x::AbstractArray{<:AbstractMatrix}) = map(b, x) +(b::Logit)(x) = _logit.(x, b.a, b.b) +(b::Logit)(x::AbstractArray{<:AbstractArray}) = map(b, x) _logit(x, a, b) = logit((x - a) / (b - a)) -(ib::Inverse{<:Logit{0}})(y::Real) = _ilogit(y, ib.orig.a, ib.orig.b) -(ib::Inverse{<:Logit{0}})(y) = _ilogit.(y, ib.orig.a, ib.orig.b) -(ib::Inverse{<:Logit{1}})(x::AbstractVecOrMat) = _ilogit.(x, ib.orig.a, ib.orig.b) -(ib::Inverse{<:Logit{2}})(x::AbstractMatrix) = _ilogit.(x, ib.orig.a, ib.orig.b) -(ib::Inverse{<:Logit{2}})(x::AbstractArray{<:AbstractMatrix}) = map(ib, x) +(ib::Inverse{<:Logit})(y) = _ilogit.(y, ib.orig.a, ib.orig.b) +(ib::Inverse{<:Logit})(x::AbstractArray{<:AbstractArray}) = map(ib, x) _ilogit(y, a, b) = (b - a) * logistic(y) + a -logabsdetjac(b::Logit{0}, x::Real) = logit_logabsdetjac(x, b.a, b.b) logabsdetjac(b::Logit{0}, x) = logit_logabsdetjac.(x, b.a, b.b) logabsdetjac(b::Logit{1}, x::AbstractVector) = sum(logit_logabsdetjac.(x, b.a, b.b)) logabsdetjac(b::Logit{1}, x::AbstractMatrix) = vec(sum(logit_logabsdetjac.(x, b.a, b.b), dims = 1))