-
Notifications
You must be signed in to change notification settings - Fork 5
Expand file tree
/
Copy pathclasses.py
More file actions
44 lines (37 loc) · 1.31 KB
/
classes.py
File metadata and controls
44 lines (37 loc) · 1.31 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
class modelSet:
def __init__(self, modelName, modelClass = 'cat'):
self.name = modelName
self.modelClass = modelClass
class modelWrapper: # allows user to use single-net and double-net model interchangeably
def __init__(self, models, bb8Labels):
if type(models) is dict:
if len(models.keys()) == 2:
self.classModel = models['classModel']
self.vecModel = models['vecModel']
self.combined = False
else:
raise Exception("Probably shouldn't be here ever..")
else:
self.combModel = models
self.combined = True
self.bb8Labels = bb8Labels
def genPredict(self, input):
if self.combined:
pred = self.combModel.predict(input)
return pred[0], pred[1]
else:
return self.vecModel.predict(input), self.classModel.predict(input)
class modelDictVal:
def __init__(self, structure, generator, losses, outVectors, outClasses, epochs = 3, lr = 0.01, metrics = ['accuracy'], outVecName = None, outClassName = None, bb8Labels = True, augmentation = True):
self.structure = structure
self.generator = generator
self.losses = losses
self.outVectors = outVectors
self.outClasses = outClasses
self.epochs = epochs
self.metrics = metrics
self.lr = lr
self.outVecName = outVecName
self.outClassName = outClassName
self.bb8Labels = bb8Labels
self.augmentation = augmentation