Skip to content

Commit

Permalink
fix rng argument position in tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Red-Portal committed Oct 24, 2023
1 parent 7a92708 commit 5dd434d
Show file tree
Hide file tree
Showing 5 changed files with 9 additions and 9 deletions.
2 changes: 1 addition & 1 deletion test/inference/advi_distributionsad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ using Test
seed = (0x38bef07cf9cc549d)
rng = StableRNG(seed)

modelstats = modelconstr(realtype; rng)
modelstats = modelconstr(rng, realtype)
@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

T, η = is_meanfield ? (5_000, 1e-2) : (30_000, 1e-3)
Expand Down
4 changes: 2 additions & 2 deletions test/interface/advi.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Test
rng = StableRNG(seed)

@testset "with bijector" begin
modelstats = normallognormal_meanfield(Float64; rng)
modelstats = normallognormal_meanfield(rng, Float64)

@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

Expand All @@ -31,7 +31,7 @@ using Test
end

@testset "without bijector" begin
modelstats = normal_meanfield(Float64; rng)
modelstats = normal_meanfield(rng, Float64)

@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

Expand Down
2 changes: 1 addition & 1 deletion test/interface/optimize.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ using Test
rng = StableRNG(seed)

T = 1000
modelstats = normallognormal_meanfield(Float64; rng)
modelstats = normallognormal_meanfield(rng, Float64)

@unpack model, μ_true, L_true, n_dims, is_meanfield = modelstats

Expand Down
4 changes: 2 additions & 2 deletions test/models/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ function LogDensityProblems.capabilities(::Type{<:TestNormal})
LogDensityProblems.LogDensityOrder{0}()
end

function normal_fullrank(realtype; rng = default_rng())
function normal_fullrank(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ = randn(rng, realtype, n_dims)
Expand All @@ -29,7 +29,7 @@ function normal_fullrank(realtype; rng = default_rng())
TestModel(model, μ, L, n_dims, false)
end

function normal_meanfield(realtype; rng = default_rng())
function normal_meanfield(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ = randn(rng, realtype, n_dims)
Expand Down
6 changes: 3 additions & 3 deletions test/models/normallognormal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ function Bijectors.bijector(model::NormalLogNormal)
[1:1, 2:1+length(μ_y)])
end

function normallognormal_fullrank(realtype; rng = default_rng())
function normallognormal_fullrank(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ_x = randn(rng, realtype)
Expand All @@ -43,12 +43,12 @@ function normallognormal_fullrank(realtype; rng = default_rng())
Σ = Σ |> Hermitian

μ = vcat(μ_x, μ_y)
L = cholesky(Σ).L |> LowerTriangular
L = cholesky(Σ).L

TestModel(model, μ, L, n_dims+1, false)
end

function normallognormal_meanfield(realtype; rng = default_rng())
function normallognormal_meanfield(rng::Random.AbstractRNG, realtype::Type)
n_dims = 5

μ_x = randn(rng, realtype)
Expand Down

0 comments on commit 5dd434d

Please sign in to comment.