From 379780bb6ca3304d63bf8ca789d5bbce5949d0b5 Mon Sep 17 00:00:00 2001 From: Ross Wightman Date: Mon, 25 Sep 2023 23:30:56 -0700 Subject: [PATCH] Remove sdpa context mgrs --- timm/models/vision_transformer_packed.py | 34 +++++++++++------------- 1 file changed, 16 insertions(+), 18 deletions(-) diff --git a/timm/models/vision_transformer_packed.py b/timm/models/vision_transformer_packed.py index efda484f90..bc95093453 100644 --- a/timm/models/vision_transformer_packed.py +++ b/timm/models/vision_transformer_packed.py @@ -124,7 +124,7 @@ def pack_images( ): max_seq_len = max_grid_size[0] * max_grid_size[1] - # patchify if needed, generate position indices, apply patch drop, record seq lengths + # patchify, generate position indices, apply patch drop, record seq lengths img_tokens = [] img_pos_indices = [] img_seq_lens = [] @@ -144,6 +144,7 @@ def pack_images( indexing='ij'), dim=-1, ) + # FIXME patch drop here img_tokens.append(patches.flatten(0, 1)) img_pos_indices.append(pos_indices.flatten(0, 1)) img_seq_lens.append(seq_len) @@ -221,12 +222,11 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): attn_mask = attn_mask.expand((-1, self.num_heads, -1, -1)) if self.fused_attn: - with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): - x = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_drop.p, - ) + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p, + ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -374,12 +374,11 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): k = self.k_norm(k.view(B, N, self.num_heads, self.head_dim)).transpose(1, 2) v = v.view(B, N, self.num_heads, self.head_dim).transpose(1, 2) if self.fused_attn: - with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): - x_attn = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - dropout_p=self.attn_drop.p, - ) + x_attn = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + dropout_p=self.attn_drop.p, + ) else: q = q * self.scale attn = q @ k.transpose(-2, -1) @@ -507,11 +506,10 @@ def forward(self, x, attn_mask: Optional[torch.Tensor] = None): q = self.q_norm(q) k = self.k_norm(k) if False: - with torch.backends.cuda.sdp_kernel(enable_mem_efficient=False): - x = F.scaled_dot_product_attention( - q, k, v, - attn_mask=attn_mask, - ) + x = F.scaled_dot_product_attention( + q, k, v, + attn_mask=attn_mask, + ) else: q = q * self.scale attn = q @ k.transpose(-2, -1)