Skip to content

Commit

Permalink
Small refactoring
Browse files Browse the repository at this point in the history
  • Loading branch information
Vitaly Protasov committed Mar 23, 2024
1 parent aa7af38 commit 707c852
Show file tree
Hide file tree
Showing 3 changed files with 80 additions and 100 deletions.
9 changes: 4 additions & 5 deletions probing/data_former.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,9 +22,10 @@ def __init__(
):
self.probe_task = probe_task
self.shuffle = shuffle
self.sep = sep
self.data_path = get_probe_task_path(probe_task, data_path)

self.samples, self.unique_labels = self.form_data(sep=sep)
self.samples, self.unique_labels = self.form_data()

def __len__(self):
return len(self.samples)
Expand All @@ -42,12 +43,10 @@ def ratio_by_classes(self) -> Dict[str, Dict[str, int]]:
return ratio_by_classes

@typing.no_type_check
def form_data(
self, sep: str = "\t"
) -> Tuple[DefaultDict[str, np.ndarray], Set[str]]:
def form_data(self) -> Tuple[DefaultDict[str, np.ndarray], Set[str]]:
samples_dict = defaultdict(list)
unique_labels = set()
dataset = pd.read_csv(self.data_path, sep=sep, header=None, dtype=str)
dataset = pd.read_csv(self.data_path, sep=self.sep, header=None, dtype=str)
for _, (stage, label, text) in dataset.iterrows():
samples_dict[stage].append((text, label))
unique_labels.add(label)
Expand Down
124 changes: 76 additions & 48 deletions probing/pipeline.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import os
from collections import Counter
from time import time
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Literal, Optional, Tuple, Union

import numpy as np
import torch
Expand Down Expand Up @@ -89,12 +88,21 @@ def train(self, train_loader: DataLoader, layer: int) -> float:
epoch_train_losses = []
self.classifier.train()
for i, batch in enumerate(train_loader):
# x is already on device since it was passed through the model
y = batch[1].to(self.transformer_model.device, non_blocking=True)

x = batch[0].permute(1, 0, 2)
x = torch.squeeze(x[layer], 0).float()
x = torch.unsqueeze(x, 0) if len(x.size()) == 1 else x
if len(batch[0].size()) == 3:
# x is already on device since it was passed through the model
x = batch[0].permute(1, 0, 2)
x = torch.squeeze(x[layer], 0).float()
x = torch.unsqueeze(x, 0) if len(x.size()) == 1 else x
elif len(x.size()) == 2:
logger.warning(
"Note that you provide 2-D tensor, which means that you consider not ",
"layerwise probing task, but rather an output from some specific layer.",
)
x = x.to(self.transformer_model.device, non_blocking=True)
else:
raise NotImplementedError()

self.classifier.zero_grad(set_to_none=True)
prediction = self.classifier(x)
Expand All @@ -120,12 +128,21 @@ def evaluate(
self.classifier.eval()
with torch.no_grad():
for x, y in dataloader:
# x is already on device since it was passed through the model
y = y.to(self.transformer_model.device, non_blocking=True)

x = x.permute(1, 0, 2)
x = torch.squeeze(x[layer], 0).float()
x = torch.unsqueeze(x, 0) if len(x.size()) == 1 else x
if len(x.size()) == 3:
# x is already on device since it was passed through the model
x = x.permute(1, 0, 2)
x = torch.squeeze(x[layer], 0).float()
x = torch.unsqueeze(x, 0) if len(x.size()) == 1 else x
elif len(x.size()) == 2:
logger.warning(
"Note that you provide 2-D tensor, which means that you consider not ",
"layerwise probing task, but rather an output from some specific layer.",
)
x = x.to(self.transformer_model.device, non_blocking=True)
else:
raise NotImplementedError()

prediction = self.classifier(x)
loss = self.criterion(prediction, y)
Expand All @@ -145,50 +162,48 @@ def run(
self,
probe_task: Union[UDProbingTaskName, str],
path_to_task_file: Optional[os.PathLike] = None,
probing_dataloaders: Dict[Literal["tr", "va", "te"], DataLoader] = None,
mapped_labels: Dict[str, int] = None,
train_epochs: int = 10,
is_scheduler: bool = False,
do_control_task: bool = False,
save_checkpoints: bool = False,
verbose: bool = True,
do_control_task: bool = False,
sep: str = "\t",
) -> None:
task_data = TextFormer(probe_task, path_to_task_file, sep)
task_dataset, num_classes = task_data.samples, len(task_data.unique_labels)
task_language, task_category = lang_category_extraction(task_data.data_path)

self.log_info["params"]["probing_task"] = probe_task
self.log_info["params"]["file_path"] = task_data.data_path
self.log_info["params"]["task_language"] = task_language
self.log_info["params"]["task_category"] = task_category
self.log_info["params"]["probing_type"] = self.probing_type
self.log_info["params"]["encoding_batch_size"] = self.encoding_batch_size
self.log_info["params"]["classifier_batch_size"] = self.classifier_batch_size
self.log_info["params"][
"hf_model_name"
] = self.transformer_model.config._name_or_path
self.log_info["params"]["classifier_name"] = self.classifier_name
self.log_info["params"]["metric_names"] = self.metric_names
self.log_info["params"]["original_classes_ratio"] = task_data.ratio_by_classes
if path_to_task_file:
task_data = TextFormer(probe_task, path_to_task_file)
task_dataset, num_classes = task_data.samples, len(task_data.unique_labels)
probing_task_language, probing_task_category = lang_category_extraction(
task_data.data_path
)
probing_file_path = task_data.data_path
original_classes_ratio = task_data.ratio_by_classes
elif probing_dataloaders:
unique_classes = np.unique(
[item for sublist in probing_dataloaders["tr"] for item in sublist[1]]
)
num_classes = len(unique_classes)
else:
raise NotImplementedError()

if verbose:
print(
f"Task in progress: {probe_task}\nPath to data: {task_data.data_path}"
)
print(f"Task in progress: {probe_task}\nPath to data: {probing_file_path}")

clear_memory()
start_time = time()
(
probing_dataloaders,
mapped_labels,
) = self.transformer_model.get_encoded_dataloaders(
task_dataset,
self.encoding_batch_size,
self.classifier_batch_size,
self.shuffle,
self.aggregation_embeddings,
verbose,
do_control_task=do_control_task,
)
if path_to_task_file:
(
probing_dataloaders,
mapped_labels,
) = self.transformer_model.get_encoded_dataloaders(
task_dataset,
self.encoding_batch_size,
self.classifier_batch_size,
self.shuffle,
self.aggregation_embeddings,
verbose,
do_control_task=do_control_task,
)

probing_iter_range = (
trange(
Expand All @@ -198,8 +213,6 @@ def run(
if verbose
else range(self.transformer_model.config.num_hidden_layers)
)
self.log_info["params"]["tr_mapped_labels"] = mapped_labels
self.log_info["results"]["elapsed_time(sec)"] = 0

for layer in probing_iter_range:
self.classifier = self.get_classifier(
Expand All @@ -211,9 +224,8 @@ def run(
# getting weights for each label in order to provide it further to the loss function
# be sure that the last element in each data sample is a label!
tr_labels = torch.cat(
[element[-1] for element in list(probing_dataloaders["tr"])]
[element[-1] for element in iter(probing_dataloaders["tr"])]
).tolist()
# self.log_info["params"]["train_classes_ratio"] = Counter(tr_labels)

class_weights = compute_class_weight(
"balanced", classes=np.unique(tr_labels), y=tr_labels
Expand Down Expand Up @@ -264,7 +276,23 @@ def run(
layer, epoch_test_score[m]
)

self.log_info["params"]["mapped_labels"] = mapped_labels
self.log_info["results"]["elapsed_time(sec)"] = time() - start_time
self.log_info["params"]["probing_task"] = probe_task
self.log_info["params"]["file_path"] = probing_file_path
self.log_info["params"]["task_language"] = probing_task_language
self.log_info["params"]["task_category"] = probing_task_category
self.log_info["params"]["original_classes_ratio"] = original_classes_ratio

self.log_info["params"]["probing_type"] = self.probing_type
self.log_info["params"]["encoding_batch_size"] = self.encoding_batch_size
self.log_info["params"]["classifier_batch_size"] = self.classifier_batch_size
self.log_info["params"][
"hf_model_name"
] = self.transformer_model.config._name_or_path
self.log_info["params"]["classifier_name"] = self.classifier_name
self.log_info["params"]["metric_names"] = self.metric_names

output_path = self.log_info.save_log(probe_task)
if verbose:
print(f"Experiments were saved in the folder: {str(output_path)}")
Expand Down
47 changes: 0 additions & 47 deletions tests/filter_test/test_filtering_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -108,53 +108,6 @@ def test_upload_files_no_paths(self):
with self.assertRaises(Exception):
self.probing_filter.upload_files()

def test_filter_and_convert_all_saved(self):
queries_sents = {
"ADVCL": [
"Она решила попытаться остановить машину — хотя выйдя под дождь , сразу же промокла насквозь .",
"И охота завыть , вскинув морду к луне .",
"И не предложит выпить , если ты решил жить трезвым .",
"Смерть твоя — настолько благая весть , что посовестись — и умри !",
"Ну , ложись им под ноги , в прах ложись , потому что уже пора !",
"В печали ль , в радости ль , во хмелю , в потемках земельных недр , Я вас всей кровью своей люблю , сады мои — метр на метр !",
"Как защитить их , себя казня , до жуткой храня поры ?",
"Как сообщается в пресс - релизе университета , программу можно использовать на любом смартфоне .",
"Вячеслав , почему бы Вам не возглавить КПРФ Пока оно ещё есть .",
"Если ты устал , иди спать .",
"Если ты голодный , иди есть .",
"Когда на улице темно , надо быть осторожным .",
"Когда друзья тебя не слышат , не надо быть настойчивым .",
"Мы всё сдадим , потому что мы хорошие студенты .",
],
"ACL": [
"Счастье это качество , не имеющее будущего и прошлого .",
"Но есть мужчина , которого я не хотела бы потерять ...",
"Среди разных сыновей был один , который звал себя Сыном Божьим .",
"Неужто вправду сгорел тот мост , которым я к ним пройду ?!",
"Она заставляет смартфон постоянно испускать высокочастотный звук , неразличимый для человеческого уха , но улавливаемый микрофоном устройства .",
"То , что никакого отношения к ним не имеет",
'Депутат ЛДПР , которого не пустили в " Европейский ", объяснил причину конфликта с охранниками',
"И пусть всё то , что кажется так сложно , решается красиво и легко !",
"Пришел мальчик , которому мама не дает конфеты .",
"Я увидел девочку , которая очень хочет спать .",
"Я увидел женщину , которую показывали в новостях .",
"Мальчик взял игрушку , с которой не расставался с самого рождения .",
"Девочка съела кашу , которую для нее приготовил папа .",
],
}

task_dir = TemporaryDirectory()
self.probing_filter.upload_files(dir_conllu_path=self.dir_conllu_path)
self.probing_filter.filter_and_convert(
queries=self.queries,
save_dir_path=task_dir.name,
task_name="cl",
)
self.assertEqual(queries_sents, self.probing_filter.probing_dict)
with open(f"{task_dir.name}/ru_taiga_cl.csv") as f:
self.assertEqual(27, len(f.readlines()))
task_dir.cleanup()

def test_filter_and_convert_too_few_sentences(self):
self.probing_filter.upload_files(
dir_conllu_path=Path(
Expand Down

0 comments on commit 707c852

Please sign in to comment.