From f500397ed970c4a1f2fb6212edefc95cbd18cbad Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Fri, 27 Mar 2026 11:26:57 +0100 Subject: [PATCH] implement s7 classes --- DESCRIPTION | 1 + NAMESPACE | 7 + R/bgm.R | 2 - R/bgmCompare.R | 1 - R/bgmcompare-methods.r | 70 +++-- R/bgms-methods.R | 77 +++-- R/bgms_s7.R | 217 +++++++++++++ R/build_output.R | 9 +- R/extractor_functions.R | 65 ++-- R/fit_accessors.R | 70 +++++ R/mcmc_summary.R | 2 +- R/simulate_predict.R | 50 +-- man/cash-.bgmCompare.Rd | 7 +- man/cash-.bgms.Rd | 7 +- tests/testthat/test-fit-object-contract.R | 365 ++++++++++++++++++++++ 15 files changed, 839 insertions(+), 111 deletions(-) create mode 100644 R/bgms_s7.R create mode 100644 R/fit_accessors.R create mode 100644 tests/testthat/test-fit-object-contract.R diff --git a/DESCRIPTION b/DESCRIPTION index 573c5110..2362c7ec 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -32,6 +32,7 @@ Imports: Rcpp (>= 1.0.7), RcppParallel, Rdpack, + S7, methods, lifecycle, stats diff --git a/NAMESPACE b/NAMESPACE index 09c79686..c03c8a09 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -27,6 +27,8 @@ S3method(extract_precision,bgms) S3method(extract_rhat,bgmCompare) S3method(extract_rhat,bgms) S3method(extract_sbm,bgms) +S3method(names,bgmCompare) +S3method(names,bgms) S3method(predict,bgmCompare) S3method(predict,bgms) S3method(print,bgmCompare) @@ -63,6 +65,11 @@ importFrom(Rcpp,evalCpp) importFrom(RcppParallel,defaultNumThreads) importFrom(RcppParallel,setThreadOptions) importFrom(Rdpack,reprompt) +importFrom(S7,class_any) +importFrom(S7,class_character) +importFrom(S7,new_class) +importFrom(S7,new_property) +importFrom(S7,prop) importFrom(methods,hasArg) importFrom(stats,dpois) importFrom(stats,predict) diff --git a/R/bgm.R b/R/bgm.R index ad0af91f..942a1cce 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -524,7 +524,5 @@ bgm = function( raw = run_sampler(spec) output = build_output(spec, raw) - - output$.bgm_spec = spec return(output) } diff --git a/R/bgmCompare.R b/R/bgmCompare.R index 648a3c56..7ef56197 100644 --- a/R/bgmCompare.R +++ b/R/bgmCompare.R @@ -372,6 +372,5 @@ bgmCompare = function( raw = run_sampler(spec) output = build_output(spec, raw) - output$.bgm_spec = spec return(output) } diff --git a/R/bgmcompare-methods.r b/R/bgmcompare-methods.r index e751d9b4..db2478bc 100644 --- a/R/bgmcompare-methods.r +++ b/R/bgmcompare-methods.r @@ -301,7 +301,8 @@ coef.bgmCompare = function(object, ...) { # ============================================================ # ---- main effects ---- - array3d_main = to_array3d(object$raw_samples$main) + raw = get_raw_samples(object) + array3d_main = to_array3d(raw$main) stopifnot(!is.null(array3d_main)) mean_main = apply(array3d_main, 3, mean) @@ -335,7 +336,7 @@ coef.bgmCompare = function(object, ...) { # ============================================================ # ---- pairwise effects ---- - array3d_pair = to_array3d(object$raw_samples$pairwise) + array3d_pair = to_array3d(raw$pairwise) stopifnot(!is.null(array3d_pair)) mean_pair = apply(array3d_pair, 3, mean) @@ -369,7 +370,7 @@ coef.bgmCompare = function(object, ...) { # ============================================================ # ---- indicators (present only if selection was used) ---- indicators = NULL - array3d_ind = to_array3d(object$raw_samples$indicator) + array3d_ind = to_array3d(raw$indicator) if(!is.null(array3d_ind)) { mean_ind = apply(array3d_ind, 3, mean) @@ -413,9 +414,10 @@ coef.bgmCompare = function(object, ...) { #' Access elements of a bgmCompare object #' -#' @description Intercepts access to \code{posterior_summary_*} fields and -#' triggers lazy computation from cache when needed. All other fields pass -#' through using standard list extraction. +#' @description Provides \code{$} access to S7 properties. Lazy +#' \code{posterior_summary_*} properties trigger computation on first +#' access via S7 property getters. Also supports legacy S3 list-based +#' fit objects. #' #' @param x A \code{bgmCompare} object. #' @param name Name of the element to access. @@ -426,17 +428,21 @@ coef.bgmCompare = function(object, ...) { #' @export #' @keywords internal `$.bgmCompare` = function(x, name) { - if(startsWith(name, "posterior_summary_")) { - cache = .subset2(x, "cache") - if(!is.null(cache)) { - ensure_summaries(x) - val = cache[[name]] - if(!is.null(val)) { - return(val) + if(inherits(x, "S7_object")) { + S7::prop(x, name) + } else { + if(startsWith(name, "posterior_summary_")) { + cache = .subset2(x, "cache") + if(!is.null(cache)) { + ensure_summaries(x) + val = cache[[name]] + if(!is.null(val)) { + return(val) + } } } + .subset2(x, name) } - .subset2(x, name) } @@ -446,15 +452,35 @@ coef.bgmCompare = function(object, ...) { #' @export #' @keywords internal `[[.bgmCompare` = function(x, name, ...) { - if(is.character(name) && startsWith(name, "posterior_summary_")) { - cache = .subset2(x, "cache") - if(!is.null(cache)) { - ensure_summaries(x) - val = cache[[name]] - if(!is.null(val)) { - return(val) + if(inherits(x, "S7_object")) { + if(is.character(name)) { + S7::prop(x, name) + } else { + stop("numeric indexing is not supported for bgmCompare objects") + } + } else { + if(is.character(name) && startsWith(name, "posterior_summary_")) { + cache = .subset2(x, "cache") + if(!is.null(cache)) { + ensure_summaries(x) + val = cache[[name]] + if(!is.null(val)) { + return(val) + } } } + .subset2(x, name) + } +} + + +#' @method names bgmCompare +#' @export +#' @keywords internal +names.bgmCompare = function(x) { + if(inherits(x, "S7_object")) { + S7::prop(x, ".field_names") + } else { + NextMethod() } - .subset2(x, name) } diff --git a/R/bgms-methods.R b/R/bgms-methods.R index 31da3c4c..b9b362a7 100644 --- a/R/bgms-methods.R +++ b/R/bgms-methods.R @@ -124,7 +124,7 @@ summary.bgms = function(object, ...) { if(!is.null(object$posterior_summary_pairwise_allocations)) { out$allocations = object$posterior_summary_pairwise_allocations - out$mean_allocations = object$posterior_mean_allocations + out$mean_allocations = get_posterior_mean(object, "allocations") out$mode_allocations = object$posterior_mode_allocations out$num_blocks = object$posterior_num_blocks } @@ -261,15 +261,15 @@ print.summary.bgms = function(x, digits = 3, ...) { #' @export coef.bgms = function(object, ...) { out = list( - main = object$posterior_mean_main, - pairwise = object$posterior_mean_associations + main = get_posterior_mean(object, "main"), + pairwise = get_posterior_mean(object, "associations") ) - if(!is.null(object$posterior_mean_indicator)) { - out$indicator = object$posterior_mean_indicator + if(!is.null(get_posterior_mean(object, "indicator"))) { + out$indicator = get_posterior_mean(object, "indicator") } - if(!is.null(object$posterior_mean_allocations)) { - out$mean_allocations = object$posterior_mean_allocations + if(!is.null(get_posterior_mean(object, "allocations"))) { + out$mean_allocations = get_posterior_mean(object, "allocations") out$mode_allocations = object$posterior_mode_allocations out$num_blocks = object$posterior_num_blocks } @@ -280,9 +280,10 @@ coef.bgms = function(object, ...) { #' Access elements of a bgms object #' -#' @description Intercepts access to \code{posterior_summary_*} fields and -#' triggers lazy computation from cache when needed. All other fields pass -#' through using standard list extraction. +#' @description Provides \code{$} access to S7 properties. Lazy +#' \code{posterior_summary_*} properties trigger computation on first +#' access via S7 property getters. Also supports legacy S3 list-based +#' fit objects. #' #' @param x A \code{bgms} object. #' @param name Name of the element to access. @@ -293,17 +294,21 @@ coef.bgms = function(object, ...) { #' @export #' @keywords internal `$.bgms` = function(x, name) { - if(startsWith(name, "posterior_summary_")) { - cache = .subset2(x, "cache") - if(!is.null(cache)) { - ensure_summaries(x) - val = cache[[name]] - if(!is.null(val)) { - return(val) + if(inherits(x, "S7_object")) { + S7::prop(x, name) + } else { + if(startsWith(name, "posterior_summary_")) { + cache = .subset2(x, "cache") + if(!is.null(cache)) { + ensure_summaries(x) + val = cache[[name]] + if(!is.null(val)) { + return(val) + } } } + .subset2(x, name) } - .subset2(x, name) } @@ -313,17 +318,37 @@ coef.bgms = function(object, ...) { #' @export #' @keywords internal `[[.bgms` = function(x, name, ...) { - if(is.character(name) && startsWith(name, "posterior_summary_")) { - cache = .subset2(x, "cache") - if(!is.null(cache)) { - ensure_summaries(x) - val = cache[[name]] - if(!is.null(val)) { - return(val) + if(inherits(x, "S7_object")) { + if(is.character(name)) { + S7::prop(x, name) + } else { + stop("numeric indexing is not supported for bgms objects") + } + } else { + if(is.character(name) && startsWith(name, "posterior_summary_")) { + cache = .subset2(x, "cache") + if(!is.null(cache)) { + ensure_summaries(x) + val = cache[[name]] + if(!is.null(val)) { + return(val) + } } } + .subset2(x, name) + } +} + + +#' @method names bgms +#' @export +#' @keywords internal +names.bgms = function(x) { + if(inherits(x, "S7_object")) { + S7::prop(x, ".field_names") + } else { + NextMethod() } - .subset2(x, name) } diff --git a/R/bgms_s7.R b/R/bgms_s7.R new file mode 100644 index 00000000..52089b83 --- /dev/null +++ b/R/bgms_s7.R @@ -0,0 +1,217 @@ +# ============================================================================== +# S7 Class Definitions for bgms and bgmCompare Fit Objects +# ============================================================================== +# +# These S7 classes define the structure of fit objects returned by bgm() +# and bgmCompare(). They replace the previous S3 list-based representation +# while preserving the same user-facing API via $ and [[ compatibility +# methods (see bgms-methods.R and bgmcompare-methods.r). +# +# Lazy summary computation: +# The posterior_summary_* properties use S7 custom getters that trigger +# ensure_summaries() on first access. Computed values are stored in +# the cache environment (reference semantics) and returned on subsequent +# access without recomputation. +# +# names(fit) contract: +# The .field_names property stores the set of names that names(fit) +# should return, matching the previous S3 list-based behavior where +# conditional fields only appear when present. This is computed during +# construction. +# ============================================================================== + + +# ------------------------------------------------------------------ +# bgms S7 class +# ------------------------------------------------------------------ + +#' @importFrom S7 new_class new_property class_any class_character prop +bgms_class = new_class("bgms", + package = NULL, + properties = list( + # --- Core (always present) --- + arguments = new_property(class_any), + raw_samples = new_property(class_any), + cache = new_property(class_any), + + # --- Posterior means (set during construction, immutable) --- + posterior_mean_main = new_property(class_any, default = NULL), + posterior_mean_associations = new_property(class_any, default = NULL), + posterior_mean_residual_variance = new_property(class_any, default = NULL), + posterior_mean_indicator = new_property(class_any, default = NULL), + posterior_mean_coclustering_matrix = new_property(class_any, default = NULL), + posterior_mean_allocations = new_property(class_any, default = NULL), + posterior_mode_allocations = new_property(class_any, default = NULL), + posterior_num_blocks = new_property(class_any, default = NULL), + + # --- Pre-computed summaries (not lazy) --- + posterior_summary_pairwise_allocations = new_property(class_any, default = NULL), + + # --- Lazy MCMC diagnostics (computed on first access via getter) --- + posterior_summary_main = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_main"]] + } + ), + posterior_summary_pairwise = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_pairwise"]] + } + ), + posterior_summary_indicator = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_indicator"]] + } + ), + posterior_summary_quadratic = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_quadratic"]] + } + ), + + # --- Optional --- + nuts_diag = new_property(class_any, default = NULL), + + # --- easybgm compatibility (deprecated) --- + indicator = new_property(class_any, default = NULL), + interactions = new_property(class_any, default = NULL), + thresholds = new_property(class_any, default = NULL), + + # --- Internal --- + .bgm_spec = new_property(class_any, default = NULL), + .field_names = new_property(class_character) + ) +) + + +# ------------------------------------------------------------------ +# s3_list_to_bgms +# ------------------------------------------------------------------ +# Converts an S3 list-based bgms fit (built incrementally in +# build_output_bgm / build_output_mixed_mrf) to the S7 bgms_class. +# +# @param results A named list with class "bgms". +# +# Returns: A bgms_class S7 object. +# ------------------------------------------------------------------ +s3_list_to_bgms = function(results) { + bgms_class( + arguments = .subset2(results, "arguments"), + raw_samples = .subset2(results, "raw_samples"), + cache = .subset2(results, "cache"), + posterior_mean_main = .subset2(results, "posterior_mean_main"), + posterior_mean_associations = .subset2(results, "posterior_mean_associations"), + posterior_mean_residual_variance = .subset2(results, "posterior_mean_residual_variance"), + posterior_mean_indicator = .subset2(results, "posterior_mean_indicator"), + posterior_mean_coclustering_matrix = .subset2(results, "posterior_mean_coclustering_matrix"), + posterior_mean_allocations = .subset2(results, "posterior_mean_allocations"), + posterior_mode_allocations = .subset2(results, "posterior_mode_allocations"), + posterior_num_blocks = .subset2(results, "posterior_num_blocks"), + posterior_summary_pairwise_allocations = .subset2(results, "posterior_summary_pairwise_allocations"), + nuts_diag = .subset2(results, "nuts_diag"), + indicator = .subset2(results, "indicator"), + interactions = .subset2(results, "interactions"), + thresholds = .subset2(results, "thresholds"), + .bgm_spec = .subset2(results, ".bgm_spec"), + .field_names = names(results) + ) +} + + +# ------------------------------------------------------------------ +# bgmCompare S7 class +# ------------------------------------------------------------------ + +bgmCompare_class = new_class("bgmCompare", + package = NULL, + properties = list( + # --- Core (always present) --- + arguments = new_property(class_any), + raw_samples = new_property(class_any), + cache = new_property(class_any), + + # --- Posterior means (set during construction, immutable) --- + posterior_mean_main_baseline = new_property(class_any, default = NULL), + posterior_mean_associations_baseline = new_property(class_any, default = NULL), + posterior_mean_main_differences = new_property(class_any, default = NULL), + posterior_mean_associations_differences = new_property(class_any, default = NULL), + + # --- Lazy MCMC diagnostics (computed on first access via getter) --- + posterior_summary_main_baseline = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_main_baseline"]] + } + ), + posterior_summary_pairwise_baseline = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_pairwise_baseline"]] + } + ), + posterior_summary_main_differences = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_main_differences"]] + } + ), + posterior_summary_pairwise_differences = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_pairwise_differences"]] + } + ), + posterior_summary_indicator = new_property( + class = class_any, + getter = function(self) { + ensure_summaries(self) + self@cache[["posterior_summary_indicator"]] + } + ), + + # --- Optional --- + nuts_diag = new_property(class_any, default = NULL), + + # --- Internal --- + .bgm_spec = new_property(class_any, default = NULL), + .field_names = new_property(class_character) + ) +) + + +# ------------------------------------------------------------------ +# s3_list_to_bgmCompare +# ------------------------------------------------------------------ +# Converts an S3 list-based bgmCompare fit (built incrementally in +# build_output_compare) to the S7 bgmCompare_class. +# +# @param results A named list with class "bgmCompare". +# +# Returns: A bgmCompare_class S7 object. +# ------------------------------------------------------------------ +s3_list_to_bgmCompare = function(results) { + bgmCompare_class( + arguments = .subset2(results, "arguments"), + raw_samples = .subset2(results, "raw_samples"), + cache = .subset2(results, "cache"), + posterior_mean_main_baseline = .subset2(results, "posterior_mean_main_baseline"), + posterior_mean_associations_baseline = .subset2(results, "posterior_mean_associations_baseline"), + posterior_mean_main_differences = .subset2(results, "posterior_mean_main_differences"), + posterior_mean_associations_differences = .subset2(results, "posterior_mean_associations_differences"), + nuts_diag = .subset2(results, "nuts_diag"), + .bgm_spec = .subset2(results, ".bgm_spec"), + .field_names = names(results) + ) +} diff --git a/R/build_output.R b/R/build_output.R index e990cd13..644c1cfd 100644 --- a/R/build_output.R +++ b/R/build_output.R @@ -506,7 +506,8 @@ build_output_bgm = function(spec, raw) { ) } - results + results$.bgm_spec = spec + s3_list_to_bgms(results) } @@ -781,7 +782,8 @@ build_output_mixed_mrf = function(spec, raw) { ) } - results + results$.bgm_spec = spec + s3_list_to_bgms(results) } @@ -971,7 +973,8 @@ build_output_compare = function(spec, raw) { ) } - results + results$.bgm_spec = spec + s3_list_to_bgmCompare(results) } diff --git a/R/extractor_functions.R b/R/extractor_functions.R index 4d145541..46345b35 100644 --- a/R/extractor_functions.R +++ b/R/extractor_functions.R @@ -79,10 +79,11 @@ extract_indicators.bgms = function(bgms_object) { } # Current format (0.1.6.0+) - if(!is.null(bgms_object$raw_samples$indicator)) { - indicators_list = bgms_object$raw_samples$indicator + raw = get_raw_samples(bgms_object) + if(!is.null(raw$indicator)) { + indicators_list = raw$indicator indicator_samples = do.call(rbind, indicators_list) - param_names = bgms_object$raw_samples$parameter_names$indicator + param_names = raw$parameter_names$indicator stopifnot("parameter_names$indicator missing in fit object" = !is.null(param_names)) colnames(indicator_samples) = param_names return(indicator_samples) @@ -115,9 +116,10 @@ extract_indicators.bgmCompare = function(bgms_object) { } # Current format (0.1.6.0+) - if(!is.null(bgms_object$raw_samples$indicator)) { - indicator_samples = do.call(rbind, bgms_object$raw_samples$indicator) - param_names = bgms_object$raw_samples$parameter_names$indicators + raw = get_raw_samples(bgms_object) + if(!is.null(raw$indicator)) { + indicator_samples = do.call(rbind, raw$indicator) + param_names = raw$parameter_names$indicators if(!is.null(param_names)) { colnames(indicator_samples) = param_names } @@ -177,7 +179,8 @@ extract_posterior_inclusion_probabilities.bgms = function(bgms_object) { data_columnnames = arguments$data_columnnames # Current format (0.1.6.0+) - if(!is.null(bgms_object$raw_samples$indicator)) { + raw = get_raw_samples(bgms_object) + if(!is.null(raw$indicator)) { indicator_samples = extract_indicators(bgms_object) edge_means = colMeans(indicator_samples) } else if(!is.null(bgms_object$indicator)) { @@ -278,8 +281,9 @@ extract_posterior_inclusion_probabilities.bgmCompare = function(bgms_object) { } # Current format (0.1.6.0+) - if(!is.null(bgms_object$raw_samples$indicator)) { - array3d_ind = to_array3d(bgms_object$raw_samples$indicator) + raw = get_raw_samples(bgms_object) + if(!is.null(raw$indicator)) { + array3d_ind = to_array3d(raw$indicator) mean_ind = apply(array3d_ind, 3, mean) # reconstruct VxV matrix using the sampler’s interleaved order: @@ -424,13 +428,14 @@ extract_pairwise_interactions.bgms = function(bgms_object) { var_names = arguments$data_columnnames # Current format (0.1.6.0+): raw samples - if(!is.null(bgms_object$raw_samples)) { - mats = bgms_object$raw_samples$pairwise + raw = get_raw_samples(bgms_object) + if(!is.null(raw)) { + mats = raw$pairwise mat = do.call(rbind, mats) # Use stored parameter names when available (correct for all model types # including mixed MRF where block order differs from upper-triangle order) - stored_names = bgms_object$raw_samples$parameter_names$pairwise + stored_names = raw$parameter_names$pairwise if(!is.null(stored_names)) { edge_names = stored_names } else { @@ -481,14 +486,15 @@ extract_pairwise_interactions.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) # Current format (0.1.6.0+) - if(!is.null(bgms_object$raw_samples$pairwise)) { - pairwise_samples = do.call(rbind, bgms_object$raw_samples$pairwise) + raw = get_raw_samples(bgms_object) + if(!is.null(raw$pairwise)) { + pairwise_samples = do.call(rbind, raw$pairwise) num_vars = bgms_object$arguments$num_variables num_pairs = num_vars * (num_vars - 1) / 2 pairwise_samples = pairwise_samples[, 1:num_pairs] - colnames(pairwise_samples) = bgms_object$raw_samples$parameter_names$pairwise_baseline + colnames(pairwise_samples) = raw$parameter_names$pairwise_baseline return(pairwise_samples) } @@ -563,12 +569,13 @@ extract_main_effects.bgms = function(bgms_object) { # Mixed MRF: return pre-built list from posterior_mean_main if(isTRUE(arguments$is_mixed)) { - return(bgms_object$posterior_mean_main) + return(get_posterior_mean(bgms_object, "main")) } # OMRF: return pre-built threshold matrix - if(!is.null(bgms_object$posterior_mean_main)) { - return(bgms_object$posterior_mean_main) + pm_main = get_posterior_mean(bgms_object, "main") + if(!is.null(pm_main)) { + return(pm_main) } # Deprecated format (0.1.4–0.1.5): $thresholds @@ -598,14 +605,15 @@ extract_main_effects.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) # Current format (0.1.6.0+) - if(!is.null(bgms_object$raw_samples$main)) { - main_samples = do.call(rbind, bgms_object$raw_samples$main) + raw = get_raw_samples(bgms_object) + if(!is.null(raw$main)) { + main_samples = do.call(rbind, raw$main) num_vars = bgms_object$arguments$num_variables - num_main = length(bgms_object$raw_samples$parameter_names$main_baseline) + num_main = length(raw$parameter_names$main_baseline) main_samples = main_samples[, 1:num_main] - colnames(main_samples) = bgms_object$raw_samples$parameter_names$main_baseline + colnames(main_samples) = raw$parameter_names$main_baseline return(main_samples) } @@ -687,7 +695,7 @@ extract_group_params.bgmCompare = function(bgms_object) { arguments = extract_arguments(bgms_object) # Current format (0.1.6.0+) - if(!is.null(bgms_object$raw_samples$main)) { + if(!is.null(get_raw_samples(bgms_object)$main)) { return(.extract_group_params_current(bgms_object, arguments)) } @@ -725,7 +733,8 @@ extract_group_params.bgmCompare = function(bgms_object) { # ============================================================ # ---- main effects ---- - array3d_main = to_array3d(bgms_object$raw_samples$main) + raw = get_raw_samples(bgms_object) + array3d_main = to_array3d(raw$main) mean_main = apply(array3d_main, 3, mean) stopifnot(length(mean_main) %% num_groups == 0L) @@ -758,7 +767,7 @@ extract_group_params.bgmCompare = function(bgms_object) { # ============================================================ # ---- pairwise effects ---- - array3d_pair = to_array3d(bgms_object$raw_samples$pairwise) + array3d_pair = to_array3d(raw$pairwise) mean_pair = apply(array3d_pair, 3, mean) stopifnot(length(mean_pair) %% num_groups == 0L) @@ -1164,8 +1173,8 @@ extract_precision.bgms = function(bgms_object) { return(invisible(NULL)) } - rv = bgms_object$posterior_mean_residual_variance - associations = bgms_object$posterior_mean_associations + rv = get_posterior_mean(bgms_object, "residual_variance") + associations = get_posterior_mean(bgms_object, "associations") if(isTRUE(arguments$is_mixed)) { # Mixed MRF: extract the q x q continuous block, convert to precision @@ -1292,7 +1301,7 @@ extract_log_odds.bgms = function(bgms_object) { return(invisible(NULL)) } - associations = bgms_object$posterior_mean_associations + associations = get_posterior_mean(bgms_object, "associations") if(isTRUE(arguments$is_mixed)) { # Mixed MRF: extract the p x p discrete block, convert to log-odds diff --git a/R/fit_accessors.R b/R/fit_accessors.R new file mode 100644 index 00000000..2d1f6b3d --- /dev/null +++ b/R/fit_accessors.R @@ -0,0 +1,70 @@ +# ============================================================================== +# Internal Accessor Helpers for bgms / bgmCompare Fit Objects +# ============================================================================== +# +# These functions provide a single abstraction point for reading fields +# from bgms and bgmCompare fit objects. They handle both S7 (current) and +# legacy S3 list-based objects transparently. +# +# NOT for use in the `$.bgms`/`[[.bgms` compatibility methods (which use +# S7::prop() directly). +# ============================================================================== + + +# ------------------------------------------------------------------ +# get_fit_cache +# ------------------------------------------------------------------ +# Extracts the cache environment from a fit object. +# +# @param fit A bgms or bgmCompare object (S7 or legacy S3). +# +# Returns: The cache environment, or NULL if absent. +# ------------------------------------------------------------------ +get_fit_cache = function(fit) { + if(inherits(fit, "S7_object")) { + fit@cache + } else { + .subset2(fit, "cache") + } +} + + +# ------------------------------------------------------------------ +# get_raw_samples +# ------------------------------------------------------------------ +# Extracts the raw_samples list from a fit object. +# +# @param fit A bgms or bgmCompare object (S7 or legacy S3). +# +# Returns: A list with components main, pairwise, indicator (if present), +# nchains, niter, parameter_names. +# ------------------------------------------------------------------ +get_raw_samples = function(fit) { + if(inherits(fit, "S7_object")) { + fit@raw_samples + } else { + .subset2(fit, "raw_samples") + } +} + + +# ------------------------------------------------------------------ +# get_posterior_mean +# ------------------------------------------------------------------ +# Extracts a named posterior_mean_* field from a fit object. +# +# @param fit A bgms or bgmCompare object (S7 or legacy S3). +# @param field The suffix after "posterior_mean_", e.g. "associations", +# "main", "indicator", "residual_variance", "allocations", +# "associations_baseline", "associations_differences". +# +# Returns: The posterior mean value (matrix, vector, or list), or NULL +# if the field does not exist. +# ------------------------------------------------------------------ +get_posterior_mean = function(fit, field) { + if(inherits(fit, "S7_object")) { + S7::prop(fit, paste0("posterior_mean_", field)) + } else { + .subset2(fit, paste0("posterior_mean_", field)) + } +} diff --git a/R/mcmc_summary.R b/R/mcmc_summary.R index 18c287f1..5f792dc8 100644 --- a/R/mcmc_summary.R +++ b/R/mcmc_summary.R @@ -14,7 +14,7 @@ # Returns: invisible(NULL). Results are stored in fit$cache. # ------------------------------------------------------------------ ensure_summaries = function(fit) { - cache = fit$cache + cache = get_fit_cache(fit) if(is.null(cache)) { return(invisible(NULL)) } diff --git a/R/simulate_predict.R b/R/simulate_predict.R index 475da3bf..63f02d07 100644 --- a/R/simulate_predict.R +++ b/R/simulate_predict.R @@ -707,8 +707,8 @@ simulate.bgms = function(object, if(method == "posterior-mean") { # Use posterior mean parameters - pairwise = object$posterior_mean_associations - main = object$posterior_mean_main + pairwise = get_posterior_mean(object, "associations") + main = get_posterior_mean(object, "main") # Set R's RNG for simulate_mrf if(!is.null(seed)) set.seed(seed) @@ -729,8 +729,9 @@ simulate.bgms = function(object, return(result) } else { # Use posterior samples with parallel processing - pairwise_samples = do.call(rbind, object$raw_samples$pairwise) - main_samples = do.call(rbind, object$raw_samples$main) + raw = get_raw_samples(object) + pairwise_samples = do.call(rbind, raw$pairwise) + main_samples = do.call(rbind, raw$main) total_draws = nrow(pairwise_samples) if(is.null(ndraws)) { @@ -1159,8 +1160,8 @@ predict.bgms = function(object, if(method == "posterior-mean") { # Use posterior mean parameters - pairwise = object$posterior_mean_associations - main = object$posterior_mean_main + pairwise = get_posterior_mean(object, "associations") + main = get_posterior_mean(object, "main") probs = compute_conditional_probs( observations = newdata_recoded, @@ -1181,8 +1182,9 @@ predict.bgms = function(object, } } else { # Use posterior samples - pairwise_samples = do.call(rbind, object$raw_samples$pairwise) - main_samples = do.call(rbind, object$raw_samples$main) + raw = get_raw_samples(object) + pairwise_samples = do.call(rbind, raw$pairwise) + main_samples = do.call(rbind, raw$main) total_draws = nrow(pairwise_samples) if(is.null(ndraws)) { @@ -1607,8 +1609,8 @@ predict_bgms_ggm = function(object, newdata, predict_vars, data_columnnames, if(method == "posterior-mean") { # Reconstruct precision matrix from posterior means omega = reconstruct_precision( - object$posterior_mean_associations, - object$posterior_mean_residual_variance + get_posterior_mean(object, "associations"), + get_posterior_mean(object, "residual_variance") ) result = compute_conditional_ggm( @@ -1627,8 +1629,9 @@ predict_bgms_ggm = function(object, newdata, predict_vars, data_columnnames, } } else { # Use posterior samples - pairwise_samples = do.call(rbind, object$raw_samples$pairwise) - main_samples = do.call(rbind, object$raw_samples$main) + raw = get_raw_samples(object) + pairwise_samples = do.call(rbind, raw$pairwise) + main_samples = do.call(rbind, raw$main) total_draws = nrow(pairwise_samples) if(is.null(ndraws)) { @@ -1724,8 +1727,8 @@ simulate_bgms_ggm = function(object, nsim, seed, method, ndraws, if(method == "posterior-mean") { # Reconstruct precision matrix from off-diagonal + separate diagonal precision = reconstruct_precision( - object$posterior_mean_associations, - object$posterior_mean_residual_variance + get_posterior_mean(object, "associations"), + get_posterior_mean(object, "residual_variance") ) # Call simulate_mrf with variable_type = "continuous" @@ -1742,8 +1745,9 @@ simulate_bgms_ggm = function(object, nsim, seed, method, ndraws, return(result) } else { # Use posterior samples with parallel processing - pairwise_samples = do.call(rbind, object$raw_samples$pairwise) - main_samples = do.call(rbind, object$raw_samples$main) + raw = get_raw_samples(object) + pairwise_samples = do.call(rbind, raw$pairwise) + main_samples = do.call(rbind, raw$main) total_draws = nrow(pairwise_samples) if(is.null(ndraws)) { @@ -2050,7 +2054,7 @@ build_mixed_params_mean = function(object, arguments) { disc_idx = arguments$discrete_indices cont_idx = arguments$continuous_indices - pmat = object$posterior_mean_associations + pmat = get_posterior_mean(object, "associations") pairwise_disc = matrix(0, p, p) for(i in seq_len(p)) { @@ -2073,15 +2077,16 @@ build_mixed_params_mean = function(object, arguments) { } } # Convert residual variance back to association-scale diagonal - rv = object$posterior_mean_residual_variance + rv = get_posterior_mean(object, "residual_variance") for(j in seq_len(q)) { pairwise_cont[j, j] = -1 / (2 * rv[j]) } - mux = object$posterior_mean_main$discrete + pm_main = get_posterior_mean(object, "main") + mux = pm_main$discrete mux[is.na(mux)] = 0 - muy = as.numeric(object$posterior_mean_main$continuous[, "mean"]) + muy = as.numeric(pm_main$continuous[, "mean"]) list(pairwise_disc = pairwise_disc, pairwise_cross = pairwise_cross, pairwise_cont = pairwise_cont, mux = mux, muy = muy) } @@ -2105,8 +2110,9 @@ split_mixed_raw_samples = function(object, arguments) { num_categories = arguments$num_categories is_ordinal = arguments$is_ordinal - main_all = do.call(rbind, object$raw_samples$main) - pairwise_all = do.call(rbind, object$raw_samples$pairwise) + raw = get_raw_samples(object) + main_all = do.call(rbind, raw$main) + pairwise_all = do.call(rbind, raw$pairwise) total_draws = nrow(main_all) # Main layout: [mux_flat | muy | cont_diag] diff --git a/man/cash-.bgmCompare.Rd b/man/cash-.bgmCompare.Rd index 6de878a1..32c8dec4 100644 --- a/man/cash-.bgmCompare.Rd +++ b/man/cash-.bgmCompare.Rd @@ -20,8 +20,9 @@ The requested element. } \description{ -Intercepts access to \code{posterior_summary_*} fields and -triggers lazy computation from cache when needed. All other fields pass -through using standard list extraction. +Provides \code{$} access to S7 properties. Lazy +\code{posterior_summary_*} properties trigger computation on first +access via S7 property getters. Also supports legacy S3 list-based +fit objects. } \keyword{internal} diff --git a/man/cash-.bgms.Rd b/man/cash-.bgms.Rd index 41c681be..a50be4c0 100644 --- a/man/cash-.bgms.Rd +++ b/man/cash-.bgms.Rd @@ -20,8 +20,9 @@ The requested element. } \description{ -Intercepts access to \code{posterior_summary_*} fields and -triggers lazy computation from cache when needed. All other fields pass -through using standard list extraction. +Provides \code{$} access to S7 properties. Lazy +\code{posterior_summary_*} properties trigger computation on first +access via S7 property getters. Also supports legacy S3 list-based +fit objects. } \keyword{internal} diff --git a/tests/testthat/test-fit-object-contract.R b/tests/testthat/test-fit-object-contract.R new file mode 100644 index 00000000..1c72d16e --- /dev/null +++ b/tests/testthat/test-fit-object-contract.R @@ -0,0 +1,365 @@ +# ============================================================================== +# Phase 0: Fit-Object Contract Tests +# ============================================================================== +# +# Regression tests for the fit-object contract documented in +# dev/plans/fit-object-contract.md. These lock down: +# +# 1. Serialization round-trips (saveRDS / readRDS) +# 2. Lazy summary computation semantics +# 3. names(fit) stability +# +# These tests exist to guard the contract before any structural refactor +# (e.g. S7 migration). +# ============================================================================== + + +# ============================================================================== +# 1. Serialization Round-Trips +# ============================================================================== + +test_that("bgms fit survives saveRDS/readRDS without prior summary access", { + fit = get_bgms_fit() + tmp = tempfile(fileext = ".rds") + on.exit(unlink(tmp), add = TRUE) + + saveRDS(fit, tmp) + restored = readRDS(tmp) + + expect_s3_class(restored, "bgms") + expect_equal(names(restored), names(fit)) + expect_equal(restored$arguments, fit$arguments) + expect_equal( + restored$posterior_mean_associations, + fit$posterior_mean_associations + ) + + # Lazy summaries must still work after deserialization + s = restored$posterior_summary_pairwise + expect_true(is.data.frame(s) || is.matrix(s)) + expect_true(nrow(s) > 0) +}) + +test_that("bgms fit survives saveRDS/readRDS after summary access", { + fit = get_bgms_fit() + # Force summary computation before saving + s_before = fit$posterior_summary_pairwise + + tmp = tempfile(fileext = ".rds") + on.exit(unlink(tmp), add = TRUE) + + saveRDS(fit, tmp) + restored = readRDS(tmp) + + s_after = restored$posterior_summary_pairwise + expect_equal(s_after, s_before) +}) + +test_that("bgmCompare fit survives saveRDS/readRDS without prior summary access", { + fit = get_bgmcompare_fit() + tmp = tempfile(fileext = ".rds") + on.exit(unlink(tmp), add = TRUE) + + saveRDS(fit, tmp) + restored = readRDS(tmp) + + expect_s3_class(restored, "bgmCompare") + expect_equal(names(restored), names(fit)) + expect_equal(restored$arguments, fit$arguments) + + # Lazy summaries must still work after deserialization + s = restored$posterior_summary_pairwise_baseline + expect_true(is.data.frame(s) || is.matrix(s)) + expect_true(nrow(s) > 0) +}) + +test_that("bgmCompare fit survives saveRDS/readRDS after summary access", { + fit = get_bgmcompare_fit() + s_before = fit$posterior_summary_pairwise_baseline + + tmp = tempfile(fileext = ".rds") + on.exit(unlink(tmp), add = TRUE) + + saveRDS(fit, tmp) + restored = readRDS(tmp) + + s_after = restored$posterior_summary_pairwise_baseline + expect_equal(s_after, s_before) +}) + +test_that("GGM fit survives saveRDS/readRDS", { + fit = get_bgms_fit_ggm() + tmp = tempfile(fileext = ".rds") + on.exit(unlink(tmp), add = TRUE) + + saveRDS(fit, tmp) + restored = readRDS(tmp) + + expect_s3_class(restored, "bgms") + expect_equal(names(restored), names(fit)) + expect_equal( + restored$posterior_mean_residual_variance, + fit$posterior_mean_residual_variance + ) + + s = restored$posterior_summary_pairwise + expect_true(is.data.frame(s) || is.matrix(s)) +}) + + +# ============================================================================== +# 2. Lazy Summary Computation Semantics +# ============================================================================== + +test_that("fresh fit starts with summaries_computed = FALSE", { + # Use a dedicated fit to avoid shared-fixture interference + data("ADHD", package = "bgms") + fit = bgm( + ADHD[1:30, 2:5], + iter = 25, warmup = 50, chains = 1, + seed = 99998, + display_progress = "none" + ) + cache = fit$cache + expect_false(isTRUE(cache$summaries_computed)) + + s = fit$posterior_summary_pairwise + expect_true(is.data.frame(s) || is.matrix(s)) + expect_true(isTRUE(cache$summaries_computed)) +}) + +test_that("accessing posterior_summary_pairwise populates cache (bgms)", { + fit = get_bgms_fit() + cache = fit$cache + + # Access triggers computation (or returns cached) + s = fit$posterior_summary_pairwise + expect_true(is.data.frame(s) || is.matrix(s)) + expect_true(nrow(s) > 0) + + # Cache must be populated after access + expect_true(isTRUE(cache$summaries_computed)) +}) + +test_that("second access returns cached result without recomputation (bgms)", { + fit = get_bgms_fit() + + # First access + s1 = fit$posterior_summary_pairwise + + # Second access — should be identical (same object from cache) + s2 = fit$posterior_summary_pairwise + expect_identical(s1, s2) +}) + +test_that("accessing posterior_summary_pairwise_baseline populates cache (bgmCompare)", { + fit = get_bgmcompare_fit() + cache = fit$cache + + s = fit$posterior_summary_pairwise_baseline + expect_true(is.data.frame(s) || is.matrix(s)) + expect_true(nrow(s) > 0) + + expect_true(isTRUE(cache$summaries_computed)) +}) + +test_that("summary() triggers lazy computation (bgms)", { + fit = get_bgms_fit() + cache = fit$cache + + sm = summary(fit) + expect_true(isTRUE(cache$summaries_computed)) + expect_s3_class(sm, "summary.bgms") +}) + +test_that("summary() triggers lazy computation (bgmCompare)", { + fit = get_bgmcompare_fit() + cache = fit$cache + + sm = summary(fit) + expect_true(isTRUE(cache$summaries_computed)) + expect_s3_class(sm, "summary.bgmCompare") +}) + + +# ============================================================================== +# 3. names(fit) Stability +# ============================================================================== + +test_that("bgms fit (edge selection) has all required names", { + fit = get_bgms_fit() + nm = names(fit) + + # Core fields + expect_true("arguments" %in% nm) + expect_true("raw_samples" %in% nm) + expect_true("posterior_mean_associations" %in% nm) + expect_true("cache" %in% nm) + + # Lazy summary placeholders + expect_true("posterior_summary_main" %in% nm) + expect_true("posterior_summary_pairwise" %in% nm) + + # Edge selection fields + expect_true("posterior_mean_indicator" %in% nm) + expect_true("posterior_summary_indicator" %in% nm) +}) + +test_that("bgms fit (no edge selection) does not have indicator fields", { + fit = get_bgms_fit_ggm_no_es() + nm = names(fit) + + expect_true("arguments" %in% nm) + expect_true("raw_samples" %in% nm) + expect_true("posterior_mean_associations" %in% nm) + expect_true("posterior_summary_main" %in% nm) + expect_true("posterior_summary_pairwise" %in% nm) + expect_false("posterior_mean_indicator" %in% nm) + expect_false("posterior_summary_indicator" %in% nm) +}) + +test_that("bgms GGM fit has residual variance but no main effects", { + fit = get_bgms_fit_ggm() + nm = names(fit) + + expect_true("posterior_mean_residual_variance" %in% nm) + expect_null(fit$posterior_mean_main) +}) + +test_that("bgms OMRF fit has main effects", { + fit = get_bgms_fit_ordinal() + expect_false(is.null(fit$posterior_mean_main)) +}) + +test_that("bgmCompare fit has all required names", { + fit = get_bgmcompare_fit() + nm = names(fit) + + expect_true("arguments" %in% nm) + expect_true("raw_samples" %in% nm) + expect_true("cache" %in% nm) + + # Baseline fields + expect_true("posterior_mean_main_baseline" %in% nm) + expect_true("posterior_mean_associations_baseline" %in% nm) + + # Difference fields + expect_true("posterior_mean_main_differences" %in% nm) + expect_true("posterior_mean_associations_differences" %in% nm) + + # Lazy summary placeholders + expect_true("posterior_summary_main_baseline" %in% nm) + expect_true("posterior_summary_pairwise_baseline" %in% nm) + expect_true("posterior_summary_main_differences" %in% nm) + expect_true("posterior_summary_pairwise_differences" %in% nm) +}) + +test_that("bgmCompare fit (with difference selection) has indicator summary", { + fit = get_bgmcompare_fit_beta_bernoulli() + nm = names(fit) + + expect_true("posterior_summary_indicator" %in% nm) +}) + + +# ============================================================================== +# 4. Parameterized: all bgms fixtures survive serialization +# ============================================================================== + +for(spec in get_bgms_fixtures()) { + test_that( + sprintf("saveRDS/readRDS round-trip preserves structure (%s)", spec$label), + { + fit = spec$get_fit() + tmp = tempfile(fileext = ".rds") + on.exit(unlink(tmp), add = TRUE) + + saveRDS(fit, tmp) + restored = readRDS(tmp) + + expect_s3_class(restored, "bgms") + expect_equal(names(restored), names(fit)) + expect_equal(restored$arguments, fit$arguments) + } + ) +} + +for(spec in get_bgmcompare_fixtures()) { + test_that( + sprintf("saveRDS/readRDS round-trip preserves structure (%s)", spec$label), + { + fit = spec$get_fit() + tmp = tempfile(fileext = ".rds") + on.exit(unlink(tmp), add = TRUE) + + saveRDS(fit, tmp) + restored = readRDS(tmp) + + expect_s3_class(restored, "bgmCompare") + expect_equal(names(restored), names(fit)) + expect_equal(restored$arguments, fit$arguments) + } + ) +} + + +# ============================================================================== +# 5. Parameterized: lazy summaries work for all bgms fixtures +# ============================================================================== + +for(spec in get_bgms_fixtures()) { + test_that( + sprintf("lazy summaries compute on first access (%s)", spec$label), + { + fit = spec$get_fit() + cache = fit$cache + + # After fixture caching the summaries may already be computed from + # other tests. Reset the flag to test the lazy path. + # (This is safe because the fixture cache returns the same object.) + # Instead, just verify that accessing a summary field returns data. + summary_fields = grep( + "^posterior_summary_", names(fit), value = TRUE + ) + for(field in summary_fields) { + val = fit[[field]] + # Some fields may legitimately be NULL (e.g. main for GGM) + if(!is.null(val)) { + expect_true( + is.data.frame(val) || is.matrix(val), + info = sprintf("[%s] %s is not a data.frame/matrix", spec$label, field) + ) + expect_true( + nrow(val) > 0, + info = sprintf("[%s] %s has zero rows", spec$label, field) + ) + } + } + } + ) +} + +for(spec in get_bgmcompare_fixtures()) { + test_that( + sprintf("lazy summaries compute on first access (%s)", spec$label), + { + fit = spec$get_fit() + summary_fields = grep( + "^posterior_summary_", names(fit), value = TRUE + ) + for(field in summary_fields) { + val = fit[[field]] + if(!is.null(val)) { + expect_true( + is.data.frame(val) || is.matrix(val), + info = sprintf("[%s] %s is not a data.frame/matrix", spec$label, field) + ) + expect_true( + nrow(val) > 0, + info = sprintf("[%s] %s has zero rows", spec$label, field) + ) + } + } + } + ) +}