Skip to content

Commit

Permalink
Deal with state translation
Browse files Browse the repository at this point in the history
  • Loading branch information
xuzhao9 committed Nov 18, 2024
1 parent 1f4fa82 commit 3b7dda0
Showing 1 changed file with 17 additions and 3 deletions.
20 changes: 17 additions & 3 deletions tritonbench/utils/triton_op.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,6 +521,19 @@ def __call__(cls, *args, **kwargs):
obj.__post__init__()
return obj

def _translate_mode(tb_args):
def _has_and_true(attr):
if hasattr(tb_args, attr) and attr:
return True
return False
if _has_and_true("fwd"):
tb_args.mode = "fwd"
if _has_and_true("bwd"):
tb_args.mode = "bwd"
if _has_and_true("fwd_bwd"):
tb_args.mode = "fwd_bwd"
if _has_and_true("fwd_no_grad"):
tb_args.mode = "fwd_no_grad"

class BenchmarkOperator(metaclass=PostInitProcessor):
mode: Mode = Mode.FWD
Expand Down Expand Up @@ -556,11 +569,12 @@ def __init__(
self.tb_args.cudagraph if self.tb_args.cudagraph else self.use_cuda_graphs
)
# we accept both "fwd" and "eval"
if self.tb_args.mode == "fwd" or self.tb_args.fwd:
_translate_mode(self.tb_args)
if self.tb_args.mode == "fwd":
self.mode = Mode.FWD
elif self.tb_args.mode == "fwd_bwd" or self.tb_args.fwd_bwd:
elif self.tb_args.mode == "fwd_bwd":
self.mode = Mode.FWD_BWD
elif self.tb_args.mode == "fwd_no_grad" or self.tb_args.fwd_no_grad:
elif self.tb_args.mode == "fwd_no_grad":
self.mode = Mode.FWD_NO_GRAD
else:
assert (
Expand Down

0 comments on commit 3b7dda0

Please sign in to comment.