Skip to content

Commit

Permalink
EWC and MAS implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
ASR committed May 22, 2024
1 parent 0f8bc67 commit 761165a
Show file tree
Hide file tree
Showing 4 changed files with 651 additions and 12 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -58,29 +58,193 @@
"""

import torch, gc,os, pickle
import pytorch_lightning as pl
from omegaconf import OmegaConf

from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel
from nemo.collections.asr.models import EncDecHybridRNNTCTCBPEModel, EncDecHybridRNNTCTCBPEModelEWC, EncDecHybridRNNTCTCBPEModelMAS
from nemo.core.config import hydra_runner
from nemo.utils import logging
from nemo.utils.exp_manager import exp_manager

import tqdm


def compute_ewc_params(model, dataloader,device):
model.to(device)
params, fisher = {}, {}
nb = 0
for e, batch in tqdm.tqdm(enumerate(dataloader), total = len(dataloader)):
cuda_batch = [x.to(device) for x in batch[:4]] + batch[4:]
loss = model.training_step_custom(cuda_batch,e)
loss.backward()

with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad or param.grad is None:
continue
if name not in params:
params[name] = param.clone().cpu()
if name not in fisher:
fisher[name] = (param.grad.clone() ** 2).cpu()
else:
fisher[name] += (param.grad.clone() ** 2).cpu()

model.zero_grad()
nb += 1
# loss.detach().cpu()
# if e == 5:
# del cuda_batch, model, loss
# break

for name in fisher:
fisher[name] /= nb

return params, fisher

def compute_mas_params(model, dataloader,device):
model.to(device)
params, importance = {}, {}
nb = 0
for e, batch in tqdm.tqdm(enumerate(dataloader), total = len(dataloader)):
cuda_batch = [x.to(device) for x in batch[:4]] + batch[4:]
loss = model.training_step_custom(cuda_batch,e)
loss.backward()

with torch.no_grad():
for name, param in model.named_parameters():
if not param.requires_grad or param.grad is None:
continue
if name not in params:
params[name] = param.clone().cpu()
if name not in importance:
importance[name] = param.grad.clone().abs().cpu()
else:
importance[name] += param.grad.clone().abs().cpu()

model.zero_grad()
nb += 1
# loss.detach().cpu()
# if e == 5:
# del cuda_batch, model, loss
# break

for name in importance:
importance[name] /= nb

return params, importance

@hydra_runner(
config_path="../conf/conformer/hybrid_transducer_ctc/", config_name="conformer_hybrid_transducer_ctc_bpe"
)
def main(cfg):
logging.info(f'Hydra config: {OmegaConf.to_yaml(cfg)}')

trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer)

# Initialize the weights of the model from another model, if provided via config
asr_model.maybe_init_from_pretrained_checkpoint(cfg)

trainer.fit(asr_model)
# For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
torch.set_float32_matmul_precision("medium")

cl_config = cfg.get('continual_learning_strategy',None)

if cl_config is None:
trainer = pl.Trainer(**cfg.trainer)
exp_manager(trainer, cfg.get("exp_manager", None))
asr_model = EncDecHybridRNNTCTCBPEModel(cfg=cfg.model, trainer=trainer)

# Initialize the weights of the model from another model, if provided via config
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
trainer.fit(asr_model)

elif cl_config.name == 'EWC':
# EWC Related things
log_dir = cfg.exp_manager.explicit_log_dir
if not os.path.exists(f'{log_dir}/ewc.pkl'):
# load the previous checkpoint with the old dataloader
prev_cfg = OmegaConf.load(cl_config.ewc_params.old_config)
trainer = pl.Trainer(**cfg.trainer)
# prev_cfg.model.train_ds.batch_size = 32

## model contains dataset, this means that this line has loaded all the data of the previous episode
asr_model_old = EncDecHybridRNNTCTCBPEModel(cfg=prev_cfg.model, trainer=trainer)

## load the model weights from the current config
asr_model_old.maybe_init_from_pretrained_checkpoint(cfg)
# asr_model.setup_optimization()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## computing the fisher and storing the old params
params, fisher = compute_ewc_params(asr_model_old,asr_model_old._train_dl,device)
with open(f'{log_dir}/ewc.pkl','wb') as writer:
pickle.dump({'params':params,'fisher':fisher},writer)

del asr_model_old, trainer

gc.collect()
torch.cuda.empty_cache()

else:
## load the param and fisher
with open(f'{log_dir}/ewc.pkl','rb') as reader:
saved = pickle.load(reader)
params, fisher = saved['params'],saved['fisher']

cl_params = {
'alpha': cl_config.ewc_params.alpha,
'lamda': cl_config.ewc_params.lda,
'params': params,
'fisher': fisher
}

trainer = pl.Trainer(**cfg.trainer)
asr_model = EncDecHybridRNNTCTCBPEModelEWC(cfg=cfg.model, trainer=trainer, cl_params=cl_params)
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
trainer.fit(asr_model)

elif cl_config.name == 'MAS':
# MAS Related things
log_dir = cfg.exp_manager.explicit_log_dir
if not os.path.exists(f'{log_dir}/mas.pkl'):
# load the previous checkpoint with the old dataloader
prev_cfg = OmegaConf.load(cl_config.mas_params.old_config)
trainer = pl.Trainer(**cfg.trainer)
# prev_cfg.model.train_ds.batch_size = 32

## model contains dataset, this means that this line has loaded all the data of the previous episode
asr_model_old = EncDecHybridRNNTCTCBPEModel(cfg=prev_cfg.model, trainer=trainer)

## load the model weights from the current config
asr_model_old.maybe_init_from_pretrained_checkpoint(cfg)
# asr_model.setup_optimization()
device = 'cuda' if torch.cuda.is_available() else 'cpu'

## computing the importance and storing the old params
params, importance = compute_mas_params(asr_model_old,asr_model_old._train_dl,device)
with open(f'{log_dir}/mas.pkl','wb') as writer:
pickle.dump({'params':params,'importance':importance},writer)

del asr_model_old, trainer

gc.collect()
torch.cuda.empty_cache()

else:
## load the param and importance
with open(f'{log_dir}/mas.pkl','rb') as reader:
saved = pickle.load(reader)
params, importance = saved['params'],saved['importance']

cl_params = {
# 'alpha': cl_config.mas_params.alpha,
'lamda': cl_config.mas_params.lda,
'params': params,
'importance': importance
}

trainer = pl.Trainer(**cfg.trainer)
asr_model = EncDecHybridRNNTCTCBPEModelMAS(cfg=cfg.model, trainer=trainer, cl_params=cl_params)
asr_model.maybe_init_from_pretrained_checkpoint(cfg)
trainer.fit(asr_model)
elif cl_config.name == 'LWF':
pass
else:
raise 'Error'

if hasattr(cfg.model, 'test_ds') and cfg.model.test_ds.manifest_filepath is not None:
if asr_model.prepare_test(trainer):
Expand Down
6 changes: 5 additions & 1 deletion nemo/collections/asr/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,11 @@
from nemo.collections.asr.models.ctc_bpe_models import EncDecCTCModelBPE
from nemo.collections.asr.models.ctc_models import EncDecCTCModel
from nemo.collections.asr.models.enhancement_models import EncMaskDecAudioToAudioModel
from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import EncDecHybridRNNTCTCBPEModel
from nemo.collections.asr.models.hybrid_rnnt_ctc_bpe_models import (
EncDecHybridRNNTCTCBPEModel,
EncDecHybridRNNTCTCBPEModelEWC,
EncDecHybridRNNTCTCBPEModelMAS
)
from nemo.collections.asr.models.hybrid_rnnt_ctc_models import EncDecHybridRNNTCTCModel
from nemo.collections.asr.models.k2_sequence_models import (
EncDecK2RnntSeqModel,
Expand Down
Loading

0 comments on commit 761165a

Please sign in to comment.