Skip to content

Commit

Permalink
#4003: return TorchTensor from ttnn.to_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Feb 26, 2024
1 parent 9d198cf commit d4d5f3d
Showing 1 changed file with 12 additions and 15 deletions.
27 changes: 12 additions & 15 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
from typing import Union, Tuple, Optional, Any

from loguru import logger
import torch

import tt_lib as ttl

Expand Down Expand Up @@ -368,6 +369,16 @@ def _to_torch_validate_input_tensors(operation_name, input_tensor, *args, **kwar
)


class TorchTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args=(), kwargs=None):
# this tells torch to treat TorchTensor just like torch.Tensor's.
# Otherwise, torch will complain that it doesn't know how to handle it.
types = tuple(torch.Tensor if t == TorchTensor else t for t in types)
func = ttl.tensor.decorate_external_operation(func, function_name=f"(torch) {func.__name__}")
return super().__torch_function__(func, types, args, kwargs)


@ttnn.register_operation(name="ttnn.to_torch", validate_input_tensors=_to_torch_validate_input_tensors)
def to_torch(tensor: ttnn.Tensor, *, torch_rank: Optional[int] = None) -> "torch.Tensor":
"""
Expand Down Expand Up @@ -413,21 +424,7 @@ def impl(ttl_tensor):
tensor = ttnn.Tensor(ttl_tensor.reshape(tensor.shape.with_tile_padding().value))
tensor = ttl.tensor.decorate_external_operation(impl, function_name="ttnn.to_torch")(tensor.value)

import torch

class TorchTensor(torch.Tensor):
@classmethod
def __torch_function__(
cls: Any,
func,
types: Any,
args: Any = (),
kwargs: Any = None,
) -> Any:
func = ttl.tensor.decorate_external_operation(func, function_name=f"(torch) {func.__name__}")
return super().__torch_function__(func, types, args, kwargs)

return tensor
return TorchTensor(tensor)


def _to_device_validate_input_tensors(operation_name, tensor, *args, **kwargs):
Expand Down

0 comments on commit d4d5f3d

Please sign in to comment.