Skip to content

Commit

Permalink
making rqs compatible with float32 input (#267)
Browse files Browse the repository at this point in the history
* making rqs compatible with float32 input

* minor format edit

* Update Format.yml

* fix format

* Update Format.yml

* Update Format.yml

* save additional allocations in rqs layer

Co-authored-by: David Widmann <[email protected]>

* rm allocations in rqs layer

Co-authored-by: David Widmann <[email protected]>

* bump version to 0.12.6

* add tests for rqs

---------

Co-authored-by: Tor Erlend Fjelde <[email protected]>
Co-authored-by: David Widmann <[email protected]>
  • Loading branch information
3 people authored Jun 10, 2023
1 parent dd8a24b commit 24fa396
Show file tree
Hide file tree
Showing 4 changed files with 49 additions and 5 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/Format.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,4 +35,4 @@ jobs:
if: github.event_name == 'pull_request'
with:
tool_name: JuliaFormatter
fail_on_error: true
fail_on_error: true
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "Bijectors"
uuid = "76274a88-744f-5084-9051-94815aaf08c4"
version = "0.12.5"
version = "0.12.6"

[deps]
ArgCheck = "dce04be8-c92d-5529-be00-80e4d2c0e197"
Expand Down
6 changes: 3 additions & 3 deletions src/bijectors/rational_quadratic_spline.jl
Original file line number Diff line number Diff line change
Expand Up @@ -100,8 +100,8 @@ function RationalQuadraticSpline(
widths::A, heights::A, derivatives::A, B::T2
) where {T1,T2,A<:AbstractVector{T1}}
return RationalQuadraticSpline(
(cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .- 0.5) * 2 * B,
(cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .- 0.5) * 2 * B,
cumsum(vcat([zero(T1)], LogExpFunctions.softmax(widths))) .* (2 * B) .- B,
cumsum(vcat([zero(T1)], LogExpFunctions.softmax(heights))) .* (2 * B) .- B,
vcat([one(T1)], LogExpFunctions.log1pexp.(derivatives), [one(T1)]),
)
end
Expand All @@ -118,7 +118,7 @@ function RationalQuadraticSpline(
)

return RationalQuadraticSpline(
(2 * B) .* (cumsum(ws; dims=2) .- 0.5), (2 * B) .* (cumsum(hs; dims=2) .- 0.5), ds
(2 * B) .* cumsum(ws; dims=2) .- B, (2 * B) .* cumsum(hs; dims=2) .- B, ds
)
end

Expand Down
44 changes: 44 additions & 0 deletions test/bijectors/rational_quadratic_spline.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Test
using Bijectors
using Bijectors: RationalQuadraticSpline
using LogExpFunctions

@testset "RationalQuadraticSpline" begin
# Monotonic spline on '[-B, B]' with `K` intermediate knots/"connection points".
Expand Down Expand Up @@ -59,4 +60,47 @@ using Bijectors: RationalQuadraticSpline
x = [-5.0, 5.0]
test_bijector(b, x; y=x, logjac=zero(eltype(x)))
end

@testset "Float32 support" begin
ws = randn(Float32, K)
hs = randn(Float32, K)
ds = randn(Float32, K - 1)

Ws = randn(Float32, d, K)
Hs = randn(Float32, d, K)
Ds = randn(Float32, d, K - 1)

# success of construction
b = RationalQuadraticSpline(ws, hs, ds, B)
bb = RationalQuadraticSpline(Ws, Hs, Ds, B)
end

@testset "consistency after commit" begin
ws = randn(K)
hs = randn(K)
ds = randn(K - 1)

Ws = randn(d, K)
Hs = randn(d, K)
Ds = randn(d, K - 1)

Ws_t = hcat(zeros(size(Ws, 1)), LogExpFunctions.softmax(Ws; dims=2))
Hs_t = hcat(zeros(size(Ws, 1)), LogExpFunctions.softmax(Hs; dims=2))

# success of construction
b = RationalQuadraticSpline(ws, hs, ds, B)
b_mv = RationalQuadraticSpline(Ws, Hs, Ds, B)

# consistency of evaluation
@test all(
(cumsum(vcat([zero(Float64)], LogExpFunctions.softmax(ws))) .- 0.5) * 2 * B .≈
b.widths,
)
@test all(
(cumsum(vcat([zero(Float64)], LogExpFunctions.softmax(hs))) .- 0.5) * 2 * B .≈
b.heights,
)
@test all((2 * B) .* (cumsum(Ws_t; dims=2) .- 0.5) .≈ b_mv.widths)
@test all((2 * B) .* (cumsum(Hs_t; dims=2) .- 0.5) .≈ b_mv.heights)
end
end

2 comments on commit 24fa396

@yebai
Copy link
Member

@yebai yebai commented on 24fa396 Jun 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/85252

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.12.6 -m "<description of version>" 24fa396f44084dae0f9751d740a8b01d4862a893
git push origin v0.12.6

Please sign in to comment.