Skip to content

Commit

Permalink
Clippy and fmt
Browse files Browse the repository at this point in the history
  • Loading branch information
EricLBuehler committed Nov 21, 2024
1 parent 1fdc4d8 commit bbc886e
Showing 1 changed file with 5 additions and 6 deletions.
11 changes: 5 additions & 6 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -193,22 +193,21 @@ impl CausalMasker {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?
.to_dtype(DType::U8)?;

mask
masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?
.to_dtype(DType::U8)?
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
causal_mask = {
let mask = causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(

masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;
mask
)?
};

// IMPORTANT: this must match the logic in attention.rs.
Expand Down

0 comments on commit bbc886e

Please sign in to comment.