diff --git a/conda_envs/environment-test.yaml b/conda_envs/environment-test.yaml index 4574d91..76d17ce 100644 --- a/conda_envs/environment-test.yaml +++ b/conda_envs/environment-test.yaml @@ -25,8 +25,3 @@ dependencies: - threadpoolctl - typing_extensions - pip - - pip: - - jax - - jaxlib - - optax - - gpjax diff --git a/ptgp/kernels/stationary.py b/ptgp/kernels/stationary.py index a20b958..9f6046b 100644 --- a/ptgp/kernels/stationary.py +++ b/ptgp/kernels/stationary.py @@ -1,3 +1,4 @@ +import pytensor import pytensor.tensor as pt from ptgp.kernels.base import Kernel @@ -87,8 +88,9 @@ class Matern52(Stationary): def _eval(self, X, Y): tau = self._scaled_euclid_dist(X, Y) - sqrt5 = pt.sqrt(5.0) - return (1.0 + sqrt5 * tau + 5.0 / 3.0 * pt.square(tau)) * pt.exp(-sqrt5 * tau) + sqrt5 = pt.constant(5.0**0.5, dtype=pytensor.config.floatX) + five_thirds = pt.constant(5.0 / 3.0, dtype=pytensor.config.floatX) + return (1.0 + sqrt5 * tau + five_thirds * pt.square(tau)) * pt.exp(-sqrt5 * tau) class Matern32(Stationary): @@ -100,7 +102,7 @@ class Matern32(Stationary): def _eval(self, X, Y): tau = self._scaled_euclid_dist(X, Y) - sqrt3 = pt.sqrt(3.0) + sqrt3 = pt.constant(3.0**0.5, dtype=pytensor.config.floatX) return (1.0 + sqrt3 * tau) * pt.exp(-sqrt3 * tau) diff --git a/tests/_svgp_ref.py b/tests/_svgp_ref.py new file mode 100644 index 0000000..ead2a06 --- /dev/null +++ b/tests/_svgp_ref.py @@ -0,0 +1,77 @@ +"""Independent whitened-SVGP ELBO reference (numpy + scipy, no PTGP internals). + +Shared by the SVGP ELBO-match tests. The reference assembles the ELBO from +first principles (whitened predict by hand, scipy.integrate.quad for the +variational expectation, closed-form whitened KL) and only needs a numpy +``log_prob(f, y)`` for the likelihood, so it is library-agnostic. + +The leading underscore keeps pytest from collecting this module as a test file. +""" + +import numpy as np +import scipy.linalg + +from scipy import integrate + + +def _matern52_numpy(X1, X2, ls, eta): + """Matern52 kernel in numpy: k(r) = eta^2 (1 + sqrt(5)r + 5r^2/3) exp(-sqrt(5)r).""" + sqd = np.sum(X1**2, axis=-1)[:, None] + np.sum(X2**2, axis=-1)[None, :] - 2.0 * X1 @ X2.T + r = np.sqrt(np.maximum(sqd, 0.0)) / ls + s5 = np.sqrt(5.0) + return eta**2 * (1.0 + s5 * r + 5.0 * r**2 / 3.0) * np.exp(-s5 * r) + + +def whitened_predict(X, Z, ls, eta, q_mu, q_sqrt): + """Whitened-SVGP marginal predictive mean and variance at X.""" + Kzz = _matern52_numpy(Z, Z, ls, eta) + Kzx = _matern52_numpy(Z, X, ls, eta) + K_diag = np.full(X.shape[0], eta**2) # Matern52 diag is eta^2 + Lz = scipy.linalg.cholesky(Kzz, lower=True) + A = scipy.linalg.solve_triangular(Lz, Kzx, lower=True) # (M, N) + mu_f = A.T @ q_mu + var_f = K_diag - np.sum(A**2, axis=0) + np.sum((A.T @ q_sqrt) ** 2, axis=1) + return mu_f, var_f + + +def whitened_kl(q_mu, q_sqrt): + """Closed-form KL[N(q_mu, q_sqrt q_sqrt^T) || N(0, I)].""" + M = q_mu.size + logdet_S = 2.0 * np.sum(np.log(np.abs(np.diag(q_sqrt)))) + return 0.5 * (np.sum(q_sqrt**2) + q_mu @ q_mu - M - logdet_S) + + +def variational_expectation_quad(log_prob_fn, y, mu_f, var_f): + """Sum of per-point E_{q(f_n)}[log p(y_n|f_n)] via adaptive quadrature. + + Integrates in the standard-normal z-scale: f = mu + sqrt(var) * z. The + [-30, 30] z-range is well past the numerical support of exp(-z^2/2). + """ + total = 0.0 + for yn, m, v in zip(y, mu_f, var_f): + sd = np.sqrt(v) + + def integrand(z, yn=yn, m=m, sd=sd): + return log_prob_fn(m + sd * z, yn) * np.exp(-0.5 * z**2) / np.sqrt(2.0 * np.pi) + + val, _ = integrate.quad(integrand, -30.0, 30.0) + total += val + return total + + +def reference_elbo(X, y, Z, q_mu, q_sqrt, ls, eta, log_prob_fn): + """Whitened-SVGP ELBO assembled from numpy + scipy.""" + mu_f, var_f = whitened_predict(X, Z, ls, eta, q_mu, q_sqrt) + ve = variational_expectation_quad(log_prob_fn, y, mu_f, var_f) + kl = whitened_kl(q_mu, q_sqrt) + return ve - kl + + +def fixed_config(rng, N=40, M=8, x_range=(-2.0, 2.0)): + """Shared fixed configuration for ELBO-match tests.""" + X = np.sort(rng.uniform(x_range[0], x_range[1], N))[:, None] + Z = np.linspace(x_range[0], x_range[1], M)[:, None] + q_mu = rng.normal(0, 0.3, M) + L = np.tril(rng.normal(0, 0.2, (M, M))) + L[np.arange(M), np.arange(M)] = np.abs(L[np.arange(M), np.arange(M)]) + 0.5 + return X, Z, q_mu, L diff --git a/tests/kernels/test_stationary.py b/tests/kernels/test_stationary.py index d563e9d..a2701e7 100644 --- a/tests/kernels/test_stationary.py +++ b/tests/kernels/test_stationary.py @@ -1,25 +1,20 @@ -"""Stationary kernel tests against GPJax reference implementation.""" +"""Stationary kernel tests against closed-form analytic references. + +Each kernel is checked against its textbook formula evaluated in NumPy, so the +suite pins down that the kernels are mathematically correct without depending on +another GP library. +""" -import jax.numpy as jnp import numpy as np import pytensor import pytensor.tensor as pt import pytest -from gpjax.kernels.stationary import ( - RBF as GPJaxRBF, -) -from gpjax.kernels.stationary import ( - Matern32 as GPJaxMatern32, -) -from gpjax.kernels.stationary import ( - Matern52 as GPJaxMatern52, -) - from ptgp.kernels import ExpQuad, Matern12, Matern32, Matern52 -# GPJax uses float32 internally, so comparisons are limited to ~1e-6 precision. -ATOL = 1e-5 +# All kernels evaluate in float64, so the comparison against the analytic +# reference is tight. +ATOL = 1e-10 def _ptgp_eval(kernel, X_np, Y_np=None): @@ -31,18 +26,31 @@ def _ptgp_eval(kernel, X_np, Y_np=None): return f() -def _gpjax_gram(kernel, X_np): - X_jnp = jnp.array(X_np, dtype=jnp.float32) - return np.array(kernel.gram(X_jnp).to_dense()) +def _scaled_dist(X, Y, ls): + """Euclidean distance between rows of X and Y after dividing by lengthscale. + ``ls`` may be a scalar (isotropic) or a per-dimension vector (ARD). + """ + ls = np.asarray(ls, dtype=float) + Xs, Ys = X / ls, Y / ls + sqd = np.sum(Xs**2, axis=1)[:, None] + np.sum(Ys**2, axis=1)[None, :] - 2.0 * Xs @ Ys.T + return np.sqrt(np.maximum(sqd, 0.0)) -def _gpjax_cross(kernel, X_np, Y_np): - return np.array( - kernel.cross_covariance( - jnp.array(X_np, dtype=jnp.float32), - jnp.array(Y_np, dtype=jnp.float32), - ) - ) + +def _analytic(X, Y, ls, eta, kind): + """Closed-form covariance matrix for a stationary kernel, in NumPy.""" + r = _scaled_dist(X, Y, ls) + if kind == "expquad": + k = np.exp(-0.5 * r**2) + elif kind == "matern52": + k = (1.0 + np.sqrt(5.0) * r + 5.0 * r**2 / 3.0) * np.exp(-np.sqrt(5.0) * r) + elif kind == "matern32": + k = (1.0 + np.sqrt(3.0) * r) * np.exp(-np.sqrt(3.0) * r) + elif kind == "matern12": + k = np.exp(-r) + else: # pragma: no cover - guards against typos in test parametrization + raise ValueError(f"unknown kernel kind {kind!r}") + return eta**2 * k @pytest.fixture @@ -71,21 +79,20 @@ class TestExpQuad: def test_gram_1d(self, X_1d): ls, eta = 1.5, 2.0 ptgp_k = eta**2 * ExpQuad(input_dim=1, ls=ls) - gpjax_k = GPJaxRBF(lengthscale=jnp.array(ls), variance=jnp.array(eta**2)) - np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), _gpjax_gram(gpjax_k, X_1d), atol=ATOL) + ref = _analytic(X_1d, X_1d, ls, eta, "expquad") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), ref, atol=ATOL) def test_cross_1d(self, X_1d, X_1d_other): ls, eta = 1.5, 2.0 ptgp_k = eta**2 * ExpQuad(input_dim=1, ls=ls) - gpjax_k = GPJaxRBF(lengthscale=jnp.array(ls), variance=jnp.array(eta**2)) - np.testing.assert_allclose( - _ptgp_eval(ptgp_k, X_1d, X_1d_other), _gpjax_cross(gpjax_k, X_1d, X_1d_other), atol=ATOL - ) + ref = _analytic(X_1d, X_1d_other, ls, eta, "expquad") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d, X_1d_other), ref, atol=ATOL) def test_gram_2d(self, X_2d): - ptgp_k = ExpQuad(input_dim=2, ls=0.8) - gpjax_k = GPJaxRBF(lengthscale=jnp.array(0.8), variance=jnp.array(1.0)) - np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), _gpjax_gram(gpjax_k, X_2d), atol=ATOL) + ls, eta = 0.8, 1.0 + ptgp_k = ExpQuad(input_dim=2, ls=ls) + ref = _analytic(X_2d, X_2d, ls, eta, "expquad") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), ref, atol=ATOL) def test_symmetric_annotation(self, X_1d): X_pt = pt.as_tensor_variable(X_1d) @@ -110,40 +117,42 @@ class TestMatern52: def test_gram_1d(self, X_1d): ls, eta = 1.2, 1.5 ptgp_k = eta**2 * Matern52(input_dim=1, ls=ls) - gpjax_k = GPJaxMatern52(lengthscale=jnp.array(ls), variance=jnp.array(eta**2)) - np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), _gpjax_gram(gpjax_k, X_1d), atol=ATOL) + ref = _analytic(X_1d, X_1d, ls, eta, "matern52") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), ref, atol=ATOL) def test_cross_1d(self, X_1d, X_1d_other): ls, eta = 1.2, 1.5 ptgp_k = eta**2 * Matern52(input_dim=1, ls=ls) - gpjax_k = GPJaxMatern52(lengthscale=jnp.array(ls), variance=jnp.array(eta**2)) - np.testing.assert_allclose( - _ptgp_eval(ptgp_k, X_1d, X_1d_other), _gpjax_cross(gpjax_k, X_1d, X_1d_other), atol=ATOL - ) + ref = _analytic(X_1d, X_1d_other, ls, eta, "matern52") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d, X_1d_other), ref, atol=ATOL) def test_gram_2d(self, X_2d): ls, eta = 0.5, 2.0 ptgp_k = eta**2 * Matern52(input_dim=2, ls=ls) - gpjax_k = GPJaxMatern52(lengthscale=jnp.array(ls), variance=jnp.array(eta**2)) - np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), _gpjax_gram(gpjax_k, X_2d), atol=ATOL) + ref = _analytic(X_2d, X_2d, ls, eta, "matern52") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), ref, atol=ATOL) class TestMatern32: def test_gram_1d(self, X_1d): - ptgp_k = Matern32(input_dim=1, ls=2.0) - gpjax_k = GPJaxMatern32(lengthscale=jnp.array(2.0), variance=jnp.array(1.0)) - np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), _gpjax_gram(gpjax_k, X_1d), atol=ATOL) + ls, eta = 2.0, 1.0 + ptgp_k = Matern32(input_dim=1, ls=ls) + ref = _analytic(X_1d, X_1d, ls, eta, "matern32") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_1d), ref, atol=ATOL) def test_cross_2d(self, X_2d, X_2d_other): ls, eta = 0.7, 1.3 ptgp_k = eta**2 * Matern32(input_dim=2, ls=ls) - gpjax_k = GPJaxMatern32(lengthscale=jnp.array(ls), variance=jnp.array(eta**2)) - np.testing.assert_allclose( - _ptgp_eval(ptgp_k, X_2d, X_2d_other), _gpjax_cross(gpjax_k, X_2d, X_2d_other), atol=ATOL - ) + ref = _analytic(X_2d, X_2d_other, ls, eta, "matern32") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d, X_2d_other), ref, atol=ATOL) class TestMatern12: + def test_gram_1d(self, X_1d): + ls, eta = 1.0, 1.0 + ref = _analytic(X_1d, X_1d, ls, eta, "matern12") + np.testing.assert_allclose(_ptgp_eval(Matern12(input_dim=1, ls=ls), X_1d), ref, atol=ATOL) + def test_gram_symmetry(self, X_1d): K = _ptgp_eval(Matern12(input_dim=1, ls=1.0), X_1d) np.testing.assert_allclose(K, K.T, atol=1e-14) @@ -180,16 +189,9 @@ def test_scalar_and_vector_ls_match_when_equal(self, X_2d): k_ard = Matern52(input_dim=2, ls=np.array([0.5, 0.5])) np.testing.assert_allclose(_ptgp_eval(k_iso, X_2d), _ptgp_eval(k_ard, X_2d), atol=1e-14) - def test_ard_vs_pymc(self, X_2d): - """ARD Matern52 with per-dim lengthscales matches PyMC's implementation.""" - import pymc as pm - + def test_ard_matches_analytic(self, X_2d): + """ARD Matern52 with per-dim lengthscales matches the closed-form reference.""" ls = np.array([0.5, 1.2]) ptgp_k = Matern52(input_dim=2, ls=ls) - pymc_k = pm.gp.cov.Matern52(input_dim=2, ls=ls) - - K_ptgp = _ptgp_eval(ptgp_k, X_2d) - K_pymc = pymc_k(pt.as_tensor_variable(X_2d)).eval() - # PyMC and PTGP use different NaN-safe sqrt strategies near zero, so - # allow small numerical differences. - np.testing.assert_allclose(K_ptgp, K_pymc, atol=1e-8) + ref = _analytic(X_2d, X_2d, ls, 1.0, "matern52") + np.testing.assert_allclose(_ptgp_eval(ptgp_k, X_2d), ref, atol=ATOL) diff --git a/tests/likelihoods/test_bernoulli.py b/tests/likelihoods/test_bernoulli.py index ca75fa8..e4c593a 100644 --- a/tests/likelihoods/test_bernoulli.py +++ b/tests/likelihoods/test_bernoulli.py @@ -1,16 +1,14 @@ -"""Bernoulli likelihood tests against GPJax reference and analytical results.""" +"""Bernoulli likelihood tests against an independent quadrature reference.""" -import jax.numpy as jnp import numpy as np import pytensor import pytensor.tensor as pt -from gpjax.integrators import GHQuadratureIntegrator -from gpjax.likelihoods import Bernoulli as GPJaxBernoulli +from scipy.special import erf from ptgp.likelihoods import Bernoulli -ATOL = 1e-5 +ATOL = 1e-10 def _eval(*tensors): @@ -18,8 +16,26 @@ def _eval(*tensors): return f() +def _bernoulli_ve_reference(y, mu, var, n_points=20): + """Gauss-Hermite quadrature of E_{q(f)}[log p(y | f)] for the probit Bernoulli. + + Reimplements the same quadrature rule PTGP uses (physicist Hermite nodes, + weights divided by sqrt(pi), f = mu + sqrt(2 var) t) and the same clamped + probit link, in plain NumPy. This validates PTGP's pytensor implementation + against a from-scratch reference rather than against another GP library. + """ + jitter = 1e-3 # mirrors ptgp.likelihoods.bernoulli.inv_probit clamping + nodes, weights = np.polynomial.hermite.hermgauss(n_points) + weights = weights / np.sqrt(np.pi) + sd = np.sqrt(var)[:, None] + F = mu[:, None] + np.sqrt(2.0) * sd * nodes[None, :] + p = 0.5 * (1.0 + erf(F / np.sqrt(2.0))) * (1.0 - 2.0 * jitter) + jitter + log_prob = y[:, None] * np.log(p) + (1.0 - y[:, None]) * np.log(1.0 - p) + return np.sum(log_prob * weights[None, :], axis=1) + + class TestBernoulli: - def test_ve_against_gpjax(self): + def test_ve_matches_quadrature(self): mu, var = np.array([0.0, 1.0, -1.0]), np.array([0.25, 0.5, 1.0]) y = np.array([1.0, 1.0, 0.0]) @@ -29,18 +45,9 @@ def test_ve_against_gpjax(self): ) ) - gpjax_ve = np.array( - GPJaxBernoulli( - num_datapoints=3, - integrator=GHQuadratureIntegrator(num_points=20), - ).expected_log_likelihood( - y=jnp.array(y)[:, None], - mean=jnp.array(mu)[:, None], - variance=jnp.array(var)[:, None], - ) - ) + expected = _bernoulli_ve_reference(y, mu, var, n_points=20) - np.testing.assert_allclose(ve, gpjax_ve, atol=ATOL) + np.testing.assert_allclose(ve, expected, atol=ATOL) def test_predict_mean_and_var_closed_form(self): mu, var = np.array([0.0, 2.0, -2.0]), np.array([0.1, 0.5, 1.0]) diff --git a/tests/likelihoods/test_gaussian.py b/tests/likelihoods/test_gaussian.py index 3593739..c997c65 100644 --- a/tests/likelihoods/test_gaussian.py +++ b/tests/likelihoods/test_gaussian.py @@ -1,16 +1,12 @@ -"""Gaussian likelihood tests against GPJax reference and analytical results.""" +"""Gaussian likelihood tests against closed-form analytic results.""" -import jax.numpy as jnp import numpy as np import pytensor import pytensor.tensor as pt -from gpjax.integrators import AnalyticalGaussianIntegrator -from gpjax.likelihoods import Gaussian as GPJaxGaussian - from ptgp.likelihoods import Gaussian -ATOL = 1e-5 +ATOL = 1e-12 def _eval(*tensors): @@ -19,7 +15,7 @@ def _eval(*tensors): class TestGaussian: - def test_ve_against_gpjax(self): + def test_ve_closed_form(self): mu, var = np.array([0.0, 0.5, -1.0]), np.array([0.1, 0.5, 1.0]) y, sigma = np.array([0.1, 0.3, -0.8]), 0.5 @@ -29,19 +25,11 @@ def test_ve_against_gpjax(self): ) ) - gpjax_ve = np.array( - GPJaxGaussian( - num_datapoints=3, - obs_stddev=jnp.array(sigma), - integrator=AnalyticalGaussianIntegrator(), - ).expected_log_likelihood( - y=jnp.array(y)[:, None], - mean=jnp.array(mu)[:, None], - variance=jnp.array(var)[:, None], - ) - ) + # E_{q(f)}[log N(y | f, sigma^2)] for q(f) = N(mu, var) has the closed form + # -0.5 log(2 pi sigma^2) - 0.5 ((y - mu)^2 + var) / sigma^2 + expected = -0.5 * np.log(2 * np.pi * sigma**2) - 0.5 * ((y - mu) ** 2 + var) / sigma**2 - np.testing.assert_allclose(ve, gpjax_ve, atol=ATOL) + np.testing.assert_allclose(ve, expected, atol=ATOL) def test_zero_var_matches_log_prob(self): mu, y, sigma = np.array([0.0, 1.0]), np.array([0.1, 0.9]), 0.3 diff --git a/tests/test_svgp.py b/tests/test_svgp.py index 8614de8..19e0b8c 100644 --- a/tests/test_svgp.py +++ b/tests/test_svgp.py @@ -1,27 +1,37 @@ -"""SVGP tests for non-Gaussian likelihoods, cross-checked against GPJax. +"""SVGP tests for non-Gaussian likelihoods, cross-checked against an +independent numpy+scipy reference. The likelihood unit tests already verify the Gauss-Hermite ``variational_expectation`` in isolation. These tests close the remaining gap: exercising the full SVGP ELBO wiring (predict + KL + variational -expectation) with a non-Gaussian likelihood end-to-end. +expectation) with a non-Gaussian likelihood end-to-end. The reference ELBO +machinery lives in ``tests/_svgp_ref.py``. """ -import jax +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt -jax.config.update("jax_enable_x64", True) +from scipy.special import erf, gammaln -import gpjax as gpx # noqa: E402 -import jax.numpy as jnp # noqa: E402 -import numpy as np # noqa: E402 -import pymc as pm # noqa: E402 -import pytensor # noqa: E402 -import pytensor.tensor as pt # noqa: E402 +import ptgp as pg -from scipy.special import erf # noqa: E402 +from tests._svgp_ref import fixed_config, reference_elbo -import ptgp as pg # noqa: E402 +ATOL = 1e-5 -ATOL = 1e-4 # cross-library noise is ~3e-5 once both libs jitter at 1e-6 + +def _bernoulli_logprob(f, y): + """Numpy log p(y | f) for the probit Bernoulli (matches ptgp's clamped link).""" + jitter = 1e-3 # mirrors ptgp.likelihoods.bernoulli.inv_probit clamping + p = 0.5 * (1.0 + erf(f / np.sqrt(2.0))) * (1.0 - 2.0 * jitter) + jitter + return y * np.log(p) + (1.0 - y) * np.log(1.0 - p) + + +def _poisson_logprob(f, y): + """Numpy log p(y | f) for the log-link Poisson: y f - exp(f) - log(y!).""" + return y * f - np.exp(f) - gammaln(y + 1.0) def _binary_data(rng, n=80): @@ -94,24 +104,14 @@ def test_loss_decreases_and_classifies(self): assert acc > 0.85, f"classification accuracy {acc:.2f} too low" -class TestSVGPBernoulliElboMatchesGPJax: - """Evaluate the whitened-SVGP ELBO in PTGP and GPJax at a fixed - configuration (hyperparameters, inducing points, q_mu, q_sqrt all - identical) and require the two scalars to match at atol=1e-5. No +class TestSVGPBernoulliElboMatchesReference: + """Evaluate the whitened-SVGP ELBO in PTGP and in the numpy+scipy + reference at a fixed configuration (hyperparameters, inducing points, + q_mu, q_sqrt all identical) and require the two scalars to match. No optimizer — this pins the ELBO math (predict + KL + variational - expectation) against a reference implementation. + expectation) against a first-principles reference. """ - def _fixed_config(self, rng, N=40, M=8): - X = np.sort(rng.uniform(-3, 3, N))[:, None] - y = rng.integers(0, 2, N).astype(np.float64) - Z = np.linspace(-3, 3, M)[:, None] - q_mu = rng.normal(0, 0.3, M) - # Lower-triangular factor with positive diagonal. - L = np.tril(rng.normal(0, 0.2, (M, M))) - L[np.arange(M), np.arange(M)] = np.abs(L[np.arange(M), np.arange(M)]) + 0.5 - return X, y, Z, q_mu, L - def _ptgp_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val): """Evaluate PTGP whitened-SVGP ELBO at the fixed configuration.""" ls = pt.scalar("ls") @@ -132,35 +132,16 @@ def _ptgp_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val): fn = pytensor.function([X_var, y_var, *vp.extra_vars, ls, eta], elbo_expr) return float(fn(X, y, *vp.extra_init, ls_val, eta_val)) - def _gpjax_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val): - """Evaluate GPJax whitened-SVGP ELBO at the same configuration.""" - kernel = gpx.kernels.Matern52( - active_dims=[0], lengthscale=jnp.array(ls_val), variance=jnp.array(eta_val**2) - ) - meanf = gpx.mean_functions.Zero() - prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) - likelihood = gpx.likelihoods.Bernoulli(num_datapoints=X.shape[0]) - posterior = prior * likelihood - # Match ptgp's _DEFAULT_JITTER = 1e-6 so the two libraries' Kzz match. - q = gpx.variational_families.WhitenedVariationalGaussian( - posterior=posterior, - inducing_inputs=jnp.array(Z), - variational_mean=jnp.array(q_mu_val)[:, None], - variational_root_covariance=jnp.array(q_sqrt_val), - jitter=1e-6, - ) - data = gpx.Dataset(X=jnp.array(X), y=jnp.array(y)[:, None]) - return float(gpx.objectives.elbo(q, data)) - def test_elbo_match(self): rng = np.random.default_rng(1) - X, y, Z, q_mu_val, q_sqrt_val = self._fixed_config(rng) + X, Z, q_mu_val, q_sqrt_val = fixed_config(rng, x_range=(-3.0, 3.0)) + y = rng.integers(0, 2, X.shape[0]).astype(np.float64) ls_val, eta_val = 1.3, 0.9 e_ptgp = self._ptgp_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val) - e_gpjax = self._gpjax_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val) + e_ref = reference_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val, _bernoulli_logprob) - np.testing.assert_allclose(e_ptgp, e_gpjax, atol=ATOL) + np.testing.assert_allclose(e_ptgp, e_ref, atol=ATOL) def _count_data(rng, n=80): @@ -233,24 +214,14 @@ def test_loss_decreases_and_rate_correlates(self): assert r > 0.8, f"rate correlation {r:.2f} too low" -class TestSVGPPoissonElboMatchesGPJax: - """Evaluate the whitened-SVGP ELBO in PTGP and GPJax at a fixed - Poisson configuration and require the two scalars to match at - atol=1e-5. Pins the closed-form Poisson variational expectation - plus the rest of the ELBO (predict + KL) against GPJax, independent - of any optimizer. +class TestSVGPPoissonElboMatchesReference: + """Evaluate the whitened-SVGP ELBO in PTGP and in the numpy+scipy + reference at a fixed Poisson configuration and require the two scalars + to match. Pins the closed-form Poisson variational expectation plus the + rest of the ELBO (predict + KL) against a first-principles reference, + independent of any optimizer. """ - def _fixed_config(self, rng, N=40, M=8): - X = np.sort(rng.uniform(-2, 2, N))[:, None] - rate = np.exp(0.5 * X[:, 0] + 0.3) - y = rng.poisson(rate).astype(np.float64) - Z = np.linspace(-2, 2, M)[:, None] - q_mu = rng.normal(0, 0.3, M) - L = np.tril(rng.normal(0, 0.2, (M, M))) - L[np.arange(M), np.arange(M)] = np.abs(L[np.arange(M), np.arange(M)]) + 0.5 - return X, y, Z, q_mu, L - def _ptgp_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val): """Evaluate PTGP whitened-SVGP ELBO at the fixed configuration.""" ls = pt.scalar("ls") @@ -271,35 +242,17 @@ def _ptgp_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val): fn = pytensor.function([X_var, y_var, *vp.extra_vars, ls, eta], elbo_expr) return float(fn(X, y, *vp.extra_init, ls_val, eta_val)) - def _gpjax_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val): - """Evaluate GPJax whitened-SVGP ELBO at the same configuration.""" - kernel = gpx.kernels.Matern52( - active_dims=[0], lengthscale=jnp.array(ls_val), variance=jnp.array(eta_val**2) - ) - meanf = gpx.mean_functions.Zero() - prior = gpx.gps.Prior(mean_function=meanf, kernel=kernel) - likelihood = gpx.likelihoods.Poisson(num_datapoints=X.shape[0]) - posterior = prior * likelihood - # Match ptgp's _DEFAULT_JITTER = 1e-6 so the two libraries' Kzz match. - q = gpx.variational_families.WhitenedVariationalGaussian( - posterior=posterior, - inducing_inputs=jnp.array(Z), - variational_mean=jnp.array(q_mu_val)[:, None], - variational_root_covariance=jnp.array(q_sqrt_val), - jitter=1e-6, - ) - data = gpx.Dataset(X=jnp.array(X), y=jnp.array(y)[:, None]) - return float(gpx.objectives.elbo(q, data)) - def test_elbo_match(self): rng = np.random.default_rng(3) - X, y, Z, q_mu_val, q_sqrt_val = self._fixed_config(rng) + X, Z, q_mu_val, q_sqrt_val = fixed_config(rng) + rate = np.exp(0.5 * X[:, 0] + 0.3) + y = rng.poisson(rate).astype(np.float64) ls_val, eta_val = 1.3, 0.9 e_ptgp = self._ptgp_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val) - e_gpjax = self._gpjax_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val) + e_ref = reference_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val, _poisson_logprob) - np.testing.assert_allclose(e_ptgp, e_gpjax, atol=ATOL) + np.testing.assert_allclose(e_ptgp, e_ref, atol=ATOL) class TestSVGPPointsUnwhitenedRegression: diff --git a/tests/test_svgp_scipy_ref.py b/tests/test_svgp_scipy_ref.py index 3c8cb58..0f12917 100644 --- a/tests/test_svgp_scipy_ref.py +++ b/tests/test_svgp_scipy_ref.py @@ -1,103 +1,26 @@ -"""SVGP tests for likelihoods that GPJax does not implement. - -These tests are the structural equivalent of the GPJax cross-checks in -``tests/test_svgp.py``: evaluate the whitened-SVGP ELBO in PTGP and in -an independent reference at a fixed configuration, and require them to -match at atol=1e-5. The difference is the reference — GPJax doesn't ship -StudentT or NegativeBinomial likelihoods, so the reference ELBO here is -assembled from numpy + scipy (whitened predict by hand, scipy.integrate.quad -for the variational expectation, closed-form whitened KL). - -If we ever want to drop GPJax as a test dependency, the Bernoulli and -Poisson tests in ``test_svgp.py`` could be converted to this same pattern -— the reference machinery below is likelihood-agnostic and only needs a -numpy log-prob function. -""" +"""SVGP tests for likelihoods cross-checked against an independent reference. -import jax +Evaluate the whitened-SVGP ELBO in PTGP and in a numpy+scipy reference at a +fixed configuration, and require them to match at atol=1e-5. The reference +machinery lives in ``tests/_svgp_ref.py``; it is likelihood-agnostic and only +needs a numpy log-prob function, so StudentT and NegativeBinomial (which GPJax +does not ship) are checked the same way as Bernoulli and Poisson. +""" -jax.config.update("jax_enable_x64", True) +import numpy as np +import pymc as pm +import pytensor +import pytensor.tensor as pt -import numpy as np # noqa: E402 -import pymc as pm # noqa: E402 -import pytensor # noqa: E402 -import pytensor.tensor as pt # noqa: E402 -import scipy.linalg # noqa: E402 +from scipy.special import gammaln -from scipy import integrate # noqa: E402 -from scipy.special import gammaln # noqa: E402 +import ptgp as pg -import ptgp as pg # noqa: E402 +from tests._svgp_ref import fixed_config, reference_elbo ATOL = 1e-5 -# ---- Reference ELBO machinery (numpy + scipy, no PTGP internals) ---------- - - -def _matern52_numpy(X1, X2, ls, eta): - """Matern52 kernel in numpy: k(r) = eta^2 (1 + sqrt(5)r + 5r^2/3) exp(-sqrt(5)r).""" - sqd = np.sum(X1**2, axis=-1)[:, None] + np.sum(X2**2, axis=-1)[None, :] - 2.0 * X1 @ X2.T - r = np.sqrt(np.maximum(sqd, 0.0)) / ls - s5 = np.sqrt(5.0) - return eta**2 * (1.0 + s5 * r + 5.0 * r**2 / 3.0) * np.exp(-s5 * r) - - -def _whitened_predict(X, Z, ls, eta, q_mu, q_sqrt): - """Whitened-SVGP marginal predictive mean and variance at X.""" - Kzz = _matern52_numpy(Z, Z, ls, eta) - Kzx = _matern52_numpy(Z, X, ls, eta) - K_diag = np.full(X.shape[0], eta**2) # Matern52 diag is eta^2 - Lz = scipy.linalg.cholesky(Kzz, lower=True) - A = scipy.linalg.solve_triangular(Lz, Kzx, lower=True) # (M, N) - mu_f = A.T @ q_mu - var_f = K_diag - np.sum(A**2, axis=0) + np.sum((A.T @ q_sqrt) ** 2, axis=1) - return mu_f, var_f - - -def _whitened_kl(q_mu, q_sqrt): - """Closed-form KL[N(q_mu, q_sqrt q_sqrt^T) || N(0, I)].""" - M = q_mu.size - logdet_S = 2.0 * np.sum(np.log(np.abs(np.diag(q_sqrt)))) - return 0.5 * (np.sum(q_sqrt**2) + q_mu @ q_mu - M - logdet_S) - - -def _variational_expectation_quad(log_prob_fn, y, mu_f, var_f): - """Sum of per-point E_{q(f_n)}[log p(y_n|f_n)] via adaptive quadrature. - - Integrates in the standard-normal z-scale: f = mu + sqrt(var) * z. The - [-30, 30] z-range is well past the numerical support of exp(-z^2/2). - """ - total = 0.0 - for yn, m, v in zip(y, mu_f, var_f): - sd = np.sqrt(v) - - def integrand(z): - return log_prob_fn(m + sd * z, yn) * np.exp(-0.5 * z**2) / np.sqrt(2.0 * np.pi) - - val, _ = integrate.quad(integrand, -30.0, 30.0) - total += val - return total - - -def _reference_elbo(X, y, Z, q_mu, q_sqrt, ls, eta, log_prob_fn): - """Whitened-SVGP ELBO assembled from numpy + scipy.""" - mu_f, var_f = _whitened_predict(X, Z, ls, eta, q_mu, q_sqrt) - ve = _variational_expectation_quad(log_prob_fn, y, mu_f, var_f) - kl = _whitened_kl(q_mu, q_sqrt) - return ve - kl - - -def _fixed_config(rng, N=40, M=8, x_range=(-2.0, 2.0)): - """Shared fixed configuration for ELBO-match tests.""" - X = np.sort(rng.uniform(x_range[0], x_range[1], N))[:, None] - Z = np.linspace(x_range[0], x_range[1], M)[:, None] - q_mu = rng.normal(0, 0.3, M) - L = np.tril(rng.normal(0, 0.2, (M, M))) - L[np.arange(M), np.arange(M)] = np.abs(L[np.arange(M), np.arange(M)]) + 0.5 - return X, Z, q_mu, L - - # ---- StudentT ------------------------------------------------------------- @@ -218,12 +141,12 @@ def _ptgp_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val, nu_val, sig def test_elbo_match(self): rng = np.random.default_rng(5) - X, Z, q_mu_val, q_sqrt_val = _fixed_config(rng) + X, Z, q_mu_val, q_sqrt_val = fixed_config(rng) y = np.sin(X[:, 0]) + 0.3 * rng.standard_t(5.0, X.shape[0]) ls_val, eta_val, nu_val, sigma_val = 1.3, 0.9, 5.0, 0.3 e_ptgp = self._ptgp_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val, nu_val, sigma_val) - e_ref = _reference_elbo( + e_ref = reference_elbo( X, y, Z, @@ -357,7 +280,7 @@ def _ptgp_elbo(self, X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val, alpha_val): def test_elbo_match(self): rng = np.random.default_rng(7) - X, Z, q_mu_val, q_sqrt_val = _fixed_config(rng) + X, Z, q_mu_val, q_sqrt_val = fixed_config(rng) alpha_val = 2.0 mu_true = np.exp(0.5 * X[:, 0] + 0.3) p = alpha_val / (alpha_val + mu_true) @@ -365,7 +288,7 @@ def test_elbo_match(self): ls_val, eta_val = 1.3, 0.9 e_ptgp = self._ptgp_elbo(X, y, Z, q_mu_val, q_sqrt_val, ls_val, eta_val, alpha_val) - e_ref = _reference_elbo( + e_ref = reference_elbo( X, y, Z,