Skip to content

Commit

Permalink
Merge Tape and TapedFunction (#105)
Browse files Browse the repository at this point in the history
* merge Tape and TapedFunction

* Tape => RawTape

* trivias update

* remove gettape

* merge run(::RawTape) and (tf::TapedFunction)(args...)

* minor update
  • Loading branch information
KDr2 authored Jan 20, 2022
1 parent 64f90e6 commit 333b279
Show file tree
Hide file tree
Showing 5 changed files with 114 additions and 139 deletions.
6 changes: 3 additions & 3 deletions perf/p0.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,9 +29,9 @@ args = m.evaluator[2:end];
@show "CTask construction..."
t = @btime Libtask.CTask(f, args...)
# schedule(t.task) # work fine!
# @show Libtask.result(t.tf.tape)
# @show Libtask.result(t.tf)
@show "Step in a tape..."
@btime Libtask.step_in(t.tf.tape, args)
@btime Libtask.step_in(t.tf, args)

# Case 2: SMC sampler

Expand All @@ -44,4 +44,4 @@ t = @btime Libtask.CTask(m.evaluator[1], m.evaluator[2:end]...);
# schedule(t.task)
# @show Libtask.result(t.tf.tape)
@show "Step in a tape..."
@btime Libtask.step_in(t.tf.tape, m.evaluator[2:end])
@btime Libtask.step_in(t.tf, m.evaluator[2:end])
4 changes: 2 additions & 2 deletions perf/p2.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,6 @@ args = m.evaluator[2:end]

t = Libtask.CTask(f, args...)

Libtask.step_in(t.tf.tape, args)
Libtask.step_in(t.tf, args)

@show Libtask.result(t.tf.tape)
@show Libtask.result(t.tf)
191 changes: 85 additions & 106 deletions src/tapedfunction.jl
Original file line number Diff line number Diff line change
@@ -1,72 +1,105 @@
abstract type AbstractInstruction end

mutable struct Tape
tape::Vector{<:AbstractInstruction}
counter::Int
owner
end
abstract type Taped end
const RawTape = Vector{AbstractInstruction}

"""
Instruction
An `Instruction` stands for a function call
"""
mutable struct Instruction{F} <: AbstractInstruction
fun::F
mutable struct Instruction{F, T<:Taped} <: AbstractInstruction
func::F
input::Tuple
output
tape::Tape
tape::T
end

Tape() = Tape(Vector{AbstractInstruction}(), 1, nothing)
Tape(owner) = Tape(Vector{AbstractInstruction}(), 1, owner)
MacroTools.@forward Tape.tape Base.iterate, Base.length
MacroTools.@forward Tape.tape Base.push!, Base.getindex, Base.lastindex
const NULL_TAPE = Tape()

function setowner!(tape::Tape, owner)
tape.owner = owner
return tape
mutable struct TapedFunction{F} <: Taped
func::F # maybe a function or a callable obejct
arity::Int
ir::Union{Nothing, IRTools.IR}
tape::RawTape
counter::Int
owner
function TapedFunction(f::F; arity::Int=-1) where {F}
new{F}(f, arity, nothing, RawTape(), 1, nothing)
end
end

mutable struct Box{T}
val::T
end

## methods for Box
val(x) = x
val(x::Box) = x.val
val(x::TapedFunction) = x.func
box(x) = Box(x)
box(x::Box) = x
Base.show(io::IO, box::Box) = print(io, "Box(", box.val, ")")

gettape(x) = nothing
gettape(x::Instruction) = x.tape
function gettape(x::Tuple)
for i in x
gettape(i) != nothing && return gettape(i)
end
## methods for RawTape and Taped
MacroTools.@forward TapedFunction.tape Base.iterate, Base.length
MacroTools.@forward TapedFunction.tape Base.push!, Base.getindex, Base.lastindex

result(t::RawTape) = isempty(t) ? nothing : val(t[end].output)
result(t::TapedFunction) = result(t.tape)

function increase_counter!(t::TapedFunction)
t.counter > length(t) && return
# instr = t[t.counter]
t.counter += 1
return t
end
result(t::Tape) = isempty(t) ? nothing : val(t[end].output)

function Base.show(io::IO, box::Box)
println(io, "Box($(box.val))")
function reset!(tf::TapedFunction, ir::IRTools.IR, tape::RawTape)
tf.ir = ir
tf.tape = tape
return tf
end

function Base.show(io::IO, instruction::AbstractInstruction)
println(io, "A $(typeof(instruction))")
function (tf::TapedFunction)(args...)
if isempty(tf.tape)
ir = IRTools.@code_ir tf.func(args...)
ir = intercept(ir; recorder=:track!)
tf.ir = ir
tf.tape = RawTape()
tf2 = IRTools.evalir(ir, tf, args...)
@assert tf === tf2
else
# run the raw tape
if length(args) > 0
input = map(box, args)
tf.tape[1].input = input
end
for instruction in tf.tape
instruction()
end
end
return result(tf)
end

function Base.show(io::IO, instruction::Instruction)
fun = instruction.fun
tape = instruction.tape
println(io, "Instruction($(fun)$(map(val, instruction.input)), tape=$(objectid(tape)))")
function Base.show(io::IO, tf::TapedFunction)
buf = IOBuffer()
println(buf, "TapedFunction:")
println(buf, "* .func => $(tf.func)")
println(buf, "* .ir =>")
println(buf, "------------------")
println(buf, tf.ir)
println(buf, "------------------")
println(buf, "* .tape =>")
println(buf, "------------------")
println(buf, tf.tape)
println(buf, "------------------")
print(io, String(take!(buf)))
end

function Base.show(io::IO, tp::Tape)
function Base.show(io::IO, tp::RawTape)
# we use an extra IOBuffer to collect all the data and then
# output it once to avoid output interrupt during task context
# switching
buf = IOBuffer()
print(buf, "$(length(tp))-element Tape")
print(buf, "$(length(tp))-element RawTape")
isempty(tp) || println(buf, ":")
i = 1
for instruction in tp
Expand All @@ -77,10 +110,19 @@ function Base.show(io::IO, tp::Tape)
print(io, String(take!(buf)))
end

## methods for Instruction
Base.show(io::IO, instruction::AbstractInstruction) = print(io, "A ", typeof(instruction))

function Base.show(io::IO, instruction::Instruction)
func = instruction.func
tape = instruction.tape
println(io, "Instruction($(func)$(map(val, instruction.input)), tape=$(objectid(tape)))")
end

function (instr::Instruction{F})() where F
# catch run-time exceptions / errors.
try
output = instr.fun(map(val, instr.input)...)
output = instr.func(map(val, instr.input)...)
instr.output.val = output
catch e
println(e, catch_backtrace());
Expand All @@ -101,26 +143,9 @@ function (instr::Instruction{typeof(_new)})()
end
end

## internal functions

function increase_counter!(t::Tape)
t.counter > length(t) && return
# instr = t[t.counter]
t.counter += 1
return t
end

function run(tape::Tape, args...)
if length(args) > 0
input = map(box, args)
tape[1].input = input
end
for instruction in tape
instruction()
increase_counter!(tape)
end
end

function run_and_record!(tape::Tape, f, args...)
function track!(tape::Taped, f, args...)
f = val(f) # f maybe a Boxed closure
output = try
box(f(map(val, args)...))
Expand All @@ -133,7 +158,7 @@ function run_and_record!(tape::Tape, f, args...)
return output
end

function run_and_record!(tape::Tape, ::typeof(_new), args...)
function track!(tape::Taped, ::typeof(_new), args...)
output = try
expr = Expr(:new, map(val, args)...)
box(eval(expr))
Expand Down Expand Up @@ -171,9 +196,11 @@ function _replace_args(args, pairs::Dict)
end
end

function intercept(ir; recorder=:run_and_record!)
function intercept(ir; recorder=:track!)
ir == nothing && return
tape = pushfirst!(ir, IRTools.xcall(@__MODULE__, :Tape))
# we use tf instead of the original function as the first argument
# get the TapedFunction
tape = pushfirst!(ir, IRTools.xcall(Base, :identity, IRTools.arguments(ir)[1]))

# box the args
first_blk = IRTools.blocks(ir)[1]
Expand Down Expand Up @@ -229,51 +256,3 @@ function intercept(ir; recorder=:run_and_record!)
unbox_condition(ir)
return ir
end

mutable struct TapedFunction
func # ::Function # maybe a callable obejct
arity::Int
ir::Union{Nothing, IRTools.IR}
tape::Tape
owner
function TapedFunction(f; arity::Int=-1)
new(f, arity, nothing, NULL_TAPE, nothing)
end
end

function reset!(tf::TapedFunction, ir::IRTools.IR, tape::Tape)
tf.ir = ir
tf.tape = tape
setowner!(tape, tf)
return tf
end

function (tf::TapedFunction)(args...)
if isempty(tf.tape)
ir = IRTools.@code_ir tf.func(args...)
ir = intercept(ir; recorder=:run_and_record!)
tape = IRTools.evalir(ir, tf.func, args...)
tf.ir = ir
tf.tape = tape
setowner!(tape, tf)
return result(tape)
end
# TODO: use cache
run(tf.tape, args...)
return result(tf.tape)
end

function Base.show(io::IO, tf::TapedFunction)
buf = IOBuffer()
println(buf, "TapedFunction:")
println(buf, "* .func => $(tf.func)")
println(buf, "* .ir =>")
println(buf, "------------------")
println(buf, tf.ir)
println(buf, "------------------")
println(buf, "* .tape =>")
println(buf, "------------------")
println(buf, tf.tape)
println(buf, "------------------")
print(io, String(take!(buf)))
end
Loading

0 comments on commit 333b279

Please sign in to comment.