From 8018de0b43da8d66617f3ef10d3f2a41c1d78836 Mon Sep 17 00:00:00 2001 From: Katherine Crowson Date: Tue, 7 Jan 2025 20:56:57 +0000 Subject: [PATCH] Enable vmap, jvp, double backward for apply_rotary_emb_() --- k_diffusion/models/image_transformer_v2.py | 21 +++++++++++++++------ 1 file changed, 15 insertions(+), 6 deletions(-) diff --git a/k_diffusion/models/image_transformer_v2.py b/k_diffusion/models/image_transformer_v2.py index 303c91e..4f3410d 100644 --- a/k_diffusion/models/image_transformer_v2.py +++ b/k_diffusion/models/image_transformer_v2.py @@ -200,23 +200,32 @@ def _apply_rotary_emb_inplace(x, theta, conj): class ApplyRotaryEmbeddingInplace(torch.autograd.Function): - @staticmethod - def forward(x, theta, conj): - _apply_rotary_emb_inplace(x, theta, conj=conj) - return x + generate_vmap_rule = True @staticmethod def setup_context(ctx, inputs, output): - _, theta, conj = inputs + x, theta, conj = inputs + ctx.mark_dirty(x) ctx.save_for_backward(theta) + ctx.save_for_forward(theta) ctx.conj = conj + @staticmethod + def forward(x, theta, conj): + _apply_rotary_emb_inplace(x, theta, conj) + return x + @staticmethod def backward(ctx, grad_output): theta, = ctx.saved_tensors - _apply_rotary_emb_inplace(grad_output, theta, conj=not ctx.conj) + grad_output = ApplyRotaryEmbeddingInplace.apply(grad_output.clone(), theta, not ctx.conj) return grad_output, None, None + @staticmethod + def jvp(ctx, grad_input, _, __): + theta, = ctx.saved_tensors + return ApplyRotaryEmbeddingInplace.apply(grad_input, theta, ctx.conj) + def apply_rotary_emb_(x, theta): return ApplyRotaryEmbeddingInplace.apply(x, theta, False)