Skip to content

Commit 89d49ac

Browse files
merge code_ppo_lag and code_ppo_pen
0 parents  commit 89d49ac

File tree

249 files changed

+358887
-0
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

249 files changed

+358887
-0
lines changed

03-icl-mix-improved.py

Lines changed: 133 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,133 @@
1+
import common
2+
import tools
3+
import numpy as np
4+
import torch
5+
from cppo import CPPO
6+
from ppolag2 import PPOLag2
7+
from costadjustment import CostAdjustment
8+
9+
# Get configuration
10+
configuration = common.get_configuration(method_name="icl-mix-improved")
11+
12+
# Create manual cost function
13+
if configuration["cost_condition"] != "":
14+
manual_cost = common.create_manual_cost_function(configuration)
15+
manualcostvalues, manualcostmap = \
16+
manual_cost.outputs(configuration["state_action_space"])
17+
manualcostvalues = np.array(manualcostvalues).squeeze()
18+
configuration["logger"].update({
19+
"expert_cost": manualcostmap.fig,
20+
})
21+
22+
# Create cost function
23+
cost = tools.functions.CostFunction(configuration, i=configuration["i"], h=64, o=1)
24+
configuration.update({"cost": cost})
25+
costvalues, costmap = cost.outputs(configuration["state_action_space"])
26+
costvalues = np.array(costvalues).squeeze()
27+
configuration["logger"].update({"cost": costmap.fig})
28+
if configuration["cost_condition"] != "":
29+
configuration["logger"].update({"cost_comparison": \
30+
configuration["cost_comparison"](manualcostvalues, costvalues)})
31+
32+
# Expert dataset accrual + train flow
33+
expert_dataset = tools.base.TrajectoryDataset.load()
34+
eS = configuration["vector_state_reduction"](expert_dataset.S)
35+
eA = configuration["vector_action_reduction"](expert_dataset.A)
36+
eSA = configuration["vector_input_format"](eS, eA).view(-1, configuration["i"])[
37+
torch.nonzero(expert_dataset.M.view(-1)).view(-1)]
38+
flow = tools.functions.create_flow(configuration, eSA, "realnvp", configuration["i"])
39+
for flowepoch in range(configuration["flow_epochs"]):
40+
configuration["logger"].update(flow.train())
41+
ep = -flow.log_probs(eSA)
42+
em, es = ep.mean(), ep.std()
43+
configuration.update({"flow": flow, "expert_nll": (em.item(), es.item())})
44+
configuration["logger"].update({"expert_nll": (em.item(), es.item())})
45+
expert_acr, expert_acrplot = tools.functions.NormalizedAccrual()({
46+
"state_reduction": configuration["state_reduction"],
47+
"dataset": expert_dataset,
48+
"spaces": configuration["state_action_space"],
49+
"normalize_func": configuration["normalize_func"],
50+
})
51+
expert_acr = np.array(expert_acr).squeeze()
52+
configuration["logger"].update({"expert_accrual": expert_acrplot.fig})
53+
54+
# Alternating process
55+
for outer_epoch in range(configuration["outer_epochs"]):
56+
57+
# Constrained PPO
58+
algorithm = {
59+
"CPPO": CPPO,
60+
"PPOLag2": PPOLag2,
61+
}[configuration["forward_crl"]](configuration)
62+
for epoch in range(configuration["ppo_epochs"]):
63+
metrics = algorithm.train()
64+
configuration["logger"].update(metrics)
65+
dataset = configuration["env"].trajectory_dataset(algorithm.policy,
66+
configuration["expert_episodes"], weights=configuration["past_pi_weights"],
67+
p=configuration["past_pi_dissimilarities"], config=configuration)
68+
acr, acrplot = tools.functions.NormalizedAccrual()({
69+
"state_reduction": configuration["state_reduction"],
70+
"dataset": dataset,
71+
"spaces": configuration["state_action_space"],
72+
"normalize_func": configuration["normalize_func"],
73+
})
74+
acr = np.array(acr).squeeze()
75+
configuration["logger"].update({
76+
"accrual": acrplot.fig,
77+
"accrual_comparison": configuration["accrual_comparison"](expert_acr, acr),
78+
})
79+
80+
# Cost adjustment
81+
adjustment = CostAdjustment(configuration)
82+
for inner_epoch in range(configuration["updates_per_epoch"]):
83+
metrics = adjustment.train()
84+
configuration["logger"].update(metrics)
85+
costvalues, costmap = cost.outputs(configuration["state_action_space"])
86+
costvalues = np.array(costvalues).squeeze()
87+
configuration["logger"].update({"cost": costmap.fig})
88+
if configuration["cost_condition"] != "":
89+
configuration["logger"].update({"cost_comparison": \
90+
configuration["cost_comparison"](manualcostvalues, costvalues)})
91+
92+
# Constrained PPO
93+
algorithm = {
94+
"CPPO": CPPO,
95+
"PPOLag2": PPOLag2,
96+
}[configuration["forward_crl"]](configuration)
97+
for epoch in range(configuration["ppo_epochs"]):
98+
metrics = algorithm.train(no_mix=True)
99+
configuration["logger"].update(metrics)
100+
dataset = configuration["env"].trajectory_dataset(algorithm.policy,
101+
configuration["expert_episodes"], cost=configuration["cost"], config=configuration)
102+
acr, acrplot = tools.functions.NormalizedAccrual()({
103+
"state_reduction": configuration["state_reduction"],
104+
"dataset": dataset,
105+
"spaces": configuration["state_action_space"],
106+
"normalize_func": configuration["normalize_func"],
107+
})
108+
acr = np.array(acr).squeeze()
109+
configuration["accruals"] = acr
110+
configuration["expert_accruals"] = expert_acr
111+
configuration["logger"].update({
112+
"ppo_accrual": acrplot.fig,
113+
"ppo_accrual_comparison": configuration["accrual_comparison"](expert_acr, acr),
114+
})
115+
116+
# Constrained PPO no cost
117+
dataset = configuration["env"].trajectory_dataset(algorithm.policy,
118+
configuration["expert_episodes"], config=configuration)
119+
acr, acrplot = tools.functions.NormalizedAccrual()({
120+
"state_reduction": configuration["state_reduction"],
121+
"dataset": dataset,
122+
"spaces": configuration["state_action_space"],
123+
"normalize_func": configuration["normalize_func"],
124+
})
125+
acr = np.array(acr).squeeze()
126+
configuration["accruals_no_cost"] = acr
127+
configuration["logger"].update({
128+
"ppo_accrual_no_cost": acrplot.fig,
129+
"ppo_accrual_comparison_no_cost": configuration["accrual_comparison"](expert_acr, acr),
130+
})
131+
132+
# Finally
133+
common.finish(configuration)

11-gail.py

Lines changed: 197 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,197 @@
1+
import datetime
2+
import json
3+
import os
4+
import sys
5+
import time
6+
import torch
7+
import tools
8+
import wandb
9+
import gym
10+
import yaml
11+
import numpy as np
12+
import common
13+
import sys; sys.path += ["baselines"]
14+
15+
from baselines.constraint_models.constraint_net.gail_net import GailDiscriminator, GailCallback
16+
# from baselines.exploration.exploration import CostShapingCallback
17+
from baselines.stable_baselines3 import PPO
18+
from baselines.stable_baselines3.common.utils import get_schedule_fn
19+
from baselines.utils.data_utils import read_args, load_config
20+
from baselines.utils.model_utils import load_ppo_config
21+
22+
tools.utils.nowarnings()
23+
24+
def train(args):
25+
configuration, seed = load_config(args)
26+
configuration["seed"] = seed
27+
configuration = tools.data.Configuration(tools.utils.convert_lambdas(configuration))
28+
state_action_space = tools.environments.get_state_action_space(
29+
configuration["env_type"], configuration["env_id"])
30+
configuration.update({"state_action_space": state_action_space})
31+
config_name = os.path.splitext(os.path.basename(args.c))[0]
32+
logdir = "%s(%s)-%s-%s-(%.2f,%d)" % ("GAIL",
33+
"GC", config_name.split("-")[-1], tools.utils.timestamp(),
34+
0, configuration["seed"])
35+
logger = tools.data.Logger(project="ICL",
36+
window=configuration["window"], logdir=logdir)
37+
configuration.update({"logger": logger})
38+
wandb.run.log_code()
39+
# wandb.run.log_code(root=args.c, include_fn=lambda path: path.endswith(".json"))
40+
yaml_artifact = wandb.Artifact('config-yaml', type='yaml')
41+
yaml_artifact.add_file(args.c)
42+
wandb.log_artifact(yaml_artifact)
43+
44+
# Create manual cost function
45+
if configuration["cost_condition"] != "":
46+
manual_cost = common.create_manual_cost_function(configuration)
47+
manualcostvalues, manualcostmap = \
48+
manual_cost.outputs(configuration["state_action_space"])
49+
manualcostvalues = np.array(manualcostvalues).squeeze()
50+
configuration["logger"].update({
51+
"expert_cost": manualcostmap.fig,
52+
})
53+
configuration.update({
54+
"manualcostvalues": manualcostvalues,
55+
})
56+
57+
# Create cost function
58+
cost = tools.functions.CostFunction(configuration, i=configuration["i"], h=64, o=1)
59+
configuration.update({"cost": cost})
60+
costvalues, costmap = cost.outputs(configuration["state_action_space"], invert=True)
61+
costvalues = np.array(costvalues).squeeze()
62+
configuration["logger"].update({"cost": costmap.fig})
63+
if configuration["cost_condition"] != "":
64+
configuration["logger"].update({"cost_comparison": \
65+
configuration["cost_comparison"](manualcostvalues, costvalues)})
66+
67+
# Expert dataset accrual
68+
expert_dataset = tools.base.TrajectoryDataset.load()
69+
expert_acr, expert_acrplot = tools.functions.NormalizedAccrual()({
70+
"state_reduction": configuration["state_reduction"],
71+
"dataset": expert_dataset,
72+
"spaces": configuration["state_action_space"],
73+
"normalize_func": configuration["normalize_func"],
74+
})
75+
expert_acr = np.array(expert_acr).squeeze()
76+
configuration["logger"].update({
77+
"expert_accrual": expert_acrplot.fig
78+
})
79+
configuration.update({
80+
"expert_acr": expert_acr,
81+
})
82+
83+
# Set specs
84+
train_env = configuration["env"]
85+
is_discrete = isinstance(train_env.action_space, gym.spaces.Discrete)
86+
obs_dim = train_env.observation_space.shape[0]
87+
acs_dim = train_env.action_space.n if is_discrete else train_env.action_space.shape[0]
88+
action_low, action_high = None, None
89+
if isinstance(train_env.action_space, gym.spaces.Box):
90+
action_low, action_high = train_env.action_space.low, train_env.action_space.high
91+
92+
# Load expert data
93+
expert_data = torch.load("data.pt")
94+
expert_obs = []
95+
expert_acs = []
96+
for S, A in expert_data:
97+
for s in S:
98+
expert_obs += [s]
99+
for a in A:
100+
expert_acs += [a]
101+
expert_obs = np.array(expert_obs)
102+
expert_acs = np.array(expert_acs)
103+
104+
discriminator = GailDiscriminator(
105+
obs_dim,
106+
acs_dim,
107+
configuration["cost"],
108+
configuration['PPO']['batch_size'],
109+
expert_obs,
110+
expert_acs,
111+
is_discrete,
112+
obs_select_dim=None,
113+
acs_select_dim=None,
114+
clip_obs=configuration['DISC']['clip_obs'],
115+
initial_obs_mean=None,
116+
initial_obs_var=None,
117+
action_low=action_low,
118+
action_high=action_high,
119+
num_spurious_features=None,
120+
freeze_weights=False,
121+
eps=float(configuration['DISC']['disc_eps']),
122+
device=configuration['t'].device,
123+
)
124+
125+
# true_cost_function = get_true_cost_function(configuration['env']['eval_env_id'])
126+
127+
# costShapingCallback = CostShapingCallback(obs_dim,
128+
# acs_dim,
129+
# use_nn_for_shaping=configuration['DISC']['use_cost_net'])
130+
# all_callbacks = [costShapingCallback]
131+
132+
# Define and train model
133+
ppo_parameters = load_ppo_config(config=configuration, train_env=train_env, seed=seed, log_file=None)
134+
model = PPO(logger, **ppo_parameters)
135+
136+
class GAILPolicy(tools.base.Policy):
137+
def act(self, s):
138+
return model.policy.forward(torch.as_tensor([s]).to(configuration['t'].device))[0].detach().view(-1).cpu().numpy()
139+
policy = GAILPolicy()
140+
141+
gail_update = GailCallback(logger, configuration, policy, configuration['plot_interval'],
142+
discriminator=discriminator,
143+
learn_cost=configuration['DISC']['learn_cost'],
144+
plot_disc=False)
145+
all_callbacks = [gail_update]
146+
147+
# Train
148+
try:
149+
model.learn(total_timesteps=int(configuration['PPO']['timesteps']),
150+
callback=all_callbacks)
151+
except:
152+
pass
153+
154+
costvalues, costmap = cost.outputs(configuration["state_action_space"], invert=True)
155+
costvalues = np.array(costvalues).squeeze()
156+
configuration["logger"].update({"cost": costmap.fig})
157+
if configuration["cost_condition"] != "":
158+
configuration["logger"].update({"cost_comparison": \
159+
configuration["cost_comparison"](manualcostvalues, costvalues)})
160+
161+
dataset = configuration["env"].trajectory_dataset(policy,
162+
configuration["expert_episodes"], cost=configuration["cost"])
163+
acr, acrplot = tools.functions.NormalizedAccrual()({
164+
"state_reduction": configuration["state_reduction"],
165+
"dataset": dataset,
166+
"spaces": configuration["state_action_space"],
167+
"normalize_func": configuration["normalize_func"],
168+
})
169+
acr = np.array(acr).squeeze()
170+
configuration["accruals"] = acr
171+
configuration["expert_accruals"] = expert_acr
172+
configuration["logger"].update({
173+
"accrual": acrplot.fig,
174+
"accrual_comparison": configuration["accrual_comparison"](expert_acr, acr),
175+
})
176+
177+
dataset = configuration["env"].trajectory_dataset(policy,
178+
configuration["expert_episodes"])
179+
configuration.update({"agent_dataset": dataset})
180+
acr, acrplot = tools.functions.NormalizedAccrual()({
181+
"state_reduction": configuration["state_reduction"],
182+
"dataset": dataset,
183+
"spaces": configuration["state_action_space"],
184+
"normalize_func": configuration["normalize_func"],
185+
})
186+
acr = np.array(acr).squeeze()
187+
configuration["accruals_no_cost"] = acr
188+
configuration["logger"].update({
189+
"accrual_no_cost": acrplot.fig,
190+
"accrual_comparison_no_cost": configuration["accrual_comparison"](expert_acr, acr),
191+
})
192+
193+
common.finish(configuration)
194+
195+
if __name__ == "__main__":
196+
args = read_args()
197+
train(args)

0 commit comments

Comments
 (0)