From 333b279acbd575bea51f40c792e3bf83a2ce8b3f Mon Sep 17 00:00:00 2001 From: "Killian Q. Zhuo" Date: Thu, 20 Jan 2022 19:46:22 +0800 Subject: [PATCH] Merge Tape and TapedFunction (#105) * merge Tape and TapedFunction * Tape => RawTape * trivias update * remove gettape * merge run(::RawTape) and (tf::TapedFunction)(args...) * minor update --- perf/p0.jl | 6 +- perf/p2.jl | 4 +- src/tapedfunction.jl | 191 +++++++++++++++++++------------------------ src/tapedtask.jl | 50 ++++++----- test/tf.jl | 2 +- 5 files changed, 114 insertions(+), 139 deletions(-) diff --git a/perf/p0.jl b/perf/p0.jl index b2ae3c57..02f20015 100644 --- a/perf/p0.jl +++ b/perf/p0.jl @@ -29,9 +29,9 @@ args = m.evaluator[2:end]; @show "CTask construction..." t = @btime Libtask.CTask(f, args...) # schedule(t.task) # work fine! -# @show Libtask.result(t.tf.tape) +# @show Libtask.result(t.tf) @show "Step in a tape..." -@btime Libtask.step_in(t.tf.tape, args) +@btime Libtask.step_in(t.tf, args) # Case 2: SMC sampler @@ -44,4 +44,4 @@ t = @btime Libtask.CTask(m.evaluator[1], m.evaluator[2:end]...); # schedule(t.task) # @show Libtask.result(t.tf.tape) @show "Step in a tape..." -@btime Libtask.step_in(t.tf.tape, m.evaluator[2:end]) +@btime Libtask.step_in(t.tf, m.evaluator[2:end]) diff --git a/perf/p2.jl b/perf/p2.jl index 44fd61a7..5d1a4454 100644 --- a/perf/p2.jl +++ b/perf/p2.jl @@ -58,6 +58,6 @@ args = m.evaluator[2:end] t = Libtask.CTask(f, args...) -Libtask.step_in(t.tf.tape, args) +Libtask.step_in(t.tf, args) -@show Libtask.result(t.tf.tape) +@show Libtask.result(t.tf) diff --git a/src/tapedfunction.jl b/src/tapedfunction.jl index 10d27dc3..9ed039ac 100644 --- a/src/tapedfunction.jl +++ b/src/tapedfunction.jl @@ -1,72 +1,105 @@ abstract type AbstractInstruction end - -mutable struct Tape - tape::Vector{<:AbstractInstruction} - counter::Int - owner -end +abstract type Taped end +const RawTape = Vector{AbstractInstruction} """ Instruction An `Instruction` stands for a function call """ -mutable struct Instruction{F} <: AbstractInstruction - fun::F +mutable struct Instruction{F, T<:Taped} <: AbstractInstruction + func::F input::Tuple output - tape::Tape + tape::T end -Tape() = Tape(Vector{AbstractInstruction}(), 1, nothing) -Tape(owner) = Tape(Vector{AbstractInstruction}(), 1, owner) -MacroTools.@forward Tape.tape Base.iterate, Base.length -MacroTools.@forward Tape.tape Base.push!, Base.getindex, Base.lastindex -const NULL_TAPE = Tape() - -function setowner!(tape::Tape, owner) - tape.owner = owner - return tape +mutable struct TapedFunction{F} <: Taped + func::F # maybe a function or a callable obejct + arity::Int + ir::Union{Nothing, IRTools.IR} + tape::RawTape + counter::Int + owner + function TapedFunction(f::F; arity::Int=-1) where {F} + new{F}(f, arity, nothing, RawTape(), 1, nothing) + end end mutable struct Box{T} val::T end +## methods for Box val(x) = x val(x::Box) = x.val +val(x::TapedFunction) = x.func box(x) = Box(x) box(x::Box) = x +Base.show(io::IO, box::Box) = print(io, "Box(", box.val, ")") -gettape(x) = nothing -gettape(x::Instruction) = x.tape -function gettape(x::Tuple) - for i in x - gettape(i) != nothing && return gettape(i) - end +## methods for RawTape and Taped +MacroTools.@forward TapedFunction.tape Base.iterate, Base.length +MacroTools.@forward TapedFunction.tape Base.push!, Base.getindex, Base.lastindex + +result(t::RawTape) = isempty(t) ? nothing : val(t[end].output) +result(t::TapedFunction) = result(t.tape) + +function increase_counter!(t::TapedFunction) + t.counter > length(t) && return + # instr = t[t.counter] + t.counter += 1 + return t end -result(t::Tape) = isempty(t) ? nothing : val(t[end].output) -function Base.show(io::IO, box::Box) - println(io, "Box($(box.val))") +function reset!(tf::TapedFunction, ir::IRTools.IR, tape::RawTape) + tf.ir = ir + tf.tape = tape + return tf end -function Base.show(io::IO, instruction::AbstractInstruction) - println(io, "A $(typeof(instruction))") +function (tf::TapedFunction)(args...) + if isempty(tf.tape) + ir = IRTools.@code_ir tf.func(args...) + ir = intercept(ir; recorder=:track!) + tf.ir = ir + tf.tape = RawTape() + tf2 = IRTools.evalir(ir, tf, args...) + @assert tf === tf2 + else + # run the raw tape + if length(args) > 0 + input = map(box, args) + tf.tape[1].input = input + end + for instruction in tf.tape + instruction() + end + end + return result(tf) end -function Base.show(io::IO, instruction::Instruction) - fun = instruction.fun - tape = instruction.tape - println(io, "Instruction($(fun)$(map(val, instruction.input)), tape=$(objectid(tape)))") +function Base.show(io::IO, tf::TapedFunction) + buf = IOBuffer() + println(buf, "TapedFunction:") + println(buf, "* .func => $(tf.func)") + println(buf, "* .ir =>") + println(buf, "------------------") + println(buf, tf.ir) + println(buf, "------------------") + println(buf, "* .tape =>") + println(buf, "------------------") + println(buf, tf.tape) + println(buf, "------------------") + print(io, String(take!(buf))) end -function Base.show(io::IO, tp::Tape) +function Base.show(io::IO, tp::RawTape) # we use an extra IOBuffer to collect all the data and then # output it once to avoid output interrupt during task context # switching buf = IOBuffer() - print(buf, "$(length(tp))-element Tape") + print(buf, "$(length(tp))-element RawTape") isempty(tp) || println(buf, ":") i = 1 for instruction in tp @@ -77,10 +110,19 @@ function Base.show(io::IO, tp::Tape) print(io, String(take!(buf))) end +## methods for Instruction +Base.show(io::IO, instruction::AbstractInstruction) = print(io, "A ", typeof(instruction)) + +function Base.show(io::IO, instruction::Instruction) + func = instruction.func + tape = instruction.tape + println(io, "Instruction($(func)$(map(val, instruction.input)), tape=$(objectid(tape)))") +end + function (instr::Instruction{F})() where F # catch run-time exceptions / errors. try - output = instr.fun(map(val, instr.input)...) + output = instr.func(map(val, instr.input)...) instr.output.val = output catch e println(e, catch_backtrace()); @@ -101,26 +143,9 @@ function (instr::Instruction{typeof(_new)})() end end +## internal functions -function increase_counter!(t::Tape) - t.counter > length(t) && return - # instr = t[t.counter] - t.counter += 1 - return t -end - -function run(tape::Tape, args...) - if length(args) > 0 - input = map(box, args) - tape[1].input = input - end - for instruction in tape - instruction() - increase_counter!(tape) - end -end - -function run_and_record!(tape::Tape, f, args...) +function track!(tape::Taped, f, args...) f = val(f) # f maybe a Boxed closure output = try box(f(map(val, args)...)) @@ -133,7 +158,7 @@ function run_and_record!(tape::Tape, f, args...) return output end -function run_and_record!(tape::Tape, ::typeof(_new), args...) +function track!(tape::Taped, ::typeof(_new), args...) output = try expr = Expr(:new, map(val, args)...) box(eval(expr)) @@ -171,9 +196,11 @@ function _replace_args(args, pairs::Dict) end end -function intercept(ir; recorder=:run_and_record!) +function intercept(ir; recorder=:track!) ir == nothing && return - tape = pushfirst!(ir, IRTools.xcall(@__MODULE__, :Tape)) + # we use tf instead of the original function as the first argument + # get the TapedFunction + tape = pushfirst!(ir, IRTools.xcall(Base, :identity, IRTools.arguments(ir)[1])) # box the args first_blk = IRTools.blocks(ir)[1] @@ -229,51 +256,3 @@ function intercept(ir; recorder=:run_and_record!) unbox_condition(ir) return ir end - -mutable struct TapedFunction - func # ::Function # maybe a callable obejct - arity::Int - ir::Union{Nothing, IRTools.IR} - tape::Tape - owner - function TapedFunction(f; arity::Int=-1) - new(f, arity, nothing, NULL_TAPE, nothing) - end -end - -function reset!(tf::TapedFunction, ir::IRTools.IR, tape::Tape) - tf.ir = ir - tf.tape = tape - setowner!(tape, tf) - return tf -end - -function (tf::TapedFunction)(args...) - if isempty(tf.tape) - ir = IRTools.@code_ir tf.func(args...) - ir = intercept(ir; recorder=:run_and_record!) - tape = IRTools.evalir(ir, tf.func, args...) - tf.ir = ir - tf.tape = tape - setowner!(tape, tf) - return result(tape) - end - # TODO: use cache - run(tf.tape, args...) - return result(tf.tape) -end - -function Base.show(io::IO, tf::TapedFunction) - buf = IOBuffer() - println(buf, "TapedFunction:") - println(buf, "* .func => $(tf.func)") - println(buf, "* .ir =>") - println(buf, "------------------") - println(buf, tf.ir) - println(buf, "------------------") - println(buf, "* .tape =>") - println(buf, "------------------") - println(buf, tf.tape) - println(buf, "------------------") - print(io, String(take!(buf))) -end diff --git a/src/tapedtask.jl b/src/tapedtask.jl index 72af7d2e..cb7799fb 100644 --- a/src/tapedtask.jl +++ b/src/tapedtask.jl @@ -26,7 +26,7 @@ function TapedTask(tf::TapedFunction, args...) ir, tape = TRCache[cache_key] # Here we don't need change the initial arguments of the tape, # it will be set when we `step_in` to the tape. - reset!(tf, ir, copy(tape, Dict{UInt64, Any}(); partial=false)) + reset!(tf, ir, copy(tape, tf, Dict{UInt64, Any}(); start=1)) else tf(args...) TRCache[cache_key] = (tf.ir, tf.tape) @@ -35,7 +35,7 @@ function TapedTask(tf::TapedFunction, args...) produce_ch = Channel() consume_ch = Channel{Int}() task = @task try - step_in(tf.tape, args) + step_in(tf, args) catch e bt = catch_backtrace() put!(produce_ch, TapedTaskException(e, bt)) @@ -65,28 +65,27 @@ TapedTask(f, args...) = TapedTask(TapedFunction(f, arity=length(args)), args...) TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...) func(t::TapedTask) = t.tf.func - -function step_in(t::Tape, args) - len = length(t) - if(t.counter <= 1 && length(args) > 0) +function step_in(tf::TapedFunction, args) + len = length(tf) + if(tf.counter <= 1 && length(args) > 0) input = map(box, args) - t[1].input = input + tf[1].input = input end - while t.counter <= len - t[t.counter]() + while tf.counter <= len + tf[tf.counter]() # produce and wait after an instruction is done - ttask = t.owner.owner + ttask = tf.owner if length(ttask.produced_val) > 0 val = pop!(ttask.produced_val) put!(ttask.produce_ch, val) take!(ttask.consume_ch) # wait for next consumer end - increase_counter!(t) + increase_counter!(tf) end end function next_step!(t::TapedTask) - increase_counter!(t.tf.tape) + increase_counter!(t.tf) return t end @@ -95,8 +94,7 @@ end # Make`produce` a standalone instturction. This approach does NOT # support `produce` in a nested call function internal_produce(instr::Instruction, val) - tape = gettape(instr) - tf = tape.owner + tf = instr.tape ttask = tf.owner put!(ttask.produce_ch, val) take!(ttask.consume_ch) # wait for next consumer @@ -125,7 +123,7 @@ end ct.storage === nothing && return false haskey(ct.storage, :tapedtask) || return false # check if we are recording a tape - ct.storage[:tapedtask].tf.tape === NULL_TAPE && return false + isempty(ct.storage[:tapedtask].tf.tape) && return false return true end @@ -204,36 +202,34 @@ function copy_box(old_box::Box{T}, roster::Dict{UInt64, Any}) where T end copy_box(o, roster::Dict{UInt64, Any}) = o -function Base.copy(x::Instruction, on_tape::Tape, roster::Dict{UInt64, Any}) +function Base.copy(x::Instruction, on_tape::Taped, roster::Dict{UInt64, Any}) input = map(x.input) do ob copy_box(ob, roster) end output = copy_box(x.output, roster) - Instruction(x.fun, input, output, on_tape) + Instruction(x.func, input, output, on_tape) end -function Base.copy(t::Tape, roster::Dict{UInt64, Any}; partial=true) - old_data = t.tape - len = partial ? length(old_data) - t.counter + 1 : length(old_data) - start = partial ? t.counter : 1 - new_data = Vector{AbstractInstruction}(undef, len) - new_tape = Tape(new_data, 1, t.owner) +function Base.copy(t::RawTape, on_tape::Taped, roster::Dict{UInt64, Any}; start::Int=1) + old_data = t + len = length(old_data) - start + 1 + new_data = RawTape(undef, len) for (i, x) in enumerate(old_data[start:end]) - new_ins = copy(x, new_tape, roster) + new_ins = copy(x, on_tape, roster) new_data[i] = new_ins end - return new_tape + return new_data end function Base.copy(tf::TapedFunction) new_tf = TapedFunction(tf.func; arity=tf.arity) new_tf.ir = tf.ir roster = Dict{UInt64, Any}() - new_tape = copy(tf.tape, roster) - setowner!(new_tape, new_tf) + new_tape = copy(tf.tape, new_tf, roster; start=tf.counter) new_tf.tape = new_tape + new_tf.counter = 1 return new_tf end diff --git a/test/tf.jl b/test/tf.jl index 53ac57ca..8a3822ac 100644 --- a/test/tf.jl +++ b/test/tf.jl @@ -11,7 +11,7 @@ using Libtask tf = Libtask.TapedFunction(S) s1 = tf(1, 2) @test s1.i == 3 - newins = findall(x -> isa(x, Libtask.Instruction{typeof(Libtask._new)}), tf.tape.tape) + newins = findall(x -> isa(x, Libtask.Instruction{typeof(Libtask._new)}), tf.tape) @test length(newins) == 1 end end