Skip to content
This repository was archived by the owner on Mar 17, 2022. It is now read-only.
This repository was archived by the owner on Mar 17, 2022. It is now read-only.

Batchnorm from pytorch to keras #7

@dathath

Description

@dathath

I'm having issues in transferring models with batchnorm layers from pytorch to keras. Other way round works perfectly fine. Any thoughts? Appreciate the help!

Here are the two architectures I am testing:
Keras Model:
model = Sequential()
model.add(Conv2D(6, kernel_size=(5, 5),
activation='relu',
input_shape=(1, 28, 28),
name='conv1'))
model.add(BatchNormalization(axis=1,name='bnm1',momentum=0.1))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Conv2D(16, (5, 5), activation='relu', name='conv2'))
model.add(BatchNormalization(axis=1,name='bnm2',momentum=0.1))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(Flatten())
model.add(Dense(120, activation='relu', name='fc1'))
model.add(BatchNormalization(axis=1,name='bnm3',momentum=0.1))
model.add(Dense(84, activation='relu', name='fc2'))
model.add(BatchNormalization(axis=1,name='bnm4',momentum=0.1))
model.add(Dense(10, activation=None, name='fc3'))
model.add(Activation('softmax'))
model.compile(
loss=cross_entropy,
optimizer='adadelta',
metrics=['accuracy']
)

Pytorch Model:
class LeNet(nn.Module):
def init(self):
super(LeNet, self).init()
self.conv1 = nn.Conv2d(1, 6, 5)
self.bnm1 = nn.BatchNorm2d(6, momentum=0.1)
self.conv2 = nn.Conv2d(6, 16, 5)
self.bnm2 = nn.BatchNorm2d(16, momentum=0.1)
self.fc1 = nn.Linear(256, 120)
self.bnm3 = nn.BatchNorm1d(120, momentum=0.1)
self.fc2 = nn.Linear(120, 84)
self.bnm4 = nn.BatchNorm1d(84, momentum=0.1)
self.fc3 = nn.Linear(84, 10)

def forward(self, x):
    out = F.relu(self.conv1(x))
    #out = self.bnm1(out)
    out = F.max_pool2d(out, 2)
    out = F.relu(self.conv2(out))
    #out = self.bnm2(out)
    out = F.max_pool2d(out, 2)
    out = out.view(out.size(0), -1)
    out = F.relu(self.fc1(out))
    #out = self.bnm3(out)
    out = F.relu(self.fc2(out))
    #out = self.bnm4(out)
    out = self.fc3(out)
    return (out)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions