-
Notifications
You must be signed in to change notification settings - Fork 11
Expand file tree
/
Copy pathimage_datasets.py
More file actions
74 lines (63 loc) · 2.13 KB
/
image_datasets.py
File metadata and controls
74 lines (63 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from torch.utils.data import Dataset, DataLoader
import torch.utils.data as data
from torchvision import transforms
from torchvision.utils import save_image
import PIL
import os
import numpy as np
def _list_image_files(data_dir):
results = []
# Walk through the directory and its subdirectories
for root, _, files in os.walk(data_dir):
for file in sorted(files):
ext = file.split(".")[-1]
if "." in file and ext.lower() in ["jpg", "jpeg", "png", "gif"]:
# Construct the full path of the file
full_path = os.path.join(root, file)
results.append(full_path)
return results
def InfiniteSampler(n):
"""Data sampler"""
i = n - 1
order = np.random.permutation(n)
while True:
yield order[i]
i += 1
if i >= n:
np.random.seed()
order = np.random.permutation(n)
i = 0
class InfiniteSamplerWrapper(data.sampler.Sampler):
"""Data sampler wrapper"""
def __init__(self, data_source):
self.num_samples = len(data_source)
def __iter__(self):
return iter(InfiniteSampler(self.num_samples))
def __len__(self):
return 2 ** 31
class ImageLoader(Dataset):
def __init__(self, image_paths, transform=None):
super().__init__()
self.image_paths = image_paths
self.transform = transform
def __len__(self):
return len(self.image_paths)
def __getitem__(self, idx):
img = PIL.Image.open(self.image_paths[idx]).convert('RGB')
if self.transform:
img = self.transform(img)
return img
def create_loader(data_dir, img_size, batch_size):
all_files = _list_image_files(data_dir)
transform = transforms.Compose([
transforms.Resize((img_size, img_size)),
transforms.ToTensor(),
transforms.Lambda(lambda t: (t * 2) - 1)
])
dataset = ImageLoader(all_files, transform)
loader = iter(DataLoader(
dataset, batch_size=batch_size, shuffle=False,
sampler=InfiniteSamplerWrapper(dataset),
num_workers=4, pin_memory=True
))
return loader