-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathload_dataset.py
78 lines (69 loc) · 2.64 KB
/
load_dataset.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
import matplotlib.pyplot as plt
import numpy as np
import os
from codebase.data import (
DataLoader,
ImageFolderDataset,
RescaleTransform,
NormalizeTransform,
ReshapeTransform,
FlattenTransform,
ComposeTransform,
)
from codebase.data.image_folder_dataset import RandomHorizontalFlip
plt.rcParams['figure.figsize'] = (10.0, 8.0) # set default size of plots
plt.rcParams['image.interpolation'] = 'nearest'
plt.rcParams['image.cmap'] = 'gray'
# ------------------------------------
def get_dataloaders(datasets, batch_size=256):
dataloaders = {}
for mode in ['train', 'val', 'test']:
crt_dataloader = DataLoader(
dataset=datasets[mode],
batch_size=batch_size,
shuffle=True,
drop_last=True,
)
dataloaders[mode] = crt_dataloader
return dataloaders
def get_compose_transform(useFlatten=True, training=False):
# Use the Cifar10 mean and standard deviation computed in Exercise 3.
cifar_mean = np.array([0.5, 0.5, 0.5]) # np.array([0.49191375, 0.48235852, 0.44673872])
cifar_std = np.array([0.5, 0.5, 0.5]) # np.array([0.24706447, 0.24346213, 0.26147554])
rescale_transform = RescaleTransform()
normalize_transform = NormalizeTransform(
mean=cifar_mean,
std=cifar_std
)
reshape_transform = ReshapeTransform()
flip = RandomHorizontalFlip()
if not training:
compose_transform = ComposeTransform([rescale_transform,
normalize_transform,
reshape_transform])
else:
compose_transform = ComposeTransform([flip,
rescale_transform,
normalize_transform,
reshape_transform])
return compose_transform
def get_datasets(DATASET, cifar_root, compose_transform, compose_transform_training=None):
# Create a train, validation and test dataset.
datasets = {}
for mode in ['train', 'val', 'test']:
if compose_transform_training is not None and mode == 'train':
crt_dataset = DATASET(
mode=mode,
root=cifar_root,
transform=compose_transform_training,
split={'train': 0.65, 'val': 0.15, 'test': 0.2}
)
else:
crt_dataset = DATASET(
mode=mode,
root=cifar_root,
transform=compose_transform,
split={'train': 0.65, 'val': 0.15, 'test': 0.2}
)
datasets[mode] = crt_dataset
return datasets