Skip to content

Commit

Permalink
Merge pull request #49 from hatappi/feature/cifar
Browse files Browse the repository at this point in the history
add CIFAR example
  • Loading branch information
hatappi authored May 5, 2018
2 parents a7a18df + 2f82cc3 commit ee12ab8
Show file tree
Hide file tree
Showing 8 changed files with 216 additions and 5 deletions.
84 changes: 84 additions & 0 deletions examples/cifar/models/vgg.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
class Block < Chainer::Chain
def initialize(out_channels, ksize, pad: 1)
super()
init_scope do
@conv = Chainer::Links::Connection::Convolution2D.new(nil, out_channels, ksize, pad: pad, nobias: true)
@bn = Chainer::Links::Normalization::BatchNormalization.new(out_channels)
end
end

def call(x)
h = @conv.(x)
h = @bn.(h)
Chainer::Functions::Activation::Relu.relu(h)
end
end

class VGG < Chainer::Chain
def initialize(class_labels: 10)
super()
init_scope do
@block1_1 = Block.new(64, 3)
@block1_2 = Block.new(64, 3)
@block2_1 = Block.new(128, 3)
@block2_2 = Block.new(128, 3)
@block3_1 = Block.new(256, 3)
@block3_2 = Block.new(256, 3)
@block3_3 = Block.new(256, 3)
@block4_1 = Block.new(512, 3)
@block4_2 = Block.new(512, 3)
@block4_3 = Block.new(512, 3)
@block5_1 = Block.new(512, 3)
@block5_2 = Block.new(512, 3)
@block5_3 = Block.new(512, 3)
@fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true)
@bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512)
@fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true)
end
end

def call(x)
# 64 channel blocks:
h = @block1_1.(x)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.3)
h = @block1_2.(h)
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)

# 128 channel blocks:
h = @block2_1.(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
h = @block2_2.(h)
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride:2)

# 256 channel blocks:
h = @block3_1.(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
h = @block3_2.(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
h = @block3_3.(h)
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)

# 512 channel blocks:
h = @block4_1.(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
h = @block4_2.(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
h = @block4_3.(h)
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)

# 512 channel blocks:
h = @block5_1.(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
h = @block5_2.(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4)
h = @block5_3.(h)
h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2)

h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
h = @fc1.(h)
h = @bn_fc1.(h)
h = Chainer::Functions::Activation::Relu.relu(h)
h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5)
@fc2.(h)
end
end
70 changes: 70 additions & 0 deletions examples/cifar/train_cifar.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,70 @@
require 'chainer'
require __dir__ + '/models/vgg'
require 'optparse'

args = {
dataset: 'cifar10',
frequency: -1,
batchsize: 64,
learnrate: 0.05,
epoch: 300,
out: 'result',
resume: nil
}


opt = OptionParser.new
opt.on('-d', '--dataset VALUE', "The dataset to use: cifar10 or cifar100 (default: #{args[:dataset]})") { |v| args[:dataset] = v }
opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i }
opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i }
opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f }
opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i }
opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v }
opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v }
opt.parse!(ARGV)

# Set up a neural network to train.
# Classifier reports softmax cross entropy loss and accuracy at every
# iteration, which will be used by the PrintReport extension below.
if args[:dataset] == 'cifar10'
puts 'Using CIFAR10 dataset.'
class_labels = 10
train, test = Chainer::Datasets::CIFAR.get_cifar10
elsif args[:dataset] == 'cifar100'
puts 'Using CIFAR100 dataset.'
class_labels = 100
train, test = Chainer::Datasets::CIFAR.get_cifar100
else
raise 'Invalid dataset choice.'
end

puts "setup..."

model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels))

optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
optimizer.setup(model)

train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)

updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1)
trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out])

trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1))

trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch'])

frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max
trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch'])

trainer.extend(Chainer::Training::Extensions::LogReport.new)
trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))
trainer.extend(Chainer::Training::Extensions::ProgressBar.new)

if args[:resume]
Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer)
end

trainer.run

1 change: 1 addition & 0 deletions lib/chainer.rb
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
require 'chainer/optimizers/momentum_sgd'
require 'chainer/dataset/download'
require 'chainer/datasets/mnist'
require 'chainer/datasets/cifar'
require 'chainer/datasets/tuple_dataset'
require 'chainer/reporter'
require 'chainer/serializer'
Expand Down
56 changes: 56 additions & 0 deletions lib/chainer/datasets/cifar.rb
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
require 'datasets'

module Chainer
module Datasets
module CIFAR
def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0)
get_cifar(10, with_label, ndim, scale)
end

def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0)
get_cifar(100, with_label, ndim, scale)
end

def self.get_cifar(n_classes, with_label, ndim, scale)
train_data = []
train_labels = []
::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record|
train_data << record.pixels
train_labels << (n_classes == 10 ? record.label : record.fine_label)
end

test_data = []
test_labels = []
::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record|
test_data << record.pixels
test_labels << (n_classes == 10 ? record.label : record.fine_label)
end

[
preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
]
end

def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
if ndim == 1
images = images.reshape(images.shape[0], 3072)
elsif ndim == 3
images = images.reshape(images.shape[0], 3, 32, 32)
else
raise 'invalid ndim for CIFAR dataset'
end
images = images.cast_to(Numo::DFloat)
images *= scale / 255.0

if withlabel
labels = labels.cast_to(Numo::Int32)
TupleDataset.new(images, labels)
else
images
end
end
end
end
end

2 changes: 1 addition & 1 deletion lib/chainer/iterators/serial_iterator.rb
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,8 @@ def next
@previous_epoch_detail = epoch_detail

i = @current_position
i_end = i + @batch_size
n = @dataset.size
i_end = [i + @batch_size, n].min

batch = @order[i...i_end].to_a.map { |index| @dataset[index] }

Expand Down
2 changes: 1 addition & 1 deletion lib/chainer/links/normalization/batch_normalization.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BatchNormalization < Chainer::Link
# @param [Numo::NArray.dtype] dtype Type to use in computing.
# @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
# @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::Float32, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::DFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
super()
@avg_mean = dtype.zeros(size)
register_persistent('avg_mean')
Expand Down
2 changes: 1 addition & 1 deletion lib/chainer/serializers/marshal.rb
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def call(key, value)
arr = Numo::Bit[1]
elsif value.is_a?(FalseClass)
arr = Numo::Bit[0]
elsif value.instance_of?(String)
elsif value.instance_of?(String) || value.nil?
arr = value
else
arr = Numo::NArray.cast(value)
Expand Down
4 changes: 2 additions & 2 deletions lib/chainer/utils/variable.rb
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@ def self.check_grad_type(func, x, gx)
return
end

unless gx.instance_of?(x.data.class)
raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}"
unless gx.is_a?(x.data.class.superclass)
raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}"
end

unless gx.class == x.data.class
Expand Down

0 comments on commit ee12ab8

Please sign in to comment.