Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: access overloaded inputs from preparation result #672

Merged
merged 14 commits into from
Jan 14, 2025

Conversation

ErikQQY
Copy link
Contributor

@ErikQQY ErikQQY commented Jan 6, 2025

Fix: #668

This PR only adds the overloaded_inputs functionalities for Jacobian preparations, not sure if the derivative, gradient and hessian preparation need this API.

The tests for this feature maybe too simple, need some suggestions for better test cases.

@ErikQQY ErikQQY requested a review from gdalle as a code owner January 6, 2025 15:24
Copy link

codecov bot commented Jan 6, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 97.61%. Comparing base (240e7e8) to head (805b000).
Report is 1 commits behind head on main.

Additional details and impacted files
@@           Coverage Diff           @@
##             main     #672   +/-   ##
=======================================
  Coverage   97.61%   97.61%           
=======================================
  Files         112      114    +2     
  Lines        5610     5623   +13     
=======================================
+ Hits         5476     5489   +13     
  Misses        134      134           
Flag Coverage Δ
DI 98.76% <100.00%> (+<0.01%) ⬆️
DIT 95.35% <ø> (ø)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@ErikQQY
Copy link
Contributor Author

ErikQQY commented Jan 6, 2025

CI failures seem not related.

@gdalle
Copy link
Member

gdalle commented Jan 6, 2025

Thanks for getting this started! In terms of design, I think we should not just return an eltype but instead we should return the full dualized array. The reason being that some backends (like Mooncake) don't take an elementwise dual.
Could you also make sure that code coverage does not decrease?

@ErikQQY
Copy link
Contributor Author

ErikQQY commented Jan 7, 2025

All tests passed now.

@ErikQQY
Copy link
Contributor Author

ErikQQY commented Jan 7, 2025

Also implemented the overloaded_inputs for derivative and gradient preparations.

@ErikQQY
Copy link
Contributor Author

ErikQQY commented Jan 9, 2025

I believe this PR is ready now

@gdalle
Copy link
Member

gdalle commented Jan 11, 2025

thanks, I'll take a look!

@gdalle
Copy link
Member

gdalle commented Jan 13, 2025

@ErikQQY I have given this some more thought and I'm not sure the abstraction we picked is the correct one.

Symbolic backends also rely on operator overloading, but they will only call the function on "overloaded inputs" once, during preparation. This serves to compile a fast version of the derivative, which is then applied to Float64 inputs. So it is very different from ForwardDiff and ReverseDiff, where the "overloaded inputs" are used every time.

In general, we have no guarantee that the function f is called, either during preparation, or during prepared or unprepared differentiation.

@gdalle gdalle changed the title Add overloaded_inputs for preparations feat: Jan 13, 2025
@gdalle
Copy link
Member

gdalle commented Jan 13, 2025

I tried to express this in the docstring but it may not be precise enough

@gdalle gdalle changed the title feat: feat: access overloaded inputs from preparation result Jan 13, 2025
@ErikQQY
Copy link
Contributor Author

ErikQQY commented Jan 13, 2025

Oh we are returning the eltype of the dualized array now 👍.

Symbolic backends also rely on operator overloading, but they will only call the function on "overloaded inputs" once, during preparation.

So we also need the overload_input_type for symbolic backends preparation?

@gdalle
Copy link
Member

gdalle commented Jan 13, 2025

Not exactly, those actually are two different concepts.

  • For ForwardDiff and ReverseDiff, the overloaded type is used at every call to the differentiation routine, which is obtained by going through the function itself.
  • For Symbolics, the overloaded type is only used once during preparation, after which the initial function is discarded and runtime-generated functions are used instead which directly encode the derivatives, hence they can be applied to Float64 directly.
  • For source transformation backends like Zygote and Enzyme, the function may or may not be called, but there are no overloaded types to speak of.

I guess I'm wondering how to get this point across through a better choice of terminology. And I'm also wondering whether we can find yet another solution for this problem, because introducing something which behaves so differently across backends sounds like a footgun.

@gdalle
Copy link
Member

gdalle commented Jan 14, 2025

Let's merge this, see how it goes downstream, and then solidify the interface if necessary. It's explicitly marked as not public for the time being so there's no harm

@gdalle gdalle merged commit 4feb596 into JuliaDiff:main Jan 14, 2025
48 checks passed
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Add API for overloaded inputs
2 participants