Skip to content

Commit

Permalink
Cleanup system description discovery (#1184) (#1196)
Browse files Browse the repository at this point in the history
This change moves grabbing the path for the `ttrt` generated system
description from the env into decorators, so that it doesn't need to be
passed into every invocation of the decorator. This is a simple cleanup
change that does not change behavior.
  • Loading branch information
ctodTT authored Nov 8, 2024
1 parent 2cb0bf2 commit dca9c66
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 68 deletions.
20 changes: 4 additions & 16 deletions python/test_infra/test_ttir_ops_ttmetal.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# RUN: %python %s

import inspect
import os

from ttmlir.test_utils import (
compile_as_mlir_module,
Expand All @@ -14,41 +13,30 @@
)
from ttmlir.ttir_builder import Operand, TTIRBuilder

system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")


@ttmetal_to_flatbuffer(output_file_name="test_exp.ttm")
@ttir_to_ttmetal(
output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttmetal(output_file_name="test_exp.mlir")
@compile_as_mlir_module((128, 128))
def test_exp_ttmetal(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@ttmetal_to_flatbuffer(output_file_name="test_add.ttm")
@ttir_to_ttmetal(
output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttmetal(output_file_name="test_add.mlir")
@compile_as_mlir_module((64, 128), (64, 128))
def test_add_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)


@ttmetal_to_flatbuffer(output_file_name="test_multiply.ttm")
@ttir_to_ttmetal(
output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttmetal(output_file_name="test_multiply.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_multiply_ttmetal(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)


@ttmetal_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttm")
@ttir_to_ttmetal(
output_file_name="test_arbitrary_op_chain.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttmetal(output_file_name="test_arbitrary_op_chain.mlir")
@compile_as_mlir_module((32, 32), (32, 32), (32, 32))
def test_arbitrary_op_chain_ttmetal(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
Expand Down
72 changes: 23 additions & 49 deletions python/test_infra/test_ttir_ops_ttnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@
# RUN: %python %s

import inspect
import os

from ttmlir.test_utils import (
compile_as_mlir_module,
Expand All @@ -14,188 +13,163 @@
)
from ttmlir.ttir_builder import Operand, TTIRBuilder

system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")


@ttnn_to_flatbuffer(output_file_name="test_exp.ttnn")
@ttir_to_ttnn(output_file_name="test_exp.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_exp.mlir")
@compile_as_mlir_module((128, 128))
def test_exp_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.exp(in0)


@ttnn_to_flatbuffer(output_file_name="test_abs.ttnn")
@ttir_to_ttnn(output_file_name="test_abs.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_abs.mlir")
@compile_as_mlir_module((128, 128))
def test_abs_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.abs(in0)


@ttnn_to_flatbuffer(output_file_name="test_logical_not.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_not.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttnn(output_file_name="test_logical_not.mlir")
@compile_as_mlir_module((128, 128))
def test_logical_not_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


@ttnn_to_flatbuffer(output_file_name="test_neg.ttnn")
@ttir_to_ttnn(output_file_name="test_neg.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_neg.mlir")
@compile_as_mlir_module((128, 128))
def test_neg_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.neg(in0)


@ttnn_to_flatbuffer(output_file_name="test_relu.ttnn")
@ttir_to_ttnn(output_file_name="test_relu.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_relu.mlir")
@compile_as_mlir_module((128, 128))
def test_relu_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.relu(in0)


@ttnn_to_flatbuffer(output_file_name="test_sqrt.ttnn")
@ttir_to_ttnn(output_file_name="test_sqrt.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_sqrt.mlir")
@compile_as_mlir_module((128, 128))
def test_sqrt_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.sqrt(in0)


@ttnn_to_flatbuffer(output_file_name="test_rsqrt.ttnn")
@ttir_to_ttnn(
output_file_name="test_rsqrt.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_rsqrt.mlir")
@compile_as_mlir_module((128, 128))
def test_rsqrt_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.rsqrt(in0)


@ttnn_to_flatbuffer(output_file_name="test_sigmoid.ttnn")
@ttir_to_ttnn(
output_file_name="test_sigmoid.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_sigmoid.mlir")
@compile_as_mlir_module((128, 128))
def test_sigmoid_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.sigmoid(in0)


@ttnn_to_flatbuffer(output_file_name="test_reciprocal.ttnn")
@ttir_to_ttnn(
output_file_name="test_reciprocal.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_reciprocal.mlir")
@compile_as_mlir_module((128, 128))
def test_reciprocal_ttnn(in0: Operand, builder: TTIRBuilder):
return builder.reciprocal(in0)


@ttnn_to_flatbuffer(output_file_name="test_add.ttnn")
@ttir_to_ttnn(output_file_name="test_add.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_add.mlir")
@compile_as_mlir_module((64, 128), (64, 128))
def test_add_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.add(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_multiply.ttnn")
@ttir_to_ttnn(
output_file_name="test_multiply.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_multiply.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_multiply_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.multiply(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_logical_and.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_and.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttnn(output_file_name="test_logical_and.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_logical_and_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_and(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_logical_or.ttnn")
@ttir_to_ttnn(
output_file_name="test_logical_or.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_logical_or.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_logical_or_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_or(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_subtract.ttnn")
@ttir_to_ttnn(
output_file_name="test_subtract.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_subtract.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_subtract_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.subtract(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_eq.ttnn")
@ttir_to_ttnn(output_file_name="test_eq.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_eq.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_eq_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.eq(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_ne.ttnn")
@ttir_to_ttnn(output_file_name="test_ne.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_ne.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_ne_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ne(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_ge.ttnn")
@ttir_to_ttnn(output_file_name="test_ge.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_ge.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_ge_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.ge(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_gt.ttnn")
@ttir_to_ttnn(output_file_name="test_gt.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_gt.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_gt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.gt(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_le.ttnn")
@ttir_to_ttnn(output_file_name="test_le.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_le.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_le_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.le(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_lt.ttnn")
@ttir_to_ttnn(output_file_name="test_lt.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_lt.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_lt_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.lt(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_div.ttnn")
@ttir_to_ttnn(output_file_name="test_div.mlir", system_desc_path=f"{system_desc_path}")
@ttir_to_ttnn(output_file_name="test_div.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_div_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.div(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_maximum.ttnn")
@ttir_to_ttnn(
output_file_name="test_maximum.mlir", system_desc_path=f"{system_desc_path}"
)
@ttir_to_ttnn(output_file_name="test_maximum.mlir")
@compile_as_mlir_module((64, 64), (64, 64))
def test_maximum_ttnn(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.maximum(in0, in1)


@ttnn_to_flatbuffer(output_file_name="test_arbitrary_op_chain.ttnn")
@ttir_to_ttnn(
output_file_name="test_arbitrary_op_chain.mlir",
system_desc_path=f"{system_desc_path}",
)
@ttir_to_ttnn(output_file_name="test_arbitrary_op_chain.mlir")
@compile_as_mlir_module((32, 32), (32, 32), (32, 32))
def test_arbitrary_op_chain_ttnn(
in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder
Expand Down
14 changes: 11 additions & 3 deletions python/test_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
# SPDX-License-Identifier: Apache-2.0

import os
from typing import Callable, Dict, Tuple, List
from typing import Callable, Dict, Tuple, List, Optional

import torch
from ttmlir.dialects import func
Expand Down Expand Up @@ -135,7 +135,7 @@ def decorated_func(*inputs):
def ttir_to_ttnn(
dump_to_file: bool = True,
output_file_name: str = "test.mlir",
system_desc_path: str = "",
system_desc_path: Optional[str] = None,
):
"""
Converts TTIR module to TTNN module and optionally dumps to file.
Expand All @@ -155,6 +155,10 @@ def ttir_to_ttnn(
MLIR module containing MLIR op graph defined by decorated test function and instance of TTIRBuilder.
"""

# Default to the `SYSTEM_DESC_PATH` envvar
if system_desc_path is None:
system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")

def decorator(fn: Callable):
def wrapper(*args, **kwargs):
# First, call the decorated function to get the MLIR module and builder instance
Expand Down Expand Up @@ -183,7 +187,7 @@ def wrapper(*args, **kwargs):
def ttir_to_ttmetal(
dump_to_file: bool = True,
output_file_name: str = "test.mlir",
system_desc_path: str = "",
system_desc_path: Optional[str] = None,
):
"""
Converts TTIR module to TTMetal module and optionally dumps to file.
Expand All @@ -203,6 +207,10 @@ def ttir_to_ttmetal(
MLIR module containing MLIR op graph defined by decorated test function and instance of TTIRBuilder.
"""

# Default to the `SYSTEM_DESC_PATH` envvar
if system_desc_path is None:
system_desc_path = os.getenv("SYSTEM_DESC_PATH", "")

def decorator(fn: Callable):
def wrapper(*args, **kwargs):
# First, call the decorated function to get the MLIR module.
Expand Down

0 comments on commit dca9c66

Please sign in to comment.