From 8816516dbe72a0aacd1c57675395b9f73685c6b9 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Tue, 26 May 2026 16:49:09 +0200 Subject: [PATCH 1/3] Unify cluster palette and add discrete palette --- pypesto/visualize/_style.py | 85 ++++++++++++++++++++ pypesto/visualize/clust_color.py | 100 ++++++++++++++++++++---- pypesto/visualize/misc.py | 7 +- pypesto/visualize/optimization_stats.py | 63 ++++++++++++++- pypesto/visualize/optimizer_history.py | 27 ++++++- pypesto/visualize/parameters.py | 28 ++++++- pypesto/visualize/profiles.py | 27 ++++++- pypesto/visualize/waterfall.py | 29 ++++++- 8 files changed, 337 insertions(+), 29 deletions(-) create mode 100644 pypesto/visualize/_style.py diff --git a/pypesto/visualize/_style.py b/pypesto/visualize/_style.py new file mode 100644 index 000000000..e944992f1 --- /dev/null +++ b/pypesto/visualize/_style.py @@ -0,0 +1,85 @@ +""" +Shared visualization constants and helpers for ``pypesto.visualize``. + +TODO remove too much info after all PRs have been merged. + +Grown incrementally across the PR 1.5 follow-on series: each per-viz PR +adds the constants and helpers its consumers need, in the same diff as +those consumers. The final PR in the series adds :func:`apply_style` +(an opt-in rcParams preset). + +Style keys are surfaced to users via the ``style_kwargs`` parameter on +every public plotter, validated against :data:`_DEFAULTS`:: + + waterfall(result, style_kwargs={"mle_color": "tab:purple"}) + +How to extend +------------- +- **Add a constant**: ``UPPER_SNAKE`` name with a 1-line comment on its + semantic role, under the appropriate section header. Add a new + section header (``# ===`` block) when the purpose is genuinely new. +- **Add a registry key**: lowercase entry in :data:`_DEFAULTS` referencing + the constant. Unknown keys passed to :func:`resolve_style` raise + ``UserWarning`` so typos surface immediately. +- **Add a helper**: under an existing section, or a new one. Helpers that + cross module boundaries belong here. Module-local helpers stay in + their module. +""" + +from __future__ import annotations + +import warnings + +# Colors — semantic roles +# ----------------------- + +# matplotlib ``tab:red``; used for both the best-cluster colour and MLE markers. +MLE_COLOR = "#d62728" + +# Neutral mid-grey; isolated (singleton) starts and outlier indicators. +OUTLIER_COLOR = "#b3b3b3" + +# Colormaps +# --------- + +# Qualitative palette; secondary cluster colours and per-variable colours in +# prediction-trajectory plots are sampled from this. +CMAP_DISCRETE = "tab10" + +# Style registry +# -------------- + +_DEFAULTS: dict[str, object] = { + "mle_color": MLE_COLOR, + "outlier_color": OUTLIER_COLOR, + "cmap_discrete": CMAP_DISCRETE, +} + + +def resolve_style(style_kwargs: dict | None = None) -> dict: + """Return the effective style dict, merging defaults with caller overrides. + + Parameters + ---------- + style_kwargs: + User-supplied overrides. Unknown keys raise a ``UserWarning`` so + typos surface immediately. + + Returns + ------- + dict + Merged style dict with all keys from :data:`_DEFAULTS`, with + caller overrides applied on top. + """ + style = dict(_DEFAULTS) + if style_kwargs: + unknown = set(style_kwargs) - set(_DEFAULTS) + if unknown: + warnings.warn( + f"Unknown style_kwargs keys: {sorted(unknown)}. " + f"Valid keys: {sorted(_DEFAULTS)}.", + UserWarning, + stacklevel=3, + ) + style.update(style_kwargs) + return style diff --git a/pypesto/visualize/clust_color.py b/pypesto/visualize/clust_color.py index e73a5087b..5cbfe08b0 100644 --- a/pypesto/visualize/clust_color.py +++ b/pypesto/visualize/clust_color.py @@ -1,4 +1,7 @@ -import matplotlib.cm as cm +from __future__ import annotations + +import matplotlib.colors as mcolors +import matplotlib.pyplot as plt import numpy as np from matplotlib.colors import is_color_like @@ -6,10 +9,56 @@ # for typehints from ..C import COLOR +from ._style import resolve_style + + +def _build_cluster_palette(style: dict) -> np.ndarray: + """Sample non-best cluster colors from ``cmap_discrete``. + + Colors close to ``mle_color`` or ``outlier_color`` are filtered out so + those reserved roles remain visually distinct from cycled cluster colors. + """ + cmap = plt.get_cmap(style["cmap_discrete"]) + reserved = [ + np.array(mcolors.to_rgb(style["mle_color"])), + np.array(mcolors.to_rgb(style["outlier_color"])), + ] + + # We remove colors from the cluster palette that are too close to the + # reserved colors (MLE red and outlier grey). The distance threshold is + # just a reasonable heuristic. + _RESERVED_COLOR_DISTANCE = 0.2 + + # Number of evenly-spaced samples taken from a continuous cmap (e.g. viridis) + # when ``cmap_discrete`` is set to one. Categorical cmaps (e.g. tab10) use + # all their listed colors and ignore this. + _CMAP_DISCRETE_SAMPLES = 10 + + if hasattr(cmap, "colors"): + candidates = [mcolors.to_rgba(c) for c in cmap.colors] + else: + candidates = [ + cmap(i / (_CMAP_DISCRETE_SAMPLES - 1)) + for i in range(_CMAP_DISCRETE_SAMPLES) + ] + palette = [ + c + for c in candidates + if all( + np.linalg.norm(np.array(c[:3]) - r) > _RESERVED_COLOR_DISTANCE + for r in reserved + ) + ] + if not palette: + palette = candidates + return np.array(palette) def assign_clustered_colors( - vals: np.ndarray, balance_alpha: bool = True, highlight_global: bool = True + vals: np.ndarray, + balance_alpha: bool = True, + highlight_global: bool = True, + style: dict | None = None, ): """ Cluster and assign colors. @@ -23,6 +72,10 @@ def assign_clustered_colors( avoid overplotting highlight_global: flag indicating whether global optimum should be highlighted + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -36,6 +89,12 @@ def assign_clustered_colors( # assign clusters clusters, cluster_size = assign_clusters(vals) + if style is None: + style = resolve_style(None) + palette = _build_cluster_palette(style) + mle_rgba = list(mcolors.to_rgba(style["mle_color"])) + outlier_rgb = list(mcolors.to_rgb(style["outlier_color"])) + # create list of colors, which has the correct shape n_clusters = 1 + max(clusters) - sum(cluster_size == 1) @@ -43,13 +102,12 @@ def assign_clustered_colors( if highlight_global and cluster_size[0] > 1: n_clusters -= 1 - # fill color array from colormap - colormap = cm.ScalarMappable().to_rgba - color_list = colormap(np.linspace(0.0, 1.0, n_clusters)) + # fill color array by cycling through the categorical cluster palette + color_list = palette[np.arange(n_clusters) % len(palette)].copy() - # best optimum should be colored in red + # best optimum should be colored in MLE red if highlight_global and cluster_size[0] > 1: - color_list = np.concatenate(([[1.0, 0.0, 0.0, 1.0]], color_list)) + color_list = np.concatenate(([mle_rgba], color_list)) # We have clustered the results. However, clusters may have size 1, # so we need to rearrange the regroup the results into "no_clusters", @@ -64,8 +122,8 @@ def assign_clustered_colors( if balance_alpha: # set minimal alpha value to avoid non-visible colors min_alpha = 0.01 - # assign neutral color, add 1 for avoiding division by zero - grey = [0.7, 0.7, 0.7, min(1.0, 5.0 / (no_clusters.size + 1.0))] + # alpha shrinks with the number of singletons to avoid overplotting + grey = [*outlier_rgb, min(1.0, 5.0 / (no_clusters.size + 1.0))] # reduce alpha level depend on size of each cluster n_cluster_size = np.delete(cluster_size, no_clusters) @@ -74,8 +132,7 @@ def assign_clustered_colors( 1.0, max(5.0 / n_cluster_size[icluster], min_alpha) ) else: - # assign neutral color - grey = [0.7, 0.7, 0.7, 1.0] + grey = [*outlier_rgb, 1.0] # create a color list, prfilled with grey values colors = np.array([grey] * clusters.size) @@ -86,9 +143,9 @@ def assign_clustered_colors( ind_of_iclust = np.argwhere(clusters == iclust).flatten() colors[ind_of_iclust, :] = color_list[icol, :] - # if best value was found only once: replace it with red + # if best value was found only once: replace it with MLE red if highlight_global and cluster_size[0] == 1: - colors[0] = [1.0, 0.0, 0.0, 1.0] + colors[0] = mle_rgba return colors @@ -98,6 +155,7 @@ def assign_colors( colors: COLOR | list[COLOR] | np.ndarray | None = None, balance_alpha: bool = True, highlight_global: bool = True, + style: dict | None = None, ) -> np.ndarray: """ Assign colors or format user specified colors. @@ -113,6 +171,10 @@ def assign_colors( avoid overplotting highlight_global: flag indicating whether global optimum should be highlighted + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -129,6 +191,7 @@ def assign_colors( vals, balance_alpha=balance_alpha, highlight_global=highlight_global, + style=style, ) # Get number of elements and use user assigned colors @@ -160,6 +223,7 @@ def assign_colors( def assign_colors_for_list( num_entries: int, colors: COLOR | list[COLOR] | np.ndarray | None = None, + style: dict | None = None, ) -> list[list[float]] | np.ndarray: """ Create a list of colors for a list of items. @@ -173,6 +237,10 @@ def assign_colors_for_list( number of results in list colors: list of colors, or single color + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -188,7 +256,10 @@ def assign_colors_for_list( # we don't want alpha levels for all plotting routines in this case... colors = assign_colors( - dummy_clusters, balance_alpha=False, highlight_global=False + dummy_clusters, + balance_alpha=False, + highlight_global=False, + style=style, ) # dummy cluster had twice as many entries as really there. Reduce. @@ -201,4 +272,5 @@ def assign_colors_for_list( colors=colors, balance_alpha=False, highlight_global=False, + style=style, ) diff --git a/pypesto/visualize/misc.py b/pypesto/visualize/misc.py index b74f2b55f..661aed86b 100644 --- a/pypesto/visualize/misc.py +++ b/pypesto/visualize/misc.py @@ -35,6 +35,7 @@ def process_result_list( results: Result | list[Result], colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: str | list[str] | None = None, + style: dict | None = None, ) -> tuple[list[Result], list[COLOR], list[str]]: """ Assign colors and legends to a list of results, check user provided lists. @@ -47,6 +48,10 @@ def process_result_list( list of colors recognized by matplotlib, or single color legends: labels for line plots + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -84,7 +89,7 @@ def process_result_list( legend_type_error = True else: # if more than one result is passed, we use one color per result - colors = assign_colors_for_list(len(results), colors) + colors = assign_colors_for_list(len(results), colors, style=style) # check whether list of legends has the correct length if legends is None: diff --git a/pypesto/visualize/optimization_stats.py b/pypesto/visualize/optimization_stats.py index 5669627f8..7cb719864 100644 --- a/pypesto/visualize/optimization_stats.py +++ b/pypesto/visualize/optimization_stats.py @@ -8,6 +8,7 @@ from ..C import COLOR from ..result import Result +from ._style import resolve_style from .clust_color import assign_colors, assign_colors_for_list from .misc import ( get_ax, @@ -28,6 +29,7 @@ def optimization_run_properties_one_plot( legends: str | list[str] | None = None, plot_type: str = "line", ax: matplotlib.axes.Axes | None = None, + style_kwargs: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot stats for allproperties specified in properties_to_plot on one plot. @@ -54,6 +56,16 @@ def optimization_run_properties_one_plot( Labels, one label per optimization property plot_type: Specifies plot type. Possible values: 'line' and 'hist' + style_kwargs: + Style overrides. Keys used by this function: + + - ``cmap_discrete`` — the categorical palette from which + per-property line colours are sampled. Only consulted when + ``colors`` is ``None``; an explicit ``colors`` short-circuits + palette selection. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. Returns ------- @@ -76,6 +88,8 @@ def optimization_run_properties_one_plot( colors=[[.5, .9, .9, .3], [.2, .1, .9, .5]] ) """ + style = resolve_style(style_kwargs) + if properties_to_plot is None: properties_to_plot = [ "time", @@ -87,7 +101,9 @@ def optimization_run_properties_one_plot( ] if colors is None: - colors = assign_colors_for_list(len(properties_to_plot)) + colors = assign_colors_for_list( + len(properties_to_plot), style=style + ) elif is_color_like(colors): colors = [colors] @@ -136,6 +152,7 @@ def optimization_run_properties_per_multistart( legends: str | list[str] | None = None, plot_type: str = "line", axes: np.ndarray | None = None, + style_kwargs: dict | None = None, ) -> np.ndarray: """ One plot per optimization property in properties_to_plot. @@ -160,6 +177,19 @@ def optimization_run_properties_per_multistart( Labels for line plots, one label per result object plot_type: Specifies plot type. Possible values: 'line' and 'hist' + style_kwargs: + Style overrides forwarded to + :func:`optimization_run_property_per_multistart`. Keys used by + this function: + + - ``cmap_discrete``, ``mle_color``, ``outlier_color`` — colours + of the per-start scatter when clustering is applied (best + cluster, secondary clusters, isolated starts respectively). + Only consulted when ``colors`` is ``None``; an explicit + ``colors`` short-circuits clustering. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. Returns ------- @@ -223,6 +253,7 @@ def optimization_run_properties_per_multistart( colors=colors, legends=legends, plot_type=plot_type, + style_kwargs=style_kwargs, ) return axes @@ -236,6 +267,7 @@ def optimization_run_property_per_multistart( colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: str | list[str] | None = None, plot_type: str = "line", + style_kwargs: dict | None = None, ) -> np.ndarray: """ Plot stats for an optimization run property specified by opt_run_property. @@ -268,12 +300,25 @@ def optimization_run_property_per_multistart( Labels for line plots, one label per result object plot_type: Specifies plot type. Possible values: 'line', 'hist', 'both' + style_kwargs: + Style overrides. Keys used by this function: + + - ``cmap_discrete``, ``mle_color``, ``outlier_color`` — colours + of the per-start scatter when clustering is applied + (single-result default; best cluster, secondary clusters, + isolated starts respectively). Only consulted when + ``colors`` is ``None``; an explicit ``colors`` short-circuits + clustering. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. Returns ------- axes: 2-D NumPy array containing one matplotlib Axes per panel. """ + style = resolve_style(style_kwargs) supported_properties = { "time": "Wall-clock time (seconds)", "n_fval": "Number of function evaluations", @@ -291,7 +336,9 @@ def optimization_run_property_per_multistart( ) # parse input - (results, colors, legends) = process_result_list(results, colors, legends) + (results, colors, legends) = process_result_list( + results, colors, legends, style=style + ) ncols = 2 if plot_type == "both" else 1 axes = get_axes_array(axes=axes, nrows=1, ncols=ncols, size=size) @@ -320,6 +367,7 @@ def optimization_run_property_per_multistart( start_indices, colors[j], legends[j], + style=style, ) stats_lowlevel( @@ -331,6 +379,7 @@ def optimization_run_property_per_multistart( colors[j], legends[j], plot_type="hist", + style=style, ) else: stats_lowlevel( @@ -342,6 +391,7 @@ def optimization_run_property_per_multistart( colors[j], legends[j], plot_type, + style=style, ) if sum(legend is not None for legend in legends) > 0: @@ -363,6 +413,7 @@ def stats_lowlevel( color: COLOR | list[COLOR] | np.ndarray | None = "C0", legend: str | None = None, plot_type: str = "line", + style: dict | None = None, ): """ Plot values of the optimization run property across different multistarts. @@ -389,6 +440,10 @@ def stats_lowlevel( Label describing the result plot_type: Specifies plot type. Possible values: 'line' and 'hist' + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -407,7 +462,9 @@ def stats_lowlevel( n_starts = len(values) # assign colors - colors = assign_colors(vals=fvals, colors=color, balance_alpha=False) + colors = assign_colors( + vals=fvals, colors=color, balance_alpha=False, style=style + ) sorted_indices = sorted(range(n_starts), key=lambda j: fvals[j]) values = values[sorted_indices] diff --git a/pypesto/visualize/optimizer_history.py b/pypesto/visualize/optimizer_history.py index 6fc624010..8baa6ad3f 100644 --- a/pypesto/visualize/optimizer_history.py +++ b/pypesto/visualize/optimizer_history.py @@ -15,6 +15,7 @@ ) from ..history import HistoryBase from ..result import Result +from ._style import resolve_style from .clust_color import assign_colors from .misc import ( get_ax, @@ -44,6 +45,7 @@ def optimizer_history( | list[dict] | None = None, legends: str | list[str] | None = None, + style_kwargs: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot history of optimizer. @@ -86,17 +88,32 @@ def optimizer_history( least a function value fval legends: Labels for line plots, one label per result object + style_kwargs: + Style overrides. Keys used by this function: + + - ``cmap_discrete``, ``mle_color``, ``outlier_color`` — colours + of the per-start history traces when clustering is applied + (best cluster, secondary clusters, isolated starts respectively). + Only consulted when ``colors`` is ``None``; an explicit + ``colors`` short-circuits clustering. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. Returns ------- ax: The plot axes. """ + style = resolve_style(style_kwargs) + if isinstance(start_indices, int): start_indices = list(range(start_indices)) # parse input - (results, colors, legends) = process_result_list(results, colors, legends) + (results, colors, legends) = process_result_list( + results, colors, legends, style=style + ) for j, result in enumerate(results): # extract cost function values from result @@ -119,6 +136,7 @@ def optimizer_history( x_label=x_label, y_label=y_label, legend_text=legends[j], + style=style, ) # parse and apply plotting options @@ -139,6 +157,7 @@ def optimizer_history_lowlevel( x_label: str = "Optimizer steps", y_label: str = "Objective value", legend_text: str | None = None, + style: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot optimizer history using list of numpy arrays. @@ -162,6 +181,10 @@ def optimizer_history_lowlevel( label for y-axis legend_text: Label for line plots + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -194,7 +217,7 @@ def optimizer_history_lowlevel( # assign colors # note: this has to happen before sorting # to get the same colors in different plots - colors = assign_colors(fvals, colors) + colors = assign_colors(fvals, colors, style=style) # sort indices = sorted(range(n_fvals), key=lambda j: fvals[j]) diff --git a/pypesto/visualize/parameters.py b/pypesto/visualize/parameters.py index 597f615f0..470531310 100644 --- a/pypesto/visualize/parameters.py +++ b/pypesto/visualize/parameters.py @@ -17,6 +17,7 @@ InnerParameterType, ) from ..result import Result +from ._style import resolve_style from .clust_color import assign_colors from .misc import ( get_ax, @@ -53,6 +54,7 @@ def parameters( scale_to_interval: tuple[float, float] | None = None, plot_inner_parameters: bool = True, log10_scale_hier_sigma: bool = True, + style_kwargs: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot parameter values. @@ -95,14 +97,29 @@ def parameters( log10_scale_hier_sigma: Flag indicating whether to scale inner parameters of type ``InnerParameterType.SIGMA`` to log10 (default: True). + style_kwargs: + Style overrides. Keys used by this function: + + - ``cmap_discrete``, ``mle_color``, ``outlier_color`` — colours + of the per-start parameter traces when clustering is used + (best cluster, secondary clusters, isolated starts respectively). + Only consulted when ``colors`` is ``None``; an explicit + ``colors`` short-circuits clustering. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. Returns ------- ax: The plot axes. """ + style = resolve_style(style_kwargs) + # parse input - (results, colors, legends) = process_result_list(results, colors, legends) + (results, colors, legends) = process_result_list( + results, colors, legends, style=style + ) if isinstance(parameter_indices, str): if parameter_indices == "all": @@ -161,6 +178,7 @@ def scale_parameters(x): colors=colors[j], legend_text=legends[j], balance_alpha=balance_alpha, + style=style, ) # parse and apply plotting options @@ -264,6 +282,7 @@ def parameters_lowlevel( linestyle: str = "-", legend_text: str | None = None, balance_alpha: bool = True, + style: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot parameters plot using list of parameters. @@ -292,13 +311,16 @@ def parameters_lowlevel( balance_alpha: Flag indicating whether alpha for large clusters should be reduced to avoid overplotting (default: True) + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- ax: The plot axes. """ - if size is None: # 0.5 inch height per parameter size = (18.5, max(xs.shape[1], 1) / 2) @@ -307,7 +329,7 @@ def parameters_lowlevel( # assign colors colors = assign_colors( - vals=fvals, colors=colors, balance_alpha=balance_alpha + vals=fvals, colors=colors, balance_alpha=balance_alpha, style=style ) # parameter indices diff --git a/pypesto/visualize/profiles.py b/pypesto/visualize/profiles.py index d61f5bd7e..c5fef9512 100644 --- a/pypesto/visualize/profiles.py +++ b/pypesto/visualize/profiles.py @@ -12,6 +12,7 @@ from ..problem import Problem from ..profile import chi2_quantile_to_ratio from ..result import Result +from ._style import resolve_style from .clust_color import assign_colors from .misc import get_ax, process_result_list from .reference_points import ReferencePoint, create_references @@ -123,6 +124,7 @@ def profiles( show_bounds: bool = False, plot_objective_values: bool = False, quality_colors: bool = False, + style_kwargs: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot classical 1D profile plot. @@ -173,12 +175,24 @@ def profiles( had to resample the parameter vector due to optimization failure of the previous two. Black indicates a step for which none of the above was necessary. This option is only available if there is only one result and one profile_list_id (one profile per plot). + style_kwargs: + Style overrides. Keys used by this function: + + - ``cmap_discrete``, ``mle_color``, ``outlier_color`` — colours + of the per-result / per-profile-list profile lines (best + cluster, secondary clusters, isolated starts respectively). + Only consulted when ``colors`` is ``None``; an explicit + ``colors`` short-circuits clustering. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. Returns ------- ax: The plot axes. """ + style = resolve_style(style_kwargs) if colors is not None and quality_colors: raise ValueError( @@ -195,7 +209,7 @@ def profiles( # parse input results, profile_list_ids, colors, legends = process_result_list_profiles( - results, profile_list_ids, legends, colors + results, profile_list_ids, legends, colors, style=style ) # get the parameter ids to be plotted @@ -621,6 +635,7 @@ def process_result_list_profiles( profile_list_ids: int | Sequence[int] | None, legends: str | list[str], colors: COLOR | list[COLOR] | np.ndarray | None = None, # todo: check + style: dict | None = None, ) -> tuple[list[Result], list[int] | Sequence[int], list, list[str]]: """ Assign colors and legends to a list of results. @@ -637,6 +652,10 @@ def process_result_list_profiles( list of colors for plotting. legends: Legends for plotting + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -653,7 +672,7 @@ def process_result_list_profiles( if len(results) != 1: # if we have no single result, then use the standard api results, colors, legends = process_result_list( - results, colors, legends + results, colors, legends, style=style ) return results, profile_list_ids, colors, legends else: @@ -662,7 +681,9 @@ def process_result_list_profiles( # If we have a single result, we may still have multiple profile_list_ids # which should be plotted separately: use profile_list_ids as results dummy - _, colors, legends = process_result_list(profile_list_ids, colors, legends) + _, colors, legends = process_result_list( + profile_list_ids, colors, legends, style=style + ) return results, profile_list_ids, colors, legends diff --git a/pypesto/visualize/waterfall.py b/pypesto/visualize/waterfall.py index c25158363..13be87a62 100644 --- a/pypesto/visualize/waterfall.py +++ b/pypesto/visualize/waterfall.py @@ -9,6 +9,7 @@ from ..C import ALL, COLOR, WATERFALL_MAX_VALUE from ..result import Result +from ._style import resolve_style from .clust_color import assign_colors from .misc import ( get_ax, @@ -33,6 +34,7 @@ def waterfall( colors: COLOR | list[COLOR] | np.ndarray | None = None, legends: Sequence[str] | str | None = None, order_by_id: bool = False, + style_kwargs: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot waterfall plot. @@ -71,6 +73,17 @@ def waterfall( the same x-axis position. Only applicable when a list of result objects are provided. Default behavior is to sort the function values of each result independently of other results. + style_kwargs: + Style overrides. Keys used by this function: + + - ``cmap_discrete``, ``mle_color``, ``outlier_color`` — colours + of the per-start scatter dots when clustering is used (best + cluster, secondary clusters, isolated starts respectively). + Only consulted when ``colors`` is ``None``; an explicit + ``colors`` short-circuits clustering. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. Returns ------- @@ -78,6 +91,7 @@ def waterfall( The plot axes. """ ax = get_ax(ax, size) + style = resolve_style(style_kwargs) if n_starts_to_zoom: # create zoom in @@ -89,7 +103,9 @@ def waterfall( inset_axes = None # parse input - (results, colors, legends) = process_result_list(results, colors, legends) + (results, colors, legends) = process_result_list( + results, colors, legends, style=style + ) # handle `order_by_id` if order_by_id: @@ -153,7 +169,7 @@ def waterfall( fvals.sort() # assign colors - coloring = assign_colors(fvals, colors=colors[j]) + coloring = assign_colors(fvals, colors=colors[j], style=style) # call lowlevel plot routine ax = waterfall_lowlevel( @@ -164,6 +180,7 @@ def waterfall( size=size, colors=coloring, legend_text=legends[j], + style=style, ) if inset_axes is not None: @@ -172,6 +189,7 @@ def waterfall( scale_y=scale_y, ax=inset_axes, colors=coloring[:n_starts_to_zoom], + style=style, ) # remove the title and axes labels for the zoom in subplot inset_axes.set(title=None, xlabel=None, ylabel=None) @@ -203,6 +221,7 @@ def waterfall_lowlevel( offset_y: float = 0.0, colors: COLOR | list[COLOR] | np.ndarray | None = None, legend_text: str | None = None, + style: dict | None = None, ) -> matplotlib.axes.Axes: """ Plot waterfall plot using list of function values. @@ -226,6 +245,10 @@ def waterfall_lowlevel( and colors are assigned automatically legend_text: Label for line plots + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. Returns ------- @@ -240,7 +263,7 @@ def waterfall_lowlevel( colors = [colors[i] for i in start_indices] # assign colors - colors = assign_colors(fvals, colors=colors) + colors = assign_colors(fvals, colors=colors, style=style) # plot ax.xaxis.set_major_locator(MaxNLocator(integer=True)) From 835b95eaa2a3c8cb13b5596bd088ccbb000b2b23 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Thu, 28 May 2026 15:35:39 +0200 Subject: [PATCH 2/3] small style cleanup + fix test failure --- pypesto/visualize/_style.py | 39 +++++-------------------- pypesto/visualize/optimization_stats.py | 4 +-- 2 files changed, 9 insertions(+), 34 deletions(-) diff --git a/pypesto/visualize/_style.py b/pypesto/visualize/_style.py index e944992f1..44b9b14e8 100644 --- a/pypesto/visualize/_style.py +++ b/pypesto/visualize/_style.py @@ -1,29 +1,13 @@ """ -Shared visualization constants and helpers for ``pypesto.visualize``. +Visual style for ``pypesto.visualize``. -TODO remove too much info after all PRs have been merged. +Default constants, the ``style_kwargs`` registry, and small cross-module +helpers. -Grown incrementally across the PR 1.5 follow-on series: each per-viz PR -adds the constants and helpers its consumers need, in the same diff as -those consumers. The final PR in the series adds :func:`apply_style` -(an opt-in rcParams preset). - -Style keys are surfaced to users via the ``style_kwargs`` parameter on -every public plotter, validated against :data:`_DEFAULTS`:: +Users override any default per call via ``style_kwargs``, validated against +:data:`_DEFAULTS`:: waterfall(result, style_kwargs={"mle_color": "tab:purple"}) - -How to extend -------------- -- **Add a constant**: ``UPPER_SNAKE`` name with a 1-line comment on its - semantic role, under the appropriate section header. Add a new - section header (``# ===`` block) when the purpose is genuinely new. -- **Add a registry key**: lowercase entry in :data:`_DEFAULTS` referencing - the constant. Unknown keys passed to :func:`resolve_style` raise - ``UserWarning`` so typos surface immediately. -- **Add a helper**: under an existing section, or a new one. Helpers that - cross module boundaries belong here. Module-local helpers stay in - their module. """ from __future__ import annotations @@ -32,19 +16,12 @@ # Colors — semantic roles # ----------------------- - -# matplotlib ``tab:red``; used for both the best-cluster colour and MLE markers. -MLE_COLOR = "#d62728" - -# Neutral mid-grey; isolated (singleton) starts and outlier indicators. -OUTLIER_COLOR = "#b3b3b3" +MLE_COLOR = "#d62728" # tab:red — best cluster + MLE markers +OUTLIER_COLOR = "#b3b3b3" # mid-grey — singleton / outlier starts # Colormaps # --------- - -# Qualitative palette; secondary cluster colours and per-variable colours in -# prediction-trajectory plots are sampled from this. -CMAP_DISCRETE = "tab10" +CMAP_DISCRETE = "tab10" # qualitative: cluster + per-variable colours # Style registry # -------------- diff --git a/pypesto/visualize/optimization_stats.py b/pypesto/visualize/optimization_stats.py index 7cb719864..e98e82e0a 100644 --- a/pypesto/visualize/optimization_stats.py +++ b/pypesto/visualize/optimization_stats.py @@ -101,9 +101,7 @@ def optimization_run_properties_one_plot( ] if colors is None: - colors = assign_colors_for_list( - len(properties_to_plot), style=style - ) + colors = assign_colors_for_list(len(properties_to_plot), style=style) elif is_color_like(colors): colors = [colors] From e17583a0c54f94ec74703ab3be34371972df7658 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Thu, 28 May 2026 16:36:10 +0200 Subject: [PATCH 3/3] small edits --- pypesto/visualize/_style.py | 2 +- pypesto/visualize/clust_color.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pypesto/visualize/_style.py b/pypesto/visualize/_style.py index 44b9b14e8..5f341feb2 100644 --- a/pypesto/visualize/_style.py +++ b/pypesto/visualize/_style.py @@ -39,7 +39,7 @@ def resolve_style(style_kwargs: dict | None = None) -> dict: Parameters ---------- style_kwargs: - User-supplied overrides. Unknown keys raise a ``UserWarning`` so + User-supplied overrides. Unknown keys emit a ``UserWarning`` so typos surface immediately. Returns diff --git a/pypesto/visualize/clust_color.py b/pypesto/visualize/clust_color.py index 5cbfe08b0..98c5c5484 100644 --- a/pypesto/visualize/clust_color.py +++ b/pypesto/visualize/clust_color.py @@ -134,7 +134,7 @@ def assign_clustered_colors( else: grey = [*outlier_rgb, 1.0] - # create a color list, prfilled with grey values + # create a color list, prefilled with grey values colors = np.array([grey] * clusters.size) # assign colors to real clusters