Skip to content

Commit

Permalink
feat(asr): add support for training hybrid models in maglev
Browse files Browse the repository at this point in the history
  • Loading branch information
Mayank Jain committed Nov 16, 2023
1 parent 88ecb71 commit 7ae46a6
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,11 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from pytorch_lightning.plugins.environments import LightningEnvironment
class MaglevEnvironment(LightningEnvironment):
def creates_children(self) -> bool:
logging.info(f"MaglevEnvironment: creates_children returning True")
return True


@hydra_runner(
Expand All @@ -73,7 +78,10 @@
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

trainer = pl.Trainer(**cfg.trainer)
if int(cfg.trainer.num_nodes) > 1:
trainer = pl.Trainer(**cfg.trainer, plugins=[MaglevEnvironment()])
else:
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -76,13 +76,21 @@
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager
from pytorch_lightning.plugins.environments import LightningEnvironment
class MaglevEnvironment(LightningEnvironment):
def creates_children(self) -> bool:
logging.info(f"MaglevEnvironment: creates_children returning True")
return True


@hydra_runner(config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc")
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

trainer = pl.Trainer(**cfg.trainer)
if int(cfg.trainer.num_nodes) > 1:
trainer = pl.Trainer(**cfg.trainer, plugins=[MaglevEnvironment()])
else:
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecHybridRNNTCTCModel(cfg=cfg.model, trainer=trainer)

Expand Down
10 changes: 8 additions & 2 deletions nemo/core/connectors/save_restore_connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -427,7 +427,10 @@ def _handle_artifacts(self, model, nemo_file_folder):
# Note uuid.uuid4().hex is guaranteed to be 32 character long
artifact_base_name = os.path.basename(artiitem.path)
artifact_uniq_name = f"{uuid.uuid4().hex}_{artifact_base_name}"
shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))
try:
shutil.copy2(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))
except:
shutil.copy(artiitem.path, os.path.join(nemo_file_folder, artifact_uniq_name))

# Update artifacts registry
artiitem.hashed_path = "nemo:" + artifact_uniq_name
Expand Down Expand Up @@ -483,7 +486,10 @@ def _handle_artifacts(self, model, nemo_file_folder):
artifact_base_name = os.path.basename(artiitem.path)
# no need to hash here as we are in tarfile_artifacts which are already hashed
artifact_uniq_name = artifact_base_name
shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name))
try:
shutil.copy2(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name))
except:
shutil.copy(artifact_base_name, os.path.join(nemo_file_folder, artifact_uniq_name))

# Update artifacts registry
new_artiitem = model_utils.ArtifactItem()
Expand Down

0 comments on commit 7ae46a6

Please sign in to comment.