-
Notifications
You must be signed in to change notification settings - Fork 0
/
augmentations.py
80 lines (74 loc) · 2.67 KB
/
augmentations.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import albumentations as A
from albumentations.pytorch.transforms import ToTensorV2
def get_training_augmentations(aug_type, resize_size=256, crop_size=240):
if "geometric" in aug_type:
augmentations = [
A.augmentations.crops.transforms.RandomResizedCrop(
resize_size, resize_size, (0.95, 1.0), (0.8, 1.2)
)
]
else:
augmentations = [
A.augmentations.geometric.resize.Resize(resize_size, resize_size)
]
if "soft" in aug_type:
pass
elif "wang" in aug_type:
# add Wang augmentations pipeline transformed into albumentations:
augmentations.extend(
[
A.augmentations.transforms.GaussianBlur(sigma_limit=(0.0, 3.0), p=0.5),
A.augmentations.transforms.ImageCompression(
quality_lower=30, quality_upper=100, p=0.5
),
]
)
elif "oneof" in aug_type:
augmentations.append(
A.OneOf(
[
A.augmentations.transforms.GaussianBlur(
sigma_limit=(0.0, 3.0), p=0.5
),
A.augmentations.transforms.ImageCompression(
quality_lower=30, quality_upper=100, p=0.5
),
A.augmentations.transforms.ISONoise(p=0.5),
A.augmentations.transforms.ColorJitter(0.4, 0.4, 0.0, 0.0, p=0.5),
]
)
)
elif "strong" in aug_type:
augmentations.append(
A.SomeOf(
[
A.augmentations.transforms.GaussianBlur(
sigma_limit=(0.0, 3.0), p=0.5
),
A.augmentations.transforms.ImageCompression(
quality_lower=30, quality_upper=100, p=0.5
),
A.augmentations.transforms.ISONoise(p=0.5),
A.augmentations.transforms.ColorJitter(0.4, 0.4, 0.0, 0.0, p=0.5),
],
2,
)
)
return A.Compose(
augmentations
+ [
A.augmentations.crops.transforms.RandomCrop(crop_size, crop_size),
A.augmentations.geometric.transforms.HorizontalFlip(),
A.Normalize(),
ToTensorV2(),
]
)
def get_validation_augmentations(resize_size=256, crop_size=240):
return A.Compose(
[
A.augmentations.geometric.resize.Resize(resize_size, resize_size),
A.augmentations.crops.transforms.CenterCrop(crop_size, crop_size),
A.Normalize(),
ToTensorV2(),
]
)