Skip to content

Commit

Permalink
Fix jit.script breaking with features_fx
Browse files Browse the repository at this point in the history
  • Loading branch information
dsuess committed Jun 28, 2024
1 parent d4ef0b4 commit 197c104
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 0 deletions.
32 changes: 32 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -631,3 +631,35 @@ def test_model_forward_fx_torchscript(model_name, batch_size):

assert outputs.shape[0] == batch_size
assert not torch.isnan(outputs).any(), 'Output included NaNs'

@pytest.mark.timeout(120)
@pytest.mark.parametrize('model_name', ["regnetx_002"])
@pytest.mark.parametrize('batch_size', [1])
def test_model_forward_torchscript_with_features_fx(model_name, batch_size):
"""Create a model with feature extraction based on fx, script it, and run
a single forward pass"""
if not has_fx_feature_extraction:
pytest.skip("Can't test FX. Torch >= 1.10 and Torchvision >= 0.11 are required.")

allowed_models = list_models(
exclude_filters=EXCLUDE_FILTERS + EXCLUDE_JIT_FILTERS + EXCLUDE_FX_JIT_FILTERS,
name_matches_cfg=True
)
assert model_name in allowed_models, f"{model_name=} not supported for this test"

input_size = _get_input_size(model_name=model_name, target=TARGET_JIT_SIZE)
assert max(input_size) <= MAX_JIT_SIZE, "Fixed input size model > limit. Pick a different model to run this test"

with set_scriptable(True):
model = create_model(model_name, pretrained=False, features_only=True, feature_cfg={"feature_cls": "fx"})
model.eval()

model = torch.jit.script(model)
with torch.no_grad():
outputs = model(torch.randn((batch_size, *input_size)))

assert isinstance(outputs, list)

for tensor in outputs:
assert tensor.shape[0] == batch_size
assert not torch.isnan(tensor).any(), 'Output included NaNs'
4 changes: 4 additions & 0 deletions timm/models/_features_fx.py
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,8 @@ def create_feature_extractor(model: nn.Module, return_nodes: Union[Dict[str, str
class FeatureGraphNet(nn.Module):
""" A FX Graph based feature extractor that works with the model feature_info metadata
"""
return_dict: torch.jit.Final[bool]

def __init__(
self,
model: nn.Module,
Expand Down Expand Up @@ -155,6 +157,8 @@ class GraphExtractNet(nn.Module):
squeeze_out: if only one output, and output in list format, flatten to single tensor
return_dict: return as dictionary from extractor with node names as keys, ignores squeeze_out arg
"""
return_dict: torch.jit.Final[bool]

def __init__(
self,
model: nn.Module,
Expand Down

0 comments on commit 197c104

Please sign in to comment.