Skip to content

Commit

Permalink
Flux torch.compile fix (comfyanonymous#5082)
Browse files Browse the repository at this point in the history
  • Loading branch information
city96 authored Sep 28, 2024
1 parent 83b01f9 commit 8733191
Showing 1 changed file with 2 additions and 2 deletions.
4 changes: 2 additions & 2 deletions comfy/ldm/flux/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,8 +151,8 @@ def forward(self, x, timestep, context, y, guidance, control=None, **kwargs):
h_len = ((h + (patch_size // 2)) // patch_size)
w_len = ((w + (patch_size // 2)) // patch_size)
img_ids = torch.zeros((h_len, w_len, 3), device=x.device, dtype=x.dtype)
img_ids[..., 1] = img_ids[..., 1] + torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype)[:, None]
img_ids[..., 2] = img_ids[..., 2] + torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype)[None, :]
img_ids[:, :, 1] = torch.linspace(0, h_len - 1, steps=h_len, device=x.device, dtype=x.dtype).unsqueeze(1)
img_ids[:, :, 2] = torch.linspace(0, w_len - 1, steps=w_len, device=x.device, dtype=x.dtype).unsqueeze(0)
img_ids = repeat(img_ids, "h w c -> b (h w) c", b=bs)

txt_ids = torch.zeros((bs, context.shape[1], 3), device=x.device, dtype=x.dtype)
Expand Down

0 comments on commit 8733191

Please sign in to comment.