Skip to content

Commit

Permalink
#4003: keep track of all torch operations that run after ttnn.to_torch
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed Feb 22, 2024
1 parent 48de35a commit 553c162
Showing 1 changed file with 27 additions and 2 deletions.
29 changes: 27 additions & 2 deletions ttnn/ttnn/operations/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

import math
import pathlib
from typing import Union, Tuple, Optional
from typing import Union, Tuple, Optional, Any

from loguru import logger

Expand Down Expand Up @@ -411,7 +411,32 @@ def impl(ttl_tensor):

ttl_tensor = tensor.value
tensor = ttnn.Tensor(ttl_tensor.reshape(tensor.shape.with_tile_padding().value))
return ttl.tensor.decorate_external_operation(impl, function_name="ttnn.to_torch")(tensor.value)
tensor = ttl.tensor.decorate_external_operation(impl, function_name="ttnn.to_torch")(tensor.value)

import torch

class TorchTensor(torch.Tensor):
@staticmethod
def __new__(
cls: Any,
tensor: Any,
*function_args: Any,
**function_kwargs: Any,
) -> Any:
return super().__new__(cls, tensor, *function_args, **function_kwargs) # type: ignore[call-arg]

@classmethod
def __torch_function__(
cls: Any,
function,
types: Any,
function_args: Any = (),
function_kwargs: Any = None,
) -> Any:
function = ttl.tensor.decorate_external_operation(function, function_name=function.__name__)
return super().__torch_function__(function, types, function_args, function_kwargs)

return TorchTensor(tensor)


def _to_device_validate_input_tensors(operation_name, tensor, *args, **kwargs):
Expand Down

0 comments on commit 553c162

Please sign in to comment.