|
29 | 29 |
|
30 | 30 |
|
31 | 31 | class OneNearestNeighbor(BaseEstimator, ClassifierMixin): |
32 | | - "OneNearestNeighbor classifier." |
| 32 | + """OneNearestNeighbor classifier. |
| 33 | +
|
| 34 | + This estimator implements the 1-Nearest Neighbor classification algorithm. |
| 35 | + It predicts the label of a test sample based on the label of the single |
| 36 | + closest training sample (using Euclidean distance). |
| 37 | +
|
| 38 | + No hyperparameters are needed for 1-NN. |
| 39 | + """ |
33 | 40 |
|
34 | 41 | def __init__(self): # noqa: D107 |
35 | 42 | pass |
36 | 43 |
|
37 | 44 | def fit(self, X, y): |
38 | | - """Write docstring. |
| 45 | + """Fit the OneNearestNeighbor classifier. |
| 46 | +
|
| 47 | + The 1-NN model simply stores the training data (X, y). |
| 48 | +
|
| 49 | + Parameters |
| 50 | + ---------- |
| 51 | + X : array-like of shape (n_samples, n_features) |
| 52 | + The training input samples. |
| 53 | + y : array-like of shape (n_samples,) |
| 54 | + The target values. |
39 | 55 |
|
40 | | - And describe parameters |
| 56 | + Returns |
| 57 | + ------- |
| 58 | + self : OneNearestNeighbor |
| 59 | + The fitted estimator. |
41 | 60 | """ |
42 | 61 | X, y = check_X_y(X, y) |
43 | 62 | check_classification_targets(y) |
44 | 63 | self.classes_ = np.unique(y) |
45 | 64 | self.n_features_in_ = X.shape[1] |
46 | 65 |
|
47 | | - # XXX fix |
| 66 | + self.X_fit_ = X |
| 67 | + self.y_fit_ = y |
48 | 68 | return self |
49 | 69 |
|
50 | 70 | def predict(self, X): |
51 | | - """Write docstring. |
| 71 | + """Predict the class labels for the input samples. |
52 | 72 |
|
53 | | - And describe parameters |
| 73 | + For each test sample, find the closest training sample using |
| 74 | + Euclidean distance and return its label. |
| 75 | +
|
| 76 | + Parameters |
| 77 | + ---------- |
| 78 | + X : array-like of shape (n_samples_test, n_features) |
| 79 | + The input samples to predict. |
| 80 | +
|
| 81 | + Returns |
| 82 | + ------- |
| 83 | + y_pred : ndarray of shape (n_samples_test,) |
| 84 | + The predicted class labels. |
54 | 85 | """ |
55 | 86 | check_is_fitted(self) |
56 | 87 | X = check_array(X) |
| 88 | + |
57 | 89 | y_pred = np.full( |
58 | 90 | shape=len(X), fill_value=self.classes_[0], |
59 | 91 | dtype=self.classes_.dtype |
60 | 92 | ) |
61 | 93 |
|
62 | | - # XXX fix |
| 94 | + n_test = X.shape[0] |
| 95 | + |
| 96 | + for i in range(n_test): |
| 97 | + x_test = X[i, :] |
| 98 | + |
| 99 | + distances = np.sum((self.X_fit_ - x_test) ** 2, axis=1) |
| 100 | + |
| 101 | + nearest_neighbor_index = np.argmin(distances) |
| 102 | + |
| 103 | + y_pred[i] = self.y_fit_[nearest_neighbor_index] |
63 | 104 | return y_pred |
64 | 105 |
|
65 | 106 | def score(self, X, y): |
66 | | - """Write docstring. |
67 | | -
|
68 | | - And describe parameters |
| 107 | + """Return the mean accuracy on the given test data and labels. |
| 108 | +
|
| 109 | + Parameters |
| 110 | + ---------- |
| 111 | + X : array-like of shape (n_samples_test, n_features) |
| 112 | + The input samples. |
| 113 | + y : array-like of shape (n_samples_test,) |
| 114 | + True labels for X. |
| 115 | +
|
| 116 | + Returns |
| 117 | + ------- |
| 118 | + score : float |
| 119 | + Mean accuracy of self.predict(X) wrt. y. |
69 | 120 | """ |
70 | 121 | X, y = check_X_y(X, y) |
71 | 122 | y_pred = self.predict(X) |
72 | 123 |
|
73 | | - # XXX fix |
| 124 | + is_correct = (y_pred == y) |
| 125 | + |
| 126 | + n_samples = len(y) |
| 127 | + if n_samples == 0: |
| 128 | + y_pred = np.array([0.0]) |
| 129 | + return y_pred.sum() |
| 130 | + |
| 131 | + y_pred = is_correct.astype(float) / n_samples |
| 132 | + |
74 | 133 | return y_pred.sum() |
0 commit comments