Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ Imports:
Rcpp (>= 1.0.7),
RcppParallel,
Rdpack,
S7,
methods,
lifecycle,
stats
Expand Down
7 changes: 7 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ S3method(extract_precision,bgms)
S3method(extract_rhat,bgmCompare)
S3method(extract_rhat,bgms)
S3method(extract_sbm,bgms)
S3method(names,bgmCompare)
S3method(names,bgms)
S3method(predict,bgmCompare)
S3method(predict,bgms)
S3method(print,bgmCompare)
Expand Down Expand Up @@ -63,6 +65,11 @@ importFrom(Rcpp,evalCpp)
importFrom(RcppParallel,defaultNumThreads)
importFrom(RcppParallel,setThreadOptions)
importFrom(Rdpack,reprompt)
importFrom(S7,class_any)
importFrom(S7,class_character)
importFrom(S7,new_class)
importFrom(S7,new_property)
importFrom(S7,prop)
importFrom(methods,hasArg)
importFrom(stats,dpois)
importFrom(stats,predict)
Expand Down
2 changes: 0 additions & 2 deletions R/bgm.R
Original file line number Diff line number Diff line change
Expand Up @@ -524,7 +524,5 @@ bgm = function(

raw = run_sampler(spec)
output = build_output(spec, raw)

output$.bgm_spec = spec
return(output)
}
1 change: 0 additions & 1 deletion R/bgmCompare.R
Original file line number Diff line number Diff line change
Expand Up @@ -372,6 +372,5 @@ bgmCompare = function(
raw = run_sampler(spec)
output = build_output(spec, raw)

output$.bgm_spec = spec
return(output)
}
70 changes: 48 additions & 22 deletions R/bgmcompare-methods.r
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,8 @@ coef.bgmCompare = function(object, ...) {

# ============================================================
# ---- main effects ----
array3d_main = to_array3d(object$raw_samples$main)
raw = get_raw_samples(object)
array3d_main = to_array3d(raw$main)
stopifnot(!is.null(array3d_main))
mean_main = apply(array3d_main, 3, mean)

Expand Down Expand Up @@ -335,7 +336,7 @@ coef.bgmCompare = function(object, ...) {

# ============================================================
# ---- pairwise effects ----
array3d_pair = to_array3d(object$raw_samples$pairwise)
array3d_pair = to_array3d(raw$pairwise)
stopifnot(!is.null(array3d_pair))
mean_pair = apply(array3d_pair, 3, mean)

Expand Down Expand Up @@ -369,7 +370,7 @@ coef.bgmCompare = function(object, ...) {
# ============================================================
# ---- indicators (present only if selection was used) ----
indicators = NULL
array3d_ind = to_array3d(object$raw_samples$indicator)
array3d_ind = to_array3d(raw$indicator)
if(!is.null(array3d_ind)) {
mean_ind = apply(array3d_ind, 3, mean)

Expand Down Expand Up @@ -413,9 +414,10 @@ coef.bgmCompare = function(object, ...) {

#' Access elements of a bgmCompare object
#'
#' @description Intercepts access to \code{posterior_summary_*} fields and
#' triggers lazy computation from cache when needed. All other fields pass
#' through using standard list extraction.
#' @description Provides \code{$} access to S7 properties. Lazy
#' \code{posterior_summary_*} properties trigger computation on first
#' access via S7 property getters. Also supports legacy S3 list-based
#' fit objects.
#'
#' @param x A \code{bgmCompare} object.
#' @param name Name of the element to access.
Expand All @@ -426,17 +428,21 @@ coef.bgmCompare = function(object, ...) {
#' @export
#' @keywords internal
`$.bgmCompare` = function(x, name) {
if(startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
if(inherits(x, "S7_object")) {
S7::prop(x, name)
} else {
if(startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
}
}
}
.subset2(x, name)
}
.subset2(x, name)
}


Expand All @@ -446,15 +452,35 @@ coef.bgmCompare = function(object, ...) {
#' @export
#' @keywords internal
`[[.bgmCompare` = function(x, name, ...) {
if(is.character(name) && startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
if(inherits(x, "S7_object")) {
if(is.character(name)) {
S7::prop(x, name)
} else {
stop("numeric indexing is not supported for bgmCompare objects")
}
} else {
if(is.character(name) && startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
}
}
}
.subset2(x, name)
}
}


#' @method names bgmCompare
#' @export
#' @keywords internal
names.bgmCompare = function(x) {
if(inherits(x, "S7_object")) {
S7::prop(x, ".field_names")
} else {
NextMethod()
}
.subset2(x, name)
}
77 changes: 51 additions & 26 deletions R/bgms-methods.R
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ summary.bgms = function(object, ...) {

if(!is.null(object$posterior_summary_pairwise_allocations)) {
out$allocations = object$posterior_summary_pairwise_allocations
out$mean_allocations = object$posterior_mean_allocations
out$mean_allocations = get_posterior_mean(object, "allocations")
out$mode_allocations = object$posterior_mode_allocations
out$num_blocks = object$posterior_num_blocks
}
Expand Down Expand Up @@ -261,15 +261,15 @@ print.summary.bgms = function(x, digits = 3, ...) {
#' @export
coef.bgms = function(object, ...) {
out = list(
main = object$posterior_mean_main,
pairwise = object$posterior_mean_associations
main = get_posterior_mean(object, "main"),
pairwise = get_posterior_mean(object, "associations")
)
if(!is.null(object$posterior_mean_indicator)) {
out$indicator = object$posterior_mean_indicator
if(!is.null(get_posterior_mean(object, "indicator"))) {
out$indicator = get_posterior_mean(object, "indicator")
}

if(!is.null(object$posterior_mean_allocations)) {
out$mean_allocations = object$posterior_mean_allocations
if(!is.null(get_posterior_mean(object, "allocations"))) {
out$mean_allocations = get_posterior_mean(object, "allocations")
out$mode_allocations = object$posterior_mode_allocations
out$num_blocks = object$posterior_num_blocks
}
Expand All @@ -280,9 +280,10 @@ coef.bgms = function(object, ...) {

#' Access elements of a bgms object
#'
#' @description Intercepts access to \code{posterior_summary_*} fields and
#' triggers lazy computation from cache when needed. All other fields pass
#' through using standard list extraction.
#' @description Provides \code{$} access to S7 properties. Lazy
#' \code{posterior_summary_*} properties trigger computation on first
#' access via S7 property getters. Also supports legacy S3 list-based
#' fit objects.
#'
#' @param x A \code{bgms} object.
#' @param name Name of the element to access.
Expand All @@ -293,17 +294,21 @@ coef.bgms = function(object, ...) {
#' @export
#' @keywords internal
`$.bgms` = function(x, name) {
if(startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
if(inherits(x, "S7_object")) {
S7::prop(x, name)
} else {
if(startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
}
}
}
.subset2(x, name)
}
.subset2(x, name)
}


Expand All @@ -313,17 +318,37 @@ coef.bgms = function(object, ...) {
#' @export
#' @keywords internal
`[[.bgms` = function(x, name, ...) {
if(is.character(name) && startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
if(inherits(x, "S7_object")) {
if(is.character(name)) {
S7::prop(x, name)
} else {
stop("numeric indexing is not supported for bgms objects")
}
} else {
if(is.character(name) && startsWith(name, "posterior_summary_")) {
cache = .subset2(x, "cache")
if(!is.null(cache)) {
ensure_summaries(x)
val = cache[[name]]
if(!is.null(val)) {
return(val)
}
}
}
.subset2(x, name)
}
}


#' @method names bgms
#' @export
#' @keywords internal
names.bgms = function(x) {
if(inherits(x, "S7_object")) {
S7::prop(x, ".field_names")
} else {
NextMethod()
}
.subset2(x, name)
}


Expand Down
Loading
Loading