-
Notifications
You must be signed in to change notification settings - Fork 38
Expand file tree
/
Copy pathaugment.py
More file actions
158 lines (135 loc) · 5.26 KB
/
augment.py
File metadata and controls
158 lines (135 loc) · 5.26 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
"""Image augmentation transformations using torchvision v2."""
import numpy as np
from torchvision.transforms import v2 as T
from PIL import Image
# transform registry - comprehensive torchvision v2 transforms
# see https://github.com/pytorch/vision/blob/218d2ab791d437309f91e0486eb9fa7f00badc17/torchvision/transforms/transforms.py
TRANSFORMS = {
# Geometric transforms
"rotate": lambda m: T.RandomRotation(degrees=int(180 * m)),
"flip_h": lambda m: T.RandomHorizontalFlip(p=m),
"flip_v": lambda m: T.RandomVerticalFlip(p=m),
"affine": lambda m: T.RandomAffine(degrees=0, translate=(m * 0.2, m * 0.2)),
"shear": lambda m: T.RandomAffine(degrees=0, shear=int(45 * m)),
"perspective": lambda m: T.RandomPerspective(distortion_scale=0.5 * m, p=1.0),
"elastic": lambda m: T.ElasticTransform(alpha=m * 50.0),
# Color/photometric transforms
"brightness": lambda m: T.ColorJitter(brightness=m * 0.5),
"contrast": lambda m: T.ColorJitter(contrast=m * 0.5),
"saturation": lambda m: T.ColorJitter(saturation=m * 0.5),
"hue": lambda m: T.ColorJitter(hue=0.1 * m),
"color_jitter": lambda m: T.ColorJitter(
brightness=m * 0.3, contrast=m * 0.3, saturation=m * 0.3, hue=0.05 * m
),
# Advanced color transforms
"sharpen": lambda m: T.RandomAdjustSharpness(sharpness_factor=1 + m * 3, p=1.0),
"autocontrast": lambda m: T.RandomAutocontrast(p=1.0),
"equalize": lambda m: T.RandomEqualize(p=1.0),
"invert": lambda m: T.RandomInvert(p=1.0),
"solarize": lambda m: T.RandomSolarize(threshold=int(128 + 127 * m), p=1.0),
"posterize": lambda m: T.RandomPosterize(bits=max(1, int(2 + 6 * m)), p=1.0),
"grayscale": lambda m: T.RandomGrayscale(p=m),
# Blur and noise
"blur": lambda m: T.GaussianBlur(kernel_size=3, sigma=(0.1, 3 + 20 * m)),
# Occlusion/masking
"erasing": lambda m: T.RandomErasing(p=m, scale=(0.02, 0.33), ratio=(0.3, 3.3)),
# 'cutout': lambda m: T.RandomErasing(p=1.0, scale=(0.02, m*0.4), ratio=(1.0, 1.0)),
# Advanced augmentation techniques
"channel_permute": lambda m: T.RandomChannelPermutation()
if m > 0.5
else T.Identity(),
"photometric_distort": lambda m: T.RandomPhotometricDistort(
brightness=(1 - 0.3 * m, 1 + 0.3 * m),
contrast=(1 - 0.3 * m, 1 + 0.3 * m),
saturation=(1 - 0.3 * m, 1 + 0.3 * m),
hue=(-0.05 * m, 0.05 * m),
),
# Crop and resize (useful for scale invariance)
"random_crop": lambda m: T.RandomResizedCrop(
size=32, # Will be adjusted based on input size
scale=(1 - m * 0.3, 1.0),
ratio=(0.75, 1.33),
antialias=True,
),
}
def get_transform_categories():
"""Get transforms organized by category."""
return {
"geometric": [
"rotate",
"flip_h",
"flip_v",
"affine",
"shear",
"perspective",
"elastic",
"random_crop",
],
"color": ["brightness", "contrast", "saturation", "hue", "color_jitter"],
"advanced_color": [
"sharpen",
"autocontrast",
"equalize",
"invert",
"solarize",
"posterize",
"grayscale",
],
"blur_noise": ["blur"],
"occlusion": ["erasing"],
"advanced": ["channel_permute", "photometric_distort"],
}
def numpy_to_pil(image):
"""Convert numpy array to PIL Image."""
if isinstance(image, np.ndarray):
return Image.fromarray(image.astype(np.uint8))
return image
def pil_to_numpy(image):
"""Convert PIL Image to numpy array."""
if isinstance(image, Image.Image):
return np.array(image)
return image
def make_transform(name, magnitude):
"""Create a single transform with given magnitude (0-1 scale)."""
if name not in TRANSFORMS:
raise ValueError(
f"Unknown transform: {name}. Available: {list(TRANSFORMS.keys())}"
)
return TRANSFORMS[name](np.clip(magnitude, 0.0, 1.0))
def apply_policy(image, policy):
"""
Apply augmentation policy to an image.
Args:
image: RGB image as numpy array (H, W, C) or PIL Image
policy: List of (transform_name, magnitude) tuples
Returns:
Augmented image as numpy array
"""
# Convert to PIL if needed
was_numpy = isinstance(image, np.ndarray)
pil_image = numpy_to_pil(image)
# Get original size for size-dependent transforms
orig_size = pil_image.size # (W, H)
# Apply transforms
transforms = []
for name, mag in policy:
transform = make_transform(name, mag)
# Adjust size for crop transforms
if name == "random_crop":
transform = T.RandomResizedCrop(
size=min(orig_size),
scale=(1 - mag * 0.3, 1.0),
ratio=(0.75, 1.33),
antialias=True,
)
transforms.append(transform)
pipeline = T.Compose(transforms)
augmented = pipeline(pil_image)
# Convert back to numpy if input was numpy
if was_numpy:
return pil_to_numpy(augmented)
return augmented
def create_augmenter(policy):
"""Create reusable augmentation pipeline from policy."""
transforms = [make_transform(name, mag) for name, mag in policy]
return T.Compose(transforms)