forked from Anirudh-Muthukumar/Causal-Mediation-Analysis
-
Notifications
You must be signed in to change notification settings - Fork 0
/
winogender.py
118 lines (107 loc) · 4.91 KB
/
winogender.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import csv
import inspect
import os
from experiment import Intervention
def load_examples(path='winogender_data/'):
bergsma_pct_female = {}
bls_pct_female = {}
with open(os.path.join(path, 'winogender_occupation_stats.tsv')) as f:
next(f, None) # skip the headers
for row in csv.reader(f, delimiter='\t'):
occupation = row[0]
bergsma_pct_female[occupation] = float(row[1])
bls_pct_female[occupation] = float(row[2])
examples = []
with open(os.path.join(path, 'winogender_templates_structurefilter.tsv')) as f:
next(f, None) # skip the headers
row_pair = []
for row in csv.reader(f, delimiter='\t'):
row_pair.append(row)
if len(row_pair) == 2:
base_string1, female_pronoun1, male_pronoun1, candidate1, occupation1, answer1 = _parse_row(row_pair[0])
base_string2, female_pronoun2, male_pronoun2, candidate2, occupation2, answer2 = _parse_row(row_pair[1])
assert base_string1 == base_string2
assert len(base_string1) > 0
assert '$' not in base_string1
assert female_pronoun1 == female_pronoun2
assert male_pronoun1 == male_pronoun2
assert len(candidate1) > 0 and len(candidate2) > 0 and candidate1 != candidate2
if answer1 == 0:
continuation_occupation = candidate1
continuation_participant = candidate2
else:
continuation_occupation = candidate2
continuation_participant = candidate1
examples.append(
WinogenderExample(base_string1,
female_pronoun1,
male_pronoun1,
continuation_occupation,
continuation_participant,
occupation1,
bergsma_pct_female[occupation1],
bls_pct_female[occupation1])
)
row_pair = []
return examples
def _parse_row(row):
occupation, participant, answer, sentence = row
pronoun_to_substitutes = {
'$NOM_PRONOUN': ('she', 'he'),
'$POSS_PRONOUN': ('her', 'his')
}
for pronoun_type, substitutes in pronoun_to_substitutes.items():
if pronoun_type in sentence:
context, candidate = sentence.split(pronoun_type)
base_string = context.replace('$OCCUPATION', occupation)
base_string = base_string.replace('$PARTICIPANT', participant)
base_string = base_string + '{}'
female_pronoun = substitutes[0]
male_pronoun = substitutes[1]
return base_string, female_pronoun, male_pronoun, candidate.strip(), occupation, int(answer)
raise ValueError('Sentence does not contain pronoun type')
class WinogenderExample():
def __init__(self, base_string, female_pronoun, male_pronoun, continuation_occupation, continuation_participant, occupation,
bergsma_pct_female, bls_pct_female):
self.base_string = base_string
self.female_pronoun = female_pronoun
self.male_pronoun = male_pronoun
self.continuation_occupation = continuation_occupation
self.continuation_participant = continuation_participant
self.occupation = occupation
self.bergsma_pct_female = bergsma_pct_female
self.bls_pct_female = bls_pct_female
def to_intervention(self, tokenizer, stat):
if stat=='bergsma':
pct_female = self.bergsma_pct_female
elif stat == 'bls':
pct_female = self.bls_pct_female
else:
raise ValueError('Invalid: ' + stat)
if pct_female > 50:
female_continuation = self.continuation_occupation
male_continuation = self.continuation_participant
else:
male_continuation = self.continuation_occupation
female_continuation = self.continuation_participant
return Intervention(
tokenizer=tokenizer,
base_string=self.base_string,
substitutes=[self.female_pronoun, self.male_pronoun],
candidates=[female_continuation, male_continuation]
)
def __str__(self):
return inspect.cleandoc(f"""
base_string: {self.base_string}
female_pronoun: {self.female_pronoun}
male_pronoun: {self.male_pronoun}
continuation_occupation: {self.continuation_occupation}
continuation_participant: {self.continuation_participant}
occupation: {self.occupation}
bergsma_pct_female: {self.bergsma_pct_female}
bls_pct_female: {self.bls_pct_female}
""")
if __name__ == "__main__":
for ex in load_examples():
print()
print(ex)