Skip to content

Commit 04cda3b

Browse files
committed
feat: added A2C algo train and evaluate functions
1 parent 8129229 commit 04cda3b

File tree

2 files changed

+90
-0
lines changed

2 files changed

+90
-0
lines changed

A2Ctrain.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
### requires ray version 2.6.3 to run
2+
3+
import ray
4+
from ray import tune, air
5+
from ray.tune.registry import register_env
6+
from env_creator import qsimpy_env_creator
7+
from ray.rllib.algorithms.a2c import A2CConfig # Import A2CConfig
8+
import os
9+
10+
if __name__ == "__main__":
11+
register_env("QSimPyEnv", qsimpy_env_creator)
12+
13+
config = (
14+
A2CConfig()
15+
.framework(framework='torch')
16+
.environment(
17+
env="QSimPyEnv",
18+
env_config={
19+
"obs_filter": "rescale_-1_1",
20+
"reward_filter": None,
21+
"dataset": "qdataset/qsimpyds_1000_sub_36.csv",
22+
},
23+
)
24+
.training(gamma=0.9, lr=0.01)
25+
.rollouts(num_rollout_workers=4)
26+
)
27+
28+
stopping_criteria = {
29+
"training_iteration": 1000,
30+
"timesteps_total": 100000
31+
}
32+
33+
# Get the absolute path of the current directory
34+
current_directory = os.getcwd()
35+
36+
# Append the "result" folder to the current directory path
37+
result_directory = os.path.join(current_directory, "results")
38+
storage_path = f"file://{result_directory}"
39+
40+
results = tune.Tuner(
41+
"A2C", # Specify A2C algorithm
42+
run_config=air.RunConfig(
43+
stop=stopping_criteria,
44+
# Save checkpoints every 10 iterations.
45+
checkpoint_config=air.CheckpointConfig(checkpoint_frequency=100),
46+
storage_path=storage_path,
47+
name="A2C_QCE_1000"
48+
49+
),
50+
param_space=config.to_dict(),
51+
).fit()

evaluateA2C.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
from env_creator import qsimpy_env_creator
2+
from ray.tune.registry import register_env
3+
from ray.rllib.algorithms import Algorithm
4+
5+
register_env("QSimPyEnv", qsimpy_env_creator)
6+
7+
env = qsimpy_env_creator (
8+
env_config = {
9+
"obs_filter": "rescale_-1_1",
10+
"reward_filter": None,
11+
"dataset": "qdataset/qsimpyds_1000_sub_36.csv",
12+
}
13+
)
14+
15+
checkpoint_path = "results\A2C_QCE_1000\A2C_QSimPyEnv_3fc51_00000_0_2024-08-11_19-08-18\checkpoint_000100"
16+
17+
model = Algorithm.from_checkpoint(checkpoint_path)
18+
19+
num_ep = 50
20+
21+
for ep in range(num_ep):
22+
obs= env.reset()
23+
finished = False
24+
ep_reward=0
25+
26+
while not finished:
27+
formatted_obs = obs if not isinstance(obs,tuple) else obs[0]
28+
action = model.compute_single_action(formatted_obs,explore = False)
29+
obs, reward, finished, _, info = env.step(action)
30+
ep_reward += reward
31+
32+
if finished:
33+
print(f"Episode{ep} finished with reward {ep_reward} and info {info} ")
34+
break
35+
env.close()
36+
37+
38+
39+

0 commit comments

Comments
 (0)