@@ -80,6 +80,51 @@ pnnx.Output output 1 0 out
8080
8181REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS (F_scaled_dot_product_attention_1, 10 )
8282
83+ class F_scaled_dot_product_attention_2 : public GraphRewriterPass
84+ {
85+ public:
86+ const char * match_pattern_graph () const
87+ {
88+ return R"PNNXIR( 7767517
89+ 10 9
90+ pnnx.Input input_0 0 1 query
91+ pnnx.Input input_1 0 1 key
92+ pnnx.Input input_2 0 1 value
93+ pnnx.Input input_3 0 1 attn_mask
94+ prim::Constant op_0 0 1 dropout_p value=%dropout_p
95+ prim::Constant op_1 0 1 is_causal value=%is_causal
96+ prim::Constant op_2 0 1 scale value=%scale
97+ prim::Constant op_3 0 1 enable_gqa value=%enable_gqa
98+ aten::scaled_dot_product_attention op_4 8 1 query key value attn_mask dropout_p is_causal scale enable_gqa out
99+ pnnx.Output output 1 0 out
100+ )PNNXIR" ;
101+ }
102+
103+ const char * type_str () const
104+ {
105+ return " F.scaled_dot_product_attention" ;
106+ }
107+
108+ void write (Operator* op, const std::map<std::string, Parameter>& captured_params, const std::map<std::string, Attribute>& captured_attrs) const
109+ {
110+ GraphRewriterPass::write (op, captured_params, captured_attrs);
111+
112+ if (captured_params.at (" scale" ).type == 0 )
113+ {
114+ // drop scale=None for compatibility with old torch
115+ op->params .erase (" scale" );
116+ }
117+
118+ if (captured_params.at (" enable_gqa" ).type == 1 && captured_params.at (" enable_gqa" ).b == false )
119+ {
120+ // drop enable_gqa=False for compatibility with old torch
121+ op->params .erase (" enable_gqa" );
122+ }
123+ }
124+ };
125+
126+ REGISTER_GLOBAL_PNNX_GRAPH_REWRITER_PASS (F_scaled_dot_product_attention_2, 10 )
127+
83128static bool NearlyEqual (float a, float b, float epsilon)
84129{
85130 if (a == b)
0 commit comments