-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdata.lua
More file actions
118 lines (100 loc) · 2.79 KB
/
Copy pathdata.lua
File metadata and controls
118 lines (100 loc) · 2.79 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
local datasets = require 'datasets'
local paths = require 'paths'
local threads = require 'threads'
threads.Threads.serialization('threads.sharedserialize')
local DataProvider = torch.class('torch.DataProvider')
function DataProvider:__init(opt, dataset, split)
self.pool = threads.Threads(
opt.nThreads,
function()
require 'datasets'
end,
function(threadid)
torch.manualSeed(opt.seed + threadid)
threadDataset = dataset
end
)
self.batchSize = opt.batchSize
self.epochSize = dataset:size()
self.count = 0
if split == 'train' then
self.split = split
self.shuffle = opt.shuffle
self.idcs = torch.randperm(self.epochSize)
else
self.split = 'test'
self.idcs = torch.range(1, self.epochSize)
end
end
local batch = nil
function DataProvider:getBatch()
while self.pool:acceptsjob() do
local start = self.count + 1
self.count = self.count + self.batchSize
local idcs
if self.count <= self.epochSize then
idcs = self.idcs[{{start, self.count}}]
else
if start <= self.epochSize then
idcs = self.idcs[{{start, self.epochSize}}]
self.count = self.batchSize - idcs:size(1)
else
self.count = self.batchSize
end
if self.shuffle then
self.idcs = torch.randperm(self.epochSize)
end
if start <= self.epochSize then
idcs = torch.cat(idcs, self.idcs[{{1, self.count}}], 1)
else
idcs = self.idcs[{{1, self.count}}]
end
end
self.pool:addjob(
function(idcs, split)
return threadDataset:getBatch(idcs, split)
end,
function(_batch_)
batch = _batch_
end,
idcs,
self.split
)
end
self.pool:dojob()
return batch
end
function DataProvider:reset()
self.pool:synchronize()
self.count = 0
if self.shuffle then
self.idcs = torch.randperm(self.epochSize)
end
end
local function siftflow(opt)
local scene = torch.load(
paths.concat(opt.dataDir, opt.dataset .. '.t7')
)
local mean = {128, 128, 128}
local trainDataset = datasets.SceneDataset(
scene.train.data, scene.train.labels, 3, 256, 256, mean, nil,
true, {0.5, 1.5, 0.5, 1.5}
)
local trainDataProvider = torch.DataProvider(
opt, trainDataset, 'train'
)
local testDataset = datasets.SceneDataset(
scene.test.data, scene.test.labels, 3, 256, 256, mean, nil,
nil, nil
)
local testDataProvider = torch.DataProvider(
opt, testDataset, 'test'
)
return trainDataProvider, testDataProvider
end
local function getDataProvider(opt)
if opt.dataset == 'siftflow' then
return siftflow(opt)
end
end
return getDataProvider