-
Notifications
You must be signed in to change notification settings - Fork 247
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Problematic transformations while streamlining of scaled dot-product attention #878
Comments
Another problematic transformation is the However, we still need a way to deal with these scalar adds to properly streamline the attention pattern. Currently I am trying the following idea: Instead of streamlining the If I ran the streamlining unit test (like |
Re-post of insights I gained while looking at the related issue #892, to have this documented here as well: For me it seems like currently all occurrences of the |
As we are gradually moving towards more realistic and complete models using the Brevitas quantized multi-head attention, we are seeing even more issues:
|
Hm, the issue regarding the inverse of the scale after the For now, this can be solved by detecting this situation (there are just two possible configurations) and flipping the indices if the quantized initializer feeds the first input. |
Most of the issues seem to boil down to violations of some input order assumptions regarding "dynamic", i.e., produced upstream, vs. initializer inputs. I will try to fix this by a new cleanup transformation in QONNX and a few conditions here and there, not sure whether this will be a sustainable solution. |
|
Quick summary
This is not really a bug report, it is more a collection of missing support and minor annoyances I am encountering while working towards streamlining and eventually mapping a scaled dot-product attention operator. This list might grow over time and is rather meant to start a discussion and document these problems for others.
Details
I am currently playing around with with QONNX and FINN graph transformations applied to some dummy single-head scaled dot-product attention operator stripped down to its bare minimum (i.e. it is tiny, has no input/output projections, no weights, no masking, ...). Essentially, this is just comprising a chain of
MatMul-Softmax-MatMul
operators with some Brevitas quantizers in between. I want to understand the streamlining process and eventually work towards mapping this operator pattern to some custom HLS operators (this is all WIP). Doing so, I have encountered a few, probably small problems in some of the transforms, mostly related to FINN assuming a MatMul operator always to involve one input and one weight initializer, which is not the case for the two-input, no weights MatMuls within scaled dot-product attention, i.e. in queries x keys, both are inputs produced by the preceding layer of input projections. In the following I will list the problematic transformations and some ideas how to fix the problem (I might add to this list over time if I find more or gain some further insights, either via edit or in the comments below):MoveScalarMulPastMatMul
always expects a weight matrix initializer as right hand side. As mentioned before, scaled dot-product attention involves MatMuls with two dynamic inputs, both of which may have a scalar multiplication. This transformation is currently skipped due tois_fork_node
andis_join_node
queries (two-input MatMul is a join node) and due to testing for the presence of weight initializers. Moving any or both of the scalar multiplications past the MatMul should be a valid transformation for two-input MatMul operators as well. This probably requires some refactoring of the transformation, as simply removing these checks seems to lead to detached sub-graphs.Absorb1BitMulIntoMatMul
andAbsorb1BitMulIntoConv
always test for the presence of weight initializers via assertions, causing a program to terminate instead of simply ignoring two-input MatMul operators without weights. In particular, this means the wholeStreamline
transformation (which among others contains these two) is not applicable when a scaled dot-product attention operator (i.e. a two-input MatMul, but no weights) appears anywhere in a model graph. This can probably be fixed by turning the assertions into simple tests (i.e.if ...:
) skipping the application of the transformation.InferShapes
fails afterFoldTransposeIntoQuantInit
: This is probably a bug, but I am not sure whether the transpose of the keys input to the first MatMul of the attention operator should actually be folded, as it is probably just part of the pattern we want to detect and map to our new custom-op. However, as both transforms (in this order) are part of theConvertQONNXtoFINN
transform, this needs to be fixed. I do not really know why this happens or how to fix it, but it fails withShapeInferenceError
somewhere withinonnx.shape_inference.infer_shapes
(so not even within FINN or QONNX), but the cause of this might be higher up?Expected behavior
Eventually I want to be able to streamline the scaled dot-product attention operator to only contain MatMul, Softmax and MultiThreshold operators using FINN's
ConvertQONNXtoFINN
andStreamline
transformations "out of the box".Steps to Reproduce
You can have a look at https://github.com/iksnagreb/attention-dummy for code I use to create my dummy operator and apply some transformations. Note that the code pulls from my fork's feature branch of FINN, but the current dev branch should do as well. I will also attach the ONNX export of the generated dummy operator (without any transforms) here: attention.onnx.zip.
Possible fix
I have already mentioned my current understanding of these problems and ideas to solve them above. It would be nice to get some input/guidance on how to solve it or at least how to work around it. I will continue to work on the problems and might have to adjust/add some transformations anyway to support a scaled dot-product attention operator. I will add any new insights here and will be happy to contribute fixes via PR.
The text was updated successfully, but these errors were encountered: