forked from mattdangerw/keras-text-generation
-
Notifications
You must be signed in to change notification settings - Fork 0
/
vectorizer.py
61 lines (52 loc) · 2.09 KB
/
vectorizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
# -*- coding: utf-8 -*-
from __future__ import print_function
from collections import Counter
import numpy as np
from utils import print_red
from utils import word_tokenize, word_detokenize
class Vectorizer:
"""
Transforms text to vectors of integer numbers representing in text tokens
and back. Handles word and character level tokenization.
"""
def __init__(self, text, word_tokens, pristine_input, pristine_output):
self.word_tokens = word_tokens
self._pristine_input = pristine_input
self._pristine_output = pristine_output
tokens = self._tokenize(text)
print('corpus length:', len(tokens))
token_counts = Counter(tokens)
# Sort so most common tokens come first in our vocabulary
tokens = [x[0] for x in token_counts.most_common()]
self._token_indices = {x: i for i, x in enumerate(tokens)}
self._indices_token = {i: x for i, x in enumerate(tokens)}
self.vocab_size = len(tokens)
print('vocab size:', self.vocab_size)
def _tokenize(self, text):
if not self._pristine_input:
text = text.lower()
if self.word_tokens:
if self._pristine_input:
return text.split()
return word_tokenize(text)
return text
def _detokenize(self, tokens):
if self.word_tokens:
if self._pristine_output:
return ' '.join(tokens)
return word_detokenize(tokens)
return ''.join(tokens)
def vectorize(self, text):
"""Transforms text to a vector of integers"""
tokens = self._tokenize(text)
indices = []
for token in tokens:
if token in self._token_indices:
indices.append(self._token_indices[token])
else:
print_red('Ignoring unrecognized token:', token)
return np.array(indices, dtype=np.int32)
def unvectorize(self, vector):
"""Transforms a vector of integers back to text"""
tokens = [self._indices_token[index] for index in vector.tolist()]
return self._detokenize(tokens)