From 7ae46a6ca68700a6e5f73414aa6d8b8160d51c94 Mon Sep 17 00:00:00 2001 From: Mayank Jain Date: Tue, 17 Oct 2023 11:38:11 +0530 Subject: [PATCH] feat(asr): add support for training hybrid models in maglev --- .../speech_to_text_hybrid_rnnt_ctc_bpe.py | 10 +++++++++- .../speech_to_text_hybrid_rnnt_ctc_char.py | 10 +++++++++- nemo/core/connectors/save_restore_connector.py | 10 ++++++++-- 3 files changed, 26 insertions(+), 4 deletions(-) diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py index 2de150c71328..99ab15a2263a 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_bpe.py @@ -65,6 +65,11 @@ from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager +from pytorch_lightning.plugins.environments import LightningEnvironment +class MaglevEnvironment(LightningEnvironment): + def creates_children(self) -> bool: + logging.info(f"MaglevEnvironment: creates_children returning True") + return True @hydra_runner( @@ -73,7 +78,10 @@ def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') - trainer = pl.Trainer(**cfg.trainer) + if int(cfg.trainer.num_nodes) > 1: + trainer = pl.Trainer(**cfg.trainer, plugins=[MaglevEnvironment()]) + else: + trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer) diff --git a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py index 532e2c9ed0be..3944d802f5da 100644 --- a/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py +++ b/examples/asr/asr_hybrid_transducer_ctc/speech_to_text_hybrid_rnnt_ctc_char.py @@ -76,13 +76,21 @@ from nemo.core.config import hydra_runner from nemo.utils import logging from nemo.utils.exp_manager import exp_manager +from pytorch_lightning.plugins.environments import LightningEnvironment +class MaglevEnvironment(LightningEnvironment): + def creates_children(self) -> bool: + logging.info(f"MaglevEnvironment: creates_children returning True") + return True @hydra_runner(config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc") def main(cfg): logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}') - trainer = pl.Trainer(**cfg.trainer) + if int(cfg.trainer.num_nodes) > 1: + trainer = pl.Trainer(**cfg.trainer, plugins=[MaglevEnvironment()]) + else: + trainer = pl.Trainer(**cfg.trainer) exp_manager(trainer, cfg.get("exp_manager", None)) asr_model = EncDecHybridRNNTCTCModel(cfg=cfg.model, trainer=trainer) diff --git a/nemo/core/connectors/save_restore_connector.py b/nemo/core/connectors/save_restore_connector.py index bbcb32f84024..5428b02eeaac 100644 --- a/nemo/core/connectors/save_restore_connector.py +++ b/nemo/core/connectors/save_restore_connector.py @@ -427,7 +427,10 @@ def _handle_artifacts(self, model, nemo_file_folder): # Note uuid.uuid4().hex is guaranteed to be 32 character long artifact_base_name = os.path.basename(artiitem.path) artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}" - shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name)) + try: + shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name)) + except: + shutil.copy(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name)) # Update artifacts registry artiitem.hashed_path = "nemo:" + artifact_uniq_name @@ -483,7 +486,10 @@ def _handle_artifacts(self, model, nemo_file_folder): artifact_base_name = os.path.basename(artiitem.path) # no need to hash here as we are in tarfile_artifacts which are already hashed artifact_uniq_name = artifact_base_name - shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name)) + try: + shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name)) + except: + shutil.copy(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name)) # Update artifacts registry new_artiitem = model_utils.ArtifactItem()