From 9d35b638e2bd98d707d72e3034f2890a07cef615 Mon Sep 17 00:00:00 2001 From: Xianda Sun <5433119+sunxd3@users.noreply.github.com> Date: Sun, 7 Apr 2024 18:30:13 +0100 Subject: [PATCH] Fix `get` function (#93) * Fix `get` function * fix definition and tests * use `optic` instead of`.optic` * add tests for `set` * fix more test errors * fix doctest * version bump --- Project.toml | 2 +- src/varname.jl | 16 ++-------------- test/varname.jl | 42 ++++++++++++++++++++++++++++++------------ 3 files changed, 33 insertions(+), 27 deletions(-) diff --git a/Project.toml b/Project.toml index d5fe1a2..37cac96 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "7a57a42e-76ec-4ea3-a279-07e840d6d9cf" keywords = ["probablistic programming"] license = "MIT" desc = "Common interfaces for probabilistic programming" -version = "0.8.0" +version = "0.8.1" [deps] AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001" diff --git a/src/varname.jl b/src/varname.jl index adcd568..83d58ce 100644 --- a/src/varname.jl +++ b/src/varname.jl @@ -122,23 +122,11 @@ getoptic(vn::VarName) = vn.optic """ get(obj, vn::VarName{sym}) -Alias for `getoptic(vn)(obj)`. - -# Example - -```jldoctest; setup = :(nt = (a = 1, b = (c = [1, 2, 3],)); name = :nt) -julia> get(nt, @varname(nt.a)) -1 - -julia> get(nt, @varname(nt.b.c[1])) -1 - -julia> get(nt, @varname(\$name.b.c[1])) -1 +Alias for `(PropertyLens{sym}() ⨟ getoptic(vn))(obj)`. ``` """ function Base.get(obj, vn::VarName{sym}) where {sym} - return getoptic(vn)(obj) + return (PropertyLens{sym}() ⨟ getoptic(vn))(obj) end """ diff --git a/test/varname.jl b/test/varname.jl index b117f0e..6488c61 100644 --- a/test/varname.jl +++ b/test/varname.jl @@ -14,6 +14,19 @@ macro test_strict_subsumption(x, y) end end +function test_equal(o1::VarName{sym1}, o2::VarName{sym2}) where {sym1, sym2} + return sym1 === sym2 && test_equal(o1.optic, o2.optic) +end +function test_equal(o1::ComposedFunction, o2::ComposedFunction) + return test_equal(o1.inner, o2.inner) && test_equal(o1.outer, o2.outer) +end +function test_equal(o1::Accessors.IndexLens, o2::Accessors.IndexLens) + return test_equal(o1.indices, o2.indices) +end +function test_equal(o1, o2) + return o1 == o2 +end + @testset "varnames" begin @testset "construction & concretization" begin i = 1:10 @@ -27,14 +40,22 @@ end # concretization y = zeros(10, 10) - x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], ); + x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0],); @test @varname(y[begin, i], true) == @varname(y[1, 1:10]) - @test get(y, @varname(y[:], true)) == get(y, @varname(y[1:100])) - @test get(y, @varname(y[:, begin], true)) == get(y, @varname(y[1:10, 1])) - @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === + @test test_equal(@varname(y[:], true), @varname(y[1:100])) + @test test_equal(@varname(y[:, begin], true), @varname(y[1:10, 1])) + @test getoptic(AbstractPPL.concretize(@varname(y[:]), y)).indices[1] === AbstractPPL.ConcretizedSlice(to_indices(y, (:,))[1]) - @test get(x, @varname(x.a[1:end, end][:], true)) == get(x, @varname(x.a[1:3,2][1:3])) + @test test_equal(@varname(x.a[1:end, end][:], true), @varname(x.a[1:3,2][1:3])) + end + + @testset "get & set" begin + x = (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 1.0); + @test get(x, @varname(a[1, 2])) == 2.0 + @test get(x, @varname(b)) == 1.0 + @test set(x, @varname(a[1, 2]), 10) == (a = [1.0 10.0; 3.0 4.0; 5.0 6.0], b = 1.0) + @test set(x, @varname(b), 10) == (a = [1.0 2.0; 3.0 4.0; 5.0 6.0], b = 10.0) end @testset "subsumption with standard indexing" begin @@ -83,10 +104,10 @@ end @testset "non-standard indexing" begin A = rand(10, 10) - @test get(A, @varname(A[1, Not(3)], true)) == get(A, @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]])) + @test test_equal(@varname(A[1, Not(3)], true), @varname(A[1, [1, 2, 4, 5, 6, 7, 8, 9, 10]])) B = OffsetArray(A, -5, -5) # indices -4:5×-4:5 - @test collect(get(B, @varname(B[1, :], true))) == collect(get(B, @varname(B[1, -4:5]))) + @test test_equal(@varname(B[1, :], true), @varname(B[1, -4:5])) end @testset "type stability" begin @@ -96,15 +117,12 @@ end @inferred VarName{:a}(PropertyLens(:b)) @inferred VarName{:a}(Accessors.opcompose(IndexLens(1), PropertyLens(:b))) - a = [1, 2, 3] - @inferred get(a, @varname(a[1])) - b = (a=[1, 2, 3],) - @inferred get(b, @varname(b.a[1])) + @inferred get(b, @varname(a[1])) @inferred Accessors.set(b, @varname(a[1]), 10) c = (b=(a=[1, 2, 3],),) - @inferred get(c, @varname(c.b.a[1])) + @inferred get(c, @varname(b.a[1])) @inferred Accessors.set(c, @varname(b.a[1]), 10) end end