From d3f77de0b97dc1a1b79f43196ed2821442143fdb Mon Sep 17 00:00:00 2001 From: Nathan Vogt Date: Thu, 27 Jun 2024 10:44:05 -0700 Subject: [PATCH] fix --- td/learning/constrained_decoding.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/td/learning/constrained_decoding.py b/td/learning/constrained_decoding.py index d3ac6e4..43e3b23 100644 --- a/td/learning/constrained_decoding.py +++ b/td/learning/constrained_decoding.py @@ -137,13 +137,19 @@ def feed_token(self, token: int, probs=None): self._current_start = self._token_positions_to_real[ self._tokenizer.token_to_position(token) ] - self._current_end = self._edit_spans[self._current_start] self._decode_state = DecoderState.States.TOKEN position_rule = self._position_rule(self._current_start) self._interactive_parser = self._grammar._lark_parser_for_start[ position_rule ].parse_interactive(start=position_rule) - self._recompute_mask() + + # Find the correct end position based on the selected rule + for node in self._position_to_nodes[self._current_start]: + if self._node_rule(node) == position_rule: + self._current_end = node.meta.end_pos + break + else: + self._current_end = self._edit_spans[self._current_start] elif self._decode_state == DecoderState.States.TOKEN: token_str = self._tokenizer._index_to_token[token] token_name = self._grammar.rev_vocabulary_map[token_str]