Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ authors = [
{name = "Otto Vintola", email="hello@ottovintola.com" },
]
description = "Rust implementation of Bayesian Additive Regression Trees for Probabilistic programming with PyMC"
requires-python = ">=3.12, <3.14"
requires-python = ">=3.12, <3.15"
classifiers = [
"Programming Language :: Rust",
"Programming Language :: Python :: Implementation :: CPython",
Expand Down
2 changes: 1 addition & 1 deletion python/tests/test_sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
NUM_DRAWS = 600
NUM_CHAINS = 4
BATCH_SIZE = (0.1, 0.1)
NUM_TREES = 50
NUM_TREES = 10
NUM_PARTICLES = 10
RANDOM_SEED = 42

Expand Down
17 changes: 17 additions & 0 deletions src/data.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,3 +67,20 @@ impl OwnedData {
}

}


pub trait NotNan {
fn is_valid(&self) -> bool;
}

impl NotNan for f64 {
fn is_valid(&self) -> bool {
!self.is_nan()
}
}

impl NotNan for i32 {
fn is_valid(&self) -> bool {
true
}
}
22 changes: 22 additions & 0 deletions src/resampling.rs
Original file line number Diff line number Diff line change
Expand Up @@ -130,3 +130,25 @@ impl ResamplingStrategy for ResamplingStrategies {
}
}
}


#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_systematic_resample() {
let mut weights = &[0.0, 0.25, 0.75];
let mut rng = rand::rng();
let mut out = Vec::new();

ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out);
assert!(out.iter().all( |&index| index >= 1 && index < weights.len()));

weights = &[0.5, 0.3, 0.2];
out = Vec::new();
ResamplingStrategies::Systematic(SystematicResampling).resample_into(&mut rng, weights, &mut out);
assert!(out.iter().all( |&index| index >= 0 as usize && index < weights.len()));
}

}
32 changes: 31 additions & 1 deletion src/response.rs
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,11 @@ impl ResponseStrategy for LinearStrategy {
for &s in node_samples {
let idx = s as usize;
let v = unsafe { *col.uget(idx) };

if v.is_nan() {
continue;
}

if v <= split_val {
left_idx.push(idx);
} else {
Expand Down Expand Up @@ -299,7 +304,7 @@ pub enum ResponseStrategies {
impl ResponseStrategies {
pub fn from_name(name: &str) -> Result<Self, String> {
match name.to_lowercase().as_str() {
"gaussian" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)),
"constant" => Ok(ResponseStrategies::Gaussian(GaussianResponseStrategy)),
"linear" => Ok(ResponseStrategies::Linear(LinearStrategy)),
"motr" => Ok(ResponseStrategies::Motr(MotrStrategy)),
_ => Err(format!(
Expand Down Expand Up @@ -372,6 +377,7 @@ impl ResponseStrategy for ResponseStrategies {
mod tests {
use super::*;

// 1d-intercept
#[test]
fn test_fit_linear_1d_intercept() {
let x = &[1.0, 2.0, 3.0, 4.0, 5.0];
Expand All @@ -389,6 +395,7 @@ mod tests {
assert!((a - 0.0).abs() < 1e-6, "Expected intercept ~0.0, got {}", a);
}

// 1d-slope
#[test]
fn test_fit_linear_1d_slope() {
let x = &[1.0, 2.0, 3.0, 4.0, 5.0];
Expand All @@ -406,5 +413,28 @@ mod tests {
assert!((b - 2.0).abs() < 1e-6, "Expected intercept ~2.0, got {}", b);

}


// 1d-intercept + slope
#[test]
fn test_linear_fit() {

let x = &[1.0, 2.0, 3.0, 4.0, 5.0];
let y = &[3.0, 5.0, 7.0, 9.0, 11.0];

let noise = 0.0f64;

let n_trees = 1;

let Some((intercept, slope)) = LinearStrategy::fit_linear_1d(x, y, noise, n_trees) else {
panic!("Got None when was expecting intercept and slope");
};

let a = intercept;
let b = slope;

assert!((a - 1.0).abs() < 1e-6, "Expected intercept ~1.0, got {}", a);
assert!((b - 2.0).abs() < 1e-6, "Expected slope ~2.0, got {}", b);
}

}
8 changes: 5 additions & 3 deletions src/smc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@ use rand::distr::weighted::WeightedIndex;
use rand_distr::{Distribution};

use crate::config::BartConfig;
use crate::data::DataView;
use crate::data::{DataView, NotNan};
use crate::particle::{Particle};
use crate::resampling::ResamplingStrategy;
use crate::splitting::SplitRules;
Expand Down Expand Up @@ -95,7 +95,7 @@ where
MutationDecision::Reject => {
particle.pop_next_expandable();
}
}
}
}
}

Expand Down Expand Up @@ -190,7 +190,9 @@ fn propose_mutation(
let col = data.x.column(split_var);
let feature_values = node_samples
.iter()
.map(|&s| unsafe { *col.uget(s as usize) });
.map(|&s| unsafe { *col.uget(s as usize) })
.filter( |&v| v.is_valid());


let split_strategy = &split_rules[split_var];
let split_val = match split_strategy.sample_split_value(rng, feature_values) {
Expand Down
128 changes: 127 additions & 1 deletion src/splitting.rs
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@ impl SplitRule for OneHotSplit {
where
I: Iterator<Item = usize>,
{
data_indices.partition(|&idx| (data[[idx, feature_idx]] as i32) == threshold)
data_indices.partition(|&idx| (data[[idx, feature_idx]] as i32) <= threshold)
}
}

Expand Down Expand Up @@ -170,3 +170,129 @@ impl SplitRules {
}
}
}


#[cfg(test)]
mod tests {
use super::*;
use numpy::ndarray::array;
use rand::{rngs::StdRng, SeedableRng};

fn assert_partition_consistent(
data: &Array<f64, Ix2>,
feature_idx: usize,
threshold: f64,
left: &[usize],
right: &[usize],
) {
assert_eq!(left.len() + right.len(), data.nrows());

for &idx in left {
assert!(data[[idx, feature_idx]] <= threshold);
}

for &idx in right {
assert!(data[[idx, feature_idx]] > threshold);
}
}

#[test]
fn test_continuous_split_rule() {
let rule = ContinuousSplit;

let mut rng = StdRng::seed_from_u64(42);
assert_eq!(rule.sample_split_value(&mut rng, vec![0.0].into_iter()), None);

let available_values: Vec<f64> = (0..10).map(|x| x as f64).collect();
let sv = rule
.sample_split_value(&mut rng, available_values.clone().into_iter())
.expect("expected a split value");

let data = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0], [9.0]];

let (left, right) = rule.split_data_indices(&data, 0, sv, 0..data.nrows());

assert_partition_consistent(&data, 0, sv, &left, &right);

let (left_repeated, right_repeated) = rule.split_data_indices(&data, 0, sv, 0..data.nrows());
assert_eq!(left, left_repeated);
assert_eq!(right, right_repeated);

let probs = (0..10_000)
.map(|_| {
let split_value = rule
.sample_split_value(&mut rng, available_values.clone().into_iter())
.unwrap();
let (left, _) = rule.split_data_indices(&data, 0, split_value, 0..data.nrows());
let mut mask = vec![false; data.nrows()];
for idx in left {
mask[idx] = true;
}
mask
})
.fold(vec![0usize; data.nrows()], |mut acc, mask| {
for (i, b) in mask.into_iter().enumerate() {
if b {
acc[i] += 1;
}
}
acc
})
.into_iter()
.map(|count| count as f64 / 10_000.0)
.collect::<Vec<_>>();

assert!(probs.iter().filter(|&&p| p > 0.01).count() >= data.nrows() - 1);
assert!(probs.iter().filter(|&&p| p < 0.99).count() >= data.nrows() - 1);
}

#[test]
fn test_one_hot_split_rule() {
let rule = OneHotSplit;

let mut rng = StdRng::seed_from_u64(42);
assert_eq!(rule.sample_split_value(&mut rng, vec![0].into_iter()), None);

let available_values: Vec<i32> = (0..10).collect();
let sv = rule
.sample_split_value(&mut rng, available_values.clone().into_iter())
.expect("expected a split value");

let data = array![[0.0], [1.0], [2.0], [3.0], [4.0], [5.0], [6.0], [7.0], [8.0], [9.0]];

let (left, right) = rule.split_data_indices(&data, 0, sv, 0..data.nrows());

assert_partition_consistent(&data, 0, sv as f64, &left, &right);

let (left_repeated, right_repeated) = rule.split_data_indices(&data, 0, sv, 0..data.nrows());
assert_eq!(left, left_repeated);
assert_eq!(right, right_repeated);

let probs = (0..10_000)
.map(|_| {
let split_value = rule
.sample_split_value(&mut rng, available_values.clone().into_iter())
.unwrap();
let (left, _) = rule.split_data_indices(&data, 0, split_value, 0..data.nrows());
let mut mask = vec![false; data.nrows()];
for idx in left {
mask[idx] = true;
}
mask
})
.fold(vec![0usize; data.nrows()], |mut acc, mask| {
for (i, b) in mask.into_iter().enumerate() {
if b {
acc[i] += 1;
}
}
acc
})
.into_iter()
.map(|count| count as f64 / 10_000.0)
.collect::<Vec<_>>();

assert!(probs.iter().filter(|&&p| p > 0.01).count() >= data.nrows() - 1);
assert!(probs.iter().filter(|&&p| p < 0.99).count() >= data.nrows() - 1);
}
}
23 changes: 13 additions & 10 deletions src/tree.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,7 @@ use pyo3::prelude::*;
use pyo3::types::{PyDict, PyList};

use crate::response::{LeafKind, LeafPayload, LeafProposal};

// 1. DONE: remove pgbart.py from pymc-bart and use this sampler
// 2. DONE: linear terms
// 3. CURRENT: logp for samples that only affect the tree
// 4. monotonic response
// 5. Hawks example with separate BARTRVs
// 6. Reseaaaaarch

use crate::data::NotNan;

/// Bartz-style heap-indexed tree with separate internal/leaf arrays.
///
Expand Down Expand Up @@ -387,7 +380,13 @@ impl TreeArrays {
let intercept = self.linear_intercept[param_idx].get(out_idx).copied().unwrap_or(0.0);
let slope = self.linear_slope[param_idx].get(out_idx).copied().unwrap_or(0.0);
let mut contrib = data.column(var).to_owned();
contrib.mapv_inplace(|x| intercept + slope * x);
contrib.mapv_inplace(|x| {
if x.is_valid() {
intercept + slope * x
} else {
0.0
}
});
row += &(weights.clone() * contrib);
}
}
Expand Down Expand Up @@ -422,7 +421,11 @@ impl TreeArrays {
for out_idx in 0..n_outputs {
let intercept = self.linear_intercept[param_idx].get(out_idx).copied().unwrap_or(0.0);
let slope = self.linear_slope[param_idx].get(out_idx).copied().unwrap_or(0.0);
out[[out_idx, sample_idx]] = intercept + slope * x;
out[[out_idx, sample_idx]] = if x.is_nan() {
intercept
} else {
intercept + slope * x
};
}
}
}
Expand Down