Skip to content

Commit

Permalink
make the model take variable size input
Browse files Browse the repository at this point in the history
  • Loading branch information
hy395 committed May 14, 2024
1 parent 21271c9 commit e1be2c8
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 7 deletions.
2 changes: 2 additions & 0 deletions src/baskerville/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -1152,6 +1152,7 @@ def transformer(
kernel_initializer="he_normal",
adapter=None,
latent=16,
seqlen_train=None,
**kwargs,
):
"""Construct a transformer block.
Expand Down Expand Up @@ -1183,6 +1184,7 @@ def transformer(
initializer=mha_initializer,
l2_scale=mha_l2_scale,
qkv_width=qkv_width,
seqlen_train=seqlen_train
)(current)

# dropout
Expand Down
26 changes: 19 additions & 7 deletions src/baskerville/layers.py
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,7 @@ def __init__(
initializer="he_normal",
l2_scale=0,
qkv_width=1,
seqlen_train=None
):
"""Creates a MultiheadAttention module.
Original version written by Ziga Avsec.
Expand Down Expand Up @@ -480,6 +481,7 @@ def __init__(
self._gated = gated
self._relative_position_symmetric = relative_position_symmetric
self._relative_position_functions = relative_position_functions
self.seqlen_train = seqlen_train
if num_position_features is None:
# num_position_features needs to be divisible by the number of
# relative positional functions *2 (for symmetric & asymmetric version).
Expand Down Expand Up @@ -641,13 +643,23 @@ def call(self, inputs, training=False):
else:
# Project positions to form relative keys.
distances = tf.range(-seq_len + 1, seq_len, dtype=tf.float32)[tf.newaxis]
positional_encodings = positional_features(
positions=distances,
feature_size=self._num_position_features,
seq_length=seq_len,
symmetric=self._relative_position_symmetric,
)
# [1, 2T-1, Cr]

if self.seqlen_train is None:
positional_encodings = positional_features(
positions=distances,
feature_size=self._num_position_features,
seq_length=seq_len,
symmetric=self._relative_position_symmetric,
)
# [1, 2T-1, Cr]
else:
positional_encodings = positional_features(
positions=distances,
feature_size=self._num_position_features,
seq_length=self.seqlen_train,
symmetric=self._relative_position_symmetric,
)
# [1, 2T-1, Cr]

if training:
positional_encodings = tf.nn.dropout(
Expand Down

0 comments on commit e1be2c8

Please sign in to comment.