diff --git a/R/loo_predictive_metric.R b/R/loo_predictive_metric.R index 8ee18bd2..5bbf45f5 100644 --- a/R/loo_predictive_metric.R +++ b/R/loo_predictive_metric.R @@ -197,6 +197,11 @@ loo_predictive_metric.matrix <- yhat <- as.integer(yhat > 0.5) mask <- y == 0 + if (all(mask) || !any(mask)) { + stop("Balanced accuracy requires both classes (0 and 1) to be present in 'y'.", + call. = FALSE) + } + tn <- mean(yhat[mask] == y[mask]) # True negatives tp <- mean(yhat[!mask] == y[!mask]) # True positives diff --git a/tests/testthat/test_loo_predictive_metric.R b/tests/testthat/test_loo_predictive_metric.R index 3d796249..40ee88bb 100644 --- a/tests/testthat/test_loo_predictive_metric.R +++ b/tests/testthat/test_loo_predictive_metric.R @@ -170,4 +170,14 @@ test_that('Balanced accuracy computation is correct', { 'all(yhat <= 1 & yhat >= 0) is not TRUE', fixed = TRUE ) + expect_error( + .balanced_accuracy(c(1, 1, 1, 1), c(0.8, 0.6, 0.7, 0.9)), + 'Balanced accuracy requires both classes', + fixed = TRUE + ) + expect_error( + .balanced_accuracy(c(0, 0, 0, 0), c(0.1, 0.2, 0.3, 0.4)), + 'Balanced accuracy requires both classes', + fixed = TRUE + ) })