From 5a90e058831f1e3cf330f29434d02a3ae4b4e03c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Jun 2023 21:22:42 +0200 Subject: [PATCH 1/2] Add CDFBijector and QuantileBijector --- src/bijectors/cdf_quantile.jl | 91 +++++++++++++++++++++++++++++++++++ src/interface.jl | 1 + 2 files changed, 92 insertions(+) create mode 100644 src/bijectors/cdf_quantile.jl diff --git a/src/bijectors/cdf_quantile.jl b/src/bijectors/cdf_quantile.jl new file mode 100644 index 00000000..17ac6831 --- /dev/null +++ b/src/bijectors/cdf_quantile.jl @@ -0,0 +1,91 @@ +""" + CDFBijector(dist::Distributions.ContinuousUnivariateDistribution) + +A [`Bijector`](@ref) that transforms the input from the support of the given distribution to +the unit interval using the cumulative distribution function of the distribution. + +The inverse is [`QuantileBijector`](@ref). + +# Example + +```jldoctest +julia> using Bijectors: CDFBijector + +julia> using Distributions: Normal + +julia> b = CDFBijector(Normal()); + +julia> p = [0.1, 0.5, 0.9]; + +julia> transform(b, quantile.(Normal(), p)) ≈ p +true +``` +""" +struct CDFBijector{D<:ContinuousUnivariateDistribution} <: Bijector + dist::D +end + +Base.:(==)(b1::CDFBijector, b2::CDFBijector) = b1.dist == b2.dist + +Functors.@functor CDFBijector + +function Base.show(io::IO, b::CDFBijector) + print(io, "CDFBijector(") + print(io, b.dist) + print(io, ")") + return nothing +end + +with_logabsdet_jacobian(b::CDFBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(b::CDFBijector, x) = Distributions.cdf.(b.dist, x) + +logabsdetjac(b::CDFBijector, x) = Distributions.logpdf.(b.dist, x) + + +""" + QuantileBijector(dist::Distributions.ContinuousUnivariateDistribution) + +A [`Bijector`](@ref) that transforms the input from the unit interval to the support of the +given distribution using the quantile function of the distribution. + +The inverse is [`CDFBijector`](@ref). + +# Example + +```jldoctest +julia> using Bijectors: QuantileBijector + +julia> using Distributions: Gamma + +julia> b = QuantileBijector(Gamma()); + +julia> p = [0.1, 0.5, 0.9]; + +julia> transform(b, p) ≈ quantile.(Gamma(), p) +true +``` +""" +struct QuantileBijector{D<:ContinuousUnivariateDistribution} <: Bijector + dist::D +end + +Base.:(==)(b1::QuantileBijector, b2::QuantileBijector) = b1.dist == b2.dist + +Functors.@functor QuantileBijector + +function Base.show(io::IO, b::QuantileBijector) + print(io, "QuantileBijector(") + print(io, b.dist) + print(io, ")") + return nothing +end + +with_logabsdet_jacobian(b::QuantileBijector, x) = transform(b, x), logabsdetjac(b, x) + +transform(b::QuantileBijector, x) = Distributions.quantile.(b.dist, x) + +logabsdetjac(b::QuantileBijector, x) = @. -Distributions.logpdf(b.dist, x) + +inverse(b::CDFBijector) = QuantileBijector(b.dist) +inverse(b::QuantileBijector) = CDFBijector(b.dist) diff --git a/src/interface.jl b/src/interface.jl index 099df1bb..7c5a9fea 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -249,6 +249,7 @@ include("bijectors/corr.jl") include("bijectors/truncated.jl") include("bijectors/named_bijector.jl") include("bijectors/ordered.jl") +include("bijectors/cdf_quantile.jl") # Normalizing flow related include("bijectors/planar_layer.jl") From 24277f463b627f10e1cf6eac77e4a421e17ce9b8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 19 Jun 2023 21:30:30 +0200 Subject: [PATCH 2/2] Update src/bijectors/cdf_quantile.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/bijectors/cdf_quantile.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/bijectors/cdf_quantile.jl b/src/bijectors/cdf_quantile.jl index 17ac6831..0697458d 100644 --- a/src/bijectors/cdf_quantile.jl +++ b/src/bijectors/cdf_quantile.jl @@ -42,7 +42,6 @@ transform(b::CDFBijector, x) = Distributions.cdf.(b.dist, x) logabsdetjac(b::CDFBijector, x) = Distributions.logpdf.(b.dist, x) - """ QuantileBijector(dist::Distributions.ContinuousUnivariateDistribution)