Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

## Bug Fixes

* Fixed BCF prediction bug for prognostic function and mu(X) when adaptive coding is used and the tau(x) forest is not also requested by a prediction call [#377](https://github.com/StochasticTree/stochtree/pull/377)
* Fixed R BCF prediction bug when covariates are passed as dataframes and an internal propensity is sampled [#374](https://github.com/StochasticTree/stochtree/issues/374)

## Documentation and Other Maintenance
Expand Down
55 changes: 30 additions & 25 deletions R/bcf.R
Original file line number Diff line number Diff line change
Expand Up @@ -3076,13 +3076,12 @@ bcf <- function(
tau_hat_train <- forest_samples_tau$predict_raw(forest_dataset_train) *
y_std_train
}
# tau_hat_train stores the forest-only component tau(X); compute cate_train
# (tau_0 + tau(X)) separately for the treatment term used in y_hat
# Fold tau_0 into tau_hat_train so it holds the full CATE (tau_0 + tau(X))
if (sample_tau_0) {
tau_0_vec <- as.numeric(tau_0_samples) # num_retained_samples vector (scalar treatment)
if (adaptive_coding) {
# CATE = (b_1 - b_0) * (tau_0 + tau(X)); control adj to mu = b_0 * (tau_0 + tau(X))
cate_train <- sweep(
tau_hat_train <- sweep(
tau_hat_train,
2,
(b_1_samples - b_0_samples) * tau_0_vec * y_std_train,
Expand All @@ -3095,20 +3094,17 @@ bcf <- function(
"+"
)
} else if (!has_multivariate_treatment) {
cate_train <- sweep(tau_hat_train, 2, tau_0_vec * y_std_train, "+")
tau_hat_train <- sweep(tau_hat_train, 2, tau_0_vec * y_std_train, "+")
} else {
# tau_hat_train: n x p x num_retained_samples; tau_0_samples: p x num_retained_samples
cate_train <- tau_hat_train
for (j in seq_len(p_tau0)) {
cate_train[, j, ] <- cate_train[, j, ] +
tau_hat_train[, j, ] <- tau_hat_train[, j, ] +
outer(rep(1, nrow(X_train)), tau_0_samples[j, ] * y_std_train)
}
}
} else {
cate_train <- tau_hat_train
}
if (has_multivariate_treatment) {
tau_train_dim <- dim(cate_train)
tau_train_dim <- dim(tau_hat_train)
tau_num_obs <- tau_train_dim[1]
tau_num_samples <- tau_train_dim[3]
treatment_term_train <- matrix(
Expand All @@ -3118,11 +3114,11 @@ bcf <- function(
)
for (i in 1:nrow(Z_train)) {
treatment_term_train[i, ] <- colSums(
cate_train[i, , ] * Z_train[i, ]
tau_hat_train[i, , ] * Z_train[i, ]
)
}
} else {
treatment_term_train <- cate_train * as.numeric(Z_train)
treatment_term_train <- tau_hat_train * as.numeric(Z_train)
}
y_hat_train <- mu_hat_train + treatment_term_train
if (has_test) {
Expand All @@ -3145,10 +3141,10 @@ bcf <- function(
) *
y_std_train
}
# tau_hat_test stores forest-only tau(X); compute cate_test for y_hat
# Fold tau_0 into tau_hat_test so it holds the full CATE (tau_0 + tau(X))
if (sample_tau_0) {
if (adaptive_coding) {
cate_test <- sweep(
tau_hat_test <- sweep(
tau_hat_test,
2,
(b_1_samples - b_0_samples) * tau_0_vec * y_std_train,
Expand All @@ -3161,19 +3157,16 @@ bcf <- function(
"+"
)
} else if (!has_multivariate_treatment) {
cate_test <- sweep(tau_hat_test, 2, tau_0_vec * y_std_train, "+")
tau_hat_test <- sweep(tau_hat_test, 2, tau_0_vec * y_std_train, "+")
} else {
cate_test <- tau_hat_test
for (j in seq_len(p_tau0)) {
cate_test[, j, ] <- cate_test[, j, ] +
tau_hat_test[, j, ] <- tau_hat_test[, j, ] +
outer(rep(1, nrow(X_test)), tau_0_samples[j, ] * y_std_train)
}
}
} else {
cate_test <- tau_hat_test
}
if (has_multivariate_treatment) {
tau_test_dim <- dim(cate_test)
tau_test_dim <- dim(tau_hat_test)
tau_num_obs <- tau_test_dim[1]
tau_num_samples <- tau_test_dim[3]
treatment_term_test <- matrix(
Expand All @@ -3183,11 +3176,11 @@ bcf <- function(
)
for (i in 1:nrow(Z_test)) {
treatment_term_test[i, ] <- colSums(
cate_test[i, , ] * Z_test[i, ]
tau_hat_test[i, , ] * Z_test[i, ]
)
}
} else {
treatment_term_test <- cate_test * as.numeric(Z_test)
treatment_term_test <- tau_hat_test * as.numeric(Z_test)
}
y_hat_test <- mu_hat_test + treatment_term_test
}
Expand Down Expand Up @@ -3390,7 +3383,18 @@ bcf <- function(
#' that were not in the training set.
#' @param rfx_basis (Optional) Test set basis for "random-slope" regression in additive random effects model. If the model was sampled with a random effects `model_spec` of "intercept_only" or "intercept_plus_treatment", this is optional, but if it is provided, it will be used.
#' @param type (Optional) Type of prediction to return. Options are "mean", which averages the predictions from every draw of a BCF model, and "posterior", which returns the entire matrix of posterior predictions. Default: "posterior".
#' @param terms (Optional) Which model terms to include in the prediction. This can be a single term or a list of model terms. Options include "y_hat", "prognostic_function", "mu", "cate", "tau", "rfx", "variance_forest", or "all". If a model doesn't have random effects or variance forest predictions, but one of those terms is request, the request will simply be ignored. If a model has random effects fit with either "intercept_only" or "intercept_plus_treatment" model_spec, then "prognostic_function" refers to the predictions of the prognostic forest plus the random intercept and "cate" refers to the predictions of the treatment effect forest plus the random slope on the treatment variable. For these models, the forest predictions alone can be requested via "mu" (prognostic forest) and "tau" (treatment effect forest). In all other cases, "mu" will return exactly the same result as "prognostic_function" and "tau" will return exactly the same result as "cate". If none of the requested terms are present in a model, this function will return `NULL` along with a warning. Default: "all".
#' @param terms (Optional) Which model terms to include in the prediction. Options include `"y_hat"`, `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"rfx"`, `"variance_forest"`, or `"all"`.
#'
#' The treatment effect terms follow a three-level hierarchy:
#' \itemize{
#' \item `"tau"` returns `tau_0 + tau(X)`: the parametric treatment intercept (if sampled) plus the treatment forest. This matches `model$tau_hat_train` / `model$tau_hat_test`.
#' \item `"cate"` additionally folds in the random slope on treatment when random effects are fit with `rfx_model_spec = "intercept_plus_treatment"`; otherwise it is identical to `"tau"`.
#' \item The raw forest-only component (without `tau_0`) is not directly returned by this method; use `model$forests_tau` to access it.
#' }
#'
#' Similarly for the prognostic term: `"mu"` returns the prognostic forest only, while `"prognostic_function"` additionally folds in the random intercept when `rfx_model_spec` is `"intercept_only"` or `"intercept_plus_treatment"`; otherwise the two are identical.
#'
#' If a model doesn't have random effects or variance forest predictions but one of those terms is requested, the request will simply be ignored. If none of the requested terms are present, this function will return `NULL` along with a warning. Default: `"all"`.
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#' @param ... (Optional) Other prediction parameters.
#'
Expand Down Expand Up @@ -3572,7 +3576,8 @@ predict.bcfmodel <- function(
predict_mu_forest_intermediate <- ((predict_y_hat || predict_prog_function) &&
has_mu_forest)
predict_tau_forest_intermediate <- ((predict_y_hat ||
predict_cate_function) &&
predict_cate_function ||
(object$model_params$adaptive_coding && (predict_mu_forest || predict_prog_function))) &&
has_tau_forest)

# Make sure covariates are matrix or data frame
Expand Down Expand Up @@ -3875,7 +3880,7 @@ predict.bcfmodel <- function(
mu_hat <- pnorm(mu_hat_forest)
}
if (predict_tau_forest) {
tau_hat <- pnorm(tau_hat_forest)
tau_hat <- pnorm(cate_hat_forest)
}
if (predict_prog_function) {
prognostic_function <- pnorm(prognostic_function)
Expand All @@ -3897,7 +3902,7 @@ predict.bcfmodel <- function(
mu_hat <- mu_hat_forest
}
if (predict_tau_forest) {
tau_hat <- tau_hat_forest
tau_hat <- cate_hat_forest
}
if (predict_prog_function) {
prognostic_function <- prognostic_function
Expand Down
11 changes: 10 additions & 1 deletion R/posterior_transformation.R
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,16 @@ posterior_predictive_heuristic_multiplier <- function(
#' Compute posterior credible intervals for specified terms from a fitted BCF model. Supports intervals for prognostic forests, CATE forests, variance forests, random effects, and overall mean outcome predictions.
#'
#' @param model_object A fitted BCF model object of class `bcfmodel`.
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options for BCF models are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`. Note that `"mu"` is only different from `"prognostic_function"` if random effects are included with a model spec of `"intercept_only"` or `"intercept_plus_treatment"` and `"tau"` is only different from `"cate"` if random effects are included with a model spec of `"intercept_plus_treatment"`.
#' @param terms A character string specifying the model term(s) for which to compute intervals. Options are `"prognostic_function"`, `"mu"`, `"cate"`, `"tau"`, `"variance_forest"`, `"rfx"`, or `"y_hat"`.
#'
#' The treatment effect terms follow a three-level hierarchy:
#' \itemize{
#' \item `"tau"` returns `tau_0 + tau(X)`: the parametric treatment intercept (if sampled) plus the treatment forest. This matches `model$tau_hat_train` / `model$tau_hat_test`.
#' \item `"cate"` additionally folds in the random slope on treatment when random effects are fit with `rfx_model_spec = "intercept_plus_treatment"`; otherwise it is identical to `"tau"`.
#' \item The raw forest-only component (without `tau_0`) is not directly returned by this method; use `model$forests_tau` to access it.
#' }
#'
#' Similarly for the prognostic term: `"mu"` returns the prognostic forest only, while `"prognostic_function"` additionally folds in the random intercept when `rfx_model_spec` is `"intercept_only"` or `"intercept_plus_treatment"`; otherwise the two are identical.
#' @param level A numeric value between 0 and 1 specifying the credible interval level (default is 0.95 for a 95% credible interval).
#' @param scale (Optional) Scale of mean function predictions. Options are "linear", which returns predictions on the original scale of the mean forest / RFX terms, and "probability", which transforms predictions into a probability of observing `y == 1`. "probability" is only valid for models fit with a probit outcome model. Default: "linear".
#' @param X (Optional) A matrix or data frame of covariates at which to compute the intervals. Required if the requested term depends on covariates (e.g., prognostic forest, CATE forest, variance forest, or overall predictions).
Expand Down
2 changes: 1 addition & 1 deletion man/ForestKernelComputation.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 10 additions & 1 deletion man/computeBCFPosteriorInterval.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 12 additions & 1 deletion man/predict.bcfmodel.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading