-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathsim_expr.py
32 lines (26 loc) · 1.19 KB
/
sim_expr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
import dar_type
def not_neutral(categories: dar_type.MNLICategories) -> float:
for category in categories:
if category["label"].lower() == "neutral":
return 1 - category["score"]
raise Exception("no neutral category")
def entail_only(categories: dar_type.MNLICategories) -> float:
for category in categories:
if category["label"].lower() == "entailment":
return category["score"]
raise Exception("no entailment category")
def entail_contradict(categories: dar_type.MNLICategories) -> float:
entail_score = None
contradict_score = None
for category in categories:
if category["label"].lower() == "entailment":
if entail_score is not None:
raise Exception("multiple entailment scores")
entail_score = category["score"]
elif category["label"].lower() == "contradiction":
if contradict_score is not None:
raise Exception("multiple contradiction scores")
contradict_score = category["score"]
if entail_score is None or contradict_score is None:
raise Exception("no entailment or contradict category")
return entail_score - contradict_score