-
Notifications
You must be signed in to change notification settings - Fork 109
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
WIP: Connecting to ChainRulesCore for JuliaLang AD compat #652
base: main
Are you sure you want to change the base?
Changes from 4 commits
c5104b4
42cbad5
37372cb
e6c5759
b2f188c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,14 +1,16 @@ | ||
"Calling Dex from Julia" | ||
module DexCall | ||
using ChainRulesCore | ||
using CombinedParsers | ||
using CombinedParsers.Regexp | ||
|
||
export evaluate, DexError, DexModule, juliaize, NativeFunction, @dex_func_str | ||
export evaluate, DexError, DexModule, dexize, juliaize, NativeFunction, @dex_func_str | ||
|
||
include("api_types.jl") | ||
include("api.jl") | ||
include("evaluate.jl") | ||
include("native_function.jl") | ||
include("chainrules.jl") | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
|
||
|
||
# use this to disable free'ing haskell objects after we have closed the RTS | ||
const NO_FREE = Ref(false) | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
|
||
function ChainRulesCore.frule((_, ẋ), f::Atom, x::Atom) | ||
ẋ isa Atom || throw(DomainError(ẋ, "Tangent to an Atom must be an Atom")) | ||
env = f.ctx | ||
env = insert(env, "f", f.ptr) | ||
env = insert(env, "dx", ẋ.ptr) | ||
env = insert(env, "x", x.ptr) | ||
|
||
m = DexModule(raw""" | ||
(y, pushforward) = linearize f x | ||
dy = pushforward dx | ||
""", | ||
env | ||
) | ||
return m.y, m.dy | ||
end | ||
|
||
function ChainRulesCore.rrule(f::Atom, x::Atom) | ||
env = f.ctx | ||
env = insert(env, "f", f.ptr) | ||
env = insert(env, "x", x.ptr) | ||
|
||
m = DexModule(raw""" | ||
(y, pushforward) = linearize f x | ||
pullback = transposeLinear pushforward | ||
""", | ||
env | ||
) | ||
|
||
# It is important that we close over `m` as otherwise the env may be destroyed by GC | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is no longer the case since right now that finalizers is commented out. |
||
pullback(x̄::Atom)= (NoTangent(), m.pullback(x̄)) | ||
return m.y, pullback | ||
end | ||
|
||
ChainRulesCore.frule((_, ẋ), ::typeof(juliaize), x) = juliaize(x), juliaize(ẋ) | ||
function ChainRulesCore.rrule(::typeof(juliaize), x::Atom) | ||
env= x.ctx | ||
|
||
# pullback must take a julia typed cotangent and give back a dex typed cotangent | ||
juliaize_pullback(ȳ) = (NoTangent(), dexize(ȳ, env)) | ||
return juliaize(x), juliaize_pullback | ||
end | ||
|
||
|
||
ChainRulesCore.frule((_, ẋ), ::typeof(dexize), x) = dexize(x), dexize(ẋ) | ||
function ChainRulesCore.rrule(::typeof(dexize), x) | ||
# pullback must take a dex typed cotangent and give back a julia typed cotangent | ||
dexize_pullback(ȳ) = (NoTangent(), juliaize(ȳ)) | ||
return dexize(x), dexize_pullback | ||
end | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -13,9 +13,37 @@ end | |
|
||
Base.show(io::IO, atom::Atom) = show(io, print(atom.ptr)) | ||
|
||
""" | ||
juliaize(x) | ||
|
||
Get the corresponding Julia object from some output of Dex. | ||
""" | ||
juliaize(x::CAtom) = bust_union(x) | ||
juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x)) | ||
juliaize(x::Atom) = juliaize(x.ptr) | ||
Base.convert(::Type{T}, atom::Atom) where {T<:Number} = convert(T, juliaize(atom)) | ||
|
||
""" | ||
dexize(x) | ||
|
||
Get the corresponding Dex object from some output of Julia. | ||
|
||
NB: this is currently a hack that goes via string processing. | ||
""" | ||
function dexize(x::Float32, _module=PRELUDE, env=_module) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I suspect we do not need to take a There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. In any case, I think you can safely ignore |
||
isnan(x) && return evaluate("nan", _module, env) | ||
x === Inf32 && return evaluate("infinity", _module, env) | ||
x === -Inf32 && return evaluate("-infinity", _module, env) | ||
|
||
str = repr(x) | ||
if endswith(str, "f0") | ||
evaluate(str[1:end-2], _module, env) | ||
else | ||
# convert "123f45" into "123 * (intpow 10.0 45)" | ||
evaluate(replace(str, "f"=> " * (intpow 10.0 ") * ")", _module, env) | ||
end | ||
end | ||
|
||
|
||
function (self::Atom)(args...) | ||
# TODO: Make those calls more hygenic | ||
|
@@ -60,8 +88,8 @@ julia> m.y | |
"84" | ||
``` | ||
""" | ||
function DexModule(source::AbstractString) | ||
ctx = dex_eval(PRELUDE, source) | ||
function DexModule(source::AbstractString, parent_ctx=PRELUDE) | ||
ctx = dex_eval(parent_ctx, source) | ||
ctx == C_NULL && throw_from_dex() | ||
m = DexModule(ctx) | ||
finalizer(m) do _m | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
const double_dex = evaluate(raw"\x:Float. 2.0 * x") | ||
|
||
@testset "frule: dexize, evaluate, juliaize" begin | ||
a, ȧ = frule((NoTangent(), 10f0), dexize, 1.5f0) | ||
b, ḃ = frule((NoTangent(), ȧ), double_dex, a) | ||
c, ċ = frule((NoTangent(), ḃ), juliaize, b) | ||
@test c === 3.0f0 | ||
@test ċ === 20f0 | ||
end | ||
|
||
@testset "rrule: dexize, evaluate, juliaize" begin | ||
x = 1.5f0 | ||
a, a_pb = rrule(dexize, x) | ||
b, b_pb = rrule(double_dex, a) | ||
c, c_pb = rrule(juliaize, b) | ||
|
||
@test c === 3.0f0 | ||
c̄ = 10f0 | ||
_, b̄ = c_pb(c̄) | ||
_, ā = b_pb(b̄) | ||
_, x̄ = a_pb(ā) | ||
|
||
@test x̄ === 20f0 | ||
end | ||
|
||
@testset "Integration Test: Zygote.jl" begin | ||
double_via_dex = juliaize ∘ double_dex ∘ dexize | ||
y, pb= Zygote.pullback(double_via_dex, 1.5f0) | ||
@test y == 3f0 | ||
@test pb(1f0) == (2f0,) | ||
end |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,8 +1,11 @@ | ||
using Test | ||
using DexCall | ||
using ChainRulesCore | ||
using Zygote # for integration tests | ||
|
||
@testset "DexCall" begin | ||
include("api.jl") | ||
include("evaluate.jl") | ||
include("native_function.jl") | ||
include("chainrules.jl") | ||
end |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I haven't given it proper docs yet, because it only does Float32.
It's mostly just for testing purposes.
We need a proper API for this.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sure, we can extend it if you'd like. Just let me know what would be helpful to have.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I guess it is something like
create_literal
that works likeinsert
except instead of taking anAtom
it takes a C compatible value for a Int/Float32/Float64/Array.Possibly
ctypes
doesn't allow that directly, so maybe it needs to be wrapped into a tagged union?I guess maybe accepting the same tagged union that comes out of the atom's pointer makes sense.
(Except right now that doesn't support arrays)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wouldn't a function that converts a
CAtom
into anAtom
be sufficient? That's what I would imagine. And then we can add more cases toCAtom
if that would be useful.There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes, I think that is basically what I said badly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool, which cases would you like to have? I'm happy to add them for you
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Float32
,Float64
(particularly since can't input those as literals #497)Arrays would be nice, but given we can't currently convert
Atom
toCAtom
for arrays anyway, that doesn't matter so much.Integer types would be nice, for completeness, but not particularly interesting for AD.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Added that in #657