diff --git a/pypesto/sample/pymc.py b/pypesto/sample/pymc.py index 3b1de1933..70042c905 100644 --- a/pypesto/sample/pymc.py +++ b/pypesto/sample/pymc.py @@ -4,6 +4,7 @@ import importlib import logging +from typing import Any import numpy as np @@ -15,10 +16,9 @@ logger = logging.getLogger(__name__) -# Lazy import of pymc, arviz, and pytensor +# Lazy import of pymc and pytensor # Check availability once at module load time _HAS_PYMC = importlib.util.find_spec("pymc") is not None -_HAS_ARVIZ = importlib.util.find_spec("arviz") is not None if _HAS_PYMC: import pymc @@ -30,15 +30,20 @@ pt = None _PT_OP_BASE = object -if _HAS_ARVIZ: - import arviz as az -else: - az = None - # implementation based on: # https://www.pymc.io/projects/examples/en/latest/case_studies/blackbox_external_likelihood_numpy.html +# TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0 +# and simplify this helper to `data.posterior.to_dataset()`. +def _get_posterior_dataset(data: Any) -> Any: + """Return posterior as an xarray Dataset across ArviZ versions.""" + posterior = data.posterior + if hasattr(posterior, "to_array"): + return posterior + return posterior.to_dataset() + + class PymcObjectiveOp(_PT_OP_BASE): """PyTensor wrapper around a (non-normalized) log-probability function.""" @@ -151,7 +156,7 @@ def __init__( self.problem: Problem | None = None self.x0: np.ndarray | None = None self.trace: pymc.backends.Text | None = None - self.data: az.InferenceData | None = None + self.data: Any | None = None @classmethod def translate_options(cls, options): @@ -251,10 +256,10 @@ def sample(self, n_samples: int, beta: float = 1.0): def get_samples(self) -> McmcPtResult: """Convert result from pymc to McmcPtResult.""" + posterior = _get_posterior_dataset(self.data) + # dimensions - n_par, n_chain, n_iter = np.asarray( - self.data.posterior.to_array() - ).shape + n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape n_par -= 1 # remove log-posterior # parameters @@ -263,10 +268,10 @@ def get_samples(self) -> McmcPtResult: if len(par_ids) != n_par: raise AssertionError("Mismatch of parameter dimension") for i_par, par_id in enumerate(par_ids): - trace_x[:, :, i_par] = np.asarray(self.data.posterior[par_id]) + trace_x[:, :, i_par] = np.asarray(posterior[par_id]) # function values - trace_neglogpost = -np.asarray(self.data.posterior["loggyposty"]) + trace_neglogpost = -np.asarray(posterior["loggyposty"]) if ( trace_x.shape[0] != trace_neglogpost.shape[0] diff --git a/pypesto/variational/pymc.py b/pypesto/variational/pymc.py index d8b4cf711..f31adcb61 100644 --- a/pypesto/variational/pymc.py +++ b/pypesto/variational/pymc.py @@ -9,7 +9,11 @@ from ..objective import FD from ..result import McmcPtResult -from ..sample.pymc import PymcObjectiveOp, PymcSampler +from ..sample.pymc import ( + PymcObjectiveOp, + PymcSampler, + _get_posterior_dataset, +) from ..sample.sampler import SamplerImportError logger = logging.getLogger(__name__) @@ -89,7 +93,7 @@ def fit( } # create model context - with pymc.Model(): + with pymc.Model() as model: # parameter bounds as uniform prior _k = [ pymc.Uniform(x_name, lower=lb, upper=ub) @@ -119,6 +123,7 @@ def fit( ) self.data = data + self.model = model def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: """ @@ -130,16 +135,19 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult: Number of samples to be computed. """ # get InferenceData object - pymc_data = self.data.sample(n_samples) + with self.model: + pymc_data = self.data.sample(n_samples) + posterior = _get_posterior_dataset(pymc_data) + x_names_free = self.problem.get_reduced_vector(self.problem.x_names) post_samples = np.concatenate( - [pymc_data.posterior[name].values for name in x_names_free] + [posterior[name].values for name in x_names_free] ).T return McmcPtResult( trace_x=post_samples[np.newaxis, :], - trace_neglogpost=pymc_data.posterior.loggyposty.values, + trace_neglogpost=posterior.loggyposty.values, trace_neglogprior=np.full( - pymc_data.posterior.loggyposty.values.shape, np.nan + posterior.loggyposty.values.shape, np.nan ), betas=np.array([1.0] * post_samples.shape[0]), burn_in=0, diff --git a/pyproject.toml b/pyproject.toml index 0e336b3be..d5b62ed13 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -106,7 +106,9 @@ mpi = [ ] pymc = [ - "arviz>=0.12.1", + # TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0. + "arviz>=0.12.1,<1.0; python_version < '3.12'", + "arviz>=1.1.0; python_version >= '3.12'", "pymc>=4.2.1", ]