Skip to content

Commit

Permalink
fix padding in extract.py, fix intrinsic.py old usage of extract()
Browse files Browse the repository at this point in the history
  • Loading branch information
bminixhofer committed Dec 2, 2023
1 parent 48ab11a commit 9009374
Show file tree
Hide file tree
Showing 2 changed files with 7 additions and 6 deletions.
9 changes: 5 additions & 4 deletions wtpsplit/evaluation/intrinsic.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,9 @@
from tqdm.auto import tqdm
from transformers import AutoModelForTokenClassification, HfArgumentParser

import wtpsplit.models

Check failure on line 13 in wtpsplit/evaluation/intrinsic.py

View workflow job for this annotation

GitHub Actions / build (3.8)

Ruff (F401)

wtpsplit/evaluation/intrinsic.py:13:8: F401 `wtpsplit.models` imported but unused
from wtpsplit.evaluation import evaluate_mixture, get_labels, train_mixture
from wtpsplit.extract import extract
from wtpsplit.extract import PyTorchWrapper, extract
from wtpsplit.utils import Constants


Expand Down Expand Up @@ -91,7 +92,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
block_size=args.block_size,
batch_size=args.batch_size,
pad_last_batch=True,
)[0].numpy()
)[0]
test_labels = get_labels(lang_code, test_sentences, after_space=False)

dset_group.create_dataset("test_logits", data=test_logits)
Expand All @@ -110,7 +111,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
block_size=args.block_size,
batch_size=args.batch_size,
pad_last_batch=False,
)[0].numpy()
)[0]
train_labels = get_labels(lang_code, train_sentences, after_space=False)

dset_group.create_dataset("train_logits", data=train_logits)
Expand All @@ -128,7 +129,7 @@ def load_or_compute_logits(args, model, eval_data, valid_data=None, max_n_train_
else:
valid_data = None

model = AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device)
model = PyTorchWrapper(AutoModelForTokenClassification.from_pretrained(args.model_path).to(args.device))

# first, logits for everything.
f = load_or_compute_logits(args, model, eval_data, valid_data)
Expand Down
4 changes: 2 additions & 2 deletions wtpsplit/extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,8 +149,8 @@ def extract(
if len(batch_input_hashes) < batch_size and pad_last_batch:
n_missing = batch_size - len(batch_input_hashes)

batch_input_hashes = np.pad(batch_input_hashes, (0, n_missing, 0, 0, 0, 0))
batch_attention_mask = np.pad(batch_attention_mask, (0, n_missing, 0, 0))
batch_input_hashes = np.pad(batch_input_hashes, ((0, n_missing), (0, 0), (0, 0)))
batch_attention_mask = np.pad(batch_attention_mask, ((0, n_missing), (0, 0)))

kwargs = {"language_ids": language_ids[: len(batch_input_hashes)]} if uses_lang_adapters else {}

Expand Down

0 comments on commit 9009374

Please sign in to comment.