-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathclassifier.py
More file actions
90 lines (78 loc) · 3.66 KB
/
classifier.py
File metadata and controls
90 lines (78 loc) · 3.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import numpy as np
import logging
from typing import Dict, List, Tuple
from sklearn.ensemble import RandomForestClassifier
from random_forest_classifier import load_rfc
from classification_source import ClassificationSource
logger: logging.Logger = logging.getLogger()
CLASSIFIER_LABELS: Dict[ClassificationSource, Dict[str, str]] = {
ClassificationSource.GAIA: {
'C': 'Ulens Candidate',
'N': 'Not Ulens'
},
ClassificationSource.TWO_MASS: {
'B': 'Be Star',
'U': 'Ulens Candidate',
'E': 'Evolved',
'Y': 'YSO'
},
ClassificationSource.WISE: {
'B': 'Be Star',
'S': 'Main Sequence Star',
'R': 'Red Giant',
'E': 'Evolved',
'Y': 'YSO'
}
}
def classification_results_response(classifier: RandomForestClassifier,
descriptive_classes: Dict[str, str],
predictions: List[List[float]]) -> List[List[str]]:
results: List[List[str]] = []
for i, c in enumerate(classifier.classes_):
results.append([descriptive_classes[c], predictions[0][i]])
return results
class Classifier:
def __init__(self):
self.__gaia_rfc: RandomForestClassifier = load_rfc(ClassificationSource.GAIA)
self.__2mass_rfc: RandomForestClassifier = load_rfc(ClassificationSource.TWO_MASS)
self.__wise_rfc: RandomForestClassifier = load_rfc(ClassificationSource.WISE)
@property
def gaia_rfc(self) -> RandomForestClassifier:
return self.__gaia_rfc
@property
def twomass_rfc(self) -> RandomForestClassifier:
return self.__2mass_rfc
@property
def wise_rfc(self) -> RandomForestClassifier:
return self.__wise_rfc
def classify(self, points: np.array) -> Tuple[List[List[str]], str, int]:
if points.shape[1] == 3:
descriptive_classes: Dict[str, str] = CLASSIFIER_LABELS[ClassificationSource.GAIA]
logger.info('Predicting with Gaia RFC...')
predictions: List[List[float]] = self.gaia_rfc.predict_proba(points)
return (classification_results_response(self.gaia_rfc,
descriptive_classes,
predictions),
'Classified with data from Gaia',
ClassificationSource.GAIA.value)
elif points.shape[1] == 15:
descriptive_classes: Dict[str, str] = CLASSIFIER_LABELS[ClassificationSource.TWO_MASS]
logger.info('Predicting with 2MASS RFC...')
predictions: List[List[float]] = self.twomass_rfc.predict_proba(points)
return (classification_results_response(self.twomass_rfc,
descriptive_classes,
predictions),
'Classified with data from Gaia and 2MASS',
ClassificationSource.TWO_MASS.value)
elif points.shape[1] == 28:
descriptive_classes: Dict[str, str] = CLASSIFIER_LABELS[ClassificationSource.WISE]
logger.info('Predicting with WISE RFC...')
predictions: List[List[float]] = self.wise_rfc.predict_proba(points)
return (classification_results_response(self.wise_rfc,
descriptive_classes,
predictions),
'Classified with data from Gaia, 2MASS and WISE',
ClassificationSource.WISE.value)
else:
logger.error('Expecting either 3, 16 or 28 columns')
return [], 'Not enough data to classify', 0