From 197c10463b901bf4a0112d0d0161ed30b94a45c5 Mon Sep 17 00:00:00 2001 From: Daniel Suess Date: Fri, 28 Jun 2024 03:57:06 +0000 Subject: [PATCH] Fix jit.script breaking with features_fx --- tests/test_models.py | 32 ++++++++++++++++++++++++++++++++ timm/models/_features_fx.py | 4 ++++ 2 files changed, 36 insertions(+) diff --git a/tests/test_models.py b/tests/test_models.py index 9f7a91546c..030fb255d1 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -631,3 +631,35 @@ def test_model_forward_fx_torchscript(model_name, batch_size): assert outputs.shape[0] == batch_size assert not torch.isnan(outputs).any(), 'Output included NaNs' + + @pytest.mark.timeout(120) + @pytest.mark.parametrize('model_name', ["regnetx_002"]) + @pytest.mark.parametrize('batch_size', [1]) + def test_model_forward_torchscript_with_features_fx(model_name, batch_size): + """Create a model with feature extraction based on fx, script it, and run + a single forward pass""" + if not has_fx_feature_extraction: + pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.") + + allowed_models = list_models( + exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS, + name_matches_cfg=True + ) + assert model_name in allowed_models, f"{model_name=} not supported for this test" + + input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE) + assert max(input_size) <= MAX_JIT_SIZE, "Fixed input size model > limit. Pick a different model to run this test" + + with set_scriptable(True): + model = create_model(model_name, pretrained=False, features_only=True, feature_cfg={"feature_cls": "fx"}) + model.eval() + + model = torch.jit.script(model) + with torch.no_grad(): + outputs = model(torch.randn((batch_size, *input_size))) + + assert isinstance(outputs, list) + + for tensor in outputs: + assert tensor.shape[0] == batch_size + assert not torch.isnan(tensor).any(), 'Output included NaNs' \ No newline at end of file diff --git a/timm/models/_features_fx.py b/timm/models/_features_fx.py index 1ea4a4f4a1..6679b38b46 100644 --- a/timm/models/_features_fx.py +++ b/timm/models/_features_fx.py @@ -116,6 +116,8 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str class FeatureGraphNet(nn.Module): """ A FX Graph based feature extractor that works with the model feature_info metadata """ + return_dict: torch.jit.Final[bool] + def __init__( self, model: nn.Module, @@ -155,6 +157,8 @@ class GraphExtractNet(nn.Module): squeeze_out: if only one output, and output in list format, flatten to single tensor return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg """ + return_dict: torch.jit.Final[bool] + def __init__( self, model: nn.Module,