Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions R/score.R
Original file line number Diff line number Diff line change
Expand Up @@ -150,6 +150,19 @@ score.default <- function(forecast, metrics, ...) {
#' @returns A data table with the forecasts and the calculated metrics.
#' @keywords internal
apply_metrics <- function(forecast, metrics, ...) {
clashing <- intersect(names(metrics), colnames(forecast))
if (length(clashing) > 0) {
#nolint start: keyword_quote_linter
cli_warn(
c(
"!" = "Column names {.val {clashing}} are already present in the
forecast data and will be overwritten by metric output.",
"i" = "Consider renaming these metrics to avoid clashing with
existing column names."
)
)
#nolint end
}
lapply(names(metrics), function(metric_name) {
result <- do.call(
run_safely,
Expand Down
77 changes: 73 additions & 4 deletions tests/testthat/test-score.R
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,73 @@ test_that("Manipulating scores objects with .[ works as expected", {
expect_no_condition(ex[, extra_col := "something"])
})

# metric name / column name clash detection ------------------------------------
test_that("score() warns when metric name clashes with existing column in binary forecast", {
horizon_metric <- function(observed, predicted) rep(1, length(observed))
expect_warning(
score(example_binary, metrics = list(
brier_score = brier_score, horizon = horizon_metric
)),
"horizon"
)
})

test_that("score() warns when metric name clashes with existing column in quantile forecast", {
location_metric <- function(observed, predicted, quantile_level) {
rep(0, length(observed))
}
expect_warning(
score(example_quantile, metrics = list(
wis = wis, location = location_metric
)),
"location"
)
})

test_that("score() warns when metric name clashes with existing column in sample forecast", {
model_metric <- function(observed, predicted) rep(0.5, length(observed))
expect_warning(
score(example_sample_continuous, metrics = list(
crps = crps_sample, model = model_metric
)),
"model"
)
})

test_that("score() warns when metric name clashes with existing column in point forecast", {
horizon_metric <- function(observed, predicted) abs(observed - predicted)
expect_warning(
score(example_point, metrics = list(
ae_point = Metrics::ae, horizon = horizon_metric
)),
"horizon"
)
})

test_that("score() does not warn when metric names don't clash with column names", {
expect_no_condition(score(example_binary))
})

test_that("score() warns about multiple clashing metric names at once", {
horizon_metric <- function(observed, predicted) rep(1, length(observed))
model_metric <- function(observed, predicted) rep(2, length(observed))
expect_warning(
score(example_binary, metrics = list(
brier_score = brier_score, horizon = horizon_metric, model = model_metric
)),
"horizon.*model|model.*horizon"
)
})

test_that("score() warns when metric name clashes with a protected column name", {
observed_metric <- function(observed, predicted) rep(0, length(observed))
expect_warning(
score(example_binary, metrics = list(
brier_score = brier_score, observed = observed_metric
)),
"observed"
)
})

# test integer and continuous case ---------------------------------------------
test_that("function produces output for a continuous format case", {
Expand Down Expand Up @@ -121,18 +188,20 @@ test_that("apply_metrics() works", {
expect_equal(dt$test, 2:11) # nolint: expect_identical_linter

# additional named argument works
dt2 <- data.table::data.table(x = 1:10)
expect_no_condition(
scoringutils:::apply_metrics( # nolint: undesirable_operator_linter
forecast = dt, metrics = list(test = function(x) x + 1),
dt$x, y = dt$test
forecast = dt2, metrics = list(test = function(x) x + 1),
dt2$x, y = dt2$x
)
)

# additional unnamed argument does not work
dt3 <- data.table::data.table(x = 1:10)
expect_warning(
scoringutils:::apply_metrics( # nolint: undesirable_operator_linter
forecast = dt, metrics = list(test = function(x) x + 1),
dt$x, dt$test
forecast = dt3, metrics = list(test = function(x) x + 1),
dt3$x, dt3$x
)
)
})
Expand Down
Loading