This repository has been archived by the owner on Feb 29, 2024. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 0
/
cifar_batch.py
88 lines (74 loc) · 2.97 KB
/
cifar_batch.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
""" Custom batch class for storing cifar-10 batch and models
"""
import numpy as np
import os
import pickle
from dataset import Batch, action, model
class Cifar10Batch(Batch):
""" Cifar-10 batch.
"""
def __init__(self, index, *args, **kwargs):
""" Init func.
Args:
___
Return:
___
"""
super().__init__(index, *args, **kwargs)
self.images = None
self.labels = None
self.picnames = None
@property
def components(self):
""" Components of cifar-10 batch.
images: ndarray of shape (batch_size, 3, 32, 32), containig
cifar-10 images in RGB-mode.
labels: ndarray of shape (batch_size, 10/-1) containing labels
of cifar-10 images in one-hot/int format.
picnames: ndarray of shape (batch_size, ) containing filenames of
cifar-10 images.
"""
return 'images', 'labels', 'picnames'
@staticmethod
def _adjust_shape(component, name):
""" Adjust the shape of an array representing a component when loading from
pickle/memory(ndarray).
Args:
component: array containing component's data.
name: the name of the component.
Return:
an array with adjusted shape; component.
"""
if name == 'images':
return component.reshape(-1, 3, 32, 32).transpose((0, 2, 3, 1))
elif name == 'labels':
one_hot_labels = np.zeros(shape=(len(component), 10))
flattened = np.reshape(component, -1)
one_hot_labels[np.arange(len(flattened)), flattened] = 1
return one_hot_labels
elif name == 'picnames':
return np.reshape(component, (-1, 1))
@action
def load(self, src, fmt='pkl'):
""" Load cifar-10 pics.
Args:
src: if fmt is 'pkl', then src is assumed to be a path to a folder containing
pickled files with components ('component.pkl') that should be loaded.
if fmt is 'ndarray', then src is assumed to be a dict with keys that correspond
to compoenents to be loaded. In both cases ndarrays are subindexed according
to indices in batch.
fmt: format of src. Can be either 'pkl' (pickle) or 'ndarray'.
nclasses: type of cifar. 10 corresponds to cifar-10.
Return:
self.
"""
if fmt == 'pkl':
components = set(os.listdir(src)) & set([comp + '.pkl' for comp in self.components])
for comp in components:
with open(os.path.join(src, comp), 'rb') as file:
component = self._adjust_shape(pickle.load(file)[self.indices], comp.split('.')[0])
setattr(self, comp.split('.')[0], component)
elif fmt == 'ndarray':
for comp in src:
setattr(self, comp, self._adjust_shape(src.get(comp)[self.indices], comp))
return self