Skip to content

Commit

Permalink
* fix concat_arrays_with_padding()
Browse files Browse the repository at this point in the history
* added testcase for Dataset::Convert.
  • Loading branch information
naitoh committed May 11, 2018
1 parent 7f7a50d commit 283ea68
Show file tree
Hide file tree
Showing 2 changed files with 125 additions and 9 deletions.
36 changes: 28 additions & 8 deletions lib/chainer/dataset/convert.rb
Original file line number Diff line number Diff line change
Expand Up @@ -29,31 +29,51 @@ def self.concat_examples(batch, device: nil, padding: nil)

def self.concat_arrays(arrays, padding)
unless arrays[0].kind_of?(Numo::NArray)
# [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4]
arrays = Numo::NArray.cast(arrays)
if padding
return concat_arrays_with_padding(arrays, padding)
end
return arrays
end

if padding
return concat_arrays_with_padding(arrays, padding)
end

# [Numo::SFloat[1, 2], Numo::SFloat[3, 4]]
# => Numo::SFloat#shape=[2,2]
# [[1, 2], [3, 4]]
a = arrays.map{|arr| arr[:-, false]}
a[0].concatenate(*a[1..-1])
end

def self.concat_arrays_with_padding(arrays, padding)
shape = Numo::Int32.[](arrays[0].shape)
arrays[1...arrays.len].each do |array|
if Numo::Bit.[](shape != array.shape).any?
# TODO: numpy maximum
if arrays[0].is_a? Numo::NArray
shape = Numo::Int32.cast(arrays[0].shape)
arrays[1..-1].each do |array|
if Numo::Bit.[](shape != array.shape).any?
shape = Numo::Int32.maximum(shape, array.shape)
end
end
else # Integer
shape = []
end

shape = shape.insert(0, arrays.size).to_a
if arrays[0].is_a? Numo::NArray
result = arrays[0].class.new(shape).fill(padding)
else # Integer
result = Numo::Int32.new(shape).fill(padding)
end

shape = [shape.insert(0, arrays.size)]
result = arrays[0].dtype.[](*shape).full(padding)
arrays.size.times do |i|
src = arrays[i]
slices = src.shape.map { |s| [s] }
result[[i] + slices] = src
if src.is_a? Numo::NArray
result[i, 0...src.shape[0], 0...src.shape[1]] = src
else # Integer
result[i] = src
end
end

result
Expand Down
98 changes: 97 additions & 1 deletion test/dataset/convert_test.rb
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ def get_tuple_arrays_to_concat(xumo)
def check_concat_tuples(tuples, device: nil)
arrays = Chainer::Dataset::Convert.method(:concat_examples).call(tuples, device: device)
assert_equal(tuples[0].size, arrays.size)
for i in arrays.size.times
arrays.size.times do |i|
shape = [tuples.size] + tuples[0][i].shape
assert_equal(shape, arrays[i].shape)
check_device(arrays[i], device)
Expand All @@ -50,3 +50,99 @@ def test_concat_tuples_cpu()
check_concat_tuples(tuples)
end
end

class TestConcatExamplesWithPadding < Test::Unit::TestCase
def check_concat_arrays_padding(xumo)
arrays = [xumo::DFloat.new(3, 4).rand(), xumo::DFloat.new(2, 5).rand(), xumo::DFloat.new(4, 3).rand()]
array = Chainer::Dataset::Convert.method(:concat_examples).call(arrays, padding: 0)

assert_equal([3, 4, 5], array.shape)
assert_equal(arrays[0].class, array.class)
arrays = arrays.map{|a| array.class.cast(a)}
assert_true array[0, 0...3, 0...4].nearly_eq(arrays[0]).all?
assert_true array[0, 3..-1, 0..-1].nearly_eq(0).all?
assert_true array[0, 0..-1, 4..-1].nearly_eq(0).all?
assert_true array[1, 0...2, 0...5].nearly_eq(arrays[1]).all?
assert_true array[1, 2..-1, 0..-1].nearly_eq(0).all?
assert_true array[2, 0...4, 0...3].nearly_eq(arrays[2]).all?
assert_true array[2, 0..-1, 3..-1].nearly_eq(0).all?
end

def test_concat_arrays_padding_cpu()
check_concat_arrays_padding(Numo)
end

def check_concat_tuples_padding(xumo)
tuples = [[xumo::DFloat.new(3, 4).rand(), xumo::DFloat.new(2, 5).rand()],
[xumo::DFloat.new(4, 4).rand(), xumo::DFloat.new(3, 4).rand()],
[xumo::DFloat.new(2, 5).rand(), xumo::DFloat.new(2, 6).rand()]]
arrays = Chainer::Dataset::Convert.method(:concat_examples).call(tuples, padding: 0)

assert_equal(2, arrays.size)
assert_equal([3, 4, 5], arrays[0].shape)
assert_equal([3, 3, 6], arrays[1].shape)
assert_equal(tuples[0][0].class, arrays[0].class)
assert_equal(tuples[0][1].class, arrays[1].class)
tuples.size.times do |i|
tuples[i] = [tuples[i][0], tuples[i][1]]
end

arrays = arrays.to_a
assert_true arrays[0][0, 0...3, 0...4].nearly_eq(tuples[0][0]).all?
assert_true arrays[0][0, 3..-1, 0..-1].nearly_eq(0).all?
assert_true arrays[0][0, 0..-1, 4..-1].nearly_eq(0).all?
assert_true arrays[0][1, 0...4, 0...4].nearly_eq(tuples[1][0]).all?
assert_true arrays[0][1, 0..-1, 4..-1].nearly_eq(0).all?
assert_true arrays[0][2, 0...2, 0...5].nearly_eq(tuples[2][0]).all?
assert_true arrays[0][2, 2..-1, 0..-1].nearly_eq(0).all?
assert_true arrays[1][0, 0...2, 0...5].nearly_eq(tuples[0][1]).all?
assert_true arrays[1][0, 2..-1, 0..-1].nearly_eq(0).all?
assert_true arrays[1][0, 0..-1, 5..-1].nearly_eq(0).all?
assert_true arrays[1][1, 0...3, 0...4].nearly_eq(tuples[1][1]).all?
#assert_true arrays[1][1, 3..-1, 0..-1].nearly_eq(0).all? # range error
assert_true arrays[1][1, 0..-1, 4..-1].nearly_eq(0).all?
assert_true arrays[1][2, 0...2, 0...6].nearly_eq(tuples[2][1]).all?
assert_true arrays[1][2, 2..-1, 0..-1].nearly_eq(0).all?
end

def test_concat_tuples_padding_cpu()
check_concat_tuples_padding(Numo)
end
end

class TestConcatExamplesWithBuiltInTypes < Test::Unit::TestCase
data = {
'test1' => {padding: nil},
'test2' => {padding: 0}}

@@int_arrays = [1, 2, 3]
@@float_arrays = [1.0, 2.0, 3.0]

def check_device(array, device)
if device && device >= 0
# T.B.I (GPU Check)
else
assert_true array.is_a?(Numo::NArray)
end
end

def check_concat_arrays(arrays, device:, expected_type:)
array = Chainer::Dataset::Convert.method(:concat_examples).call(arrays, device: device, padding: @padding)
assert_equal([arrays.size], array.shape)
check_device(array, device)

array.to_a.zip(arrays.to_a).each do |x, y|
assert_true Numo::NArray.cast(y).nearly_eq(Numo::NArray.cast(x)).all?
end
end

data(data)
def test_concat_arrays_cpu(data)
@padding = data[:padding]

[-1, nil].each do |device|
check_concat_arrays(@@int_arrays, device: device, expected_type: Numo::Int64)
check_concat_arrays(@@float_arrays, device: device, expected_type: Numo::DFloat)
end
end
end

0 comments on commit 283ea68

Please sign in to comment.