Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hatappi committed May 5, 2018
1 parent 7aef029 commit 2f82cc3
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 109 deletions.
9 changes: 3 additions & 6 deletions examples/cifar/train_cifar.rb
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
require 'chainer'
require './models/vgg'
require __dir__ + '/models/vgg'
require 'optparse'

args = {
Expand Down Expand Up @@ -29,11 +29,11 @@
if args[:dataset] == 'cifar10'
puts 'Using CIFAR10 dataset.'
class_labels = 10
train, test = Chainer::Datasets::Cifar.get_cifar10
train, test = Chainer::Datasets::CIFAR.get_cifar10
elsif args[:dataset] == 'cifar100'
puts 'Using CIFAR100 dataset.'
class_labels = 100
train, test = Chainer::Datasets::Cifar.get_cifar100
train, test = Chainer::Datasets::CIFAR.get_cifar100
else
raise 'Invalid dataset choice.'
end
Expand All @@ -45,9 +45,6 @@
optimizer = Chainer::Optimizers::MomentumSGD.new(lr: args[:learnrate])
optimizer.setup(model)

# TODO
# optimizer.add_hook(chainer.optimizer.WeightDecay(5e-4))

train_iter = Chainer::Iterators::SerialIterator.new(train, args[:batchsize])
test_iter = Chainer::Iterators::SerialIterator.new(test, args[:batchsize], repeat: false, shuffle: false)

Expand Down
125 changes: 23 additions & 102 deletions lib/chainer/datasets/cifar.rb
Original file line number Diff line number Diff line change
@@ -1,114 +1,35 @@
require 'rubygems/package'
require 'datasets'

module Chainer
module Datasets
module Cifar
def self.get_cifar10(withlabel: true, ndim: 3, scale: 1.0)
get_cifar('cifar-10', withlabel, ndim, scale)
module CIFAR
def self.get_cifar10(with_label: true, ndim: 3, scale: 1.0)
get_cifar(10, with_label, ndim, scale)
end

def self.get_cifar100(withlabel: true, ndim: 3, scale: 1.0)
get_cifar('cifar-100', withlabel, ndim, scale)
def self.get_cifar100(with_label: true, ndim: 3, scale: 1.0)
get_cifar(100, with_label, ndim, scale)
end

def self.get_cifar(name, withlabel, ndim, scale)
root = Chainer::Dataset::Download.get_dataset_directory('cifar')
filename = "#{name}-binary.tar.gz"
url = "http://www.cs.toronto.edu/~kriz/#{filename}"
path = File.expand_path(filename, root)
extractfile(root, url)
raw = creator(root, name)
train = preprocess_cifar(raw[:train_x], raw[:train_y], withlabel, ndim, scale)
test = preprocess_cifar(raw[:test_x], raw[:test_y], withlabel, ndim, scale)
[train, test]
end

def self.extractfile(root, url)
archive_path = Chainer::Dataset::Download.cached_download(url)
Gem::Package::TarReader.new(Zlib::GzipReader.open(archive_path)) do |tar|
tar.each do |entry|
dest = File.expand_path(entry.full_name, root)
if entry.directory?
FileUtils.mkdir_p(dest)
else
File.open(dest, "wb") do |f|
f.print(entry.read)
end
end
end
end
end

def self.creator(root, name)
if name == 'cifar-10'
train_x = Numo::UInt8.new(5, 10000, 3072).rand(1)
train_y = Numo::UInt8.new(5, 10000).rand(1)
test_x = Numo::UInt8.new(10000, 3072).rand(1)
test_y = Numo::UInt8.new(10000).rand(1)

dir = File.expand_path("cifar-10-batches-bin", root)
(1..5).each do |i|
file_name = "#{dir}/data_batch_#{i}.bin"
open(file_name, "rb") do |f|
s = 0
while b = f.read(3073) do
datasets = b.unpack("C*")
train_y[i - 1, s] = datasets.shift
train_x[i - 1, s, false] = datasets
s += 1
end
end
end
def self.get_cifar(n_classes, with_label, ndim, scale)
train_data = []
train_labels = []
::Datasets::CIFAR.new(n_classes: n_classes, type: :train).each do |record|
train_data << record.pixels
train_labels << (n_classes == 10 ? record.label : record.fine_label)
end

file_name = "#{dir}/test_batch.bin"
open(file_name, "rb") do |f|
s = 0
while b = f.read(3073) do
datasets = b.unpack("C*")
test_y[s] = datasets.shift
test_x[s, false] = datasets
s += 1
end
end

train_x = train_x.reshape(50000, 3072)
train_y = train_y.reshape(50000)
else
train_x = Numo::UInt8.new(50017, 3072).rand(1)
train_y = Numo::UInt8.new(50017).rand(1)
test_x = Numo::UInt8.new(10004, 3072).rand(1)
test_y = Numo::UInt8.new(10004).rand(1)
dir = File.expand_path("cifar-100-binary", root)

file_name = "#{dir}/train.bin"
open(file_name, "rb") do |f|
s = 0
while b = f.read(3073) do
datasets = b.unpack("C*")
train_y[s] = datasets.shift
train_x[s, false] = datasets
s += 1
end
end

file_name = "#{dir}/test.bin"
open(file_name, "rb") do |f|
s = 0
while b = f.read(3073) do
datasets = b.unpack("C*")
test_y[s] = datasets.shift
test_x[s, false] = datasets
s += 1
end
end
test_data = []
test_labels = []
::Datasets::CIFAR.new(n_classes: n_classes, type: :test).each do |record|
test_data << record.pixels
test_labels << (n_classes == 10 ? record.label : record.fine_label)
end

{
train_x: train_x,
train_y: train_y,
test_x: test_x,
test_y: test_y
}
[
preprocess_cifar(Numo::UInt8[*train_data], Numo::UInt8[*train_labels], with_label, ndim, scale),
preprocess_cifar(Numo::UInt8[*test_data], Numo::UInt8[*test_labels], with_label, ndim, scale)
]
end

def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
Expand All @@ -119,7 +40,7 @@ def self.preprocess_cifar(images, labels, withlabel, ndim, scale)
else
raise 'invalid ndim for CIFAR dataset'
end
images = images.cast_to(Numo::Float32)
images = images.cast_to(Numo::DFloat)
images *= scale / 255.0

if withlabel
Expand Down
2 changes: 1 addition & 1 deletion lib/chainer/links/normalization/batch_normalization.rb
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ class BatchNormalization < Chainer::Link
# @param [Numo::NArray.dtype] dtype Type to use in computing.
# @param [boolean] use_gamma If `true`, use scaling parameter. Otherwise, use unit(1) which makes no effect.
# @param [boolean] use_beta If `true`, use shifting parameter. Otherwise, use unit(0) which makes no effect.
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::Float32, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
def initialize(size, decay: 0.9, eps: 2e-5, dtype: Numo::DFloat, use_gamma: true, use_beta: true, initial_gamma: nil, initial_beta: nil)
super()
@avg_mean = dtype.zeros(size)
register_persistent('avg_mean')
Expand Down

0 comments on commit 2f82cc3

Please sign in to comment.