Skip to content

Commit a8775d0

Browse files
committed
[Feature] Add deterministic_sample to masked categorical
ghstack-source-id: 6dfdf0e Pull Request resolved: #2708
1 parent db210ac commit a8775d0

File tree

1 file changed

+4
-0
lines changed

1 file changed

+4
-0
lines changed

torchrl/modules/distributions/discrete.py

+4
Original file line numberDiff line numberDiff line change
@@ -319,6 +319,10 @@ def _mask_logits(
319319
logits.masked_fill_(padding_mask, neg_inf)
320320
return logits
321321

322+
@property
323+
def deterministic_sample(self):
324+
return self.mode
325+
322326

323327
class MaskedOneHotCategorical(MaskedCategorical):
324328
"""MaskedCategorical distribution.

0 commit comments

Comments
 (0)