From b422cfcaede860d999b04f9e983e383a90b18718 Mon Sep 17 00:00:00 2001 From: Vincent Nguyen Date: Mon, 18 Mar 2024 16:33:22 +0100 Subject: [PATCH] Load data at the correct position when resuming from a checkpoint --- onmt/inputters/dynamic_iterator.py | 26 ++++++++- onmt/inputters/text_corpus.py | 46 ++++++++++++--- onmt/models/model_saver.py | 92 +++++++++++++++++++++++++++++- onmt/opts.py | 7 +++ onmt/train_single.py | 13 +++-- onmt/trainer.py | 5 ++ 6 files changed, 173 insertions(+), 16 deletions(-) diff --git a/onmt/inputters/dynamic_iterator.py b/onmt/inputters/dynamic_iterator.py index 8289d1a3ce..db7446742f 100644 --- a/onmt/inputters/dynamic_iterator.py +++ b/onmt/inputters/dynamic_iterator.py @@ -129,6 +129,7 @@ def __init__( batch_type, batch_size, batch_size_multiple, + resume_corpora_info={}, data_type="text", bucket_size=2048, bucket_size_init=-1, @@ -144,6 +145,7 @@ def __init__( self.transforms = transforms self.vocabs = vocabs self.corpora_info = corpora_info + self.resume_corpora_info = resume_corpora_info self.task = task self.init_iterators = False self.batch_size = batch_size @@ -171,7 +173,17 @@ def __init__( @classmethod def from_opt( - cls, corpora, transforms, vocabs, opt, task, copy, device, stride=1, offset=0 + cls, + corpora, + transforms, + vocabs, + opt, + task, + copy, + device, + resume_corpora_info={}, + stride=1, + offset=0, ): """Initilize `DynamicDatasetIter` with options parsed from `opt`.""" corpora_info = {} @@ -206,6 +218,7 @@ def from_opt( opt.batch_type, batch_size, batch_size_multiple, + resume_corpora_info=resume_corpora_info, data_type=opt.data_type, bucket_size=bucket_size, bucket_size_init=bucket_size_init, @@ -388,6 +401,7 @@ def build_dynamic_dataset_iter( vocabs, copy=False, task=CorpusTask.TRAIN, + resume_corpora_info={}, stride=1, offset=0, src=None, @@ -412,7 +426,14 @@ def build_dynamic_dataset_iter( advance to avoid the GPU waiting during the refilling of the bucket. """ transforms = make_transforms(opt, transforms_cls, vocabs) - corpora = get_corpora(opt, task, src=src, tgt=tgt, align=align) + corpora = get_corpora( + opt, + task, + src=src, + tgt=tgt, + align=align, + resume_corpora_info=resume_corpora_info, + ) if corpora is None: assert task != CorpusTask.TRAIN, "only valid corpus is ignorable." return None @@ -442,6 +463,7 @@ def build_dynamic_dataset_iter( vocabs, opt, task, + resume_corpora_info=resume_corpora_info, copy=copy, stride=stride, offset=offset, diff --git a/onmt/inputters/text_corpus.py b/onmt/inputters/text_corpus.py index ca32cbbf0e..6f9ecc170b 100644 --- a/onmt/inputters/text_corpus.py +++ b/onmt/inputters/text_corpus.py @@ -99,7 +99,14 @@ class ParallelCorpus(object): """A parallel corpus file pair that can be loaded to iterate.""" def __init__( - self, name, src, tgt, align=None, n_src_feats=0, src_feats_defaults=None + self, + name, + src, + tgt, + align=None, + n_src_feats=0, + src_feats_defaults=None, + line_number_to_resume=0, ): """Initialize src & tgt side file path.""" self.id = name @@ -108,6 +115,12 @@ def __init__( self.align = align self.n_src_feats = n_src_feats self.src_feats_defaults = src_feats_defaults + self.line_number_to_resume = line_number_to_resume + self.can_read_file = False + + def activate_reading_mode(self, line_number): + if line_number >= self.line_number_to_resume: + self.can_read_file = True def load(self, offset=0, stride=1): """ @@ -116,7 +129,7 @@ def load(self, offset=0, stride=1): `stride` example, starting from `offset`. """ - def make_ex(sline, tline, align): + def make_ex(sline, tline, align, line_number): sline, sfeats = parse_features( sline, n_feats=self.n_src_feats, @@ -131,6 +144,7 @@ def make_ex(sline, tline, align): "tgt": tline, "src_original": sline, "tgt_original": tline, + "cid_line_number": line_number, } if align is not None: example["align"] = align @@ -145,19 +159,25 @@ def make_ex(sline, tline, align): for i, (sline, tline, align) in enumerate( itertools.zip_longest(fs, ft, fa) ): + self.activate_reading_mode(line_number=i) + if not self.can_read_file: + continue if (i // stride) % stride == offset: - yield make_ex(sline, tline, align) + yield make_ex(sline, tline, align, i) else: with exfile_open(self.src, mode="rb") as fs, exfile_open( self.tgt, mode="rb" ) as ft, exfile_open(self.align, mode="rb") as fa: for i, (sline, tline, align) in enumerate(zip(fs, ft, fa)): + self.activate_reading_mode(line_number=i) + if not self.can_read_file: + continue if (i // stride) % stride == offset: if tline is not None: tline = tline.decode("utf-8") if align is not None: align = align.decode("utf-8") - yield make_ex(sline.decode("utf-8"), tline, align) + yield make_ex(sline.decode("utf-8"), tline, align, i) def __str__(self): cls_name = type(self).__name__ @@ -169,12 +189,17 @@ def __str__(self): ) -def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None): +def get_corpora( + opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None, resume_corpora_info={} +): corpora_dict = {} if task == CorpusTask.TRAIN: for corpus_id, corpus_dict in opts.data.items(): if corpus_id != CorpusName.VALID: if corpus_dict.get("path_txt", None) is None: + resume_line = 0 + if corpus_id in resume_corpora_info: + resume_line = resume_corpora_info[corpus_id] corpora_dict[corpus_id] = ParallelCorpus( corpus_id, corpus_dict["path_src"], @@ -182,6 +207,7 @@ def get_corpora(opts, task=CorpusTask.TRAIN, src=None, tgt=None, align=None): corpus_dict["path_align"], n_src_feats=opts.n_src_feats, src_feats_defaults=opts.src_feats_defaults, + line_number_to_resume=resume_line, ) else: corpora_dict[corpus_id] = BlockwiseCorpus( @@ -244,8 +270,6 @@ def _process(self, stream): example["src_feats"] = [ feat.strip().split(" ") for feat in example["src_feats"] ] - line_number = i * self.stride + self.offset - example["cid_line_number"] = line_number example["cid"] = self.cid if "align" in example: example["align"] = example["align"].strip().split(" ") @@ -258,6 +282,7 @@ def _process(self, stream): or ("align" in example and example["align"] == 0) ): # empty example: skip + line_number = example["cid_line_number"] empty_msg = f"Empty line in {self.cid}#{line_number}." if self.skip_empty_level == "error": raise IOError(empty_msg) @@ -282,7 +307,12 @@ def __iter__(self): def build_corpora_iters( - corpora, transforms, corpora_info, skip_empty_level="warning", stride=1, offset=0 + corpora, + transforms, + corpora_info, + skip_empty_level="warning", + stride=1, + offset=0, ): """Return `ParallelCorpusIterator` for all corpora defined in opts.""" corpora_iters = dict() diff --git a/onmt/models/model_saver.py b/onmt/models/model_saver.py index 986ca7ae99..db7383660c 100644 --- a/onmt/models/model_saver.py +++ b/onmt/models/model_saver.py @@ -1,13 +1,17 @@ import os import torch import re +import subprocess from collections import deque +import onmt.utils from onmt.utils.logging import logger from onmt.inputters.inputter import vocabs_to_dict from onmt.modules.lora import lora_state_dict -def build_model_saver(model_opt, opt, model, vocabs, optim, device_id): +def build_model_saver( + model_opt, opt, model, vocabs, optim, resume_corpora_info, device_id +): # _check_save_model_path save_model_path = os.path.abspath(opt.save_model) os.makedirs(os.path.dirname(save_model_path), exist_ok=True) @@ -20,6 +24,7 @@ def build_model_saver(model_opt, opt, model, vocabs, optim, device_id): optim, opt.keep_checkpoint, opt.save_format, + resume_corpora_info, device_id, ) return model_saver @@ -81,6 +86,65 @@ def fix_key(s): return checkpoint +def load_corpora_info(opts, checkpoint): + message_resume_from_beginning = ( + "The training will resume from the beginning of each corpus." + ) + # Check if resume_from_corpora is True + if not opts.resume_from_corpora: + logger.info( + "No resume from corpora is specified. " + message_resume_from_beginning + ) + return {} + + # Check if the corpus list from the last training + # and in the new training are identical. + checkpoint_corpora = checkpoint.get("corpus_info", None) + if checkpoint_corpora is None: + logger.info( + "Incoherent info: Some corpora in the last training " + + "and in the new list do not match. " + + message_resume_from_beginning + ) + return {} + + checkpoint_corpus_names = [name for name in checkpoint_corpora] + new_corpus_names = [name for name in opts.data] + if set(checkpoint_corpus_names) != set(new_corpus_names): + logger.info( + "Incoherent info: Some corpora in the last training " + + "and in the new list do not match. " + + message_resume_from_beginning + ) + return {} + + # For each corpus, check if the last line number to resume + # is smaller than or equal to the number of text lines. + message_incoherent_line_number = ( + "Incoherent info: text line numbers " + + "to resume in some corpora exceed their total numbers of lines. " + + message_resume_from_beginning + ) + for c_name in checkpoint_corpora: + number_of_text_lines = int( + subprocess.getoutput( + "wc -l " + opts.data[c_name]["path_src"] + " | awk '{print $1}'" + ) + ) + if checkpoint_corpora[c_name] > number_of_text_lines - 1: + logger.info(message_incoherent_line_number) + return {} + + # To set the text lines to resume, we increase all text lines by 1 + # (and return to the beginning if the end is reached) + checkpoint_corpora[c_name] = ( + checkpoint_corpora[c_name] + 1 + ) % number_of_text_lines + + logger.info("The training will resume from the saved text line in each corpus.") + return checkpoint_corpora + + class ModelSaverBase(object): """Base class for model saving operations @@ -98,6 +162,7 @@ def __init__( optim, keep_checkpoint=-1, save_format="pytorch", + resume_corpora_info={}, device_id=0, ): self.base_path = base_path @@ -108,6 +173,7 @@ def __init__( self.last_saved_step = None self.keep_checkpoint = keep_checkpoint self.save_format = save_format + self.corpus_info = resume_corpora_info self.device_id = device_id if keep_checkpoint > 0: @@ -115,7 +181,27 @@ def __init__( if save_format == "safetensors": self.model_queue = deque([], maxlen=keep_checkpoint) - def save(self, step, moving_average=None): + def update_corpus_info_from_batches(self, batches, distributed=False): + # Update the last text line of each corpus + if batches is not None: + # Gather corpus line numbers to save to checkpoints + batch_cids = sum([batch["cid"] for batch in batches], []) + batch_cid_line_numbers = sum( + [batch["cid_line_number"] for batch in batches], [] + ) + if distributed: + batch_cids = sum(onmt.utils.distributed.all_gather_list(batch_cids), []) + batch_cid_line_numbers = sum( + onmt.utils.distributed.all_gather_list(batch_cid_line_numbers), [] + ) + # Save the last processed line number of each corpus + new_corpus_info = { + c_name: cid_line_number + for c_name, cid_line_number in zip(batch_cids, batch_cid_line_numbers) + } + self.corpus_info = {**self.corpus_info, **new_corpus_info} + + def save(self, step, moving_average=None, batches=None, distributed=False): """Main entry point for model saver It wraps the `_save` method with checks and apply `keep_checkpoint` @@ -266,6 +352,7 @@ def _save(self, step, model): "vocab": vocabs_to_dict(self.vocabs), "opt": self.model_opt, "optim": self.optim.state_dict(), + "corpus_info": self.corpus_info, } if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: logger.info("Saving checkpoint %s_step_%d.pt" % (self.base_path, step)) @@ -355,6 +442,7 @@ def _st_save(self, step, model): "vocab": vocabs_to_dict(self.vocabs), "opt": self.model_opt, "optim": self.optim.state_dict(), + "corpus_info": self.corpus_info, } if not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0: diff --git a/onmt/opts.py b/onmt/opts.py index 21abd96a3d..6576246aa9 100644 --- a/onmt/opts.py +++ b/onmt/opts.py @@ -1263,6 +1263,13 @@ def _add_train_general_opts(parser): help="If training from a checkpoint then this is the " "path to the pretrained model's state_dict.", ) + group.add( + "--resume_from_corpora", + "-resume_from_corpora", + action="store_true", + help="If training from a checkpoint and this is set to True " + " then the data generator will resume from the last line of each corpora.", + ) group.add( "--reset_optim", "-reset_optim", diff --git a/onmt/train_single.py b/onmt/train_single.py index 76ab3bef66..987c948c67 100644 --- a/onmt/train_single.py +++ b/onmt/train_single.py @@ -17,7 +17,7 @@ from onmt.inputters.dynamic_iterator import build_dynamic_dataset_iter from onmt.inputters.text_corpus import save_transformed_sample from onmt.model_builder import build_model -from onmt.models.model_saver import load_checkpoint +from onmt.models.model_saver import load_checkpoint, load_corpora_info from onmt.utils.optimizers import Optimizer from onmt.utils.misc import set_random_seed from onmt.trainer import build_trainer @@ -80,6 +80,7 @@ def _init_train(opt): if opt.train_from: # Load checkpoint if we resume from a previous training. checkpoint = load_checkpoint(ckpt_path=opt.train_from) + resume_corpora_info = load_corpora_info(opt, checkpoint) vocabs = dict_to_vocabs(checkpoint["vocab"]) if ( hasattr(checkpoint["opt"], "_all_transform") @@ -105,8 +106,9 @@ def _init_train(opt): else: checkpoint = None vocabs = prepare_transforms_vocabs(opt, transforms_cls) + resume_corpora_info = {} - return checkpoint, vocabs, transforms_cls + return checkpoint, resume_corpora_info, vocabs, transforms_cls def configure_process(opt, device_id): @@ -159,7 +161,7 @@ def main(opt, device_id): configure_process(opt, device_id) init_logger(opt.log_file) - checkpoint, vocabs, transforms_cls = _init_train(opt) + checkpoint, resume_corpora_info, vocabs, transforms_cls = _init_train(opt) model_opt = _get_model_opts(opt, checkpoint=checkpoint) # Build model. @@ -197,7 +199,9 @@ def main(opt, device_id): del checkpoint # Build model saver - model_saver = build_model_saver(model_opt, opt, model, vocabs, optim, device_id) + model_saver = build_model_saver( + model_opt, opt, model, vocabs, optim, resume_corpora_info, device_id + ) trainer = build_trainer( opt, device_id, model, vocabs, optim, model_saver=model_saver @@ -211,6 +215,7 @@ def main(opt, device_id): transforms_cls, vocabs, task=CorpusTask.TRAIN, + resume_corpora_info=resume_corpora_info, copy=opt.copy_attn, stride=stride, offset=offset, diff --git a/onmt/trainer.py b/onmt/trainer.py index 6916ec3ba9..911e565cd3 100644 --- a/onmt/trainer.py +++ b/onmt/trainer.py @@ -349,6 +349,11 @@ def train( logger.info("earlystopper has_stopped!") break + self.model_saver.update_corpus_info_from_batches( + batches, + distributed=(self.n_gpu > 1 and self.parallel_mode == "data_parallel"), + ) + if self.model_saver is not None and ( save_checkpoint_steps != 0 and step % save_checkpoint_steps == 0 ):