Skip to content

Commit

Permalink
#0: fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
yugaoTT committed Dec 11, 2024
1 parent f949ba3 commit 01aed9b
Showing 1 changed file with 15 additions and 20 deletions.
35 changes: 15 additions & 20 deletions tests/ttnn/unit_tests/operations/test_matmul.py
Original file line number Diff line number Diff line change
Expand Up @@ -980,6 +980,11 @@ def run_matmul_1d_multiple_output_blocks_per_core(
if out_sharded and num_out_block_w > 1:
pytest.skip("out sharded not support multiple blocks on w dim")

if not mcast_in0:
tmp = m
m = n
n = tmp

in0_shape = [1, 1, m, k]
in1_shape = [1, 1, k, n]
bias_shape = [1, 1, n]
Expand Down Expand Up @@ -1012,26 +1017,16 @@ def run_matmul_1d_multiple_output_blocks_per_core(
in0 = torch.randn(in0_shape).bfloat16().float()
in1 = torch.randn(in1_shape).bfloat16().float()

if in_sharded:
if mcast_in0:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
in1_memory_config = ttnn.DRAM_MEMORY_CONFIG
else:
in0_memory_config = ttnn.DRAM_MEMORY_CONFIG
in1_memory_config = ttnn.create_sharded_memory_config(
(1, 1, k, n),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
if in_sharded and mcast_in0:
in0_memory_config = ttnn.create_sharded_memory_config(
(1, 1, m, k),
core_grid=ttnn.CoreGrid(y=grid_size[1], x=grid_size[0]),
strategy=ttnn.ShardStrategy.WIDTH,
orientation=ttnn.ShardOrientation.ROW_MAJOR,
)
else:
in0_memory_config = ttnn.DRAM_MEMORY_CONFIG
in1_memory_config = ttnn.DRAM_MEMORY_CONFIG
in1_memory_config = ttnn.DRAM_MEMORY_CONFIG
in0_t = ttnn.from_torch(
in0,
dtype=ttnn.bfloat16,
Expand Down Expand Up @@ -1085,13 +1080,13 @@ def run_matmul_1d_multiple_output_blocks_per_core(
fp32_dest_acc_en=False,
packer_l1_acc=True,
)
if out_sharded:
if out_sharded and mcast_in0:
out_mem_config = ttnn.MemoryConfig(
memory_layout=ttnn.TensorMemoryLayout.WIDTH_SHARDED,
buffer_type=ttnn.BufferType.L1,
)
else:
out_mem_config = ttnn.L1_MEMORY_CONFIG
out_mem_config = ttnn.DRAM_MEMORY_CONFIG

if has_bias:
output_t = ttnn.linear(
Expand Down

0 comments on commit 01aed9b

Please sign in to comment.