Skip to content

Remove GPJax dependency from the test suite#37

Open
bwengals wants to merge 2 commits into
mainfrom
remove-gpjax-test-dep
Open

Remove GPJax dependency from the test suite#37
bwengals wants to merge 2 commits into
mainfrom
remove-gpjax-test-dep

Conversation

@bwengals

@bwengals bwengals commented Jun 11, 2026

Copy link
Copy Markdown
Collaborator

Closes #36.

Replaces the GPJax reference comparisons with self-contained correctness checks, so the suite no longer breaks on GPJax API changes (the reason #35 had to pin gpjax==0.13.6).

What changed

  • Kernels (test_stationary.py): compared against closed-form analytic references in NumPy.
  • Gaussian / Bernoulli likelihoods: closed-form VE, and an independent NumPy Gauss-Hermite quadrature.
  • SVGP Bernoulli/Poisson ELBO (test_svgp.py): reuse the numpy+scipy reference that already backs the StudentT/NegBinom tests, factored into tests/_svgp_ref.py.
  • Kernel fix (stationary.py): Matern52/32 used pt.sqrt(5.0), which autocasts the literal to float32 regardless of floatX. Typed via pt.constant(..., dtype=floatX) so it follows the backend; Matern now matches the reference to ~3e-15. Other kernel literals are float32-exact, so no other changes.
  • Env: drop gpjax (plus now-unused jax/jaxlib/optax).

Test plan

  • pytest tests/ — 265 passed
  • pre-commit run --all-files clean
  • scripts/run_mypy.py — 37/37 pass

bwengals added 2 commits June 11, 2026 00:37
Replace the GPJax reference comparisons with self-contained correctness
checks so the suite no longer breaks on GPJax API changes (e.g. the 0.15
gram-operator change that required pinning gpjax==0.13.6).

- Kernels: compare ExpQuad/Matern52/32 against closed-form analytic
  references in numpy. Matern uses a looser tolerance to absorb the
  float32 sqrt(5)/sqrt(3) constant pytensor evaluates in floatX.
- Gaussian likelihood: closed-form variational expectation.
- Bernoulli likelihood: independent numpy Gauss-Hermite quadrature.
- SVGP Bernoulli/Poisson ELBO: reuse the numpy+scipy reference machinery
  (factored into tests/_svgp_ref.py and shared with the StudentT/NegBinom
  tests), as the scipy-ref module already anticipated.
- Drop gpjax from conda_envs/environment-test.yaml.

Closes #36.
… env

- Matern52/32 used pt.sqrt(5.0)/pt.sqrt(3.0), which autocasts the literal to
  float32 regardless of config.floatX, injecting ~1e-7 error into an otherwise
  float64 kernel. Use pt.constant(..., dtype=config.floatX) so the constant
  follows the backend's precision; Matern now matches the analytic reference to
  ~3e-15 in float64. Tighten the kernel test tolerance accordingly.
- ptgp, and the test suite after the GPJax removal, never import jax/optax, so
  drop jax/jaxlib/optax from conda_envs/environment-test.yaml.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Remove GPJax dependency from the test suite

1 participant