From e7437b5af934f78e702d564ef677469fa001a69b Mon Sep 17 00:00:00 2001 From: Chris Jaehnen Date: Tue, 9 Sep 2025 20:23:22 -0400 Subject: [PATCH] Add mesgrid parameter to setup_classification_plot function --- README.md | 7 ++++- .../py_ml_plot/classification_plot.py | 28 ++++++++++++++++--- tests/py_ml_plot/test_classification_plot.py | 8 ++++++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/README.md b/README.md index 63d934d..90f8a1e 100644 --- a/README.md +++ b/README.md @@ -78,6 +78,10 @@ setup_classification_plot( title="Logistic Regression", x_label="Age", y_label="Estimated Salary", + meshgrid={ + 0: {"min": 10, "max": 10, "step": 0.25}, + 1: {"min": 1000, "max": 1000, "step": 0.25}, + }, feature_scale=lambda x_set, y_set: ( sc.inverse_transform(x_set), y_set ), @@ -147,7 +151,8 @@ as follows: * If the `feature_scale` lambda is not defined, `x_set` and `y_set` are assigned the values of `x` and `y`, respectively * `meshgrid` function from the NumPy library returns a tuple of coordinate - matrices from coordinate vectors + matrices from coordinate vectors (the ranges for each axis are controlled by + the `meshgrid` dict parameter passed to `setup_classification_plot`). * Two sets of matrices (`x1` and `x2`) are returned with coordinate vectors * `x1` * `arange` function is called with a defined start and stop interval diff --git a/src/opengood/py_ml_plot/classification_plot.py b/src/opengood/py_ml_plot/classification_plot.py index 033d27a..919d10f 100644 --- a/src/opengood/py_ml_plot/classification_plot.py +++ b/src/opengood/py_ml_plot/classification_plot.py @@ -10,8 +10,9 @@ def setup_classification_plot( title, x_label, y_label, - feature_scale: lambda x1, y1: (), - predict: lambda x1, x2: (), + meshgrid=None, + feature_scale=None, + predict=None, ): """ Sets up a classification plot with decision boundaries and classified regions. @@ -27,6 +28,13 @@ class and overlays the decision boundaries of the classifier. title (str): Title for the plot. x_label (str): Label for the x-axis. y_label (str): Label for the y-axis. + meshgrid (dict[int, dict[str, float]]): Controls the np.meshgrid arange parameters for each axis. + Provide a dict with two entries indexed by 0 and 1 (x-axis and y-axis respectively), + where each entry is a dict with keys: + - "min": float padding to subtract from the min value for the start of the range + - "max": float padding to add to the max value for the end of the range + - "step": float step size between values in the range + Example: {0: {"min": 10, "max": 10, "step": 0.25}, 1: {"min": 1000, "max": 1000, "step": 0.25}} feature_scale (callable): Function to transform the feature data for visualization. Should take x and y as input and return the transformed x and y. If None, no transformation is applied. @@ -44,18 +52,30 @@ class and overlays the decision boundaries of the classifier. ... title="Logistic Regression", ... x_label="Feature 1", ... y_label="Feature 2", + ... meshgrid={0: {"min": 10, "max": 10, "step": 0.25}, 1: {"min": 1000, "max": 1000, "step": 0.25}}, ... feature_scale=lambda x_set, y_set: (x_set, y_set), ... predict=lambda x1, x2: classifier.predict(np.array([x1.ravel(), x2.ravel()]).T).reshape(x1.shape) ... ) """ + if meshgrid is None: + meshgrid = { + 0: {"min": 10, "max": 10, "step": 0.25}, + 1: {"min": 1000, "max": 1000, "step": 0.25}, + } if feature_scale is not None: x_set, y_set = feature_scale(x, y) else: x_set, y_set = x, y x1, x2 = np.meshgrid( - np.arange(start=x_set[:, 0].min() - 10, stop=x_set[:, 0].max() + 10, step=0.25), np.arange( - start=x_set[:, 1].min() - 1000, stop=x_set[:, 1].max() + 1000, step=0.25 + start=x_set[:, 0].min() - meshgrid[0]["min"], + stop=x_set[:, 0].max() + meshgrid[0]["max"], + step=meshgrid[0]["step"], + ), + np.arange( + start=x_set[:, 1].min() - meshgrid[1]["min"], + stop=x_set[:, 1].max() + meshgrid[1]["max"], + step=meshgrid[1]["step"], ), ) y_pred = predict(x1, x2) diff --git a/tests/py_ml_plot/test_classification_plot.py b/tests/py_ml_plot/test_classification_plot.py index b8435f0..56f20ce 100644 --- a/tests/py_ml_plot/test_classification_plot.py +++ b/tests/py_ml_plot/test_classification_plot.py @@ -39,6 +39,10 @@ def test_logistic_regression_setup_classification_plot_with_shaded_regions(self) title="Logistic Regression", x_label="Age", y_label="Estimated Salary", + meshgrid={ + 0: {"min": 10, "max": 10, "step": 0.25}, + 1: {"min": 1000, "max": 1000, "step": 0.25}, + }, feature_scale=lambda x_set, y_set: (sc.inverse_transform(x_set), y_set), predict=lambda x1, x2: ( classifier.predict( @@ -98,6 +102,10 @@ def test_k_nearest_neighbor_setup_classification_plot_with_shaded_regions(self): title="K-Nearest Neighbor (K-NN)", x_label="Age", y_label="Estimated Salary", + meshgrid={ + 0: {"min": 10, "max": 10, "step": 0.25}, + 1: {"min": 1000, "max": 1000, "step": 0.25}, + }, feature_scale=lambda x_set, y_set: (sc.inverse_transform(x_set), y_set), predict=lambda x1, x2: ( classifier.predict(