Skip to content

Commit

Permalink
Add attention map example
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Jun 6, 2020
1 parent 46ae8c6 commit 580ca4f
Show file tree
Hide file tree
Showing 3 changed files with 34 additions and 2 deletions.
32 changes: 32 additions & 0 deletions demo/load_model/load_and_get_attention_map.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import sys
import numpy as np
from keras_bert import load_vocabulary, load_trained_model_from_checkpoint, Tokenizer, get_checkpoint_paths
from keras_bert.backend import backend as K

print('This demo demonstrates how to load the pre-trained model and extract the attention map')

if len(sys.argv) == 2:
model_path = sys.argv[1]
else:
from keras_bert.datasets import get_pretrained, PretrainedList
model_path = get_pretrained(PretrainedList.chinese_base)

paths = get_checkpoint_paths(model_path)

model = load_trained_model_from_checkpoint(paths.config, paths.checkpoint, seq_len=10)
attention_layer = model.get_layer('Encoder-1-MultiHeadSelfAttention')
model = K.function(model.inputs, attention_layer.attention)

token_dict = load_vocabulary(paths.vocab)

tokenizer = Tokenizer(token_dict)
text = '语言模型'
tokens = tokenizer.tokenize(text)
print('Tokens:', tokens)
indices, segments = tokenizer.encode(first=text, max_len=10)

predicts = model([np.array([indices]), np.array([segments])])[0]
for i, token in enumerate(tokens):
print(token)
for head_index in range(12):
print(predicts[i][head_index, :len(text) + 2].tolist())
2 changes: 1 addition & 1 deletion keras_bert/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,4 @@
from .util import *
from .datasets import *

__version__ = '0.83.0'
__version__ = '0.84.0'
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
numpy
Keras
keras-transformer>=0.35.0
keras-transformer>=0.37.0

0 comments on commit 580ca4f

Please sign in to comment.