Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type) {
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type)
run_bgmCompare_parallel <- function(observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type, progress_callback = NULL) {
.Call(`_bgms_run_bgmCompare_parallel`, observations, num_groups, counts_per_category, blume_capel_stats, pairwise_stats, num_categories, main_alpha, main_beta, pairwise_scale, pairwise_scaling_factors, difference_scale, difference_selection_alpha, difference_selection_beta, difference_prior, iter, warmup, na_impute, missing_data_indices, is_ordinal_variable, baseline_category, difference_selection, main_difference_selection, main_effect_indices, pairwise_effect_indices, target_accept, nuts_max_depth, learn_mass_matrix, projection, group_membership, group_indices, interaction_index_matrix, inclusion_probability, num_chains, nThreads, seed, update_method, hmc_num_leapfrogs, progress_type, progress_callback)
}

get_explog_switch <- function() {
Expand Down Expand Up @@ -117,16 +117,16 @@ run_mixed_simulation_parallel <- function(mux_samples, disc_samples, muy_samples
.Call(`_bgms_run_mixed_simulation_parallel`, mux_samples, disc_samples, muy_samples, cont_samples, cross_samples, draw_indices, num_states, p, q, num_categories, variable_type_r, baseline_category, iter, nThreads, seed, progress_type)
}

sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)
sample_ggm <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, na_impute = FALSE, missing_index_nullable = NULL) {
.Call(`_bgms_sample_ggm`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, na_impute, missing_index_nullable)
}

sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) {
.Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)
sample_mixed_mrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, sampler_type = "mh", target_acceptance = 0.80, max_tree_depth = 10L, num_leapfrogs = 100L, na_impute = FALSE, missing_index_discrete_nullable = NULL, missing_index_continuous_nullable = NULL) {
.Call(`_bgms_sample_mixed_mrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, seed, no_threads, progress_type, progress_callback, edge_prior, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, sampler_type, target_acceptance, max_tree_depth, num_leapfrogs, na_impute, missing_index_discrete_nullable, missing_index_continuous_nullable)
}

sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL) {
.Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable)
sample_omrf <- function(inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback = NULL, edge_prior = "Bernoulli", na_impute = FALSE, missing_index_nullable = NULL, beta_bernoulli_alpha = 1.0, beta_bernoulli_beta = 1.0, beta_bernoulli_alpha_between = 1.0, beta_bernoulli_beta_between = 1.0, dirichlet_alpha = 1.0, lambda = 1.0, target_acceptance = 0.8, max_tree_depth = 10L, num_leapfrogs = 10L, pairwise_scaling_factors_nullable = NULL) {
.Call(`_bgms_sample_omrf`, inputFromR, prior_inclusion_prob, initial_edge_indicators, no_iter, no_warmup, no_chains, edge_selection, sampler_type, seed, no_threads, progress_type, progress_callback, edge_prior, na_impute, missing_index_nullable, beta_bernoulli_alpha, beta_bernoulli_beta, beta_bernoulli_alpha_between, beta_bernoulli_beta_between, dirichlet_alpha, lambda, target_acceptance, max_tree_depth, num_leapfrogs, pairwise_scaling_factors_nullable)
}

compute_Vn_mfm_sbm <- function(num_variables, dirichlet_alpha, t_max, lambda) {
Expand Down
12 changes: 11 additions & 1 deletion R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -273,6 +273,14 @@
#' \code{"total"} (single combined bar), or \code{"none"} (no progress).
#' Default: \code{"per-chain"}.
#'
#' @param progress_callback An optional R function with signature
#' \code{function(completed, total)} that is called at regular intervals
#' during sampling, where \code{completed} is the number of iterations
#' completed across all chains and \code{total} is the total number of
#' iterations. Useful for external front-ends (e.g., JASP) that supply
#' their own progress reporting.
#' When \code{NULL} (the default), no callback is invoked.
#'
#' @param verbose Logical. If \code{TRUE}, prints informational messages
#' during data processing (e.g., missing data handling, variable recoding).
#' Defaults to \code{getOption("bgms.verbose", TRUE)}. Set
Expand Down Expand Up @@ -447,6 +455,7 @@ bgm = function(
standardize = FALSE,
pseudolikelihood = c("conditional", "marginal"),
verbose = getOption("bgms.verbose", TRUE),
progress_callback = NULL,
interaction_scale,
burnin,
save,
Expand Down Expand Up @@ -519,7 +528,8 @@ bgm = function(
seed = seed,
display_progress = display_progress,
verbose = verbose,
pseudolikelihood = pseudolikelihood
pseudolikelihood = pseudolikelihood,
progress_callback = progress_callback
)

raw = run_sampler(spec)
Expand Down
11 changes: 10 additions & 1 deletion R/bgmCompare.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,13 @@
#' @param display_progress Character. Controls progress reporting:
#' \code{"per-chain"}, \code{"total"}, or \code{"none"}.
#' Default: \code{"per-chain"}.
#' @param progress_callback An optional R function with signature
#' \code{function(completed, total)} that is called at regular intervals
#' during sampling, where \code{completed} is the number of iterations
#' completed across all chains and \code{total} is the total number of
#' iterations. Useful for external front-ends (e.g., JASP) that supply
#' their own progress reporting.
#' When \code{NULL} (the default), no callback is invoked.
#' @param verbose Logical. If \code{TRUE}, prints informational messages
#' during data processing (e.g., missing data handling, variable recoding).
#' Defaults to \code{getOption("bgms.verbose", TRUE)}. Set
Expand Down Expand Up @@ -225,6 +232,7 @@ bgmCompare = function(
seed = NULL,
standardize = FALSE,
verbose = getOption("bgms.verbose", TRUE),
progress_callback = NULL,
main_difference_model,
reference_category,
main_difference_scale,
Expand Down Expand Up @@ -366,7 +374,8 @@ bgmCompare = function(
cores = cores,
seed = seed,
display_progress = display_progress,
verbose = verbose
verbose = verbose,
progress_callback = progress_callback
)

raw = run_sampler(spec)
Expand Down
9 changes: 6 additions & 3 deletions R/bgm_spec.R
Original file line number Diff line number Diff line change
Expand Up @@ -285,7 +285,8 @@ bgm_spec = function(x,
seed = NULL,
display_progress = c("per-chain", "total", "none"),
verbose = TRUE,
pseudolikelihood = c("conditional", "marginal")) {
pseudolikelihood = c("conditional", "marginal"),
progress_callback = NULL) {
model_type = match.arg(model_type)
na_action = tryCatch(match.arg(na_action), error = function(e) {
stop(paste0(
Expand Down Expand Up @@ -340,7 +341,8 @@ bgm_spec = function(x,
display_progress = display_progress,
is_continuous = is_continuous,
edge_selection = if(model_type == "compare") FALSE else edge_selection,
verbose = verbose
verbose = verbose,
progress_callback = progress_callback
)

# --- Build by model type ----------------------------------------------------
Expand Down Expand Up @@ -1079,7 +1081,8 @@ sampler_sublist = function(s) {
nuts_max_depth = as.integer(s$nuts_max_depth),
learn_mass_matrix = s$learn_mass_matrix,
seed = as.integer(s$seed),
progress_type = as.integer(s$progress_type)
progress_type = as.integer(s$progress_type),
progress_callback = s$progress_callback
)
}

Expand Down
5 changes: 4 additions & 1 deletion R/run_sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ run_sampler_omrf = function(spec) {
no_chains = s$chains,
no_threads = s$cores,
progress_type = s$progress_type,
progress_callback = s$progress_callback,
edge_selection = p$edge_selection,
sampler_type = s$update_method,
seed = s$seed,
Expand Down Expand Up @@ -198,6 +199,7 @@ run_sampler_mixed_mrf = function(spec) {
seed = s$seed,
no_threads = s$cores,
progress_type = s$progress_type,
progress_callback = s$progress_callback,
edge_prior = p$edge_prior,
beta_bernoulli_alpha = p$beta_bernoulli_alpha,
beta_bernoulli_beta = p$beta_bernoulli_beta,
Expand Down Expand Up @@ -267,6 +269,7 @@ run_sampler_compare = function(spec) {
seed = s$seed,
update_method = s$update_method,
hmc_num_leapfrogs = s$hmc_num_leapfrogs,
progress_type = s$progress_type
progress_type = s$progress_type,
progress_callback = s$progress_callback
)
}
6 changes: 4 additions & 2 deletions R/validate_sampler.R
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,8 @@ validate_sampler = function(update_method,
display_progress = c("per-chain", "total", "none"),
is_continuous = FALSE,
edge_selection = FALSE,
verbose = TRUE) {
verbose = TRUE,
progress_callback = NULL) {
# --- update_method ----------------------------------------------------------
user_chose_method = length(update_method) == 1
update_method = match.arg(
Expand Down Expand Up @@ -207,6 +208,7 @@ validate_sampler = function(update_method,
chains = chains,
cores = cores,
seed = seed,
progress_type = progress_type
progress_type = progress_type,
progress_callback = progress_callback
)
}
6 changes: 6 additions & 0 deletions man/bgm.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 6 additions & 0 deletions man/bgmCompare.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading