diff --git a/lib/chainer/dataset/download.rb b/lib/chainer/dataset/download.rb index 01af695..34b596a 100644 --- a/lib/chainer/dataset/download.rb +++ b/lib/chainer/dataset/download.rb @@ -34,15 +34,22 @@ def self.get_dataset_directory(dataset_name, create_directory: true) path end - def self.cache_or_load_file(path, data) - return PStore.new(path).transaction { |t| t['data'] } if File.exist?(path) + def self.cache_or_load_file(path, &creator) + raise 'Please set dataset creator on block' if creator.nil? - pstore = PStore.new(path) - pstore.transaction{|t| - t["data"] = data - } + return PStore.new(path).transaction { |t| t['data'] } if File.exist?(path) + data = creator.call + PStore.new(path).transaction do |t| + t['data'] = data + end data + rescue TypeError => e + puts e.message + FileUtils.rm_f(path) + cache_or_load_file(path) do + creator.call + end end end end diff --git a/lib/chainer/datasets/mnist.rb b/lib/chainer/datasets/mnist.rb index aadefad..540d445 100644 --- a/lib/chainer/datasets/mnist.rb +++ b/lib/chainer/datasets/mnist.rb @@ -48,8 +48,9 @@ def self.retrieve_mnist_test def self.retrieve_mnist(name, urls) root = Chainer::Dataset::Download.get_dataset_directory('pfnet/chainer/mnist') path = File.expand_path(name, root) - creator = make_npz(path, urls) - Chainer::Dataset::Download.cache_or_load_file(path, creator) + Chainer::Dataset::Download.cache_or_load_file(path) do + make_npz(path, urls) + end end def self.make_npz(path, urls)