Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

ASKCResNetFPN #26

Open
YimianDai opened this issue Jun 1, 2021 · 0 comments
Open

ASKCResNetFPN #26

YimianDai opened this issue Jun 1, 2021 · 0 comments

Comments

@YimianDai
Copy link
Owner

from __future__ import division
import os
from mxnet.gluon.block import HybridBlock
from mxnet.gluon import nn
from mxnet.gluon.nn import BatchNorm
from gluoncv.model_zoo.fcn import _FCNHead
from mxnet import nd

from .askc import LCNASKCFuse

from model.atac.backbone import ATACBlockV1, conv1ATAC, DynamicCell
from model.atac.convolution import LearnedCell, ChaDyReFCell, SeqDyReFCell, SK_ChaDyReFCell, \
    SK_1x1DepthDyReFCell, SK_MSSpaDyReFCell, SK_SpaDyReFCell, Direct_AddCell, SKCell, \
    SK_SeqDyReFCell, Sub_MSSpaDyReFCell, SK_MSSeqDyReFCell, iAAMSSpaDyReFCell
from model.atac.convolution import \
    LearnedConv, ChaDyReFConv, SeqDyReFConv, SK_ChaDyReFConv, \
    SK_1x1DepthDyReFConv, SK_MSSpaDyReFConv, SK_SpaDyReFConv, Direct_AddConv, SKConv, \
    SK_SeqDyReFConv
    # , SK_MSSeqDyReFConv
from .activation import xUnit, SpaATAC, ChaATAC, SeqATAC, MSSeqATAC, MSSeqATACAdd, \
    MSSeqATACConcat, MSSeqAttentionMap, xUnitAttentionMap
from model.atac.fusion import Direct_AddFuse_Reduce, SK_MSSpaFuse, SKFuse_Reduce, LocalChaFuse, \
    GlobalChaFuse, \
    LocalGlobalChaFuse_Reduce, LocalLocalChaFuse_Reduce, GlobalGlobalChaFuse_Reduce, \
    AYforXplusYChaFuse_Reduce, XplusAYforYChaFuse_Reduce, IASKCChaFuse_Reduce,\
    GAUChaFuse_Reduce, SpaFuse_Reduce, ConcatFuse_Reduce, AXYforXplusYChaFuse_Reduce,\
    BiLocalChaFuse_Reduce, BiGlobalChaFuse_Reduce, LocalGAUChaFuse_Reduce, GlobalSpaFuse,\
    AsymBiLocalChaFuse_Reduce, BiSpaChaFuse_Reduce, AsymBiSpaChaFuse_Reduce, LocalSpaFuse, \
    BiGlobalLocalChaFuse_Reduce

# from gluoncv.model_zoo.resnetv1b import BasicBlockV1b
from gluoncv.model_zoo.cifarresnet import CIFARBasicBlockV1


class ASKCResNetFPN(HybridBlock):
    def __init__(self, layers, channels, fuse_mode, act_dilation, classes=1, tinyFlag=False,
                 norm_layer=BatchNorm, norm_kwargs=None, **kwargs):
        super(ASKCResNetFPN, self).__init__(**kwargs)

        self.layer_num = len(layers)
        self.tinyFlag = tinyFlag
        with self.name_scope():

            stem_width = int(channels[0])
            self.stem = nn.HybridSequential(prefix='stem')
            self.stem.add(norm_layer(scale=False, center=False,
                                     **({} if norm_kwargs is None else norm_kwargs)))
            if tinyFlag:
                self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width*2))
                self.stem.add(nn.Activation('relu'))
            else:
                self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=2,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width))
                self.stem.add(nn.Activation('relu'))
                self.stem.add(nn.Conv2D(channels=stem_width, kernel_size=3, strides=1,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width))
                self.stem.add(nn.Activation('relu'))
                self.stem.add(nn.Conv2D(channels=stem_width*2, kernel_size=3, strides=1,
                                         padding=1, use_bias=False))
                self.stem.add(norm_layer(in_channels=stem_width*2))
                self.stem.add(nn.Activation('relu'))
                self.stem.add(nn.MaxPool2D(pool_size=3, strides=2, padding=1))

            # self.head1 = _FCNHead(in_channels=channels[1], channels=classes)
            # self.head2 = _FCNHead(in_channels=channels[2], channels=classes)
            # self.head3 = _FCNHead(in_channels=channels[3], channels=classes)
            # self.head4 = _FCNHead(in_channels=channels[4], channels=classes)

            self.head = _FCNHead(in_channels=channels[1], channels=classes)

            self.layer1 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[0],
                                           channels=channels[1], stride=1, stage_index=1,
                                           in_channels=channels[1])

            self.layer2 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[1],
                                           channels=channels[2], stride=2, stage_index=2,
                                           in_channels=channels[1])

            self.layer3 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[2],
                                           channels=channels[3], stride=2, stage_index=3,
                                           in_channels=channels[2])

            if self.layer_num == 4:
                self.layer4 = self._make_layer(block=CIFARBasicBlockV1, layers=layers[3],
                                               channels=channels[4], stride=2, stage_index=4,
                                               in_channels=channels[3])

            if self.layer_num == 4:
                self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[3],
                                               act_dilation=act_dilation)  # channels[4]

            self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[2],
                                           act_dilation=act_dilation)  # 64
            self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[1],
                                           act_dilation=act_dilation)  # 32

            # if fuse_order == 'reverse':
            #     self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[2])  # channels[2]
            #     self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[3])  # channels[3]
            #     self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]
            # elif fuse_order == 'normal':
	           #  self.fuse34 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]
	           #  self.fuse23 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]
	           #  self.fuse12 = self._fuse_layer(fuse_mode, channels=channels[4])  # channels[4]

    def _make_layer(self, block, layers, channels, stride, stage_index, in_channels=0,
                    norm_layer=BatchNorm, norm_kwargs=None):
        layer = nn.HybridSequential(prefix='stage%d_'%stage_index)
        with layer.name_scope():
            downsample = (channels != in_channels) or (stride != 1)
            layer.add(block(channels, stride, downsample, in_channels=in_channels,
                            prefix='', norm_layer=norm_layer, norm_kwargs=norm_kwargs))
            for _ in range(layers-1):
                layer.add(block(channels, 1, False, in_channels=channels, prefix='',
                                norm_layer=norm_layer, norm_kwargs=norm_kwargs))
        return layer

    def _fuse_layer(self, fuse_mode, channels, act_dilation):
        if fuse_mode == 'Direct_Add':
            fuse_layer = Direct_AddFuse_Reduce(channels=channels)
        elif fuse_mode == 'Concat':
            fuse_layer = ConcatFuse_Reduce(channels=channels)
        elif fuse_mode == 'SK':
            fuse_layer = SKFuse_Reduce(channels=channels)
        # elif fuse_mode == 'LocalCha':
        #     fuse_layer = LocalChaFuse(channels=channels)
        # elif fuse_mode == 'GlobalCha':
        #     fuse_layer = GlobalChaFuse(channels=channels)
        elif fuse_mode == 'LocalGlobalCha':
            fuse_layer = LocalGlobalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'LocalLocalCha':
            fuse_layer = LocalLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'GlobalGlobalCha':
            fuse_layer = GlobalGlobalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'IASKCChaFuse':
            fuse_layer = IASKCChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AYforXplusY':
            fuse_layer = AYforXplusYChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AXYforXplusY':
            fuse_layer = AXYforXplusYChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'XplusAYforY':
            fuse_layer = XplusAYforYChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'GAU':
            fuse_layer = GAUChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'LocalGAU':
            fuse_layer = LocalGAUChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'SpaFuse':
            fuse_layer = SpaFuse_Reduce(channels=channels, act_dialtion=act_dilation)
        elif fuse_mode == 'BiLocalCha':
            fuse_layer = BiLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'BiGlobalLocalCha':
            fuse_layer = BiGlobalLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AsymBiLocalCha':
            fuse_layer = AsymBiLocalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'BiGlobalCha':
            fuse_layer = BiGlobalChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'BiSpaCha':
            fuse_layer = BiSpaChaFuse_Reduce(channels=channels)
        elif fuse_mode == 'AsymBiSpaCha':
            fuse_layer = AsymBiSpaChaFuse_Reduce(channels=channels)
        # elif fuse_mode == 'LocalSpa':
        #     fuse_layer = LocalSpaFuse(channels=channels, act_dilation=act_dilation)
        # elif fuse_mode == 'GlobalSpa':
        #     fuse_layer = GlobalSpaFuse(channels=channels, act_dilation=act_dilation)
        # elif fuse_mode == 'SK_MSSpa':
        #     # fuse_layer.add(SK_MSSpaFuse(channels=channels, act_dilation=act_dilation))
        #     fuse_layer = SK_MSSpaFuse(channels=channels, act_dilation=act_dilation)
        else:
            raise ValueError('Unknown fuse_mode')

        return fuse_layer

    def hybrid_forward(self, F, x):

        _, _, hei, wid = x.shape

        x = self.stem(x)      # down 4, 32
        c1 = self.layer1(x)   # down 4, 32
        c2 = self.layer2(c1)  # down 8, 64
        out = self.layer3(c2)  # down 16, 128
        if self.layer_num == 4:
            c4 = self.layer4(out)  # down 32
            if self.tinyFlag:
                c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4)  # down 4
            else:
                c4 = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16)  # down 16
            out = self.fuse34(c4, out)
        if self.tinyFlag:
            out = F.contrib.BilinearResize2D(out, height=hei//2, width=wid//2)  # down 2, 128
        else:
            out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8)  # down 8, 128
        out = self.fuse23(out, c2)
        if self.tinyFlag:
            out = F.contrib.BilinearResize2D(out, height=hei, width=wid)  # down 1
        else:
            out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4)  # down 8
        out = self.fuse12(out, c1)

        pred = self.head(out)
        if self.tinyFlag:
            out = pred
        else:
            out = F.contrib.BilinearResize2D(pred, height=hei, width=wid)  # down 4

        ######### reverse order ##########
        # up_c2 = F.contrib.BilinearResize2D(c2, height=hei//4, width=wid//4)  # down 4
        # fuse2 = self.fuse12(up_c2, c1)  # down 4, channels[2]
        #
        # up_c3 = F.contrib.BilinearResize2D(c3, height=hei//4, width=wid//4)  # down 4
        # fuse3 = self.fuse23(up_c3, fuse2)  # down 4, channels[3]
        #
        # up_c4 = F.contrib.BilinearResize2D(c4, height=hei//4, width=wid//4)  # down 4
        # fuse4 = self.fuse34(up_c4, fuse3)  # down 4, channels[4]
        #

        ######### normal order ##########
        # out = F.contrib.BilinearResize2D(c4, height=hei//16, width=wid//16)
        # out = self.fuse34(out, c3)
        # out = F.contrib.BilinearResize2D(out, height=hei//8, width=wid//8)
        # out = self.fuse23(out, c2)
        # out = F.contrib.BilinearResize2D(out, height=hei//4, width=wid//4)
        # out = self.fuse12(out, c1)
        # out = self.head(out)
        # out = F.contrib.BilinearResize2D(out, height=hei, width=wid)


        return out

    def evaluate(self, x):
        """evaluating network with inputs and targets"""
        return self.forward(x)
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant