-
Notifications
You must be signed in to change notification settings - Fork 3
Split expectation and strategy computations #98
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: qvalue
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -205,7 +205,35 @@ function _value_iteration!(problem::AbstractIntervalMDPProblem, alg; callback = | |||||
| return value_function.current, k, value_function.previous, strategy_cache | ||||||
| end | ||||||
|
|
||||||
| function bellman_update!(workspace, strategy_cache, value_function, k, mp, spec) | ||||||
| function bellman_update!(workspace, strategy_cache, value_function::StateValueFunction, k, mp, spec) | ||||||
|
|
||||||
| # 1. compute expectation for Q(s, a) | ||||||
| expectation!( | ||||||
| workspace, | ||||||
| select_strategy_cache(strategy_cache, k), | ||||||
| value_function.intermediate_state_action_value, | ||||||
| value_function.previous, | ||||||
| select_model(mp, k); # For time-varying available and labelling functions | ||||||
| upper_bound = isoptimistic(spec), | ||||||
| maximize = ismaximize(spec), | ||||||
| prop = system_property(spec), | ||||||
| ) | ||||||
|
|
||||||
| # 2. extract strategy and compute V(s) = max_a Q(s, a) | ||||||
| strategy!( | ||||||
| select_strategy_cache(strategy_cache, k), | ||||||
| value_function.current, | ||||||
| value_function.intermediate_state_action_value, | ||||||
| select_model(mp, k), | ||||||
| ismaximize(spec), | ||||||
| ) | ||||||
|
|
||||||
| # 3. post process | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| step_postprocess_value_function!(value_function, spec) | ||||||
| step_postprocess_strategy_cache!(strategy_cache) | ||||||
| end | ||||||
|
|
||||||
| function bellman_update!(workspace, strategy_cache::NonOptimizingStrategyCache, value_function::StateValueFunction, k, mp, spec) | ||||||
| expectation!( | ||||||
| workspace, | ||||||
| select_strategy_cache(strategy_cache, k), | ||||||
|
|
||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -159,3 +159,26 @@ function _extract_strategy!(cur_strategy, values, available_actions, neutral, j | |||||
| @inbounds cur_strategy[jₛ] = opt_index | ||||||
| return opt_val | ||||||
| end | ||||||
|
|
||||||
|
|
||||||
| function strategy!( | ||||||
| strategy_cache::OptimizingStrategyCache, | ||||||
| Vres::AbstractArray{R}, | ||||||
| Q::AbstractArray{R}, | ||||||
| model, | ||||||
| maximize, | ||||||
| ) where {R <: Real} | ||||||
|
|
||||||
| #TODO: can be threaded? | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Leave it be for now - we can always change it later. |
||||||
| for jₛ in CartesianIndices(source_shape(model)) | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
|
|
||||||
| Vres[jₛ] = extract_strategy!( | ||||||
| strategy_cache, | ||||||
| Q[:, jₛ], | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Let's try to avoid expensive copies (expensive = many allocations, not large allocations. Previously it was allocating and copying for every state).
Suggested change
|
||||||
| available(model, jₛ), | ||||||
| jₛ, | ||||||
| maximize, | ||||||
| ) | ||||||
|
|
||||||
| end | ||||||
| end | ||||||
| Original file line number | Diff line number | Diff line change | ||||
|---|---|---|---|---|---|---|
|
|
@@ -12,8 +12,8 @@ function StateValueFunction(problem::AbstractIntervalMDPProblem) | |||||
| previous .= zero(valuetype(mp)) | ||||||
| current = copy(previous) | ||||||
|
|
||||||
| dim = Tuple(Iterators.flatten(zip(action_values(mp), state_values(mp)))) | ||||||
| # interleaved concat gives shape: (a1, a2) , (s1, s2) => (a1, s1, a2, s2) | ||||||
| dim = (action_values(mp)..., state_values(mp)...) | ||||||
| # concat gives shape: (a1, a2) , (s1, s2) => (a1, a2, s1, s2) | ||||||
| # (a, s) to access s more frequently due to column major | ||||||
|
Owner
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
| # TODO: works for IMDP, need to check for fIMDP | ||||||
| intermediate_state_action_value = arrayfactory(mp, valuetype(mp), dim) | ||||||
|
|
@@ -45,7 +45,7 @@ end | |||||
|
|
||||||
| function StateActionValueFunction(problem::AbstractIntervalMDPProblem) | ||||||
| mp = system(problem) | ||||||
| dim = Tuple(Iterators.flatten(zip(action_values(mp), state_values(mp)))) | ||||||
| dim = (action_values(mp)..., state_values(mp)...) | ||||||
| # TODO: works for IMDP, need to check for fIMDP | ||||||
| previous = arrayfactory(mp, valuetype(mp), dim) | ||||||
| previous .= zero(valuetype(mp)) | ||||||
|
|
||||||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.