Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,10 @@ setup_classification_plot(
title="Logistic Regression",
x_label="Age",
y_label="Estimated Salary",
meshgrid={
0: {"min": 10, "max": 10, "step": 0.25},
1: {"min": 1000, "max": 1000, "step": 0.25},
},
feature_scale=lambda x_set, y_set: (
sc.inverse_transform(x_set), y_set
),
Expand Down Expand Up @@ -147,7 +151,8 @@ as follows:
* If the `feature_scale` lambda is not defined, `x_set` and `y_set` are
assigned the values of `x` and `y`, respectively
* `meshgrid` function from the NumPy library returns a tuple of coordinate
matrices from coordinate vectors
matrices from coordinate vectors (the ranges for each axis are controlled by
the `meshgrid` dict parameter passed to `setup_classification_plot`).
* Two sets of matrices (`x1` and `x2`) are returned with coordinate vectors
* `x1`
* `arange` function is called with a defined start and stop interval
Expand Down
28 changes: 24 additions & 4 deletions src/opengood/py_ml_plot/classification_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@ def setup_classification_plot(
title,
x_label,
y_label,
feature_scale: lambda x1, y1: (),
predict: lambda x1, x2: (),
meshgrid=None,
feature_scale=None,
predict=None,
):
"""
Sets up a classification plot with decision boundaries and classified regions.
Expand All @@ -27,6 +28,13 @@ class and overlays the decision boundaries of the classifier.
title (str): Title for the plot.
x_label (str): Label for the x-axis.
y_label (str): Label for the y-axis.
meshgrid (dict[int, dict[str, float]]): Controls the np.meshgrid arange parameters for each axis.
Provide a dict with two entries indexed by 0 and 1 (x-axis and y-axis respectively),
where each entry is a dict with keys:
- "min": float padding to subtract from the min value for the start of the range
- "max": float padding to add to the max value for the end of the range
- "step": float step size between values in the range
Example: {0: {"min": 10, "max": 10, "step": 0.25}, 1: {"min": 1000, "max": 1000, "step": 0.25}}
feature_scale (callable): Function to transform the feature data for visualization.
Should take x and y as input and return the transformed x and y.
If None, no transformation is applied.
Expand All @@ -44,18 +52,30 @@ class and overlays the decision boundaries of the classifier.
... title="Logistic Regression",
... x_label="Feature 1",
... y_label="Feature 2",
... meshgrid={0: {"min": 10, "max": 10, "step": 0.25}, 1: {"min": 1000, "max": 1000, "step": 0.25}},
... feature_scale=lambda x_set, y_set: (x_set, y_set),
... predict=lambda x1, x2: classifier.predict(np.array([x1.ravel(), x2.ravel()]).T).reshape(x1.shape)
... )
"""
if meshgrid is None:
meshgrid = {
0: {"min": 10, "max": 10, "step": 0.25},
1: {"min": 1000, "max": 1000, "step": 0.25},
}
if feature_scale is not None:
x_set, y_set = feature_scale(x, y)
else:
x_set, y_set = x, y
x1, x2 = np.meshgrid(
np.arange(start=x_set[:, 0].min() - 10, stop=x_set[:, 0].max() + 10, step=0.25),
np.arange(
start=x_set[:, 1].min() - 1000, stop=x_set[:, 1].max() + 1000, step=0.25
start=x_set[:, 0].min() - meshgrid[0]["min"],
stop=x_set[:, 0].max() + meshgrid[0]["max"],
step=meshgrid[0]["step"],
),
np.arange(
start=x_set[:, 1].min() - meshgrid[1]["min"],
stop=x_set[:, 1].max() + meshgrid[1]["max"],
step=meshgrid[1]["step"],
),
)
y_pred = predict(x1, x2)
Expand Down
8 changes: 8 additions & 0 deletions tests/py_ml_plot/test_classification_plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ def test_logistic_regression_setup_classification_plot_with_shaded_regions(self)
title="Logistic Regression",
x_label="Age",
y_label="Estimated Salary",
meshgrid={
0: {"min": 10, "max": 10, "step": 0.25},
1: {"min": 1000, "max": 1000, "step": 0.25},
},
feature_scale=lambda x_set, y_set: (sc.inverse_transform(x_set), y_set),
predict=lambda x1, x2: (
classifier.predict(
Expand Down Expand Up @@ -98,6 +102,10 @@ def test_k_nearest_neighbor_setup_classification_plot_with_shaded_regions(self):
title="K-Nearest Neighbor (K-NN)",
x_label="Age",
y_label="Estimated Salary",
meshgrid={
0: {"min": 10, "max": 10, "step": 0.25},
1: {"min": 1000, "max": 1000, "step": 0.25},
},
feature_scale=lambda x_set, y_set: (sc.inverse_transform(x_set), y_set),
predict=lambda x1, x2: (
classifier.predict(
Expand Down
Loading