1
1
import math
2
- from typing import Dict , Tuple
3
2
4
3
import gymnasium as gym
5
4
import torch
@@ -43,7 +42,7 @@ def __init__(self, envs: gym.vector.SyncVectorEnv, act_fun: str = "relu", ortho_
43
42
layer_init (torch .nn .Linear (64 , envs .single_action_space .n ), std = 0.01 , ortho_init = ortho_init ),
44
43
)
45
44
46
- def get_action (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor ]:
45
+ def get_action (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor ]:
47
46
logits = self .actor (x )
48
47
distribution = Categorical (logits = logits )
49
48
if action is None :
@@ -58,12 +57,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor:
58
57
def get_value (self , x : Tensor ) -> Tensor :
59
58
return self .critic (x )
60
59
61
- def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
60
+ def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
62
61
action , log_prob , entropy = self .get_action (x , action )
63
62
value = self .get_value (x )
64
63
return action , log_prob , entropy , value
65
64
66
- def forward (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
65
+ def forward (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
67
66
return self .get_action_and_value (x , action )
68
67
69
68
@torch .no_grad ()
@@ -77,7 +76,7 @@ def estimate_returns_and_advantages(
77
76
num_steps : int ,
78
77
gamma : float ,
79
78
gae_lambda : float ,
80
- ) -> Tuple [Tensor , Tensor ]:
79
+ ) -> tuple [Tensor , Tensor ]:
81
80
next_value = self .get_value (next_obs ).reshape (1 , - 1 )
82
81
advantages = torch .zeros_like (rewards )
83
82
lastgaelam = 0
@@ -143,7 +142,7 @@ def __init__(
143
142
self .avg_value_loss = MeanMetric (** torchmetrics_kwargs )
144
143
self .avg_ent_loss = MeanMetric (** torchmetrics_kwargs )
145
144
146
- def get_action (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor ]:
145
+ def get_action (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor ]:
147
146
logits = self .actor (x )
148
147
distribution = Categorical (logits = logits )
149
148
if action is None :
@@ -158,12 +157,12 @@ def get_greedy_action(self, x: Tensor) -> Tensor:
158
157
def get_value (self , x : Tensor ) -> Tensor :
159
158
return self .critic (x )
160
159
161
- def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
160
+ def get_action_and_value (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
162
161
action , log_prob , entropy = self .get_action (x , action )
163
162
value = self .get_value (x )
164
163
return action , log_prob , entropy , value
165
164
166
- def forward (self , x : Tensor , action : Tensor = None ) -> Tuple [Tensor , Tensor , Tensor , Tensor ]:
165
+ def forward (self , x : Tensor , action : Tensor = None ) -> tuple [Tensor , Tensor , Tensor , Tensor ]:
167
166
return self .get_action_and_value (x , action )
168
167
169
168
@torch .no_grad ()
@@ -177,7 +176,7 @@ def estimate_returns_and_advantages(
177
176
num_steps : int ,
178
177
gamma : float ,
179
178
gae_lambda : float ,
180
- ) -> Tuple [Tensor , Tensor ]:
179
+ ) -> tuple [Tensor , Tensor ]:
181
180
next_value = self .get_value (next_obs ).reshape (1 , - 1 )
182
181
advantages = torch .zeros_like (rewards )
183
182
lastgaelam = 0
@@ -193,7 +192,7 @@ def estimate_returns_and_advantages(
193
192
returns = advantages + values
194
193
return returns , advantages
195
194
196
- def training_step (self , batch : Dict [str , Tensor ]):
195
+ def training_step (self , batch : dict [str , Tensor ]):
197
196
# Get actions and values given the current observations
198
197
_ , newlogprob , entropy , newvalue = self (batch ["obs" ], batch ["actions" ].long ())
199
198
logratio = newlogprob - batch ["logprobs" ]
0 commit comments