From 07014d76ed2539c2ab1537e9226fec3d1b72f869 Mon Sep 17 00:00:00 2001 From: Don van den Bergh Date: Fri, 27 Mar 2026 13:29:17 +0100 Subject: [PATCH] add callback function for JASP --- R/RcppExports.R | 16 +++++++-------- R/bgm.R | 12 +++++++++++- R/bgmCompare.R | 11 ++++++++++- R/bgm_spec.R | 9 ++++++--- R/run_sampler.R | 5 ++++- R/validate_sampler.R | 6 ++++-- man/bgm.Rd | 6 ++++++ man/bgmCompare.Rd | 6 ++++++ src/RcppExports.cpp | 36 +++++++++++++++++++--------------- src/bgmCompare_interface.cpp | 7 +++++-- src/sample_ggm.cpp | 3 ++- src/sample_mixed.cpp | 4 +++- src/sample_omrf.cpp | 4 +++- src/utils/progress_manager.cpp | 32 +++++++++++++++++++++++------- src/utils/progress_manager.h | 7 ++++++- 15 files changed, 119 insertions(+), 45 deletions(-) diff --git a/R/RcppExports.R b/R/RcppExports.R index 2bd030d0..955cb1c0 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,8 +1,8 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 -run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) { - .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) +run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type, progress_callback = NULL) { + .Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type, progress_callback) } get_explog_switch <- function() { @@ -117,16 +117,16 @@ run_mixed_simulation_parallel <- function(mux_samples, disc_samples, muy_samples .Call(`_bgms_run_mixed_simulation_parallel`, mux_samples, disc_samples, muy_samples, cont_samples, cross_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type) } -sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) { - .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable) +sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) { + .Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable) } -sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) { - .Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable) +sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) { + .Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable) } -sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL) { - .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable) +sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL) { + .Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable) } compute_Vn_mfm_sbm <- function(num_variables, dirichlet_alpha, t_max, lambda) { diff --git a/R/bgm.R b/R/bgm.R index ad0af91f..8c13567f 100644 --- a/R/bgm.R +++ b/R/bgm.R @@ -273,6 +273,14 @@ #' \code{"total"} (single combined bar), or \code{"none"} (no progress). #' Default: \code{"per-chain"}. #' +#' @param progress_callback An optional R function with signature +#' \code{function(completed, total)} that is called at regular intervals +#' during sampling, where \code{completed} is the number of iterations +#' completed across all chains and \code{total} is the total number of +#' iterations. Useful for external front-ends (e.g., JASP) that supply +#' their own progress reporting. +#' When \code{NULL} (the default), no callback is invoked. +#' #' @param verbose Logical. If \code{TRUE}, prints informational messages #' during data processing (e.g., missing data handling, variable recoding). #' Defaults to \code{getOption("bgms.verbose", TRUE)}. Set @@ -447,6 +455,7 @@ bgm = function( standardize = FALSE, pseudolikelihood = c("conditional", "marginal"), verbose = getOption("bgms.verbose", TRUE), + progress_callback = NULL, interaction_scale, burnin, save, @@ -519,7 +528,8 @@ bgm = function( seed = seed, display_progress = display_progress, verbose = verbose, - pseudolikelihood = pseudolikelihood + pseudolikelihood = pseudolikelihood, + progress_callback = progress_callback ) raw = run_sampler(spec) diff --git a/R/bgmCompare.R b/R/bgmCompare.R index 648a3c56..a88f17b4 100644 --- a/R/bgmCompare.R +++ b/R/bgmCompare.R @@ -122,6 +122,13 @@ #' @param display_progress Character. Controls progress reporting: #' \code{"per-chain"}, \code{"total"}, or \code{"none"}. #' Default: \code{"per-chain"}. +#' @param progress_callback An optional R function with signature +#' \code{function(completed, total)} that is called at regular intervals +#' during sampling, where \code{completed} is the number of iterations +#' completed across all chains and \code{total} is the total number of +#' iterations. Useful for external front-ends (e.g., JASP) that supply +#' their own progress reporting. +#' When \code{NULL} (the default), no callback is invoked. #' @param verbose Logical. If \code{TRUE}, prints informational messages #' during data processing (e.g., missing data handling, variable recoding). #' Defaults to \code{getOption("bgms.verbose", TRUE)}. Set @@ -225,6 +232,7 @@ bgmCompare = function( seed = NULL, standardize = FALSE, verbose = getOption("bgms.verbose", TRUE), + progress_callback = NULL, main_difference_model, reference_category, main_difference_scale, @@ -366,7 +374,8 @@ bgmCompare = function( cores = cores, seed = seed, display_progress = display_progress, - verbose = verbose + verbose = verbose, + progress_callback = progress_callback ) raw = run_sampler(spec) diff --git a/R/bgm_spec.R b/R/bgm_spec.R index 243547f5..1a9a4afb 100644 --- a/R/bgm_spec.R +++ b/R/bgm_spec.R @@ -285,7 +285,8 @@ bgm_spec = function(x, seed = NULL, display_progress = c("per-chain", "total", "none"), verbose = TRUE, - pseudolikelihood = c("conditional", "marginal")) { + pseudolikelihood = c("conditional", "marginal"), + progress_callback = NULL) { model_type = match.arg(model_type) na_action = tryCatch(match.arg(na_action), error = function(e) { stop(paste0( @@ -340,7 +341,8 @@ bgm_spec = function(x, display_progress = display_progress, is_continuous = is_continuous, edge_selection = if(model_type == "compare") FALSE else edge_selection, - verbose = verbose + verbose = verbose, + progress_callback = progress_callback ) # --- Build by model type ---------------------------------------------------- @@ -1079,7 +1081,8 @@ sampler_sublist = function(s) { nuts_max_depth = as.integer(s$nuts_max_depth), learn_mass_matrix = s$learn_mass_matrix, seed = as.integer(s$seed), - progress_type = as.integer(s$progress_type) + progress_type = as.integer(s$progress_type), + progress_callback = s$progress_callback ) } diff --git a/R/run_sampler.R b/R/run_sampler.R index 3eb7b684..bf6a0cc2 100644 --- a/R/run_sampler.R +++ b/R/run_sampler.R @@ -137,6 +137,7 @@ run_sampler_omrf = function(spec) { no_chains = s$chains, no_threads = s$cores, progress_type = s$progress_type, + progress_callback = s$progress_callback, edge_selection = p$edge_selection, sampler_type = s$update_method, seed = s$seed, @@ -198,6 +199,7 @@ run_sampler_mixed_mrf = function(spec) { seed = s$seed, no_threads = s$cores, progress_type = s$progress_type, + progress_callback = s$progress_callback, edge_prior = p$edge_prior, beta_bernoulli_alpha = p$beta_bernoulli_alpha, beta_bernoulli_beta = p$beta_bernoulli_beta, @@ -267,6 +269,7 @@ run_sampler_compare = function(spec) { seed = s$seed, update_method = s$update_method, hmc_num_leapfrogs = s$hmc_num_leapfrogs, - progress_type = s$progress_type + progress_type = s$progress_type, + progress_callback = s$progress_callback ) } diff --git a/R/validate_sampler.R b/R/validate_sampler.R index 901de223..e84f4389 100644 --- a/R/validate_sampler.R +++ b/R/validate_sampler.R @@ -99,7 +99,8 @@ validate_sampler = function(update_method, display_progress = c("per-chain", "total", "none"), is_continuous = FALSE, edge_selection = FALSE, - verbose = TRUE) { + verbose = TRUE, + progress_callback = NULL) { # --- update_method ---------------------------------------------------------- user_chose_method = length(update_method) == 1 update_method = match.arg( @@ -207,6 +208,7 @@ validate_sampler = function(update_method, chains = chains, cores = cores, seed = seed, - progress_type = progress_type + progress_type = progress_type, + progress_callback = progress_callback ) } diff --git a/man/bgm.Rd b/man/bgm.Rd index 19a38c58..bb193c79 100644 --- a/man/bgm.Rd +++ b/man/bgm.Rd @@ -35,6 +35,7 @@ bgm( standardize = FALSE, pseudolikelihood = c("conditional", "marginal"), verbose = getOption("bgms.verbose", TRUE), + progress_callback = NULL, interaction_scale, burnin, save, @@ -191,6 +192,11 @@ during data processing (e.g., missing data handling, variable recoding). Defaults to \code{getOption("bgms.verbose", TRUE)}. Set \code{options(bgms.verbose = FALSE)} to suppress messages globally.} +\item{progress_callback}{An optional R function (taking no arguments) that +is called at regular intervals during sampling. Useful for external +front-ends (e.g., JASP) that supply their own progress reporting. +When \code{NULL} (the default), no callback is invoked.} + \item{interaction_scale, burnin, save, threshold_alpha, threshold_beta}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} Deprecated arguments as of \strong{bgms 0.1.6.0}. Use \code{pairwise_scale}, \code{warmup}, \code{main_alpha}, and \code{main_beta} instead.} diff --git a/man/bgmCompare.Rd b/man/bgmCompare.Rd index 8efdc028..9ec50985 100644 --- a/man/bgmCompare.Rd +++ b/man/bgmCompare.Rd @@ -34,6 +34,7 @@ bgmCompare( seed = NULL, standardize = FALSE, verbose = getOption("bgms.verbose", TRUE), + progress_callback = NULL, main_difference_model, reference_category, main_difference_scale, @@ -151,6 +152,11 @@ during data processing (e.g., missing data handling, variable recoding). Defaults to \code{getOption("bgms.verbose", TRUE)}. Set \code{options(bgms.verbose = FALSE)} to suppress messages globally.} +\item{progress_callback}{An optional R function (taking no arguments) that +is called at regular intervals during sampling. Useful for external +front-ends (e.g., JASP) that supply their own progress reporting. +When \code{NULL} (the default), no callback is invoked.} + \item{main_difference_model, reference_category, pairwise_difference_scale, main_difference_scale, pairwise_difference_prior, main_difference_prior, pairwise_difference_probability, main_difference_probability, pairwise_beta_bernoulli_alpha, pairwise_beta_bernoulli_beta, main_beta_bernoulli_alpha, main_beta_bernoulli_beta, interaction_scale, threshold_alpha, threshold_beta, burnin, save}{\ifelse{html}{\href{https://lifecycle.r-lib.org/articles/stages.html#deprecated}{\figure{lifecycle-deprecated.svg}{options: alt='[Deprecated]'}}}{\strong{[Deprecated]}} Deprecated arguments as of \strong{bgms 0.1.6.0}. Use \code{difference_scale}, \code{difference_prior}, \code{difference_probability}, diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index ad4e8d3c..e4993e5a 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -12,8 +12,8 @@ Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif // run_bgmCompare_parallel -Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector& counts_per_category, const std::vector& blume_capel_stats, const std::vector& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, const arma::mat& pairwise_scaling_factors, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int warmup, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, bool main_difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs, int progress_type); -RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP pairwise_scaling_factorsSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP progress_typeSEXP) { +Rcpp::List run_bgmCompare_parallel(const arma::imat& observations, int num_groups, const std::vector& counts_per_category, const std::vector& blume_capel_stats, const std::vector& pairwise_stats, const arma::ivec& num_categories, double main_alpha, double main_beta, double pairwise_scale, const arma::mat& pairwise_scaling_factors, double difference_scale, double difference_selection_alpha, double difference_selection_beta, const std::string& difference_prior, int iter, int warmup, bool na_impute, const arma::imat& missing_data_indices, const arma::uvec& is_ordinal_variable, const arma::ivec& baseline_category, bool difference_selection, bool main_difference_selection, const arma::imat& main_effect_indices, const arma::imat& pairwise_effect_indices, double target_accept, int nuts_max_depth, bool learn_mass_matrix, const arma::mat& projection, const arma::ivec& group_membership, const arma::imat& group_indices, const arma::imat& interaction_index_matrix, const arma::mat& inclusion_probability, int num_chains, int nThreads, int seed, const std::string& update_method, int hmc_num_leapfrogs, int progress_type, SEXP progress_callback); +RcppExport SEXP _bgms_run_bgmCompare_parallel(SEXP observationsSEXP, SEXP num_groupsSEXP, SEXP counts_per_categorySEXP, SEXP blume_capel_statsSEXP, SEXP pairwise_statsSEXP, SEXP num_categoriesSEXP, SEXP main_alphaSEXP, SEXP main_betaSEXP, SEXP pairwise_scaleSEXP, SEXP pairwise_scaling_factorsSEXP, SEXP difference_scaleSEXP, SEXP difference_selection_alphaSEXP, SEXP difference_selection_betaSEXP, SEXP difference_priorSEXP, SEXP iterSEXP, SEXP warmupSEXP, SEXP na_imputeSEXP, SEXP missing_data_indicesSEXP, SEXP is_ordinal_variableSEXP, SEXP baseline_categorySEXP, SEXP difference_selectionSEXP, SEXP main_difference_selectionSEXP, SEXP main_effect_indicesSEXP, SEXP pairwise_effect_indicesSEXP, SEXP target_acceptSEXP, SEXP nuts_max_depthSEXP, SEXP learn_mass_matrixSEXP, SEXP projectionSEXP, SEXP group_membershipSEXP, SEXP group_indicesSEXP, SEXP interaction_index_matrixSEXP, SEXP inclusion_probabilitySEXP, SEXP num_chainsSEXP, SEXP nThreadsSEXP, SEXP seedSEXP, SEXP update_methodSEXP, SEXP hmc_num_leapfrogsSEXP, SEXP progress_typeSEXP, SEXP progress_callbackSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -55,7 +55,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const std::string& >::type update_method(update_methodSEXP); Rcpp::traits::input_parameter< int >::type hmc_num_leapfrogs(hmc_num_leapfrogsSEXP); Rcpp::traits::input_parameter< int >::type progress_type(progress_typeSEXP); - rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type)); + Rcpp::traits::input_parameter< SEXP >::type progress_callback(progress_callbackSEXP); + rcpp_result_gen = Rcpp::wrap(run_bgmCompare_parallel(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type, progress_callback)); return rcpp_result_gen; END_RCPP } @@ -516,8 +517,8 @@ BEGIN_RCPP END_RCPP } // sample_ggm -Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const bool na_impute, const Rcpp::Nullable missing_index_nullable); -RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP) { +Rcpp::List sample_ggm(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, SEXP progress_callback, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const bool na_impute, const Rcpp::Nullable missing_index_nullable); +RcppExport SEXP _bgms_sample_ggm(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP progress_callbackSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -532,6 +533,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< SEXP >::type progress_callback(progress_callbackSEXP); Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); @@ -543,13 +545,13 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_nullable(missing_index_nullableSEXP); - rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)); + rcpp_result_gen = Rcpp::wrap(sample_ggm(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)); return rcpp_result_gen; END_RCPP } // sample_mixed_mrf -Rcpp::List sample_mixed_mrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const std::string& sampler_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const bool na_impute, const Rcpp::Nullable missing_index_discrete_nullable, const Rcpp::Nullable missing_index_continuous_nullable); -RcppExport SEXP _bgms_sample_mixed_mrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP sampler_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP na_imputeSEXP, SEXP missing_index_discrete_nullableSEXP, SEXP missing_index_continuous_nullableSEXP) { +Rcpp::List sample_mixed_mrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const int seed, const int no_threads, const int progress_type, SEXP progress_callback, const std::string& edge_prior, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const std::string& sampler_type, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const bool na_impute, const Rcpp::Nullable missing_index_discrete_nullable, const Rcpp::Nullable missing_index_continuous_nullable); +RcppExport SEXP _bgms_sample_mixed_mrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP progress_callbackSEXP, SEXP edge_priorSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP sampler_typeSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP na_imputeSEXP, SEXP missing_index_discrete_nullableSEXP, SEXP missing_index_continuous_nullableSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -563,6 +565,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< SEXP >::type progress_callback(progress_callbackSEXP); Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); Rcpp::traits::input_parameter< const double >::type beta_bernoulli_alpha(beta_bernoulli_alphaSEXP); Rcpp::traits::input_parameter< const double >::type beta_bernoulli_beta(beta_bernoulli_betaSEXP); @@ -577,13 +580,13 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_discrete_nullable(missing_index_discrete_nullableSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_continuous_nullable(missing_index_continuous_nullableSEXP); - rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)); + rcpp_result_gen = Rcpp::wrap(sample_mixed_mrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)); return rcpp_result_gen; END_RCPP } // sample_omrf -Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const Rcpp::Nullable pairwise_scaling_factors_nullable); -RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP pairwise_scaling_factors_nullableSEXP) { +Rcpp::List sample_omrf(const Rcpp::List& inputFromR, const arma::mat& prior_inclusion_prob, const arma::imat& initial_edge_indicators, const int no_iter, const int no_warmup, const int no_chains, const bool edge_selection, const std::string& sampler_type, const int seed, const int no_threads, const int progress_type, SEXP progress_callback, const std::string& edge_prior, const bool na_impute, const Rcpp::Nullable missing_index_nullable, const double beta_bernoulli_alpha, const double beta_bernoulli_beta, const double beta_bernoulli_alpha_between, const double beta_bernoulli_beta_between, const double dirichlet_alpha, const double lambda, const double target_acceptance, const int max_tree_depth, const int num_leapfrogs, const Rcpp::Nullable pairwise_scaling_factors_nullable); +RcppExport SEXP _bgms_sample_omrf(SEXP inputFromRSEXP, SEXP prior_inclusion_probSEXP, SEXP initial_edge_indicatorsSEXP, SEXP no_iterSEXP, SEXP no_warmupSEXP, SEXP no_chainsSEXP, SEXP edge_selectionSEXP, SEXP sampler_typeSEXP, SEXP seedSEXP, SEXP no_threadsSEXP, SEXP progress_typeSEXP, SEXP progress_callbackSEXP, SEXP edge_priorSEXP, SEXP na_imputeSEXP, SEXP missing_index_nullableSEXP, SEXP beta_bernoulli_alphaSEXP, SEXP beta_bernoulli_betaSEXP, SEXP beta_bernoulli_alpha_betweenSEXP, SEXP beta_bernoulli_beta_betweenSEXP, SEXP dirichlet_alphaSEXP, SEXP lambdaSEXP, SEXP target_acceptanceSEXP, SEXP max_tree_depthSEXP, SEXP num_leapfrogsSEXP, SEXP pairwise_scaling_factors_nullableSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -598,6 +601,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type seed(seedSEXP); Rcpp::traits::input_parameter< const int >::type no_threads(no_threadsSEXP); Rcpp::traits::input_parameter< const int >::type progress_type(progress_typeSEXP); + Rcpp::traits::input_parameter< SEXP >::type progress_callback(progress_callbackSEXP); Rcpp::traits::input_parameter< const std::string& >::type edge_prior(edge_priorSEXP); Rcpp::traits::input_parameter< const bool >::type na_impute(na_imputeSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type missing_index_nullable(missing_index_nullableSEXP); @@ -611,7 +615,7 @@ BEGIN_RCPP Rcpp::traits::input_parameter< const int >::type max_tree_depth(max_tree_depthSEXP); Rcpp::traits::input_parameter< const int >::type num_leapfrogs(num_leapfrogsSEXP); Rcpp::traits::input_parameter< const Rcpp::Nullable >::type pairwise_scaling_factors_nullable(pairwise_scaling_factors_nullableSEXP); - rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable)); + rcpp_result_gen = Rcpp::wrap(sample_omrf(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable)); return rcpp_result_gen; END_RCPP } @@ -631,7 +635,7 @@ END_RCPP } static const R_CallMethodDef CallEntries[] = { - {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 38}, + {"_bgms_run_bgmCompare_parallel", (DL_FUNC) &_bgms_run_bgmCompare_parallel, 39}, {"_bgms_get_explog_switch", (DL_FUNC) &_bgms_get_explog_switch, 0}, {"_bgms_rcpp_ieee754_exp", (DL_FUNC) &_bgms_rcpp_ieee754_exp, 1}, {"_bgms_rcpp_ieee754_log", (DL_FUNC) &_bgms_rcpp_ieee754_log, 1}, @@ -660,9 +664,9 @@ static const R_CallMethodDef CallEntries[] = { {"_bgms_run_ggm_simulation_parallel", (DL_FUNC) &_bgms_run_ggm_simulation_parallel, 9}, {"_bgms_sample_mixed_mrf_gibbs", (DL_FUNC) &_bgms_sample_mixed_mrf_gibbs, 11}, {"_bgms_run_mixed_simulation_parallel", (DL_FUNC) &_bgms_run_mixed_simulation_parallel, 16}, - {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 22}, - {"_bgms_sample_mixed_mrf", (DL_FUNC) &_bgms_sample_mixed_mrf, 24}, - {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 24}, + {"_bgms_sample_ggm", (DL_FUNC) &_bgms_sample_ggm, 23}, + {"_bgms_sample_mixed_mrf", (DL_FUNC) &_bgms_sample_mixed_mrf, 25}, + {"_bgms_sample_omrf", (DL_FUNC) &_bgms_sample_omrf, 25}, {"_bgms_compute_Vn_mfm_sbm", (DL_FUNC) &_bgms_compute_Vn_mfm_sbm, 4}, {NULL, NULL, 0} }; diff --git a/src/bgmCompare_interface.cpp b/src/bgmCompare_interface.cpp index ff1730ff..1974a32d 100644 --- a/src/bgmCompare_interface.cpp +++ b/src/bgmCompare_interface.cpp @@ -329,6 +329,8 @@ struct GibbsCompareChainRunner : public Worker { // - seed: Base random seed (incremented per chain). // - update_method: Sampler type ("adaptive-metropolis", "hamiltonian-mc", "nuts"). // - hmc_num_leapfrogs: Number of leapfrog steps for HMC. +// - progress_type: Progress bar style (0 = none, 1 = total, 2 = per-chain). +// - progress_callback: R function (SEXP) called as callback(completed, total) at regular intervals, or R_NilValue. // // Returns: // - Rcpp::List of length `num_chains`, where each element is either: @@ -383,7 +385,8 @@ Rcpp::List run_bgmCompare_parallel( int seed, const std::string& update_method, int hmc_num_leapfrogs, - int progress_type + int progress_type, + SEXP progress_callback = R_NilValue ) { std::vector results(num_chains); @@ -398,7 +401,7 @@ Rcpp::List run_bgmCompare_parallel( // only used to determine the total no. warmup iterations, a bit hacky WarmupSchedule warmup_schedule_temp(warmup, difference_selection, (update_method_enum != adaptive_metropolis)); int total_warmup = warmup_schedule_temp.total_warmup; - ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type); + ProgressManager pm(num_chains, iter, total_warmup, 50, progress_type, true, progress_callback); GibbsCompareChainRunner worker( observations, num_groups, diff --git a/src/sample_ggm.cpp b/src/sample_ggm.cpp index 48ef75ea..943aaf91 100644 --- a/src/sample_ggm.cpp +++ b/src/sample_ggm.cpp @@ -25,6 +25,7 @@ Rcpp::List sample_ggm( const int seed, const int no_threads, const int progress_type, + SEXP progress_callback = R_NilValue, const std::string& edge_prior = "Bernoulli", const double beta_bernoulli_alpha = 1.0, const double beta_bernoulli_beta = 1.0, @@ -64,7 +65,7 @@ Rcpp::List sample_ggm( config.na_impute = na_impute; // Set up progress manager - ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type, true, progress_callback); // Create edge prior EdgePrior edge_prior_enum = edge_prior_from_string(edge_prior); diff --git a/src/sample_mixed.cpp b/src/sample_mixed.cpp index e93b6a2e..0421d544 100644 --- a/src/sample_mixed.cpp +++ b/src/sample_mixed.cpp @@ -34,6 +34,7 @@ // @param seed Random seed // @param no_threads Number of threads for parallel execution // @param progress_type Progress bar type +// @param progress_callback R function (SEXP) called as callback(completed, total) at regular intervals, or R_NilValue // @param edge_prior Edge prior type // @param beta_bernoulli_alpha Beta-Bernoulli alpha hyperparameter // @param beta_bernoulli_beta Beta-Bernoulli beta hyperparameter @@ -62,6 +63,7 @@ Rcpp::List sample_mixed_mrf( const int seed, const int no_threads, const int progress_type, + SEXP progress_callback = R_NilValue, const std::string& edge_prior = "Bernoulli", const double beta_bernoulli_alpha = 1.0, const double beta_bernoulli_beta = 1.0, @@ -134,7 +136,7 @@ Rcpp::List sample_mixed_mrf( config.num_leapfrogs = num_leapfrogs; // Set up progress manager - ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type, true, progress_callback); // Run MCMC using unified infrastructure std::vector results = run_mcmc_sampler( diff --git a/src/sample_omrf.cpp b/src/sample_omrf.cpp index f2b96bcc..1f44b852 100644 --- a/src/sample_omrf.cpp +++ b/src/sample_omrf.cpp @@ -27,6 +27,7 @@ // @param seed Random seed // @param no_threads Number of threads for parallel execution // @param progress_type Progress bar type +// @param progress_callback R function (SEXP) called as callback(completed, total) at regular intervals, or R_NilValue // @param edge_prior Edge prior type: "Bernoulli", "Beta-Bernoulli", "Stochastic-Block" // @param na_impute Whether to impute missing data // @param missing_index Matrix of missing data indices (n_missing x 2, 0-based) @@ -54,6 +55,7 @@ Rcpp::List sample_omrf( const int seed, const int no_threads, const int progress_type, + SEXP progress_callback = R_NilValue, const std::string& edge_prior = "Bernoulli", const bool na_impute = false, const Rcpp::Nullable missing_index_nullable = R_NilValue, @@ -108,7 +110,7 @@ Rcpp::List sample_omrf( config.na_impute = na_impute; // Set up progress manager - ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type); + ProgressManager pm(no_chains, no_iter, no_warmup, 50, progress_type, true, progress_callback); // Run MCMC using unified infrastructure std::vector results = run_mcmc_sampler( diff --git a/src/utils/progress_manager.cpp b/src/utils/progress_manager.cpp index 83610a69..8f537c14 100644 --- a/src/utils/progress_manager.cpp +++ b/src/utils/progress_manager.cpp @@ -1,8 +1,11 @@ #include "utils/progress_manager.h" -ProgressManager::ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_, int progress_type_, bool useUnicode_) +ProgressManager::ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_, int progress_type_, bool useUnicode_, SEXP progress_callback) : nChains(nChains_), nIter(nIter_ + nWarmup_), nWarmup(nWarmup_), printEvery(printEvery_), - progress_type(progress_type_), useUnicode(useUnicode_), progress(nChains_) { + progress_type(progress_type_), useUnicode(useUnicode_), progress(nChains_), callback(progress_callback) { + + // When a callback is provided, suppress the built-in progress display + if (callback.isNotNull()) progress_type = 0; for (size_t i = 0; i < nChains; i++) progress[i] = 0; start = Clock::now(); @@ -70,9 +73,18 @@ void ProgressManager::update(size_t chainId) { auto now = Clock::now(); std::chrono::duration sinceLast = now - lastPrint; + bool has_output = (progress_type != 0) || callback.isNotNull(); + // Throttle printing to avoid spamming - if (progress_type != 0 && sinceLast.count() >= 0.5) { - print(); + if (has_output && sinceLast.count() >= 0.5) { + if (progress_type != 0) { + print(); + } + if (callback.isNotNull()) { + size_t done = std::reduce(progress.begin(), progress.end()); + size_t totalWork = nChains * nIter; + Rcpp::Function(callback.get())(done, totalWork); + } lastPrint = now; } } @@ -81,10 +93,11 @@ void ProgressManager::update(size_t chainId) { void ProgressManager::finish() { - if (progress_type == 0) return; // No progress display or user interrupt + if (progress_type == 0 && callback.isNull()) return; if (needsToExit) { - Rcpp::Rcout << "All chains terminated.\n"; + if (progress_type != 0) + Rcpp::Rcout << "All chains terminated.\n"; return; } @@ -92,7 +105,12 @@ void ProgressManager::finish() { for (size_t i = 0; i < nChains; i++) progress[i] = nIter; - print(); + if (progress_type != 0) + print(); + if (callback.isNotNull()) { + size_t totalWork = nChains * nIter; + Rcpp::Function(callback.get())(totalWork, totalWork); + } } diff --git a/src/utils/progress_manager.h b/src/utils/progress_manager.h index 70402f64..6ab48fae 100644 --- a/src/utils/progress_manager.h +++ b/src/utils/progress_manager.h @@ -41,12 +41,14 @@ inline bool checkInterrupt() { * - Thread-safe printing with mutex protection * - Console width adaptation and change detection * - User interrupt checking + * - Optional R callback for external progress reporting (e.g., JASP), + * invoked as callback(completed, total) */ class ProgressManager { public: - ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_ = 10, int progress_type = 2, bool useUnicode_ = true); + ProgressManager(int nChains_, int nIter_, int nWarmup_, int printEvery_ = 10, int progress_type = 2, bool useUnicode_ = true, SEXP progress_callback = R_NilValue); void update(size_t chainId); void finish(); bool shouldExit() const; @@ -118,6 +120,9 @@ class ProgressManager { // Thread synchronization std::mutex printMutex; // Mutex for thread-safe printing + + // R callback (called without arguments at throttled intervals) + Rcpp::Nullable callback; }; #endif // PROGRESS_MANAGER_H \ No newline at end of file