-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
bs_sent np optimization and more confs
- Loading branch information
Showing
6 changed files
with
73 additions
and
66 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,30 +1,32 @@ | ||
import sys | ||
from os import path | ||
|
||
file_path = path.abspath(__file__) | ||
sys.path.append(path.dirname(path.dirname(file_path))) | ||
|
||
import typing | ||
from bertscore_sentence import eval | ||
import numpy as np | ||
from mnli.sim import similarity | ||
from dar_env import nlp | ||
from dar_env import nlp_spacy | ||
import functools | ||
import transformers | ||
|
||
|
||
def mnli_sim_mat(cand, ref) -> np.ndarray: | ||
def mnli_sim_mat(cand: str, ref: str, classifier: transformers.Pipeline) -> np.ndarray: | ||
def segmentation(piece: str): | ||
doc = nlp(piece) | ||
doc = nlp_spacy(piece) | ||
doc_sents = [sent.text for sent in doc.sents] | ||
return doc_sents | ||
|
||
cand_sentences = segmentation(cand) | ||
ref_sentences = segmentation(ref) | ||
sim_mat = np.zeros((len(ref_sentences), len(cand_sentences))) | ||
for i in range(len(ref_sentences)): | ||
for j in range(len(cand_sentences)): | ||
sim_mat[i][j] = similarity(ref_sentences[i], cand_sentences[j]) | ||
return sim_mat, cand_sentences, ref_sentences | ||
cand_sents = segmentation(cand) | ||
ref_sents = segmentation(ref) | ||
sent_pairs = [" ".join([x, y]) for x in ref_sents for y in cand_sents] | ||
sim_mat = np.empty((len(ref_sents), len(cand_sents))) | ||
sim_mat.flat = similarity(sent_pairs, classifier) | ||
return sim_mat, cand_sents, ref_sents | ||
|
||
|
||
def bertscore_sentence_compute(predictions: typing.List[str], references: typing.List[str]) -> typing.Dict: | ||
return eval.compute(predictions=predictions, references=references, sim_mat_f=mnli_sim_mat) | ||
def bertscore_sentence_compute(predictions: typing.List[str], references: typing.List[str], classifier: transformers.Pipeline) -> typing.Dict: | ||
sim_mat_f = functools.partial(mnli_sim_mat, classifier=classifier) | ||
sim_mat_f.__name__ = " ".join(["mnli", classifier.__name__]) | ||
return eval.compute(predictions=predictions, references=references, sim_mat_f=sim_mat_f) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,21 +1,24 @@ | ||
import sys | ||
from os import path | ||
|
||
file_path = path.abspath(__file__) | ||
sys.path.append(path.dirname(path.dirname(file_path))) | ||
|
||
from dar_env import mnli_classifier | ||
import typing | ||
import transformers | ||
|
||
|
||
def similarity(sentence_a: str, sentence_b: str): | ||
sequence = " ".join([sentence_a, sentence_b]) | ||
classes = mnli_classifier(sequence) | ||
for c in classes[0]: | ||
if c["label"] == "NEUTRAL": | ||
return 1 - c["score"] | ||
raise Exception("Not found NEUTRAL class") | ||
def similarity(sent_pairs: typing.List[str], classifier: transformers.Pipeline): | ||
classes = classifier(sent_pairs) | ||
scores = [] | ||
for c in classes: | ||
for category in c: | ||
if category["label"] == "NEUTRAL": | ||
scores.append(1 - category["score"]) | ||
break | ||
return scores | ||
|
||
|
||
if __name__ == "__main__": | ||
print(similarity("Each computer program uses a region of memory called the stack to enable functions to work properly.", | ||
"From the outside, Les 4G, a Lyonnais bouchon (traditional restaurant), looked much like the nondescript cafe-cum-tobacco shops that can be found in most small French towns, but inside the decor was as warm and inviting as a country pub.")) | ||
sample_a = "Each computer program uses a region of memory called the stack to enable functions to work properly." | ||
sample_b = "From the outside, Les 4G, a Lyonnais bouchon (traditional restaurant), looked much like the nondescript cafe-cum-tobacco shops that can be found in most small French towns, but inside the decor was as warm and inviting as a country pub." | ||
print(similarity([" ".join([sample_a, sample_b])])) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters