Skip to content

Commit

Permalink
Use transformer encoder
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Nov 9, 2018
1 parent 27af8c4 commit f98382a
Show file tree
Hide file tree
Showing 19 changed files with 30 additions and 523 deletions.
1 change: 0 additions & 1 deletion keras_bert/activations/__init__.py

This file was deleted.

13 changes: 0 additions & 13 deletions keras_bert/activations/gelu.py

This file was deleted.

39 changes: 17 additions & 22 deletions keras_bert/bert.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,9 @@
import random
import keras
import numpy as np
from keras_multi_head import MultiHeadAttention
from .layers import (get_inputs, Embeddings, Transformer,
FeedForward, Masked, Extract, LayerNormalization)
from .activations import gelu
from keras_transformer import gelu, get_encoders
from keras_transformer import get_custom_objects as get_encoder_custom_objects
from .layers import (get_inputs, Embeddings, Masked, Extract)


TOKEN_PAD = '' # Token for padding
Expand Down Expand Up @@ -61,14 +60,14 @@ def get_model(token_num,
kwargs['trainable'] = training
transformed = custom_layers(transformed, **kwargs)
else:
for i in range(transformer_num):
transformed = Transformer(
head_num=head_num,
hidden_dim=feed_forward_dim,
dropout_rate=dropout_rate,
trainable=training,
name='Transformer-%d' % (i + 1),
)(transformed)
transformed = get_encoders(
encoder_num=transformer_num,
input_layer=transformed,
head_num=head_num,
hidden_dim=feed_forward_dim,
activation=gelu,
dropout_rate=dropout_rate,
)
if not training:
return inputs, transformed
mlm_pred_layer = keras.layers.Dense(
Expand Down Expand Up @@ -96,16 +95,12 @@ def get_model(token_num,

def get_custom_objects():
"""Get all custom objects for loading saved models."""
return {
'Embeddings': Embeddings,
'MultiHeadAttention': MultiHeadAttention,
'FeedForward': FeedForward,
'LayerNormalization': LayerNormalization,
'Transformer': Transformer,
'Masked': Masked,
'Extract': Extract,
'gelu': gelu,
}
custom_objects = get_encoder_custom_objects()
custom_objects['Embeddings'] = Embeddings
custom_objects['Masked'] = Masked
custom_objects['Extract'] = Extract
custom_objects['gelu'] = gelu
return custom_objects


def get_base_dict():
Expand Down
3 changes: 0 additions & 3 deletions keras_bert/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,5 @@
from .wrapper import Wrapper
from .inputs import get_inputs
from .embedding import Embeddings
from .feed_forward import FeedForward
from .layer_norm import LayerNormalization
from .transformer import Transformer
from .masked import Masked
from .extract import Extract
49 changes: 0 additions & 49 deletions keras_bert/layers/feed_forward.py

This file was deleted.

34 changes: 0 additions & 34 deletions keras_bert/layers/layer_norm.py

This file was deleted.

106 changes: 0 additions & 106 deletions keras_bert/layers/transformer.py

This file was deleted.

1 change: 1 addition & 0 deletions requirements-dev.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,3 @@
tensorflow
pycodestyle
coverage
3 changes: 1 addition & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
numpy
tensorflow
Keras
keras-multi-head==0.7.0
keras-transformer==0.4.0
4 changes: 2 additions & 2 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='keras-bert',
version='0.13.0',
version='0.14.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-bert',
license='MIT',
Expand All @@ -13,7 +13,7 @@
install_requires=[
'numpy',
'keras',
'keras-multi-head==0.7.0',
'keras-transformer==0.4.0',
],
classifiers=(
"Programming Language :: Python :: 2.7",
Expand Down
Empty file removed tests/activations/__init__.py
Empty file.
16 changes: 0 additions & 16 deletions tests/activations/test_gelu.py

This file was deleted.

2 changes: 1 addition & 1 deletion tests/layers/test_embedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ def test_sample(self):
model.compile(
optimizer='adam',
loss='mse',
metrics=['mse'],
metrics={},
)
model.summary(line_length=120)
self.assertEqual((None, 512, 768), model.layers[-1].output_shape)
2 changes: 1 addition & 1 deletion tests/layers/test_extract.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ def test_sample(self):
model.compile(
optimizer='adam',
loss='mse',
metrics=['mse'],
metrics={},
)
model.summary()
inputs = np.asarray([[
Expand Down
Loading

0 comments on commit f98382a

Please sign in to comment.