Skip to content

Commit

Permalink
Merge pull request #29 from TuringLang/mt/remove_broadcast
Browse files Browse the repository at this point in the history
broadcast -> map
  • Loading branch information
mohamed82008 authored Jul 21, 2019
2 parents ed45d7c + 4b53a87 commit 57241d2
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions src/Bijectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,18 +366,18 @@ end
using Distributions: UnivariateDistribution

link(d::UnivariateDistribution, x::Real) = x
link(d::UnivariateDistribution, x::AbstractVector{<:Real}) = link.(Ref(d), x)
link(d::UnivariateDistribution, x::AbstractVector{<:Real}) = map(x -> link(d, x), x)

invlink(d::UnivariateDistribution, y::Real) = y
invlink(d::UnivariateDistribution, y::AbstractVector{<:Real}) = invlink.(Ref(d), y)
invlink(d::UnivariateDistribution, y::AbstractVector{<:Real}) = map(y -> invlink(d, y), y)

logpdf_with_trans(d::UnivariateDistribution, x::Real, ::Bool) = logpdf(d, x)
function logpdf_with_trans(
d::UnivariateDistribution,
x::AbstractVector{<:Real},
transform::Bool,
)
return logpdf_with_trans.(Ref(d), x, transform)
return map(x -> logpdf_with_trans(d, x, transform), x)
end

# MultivariateDistributions
Expand All @@ -404,11 +404,11 @@ end
using Distributions: MatrixDistribution

link(d::MatrixDistribution, X::AbstractMatrix{<:Real}) = copy(X)
link(d::MatrixDistribution, X::AbstractVector{<:AbstractMatrix{<:Real}}) = link.(Ref(d), X)
link(d::MatrixDistribution, X::AbstractVector{<:AbstractMatrix{<:Real}}) = map(x -> link(d, x), X)

invlink(d::MatrixDistribution, Y::AbstractMatrix{<:Real}) = copy(Y)
function invlink(d::MatrixDistribution, Y::AbstractVector{<:AbstractMatrix{<:Real}})
return invlink.(Ref(d), Y)
return map(y -> invlink(d, y), Y)
end

logpdf_with_trans(d::MatrixDistribution, X::AbstractMatrix{<:Real}, ::Bool) = logpdf(d, X)
Expand All @@ -417,7 +417,7 @@ function logpdf_with_trans(
X::AbstractVector{<:AbstractMatrix{<:Real}},
transform::Bool,
)
return logpdf_with_trans.(Ref(d), X, Ref(transform))
return map(x -> logpdf_with_trans(d, x, transform), X)
end

end # module

0 comments on commit 57241d2

Please sign in to comment.