diff --git a/backend/src/InterpolateArchs/DetectInterpolateArch.py b/backend/src/InterpolateArchs/DetectInterpolateArch.py index 242fbaa8..a1290ee7 100644 --- a/backend/src/InterpolateArchs/DetectInterpolateArch.py +++ b/backend/src/InterpolateArchs/DetectInterpolateArch.py @@ -39,6 +39,7 @@ def excluded_keys() -> tuple: "module.caltime.8.weight", "module.caltime.8.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -72,6 +73,7 @@ def excluded_keys() -> tuple: "module.caltime.8.weight", "module.caltime.8.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -102,6 +104,7 @@ def excluded_keys() -> tuple: "module.caltime.8.weight", "module.caltime.8.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -122,6 +125,7 @@ def excluded_keys() -> tuple: "module.encode.1.weight", "module.encode.1.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -142,6 +146,7 @@ def excluded_keys() -> tuple: "module.encode.1.weight", "module.encode.1.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -162,6 +167,7 @@ def excluded_keys() -> tuple: "module.encode.1.weight", "module.encode.1.bias", "module.block4.lastconv.0.bias", + "transformer.layers.4.self_attn.merge.weight", ] @@ -181,8 +187,28 @@ def excluded_keys() -> tuple: "module.encode.0.bias", "module.encode.1.weight", "module.encode.1.bias", + "transformer.layers.4.self_attn.merge.weight", ] +class GMFSS: + def __init__(): + pass + + def __name__(): + return "rife413" + + def unique_shapes() -> dict: + return {"transformer.layers.4.self_attn.merge.weight": "torch.Size([128, 128])"} + + def excluded_keys() -> tuple: + return [ + "module.encode.0.weight", + "module.encode.0.bias", + "module.encode.1.weight", + "module.encode.1.bias", + ] + + archs = [RIFE46, RIFE47, RIFE413, RIFE420, RIFE421, RIFE422lite, RIFE425]