From 1bdd54febf414c0cc166267ba14bff353de3cea8 Mon Sep 17 00:00:00 2001 From: TNTwise Date: Mon, 7 Oct 2024 18:07:49 -0500 Subject: [PATCH] on my way to gmfss support maybe --- .../InterpolateArchs/DetectInterpolateArch.py | 26 +++++++++++++++++++ 1 file changed, 26 insertions(+) diff --git a/backend/src/InterpolateArchs/DetectInterpolateArch.py b/backend/src/InterpolateArchs/DetectInterpolateArch.py index 242fbaa8..a1290ee7 100644 --- a/backend/src/InterpolateArchs/DetectInterpolateArch.py +++ b/backend/src/InterpolateArchs/DetectInterpolateArch.py @@ -39,6 +39,7 @@ def excluded_keys() -> tuple: "module.caltime.8.weight", "module.caltime.8.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -72,6 +73,7 @@ def excluded_keys() -> tuple: "module.caltime.8.weight", "module.caltime.8.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -102,6 +104,7 @@ def excluded_keys() -> tuple: "module.caltime.8.weight", "module.caltime.8.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -122,6 +125,7 @@ def excluded_keys() -> tuple: "module.encode.1.weight", "module.encode.1.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -142,6 +146,7 @@ def excluded_keys() -> tuple: "module.encode.1.weight", "module.encode.1.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -162,6 +167,7 @@ def excluded_keys() -> tuple: "module.encode.1.weight", "module.encode.1.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -181,8 +187,28 @@ def excluded_keys() -> tuple: "module.encode.0.bias", "module.encode.1.weight", "module.encode.1.bias", + "transformer.layers.4.self_attn.merge.weight", ] +class GMFSS: + def __init__(): + pass + + def __name__(): + return "rife413" + + def unique_shapes() -> dict: + return {"transformer.layers.4.self_attn.merge.weight": "torch.Size([128, 128])"} + + def excluded_keys() -> tuple: + return [ + "module.encode.0.weight", + "module.encode.0.bias", + "module.encode.1.weight", + "module.encode.1.bias", + ] + + archs = [RIFE46, RIFE47, RIFE413, RIFE420, RIFE421, RIFE422lite, RIFE425]