Skip to content

Commit

Permalink
Faster CUDA prompt speeds (#925)
Browse files Browse the repository at this point in the history
* Faster cuda attnmask impl

* Faster cuda pagedattn pp speeds

* Clippy and fmt
  • Loading branch information
EricLBuehler authored Nov 21, 2024
1 parent 200d313 commit 2b20951
Show file tree
Hide file tree
Showing 33 changed files with 155 additions and 129 deletions.
5 changes: 1 addition & 4 deletions mistralrs-core/src/attention.rs
Original file line number Diff line number Diff line change
Expand Up @@ -180,10 +180,7 @@ impl Sdpa {
let k = k.flatten(0, 1)?;
let q = q.flatten(0, 1)?;
let v = v.flatten(0, 1)?;
let attention_bias = match mask {
Some(mask) => Some(mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?),
None => None,
};
let attention_bias = mask.cloned();

// If attention_bias is set, we fuse the add by giving it as the output matrix
// and setting beta to 1.0
Expand Down
160 changes: 78 additions & 82 deletions mistralrs-core/src/layers_masker.rs
Original file line number Diff line number Diff line change
Expand Up @@ -136,38 +136,39 @@ impl CausalMasker {
input_ids: &Tensor,
cache: &dyn PastKvLenCache,
dtype: DType,
n_attn_heads: usize,
) -> Result<Option<Tensor>> {
let past_kv_len = cache.get_past_kv_len()?;
let (_b_sz, tgt_len) = input_ids.dims2()?;
if tgt_len == 1 {
return Ok(None);
}

let causal_mask = {
let mask = self
.make_mask(tgt_len, past_kv_len, input_ids.device())?
.to_dtype(DType::U8)?;
Some(mask)
};
let mut causal_mask = self
.make_mask(tgt_len, past_kv_len, input_ids.device())?
.to_dtype(DType::U8)?;

let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask = mask.broadcast_as((mask.dims()[0], mask.dims()[1]))?;
causal_mask = {
let mut mask =
causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(
mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;

Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Some(mask?)
} else {
None
mask
};
Ok(mask)

// IMPORTANT: this must match the logic in attention.rs.
if causal_mask.device().is_cuda()
&& !get_use_matmul_via_f16()
&& CUBLASLT_HANDLE.lock().unwrap().is_some()
{
causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}
Ok(Some(causal_mask))
}

pub fn make_sliding_window_causal_mask_matrix(
Expand All @@ -176,9 +177,10 @@ impl CausalMasker {
cache: &dyn PastKvLenCache,
sliding_window: Option<usize>,
dtype: DType,
n_attn_heads: usize,
) -> Result<Option<Tensor>> {
if sliding_window.is_none() {
return self.make_causal_mask_matrix(input_ids, cache, dtype);
return self.make_causal_mask_matrix(input_ids, cache, dtype, n_attn_heads);
}
let sliding_window = sliding_window.unwrap();
let past_kv_len = cache.get_past_kv_len()?;
Expand All @@ -187,34 +189,35 @@ impl CausalMasker {
return Ok(None);
}

let causal_mask = {
let mut causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?
.to_dtype(DType::U8)?;

Some(mask)
masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?
.to_dtype(DType::U8)?
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask = mask.broadcast_as((mask.dims()[0], mask.dims()[1]))?;
causal_mask = {
let mask = causal_mask.broadcast_as((causal_mask.dims()[0], causal_mask.dims()[1]))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(

masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;

Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
Some(mask?)
} else {
None
)?
};
Ok(mask)

// IMPORTANT: this must match the logic in attention.rs.
if causal_mask.device().is_cuda()
&& !get_use_matmul_via_f16()
&& CUBLASLT_HANDLE.lock().unwrap().is_some()
{
causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}
Ok(Some(causal_mask))
}

#[deprecated(
Expand All @@ -234,41 +237,38 @@ impl CausalMasker {
return Ok(None);
}

let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mask = mask
let mut causal_mask = {
let mut mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;
Some(mask)
mask
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask =
mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?;
causal_mask = {
let mut mask = causal_mask.broadcast_as((
causal_mask.dims()[0],
n_attn_heads,
causal_mask.dims()[2],
causal_mask.dims()[3],
))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(
mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;

Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
let mut mask = mask?;
// IMPORTANT: this must match the logic in attention.rs
if mask.device().is_cuda()
&& CUBLASLT_HANDLE.lock().unwrap().is_some()
&& !get_use_matmul_via_f16()
{
mask = mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}
Some(mask)
} else {
None
mask
};
Ok(mask)

// IMPORTANT: this must match the logic in attention.rs. Assume the cublaslt handle will be initialized
if causal_mask.device().is_cuda() && !get_use_matmul_via_f16() {
causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}

Ok(Some(causal_mask))
}

#[deprecated(
Expand All @@ -294,45 +294,41 @@ impl CausalMasker {
return Ok(None);
}

let causal_mask = {
let mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let mut causal_mask = {
let mut mask = self.make_mask(tgt_len, past_kv_len, input_ids.device())?;
let diagonal = past_kv_len as isize - sliding_window as isize - 1;
let context_mask = apply_tril(&mask.ones_like()?, diagonal)?;
let mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
let mask = mask
mask = masked_fill(&mask.to_dtype(DType::F32)?, &context_mask, f32::MIN)?;
mask = mask
.expand((b_sz, 1, tgt_len, tgt_len + past_kv_len))?
.to_dtype(DType::U8)?;

Some(mask)
mask
};

let zero = Tensor::new(0.0f32, input_ids.device())?;
let causal_mask: Option<Result<Tensor>> = causal_mask.map(|mask| {
let mask =
mask.broadcast_as((mask.dims()[0], n_attn_heads, mask.dims()[2], mask.dims()[3]))?;
causal_mask = {
let mut mask = causal_mask.broadcast_as((
causal_mask.dims()[0],
n_attn_heads,
causal_mask.dims()[2],
causal_mask.dims()[3],
))?;
// Mask: 1 means use from x (add 0.0), 0 means mask out (add -inf)
let mask = masked_fill(
mask = masked_fill(
&zero.to_dtype(dtype)?.broadcast_as(mask.shape())?,
&mask,
f32::NEG_INFINITY,
)?;

Ok(mask)
});
let mask: Option<Tensor> = if let Some(mask) = causal_mask {
let mut mask = mask?;
// IMPORTANT: this must match the logic in attention.rs
if mask.device().is_cuda()
&& CUBLASLT_HANDLE.lock().unwrap().is_some()
&& !get_use_matmul_via_f16()
{
mask = mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}
Some(mask)
} else {
None
mask
};
Ok(mask)

// IMPORTANT: this must match the logic in attention.rs. Assume the cublaslt handle will be initialized
if causal_mask.device().is_cuda() && !get_use_matmul_via_f16() {
causal_mask = causal_mask.unsqueeze(0)?.repeat((n_attn_heads, 1, 1))?;
}

Ok(Some(causal_mask))
}

pub fn apply_mask_one_and_zero(
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/gemma.rs
Original file line number Diff line number Diff line change
Expand Up @@ -565,6 +565,7 @@ impl Model {
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
xs.dtype(),
self.cfg.num_attn_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
9 changes: 7 additions & 2 deletions mistralrs-core/src/models/gemma2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -606,13 +606,18 @@ impl Model {
let xs = self.embed_tokens.forward(input_ids)?;
let mut xs = (xs * (self.hidden_size as f64).sqrt())?;
let cache = &mut self.cache.normal().0;
let attention_mask =
CausalMasker.make_causal_mask_matrix(input_ids, &*cache, xs.dtype())?;
let attention_mask = CausalMasker.make_causal_mask_matrix(
input_ids,
&*cache,
xs.dtype(),
self.cfg.num_attn_heads,
)?;
let sliding_attention_mask = CausalMasker.make_sliding_window_causal_mask_matrix(
input_ids,
&*cache,
Some(self.sliding_window),
xs.dtype(),
self.cfg.num_attn_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -414,6 +414,7 @@ impl Llama {
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
x.dtype(),
self.blocks[0].attn.num_attention_heads,
)?;
for (block_idx, block) in self.blocks.iter().enumerate() {
x = self.mapper.map(x, block_idx)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/mistral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -603,6 +603,7 @@ impl Model {
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.cfg.num_attn_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/mixtral.rs
Original file line number Diff line number Diff line change
Expand Up @@ -609,6 +609,7 @@ impl Model {
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.cfg.num_attn_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -548,6 +548,7 @@ impl Model {
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
xs.dtype(),
self.cfg.num_attn_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -542,6 +542,7 @@ impl Model {
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.cfg.num_attn_heads,
)?;

for (i, layer) in self.layers.iter().enumerate() {
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/phi3_5_moe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -664,6 +664,7 @@ impl Model {
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.cfg.num_attn_heads,
)?;

for (i, layer) in self.layers.iter().enumerate() {
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_llama.rs
Original file line number Diff line number Diff line change
Expand Up @@ -668,6 +668,7 @@ impl ModelWeights {
.map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter().enumerate() {
if let Some(ref mapper) = self.mapper {
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_phi2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,6 +358,7 @@ impl ModelWeights {
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_phi3.rs
Original file line number Diff line number Diff line change
Expand Up @@ -375,6 +375,7 @@ impl ModelWeights {
.unwrap_or(cache as &dyn PastKvLenCache),
Some(self.max_seq_len),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter().enumerate() {
if let Some(ref mapper) = self.mapper {
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -389,6 +389,7 @@ impl ModelWeights {
.map(|(_, _)| &start_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter().enumerate() {
if let Some(ref mapper) = self.mapper {
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/quantized_starcoder2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -376,6 +376,7 @@ impl ModelWeights {
.map(|(_, _)| &seqlen_offsets as &dyn PastKvLenCache)
.unwrap_or(cache as &dyn PastKvLenCache),
DType::F32,
self.layers[0].n_head,
)?;
for (i, layer) in self.layers.iter().enumerate() {
if let Some(ref mapper) = self.mapper {
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/qwen2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -549,6 +549,7 @@ impl Model {
.unwrap_or(cache as &dyn PastKvLenCache),
Some(self.sliding_window),
xs.dtype(),
self.cfg.num_attn_heads,
)?;
for (i, layer) in self.layers.iter().enumerate() {
xs = self.mapper.map(xs, i)?;
Expand Down
1 change: 1 addition & 0 deletions mistralrs-core/src/models/starcoder2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,7 @@ impl Model {
.unwrap_or(cache as &dyn PastKvLenCache),
self.sliding_window,
xs.dtype(),
self.cfg.num_attn_heads,
)?;

for (i, layer) in self.layers.iter().enumerate() {
Expand Down
Loading

0 comments on commit 2b20951

Please sign in to comment.