Skip to content

Commit

Permalink
Merge remote-tracking branch 'origin/main' into pc-eval-#48
Browse files Browse the repository at this point in the history
  • Loading branch information
PavanChaggar committed Apr 17, 2022
2 parents da00722 + ed16f90 commit bf030c3
Show file tree
Hide file tree
Showing 2 changed files with 22 additions and 15 deletions.
14 changes: 7 additions & 7 deletions src/graphinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,11 @@ adjacency matrix and topologically ordered vertex list and stored.
GraphInfo is instantiated using the `Model` constctor.
"""

struct GraphInfo{T} <: AbstractModelTrace
input::NamedTuple{T}
value::NamedTuple{T}
eval::NamedTuple{T}
kind::NamedTuple{T}
struct GraphInfo{Tnames, Tinput, Tvalue, Teval, Tkind} <: AbstractModelTrace
input::NamedTuple{Tnames, Tinput}
value::NamedTuple{Tnames, Tvalue}
eval::NamedTuple{Tnames, Teval}
kind::NamedTuple{Tnames, Tkind}
A::SparseMatrixCSC
sorted_vertices::Vector{Symbol}
end
Expand Down Expand Up @@ -58,8 +58,8 @@ y = (value = 0.0, input = (:μ, :s2), eval = var"#7#10"(), kind = :Stochastic)
```
"""

struct Model{T} <: AbstractProbabilisticProgram
g::GraphInfo{T}
struct Model{Tnames, Tinput, Tvalue, Teval, Tkind} <: AbstractProbabilisticProgram
g::GraphInfo{Tnames, Tinput, Tvalue, Teval, Tkind}
end

function Model(;kwargs...)
Expand Down
23 changes: 15 additions & 8 deletions test/graphinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,11 @@ model = (
m = Model(; zip(keys(model), values(model))...) # uses Model(; kwargs...) constructor

# test the type of the model is correct
@test typeof(m) <: Model
@test m isa Model
sorted_vertices = get_sorted_vertices(m)
@test typeof(m) == Model{Tuple(sorted_vertices)}
@test typeof(m.g) <: GraphInfo <: AbstractModelTrace
@test typeof(m.g) == GraphInfo{Tuple(sorted_vertices)}
@test m isa Model{Tuple(sorted_vertices)}
@test m.g isa GraphInfo <: AbstractModelTrace
@test m.g isa GraphInfo{Tuple(sorted_vertices)}

# test the dag is correct
A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
Expand All @@ -42,11 +42,18 @@ A = sparse([0 0 0 0 0; 0 0 0 0 0; 0 0 0 0 0; 0 1 1 0 0; 1 0 0 1 0])
@test length(m) == 5
@test eltype(m) == valtype(m)


# check the values from the NamedTuple match the values in the fields of GraphInfo
vals, evals, kinds = AbstractPPL.GraphPPL.getvals(NamedTuple{Tuple(sorted_vertices)}(model))
inputs = (s2 = (), xmat = (), β = (), μ = (:xmat, ), y = (, :s2))

for (i, vn) in enumerate(keys(m))
@inferred m[vn]
@inferred get_node_value(m, vn)
@inferred get_node_eval(m, vn)
@inferred get_nodekind(m, vn)
@inferred get_node_input(m, vn)

@test vn isa VarName
@test get_node_value(m, vn) == vals[i]
@test get_node_eval(m, vn) == evals[i]
Expand All @@ -55,16 +62,16 @@ for (i, vn) in enumerate(keys(m))
end

for node in m
@test typeof(node) <: NamedTuple{fieldnames(GraphInfo)[1:4]}
@test node isa NamedTuple{fieldnames(GraphInfo)[1:4]}
end

# test Model constructor for model with single parent node
single_parent_m = Model= (1.0, () -> 3, :Logical), y = (1.0, (μ) -> MvNormal(μ, sqrt(1)), :Stochastic))
@test typeof(single_parent_m) == Model{(, :y)}
@test typeof(single_parent_m.g) == GraphInfo{(, :y)}
@test single_parent_m isa Model{(, :y)}
@test single_parent_m.g isa GraphInfo{(, :y)}

# test setindex

# test setindex
@test_throws AssertionError set_node_value!(m, @varname(s2), [0.0])
@test_throws AssertionError set_node_value!(m, @varname(s2), (1.0,))
set_node_value!(m, @varname(s2), 2.0)
Expand Down

0 comments on commit bf030c3

Please sign in to comment.