RFC: strip most types from gradient
output
#1362
Draft
+121
−4
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
This is a draft of a way to start addressing #1334, for comment.
It implements what I called level 2 here: #1334 (comment)
On arrays like these, no change. Natural and structural representations agree:
On arrays like these, it does not know how to construct the natural representation, so doesn't try:
(I know how, but the fields of the result will not line up with the existing ones.)
Arrays of non-diff objects cannot be wrapped up in array structs:
make_zeros
uses an IdDict cache to preserve identity between different branches of the struct.At present this does not...
A simple Flux model, no functional change, just looks different to the model:
Comments:
I'm not sure what I think about the failure to preserve
===
relations between some mutable objects in the original gradient. Some of this could be solved by adding an IdDict cache likemake_zeros
does.The function called
strip_types
for now probably needs to be public, so that you can call it yourself after constructingdx = make_zero(x)
, and so that you can overload it for your array wrappers.Projecting things like Symmetric to their covariant representation probably needs to be opt-in, by somehow telling
gradient
that you want this. (That's level 4 here: Supporting covariant derivatives #1334 (comment) .) Could be implemented as additional methods of this function, something likestrip_types(x, dx, Val(true), cache)
?Surely all the code is in the wrong place, and needs tests.