From cc0dbc2241736b9870ba902a867c41b6f5762117 Mon Sep 17 00:00:00 2001 From: nikosbosse Date: Fri, 13 Feb 2026 03:18:13 +0100 Subject: [PATCH] Fixes #942: Add plot_discrimination() for binary forecasts Co-Authored-By: Claude Opus 4.6 --- NAMESPACE | 2 + R/plot-discrimination.R | 42 +++++ man/plot_discrimination.Rd | 31 ++++ .../plot-discrimination-facet-model.svg | 159 ++++++++++++++++++ .../plot-discrimination.svg | 62 +++++++ tests/testthat/test-plot_discrimination.R | 65 +++++++ 6 files changed, 361 insertions(+) create mode 100644 R/plot-discrimination.R create mode 100644 man/plot_discrimination.Rd create mode 100644 tests/testthat/_snaps/plot_discrimination/plot-discrimination-facet-model.svg create mode 100644 tests/testthat/_snaps/plot_discrimination/plot-discrimination.svg create mode 100644 tests/testthat/test-plot_discrimination.R diff --git a/NAMESPACE b/NAMESPACE index 151aa6149..3a6a75efc 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -91,6 +91,7 @@ export(overprediction_quantile) export(overprediction_sample) export(pit_histogram_sample) export(plot_correlations) +export(plot_discrimination) export(plot_forecast_counts) export(plot_heatmap) export(plot_interval_coverage) @@ -178,6 +179,7 @@ importFrom(ggplot2,element_text) importFrom(ggplot2,facet_grid) importFrom(ggplot2,facet_wrap) importFrom(ggplot2,geom_col) +importFrom(ggplot2,geom_density) importFrom(ggplot2,geom_line) importFrom(ggplot2,geom_linerange) importFrom(ggplot2,geom_polygon) diff --git a/R/plot-discrimination.R b/R/plot-discrimination.R new file mode 100644 index 000000000..e06df18f5 --- /dev/null +++ b/R/plot-discrimination.R @@ -0,0 +1,42 @@ +#' @title Plot discrimination for binary forecasts +#' +#' @description +#' Visualise the discrimination ability of binary forecasts by plotting the +#' distribution of predicted probabilities, stratified by the observed outcome. +#' A well-discriminating model will show clearly separated distributions for +#' the two observed levels. +#' +#' @param forecast A data.table (or data.frame) containing at least columns +#' `observed` (factor with two levels) and `predicted` (numeric probabilities +#' between 0 and 1). Typically a `forecast_binary` object or the output of +#' [as_forecast_binary()]. +#' @returns A ggplot object showing overlapping density curves of predicted +#' probabilities, coloured by observed outcome level. +#' @importFrom ggplot2 ggplot aes geom_density labs .data +#' @importFrom checkmate assert assert_data_frame +#' @export +#' @examples +#' library(ggplot2) +#' plot_discrimination(na.omit(example_binary)) +#' +#' plot_discrimination(na.omit(example_binary)) + +#' facet_wrap(~model) + +plot_discrimination <- function(forecast) { + forecast <- ensure_data.table(forecast) + assert(check_columns_present(forecast, c("observed", "predicted"))) + + plot <- ggplot( + forecast, + aes(x = .data[["predicted"]], fill = .data[["observed"]]) + ) + + geom_density(alpha = 0.5) + + labs( + x = "Predicted probability", + y = "Density", + fill = "Observed" + ) + + theme_scoringutils() + + return(plot) +} diff --git a/man/plot_discrimination.Rd b/man/plot_discrimination.Rd new file mode 100644 index 000000000..a93746897 --- /dev/null +++ b/man/plot_discrimination.Rd @@ -0,0 +1,31 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/plot-discrimination.R +\name{plot_discrimination} +\alias{plot_discrimination} +\title{Plot discrimination for binary forecasts} +\usage{ +plot_discrimination(forecast) +} +\arguments{ +\item{forecast}{A data.table (or data.frame) containing at least columns +\code{observed} (factor with two levels) and \code{predicted} (numeric probabilities +between 0 and 1). Typically a \code{forecast_binary} object or the output of +\code{\link[=as_forecast_binary]{as_forecast_binary()}}.} +} +\value{ +A ggplot object showing overlapping density curves of predicted +probabilities, coloured by observed outcome level. +} +\description{ +Visualise the discrimination ability of binary forecasts by plotting the +distribution of predicted probabilities, stratified by the observed outcome. +A well-discriminating model will show clearly separated distributions for +the two observed levels. +} +\examples{ +library(ggplot2) +plot_discrimination(na.omit(example_binary)) + +plot_discrimination(na.omit(example_binary)) + + facet_wrap(~model) +} diff --git a/tests/testthat/_snaps/plot_discrimination/plot-discrimination-facet-model.svg b/tests/testthat/_snaps/plot_discrimination/plot-discrimination-facet-model.svg new file mode 100644 index 000000000..c4544b714 --- /dev/null +++ b/tests/testthat/_snaps/plot_discrimination/plot-discrimination-facet-model.svg @@ -0,0 +1,159 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + +UMass-MechBayes + + + + + + + + + +epiforecasts-EpiNow2 + + + + + + + + + +EuroCOVIDhub-baseline + + + + + + + + + +EuroCOVIDhub-ensemble + + + + + + + +0.0 +0.2 +0.4 +0.6 + + + + + +0.0 +0.2 +0.4 +0.6 + +0 +2 +4 +6 + + + + + +0 +2 +4 +6 + + + + +Predicted probability +Density +Observed + + +0 +1 +plot_discrimination_facet_model + + diff --git a/tests/testthat/_snaps/plot_discrimination/plot-discrimination.svg b/tests/testthat/_snaps/plot_discrimination/plot-discrimination.svg new file mode 100644 index 000000000..e06873707 --- /dev/null +++ b/tests/testthat/_snaps/plot_discrimination/plot-discrimination.svg @@ -0,0 +1,62 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + +0 +2 +4 +6 + + + + + + + + + +0.0 +0.2 +0.4 +0.6 +Predicted probability +Density +Observed + + +0 +1 +plot_discrimination + + diff --git a/tests/testthat/test-plot_discrimination.R b/tests/testthat/test-plot_discrimination.R new file mode 100644 index 000000000..0ef5dd982 --- /dev/null +++ b/tests/testthat/test-plot_discrimination.R @@ -0,0 +1,65 @@ +test_that("plot_discrimination() works with a forecast_binary object", { + p <- plot_discrimination(na.omit(example_binary)) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_discrimination", p) +}) + +test_that("plot_discrimination() works with faceting by model", { + p <- plot_discrimination(na.omit(example_binary)) + + facet_wrap(~model) + expect_s3_class(p, "ggplot") + skip_on_cran() + vdiffr::expect_doppelganger("plot_discrimination_facet_model", p) +}) + +test_that("plot_discrimination() works with a plain data.frame input", { + df <- data.frame( + observed = factor(c("0", "0", "1", "1"), levels = c("0", "1")), + predicted = c(0.1, 0.3, 0.7, 0.9), + model = "test_model" + ) + p <- plot_discrimination(df) + expect_s3_class(p, "ggplot") +}) + +test_that("plot_discrimination() errors with missing required columns", { + df_no_observed <- data.frame(predicted = c(0.1, 0.5, 0.9)) + df_no_predicted <- data.frame( + observed = factor(c("0", "1", "0"), levels = c("0", "1")) + ) + expect_error(plot_discrimination(df_no_observed), "observed") + expect_error(plot_discrimination(df_no_predicted), "predicted") +}) + +test_that("plot_discrimination() handles single-model data", { + single_model <- na.omit(example_binary)[ + model == "EuroCOVIDhub-ensemble" + ] + p <- plot_discrimination(single_model) + expect_s3_class(p, "ggplot") +}) + +test_that("plot_discrimination() shows separation between observed levels", { + df <- data.frame( + observed = factor(c(rep("0", 50), rep("1", 50)), levels = c("0", "1")), + predicted = c(rep(0.1, 50), rep(0.9, 50)), + model = "perfect" + ) + p <- plot_discrimination(df) + expect_s3_class(p, "ggplot") + + build_data <- ggplot2::ggplot_build(p) + # The density layer should have at least 2 groups + layer_data <- build_data$data[[1]] + expect_true(length(unique(layer_data$group)) >= 2) +}) + +test_that("plot_discrimination() handles edge case with all identical predictions", { + df <- data.frame( + observed = factor(c("0", "0", "1", "1"), levels = c("0", "1")), + predicted = c(0.5, 0.5, 0.5, 0.5), + model = "constant" + ) + expect_no_error(plot_discrimination(df)) +})