forked from spyysalo/interleave-layer
-
Notifications
You must be signed in to change notification settings - Fork 0
/
interleave.py
74 lines (56 loc) · 2.45 KB
/
interleave.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
#!/usr/bin/env python
from __future__ import print_function
from keras.engine import Layer
from keras import backend as K
class Interleave(Layer):
"""Special-purpose layer for interleaving sequences.
Intended to merge a sequence of word vectors with a sequence of
vectors representing dependencies a (word, dependency, word)
pattern. Note that word vectors other than the first and the last
are duplicated.
For example, given
[ [w11 w12 ...] [w21 w22 ...] [w31 w32 ...] ... ]
[ [d11 d12 ...] [d21 d22 ...] ... ]
produces
[ [w11 w12 ... d11 d12 ... w21 w22 ...]
[w21 w22 ... d21 d22 ... w31 w32 ...] ... ]
(first dimension for batch not shown.)
"""
def call(self, inputs, mask=None):
if type(inputs) is not list or len(inputs) != 2:
raise ValueError('Interleave must be called with a list '
'of two tensors')
if any(m for m in mask if m is not None):
raise NotImplementedError('mask for Interleave')
a, b = inputs
if K.ndim(a) != 3 or K.ndim(b) != 3:
raise ValueError('Interleaved tensors must have ndim 3')
# Concatenate the sequences so that each item in b is preceded
# by an item in a and followed by the next item in a.
return K.concatenate([a[:, :-1, :], b, a[:, 1:, :]], axis=2)
def get_output_shape_for(self, input_shape):
a_shape, b_shape = input_shape
return (a_shape[0], b_shape[1], 2*a_shape[2]+b_shape[2])
if __name__ == '__main__':
# Example
import numpy as np
from keras.models import Model
from keras.layers import Input, Embedding
# 5-dim word embeddings, 4-dim dep embeddings
we = np.arange(10).repeat(5).reshape((-1, 5))
de = np.arange(10).repeat(4).reshape((-1, 4)) * 10
# Inputs are sequences of three words and two dependencies
w_in = Input(shape=(3,))
d_in = Input(shape=(2,))
w_emb = Embedding(we.shape[0], we.shape[1], weights=[we])(w_in)
d_emb = Embedding(de.shape[0], de.shape[1], weights=[de])(d_in)
out = Interleave()([w_emb, d_emb])
model = Model(input=[w_in, d_in], output=out)
model.compile('adam', 'mse')
words = np.array([[1,2,3], [4,5,6]])
deps = np.array([[1,2], [4,5]])
print('words:\n', words)
print('deps:\n', deps)
print('embedded words:\n', we[words])
print('embedded deps:\n', de[deps])
print('interleaved:\n', model.predict([words, deps]))