202 model selection module implementation#232
Conversation
There was a problem hiding this comment.
Pull request overview
This PR implements a model selection module to address issue #202, introducing two key classes: CandidateSelector for managing model selection with grid search, and BaseFittedModel as a container for individual candidate models. The module wraps sklearn's GridSearchCV and provides enhanced functionality for evaluating models based on RMSE metrics, variance, and overfitting indicators (RMSE ratio), with visualization capabilities.
Changes:
- Added
CandidateSelectorclass that extends GridSearchCV with candidate ranking, filtering, and RMSE-specific metrics - Added
BaseFittedModeldataclass for storing individual model candidate results and metadata - Implemented comprehensive test suite covering instantiation, fitting, candidate retrieval, filtering, and visualization
Reviewed changes
Copilot reviewed 4 out of 4 changed files in this pull request and generated 18 comments.
| File | Description |
|---|---|
chemotools/model_selection/__init__.py |
Module initialization exporting the two main classes |
chemotools/model_selection/_fitted_model.py |
BaseFittedModel dataclass implementation for storing candidate model information including RMSE metrics |
chemotools/model_selection/_candidate_selector.py |
CandidateSelector class wrapping GridSearchCV with enhanced model selection, filtering, and visualization features |
tests/model_selection/test_candidate_selector.py |
Comprehensive test suite covering core functionality, edge cases, and plotting methods |
💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.
| markers = ["o", "s", "^", "D", "v", "*", "p", "h"] | ||
| cmap = plt.colormaps.get_cmap("tab10") | ||
|
|
||
| for idx, key in enumerate(sorted(groups.keys())): | ||
| data = groups[key] | ||
| ax.scatter( | ||
| [d[0] for d in data], | ||
| [d[1] for d in data], | ||
| marker=markers[idx % len(markers)], | ||
| c=[cmap(idx % 10)], | ||
| s=80, | ||
| label=str(key), | ||
| edgecolors="black", | ||
| linewidths=0.5, | ||
| alpha=0.8, | ||
| ) |
There was a problem hiding this comment.
Potential issue with plot coloring when there are many parameter groups. The code uses cmap(idx % 10) on line 329, which means after 10 groups, colors will repeat. Similarly, markers cycle through 8 options. When a grid search has more than 8-10 unique values for the color_by parameter, the visualization becomes ambiguous. Consider: (1) adding a warning when there are more groups than available colors/markers, (2) using a different colormap that handles more distinct values, or (3) documenting this limitation in the docstring.
| def get_candidate(self, rank: int = 1) -> BaseFittedModel: | ||
| """Return candidate by rank (1 = best).""" | ||
| check_is_fitted(self, ["candidates_"]) | ||
| for c in self.candidates_: | ||
| if c.rank == rank: | ||
| return c | ||
| raise ValueError(f"No candidate with rank {rank}.") |
There was a problem hiding this comment.
Missing input validation for the rank parameter. The method does not validate that rank is positive. If a user passes rank=0 or rank=-1, the method will raise a "No candidate with rank" ValueError, but it would be clearer to validate the input upfront and provide a more specific error message indicating that rank must be a positive integer (>= 1).
| best_params_ : dict | ||
| Parameters that achieved the best score. | ||
| best_score_ : float | ||
| Best cross-validation score. |
There was a problem hiding this comment.
The best_index_ attribute is set during fit() but is not documented in the class docstring's "Attributes" section. For API consistency, all fitted attributes (those set during fit and ending with underscore) should be documented. Add best_index_ to the Attributes section with a description like "Index of best parameter combination in cv_results_."
| Best cross-validation score. | |
| Best cross-validation score. | |
| best_index_ : int | |
| Index of best parameter combination in ``cv_results_``. |
|
|
||
| import matplotlib.pyplot as plt | ||
| import numpy as np | ||
| from sklearn.base import BaseEstimator | ||
| from sklearn.model_selection import GridSearchCV | ||
| from sklearn.utils.validation import check_is_fitted | ||
| import operator |
There was a problem hiding this comment.
Import order inconsistency: the operator module should be imported earlier with other standard library imports (before third-party imports like matplotlib, numpy, sklearn). Following PEP 8 conventions, imports should be grouped as: (1) standard library, (2) third-party, (3) local imports.
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sklearn.base import BaseEstimator | |
| from sklearn.model_selection import GridSearchCV | |
| from sklearn.utils.validation import check_is_fitted | |
| import operator | |
| import operator | |
| import matplotlib.pyplot as plt | |
| import numpy as np | |
| from sklearn.base import BaseEstimator | |
| from sklearn.model_selection import GridSearchCV | |
| from sklearn.utils.validation import check_is_fitted |
| @pytest.fixture | ||
| def fitted_selector(dummy_data_loader): | ||
| """Return a fitted CandidateSelector for testing.""" | ||
| X, y = dummy_data_loader | ||
| selector = CandidateSelector( | ||
| estimator=Ridge(random_state=0), | ||
| param_grid={"alpha": [0.1, 1.0, 10.0]}, | ||
| cv=3, | ||
| scoring="neg_root_mean_squared_error", | ||
| return_train_score=True, | ||
| n_jobs=1, | ||
| ) | ||
| selector.fit(X, y) | ||
| return selector |
There was a problem hiding this comment.
Missing test coverage for non-RMSE scoring functions. All tests use scoring="neg_root_mean_squared_error", but the CandidateSelector should work with other scoring metrics too. The behavior differs when RMSE metrics are None (as documented in line 108 of _fitted_model.py). Add tests that verify: (1) the selector works with other scorers like "r2" or "neg_mean_squared_error", (2) RMSE-specific attributes (rmsecv, rmse_train, rmse_ratio) are None when not using RMSE scoring, and (3) plotting and filtering still work with non-RMSE metrics.
| @@ -0,0 +1,4 @@ | |||
| from ._candidate_selector import CandidateSelector | |||
There was a problem hiding this comment.
Missing module-level docstring. Other modules in the codebase (e.g., smooth, inspector) have docstrings at the module level. Add a docstring describing the model selection module's purpose and main classes.
| # Calculate RMSE metrics if using neg_root_mean_squared_error | ||
| rmsecv, rmse_train_val, rmse_ratio = None, None, None | ||
| if isinstance(scoring, str) and "neg_root_mean_squared_error" in scoring: | ||
| rmsecv = -mean_test | ||
| if mean_train is not None: | ||
| rmse_train_val = -mean_train | ||
| if rmse_train_val != 0: | ||
| rmse_ratio = rmsecv / rmse_train_val | ||
|
|
There was a problem hiding this comment.
The RMSE metric calculation only checks for "neg_root_mean_squared_error" in the scoring string, which may miss other RMSE-related scorers or fail silently for non-RMSE metrics. Consider: (1) documenting this RMSE-specific behavior more prominently in the class docstring, (2) adding a warning when RMSE metrics will not be calculated, or (3) providing a clearer way to enable/disable RMSE-specific metrics. Users who don't use RMSE scoring will have None for these metrics but may not understand why.
| """Return top *n* candidates (all if n is None).""" | ||
| check_is_fitted(self, ["candidates_"]) | ||
| if n is None: | ||
| return self.candidates_ |
There was a problem hiding this comment.
Missing input validation for the n parameter. The method does not validate that n is positive when provided. If a user passes n=0 or n=-1, the method will return an empty list or unexpected results respectively, without any warning or error. Consider adding validation to ensure n is either None or a positive integer, raising a ValueError for invalid inputs.
| return self.candidates_ | |
| return self.candidates_ | |
| if not isinstance(n, int) or n < 1: | |
| raise ValueError( | |
| f"n must be a positive integer or None, got {n!r}." | |
| ) |
| import numpy as np | ||
| import pytest | ||
| from sklearn.linear_model import Ridge | ||
|
|
||
| from chemotools.model_selection import BaseFittedModel, CandidateSelector | ||
|
|
There was a problem hiding this comment.
Missing sklearn compatibility test. Other modules in the codebase (e.g., smooth/test_savitzky_golay_filter.py) include a test using check_estimator from sklearn.utils.estimator_checks to verify sklearn API compliance. Consider adding a similar test for CandidateSelector to ensure it follows sklearn conventions, though note that meta-estimators like this one may need special handling or exclusions for certain checks.
| # Auto-detect color_by parameter | ||
| if color_by is None and self.candidates_: | ||
| color_by = next(iter(self.candidates_[0].params), None) |
There was a problem hiding this comment.
Potential edge case: if self.candidates_[0].params is an empty dictionary (which could happen if param_grid is empty or has no keys), next(iter(self.candidates_[0].params), None) will return None, which is handled. However, this edge case is unlikely in practice since GridSearchCV requires a non-empty param_grid. Consider adding a check or documenting this assumption.
| # Auto-detect color_by parameter | |
| if color_by is None and self.candidates_: | |
| color_by = next(iter(self.candidates_[0].params), None) | |
| # Auto-detect color_by parameter if candidates have parameters | |
| if color_by is None and self.candidates_: | |
| first_params = getattr(self.candidates_[0], "params", None) or {} | |
| if first_params: | |
| # Use the first available parameter name for coloring | |
| color_by = next(iter(first_params)) |
This PR will hopefully close #202 at some point 😄
Goal of the PR- what does it solve?
An implementation of a class
CandidateSelectorhave been implemented with the goal of easily creating, managing and selecting "candidates" based on performance metrics, but also include less used, but still important metrics such as variance.The introduced methods incentivize people to not only use metrics such as RMSE to determine the best candidate but also select it based on the variance of the model, which hopefully has been made easier to do with the class. 😃
Implementation details.
Two files have been created to implement the model selection module:
CandidateSelectorclass, that is the main driving force to make it workBaseFittedModelclass, that will work as a container for a single fitted model -CandidateSelectorwill give output in the form ofBaseFittedModelinstances (one for each candidate).Result from the implementation
Below are some of the ways that you can use it (at least for now):
Example 1
Imagine we need to train a series of PLS models with different parameters and preprocessing applied:
Normally, you would use something like the
GridSearchCV, but now we can instantiate an CandidateSelector class instead (which will in fact run aGridSearchCVunder the hood) and fit it:And from this, each "model" trained will be a
BaseFittedModelclass that contain several interesting metrics such as rank, parameters, test metric score (RMSE and variance) as well as RMSECV, RMSE ratio, etc.Example 2 - accessing candidates after training
After you have trained your candidates, several calls can be made to better explore these candidates - for instance you can ask for top 3 candidates and print some metrics from them:
You can also filter your candidates based on the RMSE ratio (i.e. REMSECV / RMSE from train) to identify overfitting:
Visualization is also possible with the
plot_cv_metricscall:Which plots something like this (visually speaking not good, but it shows the idea quite well nonetheless 😃 )
Or if you want to have a plot against test metric and variance:
Which produces plots like this:
What do to from here
With the addition of the inspector module, it is now of interest to somehow align these two module as I can imagine that the workflow should be something like
model selection module --> inspector moduleso in that sense I would like to get feedback on how I should go about this (and also get feedback just in general, i.e. whether or not this is good or if I should make any larger changes).So I don't expect to be done with this PR now, but merely have it work as a "feedback" round for the model selection module to be better aligned and how it should be approached in order to do that 😃