Skip to content

Commit

Permalink
comments
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Dec 14, 2023
1 parent 8289739 commit 1a5f365
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 2 deletions.
2 changes: 1 addition & 1 deletion wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ def __call__(self, hashed_ids, attention_mask, language_ids=None):

return {"logits": logits}


# TODO: comment
def extract(
batch_of_texts,
model,
Expand Down
7 changes: 7 additions & 0 deletions wtpsplit/train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ def forward(

losses = []

# main (newline prediction) objective
if self.do_sentence_training:
sentence_labels = (0.5 - self.loss_margin) + (labels == Constants.NEWLINE_INDEX + 1).to(
logits.dtype
Expand All @@ -95,6 +96,7 @@ def forward(
/ reduced_attention_mask.sum()
)

# auxiliary (punctuation prediction) objective
if self.do_auxiliary_training:
loss_fn = nn.CrossEntropyLoss()

Expand Down Expand Up @@ -166,6 +168,7 @@ def collate_fn(batch, args, label_args, label_dict):
all_label_weights = []

for sample in batch:
# NOTE: this is specific to characters at the moment!
input_ids = [ord(c) for c in sample[args.text_column]]
lang = sample["lang"]

Expand Down Expand Up @@ -278,13 +281,15 @@ def prepare_dataset(
if shuffle:
dataset = dataset.shuffle(seed=42)

# very likely not relevant / used only for the compound part
if args.ignore_non_hyphen:
with training_args.main_process_first():
dataset = dataset.filter(
lambda sample: any(c in sample[args.text_column] for c in label_args.hyphen_chars),
num_proc=args.preprocessing_num_workers,
)

# "punctuation-specific sampling" in the paper
if args.non_punctuation_sample_ratio is not None:
languages_without_punctuation = {
lang_code
Expand Down Expand Up @@ -330,6 +335,7 @@ def drop_some_non_punctuation_samples(examples):
num_proc=num_workers,
)

# similar to group_texts in huggingface's run_clm.py / run_mlm.py: https://github.com/huggingface/transformers/blob/main/examples/pytorch/language-modeling/run_mlm.py
def group_texts(examples):
all_input_blocks = []
all_input_block_lengths = []
Expand All @@ -352,6 +358,7 @@ def maybe_pad(text):
if lang == current_lang
]

# pack_samples used for the compound part, so irrelevant
if args.pack_samples:
blocks = []
block_ids = []
Expand Down
1 change: 1 addition & 0 deletions wtpsplit/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,6 +124,7 @@ def create_scheduler(self, num_training_steps: int, optimizer: torch.optim.Optim
if self.lr_scheduler is None:
warmup_steps = self.args.get_warmup_steps(self.args.max_steps)

# MODIFIED: add lang adapter lr scheduler
def lr_lambda(current_step: int):
if current_step < self.args.adapter_warmup_steps:
return 0.0
Expand Down
2 changes: 1 addition & 1 deletion wtpsplit/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ def lang_code_to_lang(lang_code):
except KeyError:
return languages.get(part3=lang_code).name


# does the steps in Figure 2 of the paper
def corrupt(
input_ids,
block_ids,
Expand Down

0 comments on commit 1a5f365

Please sign in to comment.