-
Notifications
You must be signed in to change notification settings - Fork 34
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
* Bump patch version and add Tapir ext * Make Tapir available at test time * Add Tapir runs to AD testing * Add single rule to Tapir to handle bisection * using Tapir * Run on 1.6 only and add Tapir to AD tests * Disable more tests * Update ext/BijectorsTapirExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Fix formatting * Restrict version * Remove Tapir from Project * Do not run Tapir CI on 1.6 * Enable 1.6 tests in general * Enable 1.6 on interface tests * Tweak versioning * Cancel when multiple things are pushed * Add Tapir to extras * Comment out tapir usage * Try allowing more versions of Tapir * Allow more versions of Tapir * More tweaks * Add Pkg to test deps * Refine CI * Use Tapir on 1.10 * Remove CI modifications * Formatting * add comment to Tapir installation * Support a range of types * Fix Project.toml * Fix formatting * Fix formatting * Fix formatting * Apply suggestions from code review Co-authored-by: David Widmann <[email protected]> * Sort out formatting --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> Co-authored-by: David Widmann <[email protected]>
- Loading branch information
1 parent
c3474b2
commit 2849aca
Showing
7 changed files
with
91 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,38 @@ | ||
module BijectorsTapirExt | ||
|
||
if isdefined(Base, :get_extension) | ||
using Tapir: @is_primitive, MinimalCtx, Tapir, CoDual, primal, tangent_type, @from_rrule | ||
using Bijectors: find_alpha, ChainRulesCore | ||
else | ||
using ..Tapir: @is_primitive, MinimalCtx, Tapir, primal, tangent_type, @from_rrule | ||
using ..Bijectors: find_alpha, ChainRulesCore | ||
end | ||
|
||
for P in [Float16, Float32, Float64] | ||
@from_rrule(MinimalCtx, Tuple{typeof(find_alpha),P,P,P}) | ||
end | ||
|
||
# The final argument could be an Integer of some kind. This should be fine provided that | ||
# it has tangent type equal to `NoTangent`, which means that it's non-differentiable and | ||
# can be safely dropped. We verify that the concrete type of the Integer satisfies this | ||
# constraint, and error if (for some reason) it does not. This should be fine unless a very | ||
# unusual Integer type is encountered. | ||
@is_primitive(MinimalCtx, Tuple{typeof(find_alpha),P,P,Integer} where {P<:Base.IEEEFloat}) | ||
|
||
function Tapir.rrule!!( | ||
::CoDual{typeof(find_alpha)}, x::CoDual{P}, y::CoDual{P}, z::CoDual{I} | ||
) where {P<:Base.IEEEFloat,I<:Integer} | ||
# Require that the integer is non-differentiable. | ||
if tangent_type(I) != Tapir.NoTangent | ||
msg = "Integer argument has tangent type $(tangent_type(I)), should be NoTangent." | ||
throw(ArgumentError(msg)) | ||
end | ||
out, pb = ChainRulesCore.rrule(find_alpha, primal(x), primal(y), primal(z)) | ||
function find_alpha_pb(dout::P) | ||
_, dx, dy, _ = pb(dout) | ||
return Tapir.NoRData(), P(dx), P(dy), Tapir.NoRData() | ||
end | ||
return Tapir.zero_fcodual(out), find_alpha_pb | ||
end | ||
|
||
end |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters