Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions conda_envs/environment-test.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,8 +25,3 @@ dependencies:
- threadpoolctl
- typing_extensions
- pip
- pip:
- jax
- jaxlib
- optax
- gpjax
8 changes: 5 additions & 3 deletions ptgp/kernels/stationary.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import pytensor
import pytensor.tensor as pt

from ptgp.kernels.base import Kernel
Expand Down Expand Up @@ -87,8 +88,9 @@ class Matern52(Stationary):

def _eval(self, X, Y):
tau = self._scaled_euclid_dist(X, Y)
sqrt5 = pt.sqrt(5.0)
return (1.0 + sqrt5 * tau + 5.0 / 3.0 * pt.square(tau)) * pt.exp(-sqrt5 * tau)
sqrt5 = pt.constant(5.0**0.5, dtype=pytensor.config.floatX)
five_thirds = pt.constant(5.0 / 3.0, dtype=pytensor.config.floatX)
return (1.0 + sqrt5 * tau + five_thirds * pt.square(tau)) * pt.exp(-sqrt5 * tau)


class Matern32(Stationary):
Expand All @@ -100,7 +102,7 @@ class Matern32(Stationary):

def _eval(self, X, Y):
tau = self._scaled_euclid_dist(X, Y)
sqrt3 = pt.sqrt(3.0)
sqrt3 = pt.constant(3.0**0.5, dtype=pytensor.config.floatX)
return (1.0 + sqrt3 * tau) * pt.exp(-sqrt3 * tau)


Expand Down
77 changes: 77 additions & 0 deletions tests/_svgp_ref.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
"""Independent whitened-SVGP ELBO reference (numpy + scipy, no PTGP internals).

Shared by the SVGP ELBO-match tests. The reference assembles the ELBO from
first principles (whitened predict by hand, scipy.integrate.quad for the
variational expectation, closed-form whitened KL) and only needs a numpy
``log_prob(f, y)`` for the likelihood, so it is library-agnostic.

The leading underscore keeps pytest from collecting this module as a test file.
"""

import numpy as np
import scipy.linalg

from scipy import integrate


def _matern52_numpy(X1, X2, ls, eta):
"""Matern52 kernel in numpy: k(r) = eta^2 (1 + sqrt(5)r + 5r^2/3) exp(-sqrt(5)r)."""
sqd = np.sum(X1**2, axis=-1)[:, None] + np.sum(X2**2, axis=-1)[None, :] - 2.0 * X1 @ X2.T
r = np.sqrt(np.maximum(sqd, 0.0)) / ls
s5 = np.sqrt(5.0)
return eta**2 * (1.0 + s5 * r + 5.0 * r**2 / 3.0) * np.exp(-s5 * r)


def whitened_predict(X, Z, ls, eta, q_mu, q_sqrt):
"""Whitened-SVGP marginal predictive mean and variance at X."""
Kzz = _matern52_numpy(Z, Z, ls, eta)
Kzx = _matern52_numpy(Z, X, ls, eta)
K_diag = np.full(X.shape[0], eta**2) # Matern52 diag is eta^2
Lz = scipy.linalg.cholesky(Kzz, lower=True)
A = scipy.linalg.solve_triangular(Lz, Kzx, lower=True) # (M, N)
mu_f = A.T @ q_mu
var_f = K_diag - np.sum(A**2, axis=0) + np.sum((A.T @ q_sqrt) ** 2, axis=1)
return mu_f, var_f


def whitened_kl(q_mu, q_sqrt):
"""Closed-form KL[N(q_mu, q_sqrt q_sqrt^T) || N(0, I)]."""
M = q_mu.size
logdet_S = 2.0 * np.sum(np.log(np.abs(np.diag(q_sqrt))))
return 0.5 * (np.sum(q_sqrt**2) + q_mu @ q_mu - M - logdet_S)


def variational_expectation_quad(log_prob_fn, y, mu_f, var_f):
"""Sum of per-point E_{q(f_n)}[log p(y_n|f_n)] via adaptive quadrature.

Integrates in the standard-normal z-scale: f = mu + sqrt(var) * z. The
[-30, 30] z-range is well past the numerical support of exp(-z^2/2).
"""
total = 0.0
for yn, m, v in zip(y, mu_f, var_f):
sd = np.sqrt(v)

def integrand(z, yn=yn, m=m, sd=sd):
return log_prob_fn(m + sd * z, yn) * np.exp(-0.5 * z**2) / np.sqrt(2.0 * np.pi)

val, _ = integrate.quad(integrand, -30.0, 30.0)
total += val
return total


def reference_elbo(X, y, Z, q_mu, q_sqrt, ls, eta, log_prob_fn):
"""Whitened-SVGP ELBO assembled from numpy + scipy."""
mu_f, var_f = whitened_predict(X, Z, ls, eta, q_mu, q_sqrt)
ve = variational_expectation_quad(log_prob_fn, y, mu_f, var_f)
kl = whitened_kl(q_mu, q_sqrt)
return ve - kl


def fixed_config(rng, N=40, M=8, x_range=(-2.0, 2.0)):
"""Shared fixed configuration for ELBO-match tests."""
X = np.sort(rng.uniform(x_range[0], x_range[1], N))[:, None]
Z = np.linspace(x_range[0], x_range[1], M)[:, None]
q_mu = rng.normal(0, 0.3, M)
L = np.tril(rng.normal(0, 0.2, (M, M)))
L[np.arange(M), np.arange(M)] = np.abs(L[np.arange(M), np.arange(M)]) + 0.5
return X, Z, q_mu, L
120 changes: 61 additions & 59 deletions tests/kernels/test_stationary.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,20 @@
"""Stationary kernel tests against GPJax reference implementation."""
"""Stationary kernel tests against closed-form analytic references.

Each kernel is checked against its textbook formula evaluated in NumPy, so the
suite pins down that the kernels are mathematically correct without depending on
another GP library.
"""

import jax.numpy as jnp
import numpy as np
import pytensor
import pytensor.tensor as pt
import pytest

from gpjax.kernels.stationary import (
RBF as GPJaxRBF,
)
from gpjax.kernels.stationary import (
Matern32 as GPJaxMatern32,
)
from gpjax.kernels.stationary import (
Matern52 as GPJaxMatern52,
)

from ptgp.kernels import ExpQuad, Matern12, Matern32, Matern52

# GPJax uses float32 internally, so comparisons are limited to ~1e-6 precision.
ATOL = 1e-5
# All kernels evaluate in float64, so the comparison against the analytic
# reference is tight.
ATOL = 1e-10


def _ptgp_eval(kernel, X_np, Y_np=None):
Expand All @@ -31,18 +26,31 @@ def _ptgp_eval(kernel, X_np, Y_np=None):
return f()


def _gpjax_gram(kernel, X_np):
X_jnp = jnp.array(X_np, dtype=jnp.float32)
return np.array(kernel.gram(X_jnp).to_dense())
def _scaled_dist(X, Y, ls):
"""Euclidean distance between rows of X and Y after dividing by lengthscale.

``ls`` may be a scalar (isotropic) or a per-dimension vector (ARD).
"""
ls = np.asarray(ls, dtype=float)
Xs, Ys = X / ls, Y / ls
sqd = np.sum(Xs**2, axis=1)[:, None] + np.sum(Ys**2, axis=1)[None, :] - 2.0 * Xs @ Ys.T
return np.sqrt(np.maximum(sqd, 0.0))

def _gpjax_cross(kernel, X_np, Y_np):
return np.array(
kernel.cross_covariance(
jnp.array(X_np, dtype=jnp.float32),
jnp.array(Y_np, dtype=jnp.float32),
)
)

def _analytic(X, Y, ls, eta, kind):
"""Closed-form covariance matrix for a stationary kernel, in NumPy."""
r = _scaled_dist(X, Y, ls)
if kind == "expquad":
k = np.exp(-0.5 * r**2)
elif kind == "matern52":
k = (1.0 + np.sqrt(5.0) * r + 5.0 * r**2 / 3.0) * np.exp(-np.sqrt(5.0) * r)
elif kind == "matern32":
k = (1.0 + np.sqrt(3.0) * r) * np.exp(-np.sqrt(3.0) * r)
elif kind == "matern12":
k = np.exp(-r)
else: # pragma: no cover - guards against typos in test parametrization
raise ValueError(f"unknown kernel kind {kind!r}")
return eta**2 * k


@pytest.fixture
Expand Down Expand Up @@ -71,21 +79,20 @@ class TestExpQuad:
def test_gram_1d(self, X_1d):
ls, eta = 1.5, 2.0
ptgp_k = eta**2 * ExpQuad(input_dim=1, ls=ls)
gpjax_k = GPJaxRBF(lengthscale=jnp.array(ls), variance=jnp.array(eta**2))
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), _gpjax_gram(gpjax_k, X_1d), atol=ATOL)
ref = _analytic(X_1d, X_1d, ls, eta, "expquad")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), ref, atol=ATOL)

def test_cross_1d(self, X_1d, X_1d_other):
ls, eta = 1.5, 2.0
ptgp_k = eta**2 * ExpQuad(input_dim=1, ls=ls)
gpjax_k = GPJaxRBF(lengthscale=jnp.array(ls), variance=jnp.array(eta**2))
np.testing.assert_allclose(
_ptgp_eval(ptgp_k, X_1d, X_1d_other), _gpjax_cross(gpjax_k, X_1d, X_1d_other), atol=ATOL
)
ref = _analytic(X_1d, X_1d_other, ls, eta, "expquad")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d, X_1d_other), ref, atol=ATOL)

def test_gram_2d(self, X_2d):
ptgp_k = ExpQuad(input_dim=2, ls=0.8)
gpjax_k = GPJaxRBF(lengthscale=jnp.array(0.8), variance=jnp.array(1.0))
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), _gpjax_gram(gpjax_k, X_2d), atol=ATOL)
ls, eta = 0.8, 1.0
ptgp_k = ExpQuad(input_dim=2, ls=ls)
ref = _analytic(X_2d, X_2d, ls, eta, "expquad")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), ref, atol=ATOL)

def test_symmetric_annotation(self, X_1d):
X_pt = pt.as_tensor_variable(X_1d)
Expand All @@ -110,40 +117,42 @@ class TestMatern52:
def test_gram_1d(self, X_1d):
ls, eta = 1.2, 1.5
ptgp_k = eta**2 * Matern52(input_dim=1, ls=ls)
gpjax_k = GPJaxMatern52(lengthscale=jnp.array(ls), variance=jnp.array(eta**2))
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), _gpjax_gram(gpjax_k, X_1d), atol=ATOL)
ref = _analytic(X_1d, X_1d, ls, eta, "matern52")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), ref, atol=ATOL)

def test_cross_1d(self, X_1d, X_1d_other):
ls, eta = 1.2, 1.5
ptgp_k = eta**2 * Matern52(input_dim=1, ls=ls)
gpjax_k = GPJaxMatern52(lengthscale=jnp.array(ls), variance=jnp.array(eta**2))
np.testing.assert_allclose(
_ptgp_eval(ptgp_k, X_1d, X_1d_other), _gpjax_cross(gpjax_k, X_1d, X_1d_other), atol=ATOL
)
ref = _analytic(X_1d, X_1d_other, ls, eta, "matern52")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d, X_1d_other), ref, atol=ATOL)

def test_gram_2d(self, X_2d):
ls, eta = 0.5, 2.0
ptgp_k = eta**2 * Matern52(input_dim=2, ls=ls)
gpjax_k = GPJaxMatern52(lengthscale=jnp.array(ls), variance=jnp.array(eta**2))
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), _gpjax_gram(gpjax_k, X_2d), atol=ATOL)
ref = _analytic(X_2d, X_2d, ls, eta, "matern52")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), ref, atol=ATOL)


class TestMatern32:
def test_gram_1d(self, X_1d):
ptgp_k = Matern32(input_dim=1, ls=2.0)
gpjax_k = GPJaxMatern32(lengthscale=jnp.array(2.0), variance=jnp.array(1.0))
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), _gpjax_gram(gpjax_k, X_1d), atol=ATOL)
ls, eta = 2.0, 1.0
ptgp_k = Matern32(input_dim=1, ls=ls)
ref = _analytic(X_1d, X_1d, ls, eta, "matern32")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), ref, atol=ATOL)

def test_cross_2d(self, X_2d, X_2d_other):
ls, eta = 0.7, 1.3
ptgp_k = eta**2 * Matern32(input_dim=2, ls=ls)
gpjax_k = GPJaxMatern32(lengthscale=jnp.array(ls), variance=jnp.array(eta**2))
np.testing.assert_allclose(
_ptgp_eval(ptgp_k, X_2d, X_2d_other), _gpjax_cross(gpjax_k, X_2d, X_2d_other), atol=ATOL
)
ref = _analytic(X_2d, X_2d_other, ls, eta, "matern32")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d, X_2d_other), ref, atol=ATOL)


class TestMatern12:
def test_gram_1d(self, X_1d):
ls, eta = 1.0, 1.0
ref = _analytic(X_1d, X_1d, ls, eta, "matern12")
np.testing.assert_allclose(_ptgp_eval(Matern12(input_dim=1, ls=ls), X_1d), ref, atol=ATOL)

def test_gram_symmetry(self, X_1d):
K = _ptgp_eval(Matern12(input_dim=1, ls=1.0), X_1d)
np.testing.assert_allclose(K, K.T, atol=1e-14)
Expand Down Expand Up @@ -180,16 +189,9 @@ def test_scalar_and_vector_ls_match_when_equal(self, X_2d):
k_ard = Matern52(input_dim=2, ls=np.array([0.5, 0.5]))
np.testing.assert_allclose(_ptgp_eval(k_iso, X_2d), _ptgp_eval(k_ard, X_2d), atol=1e-14)

def test_ard_vs_pymc(self, X_2d):
"""ARD Matern52 with per-dim lengthscales matches PyMC's implementation."""
import pymc as pm

def test_ard_matches_analytic(self, X_2d):
"""ARD Matern52 with per-dim lengthscales matches the closed-form reference."""
ls = np.array([0.5, 1.2])
ptgp_k = Matern52(input_dim=2, ls=ls)
pymc_k = pm.gp.cov.Matern52(input_dim=2, ls=ls)

K_ptgp = _ptgp_eval(ptgp_k, X_2d)
K_pymc = pymc_k(pt.as_tensor_variable(X_2d)).eval()
# PyMC and PTGP use different NaN-safe sqrt strategies near zero, so
# allow small numerical differences.
np.testing.assert_allclose(K_ptgp, K_pymc, atol=1e-8)
ref = _analytic(X_2d, X_2d, ls, 1.0, "matern52")
np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), ref, atol=ATOL)
41 changes: 24 additions & 17 deletions tests/likelihoods/test_bernoulli.py
Original file line number Diff line number Diff line change
@@ -1,25 +1,41 @@
"""Bernoulli likelihood tests against GPJax reference and analytical results."""
"""Bernoulli likelihood tests against an independent quadrature reference."""

import jax.numpy as jnp
import numpy as np
import pytensor
import pytensor.tensor as pt

from gpjax.integrators import GHQuadratureIntegrator
from gpjax.likelihoods import Bernoulli as GPJaxBernoulli
from scipy.special import erf

from ptgp.likelihoods import Bernoulli

ATOL = 1e-5
ATOL = 1e-10


def _eval(*tensors):
f = pytensor.function([], list(tensors) if len(tensors) > 1 else tensors[0])
return f()


def _bernoulli_ve_reference(y, mu, var, n_points=20):
"""Gauss-Hermite quadrature of E_{q(f)}[log p(y | f)] for the probit Bernoulli.

Reimplements the same quadrature rule PTGP uses (physicist Hermite nodes,
weights divided by sqrt(pi), f = mu + sqrt(2 var) t) and the same clamped
probit link, in plain NumPy. This validates PTGP's pytensor implementation
against a from-scratch reference rather than against another GP library.
"""
jitter = 1e-3 # mirrors ptgp.likelihoods.bernoulli.inv_probit clamping
nodes, weights = np.polynomial.hermite.hermgauss(n_points)
weights = weights / np.sqrt(np.pi)
sd = np.sqrt(var)[:, None]
F = mu[:, None] + np.sqrt(2.0) * sd * nodes[None, :]
p = 0.5 * (1.0 + erf(F / np.sqrt(2.0))) * (1.0 - 2.0 * jitter) + jitter
log_prob = y[:, None] * np.log(p) + (1.0 - y[:, None]) * np.log(1.0 - p)
return np.sum(log_prob * weights[None, :], axis=1)


class TestBernoulli:
def test_ve_against_gpjax(self):
def test_ve_matches_quadrature(self):
mu, var = np.array([0.0, 1.0, -1.0]), np.array([0.25, 0.5, 1.0])
y = np.array([1.0, 1.0, 0.0])

Expand All @@ -29,18 +45,9 @@ def test_ve_against_gpjax(self):
)
)

gpjax_ve = np.array(
GPJaxBernoulli(
num_datapoints=3,
integrator=GHQuadratureIntegrator(num_points=20),
).expected_log_likelihood(
y=jnp.array(y)[:, None],
mean=jnp.array(mu)[:, None],
variance=jnp.array(var)[:, None],
)
)
expected = _bernoulli_ve_reference(y, mu, var, n_points=20)

np.testing.assert_allclose(ve, gpjax_ve, atol=ATOL)
np.testing.assert_allclose(ve, expected, atol=ATOL)

def test_predict_mean_and_var_closed_form(self):
mu, var = np.array([0.0, 2.0, -2.0]), np.array([0.1, 0.5, 1.0])
Expand Down
Loading
Loading