Skip to content
This repository was archived by the owner on Nov 3, 2022. It is now read-only.

Commit 235ca35

Browse files
committed
Update kernel_initializer for ResNet50
1 parent fbf035b commit 235ca35

File tree

1 file changed

+14
-3
lines changed

1 file changed

+14
-3
lines changed

keras_applications/resnet50.py

+14-3
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,21 @@ def identity_block(input_tensor, kernel_size, filters, stage, block):
5959
bn_name_base = 'bn' + str(stage) + block + '_branch'
6060

6161
x = layers.Conv2D(filters1, (1, 1),
62+
kernel_initializer='he_normal',
6263
name=conv_name_base + '2a')(input_tensor)
6364
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
6465
x = layers.Activation('relu')(x)
6566

6667
x = layers.Conv2D(filters2, kernel_size,
67-
padding='same', name=conv_name_base + '2b')(x)
68+
padding='same',
69+
kernel_initializer='he_normal',
70+
name=conv_name_base + '2b')(x)
6871
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
6972
x = layers.Activation('relu')(x)
7073

71-
x = layers.Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
74+
x = layers.Conv2D(filters3, (1, 1),
75+
kernel_initializer='he_normal',
76+
name=conv_name_base + '2c')(x)
7277
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
7378

7479
x = layers.add([x, input_tensor])
@@ -109,19 +114,24 @@ def conv_block(input_tensor,
109114
bn_name_base = 'bn' + str(stage) + block + '_branch'
110115

111116
x = layers.Conv2D(filters1, (1, 1), strides=strides,
117+
kernel_initializer='he_normal',
112118
name=conv_name_base + '2a')(input_tensor)
113119
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2a')(x)
114120
x = layers.Activation('relu')(x)
115121

116122
x = layers.Conv2D(filters2, kernel_size, padding='same',
123+
kernel_initializer='he_normal',
117124
name=conv_name_base + '2b')(x)
118125
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2b')(x)
119126
x = layers.Activation('relu')(x)
120127

121-
x = layers.Conv2D(filters3, (1, 1), name=conv_name_base + '2c')(x)
128+
x = layers.Conv2D(filters3, (1, 1),
129+
kernel_initializer='he_normal',
130+
name=conv_name_base + '2c')(x)
122131
x = layers.BatchNormalization(axis=bn_axis, name=bn_name_base + '2c')(x)
123132

124133
shortcut = layers.Conv2D(filters3, (1, 1), strides=strides,
134+
kernel_initializer='he_normal',
125135
name=conv_name_base + '1')(input_tensor)
126136
shortcut = layers.BatchNormalization(
127137
axis=bn_axis, name=bn_name_base + '1')(shortcut)
@@ -214,6 +224,7 @@ def ResNet50(include_top=True,
214224
x = layers.Conv2D(64, (7, 7),
215225
strides=(2, 2),
216226
padding='valid',
227+
kernel_initializer='he_normal',
217228
name='conv1')(x)
218229
x = layers.BatchNormalization(axis=bn_axis, name='bn_conv1')(x)
219230
x = layers.Activation('relu')(x)

0 commit comments

Comments
 (0)