From d4d5f3de2d8d79647a7498db7c45f6661cc814ef Mon Sep 17 00:00:00 2001 From: Akhmed Rakhmati Date: Mon, 26 Feb 2024 23:03:37 +0000 Subject: [PATCH] #4003: return TorchTensor from ttnn.to_torch --- ttnn/ttnn/operations/core.py | 27 ++++++++++++--------------- 1 file changed, 12 insertions(+), 15 deletions(-) diff --git a/ttnn/ttnn/operations/core.py b/ttnn/ttnn/operations/core.py index 0de98bb43fb..629c7d37925 100644 --- a/ttnn/ttnn/operations/core.py +++ b/ttnn/ttnn/operations/core.py @@ -7,6 +7,7 @@ from typing import Union, Tuple, Optional, Any from loguru import logger +import torch import tt_lib as ttl @@ -368,6 +369,16 @@ def _to_torch_validate_input_tensors(operation_name, input_tensor, *args, **kwar ) +class TorchTensor(torch.Tensor): + @classmethod + def __torch_function__(cls, func, types, args=(), kwargs=None): + # this tells torch to treat TorchTensor just like torch.Tensor's. + # Otherwise, torch will complain that it doesn't know how to handle it. + types = tuple(torch.Tensor if t == TorchTensor else t for t in types) + func = ttl.tensor.decorate_external_operation(func, function_name=f"(torch) {func.__name__}") + return super().__torch_function__(func, types, args, kwargs) + + @ttnn.register_operation(name="ttnn.to_torch", validate_input_tensors=_to_torch_validate_input_tensors) def to_torch(tensor: ttnn.Tensor, *, torch_rank: Optional[int] = None) -> "torch.Tensor": """ @@ -413,21 +424,7 @@ def impl(ttl_tensor): tensor = ttnn.Tensor(ttl_tensor.reshape(tensor.shape.with_tile_padding().value)) tensor = ttl.tensor.decorate_external_operation(impl, function_name="ttnn.to_torch")(tensor.value) - import torch - - class TorchTensor(torch.Tensor): - @classmethod - def __torch_function__( - cls: Any, - func, - types: Any, - args: Any = (), - kwargs: Any = None, - ) -> Any: - func = ttl.tensor.decorate_external_operation(func, function_name=f"(torch) {func.__name__}") - return super().__torch_function__(func, types, args, kwargs) - - return tensor + return TorchTensor(tensor) def _to_device_validate_input_tensors(operation_name, tensor, *args, **kwargs):