generated from aaivu/aaivu-project-template
-
Notifications
You must be signed in to change notification settings - Fork 0
/
MemoryTagger.py
30 lines (26 loc) · 864 Bytes
/
MemoryTagger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
from sklearn.base import BaseEstimator, TransformerMixin
class MemoryTagger(BaseEstimator, TransformerMixin):
def fit(self, X, y):
'''
Expects a list of words as X and a list of tags as y.
'''
voc = {}
self.tags = []
for x, t in zip(X, y):
if t not in self.tags:
self.tags.append(t)
if x in voc:
if t in voc[x]:
voc[x][t] += 1
else:
voc[x][t] = 1
else:
voc[x] = {t: 1}
self.memory = {}
for k, d in voc.items():
self.memory[k] = max(d, key=d.get)
def predict(self, X, y=None):
'''
Predict the the tag from memory. If word is unknown, predict 'O'.
'''
return [self.memory.get(x, 'PRODUCT_NAME') for x in X]