diff --git a/Project.toml b/Project.toml index 9dbb81c..85c389e 100644 --- a/Project.toml +++ b/Project.toml @@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f" license = "MIT" desc = "Tape based task copying in Turing" repo = "https://github.com/TuringLang/Libtask.jl.git" -version = "0.8" +version = "0.8.1" [deps] FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e" diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl index 88a7b90..1f78c42 100644 --- a/src/tapedfunction.jl +++ b/src/tapedfunction.jl @@ -58,9 +58,9 @@ mutable struct TapedFunction{F, TapeType} binding_values::Bindings arg_binding_slots::Vector{Int} # arg indices in binding_values retval_binding_slot::Int # 0 indicates the function has not returned - deepcopy_types::Vector{Any} + deepcopy_types::Type # use a Union type for multiple types - function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=[]) where {F, T} + function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T} args_type = _accurate_typeof.(args) cache_key = (f, args_type...) @@ -78,7 +78,7 @@ mutable struct TapedFunction{F, TapeType} return tf end - TapedFunction(f, args...; cache=false, deepcopy_types=[]) = + TapedFunction(f, args...; cache=false, deepcopy_types=Union{}) = TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types) function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1} @@ -472,7 +472,7 @@ tape_shallowcopy(x::Core.Box) = Core.Box(tape_shallowcopy(x.contents)) tape_deepcopy(x::Core.Box) = Core.Box(tape_deepcopy(x.contents)) function _tape_copy(v, deepcopy_types) - if any(t -> isa(v, t), deepcopy_types) + if isa(v, deepcopy_types) tape_deepcopy(v) else tape_shallowcopy(v) diff --git a/src/tapedtask.jl b/src/tapedtask.jl index 150cdfe..c5e5d22 100644 --- a/src/tapedtask.jl +++ b/src/tapedtask.jl @@ -65,12 +65,13 @@ end # NOTE: evaluating model without a trace, see # https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329 -function TapedTask(f, args...; deepcopy_types=[Array, Ref]) # deepcoy Array and Ref by default. +function TapedTask(f, args...; deepcopy_types=Union{Array, Ref}) # deepcoy Array and Ref by default. tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy_types) TapedTask(tf, args...) end -TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...) +TapedTask(finfo::Tuple{Any, Type}, args...) = TapedTask(finfo[1], args...; deepcopy_types=finfo[2]) +TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...; deepcopy_types=t.tf.deepcopy_types) func(t::TapedTask) = t.tf.func #= diff --git a/test/tapedtask.jl b/test/tapedtask.jl index e49d5bb..f55f83e 100644 --- a/test/tapedtask.jl +++ b/test/tapedtask.jl @@ -1,4 +1,20 @@ @testset "tapedtask" begin + @testset "construction" begin + function f() + t = 1 + while true + produce(t) + t = 1 + t + end + end + + ttask = TapedTask(f) + @test consume(ttask) == 1 + + ttask = TapedTask((f, Union{})) + @test consume(ttask) == 1 + end + @testset "iteration" begin function f() t = 1