diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp index c2833971a11..cdaa4f48b6f 100644 --- a/ttnn/cpp/ttnn/device_operation.hpp +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -33,25 +33,17 @@ concept ProgramFactoryConcept = requires { }; template -concept HasComputeOutputShapes = requires { - [](const typename device_operation_t::operation_attributes_t& operation_attributes, - const typename device_operation_t::tensor_args_t& tensor_args) { - using shape_return_value_t = typename device_operation_t::shape_return_value_t; - static_assert(std::same_as< - decltype(device_operation_t::compute_output_shapes(operation_attributes, tensor_args)), - shape_return_value_t>); - }; +concept HasComputeOutputShapes = requires(device_operation_t op, + const typename device_operation_t::operation_attributes_t& operation_attributes, + const typename device_operation_t::tensor_args_t& tensor_args) { + {op.compute_output_shapes(operation_attributes, tensor_args)} -> std::same_as; }; template -concept HasComputeOutputSpecs = requires { - [](const typename device_operation_t::operation_attributes_t& operation_attributes, - const typename device_operation_t::tensor_args_t& tensor_args) { - using spec_return_value_t = typename device_operation_t::spec_return_value_t; - static_assert(std::same_as< - decltype(device_operation_t::compute_output_specs(operation_attributes, tensor_args)), - spec_return_value_t>); - }; +concept HasComputeOutputSpecs = requires(device_operation_t op, + const typename device_operation_t::operation_attributes_t& operation_attributes, + const typename device_operation_t::tensor_args_t& tensor_args) { + {op.compute_output_specs(operation_attributes, tensor_args)} -> std::same_as; }; template