From c43d020f70b3071597681fd7ea75878c0cfc769c Mon Sep 17 00:00:00 2001 From: hatappi Date: Sun, 25 Mar 2018 23:34:09 +0900 Subject: [PATCH 1/8] add cifar 10, 100 datasets --- lib/chainer.rb | 1 + lib/chainer/datasets/cifar.rb | 135 ++++++++++++++++++++++++++ lib/chainer/datasets/tuple_dataset.rb | 3 +- 3 files changed, 138 insertions(+), 1 deletion(-) create mode 100644 lib/chainer/datasets/cifar.rb diff --git a/lib/chainer.rb b/lib/chainer.rb index cd88525..693e78a 100644 --- a/lib/chainer.rb +++ b/lib/chainer.rb @@ -56,6 +56,7 @@ require 'chainer/optimizers/momentum_sgd' require 'chainer/dataset/download' require 'chainer/datasets/mnist' +require 'chainer/datasets/cifar' require 'chainer/datasets/tuple_dataset' require 'chainer/reporter' require 'chainer/serializer' diff --git a/lib/chainer/datasets/cifar.rb b/lib/chainer/datasets/cifar.rb new file mode 100644 index 0000000..46c8538 --- /dev/null +++ b/lib/chainer/datasets/cifar.rb @@ -0,0 +1,135 @@ +require 'rubygems/package' + +module Chainer + module Datasets + module Cifar + def self.get_cifar10(withlabel: true, ndim: 3, scale: 1.0) + get_cifar('cifar-10', withlabel, ndim, scale) + end + + def self.get_cifar100(withlabel: true, ndim: 3, scale: 1.0) + get_cifar('cifar-100', withlabel, ndim, scale) + end + + def self.get_cifar(name, withlabel, ndim, scale) + root = Chainer::Dataset::Download.get_dataset_directory('cifar') + filename = "#{name}-binary.tar.gz" + url = "http://www.cs.toronto.edu/~kriz/#{filename}" + path = File.expand_path(filename, root) + extractfile(root, url) + raw = creator(root, name) + train = preprocess_cifar(raw[:train_x], raw[:train_y], withlabel, ndim, scale) + test = preprocess_cifar(raw[:test_x], raw[:test_y], withlabel, ndim, scale) + [train, test] + end + + def self.extractfile(root, url) + archive_path = Chainer::Dataset::Download.cached_download(url) + Gem::Package::TarReader.new(Zlib::GzipReader.open(archive_path)) do |tar| + tar.each do |entry| + dest = File.expand_path(entry.full_name, root) + if entry.directory? + FileUtils.mkdir_p(dest) + else + File.open(dest, "wb") do |f| + f.print(entry.read) + end + end + end + end + end + + def self.creator(root, name) + if name == 'cifar-10' + train_x = Numo::UInt8.new(5, 10000, 3072).rand(1) + train_y = Numo::UInt8.new(5, 10000).rand(1) + test_x = Numo::UInt8.new(10000, 3072).rand(1) + test_y = Numo::UInt8.new(10000).rand(1) + + dir = File.expand_path("cifar-10-batches-bin", root) + (1..5).each do |i| + file_name = "#{dir}/data_batch_#{i}.bin" + open(file_name, "rb") do |f| + s = 0 + while b = f.read(3073) do + datasets = b.unpack("C*") + train_y[i - 1, s] = datasets.shift + train_x[i - 1, s, false] = datasets + s += 1 + end + end + end + + file_name = "#{dir}/test_batch.bin" + open(file_name, "rb") do |f| + s = 0 + while b = f.read(3073) do + datasets = b.unpack("C*") + test_y[s] = datasets.shift + test_x[s, false] = datasets + s += 1 + end + end + + train_x = train_x.reshape(50000, 3072) + train_y = train_y.reshape(50000) + else + train_x = Numo::UInt8.new(50000, 3072).rand(1) + train_y = Numo::UInt8.new(50000).rand(1) + test_x = Numo::UInt8.new(10000, 3072).rand(1) + test_y = Numo::UInt8.new(10000).rand(1) + dir = File.expand_path("cifar-100-batches-bin", root) + + file_name = "#{dir}/train.bin" + open(file_name, "rb") do |f| + s = 0 + while b = f.read(3073) do + datasets = b.unpack("C*") + train_y[s] = datasets.shift + train_x[s, false] = datasets + s += 1 + end + end + + file_name = "#{dir}/test.bin" + open(file_name, "rb") do |f| + s = 0 + while b = f.read(3073) do + datasets = b.unpack("C*") + test_y[s] = datasets.shift + test_x[s, false] = datasets + s += 1 + end + end + end + + { + train_x: train_x, + train_y: train_y, + test_x: test_x, + test_y: test_y + } + end + + def self.preprocess_cifar(images, labels, withlabel, ndim, scale) + if ndim == 1 + images = images.reshape(images.shape[0], 3072) + elsif ndim == 3 + images = images.reshape(images.shape[0], 3, 32, 32) + else + raise 'invalid ndim for CIFAR dataset' + end + images = images.cast_to(Numo::Float32) + images *= scale / 255.0 + + if withlabel + labels = labels.cast_to(Numo::Int32) + TupleDataset.new(images, labels) + else + images + end + end + end + end +end + diff --git a/lib/chainer/datasets/tuple_dataset.rb b/lib/chainer/datasets/tuple_dataset.rb index 26481a9..659f16e 100644 --- a/lib/chainer/datasets/tuple_dataset.rb +++ b/lib/chainer/datasets/tuple_dataset.rb @@ -16,7 +16,8 @@ def initialize(*datasets) end def [](index) - batches = @datasets.map { |dataset| dataset.ndim > 1 ? dataset[index, 0...dataset.shape[1]] : dataset[index] } + batches = @datasets.map { |dataset| dataset[index, false] } + if index.kind_of?(Enumerable) length = batches[0].shape[0] length.times.map {|i| batches.map { |m| m[i] } } From aee23bdc86783dece3f61b92197d84987b1c0e45 Mon Sep 17 00:00:00 2001 From: hatappi Date: Sun, 25 Mar 2018 23:46:03 +0900 Subject: [PATCH 2/8] fix setup cifar-100 --- lib/chainer/datasets/cifar.rb | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/lib/chainer/datasets/cifar.rb b/lib/chainer/datasets/cifar.rb index 46c8538..a35a1c6 100644 --- a/lib/chainer/datasets/cifar.rb +++ b/lib/chainer/datasets/cifar.rb @@ -74,11 +74,11 @@ def self.creator(root, name) train_x = train_x.reshape(50000, 3072) train_y = train_y.reshape(50000) else - train_x = Numo::UInt8.new(50000, 3072).rand(1) - train_y = Numo::UInt8.new(50000).rand(1) - test_x = Numo::UInt8.new(10000, 3072).rand(1) - test_y = Numo::UInt8.new(10000).rand(1) - dir = File.expand_path("cifar-100-batches-bin", root) + train_x = Numo::UInt8.new(50017, 3072).rand(1) + train_y = Numo::UInt8.new(50017).rand(1) + test_x = Numo::UInt8.new(10004, 3072).rand(1) + test_y = Numo::UInt8.new(10004).rand(1) + dir = File.expand_path("cifar-100-binary", root) file_name = "#{dir}/train.bin" open(file_name, "rb") do |f| From c4c7619934af3b624520c6803ae28fbfa9be7680 Mon Sep 17 00:00:00 2001 From: hatappi Date: Wed, 28 Mar 2018 23:01:50 +0900 Subject: [PATCH 3/8] fix reshape --- lib/chainer/functions/connection/linear.rb | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/lib/chainer/functions/connection/linear.rb b/lib/chainer/functions/connection/linear.rb index b364e0d..6681809 100644 --- a/lib/chainer/functions/connection/linear.rb +++ b/lib/chainer/functions/connection/linear.rb @@ -40,7 +40,8 @@ def backward(inputs, grad_outputs) def as_mat(x) return x if x.ndim == 2 - x.reshape(x.size, -1) + sum = x.shape.reduce(:*) + x.reshape(x.shape[0], sum / x.shape[0]) end end end From 0643d89194e6334149aa13a1737b651462769328 Mon Sep 17 00:00:00 2001 From: hatappi Date: Wed, 28 Mar 2018 23:03:39 +0900 Subject: [PATCH 4/8] fix Variable class check --- lib/chainer/utils/variable.rb | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/lib/chainer/utils/variable.rb b/lib/chainer/utils/variable.rb index c452d19..6995895 100644 --- a/lib/chainer/utils/variable.rb +++ b/lib/chainer/utils/variable.rb @@ -6,8 +6,8 @@ def self.check_grad_type(func, x, gx) return end - unless gx.instance_of?(x.data.class) - raise TypeError, "Type of data and grad mismatch\n#{x.class} != #{gx.class}" + unless gx.is_a?(x.data.class.superclass) + raise TypeError, "Type of data and grad mismatch\n#{x.data.class} != #{gx.class}" end unless gx.shape == x.data.shape From f231b027712984f7177f92fc64aa566524330ea2 Mon Sep 17 00:00:00 2001 From: hatappi Date: Wed, 28 Mar 2018 23:08:43 +0900 Subject: [PATCH 5/8] add CIFAR example --- examples/cifar/models/vgg.rb | 84 +++++++++++++++++++++++++++++++++++ examples/cifar/train_cifar.rb | 73 ++++++++++++++++++++++++++++++ 2 files changed, 157 insertions(+) create mode 100644 examples/cifar/models/vgg.rb create mode 100644 examples/cifar/train_cifar.rb diff --git a/examples/cifar/models/vgg.rb b/examples/cifar/models/vgg.rb new file mode 100644 index 0000000..e564ad0 --- /dev/null +++ b/examples/cifar/models/vgg.rb @@ -0,0 +1,84 @@ +class Block < Chainer::Chain + def initialize(out_channels, ksize, pad: 1) + super() + init_scope do + @conv = Chainer::Links::Connection::Convolution2D.new(nil, out_channels, ksize, pad: pad, nobias: true) + @bn = Chainer::Links::Normalization::BatchNormalization.new(out_channels) + end + end + + def call(x) + h = @conv.(x) + h = @bn.(h) + Chainer::Functions::Activation::Relu.relu(h) + end +end + +class VGG < Chainer::Chain + def initialize(class_labels: 10) + super() + init_scope do + @block1_1 = Block.new(64, 3) + @block1_2 = Block.new(64, 3) + @block2_1 = Block.new(128, 3) + @block2_2 = Block.new(128, 3) + @block3_1 = Block.new(256, 3) + @block3_2 = Block.new(256, 3) + @block3_3 = Block.new(256, 3) + @block4_1 = Block.new(512, 3) + @block4_2 = Block.new(512, 3) + @block4_3 = Block.new(512, 3) + @block5_1 = Block.new(512, 3) + @block5_2 = Block.new(512, 3) + @block5_3 = Block.new(512, 3) + @fc1 = Chainer::Links::Connection::Linear.new(nil, out_size: 512, nobias: true) + @bn_fc1 = Chainer::Links::Normalization::BatchNormalization.new(512) + @fc2 = Chainer::Links::Connection::Linear.new(nil, out_size: class_labels, nobias: true) + end + end + + def call(x) + # 64 channel blocks: + h = @block1_1.(x) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.3) + h = @block1_2.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + # 128 channel blocks: + h = @block2_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block2_2.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride:2) + + # 256 channel blocks: + h = @block3_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block3_2.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block3_3.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + # 512 channel blocks: + h = @block4_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block4_2.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block4_3.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + # 512 channel blocks: + h = @block5_1.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block5_2.(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.4) + h = @block5_3.(h) + h = Chainer::Functions::Pooling::MaxPooling2D.max_pooling_2d(h, 2, stride: 2) + + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5) + h = @fc1.(h) + h = @bn_fc1.(h) + h = Chainer::Functions::Activation::Relu.relu(h) + h = Chainer::Functions::Noise::Dropout.dropout(h, ratio: 0.5) + @fc2.(h) + end +end diff --git a/examples/cifar/train_cifar.rb b/examples/cifar/train_cifar.rb new file mode 100644 index 0000000..a25b841 --- /dev/null +++ b/examples/cifar/train_cifar.rb @@ -0,0 +1,73 @@ +require 'chainer' +require './models/vgg' +require 'optparse' + +args = { + dataset: 'cifar10', + frequency: -1, + batchsize: 64, + learnrate: 0.05, + epoch: 300, + out: 'result', + resume: nil +} + + +opt = OptionParser.new +opt.on('-d', '--dataset VALUE', "The dataset to use: cifar10 or cifar100 (default: #{args[:dataset]})") { |v| args[:dataset] = v } +opt.on('-b', '--batchsize VALUE', "Number of images in each mini-batch (default: #{args[:batchsize]})") { |v| args[:batchsize] = v.to_i } +opt.on('-f', '--frequency VALUE', "Frequency of taking a snapshot (default: #{args[:frequency]})") { |v| args[:frequency] = v.to_i } +opt.on('-l', '--learnrate VALUE', "Learning rate for SGD (default: #{args[:learnrate]})") { |v| args[:learnrate] = v.to_f } +opt.on('-e', '--epoch VALUE', "Number of sweeps over the dataset to train (default: #{args[:epoch]})") { |v| args[:epoch] = v.to_i } +opt.on('-o', '--out VALUE', "Directory to output the result (default: #{args[:out]})") { |v| args[:out] = v } +opt.on('-r', '--resume VALUE', "Resume the training from snapshot") { |v| args[:resume] = v } +opt.parse!(ARGV) + +# Set up a neural network to train. +# Classifier reports softmax cross entropy loss and accuracy at every +# iteration, which will be used by the PrintReport extension below. +if args[:dataset] == 'cifar10' + puts 'Using CIFAR10 dataset.' + class_labels = 10 + train, test = Chainer::Datasets::Cifar.get_cifar10 +elsif args[:dataset] == 'cifar100' + puts 'Using CIFAR100 dataset.' + class_labels = 100 + train, test = Chainer::Datasets::Cifar.get_cifar100 +else + raise 'Invalid dataset choice.' +end + +puts "setup..." + +model = Chainer::Links::Model::Classifier.new(VGG.new(class_labels: class_labels)) + +optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate]) +optimizer.setup(model) + +# TODO +# optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4)) + +train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize]) +test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false) + +updater = Chainer::Training::StandardUpdater.new(train_iter, optimizer, device: -1) +trainer = Chainer::Training::Trainer.new(updater, stop_trigger: [args[:epoch], 'epoch'], out: args[:out]) + +trainer.extend(Chainer::Training::Extensions::Evaluator.new(test_iter, model, device: -1)) + +trainer.extend(Chainer::Training::Extensions::ExponentialShift.new('lr', 0.5), trigger: [25, 'epoch']) + +frequency = args[:frequency] == -1 ? args[:epoch] : [1, args[:frequency]].max +trainer.extend(Chainer::Training::Extensions::Snapshot.new, trigger: [frequency, 'epoch']) + +trainer.extend(Chainer::Training::Extensions::LogReport.new) +trainer.extend(Chainer::Training::Extensions::PrintReport.new(['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy', 'elapsed_time'])) +trainer.extend(Chainer::Training::Extensions::ProgressBar.new) + +if args[:resume] + Chainer::Serializers::MarshalDeserializer.load_file(args[:resume], trainer) +end + +trainer.run + From 79e1031bd1a6bbe4d555d5de38ba7d796b8218da Mon Sep 17 00:00:00 2001 From: hatappi Date: Thu, 29 Mar 2018 01:38:33 +0900 Subject: [PATCH 6/8] fix serialize --- lib/chainer/serializers/marshal.rb | 2 +- lib/chainer/training/extensions/exponential_shift.rb | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/lib/chainer/serializers/marshal.rb b/lib/chainer/serializers/marshal.rb index 33885c8..438c18b 100644 --- a/lib/chainer/serializers/marshal.rb +++ b/lib/chainer/serializers/marshal.rb @@ -24,7 +24,7 @@ def call(key, value) arr = Numo::Bit[1] elsif value.is_a?(FalseClass) arr = Numo::Bit[0] - elsif value.instance_of?(String) + elsif value.instance_of?(String) || value.nil? arr = value else arr = Numo::NArray.cast(value) diff --git a/lib/chainer/training/extensions/exponential_shift.rb b/lib/chainer/training/extensions/exponential_shift.rb index 91d2b3d..a9ff89b 100644 --- a/lib/chainer/training/extensions/exponential_shift.rb +++ b/lib/chainer/training/extensions/exponential_shift.rb @@ -55,8 +55,8 @@ def call(trainer) end def serialize(serializer) - @t = serializer('t', @t) - @last_value = serializer('last_value', @last_value) + @t = serializer.('t', @t) + @last_value = serializer.('last_value', @last_value) if @last_value.is_a?(Numo::NArray) @last_value = @last_value[0] end From f57406829f665289866ef7218a3198f9b516ca81 Mon Sep 17 00:00:00 2001 From: hatappi Date: Thu, 29 Mar 2018 14:35:37 +0900 Subject: [PATCH 7/8] fix end --- lib/chainer/iterators/serial_iterator.rb | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/lib/chainer/iterators/serial_iterator.rb b/lib/chainer/iterators/serial_iterator.rb index e0967b3..ed212a6 100644 --- a/lib/chainer/iterators/serial_iterator.rb +++ b/lib/chainer/iterators/serial_iterator.rb @@ -18,8 +18,8 @@ def next @previous_epoch_detail = epoch_detail i = @current_position - i_end = i + @batch_size n = @dataset.size + i_end = [i + @batch_size, n].min batch = @order[i...i_end].to_a.map { |index| @dataset[index] } From 2f82cc3f42a381f361c0472cf744ecf0153c7cfc Mon Sep 17 00:00:00 2001 From: hatappi Date: Sat, 5 May 2018 16:54:24 +0900 Subject: [PATCH 8/8] fix --- examples/cifar/train_cifar.rb | 9 +- lib/chainer/datasets/cifar.rb | 125 ++++-------------- .../normalization/batch_normalization.rb | 2 +- 3 files changed, 27 insertions(+), 109 deletions(-) diff --git a/examples/cifar/train_cifar.rb b/examples/cifar/train_cifar.rb index a25b841..034f387 100644 --- a/examples/cifar/train_cifar.rb +++ b/examples/cifar/train_cifar.rb @@ -1,5 +1,5 @@ require 'chainer' -require './models/vgg' +require __dir__ + '/models/vgg' require 'optparse' args = { @@ -29,11 +29,11 @@ if args[:dataset] == 'cifar10' puts 'Using CIFAR10 dataset.' class_labels = 10 - train, test = Chainer::Datasets::Cifar.get_cifar10 + train, test = Chainer::Datasets::CIFAR.get_cifar10 elsif args[:dataset] == 'cifar100' puts 'Using CIFAR100 dataset.' class_labels = 100 - train, test = Chainer::Datasets::Cifar.get_cifar100 + train, test = Chainer::Datasets::CIFAR.get_cifar100 else raise 'Invalid dataset choice.' end @@ -45,9 +45,6 @@ optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate]) optimizer.setup(model) -# TODO -# optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4)) - train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize]) test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false) diff --git a/lib/chainer/datasets/cifar.rb b/lib/chainer/datasets/cifar.rb index a35a1c6..b6f55bc 100644 --- a/lib/chainer/datasets/cifar.rb +++ b/lib/chainer/datasets/cifar.rb @@ -1,114 +1,35 @@ -require 'rubygems/package' +require 'datasets' module Chainer module Datasets - module Cifar - def self.get_cifar10(withlabel: true, ndim: 3, scale: 1.0) - get_cifar('cifar-10', withlabel, ndim, scale) + module CIFAR + def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0) + get_cifar(10, with_label, ndim, scale) end - def self.get_cifar100(withlabel: true, ndim: 3, scale: 1.0) - get_cifar('cifar-100', withlabel, ndim, scale) + def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0) + get_cifar(100, with_label, ndim, scale) end - def self.get_cifar(name, withlabel, ndim, scale) - root = Chainer::Dataset::Download.get_dataset_directory('cifar') - filename = "#{name}-binary.tar.gz" - url = "http://www.cs.toronto.edu/~kriz/#{filename}" - path = File.expand_path(filename, root) - extractfile(root, url) - raw = creator(root, name) - train = preprocess_cifar(raw[:train_x], raw[:train_y], withlabel, ndim, scale) - test = preprocess_cifar(raw[:test_x], raw[:test_y], withlabel, ndim, scale) - [train, test] - end - - def self.extractfile(root, url) - archive_path = Chainer::Dataset::Download.cached_download(url) - Gem::Package::TarReader.new(Zlib::GzipReader.open(archive_path)) do |tar| - tar.each do |entry| - dest = File.expand_path(entry.full_name, root) - if entry.directory? - FileUtils.mkdir_p(dest) - else - File.open(dest, "wb") do |f| - f.print(entry.read) - end - end - end - end - end - - def self.creator(root, name) - if name == 'cifar-10' - train_x = Numo::UInt8.new(5, 10000, 3072).rand(1) - train_y = Numo::UInt8.new(5, 10000).rand(1) - test_x = Numo::UInt8.new(10000, 3072).rand(1) - test_y = Numo::UInt8.new(10000).rand(1) - - dir = File.expand_path("cifar-10-batches-bin", root) - (1..5).each do |i| - file_name = "#{dir}/data_batch_#{i}.bin" - open(file_name, "rb") do |f| - s = 0 - while b = f.read(3073) do - datasets = b.unpack("C*") - train_y[i - 1, s] = datasets.shift - train_x[i - 1, s, false] = datasets - s += 1 - end - end - end + def self.get_cifar(n_classes, with_label, ndim, scale) + train_data = [] + train_labels = [] + ::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record| + train_data << record.pixels + train_labels << (n_classes == 10 ? record.label : record.fine_label) + end - file_name = "#{dir}/test_batch.bin" - open(file_name, "rb") do |f| - s = 0 - while b = f.read(3073) do - datasets = b.unpack("C*") - test_y[s] = datasets.shift - test_x[s, false] = datasets - s += 1 - end - end - - train_x = train_x.reshape(50000, 3072) - train_y = train_y.reshape(50000) - else - train_x = Numo::UInt8.new(50017, 3072).rand(1) - train_y = Numo::UInt8.new(50017).rand(1) - test_x = Numo::UInt8.new(10004, 3072).rand(1) - test_y = Numo::UInt8.new(10004).rand(1) - dir = File.expand_path("cifar-100-binary", root) - - file_name = "#{dir}/train.bin" - open(file_name, "rb") do |f| - s = 0 - while b = f.read(3073) do - datasets = b.unpack("C*") - train_y[s] = datasets.shift - train_x[s, false] = datasets - s += 1 - end - end - - file_name = "#{dir}/test.bin" - open(file_name, "rb") do |f| - s = 0 - while b = f.read(3073) do - datasets = b.unpack("C*") - test_y[s] = datasets.shift - test_x[s, false] = datasets - s += 1 - end - end + test_data = [] + test_labels = [] + ::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record| + test_data << record.pixels + test_labels << (n_classes == 10 ? record.label : record.fine_label) end - { - train_x: train_x, - train_y: train_y, - test_x: test_x, - test_y: test_y - } + [ + preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale), + preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale) + ] end def self.preprocess_cifar(images, labels, withlabel, ndim, scale) @@ -119,7 +40,7 @@ def self.preprocess_cifar(images, labels, withlabel, ndim, scale) else raise 'invalid ndim for CIFAR dataset' end - images = images.cast_to(Numo::Float32) + images = images.cast_to(Numo::DFloat) images *= scale / 255.0 if withlabel diff --git a/lib/chainer/links/normalization/batch_normalization.rb b/lib/chainer/links/normalization/batch_normalization.rb index 40fe2e5..2d23dbb 100644 --- a/lib/chainer/links/normalization/batch_normalization.rb +++ b/lib/chainer/links/normalization/batch_normalization.rb @@ -23,7 +23,7 @@ class BatchNormalization < Chainer::Link # @param [Numo::NArray.dtype] dtype Type to use in computing. # @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect. # @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect. - def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::Float32, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil) + def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::DFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil) super() @avg_mean = dtype.zeros(size) register_persistent('avg_mean')