From 640377de10a5347b4abcb724b8232e0b1d517080 Mon Sep 17 00:00:00 2001 From: Vraj Prajapati Date: Wed, 18 Dec 2024 08:41:53 -0600 Subject: [PATCH] Modified proxy function in ttir_builder and added MNIST golden test (#1609) - Needed a more "robust" test for accuracy overlay in explorer. Few modifications to make eltwise_proxy more general and added simple MNIST test. --- python/test_infra/ttir_builder.py | 130 +++++++++++++++++++++++++--- test/python/golden/test_ttir_ops.py | 26 ++++++ 2 files changed, 146 insertions(+), 10 deletions(-) diff --git a/python/test_infra/ttir_builder.py b/python/test_infra/ttir_builder.py index 471c07ca7..31e0adb30 100644 --- a/python/test_infra/ttir_builder.py +++ b/python/test_infra/ttir_builder.py @@ -287,32 +287,108 @@ def empty( return op # ----- TTIR op factories ----- - def eltwise_proxy( + def _organize_eltwise_ttir( + self, inputs: List[Operand], output: OpView, output_shape: Optional[Shape] + ): + return ([self._get_type(output)], inputs, [output]) + + def _organize_eltwise_golden( + self, inputs: List[Operand], output: OpView, output_shape: Optional[Shape] + ): + return [self._get_golden_tensor(inp) for inp in inputs] + + def op_proxy( self, op_golden_function: Callable, op_ttir_function: Callable, inputs: List[Operand], - ) -> OpView: + organize_ttir_args: Optional[Callable] = None, + organize_golden_args: Optional[Callable] = None, + output_shape: Optional[Shape] = None, + golden_kwargs: dict = {}, + ttir_kwargs: dict = {}, + ) -> Any: + """ + Provides a general interface for proxy-ing OPs and creating them. + + Parameters: + - op_golden_function (Callable): A function that creates the OP using a golden approach. + - op_ttir_function (Callable): A function that creates the OP using a TTIR approach. + - inputs (List[Operand]): A list of operands serving as inputs to the OP. + - organize_ttir_args (Callable): A function that organizes the inputs and other positional arguments for the TTIR approach. + - Function signature: + + def organize_ttir_args(inputs: List[Operand], output: OpView, output_shape: Optional[Shape]) -> List/Tuple + + The list/tuple will then be unpacked as the positional arguments for the op_ttir_function + + - organize_golden_args (Callable): A function that organizes the inputs and other arguments for the golden approach. + - Function signature: + + def organize_golden_args(inputs: List[Operand], output: OpView, output_shape: Optional[Shape]) -> List/Tuple + + The list/tuple will then be unpacked as the positional arugments for the op_golden_function + - output_shape (Optional[Shape]): An optional argument specifying the shape of the output of the OP. + - golden_kwargs (dict): Additional keyword arguments for the `op_golden_function`. + - ttir_kwargs (dict): Additional keyword arguments for the `op_ttir_function`. + + Returns: + - OpView: The created op + """ + # Snoop the location of the first caller outside of this file to + # annotate the MLIR with. NOTE that this location is _NOT_ row:col, but + # instead row:id, where id is a unique id given to all calls to builder + # funcs. See `get_next_global_id` for more details + stack = inspect.stack() + + # find the innermost frame outside of this file + cur_filename = stack[0].filename + + while len(stack) > 0 and stack[0].filename == cur_filename: + stack = stack[1:] + + assert ( + len(stack) > 0 + ), "Top of callstack to builder funcs must be outside this file" - id = self.get_next_global_id() - loc = get_loc_of_extra_file_callee(id=id) + if organize_ttir_args is None: + organize_ttir_args = self._organize_eltwise_ttir + + if organize_golden_args is None: + organize_golden_args = self._organize_eltwise_golden with self._ctx, self._loc: - output = self.empty(self.get_shape(inputs[0])) + shape = self.get_shape(inputs[0]) if not output_shape else output_shape + output = self.empty(shape) - op = op_ttir_function([self._get_type(output)], inputs, [output], loc=loc) + id = self.get_next_global_id() + loc = get_loc_of_extra_file_callee(id=id) - goldens = [] - for input in inputs: - goldens.append(self._get_golden_tensor(input)) + op = op_ttir_function( + *organize_ttir_args(inputs, output, output_shape), + loc=loc, + **ttir_kwargs, + ) - golden = Golden(op_golden_function(*goldens)) + golden = Golden( + op_golden_function( + *organize_golden_args(inputs, output, output_shape), **golden_kwargs + ) + ) self.id_golden_map[str(loc)] = golden self._store_golden(op, golden) self._override_golden(output, golden) return op + def eltwise_proxy( + self, + op_golden_function: Callable, + op_ttir_function: Callable, + inputs: List[Operand], + ) -> OpView: + return self.op_proxy(op_golden_function, op_ttir_function, inputs) + def exp(self, in0: Operand) -> OpView: return self.eltwise_proxy(torch.exp, ttir.ExpOp, [in0]) @@ -378,3 +454,37 @@ def div(self, in0: Operand, in1: Operand) -> OpView: def maximum(self, in0: Operand, in1: Operand) -> OpView: return self.eltwise_proxy(torch.maximum, ttir.MaximumOp, [in0, in1]) + + def matmul( + self, in0: Operand, in1: Operand, bias: Optional[Operand] = None + ) -> OpView: + # Calculate the output shape for Matmul + inputs = [in0, in1] + if bias: + inputs.append(bias) + shapes = [self.get_shape(x) for x in inputs] + shape = (shapes[0][0], shapes[1][1]) + assert ( + shapes[0][1] == shapes[1][0] + ), "Input Shapes not compatible for Matrix Multiplication" + return self.op_proxy( + torch.matmul, + ttir.MatmulOp, + inputs, + output_shape=shape, + organize_ttir_args=lambda i, o, shape: (self._get_type(o), i[0], i[1], o), + ) + + def softmax(self, in0: Operand, dimension: int = 1) -> OpView: + return self.op_proxy( + torch.softmax, + ttir.SoftmaxOp, + [in0], + golden_kwargs={"dim": dimension}, + organize_ttir_args=lambda i, o, shape: ( + self._get_type(o), + i[0], + o, + dimension, + ), + ) diff --git a/test/python/golden/test_ttir_ops.py b/test/python/golden/test_ttir_ops.py index e693196f5..8a9b2ba68 100644 --- a/test/python/golden/test_ttir_ops.py +++ b/test/python/golden/test_ttir_ops.py @@ -211,6 +211,32 @@ def test_arbitrary_op_chain( return builder.multiply(add, exp) +@compile_to_flatbuffer( + [ + (1, 784), + (784, 256), + (1, 256), + (256, 10), + (1, 10), + ], + targets=["ttnn"], +) +def test_mnist( + in0: Operand, # Input 28x28 image + in1: Operand, # Weight 1 + in2: Operand, # Bias 1 + in3: Operand, # Weight 2 + in4: Operand, # Bias 2 + builder: TTIRBuilder, +): + matmul_1 = builder.matmul(in0, in1) + add_2 = builder.add(matmul_1, in2) + relu_3 = builder.relu(add_2) + matmul_5 = builder.matmul(relu_3, in3) + add_6 = builder.add(matmul_5, in4) + return builder.softmax(add_6, dimension=1) + + if __name__ == "__main__": test_functions = inspect.getmembers( inspect.getmodule(inspect.currentframe()), inspect.isfunction