Skip to content

Commit

Permalink
Rename the mask variable in _build_alibi_tensor.
Browse files Browse the repository at this point in the history
  • Loading branch information
SamanehSaadat committed May 23, 2024
1 parent e1e7bdd commit 316e122
Showing 1 changed file with 4 additions and 4 deletions.
8 changes: 4 additions & 4 deletions keras_nlp/src/models/falcon/falcon_transformer_decoder.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,16 +220,16 @@ def _compute_attention_mask(
else causal_mask
)

def _build_alibi_tensor(self, num_heads, attention_mask):
def _build_alibi_tensor(self, num_heads, mask):
slopes = ops.convert_to_tensor(
self._get_slopes(num_heads),
dtype=self.compute_dtype,
) # num_heads
attention_mask = ops.cast(attention_mask, dtype="int32")
mask = ops.cast(mask, dtype="int32")
# TODO: cumsum always outputs int64 in Keras 2 so the casting of cumsum
# result to int32 can be removed when keras 2 support is removed.
cumsum_mask = ops.cast(ops.cumsum(attention_mask, axis=-1) - 1, "int32")
arange_tensor = (cumsum_mask * attention_mask)[:, None, :]
cumsum_mask = ops.cast(ops.cumsum(mask, axis=-1) - 1, "int32")
arange_tensor = (cumsum_mask * mask)[:, None, :]
alibi = slopes[..., None] * ops.cast(arange_tensor, self.compute_dtype)
alibi = ops.expand_dims(
alibi, 0
Expand Down

0 comments on commit 316e122

Please sign in to comment.