Skip to content

Commit

Permalink
Optimize Enzyme gradient (#515)
Browse files Browse the repository at this point in the history
* Optimize Enzyme gradient

* Active
  • Loading branch information
gdalle authored Sep 30, 2024
1 parent 92ccd1c commit 73f7314
Showing 1 changed file with 74 additions and 18 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -210,82 +210,138 @@ end

## Gradient

### Without preparation

function DI.gradient(
f::F,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
autodiff(
reverse_noprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
)
return grad
end

function DI.value_and_gradient(
f::F,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
grad = make_zero(x)
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
)
return y, grad
end

### With preparation

struct EnzymeGradientPrep{G} <: GradientPrep
grad_righttype::G
end

function DI.prepare_gradient(
f::F,
::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
) where {F,C}
return NoGradientPrep()
grad_righttype = make_zero(x)
return EnzymeGradientPrep(grad_righttype)
end

function DI.gradient(
f::F,
::NoGradientPrep,
::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
derivs = gradient(reverse_noprimal(backend), f_and_df, x, map(translate, contexts)...)
return first(derivs)
grad = make_zero(x)
autodiff(
reverse_noprimal(backend),
f_and_df,
Duplicated(x, grad),
map(translate, contexts)...,
)
return grad
end

function DI.gradient!(
f::F,
grad,
::NoGradientPrep,
prep::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
dx_righttype = convert(typeof(x), grad)
make_zero!(dx_righttype)
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
make_zero!(grad_righttype)
autodiff(
reverse_noprimal(backend),
f_and_df,
Active,
Duplicated(x, dx_righttype),
Duplicated(x, grad_righttype),
map(translate, contexts)...,
)
dx_righttype === grad || copyto!(grad, dx_righttype)
grad isa typeof(x) || copyto!(grad, grad_righttype)
return grad
end

function DI.value_and_gradient(
f::F,
::NoGradientPrep,
::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
(; derivs, val) = gradient(
reverse_withprimal(backend), f_and_df, x, map(translate, contexts)...
grad = make_zero(x)
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, grad),
map(translate, contexts)...,
)
return val, first(derivs)
return y, grad
end

function DI.value_and_gradient!(
f::F,
grad,
::NoGradientPrep,
prep::EnzymeGradientPrep,
backend::AutoEnzyme{<:Union{ReverseMode,Nothing},<:Union{Nothing,Const}},
x,
contexts::Vararg{Context,C},
) where {F,C}
f_and_df = get_f_and_df(f, backend)
dx_righttype = convert(typeof(x), grad)
make_zero!(dx_righttype)
grad_righttype = grad isa typeof(x) ? grad : prep.grad_righttype
make_zero!(grad_righttype)
_, y = autodiff(
reverse_withprimal(backend),
f_and_df,
Active,
Duplicated(x, dx_righttype),
Duplicated(x, grad_righttype),
map(translate, contexts)...,
)
dx_righttype === grad || copyto!(grad, dx_righttype)
grad isa typeof(x) || copyto!(grad, grad_righttype)
return y, grad
end

Expand Down

0 comments on commit 73f7314

Please sign in to comment.