-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata_utils_v2.py
74 lines (58 loc) · 2.14 KB
/
data_utils_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
from torch.utils import data
import pandas as pd
import os, argparse, random
import torch
from torch.utils.data import Dataset, DataLoader
langs = ['english', 'danish', 'turkish', 'arabic', 'greek']
dataroot = '/content/drive/MyDrive/data/'
class OLID(Dataset):
def __init__(self, dataroot, mode='train', lang='english'):
if lang!='all' and lang not in langs:
print(f'language {lang} not supported. Supported all_langs + {",".join(langs)}')
exit()
if mode not in ['train', 'test']:
print(f'mode {mode} not supported. Supported train, test')
exit()
self.dataroot = dataroot
self.mode = mode
self.lang = lang
self.sents = []
if lang=='all':
for lang in langs:
tsv_anno = os.path.join(self.dataroot, lang, f'{mode}2.tsv')
self.df = pd.read_csv(tsv_anno, sep='\t')
self.df = self.df[['tweet', 'subtask_a']]
self.df = self.df.dropna()
self.sents.extend(self.df.values.tolist())
else:
tsv_anno = os.path.join(self.dataroot, lang, f'{mode}2.tsv')
self.df = pd.read_csv(tsv_anno, sep='\t')
self.df = self.df[['tweet', 'subtask_a']]
self.df = self.df.dropna()
self.sents.extend(self.df.values.tolist())
random.shuffle(self.sents)
# total_sents = len(self.sents)
# num_off = sum(self.df['subtask_a'] == 'OFF')
# num_not = total_sents - num_off
# print(f'lang: {lang}, mode:{mode} OFF:{num_off} NOT:{num_not}')
# if self.mode == 'train':
# tsv_anno = os.path.join(self.dataroot, lang, 'train.tsv')
# self.df = pd.read_csv(tsv_anno, sep='\t')
# else: # mode == test
# tsv_anno = os.path.join(self.dataroot, lang, 'test.tsv')
# self.df = pd.read_csv(tsv_anno, sep='\t')
# self.df = self.df.dropna()
# print('df size:', self.df.shape)
def __len__(self):
# return self.df.shape[0]
return len(self.sents)
def __getitem__(self, idx):
# row = self.df.iloc[idx]
sent = self.sents[idx]
input = sent[0]
label = sent[1]
if label == 'OFF':
label = 1
else:
label = 0
return {'input' : input, 'label' : label}