You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
@penelopeysm discovered in TuringLang/docs#559 (comment) that Mooncake's current rule for matrix-matrix multiplication in LuxLib doesn't successfully handle the case that the two input arrays contain numbers at different precisions.
It really shouldn't be too hard to handle this properly -- we would just need to re-write the current from_rrule implementations of the various variants of matrix-matrix multiplication found around here to have proper Mooncake rules which
define is_primitive for arrays whose elements are subtypes of IEEEFloat, and
define the rrule!s in such a way that ensures the correct element-type is adhered to.
add a method of rrule!! which is a catch-all for all other element types, which always errors with some kind of sensible error message that users can make use of to know how to modify their code.
Note: this is also a great opportunity to ensure excellent performance in Lux.jl -- the current implementations of the rules involve more allocations than are really needed, because we do not increment the gradients in-place.
The text was updated successfully, but these errors were encountered:
Could this be a potential duplicate of #196? In that PR, I explicitly suggested supporting mixed precision, but I agree it is better to refrain from performing any promotion implicitly.
@penelopeysm discovered in TuringLang/docs#559 (comment) that Mooncake's current rule for matrix-matrix multiplication in LuxLib doesn't successfully handle the case that the two input arrays contain numbers at different precisions.
It really shouldn't be too hard to handle this properly -- we would just need to re-write the current
from_rrule
implementations of the various variants of matrix-matrix multiplication found around here to have proper Mooncake rules whichis_primitive
for arrays whose elements are subtypes ofIEEEFloat
, andrrule!
s in such a way that ensures the correct element-type is adhered to.rrule!!
which is a catch-all for all other element types, which always errors with some kind of sensible error message that users can make use of to know how to modify their code.Note: this is also a great opportunity to ensure excellent performance in Lux.jl -- the current implementations of the rules involve more allocations than are really needed, because we do not increment the gradients in-place.
The text was updated successfully, but these errors were encountered: