forked from Ighina/DeepTiling
-
Notifications
You must be signed in to change notification settings - Fork 0
/
choiloader_sentences.py
122 lines (98 loc) · 4.16 KB
/
choiloader_sentences.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
# -*- coding: utf-8 -*-
"""
Created on Sun Jan 31 12:38:51 2021
@author: Iacopo
"""
from __future__ import print_function
from pathlib2 import Path
import torch
from torch.utils.data import Dataset
import numpy as np
import random
from text_manipulation import split_sentences, extract_sentence_words
import utils
import math
logger = utils.setup_logger(__name__, 'train.log')
def get_choi_files(path):
all_objects = Path(path).glob('**/*.ref')
files = [str(p) for p in all_objects if p.is_file()]
return files
def collate_fn(batch):
batched_data = []
batched_targets = []
paths = []
window_size = 1
before_sentence_count = int(math.ceil(float(window_size - 1) /2))
after_sentence_count = window_size - before_sentence_count - 1
for data, targets, path in batch:
try:
max_index = len(data)
tensored_data = []
for curr_sentence_index in range(0, len(data)):
from_index = max([0, curr_sentence_index - before_sentence_count])
to_index = min([curr_sentence_index + after_sentence_count + 1, max_index])
sentences_window = [word for sentence in data[from_index:to_index] for word in sentence]
tensored_data.append(torch.FloatTensor(np.concatenate(sentences_window)))
tensored_targets = torch.zeros(len(data)).long()
tensored_targets[torch.LongTensor(targets)] = 1
tensored_targets = tensored_targets[:-1]
batched_data.append(tensored_data)
batched_targets.append(tensored_targets)
paths.append(path)
except Exception as e:
logger.info('Exception "%s" in file: "%s"', e, path)
logger.debug('Exception!', exc_info=True)
continue
return batched_data, batched_targets, paths
def clean_paragraph(paragraph):
cleaned_paragraph= paragraph.replace("'' ", " ").replace(" 's", "'s").replace("``", "").strip('\n')
return cleaned_paragraph
def read_choi_file(path, train = False, manifesto=False):
seperator = '========' if manifesto else '=========='
with Path(path).open('r') as f:
raw_text = f.read()
paragraphs = [clean_paragraph(p) for p in raw_text.strip().split(seperator)
if len(p) > 5 and p != "\n"]
if train:
random.shuffle(paragraphs)
targets = []
new_text = []
lastparagraphsentenceidx = 0
for paragraph in paragraphs:
if manifesto:
sentences = split_sentences(paragraph,0)
else:
sentences = [s for s in paragraph.split('\n') if len(s.split()) > 0]
if sentences:
sentences_count =0
# This is the number of sentences in the paragraph and where we need to split.
for sentence in sentences:
words = extract_sentence_words(sentence)
if (len(words) == 0):
continue
sentences_count +=1
new_text.append(sentence)
lastparagraphsentenceidx += sentences_count
targets.append(lastparagraphsentenceidx - 1)
return new_text, targets, path
# Returns a list of batch_size that contains a list of sentences, where each word is encoded using word2vec.
class ChoiDataset(Dataset):
def __init__(self, root, train=False, folder=False,manifesto=False, folders_paths = None):
self.manifesto = manifesto
if folders_paths is not None:
self.textfiles = []
for f in folders_paths:
self.textfiles.extend(list(f.glob('*.ref')))
elif (folder):
self.textfiles = get_choi_files(root)
else:
self.textfiles = list(Path(root).glob('**/*.ref'))
if len(self.textfiles) == 0:
raise RuntimeError('Found 0 images in subfolders of: {}'.format(root))
self.train = train
self.root = root
def __getitem__(self, index):
path = self.textfiles[index]
return read_choi_file(path, self.train, manifesto=self.manifesto)
def __len__(self):
return len(self.textfiles)