Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
26 changes: 26 additions & 0 deletions R/metrics-binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,32 @@ assert_input_binary <- function(observed, predicted) {
assert_factor(observed, n.levels = 2, min.len = 1)
assert_numeric(predicted, lower = 0, upper = 1)
assert_dims_ok_point(observed, predicted)

# Warn if factor levels appear to be in counterintuitive order.
# Predictions represent P(outcome = highest factor level). If the levels
# are e.g. c("1", "0"), the highest level is "0", meaning predictions are
# interpreted as P(outcome = "0"), which is almost certainly unintended.
lvls <- levels(observed)
counterintuitive <- FALSE
if (setequal(lvls, c("0", "1")) && lvls[1] == "1") {
counterintuitive <- TRUE
} else if (setequal(lvls, c("TRUE", "FALSE")) && lvls[1] == "TRUE") {
counterintuitive <- TRUE
}
if (counterintuitive) {
#nolint start: keyword_quote_linter
cli_warn(c(
"!" = "Factor levels of {.var observed} appear to be in
counterintuitive order: {.val {lvls}}.",
"i" = "Predictions will be interpreted as the probability of
observing {.val {lvls[2]}} (the highest factor level).",
"i" = "If this is not intended, consider reordering the factor levels,
e.g. {.code factor(observed, levels = c({.val {lvls[2]}},
{.val {lvls[1]}}))}"
))
#nolint end
}

return(invisible(NULL))
}

Expand Down
21 changes: 21 additions & 0 deletions tests/testthat/test-class-forecast-binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,27 @@



test_that("as_forecast_binary() warns when data has reversed 0/1 factor levels", {

Check warning on line 50 in tests/testthat/test-class-forecast-binary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=tests/testthat/test-class-forecast-binary.R,line=50,col=12,[nonportable_path_linter] Use file.path() to construct portable file paths.
dt <- data.table(
model = "m1",
id = 1:4,
observed = factor(c(0, 1, 1, 0), levels = c("1", "0")),
predicted = c(0.1, 0.9, 0.8, 0.2)
)
expect_warning(
as_forecast_binary(dt),
"counterintuitive"
)
})

test_that("score() produces correct results with standard 0/1 factor levels", {

Check warning on line 63 in tests/testthat/test-class-forecast-binary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=tests/testthat/test-class-forecast-binary.R,line=63,col=12,[nonportable_path_linter] Use file.path() to construct portable file paths.
# example_binary has standard levels c("0", "1"), should not warn about levels
expect_no_warning(
suppressMessages(score(as_forecast_binary(example_binary)))
)
})


# ==============================================================================
# score.forecast_binary() # nolint: commented_code_linter
# ==============================================================================
Expand Down
58 changes: 58 additions & 0 deletions tests/testthat/test-metrics-binary.R
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,64 @@
})


# ==============================================================================
# Test factor level order warning
# ==============================================================================
test_that("assert_input_binary() warns when 0/1 factor levels are in counterintuitive order", {

Check warning on line 112 in tests/testthat/test-metrics-binary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=tests/testthat/test-metrics-binary.R,line=112,col=12,[nonportable_path_linter] Use file.path() to construct portable file paths.
observed_rev <- factor(c(0, 1, 1, 0, 1), levels = c("1", "0"))
predicted_rev <- c(0.1, 0.9, 0.8, 0.2, 0.7)
expect_warning(
assert_input_binary(observed_rev, predicted_rev),
"counterintuitive"
)
})

test_that("assert_input_binary() does not warn for standard 0/1 level order", {

Check warning on line 121 in tests/testthat/test-metrics-binary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=tests/testthat/test-metrics-binary.R,line=121,col=12,[nonportable_path_linter] Use file.path() to construct portable file paths.
observed_std <- factor(c(0, 1, 1, 0, 1), levels = c("0", "1"))
predicted_std <- c(0.1, 0.9, 0.8, 0.2, 0.7)
expect_no_warning(assert_input_binary(observed_std, predicted_std))
})

test_that("assert_input_binary() does not warn for non-numeric factor levels", {
observed_ab <- factor(c("a", "b", "b", "a"), levels = c("a", "b"))
predicted_ab <- c(0.3, 0.7, 0.6, 0.4)
expect_no_warning(assert_input_binary(observed_ab, predicted_ab))
})

test_that("brier_score() produces different results with reversed factor levels", {
observed_correct <- factor(c(0, 1, 1, 0), levels = c("0", "1"))
observed_reversed <- factor(c(0, 1, 1, 0), levels = c("1", "0"))
predicted_bs <- c(0.1, 0.9, 0.8, 0.2)

scores_correct <- brier_score(observed_correct, predicted_bs)
expect_equal(scores_correct, c(0.01, 0.01, 0.04, 0.04)) # nolint: expect_identical_linter

expect_warning(
scores_reversed <- brier_score(observed_reversed, predicted_bs),
"counterintuitive"
)
expect_false(all(scores_correct == scores_reversed))
})

test_that("logs_binary() warns with reversed 0/1 factor levels", {

Check warning on line 148 in tests/testthat/test-metrics-binary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=tests/testthat/test-metrics-binary.R,line=148,col=12,[nonportable_path_linter] Use file.path() to construct portable file paths.
observed_reversed <- factor(c(0, 1, 1, 0), levels = c("1", "0"))
predicted_lb <- c(0.1, 0.9, 0.8, 0.2)
expect_warning(
logs_binary(observed_reversed, predicted_lb),
"counterintuitive"
)
})

test_that("assert_input_binary() warns for TRUE/FALSE levels in counterintuitive order", {

Check warning on line 157 in tests/testthat/test-metrics-binary.R

View workflow job for this annotation

GitHub Actions / lint-changed-files

file=tests/testthat/test-metrics-binary.R,line=157,col=12,[nonportable_path_linter] Use file.path() to construct portable file paths.
observed_tf <- factor(c(TRUE, FALSE, TRUE), levels = c("TRUE", "FALSE"))
predicted_tf <- c(0.8, 0.2, 0.9)
expect_warning(
assert_input_binary(observed_tf, predicted_tf),
"counterintuitive"
)
})


# ==============================================================================
# Test Binary Metrics
# ==============================================================================
Expand Down
Loading