Skip to content

Commit

Permalink
pass in subsampled positions
Browse files Browse the repository at this point in the history
  • Loading branch information
ameroyer committed Dec 23, 2024
1 parent bfa3789 commit af11cf8
Showing 1 changed file with 33 additions and 12 deletions.
45 changes: 33 additions & 12 deletions candle-transformers/src/models/pixtral/vision_model.rs
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ impl Attention {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let (b, patches, _) = xs.dims3()?;
Expand All @@ -116,7 +117,8 @@ impl Attention {
let key_states = key_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;
let value_states = value_states.reshape(shape)?.transpose(1, 2)?.contiguous()?;

let (query_states, key_states) = emb.apply_rotary_emb_qkv(&query_states, &key_states)?;
let (query_states, key_states) =
emb.apply_rotary_emb_qkv(&query_states, &key_states, subsampled_positions)?;
let attn_weights = (query_states.matmul(&key_states.t()?)? * self.scale)?;

let attn_weights = match attention_mask {
Expand Down Expand Up @@ -189,12 +191,16 @@ impl AttentionLayer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let residual = xs;
let xs = self
.attention
.forward(&xs.apply(&self.attention_norm)?, emb, attention_mask)?;
let xs = self.attention.forward(
&xs.apply(&self.attention_norm)?,
emb,
subsampled_positions,
attention_mask,
)?;
let xs = (residual + xs)?;
let residual = &xs;
let xs = xs.apply(&self.ffn_norm)?.apply(&self.feed_forward)?;
Expand Down Expand Up @@ -222,11 +228,12 @@ impl Transformer {
&self,
xs: &Tensor,
emb: &RotaryEmbedding,
subsampled_positions: Option<&Tensor>,
attention_mask: Option<&Tensor>,
) -> Result<Tensor> {
let mut xs = xs.clone();
for layer in self.layers.iter() {
xs = layer.forward(&xs, emb, attention_mask)?
xs = layer.forward(&xs, emb, subsampled_positions, attention_mask)?
}
Ok(xs)
}
Expand Down Expand Up @@ -270,10 +277,20 @@ impl RotaryEmbedding {
Ok(Self { cos, sin })
}

fn apply_rotary_emb_qkv(&self, q: &Tensor, k: &Tensor) -> Result<(Tensor, Tensor)> {
fn apply_rotary_emb_qkv(
&self,
q: &Tensor,
k: &Tensor,
subsampled_positions: Option<&Tensor>,
) -> Result<(Tensor, Tensor)> {
let (_b_sz, _h, _seq_len, _n_embd) = q.dims4()?;
let cos = &self.cos;
let sin = &self.sin;
let (cos, sin) = match subsampled_positions {
None => (&self.cos, &self.sin),
Some(pos) => (
&self.cos.index_select(pos, 0)?,
&self.sin.index_select(pos, 0)?,
),
};
let q_embed = candle_nn::rotary_emb::rope(q, cos, sin)?;
let k_embed = candle_nn::rotary_emb::rope(k, cos, sin)?;
Ok((q_embed, k_embed))
Expand Down Expand Up @@ -333,13 +350,17 @@ impl Model {
impl Module for Model {
fn forward(&self, xs: &Tensor) -> Result<Tensor> {
let patch_embeds = xs.apply(&self.patch_conv)?;
let susampled_positions = self.position_ids_in_meshgrid(
let subsampled_positions = Some(self.position_ids_in_meshgrid(
patch_embeds.dim(2)?,
patch_embeds.dim(3)?,
patch_embeds.device(),
)?;
)?);
let patch_embeds = patch_embeds.flatten_from(2)?.t()?.apply(&self.ln_pre)?;
self.transformer
.forward(&patch_embeds, &self.patch_positional_embedding, None)
self.transformer.forward(
&patch_embeds,
&self.patch_positional_embedding,
subsampled_positions.as_ref(),
None,
)
}
}

0 comments on commit af11cf8

Please sign in to comment.