Skip to content

Commit

Permalink
#5337: dense matmul after all-gather
Browse files Browse the repository at this point in the history
  • Loading branch information
sraizada-tt authored and mtairum committed Jun 5, 2024
1 parent 3023ec0 commit a409944
Showing 1 changed file with 11 additions and 23 deletions.
34 changes: 11 additions & 23 deletions models/demos/t3000/mixtral8x7b/tt/mixtral_attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,11 +90,11 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype):
.unsqueeze(0)
.unsqueeze(0),
device=self.device_mesh,
mesh_mapper=ShardTensorToMesh(self.device_mesh, dim=-2),
mesh_mapper=ReplicateTensorToMesh(self.device_mesh),
dtype=self.dtype,
memory_config=self.model_config["ATTN_WEIGHTS_MEMCFG"],
layout=self.model_config["ATTN_W_LAYOUT_TILE"],
cache_file_name=cache_name(f"wo_multidevice4d"),
cache_file_name=cache_name(f"wo_multidevice4d_H"),
)

cache_k = torch.zeros(
Expand Down Expand Up @@ -129,17 +129,6 @@ def __init__(self, device_mesh, state_dict, args, layer_num, dtype):

self.scale = self.head_dim**-0.5

reduce_mask_torch = torch.zeros(1, 1, self.max_batch_size, self.max_batch_size * 8)
for i in range(self.max_batch_size):
reduce_mask_torch[:, :, i, range(i, self.max_batch_size * 8, self.max_batch_size)] = 1
self.reduce_mask = ttnn.from_torch(
reduce_mask_torch,
device=self.device_mesh,
mesh_mapper=ReplicateTensorToMesh(self.device_mesh),
dtype=ttnn.bfloat8_b,
layout=ttnn.TILE_LAYOUT,
)

self.compute_kernel = self.model_args.get_compute_kernel_config()
self.compute_kernel_attn = self.model_args.get_compute_kernel_attn_config()

Expand Down Expand Up @@ -300,27 +289,26 @@ def forward(
)
attn_output_1B4D.deallocate(True)

# attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved(
# attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG
# )
attn_output_11BH = ttnn.experimental.tensor.sharded_to_interleaved(
attn_output_11BH, output_mem_config=ttnn.L1_MEMORY_CONFIG
)

###
# Output matmul
###
# All gather
dense_outputs_11BH_gathered = ttnn.all_gather(attn_output_11BH, dim=3, num_links=1)

dense_out_11BH = ttnn.experimental.operations.primary.matmul(
attn_output_11BH,
# return the sum of the outputs
dense_outputs_11BH = ttnn.experimental.operations.primary.matmul(
dense_outputs_11BH_gathered,
wo,
output_mem_config=self.model_config["LM_HEAD_OUTPUT_MEMCFG"],
# compute_with_storage_grid_size=(8, 8),
program_config=self.model_config["LM_HEAD_OUTPUT_PROGCFG"],
compute_kernel_config=self.compute_kernel,
output_dtype=ttnn.bfloat8_b,
)
attn_output_11BH.deallocate(True)
# All gather
dense_outputs_11BH = ttnn.all_gather(dense_out_11BH, dim=2, num_links=1)

# return the sum of the outputs
dense_outputs_11BH = ttnn.experimental.operations.primary.matmul(self.reduce_mask, dense_outputs_11BH)
dense_outputs_11BH_gathered.deallocate(True)
return dense_outputs_11BH

0 comments on commit a409944

Please sign in to comment.