Skip to content

Commit

Permalink
Enable choosing of input type for lowering tests
Browse files Browse the repository at this point in the history
Closes #1765
  • Loading branch information
ctodTT committed Jan 22, 2025
1 parent 2c8edcc commit 8b0133b
Show file tree
Hide file tree
Showing 3 changed files with 136 additions and 61 deletions.
26 changes: 20 additions & 6 deletions python/test_infra/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import os
import inspect
import torch
from typing import Callable, List, Optional

from ttmlir.dialects import func
Expand All @@ -15,7 +16,7 @@
ttmetal_to_flatbuffer_file,
)

from .ttir_builder import Golden, Operand, Shape, TTIRBuilder
from .ttir_builder import Golden, Operand, Shape, TTIRBuilder, DataType

TT_MLIR_HOME = os.environ.get("TT_MLIR_HOME", "")

Expand All @@ -34,6 +35,7 @@ def _dump_module(module: Module) -> None:
def compile_as_mlir_module(
test_fn: Callable,
inputs_shapes: List[Shape],
inputs_types: Optional[List[torch.dtype]] = None,
module_dump: bool = False,
):
"""
Expand Down Expand Up @@ -106,9 +108,16 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
# `test_fn` so the user can use it to build ops.
builder = TTIRBuilder(ctx, loc)

# Default to all f32s
if inputs_types is None:
inputs_types = [torch.float32] * len(inputs_shapes)

assert inputs_types is not None and len(inputs_shapes) == len(inputs_types)

with ctx, loc:
test_fn_input_types = [
builder.ranked_tensor_type(input_shape) for input_shape in inputs_shapes
builder.ranked_tensor_type(shape, builder.get_type_from_torch_dtype(dtype))
for (shape, dtype) in zip(inputs_shapes, inputs_types)
]

# Wrap everything in a mlir module.
Expand All @@ -119,8 +128,8 @@ def test_add(in0: Operand, in1: Operand, builder: TTIRBuilder):
@func.func(*test_fn_input_types, name=test_fn.__name__)
def decorated_func(*inputs):
# Randomly generate golden tensors for function inputs.
for index, i in enumerate(inputs):
builder.generate_input_golden(i, index)
for index, (operand, dtype) in enumerate(zip(inputs, inputs_types)):
builder.generate_input_golden(operand, dtype, index)

return test_fn(*inputs, builder=builder)

Expand Down Expand Up @@ -256,6 +265,7 @@ def ttmetal_to_flatbuffer(

def compile_to_flatbuffer(
inputs_shapes: List[Shape],
inputs_types: Optional[List[torch.dtype]] = None,
test_name: Optional[str] = None,
targets: List[str] = ["ttmetal", "ttnn"],
module_dump: bool = False,
Expand Down Expand Up @@ -320,12 +330,16 @@ def wrapper():
# both targets are chosen

if "ttmetal" in targets:
module, builder = compile_as_mlir_module(test_fn, inputs_shapes)
module, builder = compile_as_mlir_module(
test_fn, inputs_shapes, inputs_types
)
module = ttir_to_ttmetal(module, builder, test_base + ".mlir")
ttmetal_to_flatbuffer(module, builder, test_base + ".ttm")

if "ttnn" in targets:
module, builder = compile_as_mlir_module(test_fn, inputs_shapes)
module, builder = compile_as_mlir_module(
test_fn, inputs_shapes, inputs_types
)
module = ttir_to_ttnn(module, builder, test_base + ".mlir")
ttnn_to_flatbuffer(module, builder, test_base + ".ttnn")

Expand Down
70 changes: 61 additions & 9 deletions python/test_infra/ttir_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,26 +137,32 @@ def get_shape(self, input: Operand) -> Shape:
"""Retrieves shape of operand which is expected to be a shaped type."""
return self._get_type(input).shape

def generate_and_store_random_golden(self, operand: Operand) -> Golden:
def generate_and_store_random_golden(
self, operand: Operand, dtype: torch.dtype = torch.float32
) -> Golden:
"""
Generates random tensor of `operand`s shape, assigns it to a golden,
Generates random tensor of `dtype`s of `operand`s shape, assigns it to a golden,
and maps `operand` to that golden.
Returns generated golden.
"""
seed = self._get_seed()
random_tensor = self._generate_random_tensor(self.get_shape(operand), seed)
random_tensor = self._generate_random_tensor(
self.get_shape(operand), dtype, seed
)
golden = Golden(random_tensor, seed)
self._store_golden(operand, golden)
return golden

def generate_input_golden(self, operand: Operand, index: int) -> None:
def generate_input_golden(
self, operand: Operand, dtype: torch.dtype, index: int
) -> None:
"""
Generates random tensor of `input`s shape, assigns it to a golden,
Generates random tensor of `dtype`s of `input`s shape, assigns it to a golden,
and maps `input` to that golden.
"""
self.id_golden_map[f"input_{index}"] = self.generate_and_store_random_golden(
operand
operand, dtype
)

def get_golden_map(self) -> Dict:
Expand Down Expand Up @@ -200,12 +206,26 @@ def _get_seed(self) -> int:
return seed

@staticmethod
def _generate_random_tensor(shape: Shape, seed: int) -> torch.Tensor:
def _generate_random_tensor(
shape: Shape, dtype: torch.dtype, seed: int
) -> torch.Tensor:
"""
Generates random tensor of shape `shape`, using `seed` to seed torch
Generates random tensor of shape `shape`, with type `dtype`, using `seed` to seed torch
random generator.
"""
return torch.randn(shape, generator=torch.manual_seed(seed))

if dtype.is_floating_point:
return torch.randn(shape, generator=torch.manual_seed(seed), dtype=dtype)
else:
min_int = torch.iinfo(dtype).min
max_int = torch.iinfo(dtype).max
return torch.randint(
low=min_int,
high=max_int,
size=shape,
generator=torch.manual_seed(seed),
dtype=dtype,
)

def _get_golden(self, operand: Operand) -> Golden:
"""Retrieves stored golden for `operand`."""
Expand Down Expand Up @@ -259,6 +279,38 @@ def _get_type(self, input: Operand):

return typ

# ----- Utility Conversion ----

def get_type_from_torch_dtype(self, dtype: torch.dtype) -> Type:
"""
Returns a MLIR `Type` obj corresponding to `dtype`
"""
match dtype:
case torch.float16:
return F16Type.get(self._ctx)
case torch.float32:
return F32Type.get(self._ctx)
case torch.float64:
return F64Type.get(self._ctx)
case torch.int8:
return IntegerType.get_signless(8, self._ctx)
case torch.int16:
return IntegerType.get_signless(16, self._ctx)
case torch.int32:
return IntegerType.get_signless(32, self._ctx)
case torch.int64:
return IntegerType.get_signless(64, self._ctx)
case torch.uint8:
return IntegerType.get_unsigned(8, self._ctx)
case torch.uint16:
return IntegerType.get_unsigned(16, self._ctx)
case torch.uint32:
return IntegerType.get_unsigned(32, self._ctx)
case torch.uint64:
return IntegerType.get_unsigned(64, self._ctx)
case _:
raise TypeError(f"Invalid Type {type}")

# ----- Utility factories -----

def ranked_tensor_type(
Expand Down
101 changes: 55 additions & 46 deletions test/python/golden/test_ttir_ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
# RUN: SYSTEM_DESC_PATH=%system_desc_path% %python %s

import inspect
import torch

from ttmlir.test_utils import compile_to_flatbuffer
from ttmlir.ttir_builder import Operand, TTIRBuilder, Attribute
Expand Down Expand Up @@ -40,11 +41,11 @@ def test_logical_not(in0: Operand, builder: TTIRBuilder):
return builder.logical_not(in0)


# TODO: uncomment once we have control over generated input types (bitwise ops
# don't support floats) (see issue #1765)
# @compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
# def test_bitwise_not(in0: Operand, builder: TTIRBuilder):
# return builder.bitwise_not(in0)
# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer([(128, 128)], inputs_types=[torch.int8], targets=["ttnn"])
def test_bitwise_not(in0: Operand, builder: TTIRBuilder):
return builder.bitwise_not(in0)


@compile_to_flatbuffer([(128, 128)], targets=["ttnn"])
Expand Down Expand Up @@ -190,39 +191,46 @@ def test_logical_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.logical_xor(in0, in1)


# TODO: uncomment once we have control over generated input types (bitwise ops
# don't support floats) (see issue #1765)
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_and(in0, in1)
#
#
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_or(in0, in1)
#
#
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# )
# def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
# return builder.bitwise_xor(in0, in1)
# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
inputs_types=[torch.int8, torch.int8],
targets=["ttnn"],
)
def test_bitwise_and(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_and(in0, in1)


# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
inputs_types=[torch.int8, torch.int8],
targets=["ttnn"],
)
def test_bitwise_or(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_or(in0, in1)


# NOTE: The generated flatbuffer will currently fail to run due to only floats
# being supported by the runtime. See issue #1775 for tracking
@compile_to_flatbuffer(
[
(64, 64),
(64, 64),
],
inputs_types=[torch.int8, torch.int8],
targets=["ttnn"],
)
def test_bitwise_xor(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.bitwise_xor(in0, in1)


@compile_to_flatbuffer(
Expand Down Expand Up @@ -346,17 +354,18 @@ def test_minimum(in0: Operand, in1: Operand, builder: TTIRBuilder):
return builder.minimum(in0, in1)


# TODO: uncomment when we have control over the input types
# @compile_to_flatbuffer(
# [
# (64, 64),
# (64, 64),
# (64, 64),
# ],
# targets=["ttnn"],
# [
# (64, 64),
# (64, 64),
# (64, 64),
# ],
# inputs_types = [torch.int8, torch.float32, torch.float32],
# targets=["ttnn"],
# )
# def test_where(in0: Operand, in1: Operand, in2: Operand, builder: TTIRBuilder):
# return builder.where(in0, in1, in2)
# return builder.where(in0, in1, in2)
#


@compile_to_flatbuffer(
Expand Down

0 comments on commit 8b0133b

Please sign in to comment.