Skip to content

Commit

Permalink
Use existed scaled dot production attention
Browse files Browse the repository at this point in the history
  • Loading branch information
CyberZHG committed Nov 6, 2018
1 parent 7b3203e commit c1dd746
Show file tree
Hide file tree
Showing 8 changed files with 7 additions and 81 deletions.
5 changes: 3 additions & 2 deletions keras_bert/bert.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,8 @@
import random
import keras
import numpy as np
from .layers import (get_inputs, Embeddings, Transformer, Attention, MultiHeadAttention,
from keras_self_attention import ScaledDotProductAttention
from .layers import (get_inputs, Embeddings, Transformer, MultiHeadAttention,
FeedForward, Masked, Extract, LayerNormalization)
from .activations import gelu

Expand Down Expand Up @@ -97,7 +98,7 @@ def get_custom_objects():
"""Get all custom objects for loading saved models."""
return {
'Embeddings': Embeddings,
'Attention': Attention,
'ScaledDotProductAttention': ScaledDotProductAttention,
'MultiHeadAttention': MultiHeadAttention,
'FeedForward': FeedForward,
'LayerNormalization': LayerNormalization,
Expand Down
1 change: 0 additions & 1 deletion keras_bert/layers/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
from .wrapper import Wrapper
from .inputs import get_inputs
from .embedding import Embeddings
from .attention import Attention
from .multi_head import MultiHeadAttention
from .feed_forward import FeedForward
from .layer_norm import LayerNormalization
Expand Down
31 changes: 0 additions & 31 deletions keras_bert/layers/attention.py

This file was deleted.

4 changes: 2 additions & 2 deletions keras_bert/layers/multi_head.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import keras
from .attention import Attention
from keras_self_attention import ScaledDotProductAttention
from ..activations.gelu import gelu
from .wrapper import Wrapper

Expand Down Expand Up @@ -86,7 +86,7 @@ def build(self, input_shape):
name='%s-Dense-Dropout-V_%d' % (self.name, i + 1),
)
self.layers[layer.name] = layer
layer = Attention(
layer = ScaledDotProductAttention(
trainable=self.trainable,
name='%s-Attention_%d' % (self.name, i + 1),
)
Expand Down
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
numpy
tensorflow
Keras
keras-self-attention==0.30.0
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name='keras-bert',
version='0.10.0',
version='0.11.0',
packages=find_packages(),
url='https://github.com/CyberZHG/keras-bert',
license='MIT',
Expand Down
44 changes: 0 additions & 44 deletions tests/layers/test_attention.py

This file was deleted.

Binary file modified tests/test_bert_fit.h5
Binary file not shown.

0 comments on commit c1dd746

Please sign in to comment.