diff --git a/keras_nlp/src/models/falcon/falcon_transformer_decoder.py b/keras_nlp/src/models/falcon/falcon_transformer_decoder.py index c11f9fdee2..177fe038e5 100644 --- a/keras_nlp/src/models/falcon/falcon_transformer_decoder.py +++ b/keras_nlp/src/models/falcon/falcon_transformer_decoder.py @@ -133,7 +133,7 @@ def call( mask = decoder_padding_mask if mask is None: batch_size, seq_length = ops.shape(inputs)[:2] - mask = ops.ones((batch_size, seq_length), dtype="int") + mask = ops.ones((batch_size, seq_length), dtype="int32") alibi = self._build_alibi_tensor(self.num_attention_heads, mask) # Attention block. @@ -225,9 +225,9 @@ def _build_alibi_tensor(self, num_heads, attention_mask): self._get_slopes(num_heads), dtype=self.compute_dtype, ) # num_heads - attention_mask = ops.cast(attention_mask, dtype="int") + cumsum_mask = ops.cumsum(ops.cast(attention_mask, "int32"), axis=-1) - 1 arange_tensor = ( - ((ops.cumsum(attention_mask, axis=-1) - 1) * attention_mask) + ops.cast(cumsum_mask, "int32") * ops.cast(attention_mask, "int32") )[:, None, :] alibi = slopes[..., None] * ops.cast(arange_tensor, self.compute_dtype) alibi = ops.expand_dims( diff --git a/keras_nlp/src/models/phi3/phi3_decoder.py b/keras_nlp/src/models/phi3/phi3_decoder.py index 134ce7d71b..28f9d2937d 100644 --- a/keras_nlp/src/models/phi3/phi3_decoder.py +++ b/keras_nlp/src/models/phi3/phi3_decoder.py @@ -65,7 +65,6 @@ def __init__( self.kernel_initializer = keras.initializers.get(kernel_initializer) def build(self, decoder_sequence_shape): - # Pre-attention layernorm. self.pre_attention_layernorm = Phi3LayerNorm( epsilon=self.layer_norm_epsilon,