Skip to content

Commit

Permalink
Modified proxy function in ttir_builder and added MNIST golden test (#…
Browse files Browse the repository at this point in the history
…1609)

- Needed a more "robust" test for accuracy overlay in explorer. Few
modifications to make eltwise_proxy more general and added simple MNIST
test.
  • Loading branch information
vprajapati-tt authored Dec 18, 2024
1 parent 844cb9c commit 640377d
Show file tree
Hide file tree
Showing 2 changed files with 146 additions and 10 deletions.
130 changes: 120 additions & 10 deletions python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,32 +287,108 @@ def empty(
return op

# ----- TTIR op factories -----
def eltwise_proxy(
def _organize_eltwise_ttir(
self, inputs: List[Operand], output: OpView, output_shape: Optional[Shape]
):
return ([self._get_type(output)], inputs, [output])

def _organize_eltwise_golden(
self, inputs: List[Operand], output: OpView, output_shape: Optional[Shape]
):
return [self._get_golden_tensor(inp) for inp in inputs]

def op_proxy(
self,
op_golden_function: Callable,
op_ttir_function: Callable,
inputs: List[Operand],
) -> OpView:
organize_ttir_args: Optional[Callable] = None,
organize_golden_args: Optional[Callable] = None,
output_shape: Optional[Shape] = None,
golden_kwargs: dict = {},
ttir_kwargs: dict = {},
) -> Any:
"""
Provides a general interface for proxy-ing OPs and creating them.
Parameters:
- op_golden_function (Callable): A function that creates the OP using a golden approach.
- op_ttir_function (Callable): A function that creates the OP using a TTIR approach.
- inputs (List[Operand]): A list of operands serving as inputs to the OP.
- organize_ttir_args (Callable): A function that organizes the inputs and other positional arguments for the TTIR approach.
- Function signature:
def organize_ttir_args(inputs: List[Operand], output: OpView, output_shape: Optional[Shape]) -> List/Tuple
The list/tuple will then be unpacked as the positional arguments for the op_ttir_function
- organize_golden_args (Callable): A function that organizes the inputs and other arguments for the golden approach.
- Function signature:
def organize_golden_args(inputs: List[Operand], output: OpView, output_shape: Optional[Shape]) -> List/Tuple
The list/tuple will then be unpacked as the positional arugments for the op_golden_function
- output_shape (Optional[Shape]): An optional argument specifying the shape of the output of the OP.
- golden_kwargs (dict): Additional keyword arguments for the `op_golden_function`.
- ttir_kwargs (dict): Additional keyword arguments for the `op_ttir_function`.
Returns:
- OpView: The created op
"""
# Snoop the location of the first caller outside of this file to
# annotate the MLIR with. NOTE that this location is _NOT_ row:col, but
# instead row:id, where id is a unique id given to all calls to builder
# funcs. See `get_next_global_id` for more details
stack = inspect.stack()

# find the innermost frame outside of this file
cur_filename = stack[0].filename

while len(stack) > 0 and stack[0].filename == cur_filename:
stack = stack[1:]

assert (
len(stack) > 0
), "Top of callstack to builder funcs must be outside this file"

id = self.get_next_global_id()
loc = get_loc_of_extra_file_callee(id=id)
if organize_ttir_args is None:
organize_ttir_args = self._organize_eltwise_ttir

if organize_golden_args is None:
organize_golden_args = self._organize_eltwise_golden

with self._ctx, self._loc:
output = self.empty(self.get_shape(inputs[0]))
shape = self.get_shape(inputs[0]) if not output_shape else output_shape
output = self.empty(shape)

op = op_ttir_function([self._get_type(output)], inputs, [output], loc=loc)
id = self.get_next_global_id()
loc = get_loc_of_extra_file_callee(id=id)

goldens = []
for input in inputs:
goldens.append(self._get_golden_tensor(input))
op = op_ttir_function(
*organize_ttir_args(inputs, output, output_shape),
loc=loc,
**ttir_kwargs,
)

golden = Golden(op_golden_function(*goldens))
golden = Golden(
op_golden_function(
*organize_golden_args(inputs, output, output_shape), **golden_kwargs
)
)
self.id_golden_map[str(loc)] = golden
self._store_golden(op, golden)
self._override_golden(output, golden)

return op

def eltwise_proxy(
self,
op_golden_function: Callable,
op_ttir_function: Callable,
inputs: List[Operand],
) -> OpView:
return self.op_proxy(op_golden_function, op_ttir_function, inputs)

def exp(self, in0: Operand) -> OpView:
return self.eltwise_proxy(torch.exp, ttir.ExpOp, [in0])

Expand Down Expand Up @@ -378,3 +454,37 @@ def div(self, in0: Operand, in1: Operand) -> OpView:

def maximum(self, in0: Operand, in1: Operand) -> OpView:
return self.eltwise_proxy(torch.maximum, ttir.MaximumOp, [in0, in1])

def matmul(
self, in0: Operand, in1: Operand, bias: Optional[Operand] = None
) -> OpView:
# Calculate the output shape for Matmul
inputs = [in0, in1]
if bias:
inputs.append(bias)
shapes = [self.get_shape(x) for x in inputs]
shape = (shapes[0][0], shapes[1][1])
assert (
shapes[0][1] == shapes[1][0]
), "Input Shapes not compatible for Matrix Multiplication"
return self.op_proxy(
torch.matmul,
ttir.MatmulOp,
inputs,
output_shape=shape,
organize_ttir_args=lambda i, o, shape: (self._get_type(o), i[0], i[1], o),
)

def softmax(self, in0: Operand, dimension: int = 1) -> OpView:
return self.op_proxy(
torch.softmax,
ttir.SoftmaxOp,
[in0],
golden_kwargs={"dim": dimension},
organize_ttir_args=lambda i, o, shape: (
self._get_type(o),
i[0],
o,
dimension,
),
)
26 changes: 26 additions & 0 deletions test/python/golden/test_ttir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,6 +211,32 @@ def test_arbitrary_op_chain(
return builder.multiply(add, exp)


@compile_to_flatbuffer(
[
(1, 784),
(784, 256),
(1, 256),
(256, 10),
(1, 10),
],
targets=["ttnn"],
)
def test_mnist(
in0: Operand, # Input 28x28 image
in1: Operand, # Weight 1
in2: Operand, # Bias 1
in3: Operand, # Weight 2
in4: Operand, # Bias 2
builder: TTIRBuilder,
):
matmul_1 = builder.matmul(in0, in1)
add_2 = builder.add(matmul_1, in2)
relu_3 = builder.relu(add_2)
matmul_5 = builder.matmul(relu_3, in3)
add_6 = builder.add(matmul_5, in4)
return builder.softmax(add_6, dimension=1)


if __name__ == "__main__":
test_functions = inspect.getmembers(
inspect.getmodule(inspect.currentframe()), inspect.isfunction
Expand Down

0 comments on commit 640377d

Please sign in to comment.