-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdataloader.py
More file actions
52 lines (41 loc) · 2.13 KB
/
dataloader.py
File metadata and controls
52 lines (41 loc) · 2.13 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
import torch
class _InfiniteSampler(torch.utils.data.Sampler):
"""Wraps another Sampler to yield an infinite stream."""
def __init__(self, sampler):
self.sampler = sampler
def __iter__(self):
while True:
for batch in self.sampler:
yield batch
class InfiniteDataLoader:
def __init__(self, dataset, weights, batch_size, num_workers, collate_fn=None):
super().__init__()
if weights is not None:
sampler = torch.utils.data.WeightedRandomSampler(weights, replacement=True, num_samples=batch_size)
else:
sampler = torch.utils.data.RandomSampler(dataset, replacement=True)
batch_sampler = torch.utils.data.BatchSampler(sampler, batch_size=batch_size, drop_last=True)
self._infinite_iterator = iter(torch.utils.data.DataLoader(dataset, num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler),
collate_fn=collate_fn))
def __iter__(self):
while True:
yield next(self._infinite_iterator)
def __len__(self):
raise ValueError
class FastDataLoader:
"""DataLoader wrapper with slightly improved speed by not respawning worker
processes at every epoch."""
def __init__(self, dataset, batch_size, num_workers, collate_fn=None):
super().__init__()
batch_sampler = torch.utils.data.BatchSampler(torch.utils.data.RandomSampler(dataset, replacement=False),
batch_size=batch_size, drop_last=False)
self._infinite_iterator = iter(torch.utils.data.DataLoader(dataset, num_workers=num_workers,
batch_sampler=_InfiniteSampler(batch_sampler),
collate_fn=collate_fn))
self._length = len(batch_sampler)
def __iter__(self):
for _ in range(len(self)):
yield next(self._infinite_iterator)
def __len__(self):
return self._length