@@ -62,6 +62,84 @@ static void dumpCmd(const Command* cmd) {
6262 MNN_PRINT (" }\n " );
6363}
6464
65+ void mergeConvolutionAndPrelu (Node* root, MNNForwardType forwardType){
66+ if (root->cmd ->op != nullptr && root->cmd ->op ->type () == OpType_Convolution && root->succ .size () == 1 ) {
67+ auto child = root->succ [0 ];
68+ if (child->cmd ->op ->type () == OpType_PReLU){
69+ if (root->cmd ->op ->externalPath () != nullptr ){
70+ return ;
71+ }
72+ std::shared_ptr<Command> cmdPlugin;
73+ auto inputs = root->cmd ->inputs ;
74+ auto outputs = root->cmd ->outputs ;
75+ auto convOp = root->cmd ->op ->main_as_Convolution2D ();
76+ if (convOp->quanParameter () != nullptr || convOp->symmetricQuan () != nullptr || convOp->sparseParameter () != nullptr || convOp->external () != nullptr || convOp->common ()->outputCount () != child->cmd ->op ->main_as_PRelu ()->slopeCount ()){
77+ return ;
78+ }
79+ std::unique_ptr<OpT> fuseOp (new OpT);
80+ fuseOp->type = OpType_Extra;
81+ fuseOp->name = root->cmd ->op ->name ()->str ();
82+ ExtraT* extra_param = new ExtraT;
83+ extra_param->type = " ExtraConvolution2DPrelu" ;
84+ extra_param->attr .resize (2 );
85+ // copy convolution2D param
86+ AttributeT* convAtr = new AttributeT;
87+ BlobT* convParamBlob = new BlobT;
88+ {
89+ std::unique_ptr<Convolution2DT> convolutionParam (convOp->UnPack ());
90+ flatbuffers::FlatBufferBuilder builder;
91+ auto lastOffset = Convolution2D::Pack (builder, convolutionParam.get ());
92+ builder.Finish (lastOffset);
93+
94+ const uint8_t * buffer_ptr = builder.GetBufferPointer ();
95+ const size_t size = builder.GetSize ();
96+ convParamBlob->uint8s .resize (size);
97+ ::memcpy (convParamBlob->uint8s.data(), buffer_ptr, size);
98+ }
99+ convAtr->tensor .reset (convParamBlob);
100+ extra_param->attr [0 ].reset (convAtr);
101+
102+ // copy prelu param
103+ AttributeT* preluAtr = new AttributeT;
104+ BlobT* preluParamBlob = new BlobT;
105+ {
106+ std::unique_ptr<PReluT> preluParam (child->cmd ->op ->main_as_PRelu ()->UnPack ());
107+ flatbuffers::FlatBufferBuilder builder;
108+ auto lastOffset = PRelu::Pack (builder, preluParam.get ());
109+ builder.Finish (lastOffset);
110+ const uint8_t * buffer_ptr = builder.GetBufferPointer ();
111+ const size_t size = builder.GetSize ();
112+ preluParamBlob->uint8s .resize (size);
113+ ::memcpy (preluParamBlob->uint8s.data(), buffer_ptr, size);
114+ }
115+ preluAtr->tensor .reset (preluParamBlob);
116+ extra_param->attr [1 ].reset (preluAtr);
117+
118+ fuseOp->main .type = OpParameter_Extra;
119+ fuseOp->main .value = extra_param;
120+ flatbuffers::FlatBufferBuilder builder;
121+ auto lastOffset = Op::Pack (builder, fuseOp.get ());
122+ builder.Finish (lastOffset);
123+ cmdPlugin = GeometryComputerUtils::makeCommand (builder, inputs, outputs);
124+
125+ root->cmd ->op = cmdPlugin->op ;
126+ root->cmd ->inputs = cmdPlugin->inputs ;
127+ root->cmd ->outputs = cmdPlugin->outputs ;
128+ root->cmd ->buffer = cmdPlugin->buffer ;
129+ child->cmd ->op = nullptr ;
130+ child->cmd ->buffer .reset ();
131+ for (auto &childNode : child->succ ){
132+ for (auto &input : childNode->cmd ->inputs ){
133+ if (input == child->cmd ->outputs [0 ]){
134+ input = root->cmd ->outputs [0 ];
135+ }
136+ }
137+ }
138+ root->succ = child->succ ;
139+ }
140+ }
141+ }
142+
65143// is legal fused type
66144bool isLegal (Command* cmd, MNNForwardType forwardType) {
67145 auto type = cmd->op ->type ();
@@ -369,6 +447,20 @@ bool opFuse(std::vector<Schedule::OpCacheInfo>& infos, MNNForwardType type, Back
369447 graph.push_back (std::move (node));
370448 }
371449 }
450+
451+ if (type == MNN_FORWARD_OPENCL){
452+ for (int i = 0 ; i < graph.size (); ++i){
453+ mergeConvolutionAndPrelu (graph[i].get (), type);
454+ }
455+ for (auto iter = graph.begin (); iter != graph.end ();){
456+ if (iter->get ()->cmd ->op == nullptr ){
457+ iter = graph.erase (iter);
458+ }else {
459+ ++iter;
460+ }
461+ }
462+ }
463+
372464 std::queue<Node*> postDominateNodeQueue;
373465 // build dominate tree
374466 for (int i = static_cast <int >(graph.size ()) - 1 ; i >= 0 ; i--) {
0 commit comments