Skip to content

Commit

Permalink
Remove sdpa context mgrs
Browse files Browse the repository at this point in the history
  • Loading branch information
rwightman committed Sep 26, 2023
1 parent 2734bb7 commit 379780b
Showing 1 changed file with 16 additions and 18 deletions.
34 changes: 16 additions & 18 deletions timm/models/vision_transformer_packed.py
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,7 @@ def pack_images(
):
max_seq_len = max_grid_size[0] * max_grid_size[1]

# patchify if needed, generate position indices, apply patch drop, record seq lengths
# patchify, generate position indices, apply patch drop, record seq lengths
img_tokens = []
img_pos_indices = []
img_seq_lens = []
Expand All @@ -144,6 +144,7 @@ def pack_images(
indexing='ij'),
dim=-1,
)
# FIXME patch drop here
img_tokens.append(patches.flatten(0, 1))
img_pos_indices.append(pos_indices.flatten(0, 1))
img_seq_lens.append(seq_len)
Expand Down Expand Up @@ -221,12 +222,11 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1))

if self.fused_attn:
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
Expand Down Expand Up @@ -374,12 +374,11 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2)
v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2)
if self.fused_attn:
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
x_attn = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
x_attn = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
dropout_p=self.attn_drop.p,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
Expand Down Expand Up @@ -507,11 +506,10 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None):
q = self.q_norm(q)
k = self.k_norm(k)
if False:
with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False):
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
)
x = F.scaled_dot_product_attention(
q, k, v,
attn_mask=attn_mask,
)
else:
q = q * self.scale
attn = q @ k.transpose(-2, -1)
Expand Down

0 comments on commit 379780b

Please sign in to comment.