Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions R/adaptive_btl_refit.R
Original file line number Diff line number Diff line change
Expand Up @@ -1785,29 +1785,41 @@
proxy_scores = spoke_scores
)

reliability_stop_pass <- isTRUE(stats_row$link_reliability_stop_pass %||% FALSE)
delta_sd_pass <- isTRUE(stats_row$delta_sd_pass %||% FALSE)
log_alpha_sd_pass <- is.na(stats_row$log_alpha_sd_pass %||% NA) || isTRUE(stats_row$log_alpha_sd_pass)
transform_mode <- as.character(stats_row$link_transform_mode %||%
.adaptive_link_transform_mode_for_spoke(controller, spoke_id))
reliability_stop_pass <- as.logical(stats_row$link_reliability_stop_pass %||% NA)
delta_sd_pass <- as.logical(stats_row$delta_sd_pass %||% NA)
log_alpha_sd_pass <- as.logical(stats_row$log_alpha_sd_pass %||% NA)
lag_eligible <- isTRUE(stats_row$lag_eligible %||% FALSE)
delta_change_pass <- isTRUE(stats_row$delta_change_pass %||% FALSE)
log_alpha_change_pass <- is.na(stats_row$log_alpha_change_pass %||% NA) || isTRUE(stats_row$log_alpha_change_pass)
rank_stability_pass <- isTRUE(stats_row$rank_stability_pass %||% FALSE)
link_stop_eligible <- isTRUE(lag_eligible) && !is.na(diagnostics_pass)
delta_change_pass <- as.logical(stats_row$delta_change_pass %||% NA)
log_alpha_change_pass <- as.logical(stats_row$log_alpha_change_pass %||% NA)
rank_stability_pass <- as.logical(stats_row$rank_stability_pass %||% NA)
required_defined <- !is.na(diagnostics_pass) &&
!is.na(reliability_stop_pass) &&
!is.na(delta_sd_pass) &&
!is.na(delta_change_pass) &&
!is.na(rank_stability_pass)
if (identical(transform_mode, "shift_scale")) {
required_defined <- isTRUE(required_defined) &&
!is.na(log_alpha_sd_pass) &&
!is.na(log_alpha_change_pass)
}
# Eligibility means all required stop gates are defined at this refit.
link_stop_eligible <- isTRUE(lag_eligible) && isTRUE(required_defined)
link_stop_pass <- isTRUE(link_stop_eligible) &&
isTRUE(diagnostics_pass) &&
isTRUE(reliability_stop_pass) &&
isTRUE(delta_sd_pass) &&
isTRUE(log_alpha_sd_pass) &&
(isTRUE(log_alpha_sd_pass) || identical(transform_mode, "shift_only")) &&
isTRUE(delta_change_pass) &&
isTRUE(log_alpha_change_pass) &&
(isTRUE(log_alpha_change_pass) || identical(transform_mode, "shift_only")) &&
isTRUE(rank_stability_pass)

rows[[idx]] <- list(
refit_id = as.integer(refit_id),
spoke_id = as.integer(spoke_id),
hub_id = as.integer(hub_id),
link_transform_mode = as.character(stats_row$link_transform_mode %||%
.adaptive_link_transform_mode_for_spoke(controller, spoke_id)),
link_transform_mode = as.character(transform_mode),
link_refit_mode = as.character(controller$link_refit_mode %||% NA_character_),
hub_lock_mode = as.character(controller$hub_lock_mode %||% NA_character_),
hub_lock_kappa = if (identical(as.character(controller$hub_lock_mode %||% NA_character_), "soft_lock")) {
Expand Down
3 changes: 2 additions & 1 deletion R/adaptive_rank.R
Original file line number Diff line number Diff line change
Expand Up @@ -334,7 +334,8 @@ make_adaptive_judge_llm <- function(
#' `boundary_frac`, `p_star_override_margin`, and
#' `star_override_budget_per_round`, linking controls (`run_mode`, `hub_id`,
#' `link_transform_mode`, `link_refit_mode`, `shift_only_theta_treatment`,
#' `judge_param_mode`, `hub_lock_mode`, `hub_lock_kappa`), and Phase A controls
#' `judge_param_mode`, `hub_lock_mode`, `hub_lock_kappa`,
#' `allow_spoke_spoke_cross_set`), and Phase A controls
#' (`phase_a_mode`, `phase_a_import_failure_policy`,
#' `phase_a_required_reliability_min`, `phase_a_compatible_model_ids`,
#' `phase_a_compatible_config_hashes`, `phase_a_artifacts`,
Expand Down
27 changes: 24 additions & 3 deletions R/adaptive_round_candidates.R
Original file line number Diff line number Diff line change
Expand Up @@ -480,8 +480,10 @@ generate_stage_candidates_from_state <- function(state,
"Phase metadata and routing mode disagree: no active spoke could be selected for phase_b."
)
}
allow_spoke_spoke <- isTRUE(controller$allow_spoke_spoke_cross_set %||% FALSE)
hub_ids <- as.character(state$items$item_id[as.integer(state$items$set_id) == hub_id])
spoke_ids <- as.character(state$items$item_id[as.integer(state$items$set_id) == spoke_id])
active_spoke_ids <- as.character(state$items$item_id[as.integer(state$items$set_id) %in% eligible_spokes])
if (length(hub_ids) < 1L) {
rlang::abort(
paste0(
Expand All @@ -501,7 +503,11 @@ generate_stage_candidates_from_state <- function(state,
)
)
}
active_ids <- unique(c(hub_ids, spoke_ids))
active_ids <- if (isTRUE(allow_spoke_spoke)) {
unique(c(hub_ids, active_spoke_ids))
} else {
unique(c(hub_ids, spoke_ids))
}
if (length(active_ids) < 2L) {
return(tibble::tibble(i = character(), j = character()))
}
Expand Down Expand Up @@ -555,6 +561,7 @@ generate_stage_candidates_from_state <- function(state,
link_spoke_id <- integer()
coverage_bins_used <- integer()
coverage_source <- character()
set_map <- stats::setNames(as.integer(state$items$set_id), as.character(state$items$item_id))

for (a in seq_len(length(ids) - 1L)) {
i_id <- ids[[a]]
Expand All @@ -564,9 +571,17 @@ generate_stage_candidates_from_state <- function(state,
dist <- abs(as.integer(stratum_map[[i_id]]) - as.integer(stratum_map[[j_id]]))

if (isTRUE(link_phase_b_active)) {
i_set <- as.integer(set_map[[i_id]] %||% NA_integer_)
j_set <- as.integer(set_map[[j_id]] %||% NA_integer_)
if (is.na(i_set) || is.na(j_set) || i_set == j_set) {
next
}
i_hub <- i_id %in% hub_ids
j_hub <- j_id %in% hub_ids
if (!isTRUE(xor(i_hub, j_hub))) {
if (!isTRUE(allow_spoke_spoke) && !isTRUE(xor(i_hub, j_hub))) {
next
}
if (isTRUE(allow_spoke_spoke) && !isTRUE(i_set == spoke_id || j_set == spoke_id)) {
next
}
i_anchor <- i_id %in% hub_anchor_ids
Expand Down Expand Up @@ -596,7 +611,13 @@ generate_stage_candidates_from_state <- function(state,
j_vals <- c(j_vals, j_id)
dist_vals <- c(dist_vals, as.integer(dist))
if (isTRUE(link_phase_b_active)) {
spoke_item <- if (i_id %in% spoke_ids) i_id else j_id
spoke_item <- if (as.integer(set_map[[i_id]] %||% NA_integer_) == spoke_id) {
i_id
} else if (as.integer(set_map[[j_id]] %||% NA_integer_) == spoke_id) {
j_id
} else {
NA_character_
}
spoke_bin <- as.integer(coverage$bin_map[[spoke_item]] %||% NA_integer_)
priority <- as.integer(!is.na(spoke_bin) && spoke_bin %in% coverage$bins_undercovered)
coverage_priority <- c(coverage_priority, priority)
Expand Down
6 changes: 4 additions & 2 deletions R/adaptive_run.R
Original file line number Diff line number Diff line change
Expand Up @@ -597,7 +597,8 @@
#' `link_transform_escalation_refits_required`,
#' `link_transform_escalation_is_one_way`,
#' `spoke_quantile_coverage_bins`,
#' `spoke_quantile_coverage_min_per_bin_per_refit`, `multi_spoke_mode`,
#' `spoke_quantile_coverage_min_per_bin_per_refit`,
#' `allow_spoke_spoke_cross_set`, `multi_spoke_mode`,
#' `min_cross_set_pairs_per_spoke_per_refit`,
#' `phase_a_mode`, `phase_a_import_failure_policy`,
#' `phase_a_required_reliability_min`, `phase_a_compatible_model_ids`,
Expand Down Expand Up @@ -760,7 +761,8 @@ adaptive_rank_start <- function(items,
#' `link_transform_escalation_refits_required`,
#' `link_transform_escalation_is_one_way`,
#' `spoke_quantile_coverage_bins`,
#' `spoke_quantile_coverage_min_per_bin_per_refit`, `multi_spoke_mode`,
#' `spoke_quantile_coverage_min_per_bin_per_refit`,
#' `allow_spoke_spoke_cross_set`, `multi_spoke_mode`,
#' `min_cross_set_pairs_per_spoke_per_refit`,
#' `phase_a_mode`, `phase_a_import_failure_policy`,
#' `phase_a_required_reliability_min`, `phase_a_compatible_model_ids`,
Expand Down
106 changes: 54 additions & 52 deletions R/adaptive_select.R
Original file line number Diff line number Diff line change
Expand Up @@ -455,42 +455,65 @@ adaptive_defaults <- function(N) {
current_map
}

.adaptive_link_attach_predictive_utility <- function(candidates, state, controller, spoke_id) {
cand <- tibble::as_tibble(candidates)
if (nrow(cand) < 1L || is.na(spoke_id)) {
return(cand)
.adaptive_link_theta_global_map_for_items <- function(state, controller, item_ids) {
ids <- unique(as.character(item_ids))
ids <- ids[!is.na(ids)]
if (length(ids) < 1L) {
return(stats::setNames(numeric(), character()))
}
hub_id <- as.integer(controller$hub_id %||% 1L)
transform_mode <- .adaptive_link_transform_mode_for_spoke(controller, spoke_id)
prefer_current_theta <- identical(as.character(controller$link_refit_mode %||% "shift_only"), "joint_refit")
set_by_item <- stats::setNames(as.integer(state$items$set_id), as.character(state$items$item_id))
set_ids <- unique(as.integer(set_by_item[ids]))
set_ids <- set_ids[!is.na(set_ids)]
if (length(set_ids) < 1L) {
return(stats::setNames(numeric(), character()))
}

link_stats <- controller$link_refit_stats_by_spoke %||% list()
stats_row <- link_stats[[as.character(spoke_id)]] %||% list()
delta <- as.double(stats_row$delta_spoke_mean %||% 0)
if (!is.finite(delta)) {
delta <- 0
theta_global <- stats::setNames(numeric(), character())
for (set_id in set_ids) {
theta_map <- .adaptive_link_safe_theta_map(
state,
set_id = as.integer(set_id),
prefer_current = prefer_current_theta
)
if (length(theta_map) < 1L) {
next
}
if (!identical(as.integer(set_id), hub_id)) {
stats_row <- link_stats[[as.character(set_id)]] %||% list()
transform_mode <- .adaptive_link_transform_mode_for_spoke(controller, as.integer(set_id))
delta <- as.double(stats_row$delta_spoke_mean %||% 0)
if (!is.finite(delta)) {
delta <- 0
}
log_alpha <- as.double(stats_row$log_alpha_spoke_mean %||% NA_real_)
alpha <- if (identical(transform_mode, "shift_scale") && is.finite(log_alpha)) exp(log_alpha) else 1
theta_vals <- delta + alpha * as.double(theta_map)
names(theta_vals) <- names(theta_map)
theta_map <- theta_vals
}
theta_global <- c(theta_global, theta_map)
}
log_alpha <- as.double(stats_row$log_alpha_spoke_mean %||% NA_real_)
alpha <- if (identical(transform_mode, "shift_scale") && is.finite(log_alpha)) exp(log_alpha) else 1
theta_global[!duplicated(names(theta_global))]
}

prefer_current_theta <- identical(as.character(controller$link_refit_mode %||% "shift_only"), "joint_refit")
hub_theta <- .adaptive_link_safe_theta_map(
state,
set_id = hub_id,
prefer_current = prefer_current_theta
)
spoke_theta <- .adaptive_link_safe_theta_map(
state,
set_id = spoke_id,
prefer_current = prefer_current_theta
.adaptive_link_attach_predictive_utility <- function(candidates, state, controller, spoke_id) {
cand <- tibble::as_tibble(candidates)
if (nrow(cand) < 1L || is.na(spoke_id)) {
return(cand)
}
theta_global <- .adaptive_link_theta_global_map_for_items(
state = state,
controller = controller,
item_ids = c(as.character(cand$i), as.character(cand$j))
)
if (length(hub_theta) < 1L || length(spoke_theta) < 1L) {
if (length(theta_global) < 2L) {
cand$link_p <- NA_real_
cand$link_u <- NA_real_
return(cand)
}
spoke_theta_global <- delta + alpha * as.double(spoke_theta)
names(spoke_theta_global) <- names(spoke_theta)
theta_global <- c(hub_theta, spoke_theta_global)
theta_global <- theta_global[!duplicated(names(theta_global))]

startup_gap <- .adaptive_link_phase_b_startup_gap_for_spoke(state, spoke_id = as.integer(spoke_id))
judge_params <- .adaptive_link_judge_params(
Expand Down Expand Up @@ -530,35 +553,14 @@ adaptive_defaults <- function(N) {
if (is.na(spoke_id) || is.na(A_id) || is.na(B_id)) {
return(NA_real_)
}
hub_id <- as.integer(controller$hub_id %||% 1L)
transform_mode <- .adaptive_link_transform_mode_for_spoke(controller, spoke_id)
link_stats <- controller$link_refit_stats_by_spoke %||% list()
stats_row <- link_stats[[as.character(spoke_id)]] %||% list()
delta <- as.double(stats_row$delta_spoke_mean %||% 0)
if (!is.finite(delta)) {
delta <- 0
}
log_alpha <- as.double(stats_row$log_alpha_spoke_mean %||% NA_real_)
alpha <- if (identical(transform_mode, "shift_scale") && is.finite(log_alpha)) exp(log_alpha) else 1

prefer_current_theta <- identical(as.character(controller$link_refit_mode %||% "shift_only"), "joint_refit")
hub_theta <- .adaptive_link_safe_theta_map(
state,
set_id = hub_id,
prefer_current = prefer_current_theta
)
spoke_theta <- .adaptive_link_safe_theta_map(
state,
set_id = spoke_id,
prefer_current = prefer_current_theta
theta_global <- .adaptive_link_theta_global_map_for_items(
state = state,
controller = controller,
item_ids = c(as.character(A_id), as.character(B_id))
)
if (length(hub_theta) < 1L || length(spoke_theta) < 1L) {
if (length(theta_global) < 2L) {
return(NA_real_)
}
spoke_theta_global <- delta + alpha * as.double(spoke_theta)
names(spoke_theta_global) <- names(spoke_theta)
theta_global <- c(hub_theta, spoke_theta_global)
theta_global <- theta_global[!duplicated(names(theta_global))]

startup_gap <- .adaptive_link_phase_b_startup_gap_for_spoke(state, spoke_id = as.integer(spoke_id))
judge_params <- .adaptive_link_judge_params(
Expand Down
3 changes: 3 additions & 0 deletions R/adaptive_state.R
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@
link_transform_escalation_is_one_way = TRUE,
spoke_quantile_coverage_bins = 3L,
spoke_quantile_coverage_min_per_bin_per_refit = 1L,
allow_spoke_spoke_cross_set = FALSE,
multi_spoke_mode = "independent",
min_cross_set_pairs_per_spoke_per_refit = 5L,
cross_set_utility = "linking_cross_set_p_times_1_minus_p",
Expand Down Expand Up @@ -172,6 +173,7 @@
"link_transform_escalation_is_one_way",
"spoke_quantile_coverage_bins",
"spoke_quantile_coverage_min_per_bin_per_refit",
"allow_spoke_spoke_cross_set",
"multi_spoke_mode",
"min_cross_set_pairs_per_spoke_per_refit",
"cross_set_utility",
Expand Down Expand Up @@ -322,6 +324,7 @@
1L,
Inf
)
out$allow_spoke_spoke_cross_set <- read_logical("allow_spoke_spoke_cross_set")
out$multi_spoke_mode <- read_choice("multi_spoke_mode", c("independent", "concurrent"))
out$min_cross_set_pairs_per_spoke_per_refit <- read_integer(
"min_cross_set_pairs_per_spoke_per_refit",
Expand Down
6 changes: 6 additions & 0 deletions R/adaptive_step.R
Original file line number Diff line number Diff line change
Expand Up @@ -452,6 +452,12 @@ run_one_step <- function(state, judge, ...) {
} else {
NA_integer_
}
if (isTRUE(is_cross_set) && is.na(link_spoke_id)) {
selected_spoke_id <- as.integer(selection$link_spoke_id_selected %||% NA_integer_)
if (!is.na(selected_spoke_id) && selected_spoke_id %in% c(set_i, set_j)) {
link_spoke_id <- selected_spoke_id
}
}
link_stats <- controller$link_refit_stats_by_spoke %||% list()
spoke_key <- as.character(link_spoke_id)
spoke_stats <- if (!is.na(link_spoke_id)) link_stats[[spoke_key]] %||% list() else list()
Expand Down
3 changes: 2 additions & 1 deletion man/adaptive_rank.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/adaptive_rank_run_live.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/adaptive_rank_start.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

17 changes: 17 additions & 0 deletions tests/testthat/test-5045-adaptive-helper-branches.R
Original file line number Diff line number Diff line change
Expand Up @@ -571,6 +571,13 @@ test_that("adaptive state and trueskill validators cover additional edge branche
),
"must be one of"
)
expect_error(
pairwiseLLM:::.adaptive_validate_controller_config(
list(allow_spoke_spoke_cross_set = "yes"),
5L
),
"must be TRUE or FALSE"
)
expect_error(
pairwiseLLM:::.adaptive_validate_controller_config(
list(run_mode = "link_multi_spoke"),
Expand Down Expand Up @@ -608,6 +615,16 @@ test_that("adaptive state and trueskill validators cover additional edge branche
set_ids = c(1L, 2L, 3L)
)
expect_identical(cfg_link_ok$hub_id, 1L)
cfg_spoke_spoke <- pairwiseLLM:::.adaptive_validate_controller_config(
list(
run_mode = "link_multi_spoke",
hub_id = 1L,
allow_spoke_spoke_cross_set = TRUE
),
5L,
set_ids = c(1L, 2L, 3L)
)
expect_true(isTRUE(cfg_spoke_spoke$allow_spoke_spoke_cross_set))

resolved_num <- pairwiseLLM:::.adaptive_controller_resolve(5L)
expect_true(is.list(resolved_num))
Expand Down
Loading