From d894f8fd0bbe77168567e91041eb92b9f5857491 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Mon, 4 May 2026 12:07:44 -0500 Subject: [PATCH 1/3] Standardize the tau(X) term for BCF when a tau_0 treatment intercept term is included in the model --- R/bcf.R | 52 +++++++++++++------------ R/posterior_transformation.R | 11 +++++- man/ForestKernelComputation.Rd | 2 +- man/computeBCFPosteriorInterval.Rd | 11 +++++- man/predict.bcfmodel.Rd | 13 ++++++- stochtree/bcf.py | 59 ++++++++++++++++++----------- test/R/testthat/test-bcf.R | 61 ++++++++++++++++++++++++++++++ test/python/test_predict.py | 58 ++++++++++++++++++++++++++++ tools/debug/gh-376.py | 46 ++++++++++++++++++++++ 9 files changed, 262 insertions(+), 51 deletions(-) create mode 100644 tools/debug/gh-376.py diff --git a/R/bcf.R b/R/bcf.R index a00b11ec..0cb49e3d 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -3076,13 +3076,12 @@ bcf <- function( tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) * y_std_train } - # tau_hat_train stores the forest-only component tau(X); compute cate_train - # (tau_0 + tau(X)) separately for the treatment term used in y_hat + # Fold tau_0 into tau_hat_train so it holds the full CATE (tau_0 + tau(X)) if (sample_tau_0) { tau_0_vec <- as.numeric(tau_0_samples) # num_retained_samples vector (scalar treatment) if (adaptive_coding) { # CATE = (b_1 - b_0) * (tau_0 + tau(X)); control adj to mu = b_0 * (tau_0 + tau(X)) - cate_train <- sweep( + tau_hat_train <- sweep( tau_hat_train, 2, (b_1_samples - b_0_samples) * tau_0_vec * y_std_train, @@ -3095,20 +3094,17 @@ bcf <- function( "+" ) } else if (!has_multivariate_treatment) { - cate_train <- sweep(tau_hat_train, 2, tau_0_vec * y_std_train, "+") + tau_hat_train <- sweep(tau_hat_train, 2, tau_0_vec * y_std_train, "+") } else { # tau_hat_train: n x p x num_retained_samples; tau_0_samples: p x num_retained_samples - cate_train <- tau_hat_train for (j in seq_len(p_tau0)) { - cate_train[, j, ] <- cate_train[, j, ] + + tau_hat_train[, j, ] <- tau_hat_train[, j, ] + outer(rep(1, nrow(X_train)), tau_0_samples[j, ] * y_std_train) } } - } else { - cate_train <- tau_hat_train } if (has_multivariate_treatment) { - tau_train_dim <- dim(cate_train) + tau_train_dim <- dim(tau_hat_train) tau_num_obs <- tau_train_dim[1] tau_num_samples <- tau_train_dim[3] treatment_term_train <- matrix( @@ -3118,11 +3114,11 @@ bcf <- function( ) for (i in 1:nrow(Z_train)) { treatment_term_train[i, ] <- colSums( - cate_train[i, , ] * Z_train[i, ] + tau_hat_train[i, , ] * Z_train[i, ] ) } } else { - treatment_term_train <- cate_train * as.numeric(Z_train) + treatment_term_train <- tau_hat_train * as.numeric(Z_train) } y_hat_train <- mu_hat_train + treatment_term_train if (has_test) { @@ -3145,10 +3141,10 @@ bcf <- function( ) * y_std_train } - # tau_hat_test stores forest-only tau(X); compute cate_test for y_hat + # Fold tau_0 into tau_hat_test so it holds the full CATE (tau_0 + tau(X)) if (sample_tau_0) { if (adaptive_coding) { - cate_test <- sweep( + tau_hat_test <- sweep( tau_hat_test, 2, (b_1_samples - b_0_samples) * tau_0_vec * y_std_train, @@ -3161,19 +3157,16 @@ bcf <- function( "+" ) } else if (!has_multivariate_treatment) { - cate_test <- sweep(tau_hat_test, 2, tau_0_vec * y_std_train, "+") + tau_hat_test <- sweep(tau_hat_test, 2, tau_0_vec * y_std_train, "+") } else { - cate_test <- tau_hat_test for (j in seq_len(p_tau0)) { - cate_test[, j, ] <- cate_test[, j, ] + + tau_hat_test[, j, ] <- tau_hat_test[, j, ] + outer(rep(1, nrow(X_test)), tau_0_samples[j, ] * y_std_train) } } - } else { - cate_test <- tau_hat_test } if (has_multivariate_treatment) { - tau_test_dim <- dim(cate_test) + tau_test_dim <- dim(tau_hat_test) tau_num_obs <- tau_test_dim[1] tau_num_samples <- tau_test_dim[3] treatment_term_test <- matrix( @@ -3183,11 +3176,11 @@ bcf <- function( ) for (i in 1:nrow(Z_test)) { treatment_term_test[i, ] <- colSums( - cate_test[i, , ] * Z_test[i, ] + tau_hat_test[i, , ] * Z_test[i, ] ) } } else { - treatment_term_test <- cate_test * as.numeric(Z_test) + treatment_term_test <- tau_hat_test * as.numeric(Z_test) } y_hat_test <- mu_hat_test + treatment_term_test } @@ -3390,7 +3383,18 @@ bcf <- function( #' that were not in the training set. #' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used. #' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". -#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". +#' @param terms (Optional) Which model terms to include in the prediction. Options include `"y_hat"`, `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"rfx"`, `"variance_forest"`, or `"all"`. +#' +#' The treatment effect terms follow a three-level hierarchy: +#' \itemize{ +#' \item `"tau"` returns `tau_0 + tau(X)`: the parametric treatment intercept (if sampled) plus the treatment forest. This matches `model$tau_hat_train` / `model$tau_hat_test`. +#' \item `"cate"` additionally folds in the random slope on treatment when random effects are fit with `rfx_model_spec = "intercept_plus_treatment"`; otherwise it is identical to `"tau"`. +#' \item The raw forest-only component (without `tau_0`) is not directly returned by this method; use `model$forests_tau` to access it. +#' } +#' +#' Similarly for the prognostic term: `"mu"` returns the prognostic forest only, while `"prognostic_function"` additionally folds in the random intercept when `rfx_model_spec` is `"intercept_only"` or `"intercept_plus_treatment"`; otherwise the two are identical. +#' +#' If a model doesn't have random effects or variance forest predictions but one of those terms is requested, the request will simply be ignored. If none of the requested terms are present, this function will return `NULL` along with a warning. Default: `"all"`. #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param ... (Optional) Other prediction parameters. #' @@ -3875,7 +3879,7 @@ predict.bcfmodel <- function( mu_hat <- pnorm(mu_hat_forest) } if (predict_tau_forest) { - tau_hat <- pnorm(tau_hat_forest) + tau_hat <- pnorm(cate_hat_forest) } if (predict_prog_function) { prognostic_function <- pnorm(prognostic_function) @@ -3897,7 +3901,7 @@ predict.bcfmodel <- function( mu_hat <- mu_hat_forest } if (predict_tau_forest) { - tau_hat <- tau_hat_forest + tau_hat <- cate_hat_forest } if (predict_prog_function) { prognostic_function <- prognostic_function diff --git a/R/posterior_transformation.R b/R/posterior_transformation.R index e3776393..8f4ab8c8 100644 --- a/R/posterior_transformation.R +++ b/R/posterior_transformation.R @@ -969,7 +969,16 @@ posterior_predictive_heuristic_multiplier <- function( #' Compute posterior credible intervals for specified terms from a fitted BCF model. Supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions. #' #' @param model_object A fitted BCF model object of class `bcfmodel`. -#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. +#' @param terms A character string specifying the model term(s) for which to compute intervals. Options are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. +#' +#' The treatment effect terms follow a three-level hierarchy: +#' \itemize{ +#' \item `"tau"` returns `tau_0 + tau(X)`: the parametric treatment intercept (if sampled) plus the treatment forest. This matches `model$tau_hat_train` / `model$tau_hat_test`. +#' \item `"cate"` additionally folds in the random slope on treatment when random effects are fit with `rfx_model_spec = "intercept_plus_treatment"`; otherwise it is identical to `"tau"`. +#' \item The raw forest-only component (without `tau_0`) is not directly returned by this method; use `model$forests_tau` to access it. +#' } +#' +#' Similarly for the prognostic term: `"mu"` returns the prognostic forest only, while `"prognostic_function"` additionally folds in the random intercept when `rfx_model_spec` is `"intercept_only"` or `"intercept_plus_treatment"`; otherwise the two are identical. #' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval). #' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear". #' @param X (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions). diff --git a/man/ForestKernelComputation.Rd b/man/ForestKernelComputation.Rd index d834495d..f41543fd 100644 --- a/man/ForestKernelComputation.Rd +++ b/man/ForestKernelComputation.Rd @@ -66,7 +66,7 @@ This function group offers utilities for evaluating this kernel. \code{computeForestLeafIndices} computes and return a vector representation of a forest's leaf predictions for every observation in a dataset. -The resulting vector has a "row-major" format that can be easily re-represented as +The resulting vector has a "tree-major" format that can be easily re-represented as as a CSR sparse matrix: elements are organized so that the first \code{n} elements correspond to leaf predictions for all \code{n} observations in a dataset for the first tree in an ensemble, the next \code{n} elements correspond to predictions for diff --git a/man/computeBCFPosteriorInterval.Rd b/man/computeBCFPosteriorInterval.Rd index 1086e28c..9c95c431 100644 --- a/man/computeBCFPosteriorInterval.Rd +++ b/man/computeBCFPosteriorInterval.Rd @@ -19,7 +19,16 @@ computeBCFPosteriorInterval( \arguments{ \item{model_object}{A fitted BCF model object of class \code{bcfmodel}.} -\item{terms}{A character string specifying the model term(s) for which to compute intervals. Options for BCF models are \code{"prognostic_function"}, \code{"mu"}, \code{"cate"}, \code{"tau"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}. Note that \code{"mu"} is only different from \code{"prognostic_function"} if random effects are included with a model spec of \code{"intercept_only"} or \code{"intercept_plus_treatment"} and \code{"tau"} is only different from \code{"cate"} if random effects are included with a model spec of \code{"intercept_plus_treatment"}.} +\item{terms}{A character string specifying the model term(s) for which to compute intervals. Options are \code{"prognostic_function"}, \code{"mu"}, \code{"cate"}, \code{"tau"}, \code{"variance_forest"}, \code{"rfx"}, or \code{"y_hat"}. + +The treatment effect terms follow a three-level hierarchy: +\itemize{ +\item \code{"tau"} returns \code{tau_0 + tau(X)}: the parametric treatment intercept (if sampled) plus the treatment forest. This matches \code{model$tau_hat_train} / \code{model$tau_hat_test}. +\item \code{"cate"} additionally folds in the random slope on treatment when random effects are fit with \code{rfx_model_spec = "intercept_plus_treatment"}; otherwise it is identical to \code{"tau"}. +\item The raw forest-only component (without \code{tau_0}) is not directly returned by this method; use \code{model$forests_tau} to access it. +} + +Similarly for the prognostic term: \code{"mu"} returns the prognostic forest only, while \code{"prognostic_function"} additionally folds in the random intercept when \code{rfx_model_spec} is \code{"intercept_only"} or \code{"intercept_plus_treatment"}; otherwise the two are identical.} \item{level}{A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95\% credible interval).} diff --git a/man/predict.bcfmodel.Rd b/man/predict.bcfmodel.Rd index f2d62ae2..0ea40f4d 100644 --- a/man/predict.bcfmodel.Rd +++ b/man/predict.bcfmodel.Rd @@ -34,7 +34,18 @@ that were not in the training set.} \item{type}{(Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".} -\item{terms}{(Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If none of the requested terms are present in a model, this function will return \code{NULL} along with a warning. Default: "all".} +\item{terms}{(Optional) Which model terms to include in the prediction. Options include \code{"y_hat"}, \code{"prognostic_function"}, \code{"mu"}, \code{"cate"}, \code{"tau"}, \code{"rfx"}, \code{"variance_forest"}, or \code{"all"}. + +The treatment effect terms follow a three-level hierarchy: +\itemize{ +\item \code{"tau"} returns \code{tau_0 + tau(X)}: the parametric treatment intercept (if sampled) plus the treatment forest. This matches \code{model$tau_hat_train} / \code{model$tau_hat_test}. +\item \code{"cate"} additionally folds in the random slope on treatment when random effects are fit with \code{rfx_model_spec = "intercept_plus_treatment"}; otherwise it is identical to \code{"tau"}. +\item The raw forest-only component (without \code{tau_0}) is not directly returned by this method; use \code{model$forests_tau} to access it. +} + +Similarly for the prognostic term: \code{"mu"} returns the prognostic forest only, while \code{"prognostic_function"} additionally folds in the random intercept when \code{rfx_model_spec} is \code{"intercept_only"} or \code{"intercept_plus_treatment"}; otherwise the two are identical. + +If a model doesn't have random effects or variance forest predictions but one of those terms is requested, the request will simply be ignored. If none of the requested terms are present, this function will return \code{NULL} along with a warning. Default: \code{"all"}.} \item{scale}{(Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing \code{y == 1}. "probability" is only valid for models fit with a probit outcome model. Default: "linear".} diff --git a/stochtree/bcf.py b/stochtree/bcf.py index ac04bda9..d1108927 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -2973,34 +2973,30 @@ def sample( self.tau_hat_train = self.tau_hat_train * adaptive_coding_weights self.mu_hat_train = self.mu_hat_train + np.squeeze(control_adj_train) self.tau_hat_train = np.squeeze(self.tau_hat_train * self.y_std) - # tau_hat_train stores the forest-only component tau(X); compute cate_train - # (tau_0 + tau(X)) separately for the treatment term used in y_hat + # Fold tau_0 into tau_hat_train so the attribute holds the full CATE (tau_0 + tau(X)) if self.sample_tau_0: tau_0_vec = self.tau_0_samples[0, :] # num_samples vector (scalar treatment) if self.adaptive_coding: # CATE = (b_1 - b_0) * (tau_0 + tau(X)); control adj to mu = b_0 * (tau_0 + tau(X)) - cate_train = self.tau_hat_train + ( + self.tau_hat_train = self.tau_hat_train + ( (self.b1_samples - self.b0_samples) * tau_0_vec * self.y_std ) self.mu_hat_train = self.mu_hat_train + ( self.b0_samples * tau_0_vec * self.y_std ) elif self.multivariate_treatment: - cate_train = self.tau_hat_train.copy() for j in range(p_tau0): - cate_train[:, :, j] = cate_train[:, :, j] + ( + self.tau_hat_train[:, :, j] = self.tau_hat_train[:, :, j] + ( self.tau_0_samples[j, :] * self.y_std ) else: - cate_train = self.tau_hat_train + tau_0_vec * self.y_std - else: - cate_train = self.tau_hat_train + self.tau_hat_train = self.tau_hat_train + tau_0_vec * self.y_std if self.multivariate_treatment: treatment_term_train = np.multiply( - np.atleast_3d(Z_train).swapaxes(1, 2), cate_train + np.atleast_3d(Z_train).swapaxes(1, 2), self.tau_hat_train ).sum(axis=2) else: - treatment_term_train = Z_train * np.squeeze(cate_train) + treatment_term_train = Z_train * np.squeeze(self.tau_hat_train) self.y_hat_train = self.mu_hat_train + treatment_term_train if self.has_test: mu_raw_test = self.forest_container_mu.forest_container_cpp.Predict( @@ -3020,31 +3016,28 @@ def sample( self.tau_hat_test = self.tau_hat_test * adaptive_coding_weights_test self.mu_hat_test = self.mu_hat_test + np.squeeze(control_adj_test) self.tau_hat_test = np.squeeze(self.tau_hat_test * self.y_std) - # tau_hat_test stores forest-only tau(X); compute cate_test for y_hat + # Fold tau_0 into tau_hat_test so the attribute holds the full CATE (tau_0 + tau(X)) if self.sample_tau_0: if self.adaptive_coding: - cate_test = self.tau_hat_test + ( + self.tau_hat_test = self.tau_hat_test + ( (self.b1_samples - self.b0_samples) * tau_0_vec * self.y_std ) self.mu_hat_test = self.mu_hat_test + ( self.b0_samples * tau_0_vec * self.y_std ) elif self.multivariate_treatment: - cate_test = self.tau_hat_test.copy() for j in range(p_tau0): - cate_test[:, :, j] = cate_test[:, :, j] + ( + self.tau_hat_test[:, :, j] = self.tau_hat_test[:, :, j] + ( self.tau_0_samples[j, :] * self.y_std ) else: - cate_test = self.tau_hat_test + tau_0_vec * self.y_std - else: - cate_test = self.tau_hat_test + self.tau_hat_test = self.tau_hat_test + tau_0_vec * self.y_std if self.multivariate_treatment: treatment_term_test = np.multiply( - np.atleast_3d(Z_test).swapaxes(1, 2), cate_test + np.atleast_3d(Z_test).swapaxes(1, 2), self.tau_hat_test ).sum(axis=2) else: - treatment_term_test = Z_test * np.squeeze(cate_test) + treatment_term_test = Z_test * np.squeeze(self.tau_hat_test) self.y_hat_test = self.mu_hat_test + treatment_term_test # TODO: make rfx_preds_train and rfx_preds_test persistent properties @@ -3140,7 +3133,17 @@ def predict( type : str, optional Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BART model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior". terms : str, optional - Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is request, the request will simply be ignored. If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all". + Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". + + The treatment effect terms follow a three-level hierarchy: + + - ``"tau"`` returns ``tau_0 + tau(X)``: the parametric treatment intercept (if sampled) plus the treatment forest. This matches ``model.tau_hat_train`` / ``model.tau_hat_test``. + - ``"cate"`` additionally folds in the random slope on treatment when random effects are fit with ``rfx_model_spec="intercept_plus_treatment"``; otherwise it is identical to ``"tau"``. + - The raw forest-only component (without ``tau_0``) is not directly returned by this method; use ``model.forest_container_tau`` to access it. + + Similarly for the prognostic term: ``"mu"`` returns the prognostic forest only, while ``"prognostic_function"`` additionally folds in the random intercept when ``rfx_model_spec`` is ``"intercept_only"`` or ``"intercept_plus_treatment"``; otherwise the two are identical. + + If a model doesn't have mean forest, random effects, or variance forest predictions, but one of those terms is requested, the request will simply be ignored. If none of the requested terms are present in a model, this function will return ``None`` along with a warning. Default: "all". scale : str, optional Scale on which to return predictions. Options are "linear" (the default), which returns predictions on the original outcome scale, and "probit", which returns predictions on the probit (latent) scale. Only applicable for models fit with probit link. @@ -3485,7 +3488,7 @@ def predict( if predict_mu_forest: mu_x = norm.cdf(mu_x_forest) if predict_tau_forest: - tau_x = norm.cdf(tau_x_forest) + tau_x = norm.cdf(cate_x_forest) if predict_prog_function: prognostic_function = norm.cdf(prognostic_function) if predict_cate_function: @@ -3500,7 +3503,7 @@ def predict( if predict_mu_forest: mu_x = mu_x_forest if predict_tau_forest: - tau_x = tau_x_forest + tau_x = cate_x_forest if predict_prog_function: prognostic_function = prognostic_function if predict_cate_function: @@ -3722,7 +3725,17 @@ def compute_posterior_interval( rfx_basis : np.array, optional Optional matrix of basis function evaluations for random effects. Required if the requested term includes random effects. terms : str, optional - Character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"tau_0"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Defaults to `"all"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`. `"tau_0"` is only available when the model was fit with `sample_intercept = True`. + Which model terms to include in the interval. This can be a single term or a list of model terms. Options include ``"y_hat"``, ``"prognostic_function"``, ``"mu"``, ``"cate"``, ``"tau"``, ``"rfx"``, ``"variance_forest"``, or ``"all"``. + + The treatment effect terms follow a three-level hierarchy: + + - ``"tau"`` returns ``tau_0 + tau(X)``: the parametric treatment intercept (if sampled) plus the treatment forest. This matches ``model.tau_hat_train`` / ``model.tau_hat_test``. + - ``"cate"`` additionally folds in the random slope on treatment when random effects are fit with ``rfx_model_spec="intercept_plus_treatment"``; otherwise it is identical to ``"tau"``. + - The raw forest-only component (without ``tau_0``) is not directly returned by this method; use ``model.forest_container_tau`` to access it. + + Similarly for the prognostic term: ``"mu"`` returns the prognostic forest only, while ``"prognostic_function"`` additionally folds in the random intercept when ``rfx_model_spec`` is ``"intercept_only"`` or ``"intercept_plus_treatment"``; otherwise the two are identical. + + Default: ``"all"``. scale : str, optional Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Defaults to `"linear"`. level : float, optional diff --git a/test/R/testthat/test-bcf.R b/test/R/testthat/test-bcf.R index 552227d3..36502d70 100644 --- a/test/R/testthat/test-bcf.R +++ b/test/R/testthat/test-bcf.R @@ -1089,3 +1089,64 @@ test_that("predict.bcfmodel works with data frame X when internal propensity mod ) expect_equal(nrow(result$y_hat), length(test_inds)) }) + +test_that("predict(terms='tau') == tau_hat_test, tau==cate without treatment RFX, y_hat decomposition holds", { + skip_on_cran() + + set.seed(42) + n <- 200 + p <- 5 + X <- matrix(runif(n * p), nrow = n) + pi_x <- 0.4 + 0.2 * X[, 1] + Z <- rbinom(n, 1, pi_x) + mu_x <- 1 + 2 * X[, 1] + tau_x <- 0.5 + X[, 2] + y <- mu_x + tau_x * Z + rnorm(n) + + n_train <- 160 + train_inds <- seq_len(n_train) + test_inds <- seq(n_train + 1, n) + X_train <- X[train_inds, ]; X_test <- X[test_inds, ] + Z_train <- Z[train_inds]; Z_test <- Z[test_inds] + y_train <- y[train_inds] + pi_train <- pi_x[train_inds]; pi_test <- pi_x[test_inds] + + # Fit BCF with sample_intercept = TRUE (default) + bcf_model <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 5, num_burnin = 0, num_mcmc = 10 + ) + + # predict(terms = "tau") must match tau_hat_test exactly + tau_from_predict <- predict(bcf_model, X = X_test, Z = Z_test, + propensity = pi_test, terms = "tau") + expect_equal(tau_from_predict, bcf_model$tau_hat_test) + + # predict(terms = "tau") == predict(terms = "cate") when no treatment RFX + cate_from_predict <- predict(bcf_model, X = X_test, Z = Z_test, + propensity = pi_test, terms = "cate") + expect_equal(tau_from_predict, cate_from_predict) + + # y_hat_test = mu_hat_test + Z_test * tau_hat_test (stored attributes decompose) + expected_y <- bcf_model$mu_hat_test + sweep(bcf_model$tau_hat_test, 1, as.numeric(Z_test), "*") + expect_equal(bcf_model$y_hat_test, expected_y) + + # y_hat_train = mu_hat_train + Z_train * tau_hat_train + expected_y_train <- bcf_model$mu_hat_train + sweep(bcf_model$tau_hat_train, 1, as.numeric(Z_train), "*") + expect_equal(bcf_model$y_hat_train, expected_y_train) + + # With sample_intercept = FALSE, tau includes only the forest; decomposition still holds + bcf_no_int <- bcf( + X_train = X_train, Z_train = Z_train, y_train = y_train, + propensity_train = pi_train, X_test = X_test, Z_test = Z_test, + propensity_test = pi_test, num_gfr = 5, num_burnin = 0, num_mcmc = 10, + treatment_effect_forest_params = list(sample_intercept = FALSE) + ) + tau_no_int <- predict(bcf_no_int, X = X_test, Z = Z_test, + propensity = pi_test, terms = "tau") + expect_equal(tau_no_int, bcf_no_int$tau_hat_test) + expected_y_no_int <- bcf_no_int$mu_hat_test + + sweep(bcf_no_int$tau_hat_test, 1, as.numeric(Z_test), "*") + expect_equal(bcf_no_int$y_hat_test, expected_y_no_int) +}) diff --git a/test/python/test_predict.py b/test/python/test_predict.py index 031b53df..9b568e96 100644 --- a/test/python/test_predict.py +++ b/test/python/test_predict.py @@ -528,6 +528,64 @@ def g(x5): terms=["variance_forest"], ) + def test_bcf_tau_cate_decomposition(self): + """predict(terms='tau') == tau_hat_test, tau==cate without treatment RFX, + and y_hat = mu_hat + Z * tau_hat for stored train/test attributes.""" + rng = np.random.default_rng(42) + n = 200 + p = 5 + X = rng.uniform(size=(n, p)) + pi_x = 0.4 + 0.2 * X[:, 0] + Z = rng.binomial(1, pi_x).astype(float) + mu_x = 1 + 2 * X[:, 0] + tau_x = 0.5 + X[:, 1] + y = mu_x + tau_x * Z + rng.normal(size=n) + train_inds, test_inds = train_test_split(np.arange(n), test_size=0.2, random_state=0) + X_train, X_test = X[train_inds], X[test_inds] + Z_train, Z_test = Z[train_inds], Z[test_inds] + y_train = y[train_inds] + pi_train, pi_test = pi_x[train_inds], pi_x[test_inds] + + # Fit BCF with sample_intercept=True (default) + bcf_model = BCFModel() + bcf_model.sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, + propensity_test=pi_test, num_gfr=5, num_burnin=0, num_mcmc=10, + ) + + # predict(terms="tau") must match tau_hat_test exactly + tau_from_predict = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, terms="tau") + np.testing.assert_allclose(tau_from_predict, bcf_model.tau_hat_test) + + # predict(terms="tau") == predict(terms="cate") when no treatment RFX + cate_from_predict = bcf_model.predict(X=X_test, Z=Z_test, propensity=pi_test, terms="cate") + np.testing.assert_allclose(tau_from_predict, cate_from_predict) + + # y_hat_test = mu_hat_test + Z_test * tau_hat_test (stored attributes decompose) + expected_y = bcf_model.mu_hat_test + Z_test[:, None] * bcf_model.tau_hat_test + np.testing.assert_allclose(bcf_model.y_hat_test, expected_y) + + # y_hat_train = mu_hat_train + Z_train * tau_hat_train + expected_y_train = bcf_model.mu_hat_train + Z_train[:, None] * bcf_model.tau_hat_train + np.testing.assert_allclose(bcf_model.y_hat_train, expected_y_train) + + # With sample_intercept=False, tau includes only the forest; mean ATE still recoverable + bcf_no_intercept = BCFModel() + bcf_no_intercept.sample( + X_train=X_train, Z_train=Z_train, y_train=y_train, + propensity_train=pi_train, X_test=X_test, Z_test=Z_test, + propensity_test=pi_test, num_gfr=5, num_burnin=0, num_mcmc=10, + treatment_effect_forest_params={"sample_intercept": False}, + ) + tau_no_intercept = bcf_no_intercept.predict( + X=X_test, Z=Z_test, propensity=pi_test, terms="tau" + ) + np.testing.assert_allclose(tau_no_intercept, bcf_no_intercept.tau_hat_test) + # y_hat decomposition still holds without tau_0 + expected_y_no_int = bcf_no_intercept.mu_hat_test + Z_test[:, None] * bcf_no_intercept.tau_hat_test + np.testing.assert_allclose(bcf_no_intercept.y_hat_test, expected_y_no_int) + def test_bart_cloglog_binary_interval_and_contrast(self): # Generate binary cloglog data rng = np.random.default_rng(42) diff --git a/tools/debug/gh-376.py b/tools/debug/gh-376.py new file mode 100644 index 00000000..47cc4dd7 --- /dev/null +++ b/tools/debug/gh-376.py @@ -0,0 +1,46 @@ +""" +Replication script from: +https://github.com/StochasticTree/stochtree/issues/376 +""" +import warnings +warnings.filterwarnings("ignore") +import numpy as np +from sklearn.ensemble import HistGradientBoostingClassifier +from sklearn.model_selection import cross_val_predict +from threadpoolctl import threadpool_limits +from stochtree import BCFModel + + +def confounded_dgp(n=445, p=8, n_seed=0): + """Small confounded DGP roughly mimicking LaLonde-style earnings data.""" + rng = np.random.default_rng(n_seed) + X = rng.normal(size=(n, p)) + e = 1 / (1 + np.exp(-(1.5 * X[:, 0] - 0.7 * X[:, 1]))) # confounded propensity + T = rng.binomial(1, e).astype(float) + tau = 1500 + 800 * X[:, 2] - 500 * X[:, 3] # heterogeneous, mean ~1500 + Y0 = np.maximum(np.random.default_rng(n_seed + 1).normal( + loc=3000 + 1500 * X[:, 0], scale=4000, size=n), 0) + Y = Y0 + T * tau + return X.astype(float), T, Y.astype(float) + +X, T, Y = confounded_dgp(n_seed=0) + +for seed in [0, 1, 7, 13, 42, 100, 2026]: + propensity = cross_val_predict( + HistGradientBoostingClassifier(max_iter=200, random_state=int(seed)), + X, T.astype(int), method="predict_proba", cv=5, + )[:, 1] + propensity = np.clip(propensity, 0.01, 0.99) + + m = BCFModel() + with threadpool_limits(limits=1): + m.sample( + X_train=X, Z_train=T, y_train=Y, + propensity_train=propensity, + num_gfr=5, num_burnin=200, num_mcmc=200, + general_params={"random_seed": int(seed) % (2**31)}, + ) + tau_hat = m.predict(X=X, Z=T, propensity=propensity, + type="posterior", terms="cate") + print(f"seed={seed:>4} BCF ATE={float(np.mean(tau_hat)):>+8.0f}") + From c8592fc36c52c1084c1c7639d2c979aad8a3c59b Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Tue, 5 May 2026 23:29:16 -0500 Subject: [PATCH 2/3] Fixed bug in prognostic forest prediction for BCF --- R/bcf.R | 3 ++- stochtree/bcf.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/R/bcf.R b/R/bcf.R index 0cb49e3d..27bef862 100644 --- a/R/bcf.R +++ b/R/bcf.R @@ -3576,7 +3576,8 @@ predict.bcfmodel <- function( predict_mu_forest_intermediate <- ((predict_y_hat || predict_prog_function) && has_mu_forest) predict_tau_forest_intermediate <- ((predict_y_hat || - predict_cate_function) && + predict_cate_function || + (object$model_params$adaptive_coding && (predict_mu_forest || predict_prog_function))) && has_tau_forest) # Make sure covariates are matrix or data frame diff --git a/stochtree/bcf.py b/stochtree/bcf.py index d1108927..191212bc 100644 --- a/stochtree/bcf.py +++ b/stochtree/bcf.py @@ -3242,7 +3242,8 @@ def predict( predict_y_hat or predict_prog_function ) and has_mu_forest predict_tau_forest_intermediate = ( - predict_y_hat or predict_cate_function + predict_y_hat or predict_cate_function or + (self.adaptive_coding and (predict_mu_forest or predict_prog_function)) ) and has_tau_forest if not self.is_sampled(): From fea9aa003e32764562428b342b0dcc1429c1bcd8 Mon Sep 17 00:00:00 2001 From: Drew Herren Date: Wed, 6 May 2026 00:06:46 -0500 Subject: [PATCH 3/3] Update NEWS.md --- NEWS.md | 1 + 1 file changed, 1 insertion(+) diff --git a/NEWS.md b/NEWS.md index d191d521..663fa007 100644 --- a/NEWS.md +++ b/NEWS.md @@ -4,6 +4,7 @@ ## Bug Fixes +* Fixed BCF prediction bug for prognostic function and mu(X) when adaptive coding is used and the tau(x) forest is not also requested by a prediction call [#377](https://github.com/StochasticTree/stochtree/pull/377) * Fixed R BCF prediction bug when covariates are passed as dataframes and an internal propensity is sampled [#374](https://github.com/StochasticTree/stochtree/issues/374) ## Documentation and Other Maintenance