From d0983f941aa0564563a0e9a5117a2c9f8b7333f8 Mon Sep 17 00:00:00 2001 From: Borys Bradel <164946524+bbradelTT@users.noreply.github.com> Date: Wed, 25 Sep 2024 09:17:20 -0400 Subject: [PATCH] #12828: update ttnn matmul doc string (#13071) * #12828: update ttnn matmul doc string * #12828: update matmul doc string about 0 dimensions --- .../ttnn/operations/matmul/matmul_pybind.cpp | 139 +++++++++++------- 1 file changed, 86 insertions(+), 53 deletions(-) diff --git a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp index 87ce3bb3646..f59fb95c9d2 100644 --- a/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp +++ b/ttnn/cpp/ttnn/operations/matmul/matmul_pybind.cpp @@ -152,40 +152,65 @@ void py_module(py::module& module) { R"doc( Returns the matrix product of two tensors. - The behavior depends on the dimensionality of the tensors as follows: - - - If both arguments are 2-dimensional, the matrix-matrix product is returned. - - If the first argument is 1-dimensional and the second argument is 2-dimensional, - a 1 is prepended to its dimension for the purpose of the matrix multiply. - - After the matrix multiply, the prepended dimension is removed. - - If the first argument is 2-dimensional and the second argument is 1-dimensional, - the matrix-vector product is returned in 2 dimensions. - - If both arguments are at least 1-dimensional and at least one argument is - N-dimensional (where N > 2), then a batched matrix multiply is returned. If the first - argument is 1-dimensional, a 1 is prepended to its dimension for the purpose of the - batched matrix multiply. If the second argument is 1-dimensional, a - 1 is appended to its dimension for the purpose of the batched matrix multiple. - - The non-matrix (i.e. batch) dimensions must be broadcastable. - - The behaviour is the same as PyTorch, with the exception of two cases of batch dimensions: - - - The two batch dimensions are swapped. E.g. :math:`(j \times 1)` and :math:`(1 \times j)` - or :math:`(1 \times j)` and :math:`(j \times 1)`. - - When a batch dimension is implicitly extended then the two patch dimensions are swapped. - E.g. :math:`(j \times 1)` and :math:`(j)` which is treated as - :math:`(j \times 1)` and :math:`(1 \times j)`. - - In order to leverage sharded matmul implementations we can shard both :attr:`input_tensor_a` and :attr:`input_tensor_b`. The sharding strategy used will be according - to the sharding strategy on the respective tensor. A sharded 1D matmul can be either HEIGHT or WIDTH sharded, 2D matmuls can be block sharded. - - Note that the broadcasting logic only looks at the batch dimensions when determining if the inputs - are broadcastable, and not the matrix dimensions. For example, if :attr:`input_tensor_a` is a - :math:`(j \times 1 \times n \times m)` tensor and :attr:`input_tensor_b` is a :math:`(k \times m \times p)` - tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the - matrix dimensions) are different. The operation will return a :math:`(j \times k \times n \times p)` tensor. - - Note: - The 1-dimensional dot product version of this function is currently returning the Tensor with a non-empty shape. This is expected to be fixed in an upcoming release. + The input tensors need to be tiled. Therefore, the input tensors have to be + at least 2-dimensional. + + If the input tensors have more than two dimensions, the additional, front, + dimensions may be used for batched matrix multiply. + These front dimensions may also be referred to as batch dimensions. + E.g. a tensor with dimensions :math:`(a \\times b \\times c \\times d)` + has batch dimensions a and b. + The following are the allowed possibilities for batch dimensions. + Examples below show concrete operations and tensor sizes. + + - If all batch dimensions are all of size 1, then there is no batched operation. + + - If both inputs have batch dimensions that are not all of size 1, then the + batch dimensions of both inputs should be the same. If the dimensions are + not the same then, although there may be combinations that may work, in most + cases various errors will be reported. + + - If the first input has batch dimensions that are not all of size 1, and the + second input has no batch dimensions or has batch dimensions all of size 1, + then the second input is broadcast to align appropriately with the first + input. + + - Matrix multiplication will not work if the first input has batch + dimensions that are all of size 1 and the second input has batch dimensions + that are not all of size 1. + + - Note: Dimensions of size 0 are not supported. + + - Note: In general, the number of dimensions between the two inputs should + match. There may be cases where they don't. In that case, if the inputs + are not valid based on the above criteria, the error messages may + be unexpected and refer to non-obvious issues. + + - Note: There are various combinations of dimensions possible. The behaviour + is the same as PyTorch, except for two exceptions. + These exceptions are for the following scenarios related to batch + dimensions: + + - The two batch dimensions are swapped. E.g. the first input has :math:`(j \\times 1)` + and the second input has :math:`(1 \\times j)` + or the first input has :math:`(1 \\times j)` and the second input has + :math:`(j \\times 1)` + - When a batch dimension is implicitly extended then the two patch dimensions are swapped. + E.g. :math:`(j \\times 1)` and :math:`(j)` which is treated as + :math:`(j \\times 1)` and :math:`(1 \\times j)` + + - In order to leverage sharded matmul implementations we can shard both input_tensor_a and input_tensor_b. The sharding strategy used will be according + to the sharding strategy on the respective tensor. A sharded 1D matmul can be either HEIGHT or WIDTH sharded, 2D matmuls can be block sharded. + + Note: the broadcasting logic only looks at the batch dimensions when determining if the inputs + are broadcastable, and not the matrix dimensions. For example, if :attr:`input_tensor_a` is a + :math:`(j \\times 1 \\times n\_size \\times m\_size)` tensor and :attr:`input_tensor_b` is a :math:`(k\_size \\times m\_size \\times p)` + tensor, these inputs are valid for broadcasting even though the final two dimensions (i.e. the + matrix dimensions) are different. The operation will return a :math:`(j \\times k\_size \\times n\_size \\times p)` tensor. + + - Note: there are various additional constraints related to specific program + configs chosen. Please look at the error messages carefully and fix + problems appropriately. Args: input_tensor_a (ttnn.Tensor): the first tensor to be multiplied. Needs to be on the device. @@ -203,36 +228,42 @@ void py_module(py::module& module) { ttnn.Tensor: the output tensor. Example: - >>> # vector x vector - >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device) - >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device) - >>> output = tensor1 @ tensor2 - >>> print(output.shape) - [32] - >>> # matrix x vector + >>> # matrix x matrix - no batch dimensions >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((64, 32), dtype=torch.bfloat16)), device) - >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device) - >>> output = tensor1 @ tensor2 + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device) + >>> output = ttnn.matmul(tensor1, tensor2) + >>> print(output.shape) + [64, 64] + >>> # extended matrix x extended matrix - all batch dimensions of size 1 + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) + >>> output = ttnn.matmul(tensor1, tensor2) + >>> print(output.shape) + [1, 1, 64, 64] + >>> # extended matrix x extended matrix - all batch dimensions of size 1 + >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 64, 32), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 32, 64), dtype=torch.bfloat16), layout=ttnn.TILE_LAYOUT), device=device) + >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) - [64, 1] - >>> # batched matrix x broadcasted vector + [1, 1, 64, 64] + >>> # batched matrix x broadcasted matrix - first input has batch dimensions not of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) - >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32), dtype=torch.bfloat16)), device) - >>> output = tensor1 @ tensor2 + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 64), dtype=torch.bfloat16)), device) + >>> output = ttnn.matmul(tensor1, tensor2) >>> print(output.shape) - [10, 64, 1] - >>> # batched matrix x batched matrix + [10, 64, 64] + >>> # batched matrix x batched matrix - both inputs have batch dimensions >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 32, 128), dtype=torch.bfloat16)), device) - >>> output = tensor1 @ tensor2 + >>> output = tensor1 @ tensor2 # alternative to ttnn.matmul(tensor1, tensor2) >>> print(output.shape) [10, 64, 128] - >>> # batched matrix x broadcasted matrix + >>> # batched matrix x broadcasted extended matrix - first input has batch dimensions not of size 1 >>> tensor1 = ttnn.to_device(ttnn.from_torch(torch.randn((10, 64, 32), dtype=torch.bfloat16)), device) - >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((32, 128), dtype=torch.bfloat16)), device) + >>> tensor2 = ttnn.to_device(ttnn.from_torch(torch.randn((1, 1, 32, 128), dtype=torch.bfloat16)), device) >>> output = tensor1 @ tensor2 >>> print(output.shape) - [10, 64, 128] + [1, 10, 64, 128] )doc", ttnn::pybind_overload_t{ [](decltype(::ttnn::matmul)& self, @@ -280,6 +311,8 @@ void py_module(py::module& module) { R"doc( Returns the linear transformation of the inputs. + The limitations and behaviours are the same as for matmul. + Args: input_tensor_a (ttnn.Tensor): the first tensor to be multiplied. Needs to be on the device. input_tensor_b (ttnn.Tensor): the second tensor to be multiplied. Needs to be on the device.