Skip to content

Commit

Permalink
#16153: Implement binary-ng fused input activations
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickroberts committed Jan 8, 2025
1 parent 386a47c commit 7f4eb32
Show file tree
Hide file tree
Showing 18 changed files with 726 additions and 357 deletions.
239 changes: 165 additions & 74 deletions tests/ttnn/unit_tests/operations/eltwise/test_binary_bcast.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,58 +10,173 @@
compare_pcc,
)
from models.utility_functions import skip_for_grayskull
from itertools import product as parameters


binary_fns = {
"gte",
"gt",
"lte",
"lt",
"eq",
"ne",
"logical_and",
"logical_or",
"logical_xor",
"ldexp",
"logaddexp",
"logaddexp2",
"squared_difference",
"add",
"sub",
"mul",
"div",
"bias_gelu",
}
activation_fns = {
"EXP": torch.exp,
"GELU": torch.nn.functional.gelu,
"RELU": torch.relu,
"SQRT": torch.sqrt,
"SIGMOID": torch.sigmoid,
"LOG": torch.log,
"TANH": torch.tanh,
"LOG2": torch.log2,
"LOG10": torch.log10,
"SIN": torch.sin,
"COS": torch.cos,
"ABS": torch.abs,
"SIGN": torch.sign,
"SQUARE": torch.square,
"EQZ": lambda x: torch.eq(x, 0),
"NEZ": lambda x: torch.not_equal(x, 0),
"GTZ": lambda x: torch.greater(x, 0),
"LTZ": lambda x: torch.less(x, 0),
"GEZ": lambda x: torch.greater_equal(x, 0),
"LEZ": lambda x: torch.less_equal(x, 0),
"EXP2": torch.exp2,
"EXPM1": torch.expm1,
"SIGNBIT": torch.signbit,
"RSQRT": torch.rsqrt,
"RELU6": torch.nn.functional.relu6,
"ATAN": torch.atan,
"ERF": torch.erf,
"ERFC": torch.erfc,
"ISINF": torch.isinf,
"ISPOSINF": torch.isposinf,
"ISNEGINF": torch.isneginf,
"ISNAN": torch.isnan,
"LOGICAL_NOT_UNARY": torch.logical_not,
"ISFINITE": torch.isfinite,
"ERFINV": torch.erfinv,
"I0": torch.i0,
"TAN": torch.tan,
"SILU": torch.nn.functional.silu,
"NEG": torch.neg,
"FLOOR": torch.floor,
"CEIL": torch.ceil,
}
no_activations = ((), (), ())
square_lhs = (("SQUARE",), (), ())
sin_rhs = ((), ("SIN",), ())
floor_lhs_ceil_rhs_cos_post = (("FLOOR",), ("CEIL",), ("COS",))
exp_floor_lhs_exp_rhs = (("FLOOR", "EXP"), ("EXP",), ())
log_lhs_sqrt_abs_post = (("LOG",), (), ("ABS", "SQRT"))
exp_post = ((), (), ("EXP",))
log_post = ((), (), ("LOG",))
tanh_post = ((), (), ("TANH",))
log2_post = ((), (), ("LOG2",))
log10_post = ((), (), ("LOG10",))
exp2_post = ((), (), ("EXP2",))
expm1_post = ((), (), ("EXPM1",))
erfinv_post = ((), (), ("ERFINV",))
i0_post = ((), (), ("I0",))
tan_post = ((), (), ("TAN",))
floor_post = ((), (), ("FLOOR",))
ceil_post = ((), (), ("CEIL",))


def rand_bf16_gen(shape, device, *, min=0, max=1):
pt = torch.rand(shape, dtype=torch.bfloat16) * (max - min) + min
tt = ttnn.from_torch(pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
return pt, tt


@skip_for_grayskull("Possible accuracy issues with grayskull")
@pytest.mark.parametrize(
"input_shapes",
"a_shape, b_shape",
(
(torch.Size([1, 1, 1, 1]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 1, 64, 1]), torch.Size([1, 3, 1, 128])),
(torch.Size([5, 1, 1, 64]), torch.Size([1, 3, 128, 1])),
),
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.gte,
ttnn.experimental.gt,
ttnn.experimental.lte,
ttnn.experimental.lt,
ttnn.experimental.eq,
ttnn.experimental.ne,
ttnn.experimental.logical_and,
ttnn.experimental.logical_or,
ttnn.experimental.logical_xor,
ttnn.experimental.ldexp,
ttnn.experimental.logaddexp,
ttnn.experimental.logaddexp2,
ttnn.experimental.squared_difference,
ttnn.experimental.add,
ttnn.experimental.sub,
ttnn.experimental.mul,
ttnn.experimental.div,
ttnn.experimental.bias_gelu,
],
"ttnn_fn, activations",
{
*parameters(
binary_fns,
{
no_activations,
square_lhs,
sin_rhs,
floor_lhs_ceil_rhs_cos_post,
exp_floor_lhs_exp_rhs,
log_lhs_sqrt_abs_post,
},
),
*parameters({"add"}, {((), (), (op,)) for op in activation_fns.keys()}),
}.difference(
parameters({"eq", "ne"}, {square_lhs, sin_rhs, exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}),
parameters({"logaddexp", "logaddexp2"}, {floor_lhs_ceil_rhs_cos_post}),
parameters({"gte", "lt", "lte"}, {exp_floor_lhs_exp_rhs, log_lhs_sqrt_abs_post}),
parameters({"logical_and", "logical_or", "logical_xor", "bias_gelu"}, {log_lhs_sqrt_abs_post}),
parameters({"div"}, {exp_post, tanh_post, exp2_post, expm1_post, i0_post, tan_post}),
parameters({"sub"}, {log_post, log2_post, log10_post}),
parameters({"ldexp"}, {erfinv_post, tan_post, floor_post, ceil_post}),
parameters({"squared_difference"}, {erfinv_post, i0_post}),
parameters({"add"}, {tan_post, tanh_post}),
{("mul", log_lhs_sqrt_abs_post)},
),
)
def test_binary_scalar_ops(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
a_pt = torch.rand(a_shape).bfloat16()
b_pt = torch.rand(b_shape).bfloat16()
def test_binary_scalar_ops(a_shape, b_shape, ttnn_fn, activations, device):
torch.manual_seed(0)
ttnn_op = getattr(ttnn.experimental, ttnn_fn)
lhs, rhs, post = ([getattr(ttnn.UnaryOpType, op) for op in ops] for ops in activations)
golden_lhs, golden_rhs, golden_post = ((activation_fns[op] for op in ops) for ops in activations)
# make 0 exclusive for rhs of div
min, max = (1, 0) if ttnn_fn == "div" else (0, 1)

a_pt, a_tt = rand_bf16_gen(a_shape, device)
b_pt, b_tt = rand_bf16_gen(b_shape, device, min=min, max=max)

a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
cq_id = 0
out_tt = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
golden_fn = ttnn.get_golden_function(ttnn_fn)
out_pt = golden_fn(a_pt, b_pt)
out_tt = ttnn_op(a_tt, b_tt, queue_id=cq_id, lhs_activations=lhs, rhs_activations=rhs, post_activations=post)

for golden_activation in golden_lhs:
a_pt = golden_activation(a_pt).bfloat16()

for golden_activation in golden_rhs:
b_pt = golden_activation(b_pt).bfloat16()

golden_fn = ttnn.get_golden_function(ttnn_op)
out_pt = golden_fn(a_pt, b_pt).bfloat16()

comp_pass = compare_pcc([out_tt], [out_pt])
assert comp_pass
for golden_activation in golden_post:
out_pt = golden_activation(out_pt).bfloat16()

def compare(tt, pt):
imprecise_cases = {
*parameters({"bias_gelu"}, {square_lhs, floor_lhs_ceil_rhs_cos_post}),
*parameters({"gte", "gt", "lte", "lt"}, {sin_rhs}),
}
return compare_pcc(tt, pt, 0.98) if (ttnn_fn, activations) in imprecise_cases else compare_pcc(tt, pt)

assert compare([out_tt], [out_pt])


@pytest.mark.parametrize(
"input_shapes",
"a_shape, b_shape",
(
(torch.Size([1, 1, 31, 32]), torch.Size([5, 3, 32, 32])),
(torch.Size([5, 2, 64, 1]), torch.Size([1, 3, 1, 128])),
Expand All @@ -70,43 +185,23 @@ def test_binary_scalar_ops(input_shapes, ttnn_fn, device):
)
@pytest.mark.parametrize(
"ttnn_fn",
[
ttnn.experimental.gte,
ttnn.experimental.gt,
ttnn.experimental.lte,
ttnn.experimental.lt,
ttnn.experimental.eq,
ttnn.experimental.ne,
ttnn.experimental.logical_and,
ttnn.experimental.logical_or,
ttnn.experimental.logical_xor,
ttnn.experimental.ldexp,
ttnn.experimental.logaddexp,
ttnn.experimental.logaddexp2,
ttnn.experimental.squared_difference,
ttnn.experimental.add,
ttnn.experimental.sub,
ttnn.experimental.mul,
ttnn.experimental.div,
ttnn.experimental.bias_gelu,
],
binary_fns,
)
def test_binary_scalar_ops_invalid_bcast(input_shapes, ttnn_fn, device):
a_shape, b_shape = input_shapes
a_pt = torch.rand(a_shape).bfloat16()
b_pt = torch.rand(b_shape).bfloat16()
def test_binary_scalar_ops_invalid_bcast(a_shape, b_shape, ttnn_fn, device):
torch.manual_seed(0)
ttnn_op = getattr(ttnn.experimental, ttnn_fn)

a_tt = ttnn.from_torch(a_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
b_tt = ttnn.from_torch(b_pt, device=device, layout=ttnn.TILE_LAYOUT, memory_config=ttnn.DRAM_MEMORY_CONFIG)
_, a_tt = rand_bf16_gen(a_shape, device)
_, b_tt = rand_bf16_gen(b_shape, device)

with pytest.raises(RuntimeError) as e:
cq_id = 0
_ = ttnn_fn(a_tt, b_tt, queue_id=cq_id)
_ = ttnn_op(a_tt, b_tt, queue_id=cq_id)
assert "Broadcasting rule violation" in str(e.value)


@pytest.mark.parametrize(
"shapes",
"a_shape, b_shape",
[
[[1, 71, 7, 7], [7, 7]],
[[920, 1, 256], [256]],
Expand All @@ -119,17 +214,14 @@ def test_binary_scalar_ops_invalid_bcast(input_shapes, ttnn_fn, device):
[[16, 1], [1, 1, 32]],
],
)
def test_unequal_ranks(device, shapes):
def test_unequal_ranks(a_shape, b_shape, device):
torch.manual_seed(0)
torch_input_tensor_a = torch.rand(shapes[0], dtype=torch.bfloat16)
torch_input_tensor_b = torch.rand(shapes[1], dtype=torch.bfloat16)

torch_input_tensor_a, input_tensor_a = rand_bf16_gen(a_shape, device)
torch_input_tensor_b, input_tensor_b = rand_bf16_gen(b_shape, device)

torch_output_tensor = torch_input_tensor_a + torch_input_tensor_b
input_tensor_a = ttnn.from_torch(
torch_input_tensor_a, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)
input_tensor_b = ttnn.from_torch(
torch_input_tensor_b, layout=ttnn.TILE_LAYOUT, device=device, memory_config=ttnn.DRAM_MEMORY_CONFIG
)

output_tensor = ttnn.experimental.add(input_tensor_a, input_tensor_b, memory_config=ttnn.DRAM_MEMORY_CONFIG)
output_tensor = ttnn.to_torch(output_tensor)

Expand All @@ -138,7 +230,7 @@ def test_unequal_ranks(device, shapes):


@pytest.mark.parametrize(
"data",
"a, b, c_golden",
[
([], [], []),
([1], [2], [3]),
Expand All @@ -150,8 +242,7 @@ def test_unequal_ranks(device, shapes):
],
)
@pytest.mark.parametrize("memory_config", [ttnn.DRAM_MEMORY_CONFIG, ttnn.L1_MEMORY_CONFIG])
def test_01_volume_tensors(device, data, memory_config):
(a, b, c_golden) = data
def test_01_volume_tensors(device, a, b, c_golden, memory_config):
a = torch.BFloat16Tensor(a)
b = torch.BFloat16Tensor(b)
assert torch.add(a, b).tolist() == c_golden
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -151,11 +151,11 @@ std::map<std::string, std::string> get_defines(
defines["ELTWISE_OP"] = op_name.c_str();
defines["ELTWISE_OP_TYPE"] = op_binary_type.c_str();
if (fused_activations.has_value()) {
if (op_type == BinaryOpType::ADD and fused_activations.value().size() == 1 and
fused_activations.value().at(0).op_type == UnaryOpType::RELU) {
if (op_type == BinaryOpType::ADD and fused_activations->size() == 1 and
fused_activations->at(0).op_type == UnaryOpType::RELU and not input_tensor_a_activation.has_value()) {
defines["PACK_RELU"] = "1";
} else {
defines.merge(ttnn::operations::unary::utils::get_block_defines(fused_activations.value(), "0", idst));
defines.merge(ttnn::operations::unary::utils::get_block_defines(*fused_activations, "0", idst));
}
}

Expand Down
Loading

0 comments on commit 7f4eb32

Please sign in to comment.