Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions src/cloudai/_core/configurator/cloudai_gym.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,17 +136,17 @@ def step(self, action: Any) -> Tuple[list, float, bool, dict]:
return [-1.0], -1.0, True, {}

logging.info(f"Running step {self.test_run.step} with action {action}")
self.runner.runner.test_scenario.test_runs = [copy.deepcopy(self.test_run)]
new_tr = copy.deepcopy(self.test_run)
self.runner.runner.test_scenario.test_runs = [new_tr]
asyncio.run(self.runner.run())
self.test_run = self.runner.runner.test_scenario.test_runs[0]

observation = self.get_observation(action)
reward = self.compute_reward(observation)
done = False
info = {}

self.write_trajectory(self.test_run.step, action, reward, observation)

return observation, reward, done, info
return observation, reward, False, {}

def render(self, mode: str = "human"):
"""
Expand Down
2 changes: 1 addition & 1 deletion src/cloudai/cli/handlers.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ def handle_dse_job(runner: Runner, args: argparse.Namespace):
if result is None:
break
step, action = result
test_run.step = step
env.test_run.step = step
observation, reward, done, info = env.step(action)
feedback = {"trial_index": step, "value": reward}
agent.update_policy(feedback)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ def generate_report(self) -> None:
f.write("Max: {max}\n".format(max=stats["max"]))

def get_metric(self, metric: str) -> float:
logging.debug(f"Getting metric {metric} from {self.results_file.absolute()}")
step_timings = extract_timings(self.results_file)
if not step_timings:
return METRIC_ERROR
Expand Down
2 changes: 1 addition & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,7 @@ def slurm_system(tmp_path: Path) -> SlurmSystem:
],
)
system.scheduler = "slurm"
system.monitor_interval = 10
system.monitor_interval = 0
return system


Expand Down
16 changes: 15 additions & 1 deletion tests/test_cloudaigym.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
import pytest

from cloudai._core.configurator.cloudai_gym import CloudAIGymEnv
from cloudai._core.configurator.grid_search import GridSearchAgent
from cloudai._core.runner import Runner
from cloudai._core.test import Test
from cloudai._core.test_scenario import TestRun, TestScenario
Expand Down Expand Up @@ -68,7 +69,7 @@ def setup_env(slurm_system: SlurmSystem) -> tuple[TestRun, Runner]:
slurm_system.output_path / test_scenario.name / test_run.name / f"{test_run.current_iteration}"
)

runner = Runner(mode="run", system=slurm_system, test_scenario=test_scenario)
runner = Runner(mode="dry-run", system=slurm_system, test_scenario=test_scenario)

return test_run, runner

Expand Down Expand Up @@ -237,3 +238,16 @@ def test_update_test_run_obj():

env.update_test_run_obj(cmd_args, "trainer.num_nodes", [3, 4])
assert cmd_args.trainer.num_nodes == [3, 4]


def test_tr_output_path(setup_env: tuple[TestRun, Runner]):
test_run, runner = setup_env
test_run.test.test_definition.cmd_args.data.global_batch_size = 8 # avoid constraint check failure
env = CloudAIGymEnv(test_run=test_run, runner=runner)
agent = GridSearchAgent(env)

_, action = agent.select_action()
env.test_run.step = 42
env.step(action)

assert env.test_run.output_path.name == "42"