Skip to content

Commit

Permalink
Fix the keras2 test.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed May 22, 2024
1 parent 7ecd466 commit 3533413
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 4 deletions.
6 changes: 3 additions & 3 deletions keras_nlp/src/models/falcon/falcon_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def call(
mask = decoder_padding_mask
if mask is None:
batch_size, seq_length = ops.shape(inputs)[:2]
mask = ops.ones((batch_size, seq_length), dtype="int")
mask = ops.ones((batch_size, seq_length), dtype="int32")
alibi = self._build_alibi_tensor(self.num_attention_heads, mask)

# Attention block.
Expand Down Expand Up @@ -225,9 +225,9 @@ def _build_alibi_tensor(self, num_heads, attention_mask):
self._get_slopes(num_heads),
dtype=self.compute_dtype,
) # num_heads
attention_mask = ops.cast(attention_mask, dtype="int")
cumsum_mask = ops.cumsum(ops.cast(attention_mask, "int32"), axis=-1) - 1
arange_tensor = (
((ops.cumsum(attention_mask, axis=-1) - 1) * attention_mask)
ops.cast(cumsum_mask, "int32") * ops.cast(attention_mask, "int32")
)[:, None, :]
alibi = slopes[..., None] * ops.cast(arange_tensor, self.compute_dtype)
alibi = ops.expand_dims(
Expand Down
1 change: 0 additions & 1 deletion keras_nlp/src/models/phi3/phi3_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,7 +65,6 @@ def __init__(
self.kernel_initializer = keras.initializers.get(kernel_initializer)

def build(self, decoder_sequence_shape):

# Pre-attention layernorm.
self.pre_attention_layernorm = Phi3LayerNorm(
epsilon=self.layer_norm_epsilon,
Expand Down

0 comments on commit 3533413

Please sign in to comment.