Skip to content

Commit

Permalink
Handle empty brackets in find_alpha (#205)
Browse files Browse the repository at this point in the history
* Handle empty brackets in `find_alpha`

* Bump version
  • Loading branch information
devmotion authored Oct 17, 2021
1 parent 0cc45d3 commit c50ce62
Show file tree
Hide file tree
Showing 3 changed files with 28 additions and 3 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.9.9"
version = "0.9.10"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
11 changes: 9 additions & 2 deletions src/bijectors/planar_layer.jl
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,17 @@ function find_alpha(wt_y::Real, wt_u_hat::Real, b::Real)
end
function find_alpha(wt_y::T, wt_u_hat::T, b::T) where {T<:Real}
# Compute the initial bracket (see above).
initial_bracket = (wt_y - abs(wt_u_hat), wt_y + abs(wt_u_hat))
abs_wt_u_hat = abs(wt_u_hat)
lower = float(wt_y - abs_wt_u_hat)
upper = float(wt_y + abs_wt_u_hat)

# Handle empty brackets (https://github.com/TuringLang/Bijectors.jl/issues/204)
if lower == upper
return lower
end

# Solve the root-finding problem
α0 = Roots.find_zero(initial_bracket) do α
α0 = Roots.find_zero((lower, upper)) do α
return α + wt_u_hat * tanh+ b) - wt_y
end

Expand Down
18 changes: 18 additions & 0 deletions test/norm_flows.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,24 @@ end

test_functor(flow, (w = w, u = u, b = b))
test_functor(inv(flow), (orig = flow,))

@testset "find_alpha" begin
for wt_y in (-20.3, -3, -3//2, 0.0, 5, 29//4, 12.3)
# the root finding algorithm assumes wt_u_hat ≥ -1 (satisfied for the flow)
# |wt_u_hat| < eps checks that empty brackets are handled correctly
# https://github.com/TuringLang/Bijectors.jl/issues/204
for wt_u_hat in (-1, -1//2, -1e-20, 0, 1e-20, 3, 11//3, 17.2)
for b in (-19.3, -8//3, -1, 0.0, 1//2, 3, 4.3)
# find α that solves wt_y = α + wt_u_hat * tanh(α + b)
α = @inferred(Bijectors.find_alpha(wt_y, wt_u_hat, b))

# check if α is an approximate solution to the considered equation
# have to set atol if wt_y is zero (otherwise only equality is checked)
@test wt_y α + wt_u_hat * tanh+ b) atol=iszero(wt_y) ? 1e-14 : 0.0
end
end
end
end
end

@testset "RadialLayer" begin
Expand Down

2 comments on commit c50ce62

@devmotion
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/46913

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.9.10 -m "<description of version>" c50ce621bbbd3d2edc81ae87907dfb4e87059dbc
git push origin v0.9.10

Please sign in to comment.