-
Notifications
You must be signed in to change notification settings - Fork 0
/
data_module.py
54 lines (44 loc) · 2.22 KB
/
data_module.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import os
import torch
from torch.utils.data import Dataset
from torchvision.io import read_image
from torchvision.transforms import Compose, ToTensor, Normalize, RandomHorizontalFlip,RandomVerticalFlip, RandomRotation, RandomResizedCrop, ColorJitter, RandomAffine
from PIL import Image
transform = Compose([
ToTensor(), # Convert images to Tensor
Normalize([0.5], [0.5]), # Assuming single-channel (grayscale) images; adjust for multi-channel
RandomHorizontalFlip(p=0.2),
RandomVerticalFlip(p=0.2),
RandomRotation(degrees=15), # Rotate +/- 15 degrees
RandomResizedCrop(size=(128, 128), scale=(0.8, 1.0)),
ColorJitter(brightness=0.2, contrast=0.2, saturation=0.2, hue=0.1),
RandomAffine(degrees=10, translate=(0.1, 0.1), scale=(0.8, 1.2), shear=10) # Apply affine transformation
])
class BinaryImageDataset(Dataset):
def __init__(self, synthetic_dir, original_dir, transform=transform):
"""
Args:
synthetic_dir (string): Directory with all the images for label 0.
original_dir (string): Directory with all the images for label 1.
"""
self.synthetic_dir = synthetic_dir
self.original_dir = original_dir
self.label0_images = [os.path.join(synthetic_dir, file) for file in os.listdir(synthetic_dir)]
self.label1_images = [os.path.join(original_dir, file) for file in os.listdir(original_dir)]
self.total_images = self.label0_images + self.label1_images
self.labels = [0] * len(self.label0_images) + [1] * len(self.label1_images)
self.transform = transform
def __len__(self):
return len(self.total_images)
def __getitem__(self, idx):
image_path = self.total_images[idx]
image = Image.open(image_path)
label = self.labels[idx]
image = self.transform(image)
return image, label
train_synthetic_data_path = 'dataset/train/synthetic'
train_orginal_data_path = 'dataset/train/original'
train_dataset = BinaryImageDataset(train_synthetic_data_path, train_orginal_data_path)
test_synthetic_data_path = 'dataset/test/synthetic'
test_orginal_data_path = 'dataset/test/original'
test_dataset = BinaryImageDataset(test_synthetic_data_path, test_orginal_data_path)