Skip to content

Commit

Permalink
fix for new paradox
Browse files Browse the repository at this point in the history
  • Loading branch information
mb706 committed Jul 20, 2024
1 parent 4e4f40b commit ca20629
Show file tree
Hide file tree
Showing 5 changed files with 27 additions and 36 deletions.
2 changes: 1 addition & 1 deletion yahpo_gym_r/DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ Imports:
bbotk,
checkmate (>= 2.0.0),
mlr3misc (>= 0.1.1),
paradox (>= 0.3.0),
paradox (>= 1.0.0),
reticulate (>= 1.10),
data.table,
R6
Expand Down
29 changes: 12 additions & 17 deletions yahpo_gym_r/R/BenchmarkSet.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",
#' @field check_codomain `logical` \cr
#' Check whether returned values coincide with `codomain`.
check_codomain = NULL,

#' @field noisy `logical` \cr
#' Whether noisy surrogates should be used.
noisy = NULL,
Expand Down Expand Up @@ -130,10 +130,10 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",
instance,
multifidelity,
list(
scenario = self$id,
scenario = self$id,
session = self$onnx_session,
active_session = self$active_session,
check = self$check,
active_session = self$active_session,
check = self$check,
multithread = self$multithread,
noisy = self$noisy
),
Expand Down Expand Up @@ -162,20 +162,15 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",
#' A [`paradox::ParamSet`] containing the search space to optimize over.
get_search_space = function(drop_instance_param = TRUE, drop_fidelity_params = FALSE) {
search_space = private$.load_r_domains()$search_space
params = search_space$params
params = search_space$ids()
if (drop_instance_param) {
params[self$py_instance$config$instance_names] = NULL
params = setdiff(params, self$py_instance$config$instance_names)
}
if (drop_fidelity_params) {
params[self$py_instance$config$fidelity_params] = NULL
}
search_space_new = ParamSet$new(params)
if (search_space$has_trafo) {
search_space_new$trafo = search_space$trafo
}
if (search_space$has_deps) {
search_space_new$deps = search_space$deps
params = setdiff(params, self$py_instance$config$fidelity_params)
}
search_space_new = search_space$subset(params) # subset() handles trafo & dependencies

search_space_new
},

Expand All @@ -192,15 +187,15 @@ BenchmarkSet = R6::R6Class("BenchmarkSet",

#' @description
#' Subset the codomain. Sets a new domain.
#'
#'
#' @param keep (`character`) \cr
#' Vector of co-domain target names to keep.
#' @return
#' A [`paradox::ParamSet`] containing the output space (codomain).
subset_codomain = function(keep) {
codomain = self$codomain
assert_subset(keep, names(codomain$params))
new_codomain = ParamSet$new(codomain$params[names(codomain$params) %in% keep])
assert_subset(keep, codomain$ids())
new_codomain = codomain$subset(keep)
private$.domains$codomain = new_codomain
}
),
Expand Down
24 changes: 10 additions & 14 deletions yahpo_gym_r/R/Objective.R
Original file line number Diff line number Diff line change
Expand Up @@ -25,25 +25,21 @@ ObjectiveYAHPO = R6::R6Class("ObjectiveYAHPO",
instance_param = self$py_instance$config$instance_names
fidelity_params = if (!multifidelity) self$py_instance$config$fidelity_params else NULL
pars = setdiff(domain$ids(), c(instance_param, fidelity_params))
domain_new = ParamSet$new(domain$params[pars])
if (domain$has_trafo) {
domain_new$trafo = domain$trafo
}
if (domain$has_deps) {
domain_new$deps = domain$deps
}
domain_new = domain$subset(pars) # this also handles trafo and dependencies

# define constants param_set
cst = ps()
if (length(instance_param)) {
cst$add(domain$params[[instance_param]])
cst$values = insert_named(cst$values, y = setNames(list(instance), nm = instance_param))
cst = ps_union(list(cst,
domain$subset(instance_param)
))
cst$set_values(.values = structure(list(instance), names = instance_param))
}
if (length(fidelity_params)) {
for (fidelity_param in fidelity_params) {
cst$add(domain$params[[fidelity_param]])
cst$values = insert_named(cst$values, y = setNames(list(domain$params[[fidelity_param]]$upper), nm = fidelity_param))
}
cst = ps_union(list(cst,
domain$subset(fidelity_params)
))
cst$set_values(.values = structure(as.list(domain$upper[fidelity_params]), names = fidelity_params))
}

noise = ifelse(py_instance_args$noisy, "noisy", "deterministic")
Expand Down Expand Up @@ -95,7 +91,7 @@ ObjectiveYAHPO = R6::R6Class("ObjectiveYAHPO",
self$py_instance$objective_function(
preproc_xs(xs, ...), seed = self$seed,
logging = self$logging, multithread = self$multithread
)
)
}
}
},
Expand Down
4 changes: 2 additions & 2 deletions yahpo_gym_r/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ and run our search procedure.
```r
library("bbotk")
p = opt("random_search")
ois = OptimInstanceMultiCrit$new(obj, search_space = b$get_search_space(drop_fidelity_params = TRUE), terminator = trm("evals", n_evals = 10))
ois = OptimInstanceBatchMultiCrit$new(obj, search_space = b$get_search_space(drop_fidelity_params = TRUE), terminator = trm("evals", n_evals = 10))
p$optimize(ois)
```

Expand All @@ -118,7 +118,7 @@ to specify `drop_fidelity_params = TRUE` when getting the search space via `$get
```r
library(mlr3hyperband)
obj = b$get_objective("40981", multifidelity = TRUE)
ois = OptimInstanceMultiCrit$new(obj, search_space = b$get_search_space(), terminator = trm("none"))
ois = OptimInstanceBatchMultiCrit$new(obj, search_space = b$get_search_space(), terminator = trm("none"))
p = opt("hyperband")
p$optimize(ois)
```
Expand Down
4 changes: 2 additions & 2 deletions yahpo_gym_r/tests/testthat/test_benchmarkset.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ test_that("subsetting works", {
reticulate::use_condaenv("yahpo_gym", required=TRUE)
b = BenchmarkSet$new("lcbench", active_session = TRUE)
b$subset_codomain("val_accuracy")
expect_true(names(b$codomain$params) == "val_accuracy")
expect_true(b$codomain$ids() == "val_accuracy")
})

test_that("Parallel", {
Expand All @@ -63,7 +63,7 @@ test_that("Parallel", {
xdt = generate_design_random(b$get_search_space(), 1)$data
xss_trafoed = transform_xdt_to_xss(xdt, b$get_search_space())
objective$eval_many(xss_trafoed)

future::plan("multisession")
pss = replicate(2, {
xdt = generate_design_random(b$get_search_space(), 1)$data
Expand Down

0 comments on commit ca20629

Please sign in to comment.