diff --git a/python/test_infra/test_ttir_ops.py b/python/test_infra/test_ttir_ops.py index 61fac59f6..aa18e1036 100644 --- a/python/test_infra/test_ttir_ops.py +++ b/python/test_infra/test_ttir_ops.py @@ -10,47 +10,47 @@ from ttmlir.ttir_builder import Operand, TTIRBuilder -@compile_to_flatbuffer([(128, 128)], test_name="test_exp") +@compile_to_flatbuffer([(128, 128)]) def test_exp(in0: Operand, builder: TTIRBuilder): return builder.exp(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_abs", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_abs(in0: Operand, builder: TTIRBuilder): return builder.abs(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_logical_not", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_logical_not(in0: Operand, builder: TTIRBuilder): return builder.logical_not(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_neg", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_neg(in0: Operand, builder: TTIRBuilder): return builder.neg(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_relu", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_relu(in0: Operand, builder: TTIRBuilder): return builder.relu(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_sqrt", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_sqrt(in0: Operand, builder: TTIRBuilder): return builder.sqrt(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_rsqrt", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_rsqrt(in0: Operand, builder: TTIRBuilder): return builder.rsqrt(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_sigmoid", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_sigmoid(in0: Operand, builder: TTIRBuilder): return builder.sigmoid(in0) -@compile_to_flatbuffer([(128, 128)], test_name="test_reciprocal", targets=["ttnn"]) +@compile_to_flatbuffer([(128, 128)], targets=["ttnn"]) def test_reciprocal(in0: Operand, builder: TTIRBuilder): return builder.reciprocal(in0) @@ -60,7 +60,6 @@ def test_reciprocal(in0: Operand, builder: TTIRBuilder): (64, 128), (64, 128), ], - test_name="test_add", ) def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.add(in0, in1) @@ -71,7 +70,6 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_multiply", ) def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.multiply(in0, in1) @@ -82,7 +80,6 @@ def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_logical_and", targets=["ttnn"], ) def test_logical_and(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -94,7 +91,6 @@ def test_logical_and(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_logical_or", targets=["ttnn"], ) def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -106,7 +102,6 @@ def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_subtract", targets=["ttnn"], ) def test_subtract(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -118,7 +113,6 @@ def test_subtract(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_eq", targets=["ttnn"], ) def test_eq(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -130,7 +124,6 @@ def test_eq(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_ne", targets=["ttnn"], ) def test_ne(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -142,7 +135,6 @@ def test_ne(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_ge", targets=["ttnn"], ) def test_ge(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -154,7 +146,6 @@ def test_ge(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_gt", targets=["ttnn"], ) def test_gt(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -166,7 +157,6 @@ def test_gt(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_le", targets=["ttnn"], ) def test_le(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -178,7 +168,6 @@ def test_le(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_lt", targets=["ttnn"], ) def test_lt(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -190,7 +179,6 @@ def test_lt(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_div", targets=["ttnn"], ) def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -202,7 +190,6 @@ def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder): (64, 64), (64, 64), ], - test_name="test_maximum", targets=["ttnn"], ) def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder): @@ -215,7 +202,6 @@ def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder): (32, 32), (32, 32), ], - test_name="test_arbitrary_op_chain", ) def test_arbitrary_op_chain( in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder diff --git a/python/test_infra/test_utils.py b/python/test_infra/test_utils.py index dabf9eac5..89193a70c 100644 --- a/python/test_infra/test_utils.py +++ b/python/test_infra/test_utils.py @@ -248,7 +248,7 @@ def ttmetal_to_flatbuffer( def compile_to_flatbuffer( inputs_shapes: List[Shape], - test_name: str, + test_name: Optional[str] = None, targets: List[str] = ["ttmetal", "ttnn"], module_dump: bool = False, ): @@ -272,9 +272,10 @@ def compile_to_flatbuffer( inputs_shapes: List[Shape] Shapes of the respective ranked tensor inputs of the test function. - test_name: str - The name of the decorated function. Used as the base name for dumped - files during the process + test_name: Optional[str] + The string to be used as the base name for dumped files throughout the + process. If `None` is provided, then the `__name__` of the decorated + function will be used. targets: List[str] A list that can only contain the following strings: 'ttnn' or @@ -297,6 +298,13 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): """ def decorator(test_fn: Callable): + + # Snoop the name of `test_fn` if no override to the test name is provided + if test_name is None: + test_base = test_fn.__name__ + else: + test_base = test_name + def wrapper(): # NOTE: since `ttir_to_tt{nn,metal} modifies the module in place, @@ -305,13 +313,13 @@ def wrapper(): if "ttmetal" in targets: module, builder = compile_as_mlir_module(test_fn, inputs_shapes) - module = ttir_to_ttmetal(module, builder, test_name + ".mlir") - ttmetal_to_flatbuffer(module, builder, test_name + ".ttm") + module = ttir_to_ttmetal(module, builder, test_base + ".mlir") + ttmetal_to_flatbuffer(module, builder, test_base + ".ttm") if "ttnn" in targets: module, builder = compile_as_mlir_module(test_fn, inputs_shapes) - module = ttir_to_ttnn(module, builder, test_name + ".mlir") - ttnn_to_flatbuffer(module, builder, test_name + ".ttnn") + module = ttir_to_ttnn(module, builder, test_base + ".mlir") + ttnn_to_flatbuffer(module, builder, test_base + ".ttnn") return wrapper