diff --git a/tritonbench/operator_loader/__init__.py b/tritonbench/operator_loader/__init__.py index 1f73fabc..e2332422 100644 --- a/tritonbench/operator_loader/__init__.py +++ b/tritonbench/operator_loader/__init__.py @@ -4,7 +4,6 @@ from typing import Any, Generator, List, Optional import torch -from torch._dynamo.backends.cudagraphs import cudagraphs_inner from torch._inductor.utils import gen_gm_and_inputs from torch._ops import OpOverload from torch.utils._pytree import tree_map_only @@ -85,6 +84,9 @@ def __init__( ), f"AtenOpBenchmark only supports fp16 and fp32, but got {self.dtype}" def get_input_iter(self) -> Generator: + from torch._dynamo.backends.cudagraphs import cudagraphs_inner + from torch._inductor.compile_fx import compile_fx + inps_gens = [self.huggingface_loader, self.torchbench_loader, self.timm_loader] for inp_gen in inps_gens: for inp in inp_gen.get_inputs_for_operator( @@ -101,9 +103,6 @@ def get_input_iter(self) -> Generator: "aten::convolution_backward", ) if self.device == "cuda": - from torch._inductor.compile_fx import compile_fx - - cudagraph_eager = cudagraphs_inner( gm, gm_args, copy_outputs=False, copy_inputs=False )