Skip to content

Commit

Permalink
#9681: set __name__ attribute for ttnn operations when fast runtime m…
Browse files Browse the repository at this point in the history
…ode is disabled
  • Loading branch information
arakhmati committed Jun 27, 2024
1 parent 34ae6fb commit c3042ac
Show file tree
Hide file tree
Showing 5 changed files with 28 additions and 19 deletions.
38 changes: 28 additions & 10 deletions ttnn/ttnn/decorators.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,10 @@ class Operation:
is_cpp_function: bool
is_experimental: bool

@property
def __name__(self):
return self.python_fully_qualified_name

def __gt__(self, other):
return self.python_fully_qualified_name < other.python_fully_qualified_name

Expand Down Expand Up @@ -748,6 +752,18 @@ def operation_decorator(function: callable):
global OPERATION_TO_GOLDEN_FUNCTION
global OPERATION_TO_FALLBACK_FUNCTION

is_cpp_function = hasattr(function, "__ttnn__")

python_fully_qualified_name = name
if is_cpp_function:
if doc is not None:
raise RuntimeError(f"Registering {name}: documentation for C++ function has to be set from C++")
if python_fully_qualified_name is not None:
raise RuntimeError(f"Registering {name}: name is not allowed for ttnn functions")
python_fully_qualified_name = function.python_fully_qualified_name # Replace C++ name with python
elif not is_experimental:
logger.warning(f"{name} should be migrated to C++!")

def fallback_function(*function_args, **function_kwargs):
preprocess_inputs = preprocess_golden_function_inputs or default_preprocess_golden_function_inputs
postprocess_outputs = postprocess_golden_function_outputs or default_postprocess_golden_function_outputs
Expand All @@ -759,19 +775,21 @@ def fallback_function(*function_args, **function_kwargs):
return output

if ttnn.CONFIG.enable_fast_runtime_mode:

def name_decorator(function):
@wraps(function)
def wrapper(*args, **kwargs):
return function(*args, **kwargs)

return wrapper

function = name_decorator(function)
function.__name__ = python_fully_qualified_name

OPERATION_TO_GOLDEN_FUNCTION[function] = golden_function
OPERATION_TO_FALLBACK_FUNCTION[function] = fallback_function
return function

is_cpp_function = hasattr(function, "__ttnn__")

python_fully_qualified_name = name
if is_cpp_function:
if doc is not None:
raise RuntimeError(f"Registering {name}: documentation for C++ functiomn has to be set from C++")
if python_fully_qualified_name is not None:
raise RuntimeError(f"Registering {name}: name is not allowed for ttnn functions")
python_fully_qualified_name = function.python_fully_qualified_name # Replace C++ name with python
return function

# Wrap functions before attaching documentation to avoid errors
if doc is not None:
Expand Down
2 changes: 0 additions & 2 deletions ttnn/ttnn/operations/binary.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,6 @@ def binary_function(
return output_tensor

if isinstance(binary_function, ttnn.decorators.Operation):
binary_function.__name__ = f"ttnn.{name}"
binary_function.decorated_function.__doc__ = doc + (
binary_function.__doc__ if binary_function.__doc__ is not None else ""
)
Expand Down Expand Up @@ -348,7 +347,6 @@ def elt_binary_function(
return output_tensor

if isinstance(elt_binary_function, ttnn.decorators.Operation):
elt_binary_function.__name__ = f"ttnn.{name}"
elt_binary_function.decorated_function.__doc__ = f"""{name}(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Performs eltwise-binary {op_name} operation on two tensors :attr:`input_a` and :attr:`input_b`.
Expand Down
1 change: 0 additions & 1 deletion ttnn/ttnn/operations/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ def loss_function(
return output_tensor

if isinstance(loss_function, ttnn.decorators.Operation):
loss_function.__name__ = f"ttnn.{name}"
loss_function.decorated_function.__doc__ = f"""{name}(input_tensor_a: ttnn.Tensor, input_tensor_b: ttnn.Tensor, loss_mode: str, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Applies {name} to :attr:`input_tensor_a` and :attr:`input_tensor_b` with loss_mode :attr:`loss_mode`.
Expand Down
2 changes: 0 additions & 2 deletions ttnn/ttnn/operations/ternary.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,6 @@ def ternary_function(
return output_tensor

if isinstance(ternary_function, ttnn.decorators.Operation):
ternary_function.__name__ = f"ttnn.{name}"
ternary_function.decorated_function.__doc__ = f"""{name}(input_tensor: ttnn.Tensor, input_tensor1: ttnn.Tensor, input_tensor2: ttnn.Tensor, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Returns tensor with the {name} of all of elements of the input tensors input, tensor1, tensor2.
Expand Down Expand Up @@ -221,7 +220,6 @@ def ternary_function(
return output_tensor

if isinstance(ternary_function, ttnn.decorators.Operation):
ternary_function.__name__ = f"ttnn.{(name)}"
ternary_function.decorated_function.__doc__ = f"""{(name)}(input_tensor: ttnn.Tensor, input_tensor1: ttnn.Tensor, input_tensor2: ttnn.Tensor, parameter, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Performs the element-wise {op_name} of tensor1 by tensor2, multiplies the result by the scalar value and adds it to input input.
Expand Down
4 changes: 0 additions & 4 deletions ttnn/ttnn/operations/unary.py
Original file line number Diff line number Diff line change
Expand Up @@ -260,7 +260,6 @@ def unary_function(
return output_tensor

if isinstance(unary_function, ttnn.decorators.Operation):
unary_function.__name__ = f"ttnn.{(name)}"
unary_function.decorated_function.__doc__ = f"""{(name)}(input_tensor: ttnn.Tensor, parameter, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Applies the {name} function to the elements of the input tensor :attr:`input_tensor` with :attr:`{param}` parameter.
Expand Down Expand Up @@ -346,7 +345,6 @@ def activation_function(
return output_tensor

if isinstance(activation_function, ttnn.decorators.Operation):
activation_function.__name__ = f"ttnn.{(name)}"
activation_function.decorated_function.__doc__ = f"""{(name)}(input_tensor: ttnn.Tensor, parameter, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Applies the {name} function to the elements of the input tensor :attr:`input_tensor` with :attr:`{param}` parameter.
Expand Down Expand Up @@ -430,7 +428,6 @@ def activation_function(
return output_tensor

if isinstance(activation_function, ttnn.decorators.Operation):
activation_function.__name__ = f"ttnn.{(name)}"
activation_function.decorated_function.__doc__ = f"""{(name)}(input_tensor: ttnn.Tensor, parameter, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Applies the {name} function to the elements of the input tensor :attr:`input_tensor` with :attr:`{param1_name}` and :attr:`{param2_name}` parameters.
Expand Down Expand Up @@ -550,7 +547,6 @@ def activation_function(
return output_tensor

if isinstance(activation_function, ttnn.decorators.Operation):
activation_function.__name__ = f"ttnn.{(name)}"
activation_function.decorated_function.__doc__ = f"""{(name)}(input_tensor: ttnn.Tensor, dim: int = -1, *, memory_config: ttnn.MemoryConfig = ttnn.DRAM_MEMORY_CONFIG) -> ttnn.Tensor
Applies the {name} function to the elements of the input tensor :attr:`input_tensor` split along :attr:`{param}`.
Expand Down

0 comments on commit c3042ac

Please sign in to comment.