Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add perform_wsi_on_ds_gen function #4

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 4 additions & 4 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
tqdm==4.26.0
pytorch_pretrained_bert==0.6.1
torch==1.0.0
numpy==1.15.1
tqdm>=4.27
transformers==4.8.1
torch==1.9.0
numpy>=1.17
spacy==2.1.3
scipy==1.1.0
scikit_learn==0.21.1
Expand Down
3 changes: 2 additions & 1 deletion wsi/WSISettings.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

WSISettings = namedtuple('WSISettings', ['n_represents', 'n_samples_per_rep', 'cuda_device', 'debug_dir',
'disable_tfidf', 'disable_lemmatization', 'run_name', 'patterns',
'min_sense_instances', 'bert_model',
'min_sense_instances', 'bert_model', 'spacy_lang',
'max_batch_size', 'prediction_cutoff', 'max_number_senses',
])

Expand All @@ -26,6 +26,7 @@
# sense clusters that dominate less than this number of samples
# would be remapped to their closest big sense

spacy_lang="en",
max_batch_size=10,
prediction_cutoff=200,
bert_model='bert-large-uncased'
Expand Down
10 changes: 5 additions & 5 deletions wsi/lm_bert.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from .slm_interface import SLM
import multiprocessing
from pytorch_pretrained_bert import BertForMaskedLM, tokenization
from transformers import BertForMaskedLM, BertTokenizer
import torch
import numpy as np
from tqdm import tqdm
Expand All @@ -26,7 +26,7 @@ def get_batches(from_iter, group_size):

class LMBert(SLM):

def __init__(self, cuda_device, bert_model, max_batch_size=20):
def __init__(self, cuda_device, bert_model, spacy_lang="en", max_batch_size=20):
super().__init__()
logging.info(
'creating bert in device %d. bert ath %s'
Expand All @@ -43,7 +43,7 @@ def __init__(self, cuda_device, bert_model, max_batch_size=20):
model.eval()
self.bert = model

self.tokenizer = tokenization.BertTokenizer.from_pretrained(bert_model)
self.tokenizer = BertTokenizer.from_pretrained(bert_model)

self.max_sent_len = model.config.max_position_embeddings
# self.max_sent_len = config.max_position_embeddings
Expand All @@ -54,7 +54,7 @@ def __init__(self, cuda_device, bert_model, max_batch_size=20):
self.original_vocab = []

import spacy
nlp = spacy.load("en", disable=['ner', 'parser'])
nlp = spacy.load(spacy_lang, disable=['ner', 'parser'])
self._lemmas_cache = {}
self._spacy = nlp
for spacyed in tqdm(
Expand Down Expand Up @@ -141,7 +141,7 @@ def predict_sent_substitute_representatives(self, inst_id_to_sentence: Dict[str,

torch_mask = torch_input_ids != 0

logits_all_tokens = self.bert(torch_input_ids, attention_mask=torch_mask)
logits_all_tokens = self.bert(torch_input_ids, attention_mask=torch_mask).logits

logits_target_tokens = torch.zeros((len(batch_sents), logits_all_tokens.shape[2])).to(self.device)
for i in range(0, len(batch_sents)):
Expand Down
55 changes: 55 additions & 0 deletions wsi/wsi.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,3 +124,58 @@ def run(self, wsisettings: WSISettings,
print(msg)

return scores2010['all'], scores2013['all']


from typing import Dict, List, Tuple


def perform_wsi_on_ds_gen(
lm: SLM,
ds_name: str,
gen: List[Tuple[str, str, str, str]],
wsisettings: WSISettings,
print_progress=False,
) -> Dict[str, Dict[str, int]]:

ds_by_target = defaultdict(dict)
for pre, target, post, inst_id in gen:
lemma_pos = inst_id.rsplit('.', 1)[0]
ds_by_target[lemma_pos][inst_id] = (pre, target, post)

inst_id_to_sense = {}
gen = ds_by_target.items()
if print_progress:
gen = tqdm(gen, desc=f'predicting substitutes {ds_name}')
for lemma_pos, inst_id_to_sentence in gen:
inst_ids_to_representatives = \
lm.predict_sent_substitute_representatives(
inst_id_to_sentence=inst_id_to_sentence,
wsisettings=wsisettings,
)

clusters, statistics = cluster_inst_ids_representatives(
inst_ids_to_representatives=inst_ids_to_representatives,
max_number_senses=wsisettings.max_number_senses,
min_sense_instances=wsisettings.min_sense_instances,
disable_tfidf=wsisettings.disable_tfidf,
explain_features=True,
)
inst_id_to_sense.update(clusters)
if statistics:
logging.info('Sense cluster statistics:')
for idx, (rep_count, best_features, best_features_pmi, best_instance_id) in enumerate(statistics):
best_instance = ds_by_target[lemma_pos][best_instance_id]
nice_print_instance = f'{best_instance[0]} -{best_instance[1]}- {best_instance[2]}'
logging.info(
f'Sense {idx}, # reps: {rep_count}, best feature words: {", ".join(best_features)}.'
f', best feature words(PMI): {", ".join(best_features_pmi)}.'
f' closest instance({best_instance_id}):\n---\n{nice_print_instance}\n---\n')

out_key_path = None
if wsisettings.debug_dir:
out_key_path = os.path.join(wsisettings.debug_dir, f'{wsisettings.run_name}-{ds_name}.key')

if print_progress:
print(f'writing {ds_name} key file to %s' % out_key_path)

return inst_id_to_sense
1 change: 1 addition & 0 deletions wsi_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
startmsg = startmsg.strip()

lm = LMBert(settings.cuda_device, settings.bert_model,
spacy_lang=settings.spacy_lang,
max_batch_size=settings.max_batch_size)

if settings.debug_dir:
Expand Down