Skip to content

Commit

Permalink
bugfix
Browse files Browse the repository at this point in the history
Signed-off-by: Aleksandr Laptev <[email protected]>
  • Loading branch information
GNroy committed Apr 20, 2024
1 parent cacb012 commit 83ec9b6
Show file tree
Hide file tree
Showing 3 changed files with 53 additions and 33 deletions.
61 changes: 33 additions & 28 deletions nemo/collections/asr/parts/submodules/tdt_loop_labels_computer.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,14 +250,15 @@ def loop_labels_torch(

# do not recalculate joint projection, project only once
encoder_output_projected = self.joint.project_encoder(encoder_output)
dtype = encoder_output_projected.dtype

# init output structures: BatchedHyps (for results), BatchedAlignments + last decoder state
# init empty batched hypotheses
batched_hyps = rnnt_utils.BatchedHyps(
batch_size=batch_size,
init_length=max_time * self.max_symbols if self.max_symbols is not None else max_time,
device=device,
float_dtype=encoder_output_projected.dtype,
float_dtype=dtype,
)
# sample state, will be replaced further when the decoding for hypothesis is done
last_decoder_state = self.decoder.initialize_state(encoder_output_projected)
Expand All @@ -269,7 +270,7 @@ def loop_labels_torch(
logits_dim=self.joint.num_classes_with_blank,
init_length=max_time * 2 if use_alignments else 1, # blank for each timestep + text tokens
device=device,
float_dtype=encoder_output_projected.dtype,
float_dtype=dtype,
store_alignments=self.preserve_alignments,
store_frame_confidence=self.preserve_frame_confidence,
with_duration_confidence=self.include_duration_confidence,
Expand Down Expand Up @@ -334,19 +335,19 @@ def loop_labels_torch(
time_indices=time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=labels if self.preserve_alignments else None,
confidence=torch.cat(
confidence=torch.stack(
(
self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).unsqueeze(
1
self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(
dtype=dtype
),
self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).unsqueeze(
1
self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to(
dtype=dtype
),
),
dim=1,
dim=-1,
)
if self.include_duration_confidence
else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1))
else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(dtype=dtype)
if self.preserve_frame_confidence
else None,
)
Expand Down Expand Up @@ -386,19 +387,21 @@ def loop_labels_torch(
time_indices=time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=more_labels if self.preserve_alignments else None,
confidence=torch.cat(
confidence=torch.stack(
(
self._get_confidence_tensor(
F.log_softmax(logits[:, :-num_durations], dim=-1)
).unsqueeze(1),
self._get_confidence_tensor(
F.log_softmax(logits[:, -num_durations:], dim=-1)
).unsqueeze(1),
self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(
dtype=dtype
),
self._get_confidence_tensor(F.log_softmax(logits[:, -num_durations:], dim=-1)).to(
dtype=dtype
),
),
dim=1,
dim=-1,
)
if self.include_duration_confidence
else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1))
else self._get_confidence_tensor(F.log_softmax(logits[:, :-num_durations], dim=-1)).to(
dtype=dtype
)
if self.preserve_frame_confidence
else None,
)
Expand Down Expand Up @@ -647,6 +650,7 @@ def _before_inner_loop_get_joint_output(self):
# stage 2: get joint output, iteratively seeking for non-blank labels
# blank label in `labels` tensor means "end of hypothesis" (for this index)
self.state.active_mask_prev.copy_(self.state.active_mask, non_blocking=True)
dtype = self.state.encoder_output_projected.dtype
logits = (
self.joint.joint_after_projection(
self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze(
Expand Down Expand Up @@ -675,21 +679,21 @@ def _before_inner_loop_get_joint_output(self):
time_indices=self.state.time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=self.state.labels if self.preserve_alignments else None,
confidence=torch.cat(
confidence=torch.stack(
(
self._get_confidence_tensor(
F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1)
).unsqueeze(1),
).to(dtype=dtype),
self._get_confidence_tensor(
F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1)
).unsqueeze(1),
).to(dtype=dtype),
),
dim=1,
dim=-1,
)
if self.include_duration_confidence
else self._get_confidence_tensor(
F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1)
)
).to(dtype=dtype)
if self.preserve_frame_confidence
else None,
)
Expand All @@ -715,6 +719,7 @@ def _inner_loop_code(self):
self.state.time_indices_current_labels,
out=self.state.time_indices_current_labels,
)
dtype = self.state.encoder_output_projected.dtype
logits = (
self.joint.joint_after_projection(
self.state.encoder_output_projected[self.state.batch_indices, self.state.safe_time_indices].unsqueeze(
Expand All @@ -741,21 +746,21 @@ def _inner_loop_code(self):
time_indices=self.state.time_indices_current_labels,
logits=logits if self.preserve_alignments else None,
labels=more_labels if self.preserve_alignments else None,
confidence=torch.cat(
confidence=torch.stack(
(
self._get_confidence_tensor(
F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1)
).unsqueeze(1),
).to(dtype=dtype),
self._get_confidence_tensor(
F.log_softmax(logits[:, -self.state.all_durations.shape[0] :], dim=-1)
).unsqueeze(1),
).to(dtype=dtype),
),
dim=1,
dim=-1,
)
if self.include_duration_confidence
else self._get_confidence_tensor(
F.log_softmax(logits[:, : -self.state.all_durations.shape[0]], dim=-1)
)
).to(dtype=dtype)
if self.preserve_frame_confidence
else None,
)
Expand Down
6 changes: 4 additions & 2 deletions nemo/collections/asr/parts/utils/asr_confidence_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,8 +373,8 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None):
self._aggregate_confidence = self.confidence_aggregation_bank[self.word_confidence_aggregation]

# Update preserve frame confidence
if self.preserve_frame_confidence is False:
if self.cfg.strategy in ['greedy', 'greedy_batch']:
if self.cfg.strategy in ['greedy', 'greedy_batch']:
if not self.preserve_frame_confidence:
self.preserve_frame_confidence = self.cfg.greedy.get('preserve_frame_confidence', False)
# OmegaConf.structured ensures that post_init check is always executed
confidence_method_cfg = OmegaConf.structured(self.cfg.greedy).get('confidence_method_cfg', None)
Expand All @@ -383,6 +383,8 @@ def _init_confidence(self, confidence_cfg: Optional[DictConfig] = None):
if confidence_method_cfg is None
else OmegaConf.structured(ConfidenceMethodConfig(**confidence_method_cfg))
)
if not self.tdt_include_duration_confidence:
self.tdt_include_duration_confidence = self.cfg.greedy.get('tdt_include_duration_confidence', False)

@abstractmethod
def compute_confidence(self, hypotheses_list: List[Hypothesis]) -> List[Hypothesis]:
Expand Down
19 changes: 16 additions & 3 deletions tutorials/asr/ASR_Confidence_Estimation.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -531,12 +531,25 @@
},
"outputs": [],
"source": [
"def round_confidence(confidence_number, ndigits=3):\n",
" if isinstance(confidence_number, float):\n",
" return round(confidence_number, ndigits)\n",
" elif len(confidence_number.size()) == 0: # torch.tensor with one element\n",
" return round(confidence_number.item(), ndigits)\n",
" elif len(confidence_number.size()) == 1: # torch.tensor with a list if elements\n",
" return [round(c.item(), ndigits) for c in confidence_number]\n",
" else:\n",
" raise RuntimeError(f\"Unexpected confidence_number: `{confidence_number}`\")\n",
"\n",
"\n",
"tran = transcriptions[0]\n",
"print(\n",
" f\"\"\" Recognized text: `{tran.text}`\\n\n",
" Word confidence: {[round(c, 3) for c in tran.word_confidence]}\\n\n",
" Token confidence: {[round(c, 3) for c in tran.token_confidence]}\\n\n",
" Frame confidence: {[([round(cc, 3) for cc in c] if is_rnnt else round(c, 3)) for c in tran.frame_confidence]}\"\"\"\n",
" Word confidence: {[round_confidence(c) for c in tran.word_confidence]}\\n\n",
" Token confidence: {[round_confidence(c) for c in tran.token_confidence]}\\n\n",
" Frame confidence: {\n",
" [([round_confidence(cc) for cc in c] if is_rnnt else round_confidence(c)) for c in tran.frame_confidence]\n",
" }\"\"\"\n",
")"
]
},
Expand Down

0 comments on commit 83ec9b6

Please sign in to comment.