Skip to content

Commit

Permalink
#12828: update ttnn matmul doc string (#13071)
Browse files Browse the repository at this point in the history
* #12828: update ttnn matmul doc string

* #12828: update matmul doc string about 0 dimensions
  • Loading branch information
bbradelTT authored Sep 25, 2024
1 parent 3875f61 commit d0983f9
Showing 1 changed file with 86 additions and 53 deletions.
139 changes: 86 additions & 53 deletions ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -152,40 +152,65 @@ void py_module(py::module& module) {
R"doc(
Returns the matrix product of two tensors.
The behavior depends on the dimensionality of the tensors as follows:
- If both arguments are 2-dimensional, the matrix-matrix product is returned.
- If the first argument is 1-dimensional and the second argument is 2-dimensional,
a 1 is prepended to its dimension for the purpose of the matrix multiply.
- After the matrix multiply, the prepended dimension is removed.
- If the first argument is 2-dimensional and the second argument is 1-dimensional,
the matrix-vector product is returned in 2 dimensions.
- If both arguments are at least 1-dimensional and at least one argument is
N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first
argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the
batched matrix multiply. If the second argument is 1-dimensional, a
1 is appended to its dimension for the purpose of the batched matrix multiple.
The non-matrix (i.e. batch) dimensions must be broadcastable.
The behaviour is the same as PyTorch, with the exception of two cases of batch dimensions:
- The two batch dimensions are swapped. E.g. :math:`(j \times 1)` and :math:`(1 \times j)`
or :math:`(1 \times j)` and :math:`(j \times 1)`.
- When a batch dimension is implicitly extended then the two patch dimensions are swapped.
E.g. :math:`(j \times 1)` and :math:`(j)` which is treated as
:math:`(j \times 1)` and :math:`(1 \times j)`.
- In order to leverage sharded matmul implementations we can shard both :attr:`input_tensor_a` and :attr:`input_tensor_b`. The sharding strategy used will be according
to the sharding strategy on the respective tensor. A sharded 1D matmul can be either HEIGHT or WIDTH sharded, 2D matmuls can be block sharded.
Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs
are broadcastable, and not the matrix dimensions. For example, if :attr:`input_tensor_a` is a
:math:`(j \times 1 \times n \times m)` tensor and :attr:`input_tensor_b` is a :math:`(k \times m \times p)`
tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the
matrix dimensions) are different. The operation will return a :math:`(j \times k \times n \times p)` tensor.
Note:
The 1-dimensional dot product version of this function is currently returning the Tensor with a non-empty shape. This is expected to be fixed in an upcoming release.
The input tensors need to be tiled. Therefore, the input tensors have to be
at least 2-dimensional.
If the input tensors have more than two dimensions, the additional, front,
dimensions may be used for batched matrix multiply.
These front dimensions may also be referred to as batch dimensions.
E.g. a tensor with dimensions :math:`(a \\times b \\times c \\times d)`
has batch dimensions a and b.
The following are the allowed possibilities for batch dimensions.
Examples below show concrete operations and tensor sizes.
- If all batch dimensions are all of size 1, then there is no batched operation.
- If both inputs have batch dimensions that are not all of size 1, then the
batch dimensions of both inputs should be the same. If the dimensions are
not the same then, although there may be combinations that may work, in most
cases various errors will be reported.
- If the first input has batch dimensions that are not all of size 1, and the
second input has no batch dimensions or has batch dimensions all of size 1,
then the second input is broadcast to align appropriately with the first
input.
- Matrix multiplication will not work if the first input has batch
dimensions that are all of size 1 and the second input has batch dimensions
that are not all of size 1.
- Note: Dimensions of size 0 are not supported.
- Note: In general, the number of dimensions between the two inputs should
match. There may be cases where they don't. In that case, if the inputs
are not valid based on the above criteria, the error messages may
be unexpected and refer to non-obvious issues.
- Note: There are various combinations of dimensions possible. The behaviour
is the same as PyTorch, except for two exceptions.
These exceptions are for the following scenarios related to batch
dimensions:
- The two batch dimensions are swapped. E.g. the first input has :math:`(j \\times 1)`
and the second input has :math:`(1 \\times j)`
or the first input has :math:`(1 \\times j)` and the second input has
:math:`(j \\times 1)`
- When a batch dimension is implicitly extended then the two patch dimensions are swapped.
E.g. :math:`(j \\times 1)` and :math:`(j)` which is treated as
:math:`(j \\times 1)` and :math:`(1 \\times j)`
- In order to leverage sharded matmul implementations we can shard both input_tensor_a and input_tensor_b. The sharding strategy used will be according
to the sharding strategy on the respective tensor. A sharded 1D matmul can be either HEIGHT or WIDTH sharded, 2D matmuls can be block sharded.
Note: the broadcasting logic only looks at the batch dimensions when determining if the inputs
are broadcastable, and not the matrix dimensions. For example, if :attr:`input_tensor_a` is a
:math:`(j \\times 1 \\times n\_size \\times m\_size)` tensor and :attr:`input_tensor_b` is a :math:`(k\_size \\times m\_size \\times p)`
tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the
matrix dimensions) are different. The operation will return a :math:`(j \\times k\_size \\times n\_size \\times p)` tensor.
- Note: there are various additional constraints related to specific program
configs chosen. Please look at the error messages carefully and fix
problems appropriately.
Args:
input_tensor_a (ttnn.Tensor): the first tensor to be multiplied. Needs to be on the device.
Expand All @@ -203,36 +228,42 @@ void py_module(py::module& module) {
ttnn.Tensor: the output tensor.
Example:
>>> # vector x vector
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device)
>>> output = tensor1 @ tensor2
>>> print(output.shape)
[32]
>>> # matrix x vector
>>> # matrix x matrix - no batch dimensions
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((64, 32), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device)
>>> output = tensor1 @ tensor2
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device)
>>> output = ttnn.matmul(tensor1, tensor2)
>>> print(output.shape)
[64, 64]
>>> # extended matrix x extended matrix - all batch dimensions of size 1
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device)
>>> output = ttnn.matmul(tensor1, tensor2)
>>> print(output.shape)
[1, 1, 64, 64]
>>> # extended matrix x extended matrix - all batch dimensions of size 1
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device)
>>> output = ttnn.matmul(tensor1, tensor2)
>>> print(output.shape)
[64, 1]
>>> # batched matrix x broadcasted vector
[1, 1, 64, 64]
>>> # batched matrix x broadcasted matrix - first input has batch dimensions not of size 1
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device)
>>> output = tensor1 @ tensor2
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device)
>>> output = ttnn.matmul(tensor1, tensor2)
>>> print(output.shape)
[10, 64, 1]
>>> # batched matrix x batched matrix
[10, 64, 64]
>>> # batched matrix x batched matrix - both inputs have batch dimensions
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 32, 128), dtype=torch.bfloat16)), device)
>>> output = tensor1 @ tensor2
>>> output = tensor1 @ tensor2 # alternative to ttnn.matmul(tensor1, tensor2)
>>> print(output.shape)
[10, 64, 128]
>>> # batched matrix x broadcasted matrix
>>> # batched matrix x broadcasted extended matrix - first input has batch dimensions not of size 1
>>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 128), dtype=torch.bfloat16)), device)
>>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 128), dtype=torch.bfloat16)), device)
>>> output = tensor1 @ tensor2
>>> print(output.shape)
[10, 64, 128]
[1, 10, 64, 128]
)doc",
ttnn::pybind_overload_t{
[](decltype(::ttnn::matmul)& self,
Expand Down Expand Up @@ -280,6 +311,8 @@ void py_module(py::module& module) {
R"doc(
Returns the linear transformation of the inputs.
The limitations and behaviours are the same as for matmul.
Args:
input_tensor_a (ttnn.Tensor): the first tensor to be multiplied. Needs to be on the device.
input_tensor_b (ttnn.Tensor): the second tensor to be multiplied. Needs to be on the device.
Expand Down

0 comments on commit d0983f9

Please sign in to comment.