diff --git a/examples/cifar/models/vgg.rb b/examples/cifar/models/vgg.rb new file mode 100644 index 0000000..e564ad0 --- /dev/null +++ b/examples/cifar/models/vgg.rb @@ -0,0 +1,84 @@ +class Block < Chainer::Chain + def initialize(out_channels, ksize, pad: 1) + super() + init_scope do + @conv = Chainer::Links::Connection::Convolution2D.new(nil, out_channels, ksize, pad: pad, nobias: true) + @bn = Chainer::Links::Normalization::BatchNormalization.new(out_channels) + end + end + + def call(x) + h = @conv.(x) + h = @bn.(h) + Chainer::Functions::Activation::Relu.relu(h) + end +end + +class VGG < Chainer::Chain + def initialize(class_labels: 10) + super() + init_scope do + @block1_1 = Block.new(64, 3) + @block1_2 = Block.new(64, 3) + @block2_1 = Block.new(128, 3) + @block2_2 = Block.new(128, 3) + @block3_1 = Block.new(256, 3) + @block3_2 = Block.new(256, 3) + @block3_3 = Block.new(256, 3) + @block4_1 = Block.new(512, 3) + @block4_2 = Block.new(512, 3) + @block4_3 = Block.new(512, 3) + @block5_1 = Block.new(512, 3) + @block5_2 = Block.new(512, 3) + @block5_3 = Block.new(512, 3) + @fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true) + @bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512) + @fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true) + end + end + + def call(x) + # 64 channel blocks: + h = @block1_1.(x) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.3) + h = @block1_2.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + # 128 channel blocks: + h = @block2_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block2_2.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride:2) + + # 256 channel blocks: + h = @block3_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block3_2.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block3_3.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + # 512 channel blocks: + h = @block4_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block4_2.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block4_3.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + # 512 channel blocks: + h = @block5_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block5_2.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block5_3.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5) + h = @fc1.(h) + h = @bn_fc1.(h) + h = Chainer::Functions::Activation::Relu.relu(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5) + @fc2.(h) + end +end diff --git a/examples/cifar/train_cifar.rb b/examples/cifar/train_cifar.rb new file mode 100644 index 0000000..034f387 --- /dev/null +++ b/examples/cifar/train_cifar.rb @@ -0,0 +1,70 @@ +require 'chainer' +require __dir__ + '/models/vgg' +require 'optparse' + +args = { + dataset: 'cifar10', + frequency: -1, + batchsize: 64, + learnrate: 0.05, + epoch: 300, + out: 'result', + resume: nil +} + + +opt = OptionParser.new +opt.on('-d', '--dataset VALUE', "The dataset to use: cifar10 or cifar100 (default: #{args[:dataset]})") { |v| args[:dataset] = v } +opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i } +opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i } +opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f } +opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i } +opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v } +opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v } +opt.parse!(ARGV) + +# Set up a neural network to train. +# Classifier reports softmax cross entropy loss and accuracy at every +# iteration, which will be used by the PrintReport extension below. +if args[:dataset] == 'cifar10' + puts 'Using CIFAR10 dataset.' + class_labels = 10 + train, test = Chainer::Datasets::CIFAR.get_cifar10 +elsif args[:dataset] == 'cifar100' + puts 'Using CIFAR100 dataset.' + class_labels = 100 + train, test = Chainer::Datasets::CIFAR.get_cifar100 +else + raise 'Invalid dataset choice.' +end + +puts "setup..." + +model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels)) + +optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate]) +optimizer.setup(model) + +train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize]) +test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false) + +updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1) +trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out]) + +trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1)) + +trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch']) + +frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max +trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch']) + +trainer.extend(Chainer::Training::Extensions::LogReport.new) +trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) +trainer.extend(Chainer::Training::Extensions::ProgressBar.new) + +if args[:resume] + Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer) +end + +trainer.run + diff --git a/lib/chainer.rb b/lib/chainer.rb index 3c84f44..10dbe61 100644 --- a/lib/chainer.rb +++ b/lib/chainer.rb @@ -62,6 +62,7 @@ require 'chainer/optimizers/momentum_sgd' require 'chainer/dataset/download' require 'chainer/datasets/mnist' +require 'chainer/datasets/cifar' require 'chainer/datasets/tuple_dataset' require 'chainer/reporter' require 'chainer/serializer' diff --git a/lib/chainer/datasets/cifar.rb b/lib/chainer/datasets/cifar.rb new file mode 100644 index 0000000..b6f55bc --- /dev/null +++ b/lib/chainer/datasets/cifar.rb @@ -0,0 +1,56 @@ +require 'datasets' + +module Chainer + module Datasets + module CIFAR + def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0) + get_cifar(10, with_label, ndim, scale) + end + + def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0) + get_cifar(100, with_label, ndim, scale) + end + + def self.get_cifar(n_classes, with_label, ndim, scale) + train_data = [] + train_labels = [] + ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record| + train_data << record.pixels + train_labels << (n_classes == 10 ? record.label : record.fine_label) + end + + test_data = [] + test_labels = [] + ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record| + test_data << record.pixels + test_labels << (n_classes == 10 ? record.label : record.fine_label) + end + + [ + preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale), + preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale) + ] + end + + def self.preprocess_cifar(images, labels, withlabel, ndim, scale) + if ndim == 1 + images = images.reshape(images.shape[0], 3072) + elsif ndim == 3 + images = images.reshape(images.shape[0], 3, 32, 32) + else + raise 'invalid ndim for CIFAR dataset' + end + images = images.cast_to(Numo::DFloat) + images *= scale / 255.0 + + if withlabel + labels = labels.cast_to(Numo::Int32) + TupleDataset.new(images, labels) + else + images + end + end + end + end +end + diff --git a/lib/chainer/iterators/serial_iterator.rb b/lib/chainer/iterators/serial_iterator.rb index e0967b3..ed212a6 100644 --- a/lib/chainer/iterators/serial_iterator.rb +++ b/lib/chainer/iterators/serial_iterator.rb @@ -18,8 +18,8 @@ def next @previous_epoch_detail = epoch_detail i = @current_position - i_end = i + @batch_size n = @dataset.size + i_end = [i + @batch_size, n].min batch = @order[i...i_end].to_a.map { |index| @dataset[index] } diff --git a/lib/chainer/links/normalization/batch_normalization.rb b/lib/chainer/links/normalization/batch_normalization.rb index 40fe2e5..2d23dbb 100644 --- a/lib/chainer/links/normalization/batch_normalization.rb +++ b/lib/chainer/links/normalization/batch_normalization.rb @@ -23,7 +23,7 @@ class BatchNormalization < Chainer::Link # @param [Numo::NArray.dtype] dtype Type to use in computing. # @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect. # @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect. - def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::Float32, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil) + def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::DFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil) super() @avg_mean = dtype.zeros(size) register_persistent('avg_mean') diff --git a/lib/chainer/serializers/marshal.rb b/lib/chainer/serializers/marshal.rb index ea80a37..df9d6b9 100644 --- a/lib/chainer/serializers/marshal.rb +++ b/lib/chainer/serializers/marshal.rb @@ -28,7 +28,7 @@ def call(key, value) arr = Numo::Bit[1] elsif value.is_a?(FalseClass) arr = Numo::Bit[0] - elsif value.instance_of?(String) + elsif value.instance_of?(String) || value.nil? arr = value else arr = Numo::NArray.cast(value) diff --git a/lib/chainer/utils/variable.rb b/lib/chainer/utils/variable.rb index 0f9a991..677eca3 100644 --- a/lib/chainer/utils/variable.rb +++ b/lib/chainer/utils/variable.rb @@ -6,8 +6,8 @@ def self.check_grad_type(func, x, gx) return end - unless gx.instance_of?(x.data.class) - raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}" + unless gx.is_a?(x.data.class.superclass) + raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}" end unless gx.class == x.data.class