diff --git a/lib/chainer/dataset/convert.rb b/lib/chainer/dataset/convert.rb index 7c07611..5a5b800 100644 --- a/lib/chainer/dataset/convert.rb +++ b/lib/chainer/dataset/convert.rb @@ -29,31 +29,51 @@ def self.concat_examples(batch, device: nil, padding: nil) def self.concat_arrays(arrays, padding) unless arrays[0].kind_of?(Numo::NArray) + # [1, 2, 3, 4] => Numo::Int32[1, 2, 3, 4] arrays = Numo::NArray.cast(arrays) + if padding + return concat_arrays_with_padding(arrays, padding) + end + return arrays end if padding return concat_arrays_with_padding(arrays, padding) end + # [Numo::SFloat[1, 2], Numo::SFloat[3, 4]] + # => Numo::SFloat#shape=[2,2] + # [[1, 2], [3, 4]] a = arrays.map{|arr| arr[:-, false]} a[0].concatenate(*a[1..-1]) end def self.concat_arrays_with_padding(arrays, padding) - shape = Numo::Int32.[](arrays[0].shape) - arrays[1...arrays.len].each do |array| - if Numo::Bit.[](shape != array.shape).any? - # TODO: numpy maximum + if arrays[0].is_a? Numo::NArray + shape = Numo::Int32.cast(arrays[0].shape) + arrays[1..-1].each do |array| + if Numo::Bit.[](shape != array.shape).any? + shape = Numo::Int32.maximum(shape, array.shape) + end end + else # Integer + shape = [] + end + + shape = shape.insert(0, arrays.size).to_a + if arrays[0].is_a? Numo::NArray + result = arrays[0].class.new(shape).fill(padding) + else # Integer + result = Numo::Int32.new(shape).fill(padding) end - shape = [shape.insert(0, arrays.size)] - result = arrays[0].dtype.[](*shape).full(padding) arrays.size.times do |i| src = arrays[i] - slices = src.shape.map { |s| [s] } - result[[i] + slices] = src + if src.is_a? Numo::NArray + result[i, 0...src.shape[0], 0...src.shape[1]] = src + else # Integer + result[i] = src + end end result diff --git a/test/dataset/convert_test.rb b/test/dataset/convert_test.rb index c165a54..4aaadae 100644 --- a/test/dataset/convert_test.rb +++ b/test/dataset/convert_test.rb @@ -35,7 +35,7 @@ def get_tuple_arrays_to_concat(xumo) def check_concat_tuples(tuples, device: nil) arrays = Chainer::Dataset::Convert.method(:concat_examples).call(tuples, device: device) assert_equal(tuples[0].size, arrays.size) - for i in arrays.size.times + arrays.size.times do |i| shape = [tuples.size] + tuples[0][i].shape assert_equal(shape, arrays[i].shape) check_device(arrays[i], device) @@ -50,3 +50,99 @@ def test_concat_tuples_cpu() check_concat_tuples(tuples) end end + +class TestConcatExamplesWithPadding < Test::Unit::TestCase + def check_concat_arrays_padding(xumo) + arrays = [xumo::DFloat.new(3, 4).rand(), xumo::DFloat.new(2, 5).rand(), xumo::DFloat.new(4, 3).rand()] + array = Chainer::Dataset::Convert.method(:concat_examples).call(arrays, padding: 0) + + assert_equal([3, 4, 5], array.shape) + assert_equal(arrays[0].class, array.class) + arrays = arrays.map{|a| array.class.cast(a)} + assert_true array[0, 0...3, 0...4].nearly_eq(arrays[0]).all? + assert_true array[0, 3..-1, 0..-1].nearly_eq(0).all? + assert_true array[0, 0..-1, 4..-1].nearly_eq(0).all? + assert_true array[1, 0...2, 0...5].nearly_eq(arrays[1]).all? + assert_true array[1, 2..-1, 0..-1].nearly_eq(0).all? + assert_true array[2, 0...4, 0...3].nearly_eq(arrays[2]).all? + assert_true array[2, 0..-1, 3..-1].nearly_eq(0).all? + end + + def test_concat_arrays_padding_cpu() + check_concat_arrays_padding(Numo) + end + + def check_concat_tuples_padding(xumo) + tuples = [[xumo::DFloat.new(3, 4).rand(), xumo::DFloat.new(2, 5).rand()], + [xumo::DFloat.new(4, 4).rand(), xumo::DFloat.new(3, 4).rand()], + [xumo::DFloat.new(2, 5).rand(), xumo::DFloat.new(2, 6).rand()]] + arrays = Chainer::Dataset::Convert.method(:concat_examples).call(tuples, padding: 0) + + assert_equal(2, arrays.size) + assert_equal([3, 4, 5], arrays[0].shape) + assert_equal([3, 3, 6], arrays[1].shape) + assert_equal(tuples[0][0].class, arrays[0].class) + assert_equal(tuples[0][1].class, arrays[1].class) + tuples.size.times do |i| + tuples[i] = [tuples[i][0], tuples[i][1]] + end + + arrays = arrays.to_a + assert_true arrays[0][0, 0...3, 0...4].nearly_eq(tuples[0][0]).all? + assert_true arrays[0][0, 3..-1, 0..-1].nearly_eq(0).all? + assert_true arrays[0][0, 0..-1, 4..-1].nearly_eq(0).all? + assert_true arrays[0][1, 0...4, 0...4].nearly_eq(tuples[1][0]).all? + assert_true arrays[0][1, 0..-1, 4..-1].nearly_eq(0).all? + assert_true arrays[0][2, 0...2, 0...5].nearly_eq(tuples[2][0]).all? + assert_true arrays[0][2, 2..-1, 0..-1].nearly_eq(0).all? + assert_true arrays[1][0, 0...2, 0...5].nearly_eq(tuples[0][1]).all? + assert_true arrays[1][0, 2..-1, 0..-1].nearly_eq(0).all? + assert_true arrays[1][0, 0..-1, 5..-1].nearly_eq(0).all? + assert_true arrays[1][1, 0...3, 0...4].nearly_eq(tuples[1][1]).all? + #assert_true arrays[1][1, 3..-1, 0..-1].nearly_eq(0).all? # range error + assert_true arrays[1][1, 0..-1, 4..-1].nearly_eq(0).all? + assert_true arrays[1][2, 0...2, 0...6].nearly_eq(tuples[2][1]).all? + assert_true arrays[1][2, 2..-1, 0..-1].nearly_eq(0).all? + end + + def test_concat_tuples_padding_cpu() + check_concat_tuples_padding(Numo) + end +end + +class TestConcatExamplesWithBuiltInTypes < Test::Unit::TestCase + data = { + 'test1' => {padding: nil}, + 'test2' => {padding: 0}} + + @@int_arrays = [1, 2, 3] + @@float_arrays = [1.0, 2.0, 3.0] + + def check_device(array, device) + if device && device >= 0 + # T.B.I (GPU Check) + else + assert_true array.is_a?(Numo::NArray) + end + end + + def check_concat_arrays(arrays, device:, expected_type:) + array = Chainer::Dataset::Convert.method(:concat_examples).call(arrays, device: device, padding: @padding) + assert_equal([arrays.size], array.shape) + check_device(array, device) + + array.to_a.zip(arrays.to_a).each do |x, y| + assert_true Numo::NArray.cast(y).nearly_eq(Numo::NArray.cast(x)).all? + end + end + + data(data) + def test_concat_arrays_cpu(data) + @padding = data[:padding] + + [-1, nil].each do |device| + check_concat_arrays(@@int_arrays, device: device, expected_type: Numo::Int64) + check_concat_arrays(@@float_arrays, device: device, expected_type: Numo::DFloat) + end + end +end