Skip to content

Commit

Permalink
#13693: Fixes to tech_report/Programming Mesh of Devices (#13714)
Browse files Browse the repository at this point in the history
  • Loading branch information
cfjchu authored Oct 10, 2024
1 parent 554363d commit 68e3aa7
Showing 1 changed file with 3 additions and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -289,7 +289,7 @@ Let's see an example of how to use the Ring All-Gather operation:
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,128), dtype=torch.bfloat16)
Expand Down Expand Up @@ -318,7 +318,7 @@ This time, we'll issue the CCL Line All-Gather operation along the cluster y-axi
```py
import ttnn

mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4))
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2, 4), mesh_type=ttnn.MeshType.Ring)

# Construct test tensor of data; 8 chunks of 32x32
torch_tensor = torch.rand((1,1,32,128), dtype=torch.bfloat16)
Expand Down Expand Up @@ -516,7 +516,7 @@ torch_hidden_states = (torch.rand(batch_size, 1, sequence_length, config.hidden_
torch_output = model.forward(torch_hidden_states)

# Device Initialization
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4))
mesh_device = ttnn.open_mesh_device(ttnn.MeshShape(2,4), mesh_type=ttnn.MeshType.Ring)

# Initialize input activations on all devices in the mesh
# Alternatively, we can shard the input activations on the height dimension and
Expand Down

0 comments on commit 68e3aa7

Please sign in to comment.