diff --git a/optimum/exporters/utils.py b/optimum/exporters/utils.py index 60de169de5..51a90e3a2b 100644 --- a/optimum/exporters/utils.py +++ b/optimum/exporters/utils.py @@ -675,3 +675,15 @@ def _get_submodels_and_export_configs( export_config = next(iter(models_and_export_configs.values()))[1] return export_config, models_and_export_configs + + +class DisableCompileContextManager: + def __init__(self): + self._original_compile = torch.compile + + def __enter__(self): + # Turn torch.compile into a no-op + torch.compile = lambda *args, **kwargs: lambda x: x + + def __exit__(self, exc_type, exc_val, exc_tb): + torch.compile = self._original_compile