Skip to content

Commit

Permalink
Multi-chip spec update, add meshShape to DeviceAttr (#659)
Browse files Browse the repository at this point in the history
- Updated the device spec to outline new meshShape attribute usage and
  lowering to TTNN.
- Add meshShape to DeviceAttr which drives interpretation of
  multi-device device and tensor grids.
- Add pass and pipeline option "mesh-shape" for providing a mesh shape
  to the device attribute creation.
- Make device and system_desc attributes as alias attributes, i.e.
  hoisted in mlir printing.
- Add llvm lit config features based on provided system_desc
- Update `test/python/device_attr.py` to match spec
- Add simple test marked as XFAIL for now
  • Loading branch information
nsmithtt authored Sep 14, 2024
1 parent c4d70db commit 0c45bba
Show file tree
Hide file tree
Showing 13 changed files with 363 additions and 114 deletions.
155 changes: 126 additions & 29 deletions docs/src/specs/device.md
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ document will use the following definitions:
is a view over the system.
- **Logical Grid** or just **Grid**: Is a logical shape that abstracts one or
more **Physical Grids**.
- **Mesh Shape**: Describes the virtual layout of the chips with respect to each
other. In practice the mesh shape is used to derive the logical grid.

## Motivation

Expand All @@ -38,6 +40,13 @@ The device attribute strives to achieve the following goals:
- Enable many forms of data parallel execution strategies for single and
multi chip systems under a single representation.

## Scope

This document will cover how the device attribute is encoded and how it can be
lowered to backend dialects. The document will not cover the algorithm for
choosing the best, or even legal, device configurations for a given physical
system.

## Examples

All of the following examples will assume the physical hardware has an 8x8 physical
Expand All @@ -48,17 +57,27 @@ each with an 8x8 physical grid.
underlying physical hardware device.

```mlir
#tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>
#tt.device<
workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>,
meshShape = 1,
chipIds = [0]
>
```

Let's break down what each of these attributes mean:
- `#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>`: This is a 2D logical grid with dim 8x8.
- `workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>`: This is a 2D logical grid with dim 8x8.
It's followed by an affine map `(d0, d1) -> (0, d0, d1)` that provides a mapping
from the logical grid to the physical grid. In this case, the logical grid is the same
as the physical grid, so the mapping is the identity function. The logical
grid can have any rank, but the physical mapping is always 3D, with the first
being the chip index, followed by the 2D physical core index within the chip.
- `[0]`: This is a list of chip indices. These chip indices directly reference
- `meshShape = 1`: A shape provided as part of the `DeviceAttr` constructor that
describes the virtual layout of the chips with respect to each other. Note that
in a multi-chip system, this grid encapsulates the entire system's grid shape,
e.g. 8x16 grid could be made up of a 1x2 mesh of chips side-by-side. The mesh
attribute configures how the above grid/map attributes are created such that they
implement this mesh topology.
- `chipIds = [0]`: This is a list of chip indices. These chip indices directly reference
the same chip indices in the system descriptor. The `SystemDesc` attribute
that this is in reference to is tagged on the top level `ModuleOp`.

Expand Down Expand Up @@ -98,9 +117,14 @@ Specific examples that this document will cover:

Given a 2 chip system, `[2, 8x8]`, we can represent a simple data parallel
logical grid that divides the batch dimension in half across the two chips.
This is denoted by `meshShape = 2x1x1` which means the logical grid is 3D.

```mlir
#tt.device<#tt.grid<2x8x8, (d0, d1, d2) -> (d0, d1, d2)>, [0, 1]>
#tt.device<
workerGrid = #tt.grid<2x8x8, (d0, d1, d2) -> (d0, d1, d2)>,
meshShape = 2x1x1,
chipIds = [0, 1]
>
```

The affine map here is just identity, so dims `d1` and `d2` directly index
Expand All @@ -127,10 +151,15 @@ a 2x4 grid out of the 8x8 physical grid available.

In this example we will consider a 2 chip system, `[2, 8x8]`, and view it as
though the two chips are concatenated together side by side to form a single
`8x16` grid.
`8x16` grid. This is denoted by `meshShape = 1x2` which means to concatenate
the chips in the second dimension.

```mlir
#tt.device<#tt.grid<8x16, (d0, d1) -> ((d0 floordiv 8) * 2 + d1 floordiv 8, d0, d1 mod 8)>, [0, 1]>
#tt.device<
workerGrid = #tt.grid<8x16, (d0, d1) -> ((d0 floordiv 8) * 2 + d1 floordiv 8, d0, d1 mod 8)>,
meshShape = 1x2,
chipIds = [0, 1]
>
```

Here we can see that the affine map encodes an indexing pattern such that when
Expand All @@ -157,10 +186,15 @@ physically spanning across two chips.

The previous 2 examples can be composed together to form a logical grid that
divides tensor across multiple dimensions. Here we will consider a 4 chip
system `[4, 8x8]` and view it as a `2x8x16` grid.
system `[4, 8x8]` and view it as a `2x8x16` grid. Note that the `meshShape` is
`2x1x2` which means to concatenate the chips in the first and third dimensions.

```mlir
#tt.device<#tt.grid<2x8x16, (d0, d1, d2) -> (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>, [0, 1, 2, 3]>
#tt.device<
workerGrid = #tt.grid<2x8x16, (d0, d1, d2) -> (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>,
meshShape = 2x1x2,
chipIds = [0, 1, 2, 3]
>
```

We can evaluate the affine map to see that the chips are interpreted in chunks of
Expand Down Expand Up @@ -194,8 +228,16 @@ take 4 chips and interpret them differently (though they could take the same
logical grid).

```mlir
#tt.device<#tt.grid<2x8x16, (d0, d1, d2) -> (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>, [0, 1, 2, 3]>
#tt.device<#tt.grid<16x16, (d0, d1) -> ((d0 floordiv 8) * 2 + d1 floordiv 8, d0 mod 8, d1 mod 8)>, [4, 5, 6, 7]>
#tt.device<
workerGrid = #tt.grid<2x8x16, (d0, d1, d2) -> (d0 * 2 + (d1 floordiv 8) * 2 + d2 floordiv 8, d1, d2 mod 8)>,
meshShape = 2x1x2,
chipIds = [0, 1, 2, 3]
>
#tt.device<
workerGrid = #tt.grid<16x16, (d0, d1) -> ((d0 floordiv 8) * 2 + d1 floordiv 8, d0 mod 8, d1 mod 8)>,
meshShape = 2x2,
chipIds = [4, 5, 6, 7]
>
```

### Reinterpreted Grids (Transpose)
Expand All @@ -222,12 +264,20 @@ relu(aT)

1. We'll establish a regular, single chip, identity logical grid:
```mlir
#tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, d1)>, [0]>
#tt.device<
workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, d1)>,
meshShape = 1,
chipIds = [0]
>
```
2. Execute `exp`.
3. We'll reinterpret the grid as transposed:
```mlir
#tt.device<#tt.grid<8x8, (d0, d1) -> (0, d1, d0)>, [0]>
#tt.device<
workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d1, d0)>,
meshShape = 1,
chipIds = [0]
>
```
4. _Execute_ `transpose`. Note that each core only needs to transpose their
data locally. Eventually this could be implemented as a no-op by reindexing
Expand All @@ -243,39 +293,86 @@ For the sake of examples, here's a few more ways of reinterpreting the logical g

#### Extra Wide Grid
```mlir
#tt.device<#tt.grid<1x64, (d0, d1) -> (0, d0 * 8 + d1 floordiv 8, d1 mod 8)>, [0]>
#tt.device<
workerGrid = #tt.grid<1x64, (d0, d1) -> (0, d0 * 8 + d1 floordiv 8, d1 mod 8)>,
meshShape = 1,
chipIds = [0]
>
```

#### Extra Tall + Transposed Grid
```mlir
#tt.device<#tt.grid<64x1, (d0, d1) -> (0, d1 * 8 + d0 floordiv 8, d0 mod 8)>, [0]>
#tt.device<
workerGrid = #tt.grid<64x1, (d0, d1) -> (0, d1 * 8 + d0 floordiv 8, d0 mod 8)>,
meshShape = 1,
chipIds = [0]
>
```

#### Staircase
```mlir
#tt.device<#tt.grid<8x8, (d0, d1) -> (0, d0, (d0 + d1) mod 8)>, [0]>
#tt.device<
workerGrid = #tt.grid<8x8, (d0, d1) -> (0, d0, (d0 + d1) mod 8)>,
meshShape = 1,
chipIds = [0]
>
```

This could be an interesting starting position for data in implementing matmul as a
systolic array in a ring topology.

## Backend Lowering and Constraints
## Lowering to TTNN

While the above device attribute encoding is quite flexible, this does not
necessarily mean the target backend can actually support all of these
interpretations. TTNN backend will be relatively constrained to support only
the specialized grid topologies that are supported by the API.

### TTNN

TODO:

- Multi-device
- Grid orientation
- Height / Width sharded
- TTNN Generic

### TTMetal
interpretations. TTNN backend will be constrained to support only the
specialized grid topologies that are supported by the API.

### Grid/Shard Orientation

TODO

### Multi-device

Please refer to [TTNN Mesh Programming Docs](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md)
for more information on how to program multi-device systems with TTNN API.

Multi-device TTNN dialect will try and stay as close to the TTNN API as
possible. Let's consider what this looks like from the compiler and runtime
perspectives:

#### Compiler

- **Device Creation**: The TTNN device in the compiler is exactly the same attribute
from the ttir dialect. It will encode the `meshShape` into the flatbuffer
which can be directly used to program `::ttnn::MeshShape`.
- **Tensor Layout**: Again, the tensor layout is inherited in TTNN dialect from the
ttir dialect. The grid attribute in the tensor layout can be trivially
divided by `meshShape` to determine the shape of the tensor slice on each device.
Broadcasting rules can be applied to determine which [Distribution Strategy](https://github.com/tenstorrent/tt-metal/blob/main/tech_reports/Programming%20Mesh%20of%20Devices/Programming%20Mesh%20of%20Devices%20with%20TT-NN.md#3-distributing-tensor-to-meshdevice)
to use:
- **Mesh Sharded**: If the tensor grid is > 1 along the `meshShape` dimensions,
the tensor will be sharded across the mesh devices.
- **Replication**: If the tensor needs to be broadcasted for this op, by
extension the tensor layout will be replicated across the mesh devices.

#### Runtime

- **Device Creation**: The ttnn runtime will wholesale switch to working with
mesh devices via api `ttnn::multi_device::open_mesh_device`, this is possible
because a 1x1 mesh device is a valid single device. The mesh shape during
device open will always be `1xN` where `N` is the number of deviceIds in the
array. Note that this shape can be reinterpreted by flatbuffer programs on
the fly with `SubMesh` API.
- **Tensor Creation**: Tensor creation in a multi-device system is a bit more
involved. In order to upload a multi-device tensor to the mesh, the host
tensor much first be created with `MultiDeviceHostStorage`. The ttnn runtime
can automatically do this during `handleToHostMemoryConfigOp`:
- Regular host tensor will bounce through new tensor with
`MultiDeviceHostStorage` type.
- `tensor.to(mesh_device)` will allocate/move the tensor to the mesh device.

## Lowering to TTMetal

In TTMetal dialect we are only constrained by what we've implemented in the
tt-mlir compiler, this means it is much more flexible and can theoretically
Expand Down
13 changes: 9 additions & 4 deletions include/ttmlir/Dialect/TT/IR/TTOpsTypes.td
Original file line number Diff line number Diff line change
Expand Up @@ -328,17 +328,22 @@ def TT_DeviceAttr : TT_Attr<"Device", "device", []> {
- A grid attribute that describes the device's compute grid shape. It not only describes the shape of the compute grid, but also
carries an affine map that describes how the logical grid maps to the physical grid.
- Two affine maps that describe how a tensor layout's linear attribute maps to the L1 and DRAM memory spaces.
- An array of chip ids that this device is made up of.
- A mesh shape that describes the virtual layout of the chips with respect to each other. Note that in a multi-chip system, this grid
encapsulates the entire system's grid shape, e.g. 8x16 grid could be made up of a 1x2 mesh of chips side-by-side. The mesh
attribute configures how the above grid/map attributes are created such that they implement this mesh topology.
- An array of chip ids that this device is made up of. This array's length must match the volume of the mesh shape and should be
interpreted in row-major order.
}];
let parameters = (ins TT_GridAttr:$workerGrid,
"AffineMap":$l1Map,
"AffineMap":$dramMap,
ArrayRefParameter<"int64_t">:$meshShape,
ArrayRefParameter<"unsigned">:$chipIds);
let assemblyFormat = "`<` `workerGrid` `=` qualified($workerGrid) `,` `l1Map` `=` qualified($l1Map) `,` `dramMap` `=` qualified($dramMap) `,` `chipIds` `=` `[` $chipIds `]` `>`";
let assemblyFormat = "`<` `workerGrid` `=` qualified($workerGrid) `,` `l1Map` `=` qualified($l1Map) `,` `dramMap` `=` qualified($dramMap) `,` `meshShape` `=` custom<DimensionList>($meshShape) `,` `chipIds` `=` `[` $chipIds `]` `>`";

let extraClassDeclaration = [{
static DeviceAttr get(::mlir::MLIRContext *context, SystemDescAttr systemDesc, ArrayRef<unsigned> chipIds);
static DeviceAttr get(::mlir::MLIRContext *context, SystemDescAttr systemDesc, bool enableMultichip = false);
static DeviceAttr get(::mlir::MLIRContext *context, SystemDescAttr systemDesc, ArrayRef<int64_t> meshShape, ArrayRef<unsigned> chipIds);
static DeviceAttr get(::mlir::MLIRContext *context, SystemDescAttr systemDesc, ArrayRef<int64_t> meshShape = {});
AffineMap getMapForMemorySpace(MemorySpace memorySpace) const {
switch (memorySpace) {
case MemorySpace::DeviceL1:
Expand Down
5 changes: 5 additions & 0 deletions include/ttmlir/Dialect/TTIR/Transforms/Passes.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,11 @@ def TTIRImplicitDevice: Pass<"ttir-implicit-device", "::mlir::ModuleOp"> {
This pass will take a view of the system descriptor and create an implicit
device around it.
}];

let options = [
ListOption<"meshShape", "mesh-shape", "int64_t",
"Set the multi-device mesh shape.">,
];
}

def TTIRGenericKernel: Pass<"ttir-generic-kernel", "::mlir::ModuleOp"> {
Expand Down
3 changes: 3 additions & 0 deletions include/ttmlir/Dialect/TTNN/Pipelines/TTNNPipelines.h
Original file line number Diff line number Diff line change
Expand Up @@ -94,6 +94,9 @@ struct TTIRToTTNNBackendPipelineOptions
llvm::cl::desc(
"Pass in a system descriptor flatbuffer to compile against."),
llvm::cl::init("")};

ListOption<int64_t> meshShape{
*this, "mesh-shape", llvm::cl::desc("Set the multi-device mesh shape.")};
};

void createTTIRToTTNNBackendPipeline(
Expand Down
8 changes: 8 additions & 0 deletions include/ttmlir/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,14 @@ llvm::SmallVector<int64_t> evalShape(mlir::AffineMap map, Vector shape) {
return result;
}

template <typename IntType> IntType volume(mlir::ArrayRef<IntType> shape) {
IntType result = 1;
for (auto dim : shape) {
result *= dim;
}
return result;
}

template <typename Enum>
constexpr std::underlying_type_t<Enum> enum_as_int(Enum e) {
return static_cast<std::underlying_type_t<Enum>>(e);
Expand Down
Loading

0 comments on commit 0c45bba

Please sign in to comment.