Skip to content

Commit 7f7a50d

Browse files
authored
Merge branch 'master' into sfloat
2 parents dbff487 + ee12ab8 commit 7f7a50d

File tree

8 files changed

+216
-5
lines changed

8 files changed

+216
-5
lines changed

examples/cifar/models/vgg.rb

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
class Block < Chainer::Chain
2+
def initialize(out_channels, ksize, pad: 1)
3+
super()
4+
init_scope do
5+
@conv = Chainer::Links::Connection::Convolution2D.new(nil, out_channels, ksize, pad: pad, nobias: true)
6+
@bn = Chainer::Links::Normalization::BatchNormalization.new(out_channels)
7+
end
8+
end
9+
10+
def call(x)
11+
h = @conv.(x)
12+
h = @bn.(h)
13+
Chainer::Functions::Activation::Relu.relu(h)
14+
end
15+
end
16+
17+
class VGG < Chainer::Chain
18+
def initialize(class_labels: 10)
19+
super()
20+
init_scope do
21+
@block1_1 = Block.new(64, 3)
22+
@block1_2 = Block.new(64, 3)
23+
@block2_1 = Block.new(128, 3)
24+
@block2_2 = Block.new(128, 3)
25+
@block3_1 = Block.new(256, 3)
26+
@block3_2 = Block.new(256, 3)
27+
@block3_3 = Block.new(256, 3)
28+
@block4_1 = Block.new(512, 3)
29+
@block4_2 = Block.new(512, 3)
30+
@block4_3 = Block.new(512, 3)
31+
@block5_1 = Block.new(512, 3)
32+
@block5_2 = Block.new(512, 3)
33+
@block5_3 = Block.new(512, 3)
34+
@fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true)
35+
@bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512)
36+
@fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true)
37+
end
38+
end
39+
40+
def call(x)
41+
# 64 channel blocks:
42+
h = @block1_1.(x)
43+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.3)
44+
h = @block1_2.(h)
45+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
46+
47+
# 128 channel blocks:
48+
h = @block2_1.(h)
49+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
50+
h = @block2_2.(h)
51+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride:2)
52+
53+
# 256 channel blocks:
54+
h = @block3_1.(h)
55+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
56+
h = @block3_2.(h)
57+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
58+
h = @block3_3.(h)
59+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
60+
61+
# 512 channel blocks:
62+
h = @block4_1.(h)
63+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
64+
h = @block4_2.(h)
65+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
66+
h = @block4_3.(h)
67+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
68+
69+
# 512 channel blocks:
70+
h = @block5_1.(h)
71+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
72+
h = @block5_2.(h)
73+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
74+
h = @block5_3.(h)
75+
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)
76+
77+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
78+
h = @fc1.(h)
79+
h = @bn_fc1.(h)
80+
h = Chainer::Functions::Activation::Relu.relu(h)
81+
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
82+
@fc2.(h)
83+
end
84+
end

examples/cifar/train_cifar.rb

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
require 'chainer'
2+
require __dir__ + '/models/vgg'
3+
require 'optparse'
4+
5+
args = {
6+
dataset: 'cifar10',
7+
frequency: -1,
8+
batchsize: 64,
9+
learnrate: 0.05,
10+
epoch: 300,
11+
out: 'result',
12+
resume: nil
13+
}
14+
15+
16+
opt = OptionParser.new
17+
opt.on('-d', '--dataset VALUE', "The dataset to use: cifar10 or cifar100 (default: #{args[:dataset]})") { |v| args[:dataset] = v }
18+
opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
19+
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
20+
opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
21+
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
22+
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
23+
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
24+
opt.parse!(ARGV)
25+
26+
# Set up a neural network to train.
27+
# Classifier reports softmax cross entropy loss and accuracy at every
28+
# iteration, which will be used by the PrintReport extension below.
29+
if args[:dataset] == 'cifar10'
30+
puts 'Using CIFAR10 dataset.'
31+
class_labels = 10
32+
train, test = Chainer::Datasets::CIFAR.get_cifar10
33+
elsif args[:dataset] == 'cifar100'
34+
puts 'Using CIFAR100 dataset.'
35+
class_labels = 100
36+
train, test = Chainer::Datasets::CIFAR.get_cifar100
37+
else
38+
raise 'Invalid dataset choice.'
39+
end
40+
41+
puts "setup..."
42+
43+
model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels))
44+
45+
optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
46+
optimizer.setup(model)
47+
48+
train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
49+
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)
50+
51+
updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
52+
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])
53+
54+
trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))
55+
56+
trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])
57+
58+
frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
59+
trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'])
60+
61+
trainer.extend(Chainer::Training::Extensions::LogReport.new)
62+
trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
63+
trainer.extend(Chainer::Training::Extensions::ProgressBar.new)
64+
65+
if args[:resume]
66+
Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
67+
end
68+
69+
trainer.run
70+

lib/chainer.rb

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@
6262
require 'chainer/optimizers/momentum_sgd'
6363
require 'chainer/dataset/download'
6464
require 'chainer/datasets/mnist'
65+
require 'chainer/datasets/cifar'
6566
require 'chainer/datasets/tuple_dataset'
6667
require 'chainer/reporter'
6768
require 'chainer/serializer'

lib/chainer/datasets/cifar.rb

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
require 'datasets'
2+
3+
module Chainer
4+
module Datasets
5+
module CIFAR
6+
def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0)
7+
get_cifar(10, with_label, ndim, scale)
8+
end
9+
10+
def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0)
11+
get_cifar(100, with_label, ndim, scale)
12+
end
13+
14+
def self.get_cifar(n_classes, with_label, ndim, scale)
15+
train_data = []
16+
train_labels = []
17+
::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record|
18+
train_data << record.pixels
19+
train_labels << (n_classes == 10 ? record.label : record.fine_label)
20+
end
21+
22+
test_data = []
23+
test_labels = []
24+
::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record|
25+
test_data << record.pixels
26+
test_labels << (n_classes == 10 ? record.label : record.fine_label)
27+
end
28+
29+
[
30+
preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
31+
preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
32+
]
33+
end
34+
35+
def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
36+
if ndim == 1
37+
images = images.reshape(images.shape[0], 3072)
38+
elsif ndim == 3
39+
images = images.reshape(images.shape[0], 3, 32, 32)
40+
else
41+
raise 'invalid ndim for CIFAR dataset'
42+
end
43+
images = images.cast_to(Numo::DFloat)
44+
images *= scale / 255.0
45+
46+
if withlabel
47+
labels = labels.cast_to(Numo::Int32)
48+
TupleDataset.new(images, labels)
49+
else
50+
images
51+
end
52+
end
53+
end
54+
end
55+
end
56+

lib/chainer/iterators/serial_iterator.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,8 @@ def next
1818
@previous_epoch_detail = epoch_detail
1919

2020
i = @current_position
21-
i_end = i + @batch_size
2221
n = @dataset.size
22+
i_end = [i + @batch_size, n].min
2323

2424
batch = @order[i...i_end].to_a.map { |index| @dataset[index] }
2525

lib/chainer/links/normalization/batch_normalization.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class BatchNormalization < Chainer::Link
2323
# @param [Numo::NArray.dtype] dtype Type to use in computing.
2424
# @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
2525
# @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
26-
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::Float32, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
26+
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::DFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
2727
super()
2828
@avg_mean = dtype.zeros(size)
2929
register_persistent('avg_mean')

lib/chainer/serializers/marshal.rb

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@ def call(key, value)
2828
arr = Numo::Bit[1]
2929
elsif value.is_a?(FalseClass)
3030
arr = Numo::Bit[0]
31-
elsif value.instance_of?(String)
31+
elsif value.instance_of?(String) || value.nil?
3232
arr = value
3333
else
3434
arr = Numo::NArray.cast(value)

lib/chainer/utils/variable.rb

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,8 @@ def self.check_grad_type(func, x, gx)
66
return
77
end
88

9-
unless gx.instance_of?(x.data.class)
10-
raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}"
9+
unless gx.is_a?(x.data.class.superclass)
10+
raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
1111
end
1212

1313
unless gx.class == x.data.class

0 commit comments

Comments
 (0)