Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

WIP: Connecting to ChainRulesCore for JuliaLang AD compat #652

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion julia/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,18 @@ uuid = "bb22f25d-cb49-471c-b017-930e329a2928"
version = "0.1.0"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
CombinedParsers = "5ae71ed2-6f8a-4ed1-b94f-e14e8158f19e"

[compat]
ChainRulesCore = "^1.0"
CombinedParsers = "^0.2"
Zygote = "^0.6.22"
julia = "^1.6"

[extras]
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["Test"]
test = ["Test", "Zygote"]
6 changes: 4 additions & 2 deletions julia/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

DexCall provides a mechanism for calling dex-lang code from JuliaLang.
Three main mechanism are provided for this: `evaluate`, `DexModule` and the `dex_func` string macro.
Two helper methods are also provided: `juliaize` and `NativeFunction`.
Several helper methods are also provided: `juliaize`, `dexize`, and `NativeFunction`.

## `evaluate`: just run a single Dex expression.
`evaluate` takes in a Dex expression as a string and runs it, returning a `Atom` (see below).
Expand Down Expand Up @@ -53,7 +53,7 @@ julia> m.addTwo(m.y)
"[44., 44., 44.]"
```

## Atoms: `juliaize`, `NativeFunction`
## Atoms: `juliaize`, `dexize` and `NativeFunction`

`evaluate` and the contents of a `DexModule` are returned as `Atom`s.
These can be displayed, but not much else.
Expand Down Expand Up @@ -87,6 +87,8 @@ julia> typeof(convert(Int64, m.x))
Int64
```

The inverse of `juliaize` is `dexize`, it is currently very limited.
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I haven't given it proper docs yet, because it only does Float32.
It's mostly just for testing purposes.
We need a proper API for this.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, we can extend it if you'd like. Just let me know what would be helpful to have.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I guess it is something like create_literal that works like insert except instead of taking an Atom it takes a C compatible value for a Int/Float32/Float64/Array.
Possibly ctypes doesn't allow that directly, so maybe it needs to be wrapped into a tagged union?
I guess maybe accepting the same tagged union that comes out of the atom's pointer makes sense.
(Except right now that doesn't support arrays)

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wouldn't a function that converts a CAtom into an Atom be sufficient? That's what I would imagine. And then we can add more cases to CAtom if that would be useful.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes, I think that is basically what I said badly.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool, which cases would you like to have? I'm happy to add them for you

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Float32, Float64 (particularly since can't input those as literals #497)
Arrays would be nice, but given we can't currently convert Atom to CAtom for arrays anyway, that doesn't matter so much.
Integer types would be nice, for completeness, but not particularly interesting for AD.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Added that in #657


To convert function `Atom`s into something you can execute as if it was a regular julia function use `NativeFunction`.
This will compile it and handing inputs and outputs without needing to del with `Atom`s directly.

Expand Down
4 changes: 3 additions & 1 deletion julia/src/DexCall.jl
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
"Calling Dex from Julia"
module DexCall
using ChainRulesCore
using CombinedParsers
using CombinedParsers.Regexp

export evaluate, DexError, DexModule, juliaize, NativeFunction, @dex_func_str
export evaluate, DexError, DexModule, dexize, juliaize, NativeFunction, @dex_func_str

include("api_types.jl")
include("api.jl")
include("evaluate.jl")
include("native_function.jl")
include("chainrules.jl")
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

include("chainrules.jl")


# use this to disable free'ing haskell objects after we have closed the RTS
const NO_FREE = Ref(false)
Expand Down
8 changes: 0 additions & 8 deletions julia/src/api_types.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,11 +55,3 @@ function CAtom(atm::Ptr{HsAtom})
iszero(success) && throw_from_dex()
return result[]
end

"""
juliaize(x)

Get the corresponding Julia object from some output of Dex.
"""
juliaize(x::CAtom) = bust_union(x)
juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x))
51 changes: 51 additions & 0 deletions julia/src/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@

function ChainRulesCore.frule((_, ẋ), f::Atom, x::Atom)
ẋ isa Atom || throw(DomainError(ẋ, "Tangent to an Atom must be an Atom"))
env = f.ctx
env = insert(env, "f", f.ptr)
env = insert(env, "dx", ẋ.ptr)
env = insert(env, "x", x.ptr)

m = DexModule(raw"""
(y, pushforward) = linearize f x
dy = pushforward dx
""",
env
)
return m.y, m.dy
end

function ChainRulesCore.rrule(f::Atom, x::Atom)
env = f.ctx
env = insert(env, "f", f.ptr)
env = insert(env, "x", x.ptr)

m = DexModule(raw"""
(y, pushforward) = linearize f x
pullback = transposeLinear pushforward
""",
env
)

# It is important that we close over `m` as otherwise the env may be destroyed by GC
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer the case since right now that finalizers is commented out.
And before this is merged we may way move to working out how to attach the finalizer to the context more directly.

pullback(x̄::Atom)= (NoTangent(), m.pullback(x̄))
return m.y, pullback
end

ChainRulesCore.frule((_, ẋ), ::typeof(juliaize), x) = juliaize(x), juliaize(ẋ)
function ChainRulesCore.rrule(::typeof(juliaize), x::Atom)
env= x.ctx

# pullback must take a julia typed cotangent and give back a dex typed cotangent
juliaize_pullback(ȳ) = (NoTangent(), dexize(ȳ, env))
return juliaize(x), juliaize_pullback
end


ChainRulesCore.frule((_, ẋ), ::typeof(dexize), x) = dexize(x), dexize(ẋ)
function ChainRulesCore.rrule(::typeof(dexize), x)
# pullback must take a dex typed cotangent and give back a julia typed cotangent
dexize_pullback(ȳ) = (NoTangent(), juliaize(ȳ))
return dexize(x), dexize_pullback
end

32 changes: 30 additions & 2 deletions julia/src/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,37 @@ end

Base.show(io::IO, atom::Atom) = show(io, print(atom.ptr))

"""
juliaize(x)

Get the corresponding Julia object from some output of Dex.
"""
juliaize(x::CAtom) = bust_union(x)
juliaize(x::Ptr{HsAtom}) = juliaize(CAtom(x))
juliaize(x::Atom) = juliaize(x.ptr)
Base.convert(::Type{T}, atom::Atom) where {T<:Number} = convert(T, juliaize(atom))

"""
dexize(x)

Get the corresponding Dex object from some output of Julia.

NB: this is currently a hack that goes via string processing.
"""
function dexize(x::Float32, _module=PRELUDE, env=_module)
Copy link
Contributor Author

@oxinabox oxinabox Sep 25, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I suspect we do not need to take a env and a _module argument.
We are making literals, we just need one for where the literal will exist?
I don't really understand the difference between them

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The _module and env parameters are a bit of a hack that I added for Python bindings. The rough idea was that module is the module an output atom declares itself to be defined in, while env is the scope that's really used to evaluate the expression. This is used in the __call__ implementation where we temporarily extend the prelude with new names that refer to arguments, but then we want to pretend that the result is still defined in the original module that doesn't have those dummies. But now that I think about it, it's only well defined for non-dependent functions, so we should find a different workaround...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In any case, I think you can safely ignore env and just use the _module.

isnan(x) && return evaluate("nan", _module, env)
x === Inf32 && return evaluate("infinity", _module, env)
x === -Inf32 && return evaluate("-infinity", _module, env)

str = repr(x)
if endswith(str, "f0")
evaluate(str[1:end-2], _module, env)
else
# convert "123f45" into "123 * (intpow 10.0 45)"
evaluate(replace(str, "f"=> " * (intpow 10.0 ") * ")", _module, env)
end
end


function (self::Atom)(args...)
# TODO: Make those calls more hygenic
Expand Down Expand Up @@ -60,8 +88,8 @@ julia> m.y
"84"
```
"""
function DexModule(source::AbstractString)
ctx = dex_eval(PRELUDE, source)
function DexModule(source::AbstractString, parent_ctx=PRELUDE)
ctx = dex_eval(parent_ctx, source)
ctx == C_NULL && throw_from_dex()
m = DexModule(ctx)
finalizer(m) do _m
Expand Down
31 changes: 31 additions & 0 deletions julia/test/chainrules.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
const double_dex = evaluate(raw"\x:Float. 2.0 * x")

@testset "frule: dexize, evaluate, juliaize" begin
a, ȧ = frule((NoTangent(), 10f0), dexize, 1.5f0)
b, ḃ = frule((NoTangent(), ȧ), double_dex, a)
c, ċ = frule((NoTangent(), ḃ), juliaize, b)
@test c === 3.0f0
@test ċ === 20f0
end

@testset "rrule: dexize, evaluate, juliaize" begin
x = 1.5f0
a, a_pb = rrule(dexize, x)
b, b_pb = rrule(double_dex, a)
c, c_pb = rrule(juliaize, b)

@test c === 3.0f0
c̄ = 10f0
_, b̄ = c_pb(c̄)
_, ā = b_pb(b̄)
_, x̄ = a_pb(ā)

@test x̄ === 20f0
end

@testset "Integration Test: Zygote.jl" begin
double_via_dex = juliaize ∘ double_dex ∘ dexize
y, pb= Zygote.pullback(double_via_dex, 1.5f0)
@test y == 3f0
@test pb(1f0) == (2f0,)
end
10 changes: 10 additions & 0 deletions julia/test/evaluate.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,16 @@
@test juliaize(evaluate("IToW8 65")) === Int8(65)
end

@testset "dexize" begin
@test juliaize(dexize(0f0)) === 0f0
@test juliaize(dexize(42f0)) === 42f0
@test juliaize(dexize(123f15)) === 123f15
@test dexize(123f15) isa DexCall.Atom
@test isnan(juliaize(dexize(NaN32)))
@test (juliaize(dexize(Inf32))) == Inf32
@test (juliaize(dexize(-Inf32))) == -Inf32
end

@testset "Atom function call" begin
m = DexModule("""
def addOne (x: Float) : Float = x + 1.0
Expand Down
3 changes: 3 additions & 0 deletions julia/test/runtests.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
using Test
using DexCall
using ChainRulesCore
using Zygote # for integration tests

@testset "DexCall" begin
include("api.jl")
include("evaluate.jl")
include("native_function.jl")
include("chainrules.jl")
end