-
Notifications
You must be signed in to change notification settings - Fork 57
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
CEM #373
base: master
Are you sure you want to change the base?
CEM #373
Conversation
Codecov Report
@@ Coverage Diff @@
## master #373 +/- ##
==========================================
- Coverage 91.28% 91.25% -0.04%
==========================================
Files 90 92 +2
Lines 3809 3910 +101
==========================================
+ Hits 3477 3568 +91
- Misses 332 342 +10
|
This pull request introduces 3 alerts when merging a90e8d0 into 52b0b4c - view on LGTM.com new alerts:
|
This pull request introduces 4 alerts when merging 3b2067d into 25eb018 - view on LGTM.com new alerts:
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good. Mostly just doubts. How's the performance of the agent now? (Does it hit 500?)
genrl/agents/modelbased/base.py
Outdated
raise NotImplementedError | ||
|
||
|
||
class ModelBasedAgent(ABC): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can this inherit from the genrl/deep BaseAgent
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It can, and I think thats a better option (for now at least)
# No need for this here | ||
pass | ||
|
||
def collect_rollouts(self, state: torch.Tensor): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks pretty similar to the OnPolicyAgent
method. Shouldn't this return values
and dones
though? Not sure if this is a consequence of the algo.
genrl/agents/modelbased/cem/cem.py
Outdated
for i, done in enumerate(dones): | ||
if done or timestep == self.rollout_size - 1: | ||
self.rewards.append(self.env.episode_reward[i].detach().clone()) | ||
# self.env.reset_single_env(i) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why is this commented out? This is necessary to reset environments immediately as they are set to done. (Not a good practice to do env.step()
if the env is already returning done = True
)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Since I am breaking the loop of actions if a env.step()
returns done=True
, and every plan session (the plan
function) starts with env.reset()
, I think this is redundant here, hence its commented out
|
||
|
||
def test_CEM(): | ||
env = VectorEnv("CartPole-v0", 1) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why set it to 1? It does work with multiple envs right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yeah it does
Also, forgot to mention the docs. The CEM agent code didn't have docstrings afair. |
from genrl.trainers import OnPolicyTrainer | ||
|
||
|
||
def test_CEM(): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also please make this a class so the tests are easier to find/understand
Yeah, I'll get that done too |
This pull request introduces 2 alerts when merging f5a189d into 25eb018 - view on LGTM.com new alerts:
|
This pull request introduces 2 alerts when merging 4b11c16 into 25eb018 - view on LGTM.com new alerts:
|
Wrt #363