Skip to content

Commit

Permalink
Make bislerp work on GPU.
Browse files Browse the repository at this point in the history
  • Loading branch information
comfyanonymous committed Nov 14, 2023
1 parent 420beee commit c962884
Showing 1 changed file with 5 additions and 5 deletions.
10 changes: 5 additions & 5 deletions comfy/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -307,13 +307,13 @@ def slerp(b1, b2, r):
res[dot < 1e-5 - 1] = (b1 * (1.0-r) + b2 * r)[dot < 1e-5 - 1]
return res

def generate_bilinear_data(length_old, length_new):
coords_1 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32)
def generate_bilinear_data(length_old, length_new, device):
coords_1 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1))
coords_1 = torch.nn.functional.interpolate(coords_1, size=(1, length_new), mode="bilinear")
ratios = coords_1 - coords_1.floor()
coords_1 = coords_1.to(torch.int64)

coords_2 = torch.arange(length_old).reshape((1,1,1,-1)).to(torch.float32) + 1
coords_2 = torch.arange(length_old, dtype=torch.float32, device=device).reshape((1,1,1,-1)) + 1
coords_2[:,:,:,-1] -= 1
coords_2 = torch.nn.functional.interpolate(coords_2, size=(1, length_new), mode="bilinear")
coords_2 = coords_2.to(torch.int64)
Expand All @@ -323,7 +323,7 @@ def generate_bilinear_data(length_old, length_new):
h_new, w_new = (height, width)

#linear w
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new)
ratios, coords_1, coords_2 = generate_bilinear_data(w, w_new, samples.device)
coords_1 = coords_1.expand((n, c, h, -1))
coords_2 = coords_2.expand((n, c, h, -1))
ratios = ratios.expand((n, 1, h, -1))
Expand All @@ -336,7 +336,7 @@ def generate_bilinear_data(length_old, length_new):
result = result.reshape(n, h, w_new, c).movedim(-1, 1)

#linear h
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new)
ratios, coords_1, coords_2 = generate_bilinear_data(h, h_new, samples.device)
coords_1 = coords_1.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
coords_2 = coords_2.reshape((1,1,-1,1)).expand((n, c, -1, w_new))
ratios = ratios.reshape((1,1,-1,1)).expand((n, 1, -1, w_new))
Expand Down

0 comments on commit c962884

Please sign in to comment.