diff --git a/src/baskerville/blocks.py b/src/baskerville/blocks.py index 527dd27..8e74e31 100644 --- a/src/baskerville/blocks.py +++ b/src/baskerville/blocks.py @@ -1152,6 +1152,7 @@ def transformer( kernel_initializer="he_normal", adapter=None, latent=16, + seqlen_train=None, **kwargs, ): """Construct a transformer block. @@ -1183,6 +1184,7 @@ def transformer( initializer=mha_initializer, l2_scale=mha_l2_scale, qkv_width=qkv_width, + seqlen_train=seqlen_train )(current) # dropout diff --git a/src/baskerville/layers.py b/src/baskerville/layers.py index d0513dc..e1e974d 100644 --- a/src/baskerville/layers.py +++ b/src/baskerville/layers.py @@ -448,6 +448,7 @@ def __init__( initializer="he_normal", l2_scale=0, qkv_width=1, + seqlen_train=None ): """Creates a MultiheadAttention module. Original version written by Ziga Avsec. @@ -480,6 +481,7 @@ def __init__( self._gated = gated self._relative_position_symmetric = relative_position_symmetric self._relative_position_functions = relative_position_functions + self.seqlen_train = seqlen_train if num_position_features is None: # num_position_features needs to be divisible by the number of # relative positional functions *2 (for symmetric & asymmetric version). @@ -641,13 +643,23 @@ def call(self, inputs, training=False): else: # Project positions to form relative keys. distances = tf.range(-seq_len + 1, seq_len, dtype=tf.float32)[tf.newaxis] - positional_encodings = positional_features( - positions=distances, - feature_size=self._num_position_features, - seq_length=seq_len, - symmetric=self._relative_position_symmetric, - ) - # [1, 2T-1, Cr] + + if self.seqlen_train is None: + positional_encodings = positional_features( + positions=distances, + feature_size=self._num_position_features, + seq_length=seq_len, + symmetric=self._relative_position_symmetric, + ) + # [1, 2T-1, Cr] + else: + positional_encodings = positional_features( + positions=distances, + feature_size=self._num_position_features, + seq_length=self.seqlen_train, + symmetric=self._relative_position_symmetric, + ) + # [1, 2T-1, Cr] if training: positional_encodings = tf.nn.dropout(