From d51ab9e080b9fba38e7241b1d2c46e64163dfa6a Mon Sep 17 00:00:00 2001 From: Kaushal BHogale Date: Fri, 19 Apr 2024 15:08:57 +0530 Subject: [PATCH] Add support for logprobs in CTC model --- .../asr/models/hybrid_rnnt_ctc_bpe_models.py | 2 +- .../asr/models/hybrid_rnnt_ctc_models.py | 15 ++++++++++----- nemo/collections/asr/models/rnnt_models.py | 2 ++ .../asr/modules/audio_preprocessing.py | 4 ++-- nemo/collections/asr/modules/conv_asr.py | 6 +++--- .../collections/asr/parts/mixins/transcription.py | 3 +++ .../asr/parts/preprocessing/features.py | 12 ++++++------ 7 files changed, 27 insertions(+), 17 deletions(-) diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py index ea702d053..4da9bf826 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py @@ -165,7 +165,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None): ) # setting the RNNT decoder as the default one - self.cur_decoder = "rnnt" + self.cur_decoder = "ctc" def _setup_dataloader_from_config(self, config: Optional[Dict]): diff --git a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py index 3cdc7e15e..ecdf80d24 100644 --- a/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py +++ b/nemo/collections/asr/models/hybrid_rnnt_ctc_models.py @@ -106,6 +106,7 @@ def transcribe( verbose: bool = True, override_config: Optional[TranscribeConfig] = None, language_id: str = None, #CTEMO + logprobs: bool = False, ) -> TranscriptionReturnType: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. @@ -145,7 +146,8 @@ def transcribe( augmentor=augmentor, verbose=verbose, override_config=override_config, - language_id = language_id #CTEMO + language_id = language_id, #CTEMO + logprobs=logprobs ) def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig): @@ -161,6 +163,7 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig): self.ctc_decoder.unfreeze() def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): + if self.cur_decoder == "rnnt": return super()._transcribe_forward(batch, trcfg) @@ -172,7 +175,7 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig): language_ids = [language_id] * len(batch[0]) logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids) output = dict(logits=logits, encoded_len=encoded_len, language_ids=language_ids) - + del encoded return output @@ -200,9 +203,11 @@ def _transcribe_output_processing( best_hyp[idx].alignments = best_hyp[idx].y_sequence # DEPRECATED? - # if logprobs: - # for logit, elen in zip(logits, encoded_len): - # logits_list.append(logit[:elen]) + if trcfg.logprobs: + logits_list = [] + for logit, elen in zip(logits, encoded_len): + logits_list.append(logit[:elen]) + return logits_list del logits, encoded_len diff --git a/nemo/collections/asr/models/rnnt_models.py b/nemo/collections/asr/models/rnnt_models.py index ff119aab3..e19518533 100644 --- a/nemo/collections/asr/models/rnnt_models.py +++ b/nemo/collections/asr/models/rnnt_models.py @@ -244,6 +244,7 @@ def transcribe( verbose: bool = True, language_id: str = None, #CTEMO override_config: Optional[TranscribeConfig] = None, + logprobs: bool = False, ) -> TranscriptionReturnType: """ Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping. @@ -281,6 +282,7 @@ def transcribe( augmentor=augmentor, verbose=verbose, override_config=override_config, + logprobs=logprobs, # Additional arguments partial_hypothesis=partial_hypothesis, ) diff --git a/nemo/collections/asr/modules/audio_preprocessing.py b/nemo/collections/asr/modules/audio_preprocessing.py index cc5312403..f70015257 100644 --- a/nemo/collections/asr/modules/audio_preprocessing.py +++ b/nemo/collections/asr/modules/audio_preprocessing.py @@ -280,8 +280,8 @@ def __init__( stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility ) - def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200): - batch_size = torch.randint(low=1, high=max_batch, size=[1]).item() + def input_example(self, max_batch: int = 1, max_dim: int = 32000, min_length: int = 200): + batch_size = 1 # torch.randint(low=1, high=max_batch, size=[1]).item() max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item() signals = torch.rand(size=[batch_size, max_length]) * 2 - 1 lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size]) diff --git a/nemo/collections/asr/modules/conv_asr.py b/nemo/collections/asr/modules/conv_asr.py index e959a6388..b3599a8c5 100644 --- a/nemo/collections/asr/modules/conv_asr.py +++ b/nemo/collections/asr/modules/conv_asr.py @@ -453,6 +453,7 @@ def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary= self.temperature = 1.0 self.multisoftmax = multisoftmax self.language_masks = language_masks + @typecheck() def forward(self, encoder_output, language_ids=None): #CTEMO # Adapter module forward step @@ -477,9 +478,8 @@ def forward(self, encoder_output, language_ids=None): #CTEMO # Send mask to GPU mask = mask.to(decoder_output.device) # masked_output = self.masked_softmax(decoder_output, mask) # B x T x 3073 -> B x T x 257 - decoder_output = torch.masked_select(decoder_output, mask).view(decoder_output.shape[0],decoder_output.shape[1],-1) - - del mask + decoder_output = torch.masked_select(decoder_output, mask).view(decoder_output.shape[0],decoder_output.shape[1],-1) + del mask # print(mask[0][0]) # softmax_output = self.masked_softmax(decoder_output, mask) # return softmax_output diff --git a/nemo/collections/asr/parts/mixins/transcription.py b/nemo/collections/asr/parts/mixins/transcription.py index 5a7167960..449374d0a 100644 --- a/nemo/collections/asr/parts/mixins/transcription.py +++ b/nemo/collections/asr/parts/mixins/transcription.py @@ -60,6 +60,7 @@ class TranscribeConfig: channel_selector: ChannelSelectorType = None augmentor: Optional[DictConfig] = None verbose: bool = True + logprobs: bool = False # Utility partial_hypothesis: Optional[List[Any]] = None @@ -194,6 +195,7 @@ def transcribe( augmentor: DictConfig = None, verbose: bool = True, override_config: Optional[TranscribeConfig] = None, + logprobs: bool = False, **config_kwargs, ) -> GenericTranscriptionType: """ @@ -242,6 +244,7 @@ def transcribe( channel_selector=channel_selector, augmentor=augmentor, verbose=verbose, + logprobs=logprobs, **config_kwargs, ) else: diff --git a/nemo/collections/asr/parts/preprocessing/features.py b/nemo/collections/asr/parts/preprocessing/features.py index 67813f3e6..8d0194b6e 100644 --- a/nemo/collections/asr/parts/preprocessing/features.py +++ b/nemo/collections/asr/parts/preprocessing/features.py @@ -63,12 +63,12 @@ def normalize_batch(x, seq_len, normalize_type): x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device) for i in range(x.shape[0]): - if x[i, :, : seq_len[i]].shape[1] == 1: - raise ValueError( - "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " - "in torch.std() returning nan. Make sure your audio length has enough samples for a single " - "feature (ex. at least `hop_length` for Mel Spectrograms)." - ) + # if x[i, :, : seq_len[i]].shape[1] == 1: + # raise ValueError( + # "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result " + # "in torch.std() returning nan. Make sure your audio length has enough samples for a single " + # "feature (ex. at least `hop_length` for Mel Spectrograms)." + # ) x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1) x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1) # make sure x_std is not zero