From 1a5f3656dce239b4219db6f2d9f10e91fd47b50f Mon Sep 17 00:00:00 2001 From: Benjamin Minixhofer Date: Thu, 14 Dec 2023 14:42:47 +0000 Subject: [PATCH] comments --- wtpsplit/extract.py | 2 +- wtpsplit/train/train.py | 7 +++++++ wtpsplit/train/trainer.py | 1 + wtpsplit/utils.py | 2 +- 4 files changed, 10 insertions(+), 2 deletions(-) diff --git a/wtpsplit/extract.py b/wtpsplit/extract.py index 96718bb0..bfe5bb9c 100644 --- a/wtpsplit/extract.py +++ b/wtpsplit/extract.py @@ -58,7 +58,7 @@ def __call__(self, hashed_ids, attention_mask, language_ids=None): return {"logits": logits} - +# TODO: comment def extract( batch_of_texts, model, diff --git a/wtpsplit/train/train.py b/wtpsplit/train/train.py index 538b699b..41bc4f0f 100644 --- a/wtpsplit/train/train.py +++ b/wtpsplit/train/train.py @@ -77,6 +77,7 @@ def forward( losses = [] + # main (newline prediction) objective if self.do_sentence_training: sentence_labels = (0.5 - self.loss_margin) + (labels == Constants.NEWLINE_INDEX + 1).to( logits.dtype @@ -95,6 +96,7 @@ def forward( / reduced_attention_mask.sum() ) + # auxiliary (punctuation prediction) objective if self.do_auxiliary_training: loss_fn = nn.CrossEntropyLoss() @@ -166,6 +168,7 @@ def collate_fn(batch, args, label_args, label_dict): all_label_weights = [] for sample in batch: + # NOTE: this is specific to characters at the moment! input_ids = [ord(c) for c in sample[args.text_column]] lang = sample["lang"] @@ -278,6 +281,7 @@ def prepare_dataset( if shuffle: dataset = dataset.shuffle(seed=42) + # very likely not relevant / used only for the compound part if args.ignore_non_hyphen: with training_args.main_process_first(): dataset = dataset.filter( @@ -285,6 +289,7 @@ def prepare_dataset( num_proc=args.preprocessing_num_workers, ) + # "punctuation-specific sampling" in the paper if args.non_punctuation_sample_ratio is not None: languages_without_punctuation = { lang_code @@ -330,6 +335,7 @@ def drop_some_non_punctuation_samples(examples): num_proc=num_workers, ) + # similar to group_texts in huggingface's run_clm.py / run_mlm.py: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py def group_texts(examples): all_input_blocks = [] all_input_block_lengths = [] @@ -352,6 +358,7 @@ def maybe_pad(text): if lang == current_lang ] + # pack_samples used for the compound part, so irrelevant if args.pack_samples: blocks = [] block_ids = [] diff --git a/wtpsplit/train/trainer.py b/wtpsplit/train/trainer.py index 85b0ef57..adf8fad8 100644 --- a/wtpsplit/train/trainer.py +++ b/wtpsplit/train/trainer.py @@ -124,6 +124,7 @@ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optim if self.lr_scheduler is None: warmup_steps = self.args.get_warmup_steps(self.args.max_steps) + # MODIFIED: add lang adapter lr scheduler def lr_lambda(current_step: int): if current_step < self.args.adapter_warmup_steps: return 0.0 diff --git a/wtpsplit/utils.py b/wtpsplit/utils.py index fa9ad84f..6fdfb04b 100644 --- a/wtpsplit/utils.py +++ b/wtpsplit/utils.py @@ -116,7 +116,7 @@ def lang_code_to_lang(lang_code): except KeyError: return languages.get(part3=lang_code).name - +# does the steps in Figure 2 of the paper def corrupt( input_ids, block_ids,