From 80684371665c1800c1440463a1485e4094db45f8 Mon Sep 17 00:00:00 2001 From: OttoVintola Date: Tue, 9 Jun 2026 16:32:24 +0300 Subject: [PATCH 1/2] loosen python version constraints --- pyproject.toml | 2 +- python/tests/test_sampler.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index b67a595..206ca11 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -9,7 +9,7 @@ authors = [ {name = "Otto Vintola", email="hello@ottovintola.com" }, ] description = "Rust implementation of Bayesian Additive Regression Trees for Probabilistic programming with PyMC" -requires-python = ">=3.12, <3.14" +requires-python = ">=3.12, <3.15" classifiers = [ "Programming Language :: Rust", "Programming Language :: Python :: Implementation :: CPython", diff --git a/python/tests/test_sampler.py b/python/tests/test_sampler.py index 61885e7..59f1e28 100644 --- a/python/tests/test_sampler.py +++ b/python/tests/test_sampler.py @@ -8,7 +8,7 @@ NUM_DRAWS = 600 NUM_CHAINS = 4 BATCH_SIZE = (0.1, 0.1) -NUM_TREES = 50 +NUM_TREES = 10 NUM_PARTICLES = 10 RANDOM_SEED = 42 From 8fe5008f3f55fcc173dd166b042881cd1f90c642 Mon Sep 17 00:00:00 2001 From: OttoVintola Date: Wed, 10 Jun 2026 15:48:21 +0300 Subject: [PATCH 2/2] migrate tests; fix NaN bug --- Cargo.lock | 2 +- src/data.rs | 17 ++++++ src/resampling.rs | 22 ++++++++ src/response.rs | 32 +++++++++++- src/smc.rs | 8 +-- src/splitting.rs | 128 +++++++++++++++++++++++++++++++++++++++++++++- src/tree.rs | 23 +++++---- 7 files changed, 216 insertions(+), 16 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 737bdab..beab318 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -31,7 +31,7 @@ checksum = "ace50bade8e6234aa140d9a2f552bbee1db4d353f69b8217bc503490fc1a9f26" [[package]] name = "bartrs" -version = "0.2.0" +version = "0.3.0" dependencies = [ "criterion", "numpy", diff --git a/src/data.rs b/src/data.rs index d8889e0..132dd2c 100644 --- a/src/data.rs +++ b/src/data.rs @@ -67,3 +67,20 @@ impl OwnedData { } } + + +pub trait NotNan { + fn is_valid(&self) -> bool; +} + +impl NotNan for f64 { + fn is_valid(&self) -> bool { + !self.is_nan() + } +} + +impl NotNan for i32 { + fn is_valid(&self) -> bool { + true + } +} \ No newline at end of file diff --git a/src/resampling.rs b/src/resampling.rs index 6d3c2a4..52af0b8 100644 --- a/src/resampling.rs +++ b/src/resampling.rs @@ -130,3 +130,25 @@ impl ResamplingStrategy for ResamplingStrategies { } } } + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_systematic_resample() { + let mut weights = &[0.0, 0.25, 0.75]; + let mut rng = rand::rng(); + let mut out = Vec::new(); + + ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out); + assert!(out.iter().all( |&index| index >= 1 && index < weights.len())); + + weights = &[0.5, 0.3, 0.2]; + out = Vec::new(); + ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out); + assert!(out.iter().all( |&index| index >= 0 as usize && index < weights.len())); + } + +} \ No newline at end of file diff --git a/src/response.rs b/src/response.rs index c3311c3..a944f71 100644 --- a/src/response.rs +++ b/src/response.rs @@ -211,6 +211,11 @@ impl ResponseStrategy for LinearStrategy { for &s in node_samples { let idx = s as usize; let v = unsafe { *col.uget(idx) }; + + if v.is_nan() { + continue; + } + if v <= split_val { left_idx.push(idx); } else { @@ -299,7 +304,7 @@ pub enum ResponseStrategies { impl ResponseStrategies { pub fn from_name(name: &str) -> Result { match name.to_lowercase().as_str() { - "gaussian" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)), + "constant" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)), "linear" => Ok(ResponseStrategies::Linear(LinearStrategy)), "motr" => Ok(ResponseStrategies::Motr(MotrStrategy)), _ => Err(format!( @@ -372,6 +377,7 @@ impl ResponseStrategy for ResponseStrategies { mod tests { use super::*; + // 1d-intercept #[test] fn test_fit_linear_1d_intercept() { let x = &[1.0, 2.0, 3.0, 4.0, 5.0]; @@ -389,6 +395,7 @@ mod tests { assert!((a - 0.0).abs() < 1e-6, "Expected intercept ~0.0, got {}", a); } + // 1d-slope #[test] fn test_fit_linear_1d_slope() { let x = &[1.0, 2.0, 3.0, 4.0, 5.0]; @@ -406,5 +413,28 @@ mod tests { assert!((b - 2.0).abs() < 1e-6, "Expected intercept ~2.0, got {}", b); } + + + // 1d-intercept + slope + #[test] + fn test_linear_fit() { + + let x = &[1.0, 2.0, 3.0, 4.0, 5.0]; + let y = &[3.0, 5.0, 7.0, 9.0, 11.0]; + + let noise = 0.0f64; + + let n_trees = 1; + + let Some((intercept, slope)) = LinearStrategy::fit_linear_1d(x, y, noise, n_trees) else { + panic!("Got None when was expecting intercept and slope"); + }; + + let a = intercept; + let b = slope; + + assert!((a - 1.0).abs() < 1e-6, "Expected intercept ~1.0, got {}", a); + assert!((b - 2.0).abs() < 1e-6, "Expected slope ~2.0, got {}", b); + } } \ No newline at end of file diff --git a/src/smc.rs b/src/smc.rs index 81bd46e..7c66ab7 100644 --- a/src/smc.rs +++ b/src/smc.rs @@ -8,7 +8,7 @@ use rand::distr::weighted::WeightedIndex; use rand_distr::{Distribution}; use crate::config::BartConfig; -use crate::data::DataView; +use crate::data::{DataView, NotNan}; use crate::particle::{Particle}; use crate::resampling::ResamplingStrategy; use crate::splitting::SplitRules; @@ -95,7 +95,7 @@ where MutationDecision::Reject => { particle.pop_next_expandable(); } - } + } } } @@ -190,7 +190,9 @@ fn propose_mutation( let col = data.x.column(split_var); let feature_values = node_samples .iter() - .map(|&s| unsafe { *col.uget(s as usize) }); + .map(|&s| unsafe { *col.uget(s as usize) }) + .filter( |&v| v.is_valid()); + let split_strategy = &split_rules[split_var]; let split_val = match split_strategy.sample_split_value(rng, feature_values) { diff --git a/src/splitting.rs b/src/splitting.rs index 12d0b32..f77db6e 100644 --- a/src/splitting.rs +++ b/src/splitting.rs @@ -113,7 +113,7 @@ impl SplitRule for OneHotSplit { where I: Iterator, { - data_indices.partition(|&idx| (data[[idx, feature_idx]] as i32) == threshold) + data_indices.partition(|&idx| (data[[idx, feature_idx]] as i32) <= threshold) } } @@ -170,3 +170,129 @@ impl SplitRules { } } } + + +#[cfg(test)] +mod tests { + use super::*; + use numpy::ndarray::array; + use rand::{rngs::StdRng, SeedableRng}; + + fn assert_partition_consistent( + data: &Array, + feature_idx: usize, + threshold: f64, + left: &[usize], + right: &[usize], + ) { + assert_eq!(left.len() + right.len(), data.nrows()); + + for &idx in left { + assert!(data[[idx, feature_idx]] <= threshold); + } + + for &idx in right { + assert!(data[[idx, feature_idx]] > threshold); + } + } + + #[test] + fn test_continuous_split_rule() { + let rule = ContinuousSplit; + + let mut rng = StdRng::seed_from_u64(42); + assert_eq!(rule.sample_split_value(&mut rng, vec![0.0].into_iter()), None); + + let available_values: Vec = (0..10).map(|x| x as f64).collect(); + let sv = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .expect("expected a split value"); + + let data = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0], [9.0]]; + + let (left, right) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + + assert_partition_consistent(&data, 0, sv, &left, &right); + + let (left_repeated, right_repeated) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + assert_eq!(left, left_repeated); + assert_eq!(right, right_repeated); + + let probs = (0..10_000) + .map(|_| { + let split_value = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .unwrap(); + let (left, _) = rule.split_data_indices(&data, 0, split_value, 0..data.nrows()); + let mut mask = vec![false; data.nrows()]; + for idx in left { + mask[idx] = true; + } + mask + }) + .fold(vec![0usize; data.nrows()], |mut acc, mask| { + for (i, b) in mask.into_iter().enumerate() { + if b { + acc[i] += 1; + } + } + acc + }) + .into_iter() + .map(|count| count as f64 / 10_000.0) + .collect::>(); + + assert!(probs.iter().filter(|&&p| p > 0.01).count() >= data.nrows() - 1); + assert!(probs.iter().filter(|&&p| p < 0.99).count() >= data.nrows() - 1); + } + + #[test] + fn test_one_hot_split_rule() { + let rule = OneHotSplit; + + let mut rng = StdRng::seed_from_u64(42); + assert_eq!(rule.sample_split_value(&mut rng, vec![0].into_iter()), None); + + let available_values: Vec = (0..10).collect(); + let sv = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .expect("expected a split value"); + + let data = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0], [9.0]]; + + let (left, right) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + + assert_partition_consistent(&data, 0, sv as f64, &left, &right); + + let (left_repeated, right_repeated) = rule.split_data_indices(&data, 0, sv, 0..data.nrows()); + assert_eq!(left, left_repeated); + assert_eq!(right, right_repeated); + + let probs = (0..10_000) + .map(|_| { + let split_value = rule + .sample_split_value(&mut rng, available_values.clone().into_iter()) + .unwrap(); + let (left, _) = rule.split_data_indices(&data, 0, split_value, 0..data.nrows()); + let mut mask = vec![false; data.nrows()]; + for idx in left { + mask[idx] = true; + } + mask + }) + .fold(vec![0usize; data.nrows()], |mut acc, mask| { + for (i, b) in mask.into_iter().enumerate() { + if b { + acc[i] += 1; + } + } + acc + }) + .into_iter() + .map(|count| count as f64 / 10_000.0) + .collect::>(); + + assert!(probs.iter().filter(|&&p| p > 0.01).count() >= data.nrows() - 1); + assert!(probs.iter().filter(|&&p| p < 0.99).count() >= data.nrows() - 1); + } +} \ No newline at end of file diff --git a/src/tree.rs b/src/tree.rs index 9e546cb..a9e9072 100644 --- a/src/tree.rs +++ b/src/tree.rs @@ -4,14 +4,7 @@ use pyo3::prelude::*; use pyo3::types::{PyDict, PyList}; use crate::response::{LeafKind, LeafPayload, LeafProposal}; - -// 1. DONE: remove pgbart.py from pymc-bart and use this sampler -// 2. DONE: linear terms -// 3. CURRENT: logp for samples that only affect the tree -// 4. monotonic response -// 5. Hawks example with separate BARTRVs -// 6. Reseaaaaarch - +use crate::data::NotNan; /// Bartz-style heap-indexed tree with separate internal/leaf arrays. /// @@ -387,7 +380,13 @@ impl TreeArrays { let intercept = self.linear_intercept[param_idx].get(out_idx).copied().unwrap_or(0.0); let slope = self.linear_slope[param_idx].get(out_idx).copied().unwrap_or(0.0); let mut contrib = data.column(var).to_owned(); - contrib.mapv_inplace(|x| intercept + slope * x); + contrib.mapv_inplace(|x| { + if x.is_valid() { + intercept + slope * x + } else { + 0.0 + } + }); row += &(weights.clone() * contrib); } } @@ -422,7 +421,11 @@ impl TreeArrays { for out_idx in 0..n_outputs { let intercept = self.linear_intercept[param_idx].get(out_idx).copied().unwrap_or(0.0); let slope = self.linear_slope[param_idx].get(out_idx).copied().unwrap_or(0.0); - out[[out_idx, sample_idx]] = intercept + slope * x; + out[[out_idx, sample_idx]] = if x.is_nan() { + intercept + } else { + intercept + slope * x + }; } } }