Skip to content

Commit

Permalink
Backport some features from #100 (#102)
Browse files Browse the repository at this point in the history
* Fix unbox condition function (ref #100)

* Port new produce mechanism from #100.

* Minor bugfixes.

* Fix new produce mechanism.

* Update src/tapedtask.jl

* Update src/tapedtask.jl

* Update src/tapedtask.jl

* Update src/tapedtask.jl
  • Loading branch information
yebai authored Jan 4, 2022
1 parent 48703aa commit 8323952
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 14 deletions.
14 changes: 9 additions & 5 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@ end
function Base.show(io::IO, instruction::Instruction)
fun = instruction.fun
tape = instruction.tape
println(io, "Instruction($(fun)), tape=$(objectid(tape)))")
println(io, "Instruction($(fun)$(map(val, instruction.input)), tape=$(objectid(tape)))")
end

function Base.show(io::IO, tp::Tape)
Expand Down Expand Up @@ -75,7 +75,8 @@ function run_and_record!(tape::Tape, f, args...)
f = val(f) # f maybe a Boxed closure
output = try
box(f(map(val, args)...))
catch
catch e
@warn e
any_box(nothing)
end
ins = Instruction(f, args, output, tape)
Expand All @@ -94,11 +95,14 @@ end
function unbox_condition(ir)
for blk in IRTools.blocks(ir)
vars = keys(blk)
for br in IRTools.branches(blk)
brs = IRTools.branches(blk)
for (i, br) in enumerate(brs)
IRTools.isconditional(br) || continue
cond = br.condition
prev_cond = IRTools.insert!(ir, cond, ir[cond])
ir[cond] = IRTools.xcall(@__MODULE__, :val, prev_cond)
new_cond = IRTools.push!(
blk,
IRTools.xcall(@__MODULE__, :val, cond))
brs[i] = IRTools.Branch(br; condition=new_cond)
end
end
end
Expand Down
32 changes: 23 additions & 9 deletions src/tapedtask.jl
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,12 @@ struct TapedTask
counter::Ref{Int}
produce_ch::Channel{Any}
consume_ch::Channel{Int}
produced_val::Vector{Any}

function TapedTask(
t::Task, tf::TapedFunction, counter, pch::Channel{Any}, cch::Channel{Int})
new(t, tf, counter, pch, cch, Any[])
end
end

function TapedTask(tf::TapedFunction, args...)
Expand Down Expand Up @@ -35,24 +41,33 @@ function TapedTask(tf::TapedFunction, args...)
close(consume_ch)
end
t = TapedTask(task, tf, counter, produce_ch, consume_ch)
# task.storage === nothing && (task.storage = IdDict())
# task.storage[:tapedtask] = t
task.storage === nothing && (task.storage = IdDict())
task.storage[:tapedtask] = t
tf.owner = t
return t
end

TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
# Issue: evaluating model without a trace, see
# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329
TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...)
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...)
func(t::TapedTask) = t.tf.func

function step_in(tf::TapedFunction, counter::Ref{Int}, args)
len = length(tf.tape)
if(counter[] <= 1)
if(counter[] <= 1 && length(args) > 0)
input = map(box, args)
tf.tape[1].input = input
end
while counter[] <= len
tf.tape[counter[]]()
# produce and wait after an instruction is done
ttask = tf.owner
if length(ttask.produced_val) > 0
val = pop!(ttask.produced_val)
put!(ttask.produce_ch, val)
take!(ttask.consume_ch) # wait for next consumer
end
counter[] += 1
end
end
Expand All @@ -76,7 +91,7 @@ function (instr::Instruction{typeof(produce)})()
internal_produce(instr, args)
end

#=

# Another way to support `produce` in nested call. This way has its caveat:
# `produce` may deeply hide in an instruction, but not be an instruction
# itself, and when we copy a task, the newly copied task will resume from
Expand All @@ -95,11 +110,10 @@ end
function produce(val)
is_in_tapedtask() || return nothing
ttask = current_task().storage[:tapedtask]
put!(ttask.produce_ch, val)
take!(ttask.consume_ch) # wait for next consumer
return nothing
length(ttask.produced_val) > 1 &&
error("There is a produced value which is not consumed.")
push!(ttask.produced_val, val)
end
=#

function consume(ttask::TapedTask)
if istaskstarted(ttask.task)
Expand Down

0 comments on commit 8323952

Please sign in to comment.