Skip to content

Commit

Permalink
Fix merge conflicts in conv_asr
Browse files Browse the repository at this point in the history
  • Loading branch information
kaushal-py committed Apr 22, 2024
2 parents 7a625e1 + d51ab9e commit 3fee261
Show file tree
Hide file tree
Showing 7 changed files with 25 additions and 14 deletions.
2 changes: 1 addition & 1 deletion nemo/collections/asr/models/hybrid_rnnt_ctc_bpe_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -165,7 +165,7 @@ def __init__(self, cfg: DictConfig, trainer: Trainer = None):
)

# setting the RNNT decoder as the default one
self.cur_decoder = "rnnt"
self.cur_decoder = "ctc"

def _setup_dataloader_from_config(self, config: Optional[Dict]):

Expand Down
15 changes: 10 additions & 5 deletions nemo/collections/asr/models/hybrid_rnnt_ctc_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ def transcribe(
verbose: bool = True,
override_config: Optional[TranscribeConfig] = None,
language_id: str = None, #CTEMO
logprobs: bool = False,
) -> TranscriptionReturnType:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand Down Expand Up @@ -145,7 +146,8 @@ def transcribe(
augmentor=augmentor,
verbose=verbose,
override_config=override_config,
language_id = language_id #CTEMO
language_id = language_id, #CTEMO
logprobs=logprobs
)

def _transcribe_on_begin(self, audio, trcfg: TranscribeConfig):
Expand All @@ -161,6 +163,7 @@ def _transcribe_on_end(self, trcfg: TranscribeConfig):
self.ctc_decoder.unfreeze()

def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):

if self.cur_decoder == "rnnt":
return super()._transcribe_forward(batch, trcfg)

Expand All @@ -172,7 +175,7 @@ def _transcribe_forward(self, batch: Any, trcfg: TranscribeConfig):
language_ids = [language_id] * len(batch[0])
logits = self.ctc_decoder(encoder_output=encoded, language_ids=language_ids)
output = dict(logits=logits, encoded_len=encoded_len, language_ids=language_ids)

del encoded
return output

Expand Down Expand Up @@ -200,9 +203,11 @@ def _transcribe_output_processing(
best_hyp[idx].alignments = best_hyp[idx].y_sequence

# DEPRECATED?
# if logprobs:
# for logit, elen in zip(logits, encoded_len):
# logits_list.append(logit[:elen])
if trcfg.logprobs:
logits_list = []
for logit, elen in zip(logits, encoded_len):
logits_list.append(logit[:elen])
return logits_list

del logits, encoded_len

Expand Down
2 changes: 2 additions & 0 deletions nemo/collections/asr/models/rnnt_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -244,6 +244,7 @@ def transcribe(
verbose: bool = True,
language_id: str = None, #CTEMO
override_config: Optional[TranscribeConfig] = None,
logprobs: bool = False,
) -> TranscriptionReturnType:
"""
Uses greedy decoding to transcribe audio files. Use this method for debugging and prototyping.
Expand Down Expand Up @@ -281,6 +282,7 @@ def transcribe(
augmentor=augmentor,
verbose=verbose,
override_config=override_config,
logprobs=logprobs,
# Additional arguments
partial_hypothesis=partial_hypothesis,
)
Expand Down
4 changes: 2 additions & 2 deletions nemo/collections/asr/modules/audio_preprocessing.py
Original file line number Diff line number Diff line change
Expand Up @@ -280,8 +280,8 @@ def __init__(
stft_conv=stft_conv, # Deprecated arguments; kept for config compatibility
)

def input_example(self, max_batch: int = 8, max_dim: int = 32000, min_length: int = 200):
batch_size = torch.randint(low=1, high=max_batch, size=[1]).item()
def input_example(self, max_batch: int = 1, max_dim: int = 32000, min_length: int = 200):
batch_size = 1 # torch.randint(low=1, high=max_batch, size=[1]).item()
max_length = torch.randint(low=min_length, high=max_dim, size=[1]).item()
signals = torch.rand(size=[batch_size, max_length]) * 2 - 1
lengths = torch.randint(low=min_length, high=max_dim, size=[batch_size])
Expand Down
1 change: 1 addition & 0 deletions nemo/collections/asr/modules/conv_asr.py
Original file line number Diff line number Diff line change
Expand Up @@ -453,6 +453,7 @@ def __init__(self, feat_in, num_classes, init_mode="xavier_uniform", vocabulary=
self.temperature = 1.0
self.multisoftmax = multisoftmax
self.language_masks = language_masks

@typecheck()
def forward(self, encoder_output, language_ids=None): #CTEMO
# Adapter module forward step
Expand Down
3 changes: 3 additions & 0 deletions nemo/collections/asr/parts/mixins/transcription.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ class TranscribeConfig:
channel_selector: ChannelSelectorType = None
augmentor: Optional[DictConfig] = None
verbose: bool = True
logprobs: bool = False

# Utility
partial_hypothesis: Optional[List[Any]] = None
Expand Down Expand Up @@ -194,6 +195,7 @@ def transcribe(
augmentor: DictConfig = None,
verbose: bool = True,
override_config: Optional[TranscribeConfig] = None,
logprobs: bool = False,
**config_kwargs,
) -> GenericTranscriptionType:
"""
Expand Down Expand Up @@ -242,6 +244,7 @@ def transcribe(
channel_selector=channel_selector,
augmentor=augmentor,
verbose=verbose,
logprobs=logprobs,
**config_kwargs,
)
else:
Expand Down
12 changes: 6 additions & 6 deletions nemo/collections/asr/parts/preprocessing/features.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,12 @@ def normalize_batch(x, seq_len, normalize_type):
x_mean = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
x_std = torch.zeros((seq_len.shape[0], x.shape[1]), dtype=x.dtype, device=x.device)
for i in range(x.shape[0]):
if x[i, :, : seq_len[i]].shape[1] == 1:
raise ValueError(
"normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
"in torch.std() returning nan. Make sure your audio length has enough samples for a single "
"feature (ex. at least `hop_length` for Mel Spectrograms)."
)
# if x[i, :, : seq_len[i]].shape[1] == 1:
# raise ValueError(
# "normalize_batch with `per_feature` normalize_type received a tensor of length 1. This will result "
# "in torch.std() returning nan. Make sure your audio length has enough samples for a single "
# "feature (ex. at least `hop_length` for Mel Spectrograms)."
# )
x_mean[i, :] = x[i, :, : seq_len[i]].mean(dim=1)
x_std[i, :] = x[i, :, : seq_len[i]].std(dim=1)
# make sure x_std is not zero
Expand Down

0 comments on commit 3fee261

Please sign in to comment.