diff --git a/python/test_infra/test_ttir_ops.py b/python/test_infra/test_ttir_ops.py new file mode 100644 index 000000000..61fac59f6 --- /dev/null +++ b/python/test_infra/test_ttir_ops.py @@ -0,0 +1,235 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 + +# RUN: %python %s + +import inspect + +from ttmlir.test_utils import compile_to_flatbuffer +from ttmlir.ttir_builder import Operand, TTIRBuilder + + +@compile_to_flatbuffer([(128, 128)], test_name="test_exp") +def test_exp(in0: Operand, builder: TTIRBuilder): + return builder.exp(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_abs", targets=["ttnn"]) +def test_abs(in0: Operand, builder: TTIRBuilder): + return builder.abs(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_logical_not", targets=["ttnn"]) +def test_logical_not(in0: Operand, builder: TTIRBuilder): + return builder.logical_not(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_neg", targets=["ttnn"]) +def test_neg(in0: Operand, builder: TTIRBuilder): + return builder.neg(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_relu", targets=["ttnn"]) +def test_relu(in0: Operand, builder: TTIRBuilder): + return builder.relu(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_sqrt", targets=["ttnn"]) +def test_sqrt(in0: Operand, builder: TTIRBuilder): + return builder.sqrt(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_rsqrt", targets=["ttnn"]) +def test_rsqrt(in0: Operand, builder: TTIRBuilder): + return builder.rsqrt(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_sigmoid", targets=["ttnn"]) +def test_sigmoid(in0: Operand, builder: TTIRBuilder): + return builder.sigmoid(in0) + + +@compile_to_flatbuffer([(128, 128)], test_name="test_reciprocal", targets=["ttnn"]) +def test_reciprocal(in0: Operand, builder: TTIRBuilder): + return builder.reciprocal(in0) + + +@compile_to_flatbuffer( + [ + (64, 128), + (64, 128), + ], + test_name="test_add", +) +def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.add(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_multiply", +) +def test_multiply(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.multiply(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_logical_and", + targets=["ttnn"], +) +def test_logical_and(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.logical_and(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_logical_or", + targets=["ttnn"], +) +def test_logical_or(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.logical_or(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_subtract", + targets=["ttnn"], +) +def test_subtract(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.subtract(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_eq", + targets=["ttnn"], +) +def test_eq(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.eq(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_ne", + targets=["ttnn"], +) +def test_ne(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.ne(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_ge", + targets=["ttnn"], +) +def test_ge(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.ge(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_gt", + targets=["ttnn"], +) +def test_gt(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.gt(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_le", + targets=["ttnn"], +) +def test_le(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.le(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_lt", + targets=["ttnn"], +) +def test_lt(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.lt(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_div", + targets=["ttnn"], +) +def test_div(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.div(in0, in1) + + +@compile_to_flatbuffer( + [ + (64, 64), + (64, 64), + ], + test_name="test_maximum", + targets=["ttnn"], +) +def test_maximum(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.maximum(in0, in1) + + +@compile_to_flatbuffer( + [ + (32, 32), + (32, 32), + (32, 32), + ], + test_name="test_arbitrary_op_chain", +) +def test_arbitrary_op_chain( + in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder +): + add = builder.add(in0, in1) + exp = builder.exp(in2) + return builder.multiply(add, exp) + + +if __name__ == "__main__": + test_functions = inspect.getmembers( + inspect.getmodule(inspect.currentframe()), inspect.isfunction + ) + + for function_name, func in test_functions: + if function_name.startswith("test_"): + func() diff --git a/python/test_infra/test_ttir_ops_ttmetal.py b/python/test_infra/test_ttir_ops_ttmetal.py deleted file mode 100644 index 713cdae32..000000000 --- a/python/test_infra/test_ttir_ops_ttmetal.py +++ /dev/null @@ -1,58 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -# RUN: %python %s - -import inspect - -from ttmlir.test_utils import ( - compile_as_mlir_module, - ttmetal_to_flatbuffer, - ttir_to_ttmetal, -) -from ttmlir.ttir_builder import Operand, TTIRBuilder - - -@ttmetal_to_flatbuffer(output_file_name="test_exp.ttm") -@ttir_to_ttmetal(output_file_name="test_exp.mlir") -@compile_as_mlir_module((128, 128)) -def test_exp_ttmetal(in0: Operand, builder: TTIRBuilder): - return builder.exp(in0) - - -@ttmetal_to_flatbuffer(output_file_name="test_add.ttm") -@ttir_to_ttmetal(output_file_name="test_add.mlir") -@compile_as_mlir_module((64, 128), (64, 128)) -def test_add_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.add(in0, in1) - - -@ttmetal_to_flatbuffer(output_file_name="test_multiply.ttm") -@ttir_to_ttmetal(output_file_name="test_multiply.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_multiply_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.multiply(in0, in1) - - -@ttmetal_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttm") -@ttir_to_ttmetal(output_file_name="test_arbitrary_op_chain.mlir") -@compile_as_mlir_module((32, 32), (32, 32), (32, 32)) -def test_arbitrary_op_chain_ttmetal( - in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder -): - add = builder.add(in0, in1) - exp = builder.exp(in2) - mul = builder.multiply(add, exp) - in3 = builder.empty(builder.get_shape(mul)) - return builder.multiply(mul, in3) - - -if __name__ == "__main__": - test_functions = inspect.getmembers( - inspect.getmodule(inspect.currentframe()), inspect.isfunction - ) - - for function_name, func in test_functions: - if function_name.startswith("test_"): - func() diff --git a/python/test_infra/test_ttir_ops_ttnn.py b/python/test_infra/test_ttir_ops_ttnn.py deleted file mode 100644 index 417c994f5..000000000 --- a/python/test_infra/test_ttir_ops_ttnn.py +++ /dev/null @@ -1,189 +0,0 @@ -# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC -# -# SPDX-License-Identifier: Apache-2.0 - -# RUN: %python %s - -import inspect - -from ttmlir.test_utils import ( - compile_as_mlir_module, - ttnn_to_flatbuffer, - ttir_to_ttnn, -) -from ttmlir.ttir_builder import Operand, TTIRBuilder - - -@ttnn_to_flatbuffer(output_file_name="test_exp.ttnn") -@ttir_to_ttnn(output_file_name="test_exp.mlir") -@compile_as_mlir_module((128, 128)) -def test_exp_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.exp(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_abs.ttnn") -@ttir_to_ttnn(output_file_name="test_abs.mlir") -@compile_as_mlir_module((128, 128)) -def test_abs_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.abs(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_logical_not.ttnn") -@ttir_to_ttnn(output_file_name="test_logical_not.mlir") -@compile_as_mlir_module((128, 128)) -def test_logical_not_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.logical_not(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_neg.ttnn") -@ttir_to_ttnn(output_file_name="test_neg.mlir") -@compile_as_mlir_module((128, 128)) -def test_neg_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.neg(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_relu.ttnn") -@ttir_to_ttnn(output_file_name="test_relu.mlir") -@compile_as_mlir_module((128, 128)) -def test_relu_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.relu(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_sqrt.ttnn") -@ttir_to_ttnn(output_file_name="test_sqrt.mlir") -@compile_as_mlir_module((128, 128)) -def test_sqrt_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.sqrt(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_rsqrt.ttnn") -@ttir_to_ttnn(output_file_name="test_rsqrt.mlir") -@compile_as_mlir_module((128, 128)) -def test_rsqrt_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.rsqrt(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_sigmoid.ttnn") -@ttir_to_ttnn(output_file_name="test_sigmoid.mlir") -@compile_as_mlir_module((128, 128)) -def test_sigmoid_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.sigmoid(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_reciprocal.ttnn") -@ttir_to_ttnn(output_file_name="test_reciprocal.mlir") -@compile_as_mlir_module((128, 128)) -def test_reciprocal_ttnn(in0: Operand, builder: TTIRBuilder): - return builder.reciprocal(in0) - - -@ttnn_to_flatbuffer(output_file_name="test_add.ttnn") -@ttir_to_ttnn(output_file_name="test_add.mlir") -@compile_as_mlir_module((64, 128), (64, 128)) -def test_add_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.add(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_multiply.ttnn") -@ttir_to_ttnn(output_file_name="test_multiply.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_multiply_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.multiply(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_logical_and.ttnn") -@ttir_to_ttnn(output_file_name="test_logical_and.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_logical_and_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.logical_and(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_logical_or.ttnn") -@ttir_to_ttnn(output_file_name="test_logical_or.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_logical_or_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.logical_or(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_subtract.ttnn") -@ttir_to_ttnn(output_file_name="test_subtract.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_subtract_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.subtract(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_eq.ttnn") -@ttir_to_ttnn(output_file_name="test_eq.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_eq_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.eq(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_ne.ttnn") -@ttir_to_ttnn(output_file_name="test_ne.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_ne_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.ne(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_ge.ttnn") -@ttir_to_ttnn(output_file_name="test_ge.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_ge_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.ge(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_gt.ttnn") -@ttir_to_ttnn(output_file_name="test_gt.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_gt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.gt(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_le.ttnn") -@ttir_to_ttnn(output_file_name="test_le.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_le_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.le(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_lt.ttnn") -@ttir_to_ttnn(output_file_name="test_lt.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_lt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.lt(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_div.ttnn") -@ttir_to_ttnn(output_file_name="test_div.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_div_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.div(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_maximum.ttnn") -@ttir_to_ttnn(output_file_name="test_maximum.mlir") -@compile_as_mlir_module((64, 64), (64, 64)) -def test_maximum_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder): - return builder.maximum(in0, in1) - - -@ttnn_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttnn") -@ttir_to_ttnn(output_file_name="test_arbitrary_op_chain.mlir") -@compile_as_mlir_module((32, 32), (32, 32), (32, 32)) -def test_arbitrary_op_chain_ttnn( - in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder -): - add = builder.add(in0, in1) - exp = builder.exp(in2) - return builder.multiply(add, exp) - - -if __name__ == "__main__": - test_functions = inspect.getmembers( - inspect.getmodule(inspect.currentframe()), inspect.isfunction - ) - - for function_name, func in test_functions: - if function_name.startswith("test_"): - func() diff --git a/python/test_infra/test_utils.py b/python/test_infra/test_utils.py index 12af0afae..dabf9eac5 100644 --- a/python/test_infra/test_utils.py +++ b/python/test_infra/test_utils.py @@ -3,7 +3,7 @@ # SPDX-License-Identifier: Apache-2.0 import os -from typing import Callable, Dict, Tuple, List, Optional +from typing import Callable, Dict, List, Optional import torch from ttmlir.dialects import func @@ -28,23 +28,27 @@ def _dump_module(module: Module) -> None: print(module) -# ----- Decorators for doing passes and compiling to flatbuffer ----- +# ----- General Purpose Helpers - Could Be Used In Other Files ----- def compile_as_mlir_module( - *inputs_shapes: Tuple[Shape], + test_fn: Callable, + inputs_shapes: List[Shape], module_dump: bool = False, ): """ - Decorator to define a MLIR module specified as a python function. + Define a MLIR module specified as a python function. - It will wrap decorated test function in a MLIR FuncOp and then wrap that in a MLIR + It will wrap `test_fn` in a MLIR FuncOp and then wrap that in a MLIR module, and finally tie arguments of that FuncOp to test function inputs. It will also pass a `TTIRBuilder` object as the last argument of test function. Arguments --------- - inputs_shapes: Tuple[Shape] + test_fn : Callable + Python function to be converted to MLIR + + inputs_shapes: List[Shape] Shapes of the respective ranked tensor inputs of the test function. module_dump: bool @@ -56,18 +60,16 @@ def compile_as_mlir_module( Returns ------- - MLIR module containing MLIR op graph defined by decorated test function. + MLIR module containing MLIR op graph defined by `test_fn` Example ------- ```python - @compile_as_mlir_module((32, 32), (32, 32)) def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): return builder.add(in0, in1) - - test_add() # NOTE Called without arguments. + compile_as_mlir_module(test_add, ((32, 32), (32, 32))) ``` which returns @@ -90,49 +92,40 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): https://github.com/llvm/llvm-project/blob/main/mlir/test/python/dialects/tensor.py """ - def decorator(test_fn: Callable): - # test_fn should be called with no args. - def wrapper(): - ctx = Context() - loc = Location.unknown(ctx) - # Instantiate builder which is passed as the last argument to - # `test_fn` so the user can use it to build ops. - builder = TTIRBuilder(ctx, loc) - - with ctx, loc: - test_fn_input_types = [ - builder.ranked_tensor_type(input_shape) - for input_shape in inputs_shapes - ] + ctx = Context() + loc = Location.unknown(ctx) + # Instantiate builder which is passed as the last argument to + # `test_fn` so the user can use it to build ops. + builder = TTIRBuilder(ctx, loc) - # Wrap everything in a mlir module. - module = Module.create() + with ctx, loc: + test_fn_input_types = [ + builder.ranked_tensor_type(input_shape) for input_shape in inputs_shapes + ] - with InsertionPoint(module.body): - # Wrap everything in a mlir function. - @func.func(*test_fn_input_types, name=test_fn.__name__) - def decorated_func(*inputs): - # Randomly generate golden tensors for function inputs. - for index, i in enumerate(inputs): - builder.generate_input_golden(i, index) + # Wrap everything in a mlir module. + module = Module.create() - return test_fn(*inputs, builder=builder) + with InsertionPoint(module.body): + # Wrap everything in a mlir function. + @func.func(*test_fn_input_types, name=test_fn.__name__) + def decorated_func(*inputs): + # Randomly generate golden tensors for function inputs. + for index, i in enumerate(inputs): + builder.generate_input_golden(i, index) - print( - f"`{test_fn.__name__}` sucessfully transformed into a MLIR module." - ) + return test_fn(*inputs, builder=builder) - if module_dump: - _dump_module(module) + print(f"`{test_fn.__name__}` sucessfully transformed into a MLIR module.") - return module, builder + if module_dump: + _dump_module(module) - return wrapper - - return decorator + return module, builder def ttir_to_ttnn( + module, dump_to_file: bool = True, output_file_name: str = "test.mlir", system_desc_path: Optional[str] = None, @@ -152,50 +145,43 @@ def ttir_to_ttnn( Returns ------- - MLIR module containing MLIR op graph defined by decorated test function and instance of TTIRBuilder. + MLIR module containing MLIR op graph defined by `module` and instance of TTIRBuilder. """ # Default to the `SYSTEM_DESC_PATH` envvar if system_desc_path is None: system_desc_path = os.getenv("SYSTEM_DESC_PATH", "") - def decorator(fn: Callable): - def wrapper(*args, **kwargs): - # First, call the decorated function to get the MLIR module and builder instance - module, builder = fn(*args, **kwargs) - - # Now, pass it through the TTIR to TTNN pipeline. Module gets - # modified in place. - ttir_to_ttnn_backend_pipeline( - module, f"system-desc-path={system_desc_path}" - ) + # Now, pass it through the TTIR to TTNN pipeline. Module gets + # modified in place. + ttir_to_ttnn_backend_pipeline(module, f"system-desc-path={system_desc_path}") - print("`ttir_to_ttnn_backend_pipeline` passed successfully.") + print("`ttir_to_ttnn_backend_pipeline` passed successfully.") - # Optionally dump to file. - if dump_to_file: - with open(output_file_name, "w") as f: - f.write(str(module)) + # Optionally dump to file. + if dump_to_file: + with open(output_file_name, "w") as f: + f.write(str(module)) - return module, builder - - return wrapper - - return decorator + return module def ttir_to_ttmetal( + module, dump_to_file: bool = True, output_file_name: str = "test.mlir", system_desc_path: Optional[str] = None, ): """ - Converts TTIR module to TTMetal module and optionally dumps to file. + Converts TTIR module `module` to TTMetal module and optionally dumps to file. Wrapper around `ttir_to_ttmetal_backend_pipeline` pybound pass. Arguments --------- + module: ??? + TTIR module to convert to TTMetal module + dump_to_file: bool Flag which indicates that generated TTMetal module will be dumped to file. @@ -204,86 +190,128 @@ def ttir_to_ttmetal( Returns ------- - MLIR module containing MLIR op graph defined by decorated test function and instance of TTIRBuilder. + MLIR module containing MLIR op graph defined by `module` and instance of TTIRBuilder. """ # Default to the `SYSTEM_DESC_PATH` envvar if system_desc_path is None: system_desc_path = os.getenv("SYSTEM_DESC_PATH", "") - def decorator(fn: Callable): - def wrapper(*args, **kwargs): - # First, call the decorated function to get the MLIR module. - module, builder = fn(*args, **kwargs) + # Now, pass it through the TTIR to TTMetal pipeline. Module gets + # modified in place. + ttir_to_ttmetal_backend_pipeline(module, f"system-desc-path={system_desc_path}") - # Now, pass it through the TTIR to TTMetal pipeline. Module gets - # modified in place. - ttir_to_ttmetal_backend_pipeline( - module, f"system-desc-path={system_desc_path}" - ) + print("`ttir_to_ttmetal_backend_pipeline` passed successfully.") - print("`ttir_to_ttmetal_backend_pipeline` passed successfully.") + # Optionally dump to file. + if dump_to_file: + with open(output_file_name, "w") as f: + f.write(str(module)) - # Optionally dump to file. - if dump_to_file: - with open(output_file_name, "w") as f: - f.write(str(module)) - - return module, builder - - return wrapper - - return decorator + return module def ttnn_to_flatbuffer( + module, + builder, output_file_name: str = "ttnn_fb.ttnn", ): """ - Converts TTNN module to flatbuffer and saves to file, meant to be used as a - decorator on top of `ttir_to_ttnn` decorator. Take note that `ttir_to_ttnn` - has to return module instead of file name if decorated with this decorator. - - Wrapper around `ttnn_to_flatbuffer_file` pybound pass. + Converts TTNN module to flatbuffer and saves to file. Wrapper around + `ttnn_to_flatbuffer_file` pybound pass. """ - def decorator(test_fn: Callable): - def wrapper(*args, **kwargs): - # Get the TTNN module by calling the wrapped function. - module, builder = test_fn(*args, **kwargs) + # Convert to flatbuffer file. + ttnn_to_flatbuffer_file(module, output_file_name, builder.get_golden_map()) - # Convert to flatbuffer file. - ttnn_to_flatbuffer_file(module, output_file_name, builder.get_golden_map()) + print("`ttnn_to_flatbuffer_file` passed successfully.") - print("`ttnn_to_flatbuffer_file` passed successfully.") - return wrapper +def ttmetal_to_flatbuffer( + module, + builder, + output_file_name: str = "ttmetal_fb.ttm", +): + """ + Converts TTMetal module to flatbuffer and saves to file. Wrapper around + `ttmetal_to_flatbuffer_file` pybound pass. + """ - return decorator + # Convert to flatbuffer file. + ttmetal_to_flatbuffer_file(module, output_file_name, builder.get_golden_map()) + print("`ttmetal_to_flatbuffer_file` passed successfully.") -def ttmetal_to_flatbuffer( - output_file_name: str = "ttmetal_fb.ttmg", + +# ----- Decorators for doing passes and compiling to flatbuffer ----- + + +def compile_to_flatbuffer( + inputs_shapes: List[Shape], + test_name: str, + targets: List[str] = ["ttmetal", "ttnn"], + module_dump: bool = False, ): """ - Converts TTMetal module to flatbuffer and saves to file, meant to be used as a - decorator on top of `ttir_to_ttmetal` decorator. Take note that `ttir_to_ttmetal` - has to return module instead of file name if decorated with this decorator. + Decorator to run an e2e Python -> Flatbuffer test using the decorated + function, using the TTNN and/or TTMetal backends. + + This decorator is mainly a wrapper around the following functions, with + each next function called on the output of the last: + + 1. `compile_as_mlir_module` + 2. `ttir_to_tt{nn,metal}` + 3. `tt{nn,metal}_to_flatbuffer` + + The choice of TTNN, TTMetal, or both is controlled by membership of those + strings in the `targets` parameter. + + Arguments + --------- + + inputs_shapes: List[Shape] + Shapes of the respective ranked tensor inputs of the test function. + + test_name: str + The name of the decorated function. Used as the base name for dumped + files during the process - Wrapper around `ttmetal_to_flatbuffer_file` pybound pass. + targets: List[str] + A list that can only contain the following strings: 'ttnn' or + 'ttmetal'. Inclusion in this list will signal this decorator to execute + their respective backend paths. Either, neither, or both are valid inputs. + + module_dump: bool + Set to True to print out generated MLIR module. + + Example + ------- + + ```python + @compile_and_convert(((32, 32), (32, 32)), test_name="test_add") + def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder): + return builder.add(in0, in1) + + test_add() # NOTE: called without arguments + ``` """ def decorator(test_fn: Callable): - def wrapper(*args, **kwargs): - # Get the TTMetal module by calling the wrapped function. - module, builder = test_fn(*args, **kwargs) + def wrapper(): + + # NOTE: since `ttir_to_tt{nn,metal} modifies the module in place, + # `compile_as_mlir_module` needs to be run twice in the case that + # both targets are chosen - # Convert to flatbuffer file. - ttmetal_to_flatbuffer_file( - module, output_file_name, builder.get_golden_map() - ) + if "ttmetal" in targets: + module, builder = compile_as_mlir_module(test_fn, inputs_shapes) + module = ttir_to_ttmetal(module, builder, test_name + ".mlir") + ttmetal_to_flatbuffer(module, builder, test_name + ".ttm") - print("`ttmetal_to_flatbuffer_file` passed successfully.") + if "ttnn" in targets: + module, builder = compile_as_mlir_module(test_fn, inputs_shapes) + module = ttir_to_ttnn(module, builder, test_name + ".mlir") + ttnn_to_flatbuffer(module, builder, test_name + ".ttnn") return wrapper