From 6cfc6bdff1d244eded3678763a0b06e474c24217 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Mon, 18 May 2026 15:46:42 +0200 Subject: [PATCH 1/9] Fix PyMC posterior extraction for newer ArviZ --- pypesto/sample/pymc.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 3b1de1933..0d9459203 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -251,10 +251,13 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" + posterior = self.data.posterior + if not hasattr(posterior, "to_array"): + # Newer ArviZ/xarray returns a DataTree group here. + posterior = posterior.dataset + # dimensions - n_par, n_chain, n_iter = np.asarray( - self.data.posterior.to_array() - ).shape + n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape n_par -= 1 # remove log-posterior # parameters @@ -263,10 +266,10 @@ def get_samples(self) -> McmcPtResult: if len(par_ids) != n_par: raise AssertionError("Mismatch of parameter dimension") for i_par, par_id in enumerate(par_ids): - trace_x[:, :, i_par] = np.asarray(self.data.posterior[par_id]) + trace_x[:, :, i_par] = np.asarray(posterior[par_id]) # function values - trace_neglogpost = -np.asarray(self.data.posterior["loggyposty"]) + trace_neglogpost = -np.asarray(posterior["loggyposty"]) if ( trace_x.shape[0] != trace_neglogpost.shape[0] From 9ddcb73180b34501a008d8502faf2a208ac3e396 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Mon, 18 May 2026 16:12:58 +0200 Subject: [PATCH 2/9] Fix PyMC outputs with newer ArviZ --- pypesto/variational/pymc.py | 17 ++++++++++++----- 1 file changed, 12 insertions(+), 5 deletions(-) diff --git a/pypesto/variational/pymc.py b/pypesto/variational/pymc.py index d8b4cf711..3b0947b4b 100644 --- a/pypesto/variational/pymc.py +++ b/pypesto/variational/pymc.py @@ -89,7 +89,7 @@ def fit( } # create model context - with pymc.Model(): + with pymc.Model() as model: # parameter bounds as uniform prior _k = [ pymc.Uniform(x_name, lower=lb, upper=ub) @@ -119,6 +119,7 @@ def fit( ) self.data = data + self.model = model def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: """ @@ -130,16 +131,22 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: Number of samples to be computed. """ # get InferenceData object - pymc_data = self.data.sample(n_samples) + with self.model: + pymc_data = self.data.sample(n_samples) + posterior = pymc_data.posterior + if not hasattr(posterior, "to_array"): + # Newer ArviZ/xarray returns a DataTree group here. + posterior = posterior.dataset + x_names_free = self.problem.get_reduced_vector(self.problem.x_names) post_samples = np.concatenate( - [pymc_data.posterior[name].values for name in x_names_free] + [posterior[name].values for name in x_names_free] ).T return McmcPtResult( trace_x=post_samples[np.newaxis, :], - trace_neglogpost=pymc_data.posterior.loggyposty.values, + trace_neglogpost=posterior.loggyposty.values, trace_neglogprior=np.full( - pymc_data.posterior.loggyposty.values.shape, np.nan + posterior.loggyposty.values.shape, np.nan ), betas=np.array([1.0] * post_samples.shape[0]), burn_in=0, From 4217d4e1d33544b36daa80df4f6ce57caa9c39c4 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Wed, 20 May 2026 15:11:50 +0200 Subject: [PATCH 3/9] add comment to change when bumped --- pypesto/sample/pymc.py | 4 +++- pypesto/variational/pymc.py | 4 +++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 0d9459203..742f7577e 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -252,8 +252,10 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" posterior = self.data.posterior + # ArviZ may expose posterior either as an xarray Dataset (older) or + # as a DataTree group (newer). If the minimum ArviZ version is bumped + # to DataTree-only, simplify this to `self.data.posterior.dataset`. if not hasattr(posterior, "to_array"): - # Newer ArviZ/xarray returns a DataTree group here. posterior = posterior.dataset # dimensions diff --git a/pypesto/variational/pymc.py b/pypesto/variational/pymc.py index 3b0947b4b..4c41fd0b7 100644 --- a/pypesto/variational/pymc.py +++ b/pypesto/variational/pymc.py @@ -134,8 +134,10 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: with self.model: pymc_data = self.data.sample(n_samples) posterior = pymc_data.posterior + # ArviZ may expose posterior either as an xarray Dataset (older) or + # as a DataTree group (newer). If the minimum ArviZ version is bumped + # to DataTree-only, simplify this to `pymc_data.posterior.dataset`. if not hasattr(posterior, "to_array"): - # Newer ArviZ/xarray returns a DataTree group here. posterior = posterior.dataset x_names_free = self.problem.get_reduced_vector(self.problem.x_names) From 8964bd35078fd8ad9c49b147eca5241093f81d3a Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Fri, 22 May 2026 15:39:45 +0200 Subject: [PATCH 4/9] Bump arviz>=0.21 and drop old code --- pypesto/sample/pymc.py | 7 +------ pypesto/variational/pymc.py | 7 +------ pyproject.toml | 2 +- 3 files changed, 3 insertions(+), 13 deletions(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 742f7577e..6fcc2ede3 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -251,12 +251,7 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" - posterior = self.data.posterior - # ArviZ may expose posterior either as an xarray Dataset (older) or - # as a DataTree group (newer). If the minimum ArviZ version is bumped - # to DataTree-only, simplify this to `self.data.posterior.dataset`. - if not hasattr(posterior, "to_array"): - posterior = posterior.dataset + posterior = self.data.posterior.dataset # dimensions n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape diff --git a/pypesto/variational/pymc.py b/pypesto/variational/pymc.py index 4c41fd0b7..527cda008 100644 --- a/pypesto/variational/pymc.py +++ b/pypesto/variational/pymc.py @@ -133,12 +133,7 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: # get InferenceData object with self.model: pymc_data = self.data.sample(n_samples) - posterior = pymc_data.posterior - # ArviZ may expose posterior either as an xarray Dataset (older) or - # as a DataTree group (newer). If the minimum ArviZ version is bumped - # to DataTree-only, simplify this to `pymc_data.posterior.dataset`. - if not hasattr(posterior, "to_array"): - posterior = posterior.dataset + posterior = pymc_data.posterior.dataset x_names_free = self.problem.get_reduced_vector(self.problem.x_names) post_samples = np.concatenate( diff --git a/pyproject.toml b/pyproject.toml index 0e336b3be..4a5ba4235 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ mpi = [ ] pymc = [ - "arviz>=0.12.1", + "arviz>=0.21", "pymc>=4.2.1", ] From 7a1e849dd9baa890f453a328a22cdb2a5a26c27c Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Fri, 22 May 2026 15:57:34 +0200 Subject: [PATCH 5/9] Fix: drop incorrect .dataset access --- pypesto/sample/pymc.py | 10 +++++----- pypesto/variational/pymc.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 6fcc2ede3..3b1de1933 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -251,10 +251,10 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" - posterior = self.data.posterior.dataset - # dimensions - n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape + n_par, n_chain, n_iter = np.asarray( + self.data.posterior.to_array() + ).shape n_par -= 1 # remove log-posterior # parameters @@ -263,10 +263,10 @@ def get_samples(self) -> McmcPtResult: if len(par_ids) != n_par: raise AssertionError("Mismatch of parameter dimension") for i_par, par_id in enumerate(par_ids): - trace_x[:, :, i_par] = np.asarray(posterior[par_id]) + trace_x[:, :, i_par] = np.asarray(self.data.posterior[par_id]) # function values - trace_neglogpost = -np.asarray(posterior["loggyposty"]) + trace_neglogpost = -np.asarray(self.data.posterior["loggyposty"]) if ( trace_x.shape[0] != trace_neglogpost.shape[0] diff --git a/pypesto/variational/pymc.py b/pypesto/variational/pymc.py index 527cda008..cc91233fd 100644 --- a/pypesto/variational/pymc.py +++ b/pypesto/variational/pymc.py @@ -133,7 +133,7 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: # get InferenceData object with self.model: pymc_data = self.data.sample(n_samples) - posterior = pymc_data.posterior.dataset + posterior = pymc_data.posterior x_names_free = self.problem.get_reduced_vector(self.problem.x_names) post_samples = np.concatenate( From ae13c6b2d074868e2e786ff966cee58621de5c04 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Tue, 26 May 2026 15:39:45 +0200 Subject: [PATCH 6/9] require arviz 1.1.0 -- that's the correct one --- pypesto/sample/pymc.py | 24 +++++++++++------------- pyproject.toml | 2 +- 2 files changed, 12 insertions(+), 14 deletions(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 3b1de1933..f536c2730 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -4,6 +4,7 @@ import importlib import logging +from typing import TYPE_CHECKING import numpy as np @@ -15,10 +16,12 @@ logger = logging.getLogger(__name__) -# Lazy import of pymc, arviz, and pytensor +# Lazy import of pymc and pytensor # Check availability once at module load time _HAS_PYMC = importlib.util.find_spec("pymc") is not None -_HAS_ARVIZ = importlib.util.find_spec("arviz") is not None + +if TYPE_CHECKING: + import xarray as xr if _HAS_PYMC: import pymc @@ -30,11 +33,6 @@ pt = None _PT_OP_BASE = object -if _HAS_ARVIZ: - import arviz as az -else: - az = None - # implementation based on: # https://www.pymc.io/projects/examples/en/latest/case_studies/blackbox_external_likelihood_numpy.html @@ -151,7 +149,7 @@ def __init__( self.problem: Problem | None = None self.x0: np.ndarray | None = None self.trace: pymc.backends.Text | None = None - self.data: az.InferenceData | None = None + self.data: xr.DataTree | None = None @classmethod def translate_options(cls, options): @@ -251,10 +249,10 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" + posterior = self.data.posterior + # dimensions - n_par, n_chain, n_iter = np.asarray( - self.data.posterior.to_array() - ).shape + n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape n_par -= 1 # remove log-posterior # parameters @@ -263,10 +261,10 @@ def get_samples(self) -> McmcPtResult: if len(par_ids) != n_par: raise AssertionError("Mismatch of parameter dimension") for i_par, par_id in enumerate(par_ids): - trace_x[:, :, i_par] = np.asarray(self.data.posterior[par_id]) + trace_x[:, :, i_par] = np.asarray(posterior[par_id]) # function values - trace_neglogpost = -np.asarray(self.data.posterior["loggyposty"]) + trace_neglogpost = -np.asarray(posterior["loggyposty"]) if ( trace_x.shape[0] != trace_neglogpost.shape[0] diff --git a/pyproject.toml b/pyproject.toml index 4a5ba4235..12e40c20e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,7 @@ mpi = [ ] pymc = [ - "arviz>=0.21", + "arviz>=1.1.0", "pymc>=4.2.1", ] From 5bac1cac22b7c483ca19a22c8d94c9adf2903c3d Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Tue, 26 May 2026 15:55:07 +0200 Subject: [PATCH 7/9] Support PyMC posterior extraction across ArviZ versions - Use Python-version-specific ArviZ requirements so Python 3.11 installs ArviZ 0.x while Python 3.12+ uses ArviZ 1.1+ - Keep PyMC posterior extraction on the common posterior interface shared by old and new ArviZ result objects --- pypesto/sample/pymc.py | 7 ++----- pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 6 deletions(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index f536c2730..971d54fea 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -4,7 +4,7 @@ import importlib import logging -from typing import TYPE_CHECKING +from typing import Any import numpy as np @@ -20,9 +20,6 @@ # Check availability once at module load time _HAS_PYMC = importlib.util.find_spec("pymc") is not None -if TYPE_CHECKING: - import xarray as xr - if _HAS_PYMC: import pymc import pytensor.tensor as pt @@ -149,7 +146,7 @@ def __init__( self.problem: Problem | None = None self.x0: np.ndarray | None = None self.trace: pymc.backends.Text | None = None - self.data: xr.DataTree | None = None + self.data: Any | None = None @classmethod def translate_options(cls, options): diff --git a/pyproject.toml b/pyproject.toml index 12e40c20e..5105cc0b6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,8 @@ mpi = [ ] pymc = [ - "arviz>=1.1.0", + "arviz>=0.23.4,<1.0; python_version < '3.12'", + "arviz>=1.1.0; python_version >= '3.12'", "pymc>=4.2.1", ] From d0bd566e5e72c17e33a370ceea87b4b10220e169 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Tue, 26 May 2026 17:08:40 +0200 Subject: [PATCH 8/9] some final cleanups --- pypesto/sample/pymc.py | 10 +++++++++- pypesto/variational/pymc.py | 8 ++++++-- 2 files changed, 15 insertions(+), 3 deletions(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 971d54fea..f69dbab9d 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -34,6 +34,14 @@ # https://www.pymc.io/projects/examples/en/latest/case_studies/blackbox_external_likelihood_numpy.html +def _get_posterior_dataset(data: Any) -> Any: + """Return posterior as an xarray Dataset across ArviZ versions.""" + posterior = data.posterior + if hasattr(posterior, "to_array"): + return posterior + return posterior.to_dataset() + + class PymcObjectiveOp(_PT_OP_BASE): """PyTensor wrapper around a (non-normalized) log-probability function.""" @@ -246,7 +254,7 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" - posterior = self.data.posterior + posterior = _get_posterior_dataset(self.data) # dimensions n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape diff --git a/pypesto/variational/pymc.py b/pypesto/variational/pymc.py index cc91233fd..f31adcb61 100644 --- a/pypesto/variational/pymc.py +++ b/pypesto/variational/pymc.py @@ -9,7 +9,11 @@ from ..objective import FD from ..result import McmcPtResult -from ..sample.pymc import PymcObjectiveOp, PymcSampler +from ..sample.pymc import ( + PymcObjectiveOp, + PymcSampler, + _get_posterior_dataset, +) from ..sample.sampler import SamplerImportError logger = logging.getLogger(__name__) @@ -133,7 +137,7 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: # get InferenceData object with self.model: pymc_data = self.data.sample(n_samples) - posterior = pymc_data.posterior + posterior = _get_posterior_dataset(pymc_data) x_names_free = self.problem.get_reduced_vector(self.problem.x_names) post_samples = np.concatenate( From da44abaec4245168df4f25a67b08af2fed3b2ee8 Mon Sep 17 00:00:00 2001 From: Doresic <85789271+Doresic@users.noreply.github.com> Date: Tue, 26 May 2026 18:08:31 +0200 Subject: [PATCH 9/9] add TODOs --- pypesto/sample/pymc.py | 2 ++ pyproject.toml | 3 ++- 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index f69dbab9d..70042c905 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -34,6 +34,8 @@ # https://www.pymc.io/projects/examples/en/latest/case_studies/blackbox_external_likelihood_numpy.html +# TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0 +# and simplify this helper to `data.posterior.to_dataset()`. def _get_posterior_dataset(data: Any) -> Any: """Return posterior as an xarray Dataset across ArviZ versions.""" posterior = data.posterior diff --git a/pyproject.toml b/pyproject.toml index 5105cc0b6..d5b62ed13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,8 @@ mpi = [ ] pymc = [ - "arviz>=0.23.4,<1.0; python_version < '3.12'", + # TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0. + "arviz>=0.12.1,<1.0; python_version < '3.12'", "arviz>=1.1.0; python_version >= '3.12'", "pymc>=4.2.1", ]