forked from Ighina/DeepTiling
-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
13 changed files
with
1,490 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,127 @@ | ||
""" | ||
LexRank implementation | ||
Source: https://github.com/crabcamp/lexrank/tree/dev | ||
""" | ||
|
||
import numpy as np | ||
from scipy.sparse.csgraph import connected_components | ||
|
||
def degree_centrality_scores( | ||
similarity_matrix, | ||
threshold=None, | ||
increase_power=True, | ||
): | ||
if not ( | ||
threshold is None | ||
or isinstance(threshold, float) | ||
and 0 <= threshold < 1 | ||
): | ||
raise ValueError( | ||
'\'threshold\' should be a floating-point number ' | ||
'from the interval [0, 1) or None', | ||
) | ||
|
||
if threshold is None: | ||
markov_matrix = create_markov_matrix(similarity_matrix) | ||
|
||
else: | ||
markov_matrix = create_markov_matrix_discrete( | ||
similarity_matrix, | ||
threshold, | ||
) | ||
|
||
scores = stationary_distribution( | ||
markov_matrix, | ||
increase_power=increase_power, | ||
normalized=False, | ||
) | ||
|
||
return scores | ||
|
||
|
||
def _power_method(transition_matrix, increase_power=True): | ||
eigenvector = np.ones(len(transition_matrix)) | ||
|
||
if len(eigenvector) == 1: | ||
return eigenvector | ||
|
||
transition = transition_matrix.transpose() | ||
|
||
while True: | ||
eigenvector_next = np.dot(transition, eigenvector) | ||
|
||
if np.allclose(eigenvector_next, eigenvector): | ||
return eigenvector_next | ||
|
||
eigenvector = eigenvector_next | ||
|
||
if increase_power: | ||
transition = np.dot(transition, transition) | ||
|
||
|
||
def connected_nodes(matrix): | ||
_, labels = connected_components(matrix) | ||
|
||
groups = [] | ||
|
||
for tag in np.unique(labels): | ||
group = np.where(labels == tag)[0] | ||
groups.append(group) | ||
|
||
return groups | ||
|
||
|
||
def create_markov_matrix(weights_matrix): | ||
n_1, n_2 = weights_matrix.shape | ||
if n_1 != n_2: | ||
raise ValueError('\'weights_matrix\' should be square') | ||
|
||
row_sum = weights_matrix.sum(axis=1, keepdims=True) | ||
|
||
return weights_matrix / row_sum | ||
|
||
|
||
def create_markov_matrix_discrete(weights_matrix, threshold): | ||
discrete_weights_matrix = np.zeros(weights_matrix.shape) | ||
ixs = np.where(weights_matrix >= threshold) | ||
discrete_weights_matrix[ixs] = 1 | ||
|
||
return create_markov_matrix(discrete_weights_matrix) | ||
|
||
|
||
def graph_nodes_clusters(transition_matrix, increase_power=True): | ||
clusters = connected_nodes(transition_matrix) | ||
clusters.sort(key=len, reverse=True) | ||
|
||
centroid_scores = [] | ||
|
||
for group in clusters: | ||
t_matrix = transition_matrix[np.ix_(group, group)] | ||
eigenvector = _power_method(t_matrix, increase_power=increase_power) | ||
centroid_scores.append(eigenvector / len(group)) | ||
|
||
return clusters, centroid_scores | ||
|
||
|
||
def stationary_distribution( | ||
transition_matrix, | ||
increase_power=True, | ||
normalized=True, | ||
): | ||
n_1, n_2 = transition_matrix.shape | ||
if n_1 != n_2: | ||
raise ValueError('\'transition_matrix\' should be square') | ||
|
||
distribution = np.zeros(n_1) | ||
|
||
grouped_indices = connected_nodes(transition_matrix) | ||
|
||
for group in grouped_indices: | ||
t_matrix = transition_matrix[np.ix_(group, group)] | ||
eigenvector = _power_method(t_matrix, increase_power=increase_power) | ||
distribution[group] = eigenvector | ||
|
||
if normalized: | ||
distribution /= n_1 | ||
|
||
return distribution |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,122 @@ | ||
# -*- coding: utf-8 -*- | ||
""" | ||
Created on Sun Jan 31 12:38:51 2021 | ||
@author: Iacopo | ||
""" | ||
|
||
from __future__ import print_function | ||
from pathlib2 import Path | ||
|
||
import torch | ||
from torch.utils.data import Dataset | ||
import numpy as np | ||
import random | ||
from text_manipulation import split_sentences, extract_sentence_words | ||
import utils | ||
import math | ||
|
||
|
||
logger = utils.setup_logger(__name__, 'train.log') | ||
|
||
|
||
def get_choi_files(path): | ||
all_objects = Path(path).glob('**/*.ref') | ||
files = [str(p) for p in all_objects if p.is_file()] | ||
return files | ||
|
||
def collate_fn(batch): | ||
batched_data = [] | ||
batched_targets = [] | ||
paths = [] | ||
|
||
window_size = 1 | ||
before_sentence_count = int(math.ceil(float(window_size - 1) /2)) | ||
after_sentence_count = window_size - before_sentence_count - 1 | ||
|
||
for data, targets, path in batch: | ||
try: | ||
max_index = len(data) | ||
tensored_data = [] | ||
for curr_sentence_index in range(0, len(data)): | ||
from_index = max([0, curr_sentence_index - before_sentence_count]) | ||
to_index = min([curr_sentence_index + after_sentence_count + 1, max_index]) | ||
sentences_window = [word for sentence in data[from_index:to_index] for word in sentence] | ||
tensored_data.append(torch.FloatTensor(np.concatenate(sentences_window))) | ||
tensored_targets = torch.zeros(len(data)).long() | ||
tensored_targets[torch.LongTensor(targets)] = 1 | ||
tensored_targets = tensored_targets[:-1] | ||
batched_data.append(tensored_data) | ||
batched_targets.append(tensored_targets) | ||
paths.append(path) | ||
except Exception as e: | ||
logger.info('Exception "%s" in file: "%s"', e, path) | ||
logger.debug('Exception!', exc_info=True) | ||
continue | ||
|
||
return batched_data, batched_targets, paths | ||
|
||
def clean_paragraph(paragraph): | ||
cleaned_paragraph= paragraph.replace("'' ", " ").replace(" 's", "'s").replace("``", "").strip('\n') | ||
return cleaned_paragraph | ||
|
||
def read_choi_file(path, train = False, manifesto=False): | ||
seperator = '========' if manifesto else '==========' | ||
with Path(path).open('r') as f: | ||
raw_text = f.read() | ||
paragraphs = [clean_paragraph(p) for p in raw_text.strip().split(seperator) | ||
if len(p) > 5 and p != "\n"] | ||
if train: | ||
random.shuffle(paragraphs) | ||
|
||
targets = [] | ||
new_text = [] | ||
lastparagraphsentenceidx = 0 | ||
|
||
for paragraph in paragraphs: | ||
if manifesto: | ||
sentences = split_sentences(paragraph,0) | ||
else: | ||
sentences = [s for s in paragraph.split('\n') if len(s.split()) > 0] | ||
|
||
if sentences: | ||
sentences_count =0 | ||
# This is the number of sentences in the paragraph and where we need to split. | ||
for sentence in sentences: | ||
words = extract_sentence_words(sentence) | ||
if (len(words) == 0): | ||
continue | ||
sentences_count +=1 | ||
new_text.append(sentence) | ||
|
||
lastparagraphsentenceidx += sentences_count | ||
targets.append(lastparagraphsentenceidx - 1) | ||
|
||
return new_text, targets, path | ||
|
||
|
||
# Returns a list of batch_size that contains a list of sentences, where each word is encoded using word2vec. | ||
class ChoiDataset(Dataset): | ||
def __init__(self, root, train=False, folder=False,manifesto=False, folders_paths = None): | ||
self.manifesto = manifesto | ||
if folders_paths is not None: | ||
self.textfiles = [] | ||
for f in folders_paths: | ||
self.textfiles.extend(list(f.glob('*.ref'))) | ||
elif (folder): | ||
self.textfiles = get_choi_files(root) | ||
else: | ||
self.textfiles = list(Path(root).glob('**/*.ref')) | ||
|
||
if len(self.textfiles) == 0: | ||
raise RuntimeError('Found 0 images in subfolders of: {}'.format(root)) | ||
self.train = train | ||
self.root = root | ||
|
||
def __getitem__(self, index): | ||
path = self.textfiles[index] | ||
|
||
return read_choi_file(path, self.train, manifesto=self.manifesto) | ||
|
||
def __len__(self): | ||
return len(self.textfiles) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
import json | ||
|
||
jsondata = { | ||
"word2vecfile": "/home/omri/datasets/word2vec/GoogleNews-vectors-negative300.bin", | ||
"choidataset": "/home/omri/code/text-segmentation-2017/data/choi", | ||
"wikidataset": "/home/omri/datasets/wikipedia/process_dump_r", | ||
} | ||
|
||
with open('config.json', 'w') as f: | ||
json.dump(jsondata, f) |
Oops, something went wrong.