Skip to content

Commit

Permalink
fix: prob attention shape error while bs>1 (#50)
Browse files Browse the repository at this point in the history
  • Loading branch information
LongxingTan authored Oct 16, 2023
1 parent 679e349 commit f3db233
Show file tree
Hide file tree
Showing 2 changed files with 10 additions and 10 deletions.
17 changes: 9 additions & 8 deletions tests/test_models/test_informer.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,20 +129,21 @@ def test_train(self):
predict_length = 10
n_encoder_feature = 2
n_decoder_feature = 3
batch_size = 1

x_train = (
np.random.rand(1, train_length, 1),
np.random.rand(1, train_length, n_encoder_feature),
np.random.rand(1, predict_length, n_decoder_feature),
np.random.rand(batch_size, train_length, 1),
np.random.rand(batch_size, train_length, n_encoder_feature),
np.random.rand(batch_size, predict_length, n_decoder_feature),
)
y_train = np.random.rand(1, predict_length, 1) # target: (batch, predict_length, 1)
y_train = np.random.rand(batch_size, predict_length, 1) # target: (batch, predict_length, 1)

x_valid = (
np.random.rand(1, train_length, 1),
np.random.rand(1, train_length, n_encoder_feature),
np.random.rand(1, predict_length, n_decoder_feature),
np.random.rand(batch_size, train_length, 1),
np.random.rand(batch_size, train_length, n_encoder_feature),
np.random.rand(batch_size, predict_length, n_decoder_feature),
)
y_valid = np.random.rand(1, predict_length, 1)
y_valid = np.random.rand(batch_size, predict_length, 1)

model = AutoModel("Informer", predict_length=predict_length, custom_model_params=custom_params)
trainer = KerasTrainer(model)
Expand Down
3 changes: 1 addition & 2 deletions tfts/layers/attention_layer.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,10 +155,9 @@ def _prob_qk(self, q, k, sample_k, top_n):
K_sample = tf.gather(K_sample, indx_q_seq, axis=2)
K_sample = tf.gather(K_sample, indx_k_seq, axis=3)

Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)))
Q_K_sample = tf.squeeze(tf.matmul(tf.expand_dims(q, -2), tf.einsum("...ij->...ji", K_sample)), axis=3)
M = tf.math.reduce_max(Q_K_sample, axis=-1) - tf.raw_ops.Div(x=tf.reduce_sum(Q_K_sample, axis=-1), y=L)
m_top = tf.math.top_k(M, top_n, sorted=False)[1]
m_top = m_top[tf.newaxis, tf.newaxis] if B == 1 else m_top

batch_indexes = tf.tile(tf.range(B)[:, tf.newaxis, tf.newaxis], (1, H, top_n))
head_indexes = tf.tile(tf.range(H)[tf.newaxis, :, tf.newaxis], (B, 1, top_n))
Expand Down

0 comments on commit f3db233

Please sign in to comment.