forked from PINTO0309/DeepLearningMugenKnock
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_cifar10.py
56 lines (39 loc) · 1.43 KB
/
load_cifar10.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
import numpy as np
import os
import pickle
labels = ["airplane", "automobile", "bird", "cat", "deer",
"dog", "frog", "horse", "ship", "truck"]
def load_cifar10():
path = 'cifar-10-batches-py'
if not os.path.exists(path):
os.system("wget {}".format(path))
os.system("tar xvf {}".format(path))
# train data
train_x = np.ndarray([0, 32, 32, 3], dtype=np.float32)
train_y = np.ndarray([0, ], dtype=np.int)
for i in range(1, 6):
data_path = path + '/data_batch_{}'.format(i)
with open(data_path, 'rb') as f:
datas = pickle.load(f, encoding='bytes')
print(data_path)
x = datas[b'data']
x = x.reshape(x.shape[0], 3, 32, 32)
x = x.transpose(0, 2, 3, 1)
train_x = np.vstack((train_x, x))
y = np.array(datas[b'labels'], dtype=np.int)
train_y = np.hstack((train_y, y))
print(train_x.shape)
print(train_y.shape)
# test data
data_path = path + '/test_batch'
with open(data_path, 'rb') as f:
datas = pickle.load(f, encoding='bytes')
print(data_path)
x = datas[b'data']
x = x.reshape(x.shape[0], 3, 32, 32)
test_x = x.transpose(0, 2, 3, 1)
test_y = np.array(datas[b'labels'], dtype=np.int)
print(test_x.shape)
print(test_y.shape)
return train_x, train_y, test_x, test_y
load_cifar10()