-
Notifications
You must be signed in to change notification settings - Fork 56
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
2 changed files
with
1,726 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,97 @@ | ||
import pandas as pd | ||
import numpy as np | ||
from keras.preprocessing.sequence import pad_sequences | ||
from keras.utils import to_categorical | ||
from sklearn.model_selection import train_test_split | ||
from keras.models import Model, Input | ||
from keras.layers import LSTM, Embedding, Dense, TimeDistributed, Dropout, Bidirectional | ||
from keras_contrib.layers import CRF | ||
|
||
|
||
data = pd.read_csv('../data/ner_dataset.csv', encoding='latin1') | ||
|
||
# fill NaN | ||
data = data.fillna(method='ffill') | ||
|
||
# Save all words as a list | ||
words = list(set(data['Word'].values)) | ||
n_words = len(words) | ||
|
||
tags = list(set(data["Tag"].values)) | ||
n_tags = len(tags) | ||
|
||
# Sentence class | ||
class SentenceGetter(object): | ||
|
||
def __init__(self, data): | ||
self.n_sent = 1 | ||
self.data = data | ||
self.empty = False | ||
agg_func = lambda s: [(w, p, t) for w, p, t in zip(s["Word"].values.tolist(), | ||
s["POS"].values.tolist(), | ||
s["Tag"].values.tolist())] | ||
self.grouped = self.data.groupby("Sentence #").apply(agg_func) | ||
self.sentences = [s for s in self.grouped] | ||
|
||
def get_next(self): | ||
try: | ||
s = self.grouped["Sentence: {}".format(self.n_sent)] | ||
self.n_sent += 1 | ||
return s | ||
except: | ||
return None | ||
|
||
getter = SentenceGetter(data) | ||
sentences = getter.sentences # get all sentences | ||
|
||
# max_len = 75 | ||
word2idx = {w: i + 1 for i, w in enumerate(words)} | ||
n_words = len(word2idx) | ||
# word2idx['<unk>'] = len(word2idx) + 1 | ||
tag2idx = {t: i + 1 for i, t in enumerate(tags)} | ||
tag2idx['<pad>'] = 0 | ||
n_tags = len(tag2idx) # Due to <pad>, here total tag number is from 17 to 18 | ||
|
||
# Word2inx & Padding for X | ||
X = [[word2idx[w[0]] for w in s] for s in sentences] | ||
X = pad_sequences(maxlen=max_len, sequences=X, padding="post", value=0) | ||
|
||
# Word2inx & Padding for y | ||
y = [[tag2idx[w[2]] for w in s] for s in sentences] | ||
y = pad_sequences(maxlen=max_len, sequences=y, padding="post", value=0) | ||
|
||
# Get one-hot labels | ||
y = [to_categorical(i, num_classes=n_tags) for i in y] | ||
X_tr, X_te, y_tr, y_te = train_test_split(X, y, test_size=0.1) | ||
|
||
#==============Bi-LSTM CRF============= | ||
input = Input(shape=(max_len,)) | ||
model = Embedding(input_dim=n_words + 1, output_dim=20, | ||
input_length=max_len, mask_zero=True)(input) # 20-dim embedding | ||
model = Bidirectional(LSTM(units=50, return_sequences=True, | ||
recurrent_dropout=0.1))(model) # variational biLSTM | ||
model = TimeDistributed(Dense(50, activation="tanh"))(model) # a dense layer as suggested by neuralNer | ||
crf = CRF(n_tags) # CRF layer | ||
out = crf(model) # output | ||
|
||
model = Model(input, out) | ||
model.compile(optimizer="rmsprop", loss=crf.loss_function, metrics=[crf.accuracy]) | ||
model.summary() | ||
|
||
history = model.fit(X_tr, np.array(y_tr), batch_size=32, epochs=1, | ||
validation_split=0.1, verbose=1) | ||
|
||
# Predictions. | ||
idx2word = {value: key for key, value in word2idx.items()} | ||
idx2tag = {value: key for key, value in tag2idx.items()} | ||
|
||
p_all = model.predict(np.array(X_te)) # (4796, 75, 18) | ||
p_all= np.argmax(p_all, axis=-1) # (4796, 75) | ||
p_all_tags = [[idx2tag[idx] for idx in s if idx!=0] for s in p_all] # ['B-gpe', 'O', 'O', 'O'] | ||
|
||
true_all = np.argmax(y_te, -1) | ||
true_all_tags = [[idx2tag[idx] for idx in s if idx!=0] for s in true_all] | ||
|
||
# Evaluation | ||
from seqeval.metrics import f1_score | ||
f1_score(true_all_tags, p_all_tags) |
Oops, something went wrong.