Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions pypesto/sample/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import importlib
import logging
from typing import Any

import numpy as np

Expand All @@ -15,10 +16,9 @@

logger = logging.getLogger(__name__)

# Lazy import of pymc, arviz, and pytensor
# Lazy import of pymc and pytensor
# Check availability once at module load time
_HAS_PYMC = importlib.util.find_spec("pymc") is not None
_HAS_ARVIZ = importlib.util.find_spec("arviz") is not None

if _HAS_PYMC:
import pymc
Expand All @@ -30,15 +30,20 @@
pt = None
_PT_OP_BASE = object

if _HAS_ARVIZ:
import arviz as az
else:
az = None

# implementation based on:
# https://www.pymc.io/projects/examples/en/latest/case_studies/blackbox_external_likelihood_numpy.html


# TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0
# and simplify this helper to `data.posterior.to_dataset()`.
def _get_posterior_dataset(data: Any) -> Any:
"""Return posterior as an xarray Dataset across ArviZ versions."""
posterior = data.posterior
if hasattr(posterior, "to_array"):
return posterior
return posterior.to_dataset()


class PymcObjectiveOp(_PT_OP_BASE):
"""PyTensor wrapper around a (non-normalized) log-probability function."""

Expand Down Expand Up @@ -151,7 +156,7 @@ def __init__(
self.problem: Problem | None = None
self.x0: np.ndarray | None = None
self.trace: pymc.backends.Text | None = None
self.data: az.InferenceData | None = None
self.data: Any | None = None

@classmethod
def translate_options(cls, options):
Expand Down Expand Up @@ -251,10 +256,10 @@ def sample(self, n_samples: int, beta: float = 1.0):

def get_samples(self) -> McmcPtResult:
"""Convert result from pymc to McmcPtResult."""
posterior = _get_posterior_dataset(self.data)

# dimensions
n_par, n_chain, n_iter = np.asarray(
self.data.posterior.to_array()
).shape
n_par, n_chain, n_iter = np.asarray(posterior.to_array()).shape
n_par -= 1 # remove log-posterior

# parameters
Expand All @@ -263,10 +268,10 @@ def get_samples(self) -> McmcPtResult:
if len(par_ids) != n_par:
raise AssertionError("Mismatch of parameter dimension")
for i_par, par_id in enumerate(par_ids):
trace_x[:, :, i_par] = np.asarray(self.data.posterior[par_id])
trace_x[:, :, i_par] = np.asarray(posterior[par_id])

# function values
trace_neglogpost = -np.asarray(self.data.posterior["loggyposty"])
trace_neglogpost = -np.asarray(posterior["loggyposty"])

if (
trace_x.shape[0] != trace_neglogpost.shape[0]
Expand Down
20 changes: 14 additions & 6 deletions pypesto/variational/pymc.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,11 @@

from ..objective import FD
from ..result import McmcPtResult
from ..sample.pymc import PymcObjectiveOp, PymcSampler
from ..sample.pymc import (
PymcObjectiveOp,
PymcSampler,
_get_posterior_dataset,
)
from ..sample.sampler import SamplerImportError

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -89,7 +93,7 @@ def fit(
}

# create model context
with pymc.Model():
with pymc.Model() as model:
# parameter bounds as uniform prior
_k = [
pymc.Uniform(x_name, lower=lb, upper=ub)
Expand Down Expand Up @@ -119,6 +123,7 @@ def fit(
)

self.data = data
self.model = model

def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult:
"""
Expand All @@ -130,16 +135,19 @@ def sample(self, n_samples: int, beta: float = 1.0) -> McmcPtResult:
Number of samples to be computed.
"""
# get InferenceData object
pymc_data = self.data.sample(n_samples)
with self.model:
pymc_data = self.data.sample(n_samples)
posterior = _get_posterior_dataset(pymc_data)

x_names_free = self.problem.get_reduced_vector(self.problem.x_names)
post_samples = np.concatenate(
[pymc_data.posterior[name].values for name in x_names_free]
[posterior[name].values for name in x_names_free]
).T
return McmcPtResult(
trace_x=post_samples[np.newaxis, :],
trace_neglogpost=pymc_data.posterior.loggyposty.values,
trace_neglogpost=posterior.loggyposty.values,
trace_neglogprior=np.full(
pymc_data.posterior.loggyposty.values.shape, np.nan
posterior.loggyposty.values.shape, np.nan
),
betas=np.array([1.0] * post_samples.shape[0]),
burn_in=0,
Expand Down
4 changes: 3 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,9 @@ mpi = [
]

pymc = [
"arviz>=0.12.1",
# TODO: once Python 3.11 support is dropped, require only ArviZ >=1.1.0.
"arviz>=0.12.1,<1.0; python_version < '3.12'",
"arviz>=1.1.0; python_version >= '3.12'",
"pymc>=4.2.1",
]

Expand Down
Loading