Skip to content

Commit

Permalink
quick fix to continue with issue 71 (#73)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: #73

follow up diff

Reviewed By: FindHao

Differential Revision: D66372865

fbshipit-source-id: 94716dd4701949b54bd55a60ed87a99cadbf95e3
  • Loading branch information
adamomainz authored and facebook-github-bot committed Nov 22, 2024
1 parent e8f5ba4 commit 648466b
Showing 1 changed file with 3 additions and 4 deletions.
7 changes: 3 additions & 4 deletions tritonbench/operator_loader/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@
from typing import Any, Generator, List, Optional

import torch
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
from torch._inductor.utils import gen_gm_and_inputs
from torch._ops import OpOverload
from torch.utils._pytree import tree_map_only
Expand Down Expand Up @@ -85,6 +84,9 @@ def __init__(
), f"AtenOpBenchmark only supports fp16 and fp32, but got {self.dtype}"

def get_input_iter(self) -> Generator:
from torch._dynamo.backends.cudagraphs import cudagraphs_inner
from torch._inductor.compile_fx import compile_fx

inps_gens = [self.huggingface_loader, self.torchbench_loader, self.timm_loader]
for inp_gen in inps_gens:
for inp in inp_gen.get_inputs_for_operator(
Expand All @@ -101,9 +103,6 @@ def get_input_iter(self) -> Generator:
"aten::convolution_backward",
)
if self.device == "cuda":
from torch._inductor.compile_fx import compile_fx


cudagraph_eager = cudagraphs_inner(
gm, gm_args, copy_outputs=False, copy_inputs=False
)
Expand Down

0 comments on commit 648466b

Please sign in to comment.