diff --git a/pypesto/visualize/_style.py b/pypesto/visualize/_style.py index 5f341feb2..4475f7e53 100644 --- a/pypesto/visualize/_style.py +++ b/pypesto/visualize/_style.py @@ -8,11 +8,24 @@ :data:`_DEFAULTS`:: waterfall(result, style_kwargs={"mle_color": "tab:purple"}) + +Keys are named after the **visual element**, not the plot using them: +``line_*`` (any line), ``dash_*`` (any tick / rug marker), ``rectangle_*`` +(fills), ``bound_*`` (parameter bounds). Plot-specific keys (e.g. +``trace_linewidth``) only when a default genuinely diverges. + +TODO (capstone): trim the naming note above once the series has settled. """ from __future__ import annotations import warnings +from typing import Literal + +import matplotlib as mpl +import matplotlib.axes +import numpy as np +from matplotlib.lines import Line2D # Colors — semantic roles # ----------------------- @@ -23,6 +36,40 @@ # --------- CMAP_DISCRETE = "tab10" # qualitative: cluster + per-variable colours +# Lines (KDE curves, simulation / model-fit lines, …) +# --------------------------------------------------- +LINE_COLOR = "#145685" +LINEWIDTH = 1.5 + +# Dash markers (rug ticks, CI endpoints, …) +# ----------------------------------------- +DASH_COLOR = "#174261" +DASH_LINEWIDTH = 1.2 # markeredgewidth +DASH_MARKERSIZE = 10 # marker length +DASH_ALPHA = 0.8 + +# Rectangle / histogram fills +# --------------------------- +RECTANGLE_COLOR = "#3182bd" +RECTANGLE_EDGECOLOR = "#000000" +RECTANGLE_LINEWIDTH = 1.0 +RECTANGLE_ALPHA = 0.6 + +# Grid sizing — per-panel inches for multi-panel grids (size=None default) +# ------------------------------------------------------------------------ +GRID_SIZE_PER_COL = 3.5 +GRID_SIZE_PER_ROW = 2.5 + +# Parameter bounds +# ---------------- +BOUND_LINESTYLE = "--" +BOUND_COLOR = "0.5" +BOUND_LINEWIDTH = 1.4 +BOUND_ALPHA = 0.95 +BOUND_VIEW_MARGIN = ( + 0.03 # axis-limit padding so bound lines aren't flush with the spine +) + # Style registry # -------------- @@ -30,6 +77,20 @@ "mle_color": MLE_COLOR, "outlier_color": OUTLIER_COLOR, "cmap_discrete": CMAP_DISCRETE, + "line_color": LINE_COLOR, + "linewidth": LINEWIDTH, + "dash_color": DASH_COLOR, + "dash_linewidth": DASH_LINEWIDTH, + "dash_markersize": DASH_MARKERSIZE, + "dash_alpha": DASH_ALPHA, + "rectangle_color": RECTANGLE_COLOR, + "rectangle_edgecolor": RECTANGLE_EDGECOLOR, + "rectangle_linewidth": RECTANGLE_LINEWIDTH, + "rectangle_alpha": RECTANGLE_ALPHA, + "bound_color": BOUND_COLOR, + "bound_linestyle": BOUND_LINESTYLE, + "bound_linewidth": BOUND_LINEWIDTH, + "bound_alpha": BOUND_ALPHA, } @@ -60,3 +121,110 @@ def resolve_style(style_kwargs: dict | None = None) -> dict: ) style.update(style_kwargs) return style + + +# rcParams preset, not default, opt-in via ``apply_style()`` +# --------------- + + +def apply_style() -> None: + """Apply pyPESTO's recommended matplotlib rcParams. + + Sets larger axis/tick labels, removes top/right spines globally, + styles legends (auto-placed, framed, lightly translucent fill), and enables + ``constrained_layout`` for sensible panel spacing. + + Opt-in: not called automatically. Users (and pyPESTO's example + notebooks/docs) call this once at the top of a session. + """ + mpl.rcParams.update( + { + "axes.labelsize": 13, + "axes.labelweight": mpl.rcParamsDefault["axes.labelweight"], + "axes.titlesize": 14, + "axes.titleweight": "bold", + "xtick.labelsize": 11, + "ytick.labelsize": 11, + "legend.fontsize": mpl.rcParamsDefault["legend.fontsize"], + # Legends: auto-placed, framed, and lightly translucent so text reads + # clearly without making the legend feel heavy. + "legend.loc": "best", + "legend.frameon": True, + "legend.framealpha": 0.6, + "legend.edgecolor": "0.7", + "axes.spines.top": False, + "axes.spines.right": False, + "axes.grid": False, + "figure.constrained_layout.use": True, + } + ) + + +# Bound-line helpers +# ------------------ + + +def _bounds_legend_handle( + label: str = "Bounds", style: dict | None = None +) -> Line2D: + """Return a Line2D matching the bound style suitable as a legend handle.""" + s = style or {} + return Line2D( + [0], + [0], + color=s.get("bound_color", BOUND_COLOR), + linestyle=s.get("bound_linestyle", BOUND_LINESTYLE), + linewidth=s.get("bound_linewidth", BOUND_LINEWIDTH), + alpha=s.get("bound_alpha", BOUND_ALPHA), + label=label, + ) + + +def draw_bounds_1d( + ax: matplotlib.axes.Axes, + lb: float, + ub: float, + *, + axis: Literal["x", "y"] = "x", + view_margin: bool = True, + style: dict | None = None, +) -> Line2D: + """Draw the canonical pyPESTO parameter-bound lines on *ax*. + + ``axis="x"`` draws two vertical dashed lines (``axvline``) at *lb* and + *ub*; ``axis="y"`` draws two horizontal dashed lines (``axhline``). + + When *view_margin* is true the corresponding axis limits are extended by + :data:`BOUND_VIEW_MARGIN` * (ub - lb) so the bound lines are visible + rather than flush with the spine. + + Returns a :class:`~matplotlib.lines.Line2D` that can be passed as a + legend handle (the lines drawn on the axis are not labeled to keep the + automatic legend clean). + """ + if axis not in ("x", "y"): + raise ValueError(f"axis must be 'x' or 'y', got {axis!r}") + s = style or {} + color = s.get("bound_color", BOUND_COLOR) + linestyle = s.get("bound_linestyle", BOUND_LINESTYLE) + linewidth = s.get("bound_linewidth", BOUND_LINEWIDTH) + alpha = s.get("bound_alpha", BOUND_ALPHA) + drawer = ax.axvline if axis == "x" else ax.axhline + for bound in (lb, ub): + drawer( + bound, + color=color, + linestyle=linestyle, + linewidth=linewidth, + alpha=alpha, + zorder=1, + ) + if view_margin and np.isfinite(lb) and np.isfinite(ub) and ub > lb: + margin = BOUND_VIEW_MARGIN * (ub - lb) + if axis == "x": + cur_lo, cur_hi = ax.get_xlim() + ax.set_xlim(min(cur_lo, lb - margin), max(cur_hi, ub + margin)) + else: + cur_lo, cur_hi = ax.get_ylim() + ax.set_ylim(min(cur_lo, lb - margin), max(cur_hi, ub + margin)) + return _bounds_legend_handle(style=s) diff --git a/pypesto/visualize/misc.py b/pypesto/visualize/misc.py index 661aed86b..3eebf493f 100644 --- a/pypesto/visualize/misc.py +++ b/pypesto/visualize/misc.py @@ -26,6 +26,12 @@ ) from ..result import Result from ..util import assign_clusters, delete_nan_inf +from ._style import ( + GRID_SIZE_PER_COL, + GRID_SIZE_PER_ROW, + draw_bounds_1d, + resolve_style, +) from .clust_color import assign_colors_for_list logger = logging.getLogger(__name__) @@ -491,7 +497,9 @@ def get_axes_array( Expected grid shape. size: Figure size ``(width, height)`` in inches; only used when ``axes`` - is None. + is None. When ``None`` a single panel uses matplotlib's default + figure size, and a multi-panel grid uses + ``(GRID_SIZE_PER_COL * ncols, GRID_SIZE_PER_ROW * nrows)``. Returns ------- @@ -499,6 +507,8 @@ def get_axes_array( A 2-D NumPy object array containing matplotlib Axes. """ if axes is None: + if size is None and nrows * ncols > 1: + size = (GRID_SIZE_PER_COL * ncols, GRID_SIZE_PER_ROW * nrows) _, axes = plt.subplots( nrows, ncols, @@ -630,6 +640,140 @@ def plot_diagonal_marginal( ax.set_ylabel("Count") +def plot_density_panel( + ax: matplotlib.axes.Axes, + values: np.ndarray, + bins: int | str = "auto", + bw_method: str = "scott", + style: dict | None = None, + *, + show_hist: bool = True, + show_kde: bool = True, + show_rug: bool = True, + show_bounds: bool = False, + lb: float | None = None, + ub: float | None = None, +): + """Draw a density panel: histogram, KDE overlay, rug marks, and bounds. + + Element styling is read from the resolved *style* dict: + + - histogram bars: ``rectangle_*`` + - KDE line: ``line_color`` / ``linewidth`` + - rug marks: ``dash_color`` / ``dash_linewidth`` / ``dash_markersize`` / + ``dash_alpha`` + - parameter-bound lines: ``bound_*`` (drawn only when ``show_bounds`` is + ``True`` and ``lb`` / ``ub`` are finite) + + Sets x-axis limits to the data range with a 5% margin; if bounds are + drawn, the limits are extended to include them (never shrunk). + + Parameters + ---------- + ax: + Axes to draw into. + values: + 1-D array of data values. + bins: + Histogram bins — passed directly to :func:`matplotlib.axes.Axes.hist`. + bw_method: + Bandwidth method for :class:`scipy.stats.gaussian_kde`. + style: + Pre-resolved visualization style dict, as returned by + :func:`pypesto.visualize._style.resolve_style`. When ``None``, defaults + are used. + show_hist: + Whether to draw the histogram bars. + show_kde: + Whether to draw the KDE line overlay. + show_rug: + Whether to draw rug marks along the x-axis. + show_bounds: + Whether to draw parameter-bound lines via + :func:`pypesto.visualize._style.draw_bounds_1d`. Requires ``lb`` and + ``ub`` to be passed and finite; silently skipped otherwise. + lb, ub: + Lower and upper parameter bounds. Used only when ``show_bounds`` is + ``True``. + + Returns + ------- + The bound legend handle if bounds were drawn, otherwise ``None``. + Callers wire this into their per-panel legend. + """ + from scipy.stats import gaussian_kde + + style = style if style is not None else resolve_style(None) + + values = np.asarray(values, dtype=float) + values = values[np.isfinite(values)] + if values.size == 0: + return + + if show_hist: + ax.hist( + values, + bins=bins, + density=True, + color=style["rectangle_color"], + alpha=style["rectangle_alpha"], + edgecolor=style["rectangle_edgecolor"], + linewidth=style["rectangle_linewidth"], + ) + + if show_kde and len(values) > 1 and np.std(values) > 0: + try: + kde = gaussian_kde(values, bw_method=bw_method) + # Extend 3 bandwidths beyond the data so the curve tapers to zero. + bw = kde.factor * np.std(values, ddof=1) + x_grid = np.linspace( + values.min() - 3 * bw, values.max() + 3 * bw, 300 + ) + ax.plot( + x_grid, + kde(x_grid), + color=style["line_color"], + linewidth=style["linewidth"], + ) + except np.linalg.LinAlgError: + pass + + if show_rug: + ax.plot( + values, + np.zeros(len(values)), + marker="|", + linestyle="none", + color=style["dash_color"], + alpha=style["dash_alpha"], + markersize=style["dash_markersize"], + markeredgewidth=style["dash_linewidth"], + transform=ax.get_xaxis_transform(), + clip_on=False, + zorder=5, + ) + + # Frame the panel tightly to the data (5% margin around the range, or a + # sensible fallback when the data is constant). + data_min = float(values.min()) + data_max = float(values.max()) + spread = data_max - data_min + margin = spread * 0.05 if spread > 0 else max(abs(data_min) * 0.05, 1.0) + ax.set_xlim(data_min - margin, data_max + margin) + + # Draw parameter bounds (xlim is extended by draw_bounds_1d when bounds + # fall outside the data frame; never shrunk). + if ( + show_bounds + and lb is not None + and ub is not None + and np.isfinite(lb) + and np.isfinite(ub) + ): + return draw_bounds_1d(ax, lb, ub, axis="x", style=style) + return None + + #: Sentinel meaning "this kwarg was not passed at all." #: Use as the default for deprecated kwargs so that an explicit #: ``f(old_kwarg=None)`` can be detected and warned about. @@ -637,29 +781,38 @@ def plot_diagonal_marginal( def process_deprecated_kwarg( - canonical_name: str, + canonical_name: str | None, canonical_value, deprecated_name: str, deprecated_value=_UNSET, stacklevel: int = 3, + note: str | None = None, ): """ - Resolve a kwarg that has been renamed. + Resolve a kwarg that has been renamed or removed. The deprecated kwarg must use :data:`_UNSET` as its default in the calling function so that an explicit ``f(old_kwarg=None)`` is correctly detected and warned about. - Returns the canonical value if the deprecated kwarg was not passed, - the deprecated value (with a ``DeprecationWarning``) if only the old - name was used, or raises ``ValueError`` if both are given. + Two modes: + + - **Rename** (``canonical_name`` is given): returns the canonical value + if the deprecated kwarg was not passed, the deprecated value (with a + ``DeprecationWarning``) if only the old name was used, or raises + ``ValueError`` if both are given. + - **Removal** (``canonical_name`` is ``None``): emits a + ``DeprecationWarning`` if the deprecated kwarg was passed and returns + ``None``. ``canonical_value`` is ignored. Parameters ---------- canonical_name: - Name of the canonical (new) kwarg, used in messages. + Name of the canonical (new) kwarg, used in messages. Pass ``None`` + to indicate the kwarg is removed without replacement. canonical_value: - Value passed under the canonical name (or ``None``). + Value passed under the canonical name (or ``None``). Ignored when + ``canonical_name`` is ``None``. deprecated_name: Name of the deprecated (old) kwarg, used in messages. deprecated_value: @@ -668,12 +821,26 @@ def process_deprecated_kwarg( Forwarded to :func:`warnings.warn`. Default 3 attributes the warning to the caller of the public function that invoked this helper. + note: + Optional additional sentence appended to the deprecation message + (e.g. explaining where the behaviour moved to). Useful for the + removal case. Returns ------- value: - The resolved value, or ``None`` if neither was given. + The resolved value, or ``None`` if neither was given (or in the + removal case). """ + if canonical_name is None: + if deprecated_value is _UNSET: + return None + message = f"`{deprecated_name}` is deprecated and has no effect." + if note: + message = f"{message} {note}" + warnings.warn(message, DeprecationWarning, stacklevel=stacklevel) + return None + if deprecated_value is _UNSET: return canonical_value if canonical_value is not None: @@ -681,9 +848,10 @@ def process_deprecated_kwarg( f"Pass either `{canonical_name}` or the deprecated " f"`{deprecated_name}`, not both." ) - warnings.warn( - f"`{deprecated_name}` is deprecated; use `{canonical_name}` instead.", - DeprecationWarning, - stacklevel=stacklevel, + message = ( + f"`{deprecated_name}` is deprecated; use `{canonical_name}` instead." ) + if note: + message = f"{message} {note}" + warnings.warn(message, DeprecationWarning, stacklevel=stacklevel) return deprecated_value diff --git a/pypesto/visualize/parameters.py b/pypesto/visualize/parameters.py index 470531310..9c1d7e88e 100644 --- a/pypesto/visualize/parameters.py +++ b/pypesto/visualize/parameters.py @@ -5,6 +5,7 @@ import numpy as np import pandas as pd from matplotlib.colors import Colormap +from matplotlib.lines import Line2D from matplotlib.ticker import MaxNLocator from pypesto.util import delete_nan_inf @@ -20,9 +21,12 @@ from ._style import resolve_style from .clust_color import assign_colors from .misc import ( + _UNSET, get_ax, get_axes_array, + plot_density_panel, plot_diagonal_marginal, + process_deprecated_kwarg, process_parameter_indices, process_result_list, process_start_indices, @@ -215,39 +219,78 @@ def scale_parameters(x): def parameter_hist( result: Result, parameter_name: str, + start_indices: int | list[int] | None = None, + plot_type: str = "both", bins: int | str = "auto", + bw_method: str = "scott", + show_bounds: bool = True, + title: str | None = "Parameter histogram", + size: tuple[float, float] | None = None, ax: matplotlib.axes.Axes | None = None, - size: tuple[float, float] | None = (18.5, 10.5), - color: COLOR | None = None, - start_indices: int | list[int] | None = None, + style_kwargs: dict | None = None, + color: COLOR = _UNSET, ) -> matplotlib.axes.Axes: """ - Plot parameter values as a histogram. + Plot one parameter's values across starts as a histogram + KDE + rug. Parameters ---------- result: - Optimization result obtained by 'optimize.py' + Optimization result obtained by 'optimize.py'. parameter_name: - The name of the parameter that should be plotted + Name of the parameter to plot. + start_indices: + Which optimization starts to include: a list of indices, or an int + ``n`` for the first ``n`` starts. Default: all starts. + plot_type: {'hist'|'kde'|'both'} + Histogram only, KDE line only, or both with rug marks (default). bins: - Specifies bins of the histogram - ax: - Axes object to use + Number of bins, or a matplotlib binning strategy (``'auto'``, + ``'sturges'``, …). Passed to ``ax.hist``. + bw_method: {'scott', 'silverman' | scalar | pair of scalars} + Kernel bandwidth method for the KDE overlay. + show_bounds: + If ``True`` (default) draw the parameter bound lines and frame the + x-axis to include them; if ``False`` frame tightly to the data. + title: + Axes title. Pass ``None`` to suppress. size: - Figure size (width, height) in inches. Is only applied when no ax - object is specified + Figure size in inches. Defaults to matplotlib's default. + ax: + Axes object to use. + style_kwargs: + Style overrides. Keys used by this function: + + - ``rectangle_color``, ``rectangle_alpha``, ``rectangle_edgecolor``, + ``rectangle_linewidth`` — histogram bar styling. + - ``line_color``, ``linewidth`` — KDE curve styling. + - ``dash_color``, ``dash_linewidth``, ``dash_markersize``, + ``dash_alpha`` — rug-mark styling. + - ``bound_color``, ``bound_linestyle``, ``bound_linewidth``, + ``bound_alpha`` — parameter-bound line styling. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. color: - Color recognized by matplotlib. - start_indices: - List of integers specifying the multistarts to be plotted or - int specifying up to which start index should be plotted + Deprecated. Pass ``style_kwargs`` instead — see + ``rectangle_color`` / ``line_color`` / ``dash_color`` above. Returns ------- ax: The plot axes. """ + process_deprecated_kwarg( + canonical_name=None, + canonical_value=None, + deprecated_name="color", + deprecated_value=color, + note=( + "Pass style_kwargs={'rectangle_color': ..., 'line_color': ..., " + "'dash_color': ...} instead." + ), + ) + style = resolve_style(style_kwargs) ax = get_ax(ax, size) xs = result.optimize_result.x @@ -259,12 +302,68 @@ def parameter_hist( xs = [xs[ind] for ind in start_indices] parameter_index = result.problem.x_names.index(parameter_name) - parameter_values = [x[parameter_index] for x in xs] + parameter_values = np.array([x[parameter_index] for x in xs]) + + # bounds and scale for this parameter + lb_val = result.problem.lb_full[parameter_index] + ub_val = result.problem.ub_full[parameter_index] + x_scales = getattr(result.problem, "x_scales", None) + scale = x_scales[parameter_index] if x_scales is not None else None + + bound_handle = plot_density_panel( + ax, + parameter_values, + bins=bins, + bw_method=bw_method, + style=style, + show_hist=(plot_type in ("hist", "both")), + show_kde=(plot_type in ("kde", "both")), + show_rug=(plot_type in ("hist", "both")), + show_bounds=show_bounds, + lb=lb_val, + ub=ub_val, + ) - ax.hist(parameter_values, color=color, bins=bins, label=parameter_name) - ax.set_xlabel(parameter_name) - ax.set_ylabel("counts") - ax.set_title(f"{parameter_name}") + legend_handles, legend_labels = [], [] + show_kde = plot_type in ("kde", "both") + show_rug = plot_type in ("hist", "both") + finite_vals = parameter_values[np.isfinite(parameter_values)] + if finite_vals.size > 0: + if show_kde: + legend_handles.append( + Line2D( + [0], [0], color=style["line_color"], lw=style["linewidth"] + ) + ) + legend_labels.append("KDE") + if show_rug: + legend_handles.append( + Line2D( + [0], + [0], + color=style["dash_color"], + marker="|", + lw=0, + markersize=style["dash_markersize"], + markeredgewidth=style["dash_linewidth"], + ) + ) + legend_labels.append("Starts") + + if bound_handle is not None: + legend_handles.append(bound_handle) + legend_labels.append("Bounds") + + if legend_handles: + ax.legend(handles=legend_handles, labels=legend_labels) + + xlabel = ( + f"{parameter_name} ({scale})" if scale is not None else parameter_name + ) + if title is not None: + ax.set_title(title) + ax.set_xlabel(xlabel) + ax.set_ylabel("Density") return ax diff --git a/pypesto/visualize/sampling.py b/pypesto/visualize/sampling.py index bb520d352..fb7ab3c37 100644 --- a/pypesto/visualize/sampling.py +++ b/pypesto/visualize/sampling.py @@ -25,12 +25,14 @@ from ..ensemble import EnsemblePrediction, get_percentile_label from ..result import McmcPtResult, PredictionResult, Result from ..sample import calculate_ci_mcmc_sample +from ._style import resolve_style from .misc import ( _UNSET, get_ax, get_axes_array, hide_unused_axes, make_grid_shape, + plot_density_panel, plot_diagonal_marginal, process_deprecated_kwarg, rgba2rgb, @@ -1308,48 +1310,78 @@ def sampling_scatter( def sampling_1d_marginals( result: Result, i_chain: int = 0, - parameter_indices: Sequence[int] = None, + parameter_indices: Sequence[int] | None = None, stepsize: int = 1, plot_type: str = "both", + bins: int | str = "auto", bw_method: str = "scott", - suptitle: str | None = None, + show_bounds: bool = True, + title: str | None = None, size: tuple[float, float] | None = None, axes: np.ndarray | None = None, + style_kwargs: dict | None = None, par_indices: Sequence[int] = _UNSET, + suptitle: str | None = _UNSET, ) -> np.ndarray: """ - Plot marginals. + Plot 1-D marginals of the sampled parameters as histogram + KDE + rug. Parameters ---------- result: The pyPESTO result object with filled sample result. i_chain: - Which chain to plot. Default: First chain. - parameter_indices: list of integer values - List of integer values specifying which parameters to plot. - Default: All parameters are shown. + Which chain to plot. Default: first chain. + parameter_indices: + Which parameters to plot, as a list of indices. Default: all parameters. stepsize: - Only one in `stepsize` values is plotted. + Thinning factor — plot every ``stepsize``-th sample (``1`` = all). + Reduces overplotting and speeds up rendering for long chains. plot_type: {'hist'|'kde'|'both'} - Specify whether to plot a histogram ('hist'), a kernel density estimate - ('kde'), or both ('both'). + Histogram only, KDE line only, or both with rug marks (default). + bins: + Number of bins, or a matplotlib binning strategy (``'auto'``, + ``'sturges'``, …). Passed to ``ax.hist``. bw_method: {'scott', 'silverman' | scalar | pair of scalars} - Kernel bandwidth method. - suptitle: - Figure super title. + Kernel bandwidth method for the KDE overlay. + show_bounds: + If ``True`` (default) draw the parameter bound lines and frame each + panel's x-axis to include them; if ``False`` frame each panel tightly + to its data. + title: + Figure title. Default: none (grids omit a title by default). size: - Figure size in inches. + Figure size in inches. When ``None`` the grid uses + ``GRID_SIZE_PER_COL * num_col`` × ``GRID_SIZE_PER_ROW * num_row`` + (defaults from :mod:`pypesto.visualize._style`). axes: Axes grid to use. Must match the computed subplot layout. + style_kwargs: + Style overrides. Keys used by this function: + + - ``rectangle_color``, ``rectangle_alpha``, ``rectangle_edgecolor``, + ``rectangle_linewidth`` — histogram bar styling. + - ``line_color``, ``linewidth`` — KDE curve styling. + - ``dash_color``, ``dash_linewidth``, ``dash_markersize``, + ``dash_alpha`` — rug-mark styling. + - ``bound_color``, ``bound_linestyle``, ``bound_linewidth``, + ``bound_alpha`` — parameter-bound line styling. + + All valid keys and their defaults are listed in + :data:`pypesto.visualize._style._DEFAULTS`. par_indices: Deprecated. Use ``parameter_indices`` instead. + suptitle: + Deprecated. Use ``title`` instead. - Return - -------- + Returns + ------- axes: 2-D NumPy array containing one matplotlib Axes per panel. """ + style = resolve_style(style_kwargs) + title = process_deprecated_kwarg("title", title, "suptitle", suptitle) + parameter_indices = process_deprecated_kwarg( "parameter_indices", parameter_indices, @@ -1357,8 +1389,6 @@ def sampling_1d_marginals( par_indices, ) - import seaborn as sns - # get data which should be plotted nr_params, params_fval, theta_lb, theta_ub, param_names = get_data_to_plot( result=result, @@ -1368,43 +1398,93 @@ def sampling_1d_marginals( ) num_row, num_col = make_grid_shape(nr_params) - if size is None and axes is None: - size = (3.5 * num_col, 2.5 * num_row) axes = get_axes_array(axes=axes, nrows=num_row, ncols=num_col, size=size) fig = axes.flat[0].figure axes = hide_unused_axes(axes=axes, n_used=nr_params, clear=True) - par_ax = dict(zip(param_names, axes.flat, strict=True)) + par_ax = dict(zip(param_names, axes.flat[:nr_params], strict=True)) + + # Build name→index map for looking up per-parameter lb/ub/scale. + all_reduced_names = result.problem.get_reduced_vector( + result.problem.x_names + ) + name_to_reduced_idx = {name: i for i, name in enumerate(all_reduced_names)} + x_scales_reduced = ( + result.problem.get_reduced_vector(result.problem.x_scales) + if getattr(result.problem, "x_scales", None) is not None + else None + ) + + _show_kde = plot_type in ("kde", "both") + _show_rug = plot_type in ("hist", "both") - # fig, ax = plt.subplots(nr_params, figsize=size)[1] for idx, par_id in enumerate(param_names): - if plot_type == "kde": - # TODO: add bw_adjust as option? - sns.kdeplot( - params_fval[par_id], bw_method=bw_method, ax=par_ax[par_id] - ) - elif plot_type == "hist": - # fixes usage of sns distplot which throws a future warning - sns.histplot( - x=params_fval[par_id], ax=par_ax[par_id], stat="density" - ) - sns.rugplot(x=params_fval[par_id], ax=par_ax[par_id]) - elif plot_type == "both": - sns.histplot( - x=params_fval[par_id], - kde=True, - ax=par_ax[par_id], - stat="density", - ) - sns.rugplot(x=params_fval[par_id], ax=par_ax[par_id]) + ax = par_ax[par_id] + vals = np.asarray(params_fval[par_id]) + finite_vals = vals[np.isfinite(vals)] + par_reduced_idx = name_to_reduced_idx.get(par_id, idx) + lb_val = theta_lb[par_reduced_idx] + ub_val = theta_ub[par_reduced_idx] + + bound_handle = plot_density_panel( + ax, + vals, + bins=bins, + bw_method=bw_method, + style=style, + show_hist=(plot_type in ("hist", "both")), + show_kde=_show_kde, + show_rug=_show_rug, + show_bounds=show_bounds, + lb=lb_val, + ub=ub_val, + ) - par_ax[par_id].set_xlabel(param_names[idx]) - par_ax[par_id].set_ylabel("Density") + legend_handles, legend_labels = [], [] + if finite_vals.size > 0 and idx == 0: + if _show_kde: + legend_handles.append( + Line2D( + [0], + [0], + color=style["line_color"], + lw=style["linewidth"], + ) + ) + legend_labels.append("KDE") + if _show_rug: + legend_handles.append( + Line2D( + [0], + [0], + color=style["dash_color"], + marker="|", + lw=0, + markersize=style["dash_markersize"], + markeredgewidth=style["dash_linewidth"], + ) + ) + legend_labels.append("Samples") - sns.despine() + if bound_handle is not None and idx == 0: + legend_handles.append(bound_handle) + legend_labels.append("Bounds") - if suptitle: - fig.suptitle(suptitle) + if legend_handles: + ax.legend(handles=legend_handles, labels=legend_labels) + + scale = ( + x_scales_reduced[par_reduced_idx] + if x_scales_reduced is not None + else None + ) + xlabel = f"{par_id} ({scale})" if scale is not None else par_id + ax.set_xlabel(xlabel) + # y-label only on the leftmost column to avoid grid-wide repetition + ax.set_ylabel("Density" if idx % num_col == 0 else "") + + if title is not None: + fig.suptitle(title) return axes