diff --git a/ttnn/cpp/ttnn/device_operation.hpp b/ttnn/cpp/ttnn/device_operation.hpp index e6d4320b494..b50132efb6e 100644 --- a/ttnn/cpp/ttnn/device_operation.hpp +++ b/ttnn/cpp/ttnn/device_operation.hpp @@ -32,6 +32,20 @@ concept ProgramFactoryConcept = requires { }; }; +template +concept HasComputeOutputShapes = requires(device_operation_t op, + const typename device_operation_t::operation_attributes_t& operation_attributes, + const typename device_operation_t::tensor_args_t& tensor_args) { + {op.compute_output_shapes(operation_attributes, tensor_args)} -> std::same_as; +}; + +template +concept HasComputeOutputSpecs = requires(device_operation_t op, + const typename device_operation_t::operation_attributes_t& operation_attributes, + const typename device_operation_t::tensor_args_t& tensor_args) { + {op.compute_output_specs(operation_attributes, tensor_args)} -> std::same_as; +}; + template concept DeviceOperationConcept = requires { [](const typename device_operation_t::operation_attributes_t& operation_attributes, @@ -39,11 +53,6 @@ concept DeviceOperationConcept = requires { device_operation_t::validate_on_program_cache_hit(operation_attributes, tensor_args); device_operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args); - using shape_return_value_t = typename device_operation_t::shape_return_value_t; - static_assert(std::same_as< - decltype(device_operation_t::compute_output_shapes(operation_attributes, tensor_args)), - shape_return_value_t>); - using tensor_return_value_t = typename device_operation_t::tensor_return_value_t; static_assert(std::same_as< decltype(device_operation_t::create_output_tensors(operation_attributes, tensor_args)), @@ -57,7 +66,7 @@ concept DeviceOperationConcept = requires { }, program_factory); }; -}; +} && (HasComputeOutputSpecs || HasComputeOutputShapes); template concept DeviceOperationWithCustomProgramCacheConcept = diff --git a/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.cpp b/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.cpp index da542429466..70c85630c26 100644 --- a/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.cpp @@ -47,9 +47,14 @@ void BernoulliDeviceOperation::validate_on_program_cache_hit( validate_inputs(operation_attributes, tensor_args); } -BernoulliDeviceOperation::shape_return_value_t BernoulliDeviceOperation::compute_output_shapes( +BernoulliDeviceOperation::spec_return_value_t BernoulliDeviceOperation::compute_output_specs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return tensor_args.input.get_logical_shape(); + if (tensor_args.output.has_value()) { + return tensor_args.output->get_tensor_spec(); + } + + auto output_shape = tensor_args.input.get_logical_shape(); + return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config)); } BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create_output_tensors( @@ -58,13 +63,7 @@ BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create return tensor_args.output.value(); } - auto output_shapes = compute_output_shapes(operation_attributes, tensor_args); - return create_device_tensor( - output_shapes, - operation_attributes.dtype, - Layout::TILE, - tensor_args.input.device(), - operation_attributes.memory_config); + return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device()); } std::tuple diff --git a/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.hpp b/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.hpp index a085a650432..bec42699629 100644 --- a/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/bernoulli/device/bernoulli_device_operation.hpp @@ -21,7 +21,7 @@ struct BernoulliDeviceOperation { const std::optional& output; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct ProgramFactory { @@ -52,7 +52,7 @@ struct BernoulliDeviceOperation { static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp index 55bf9c308cb..692edc87486 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.cpp @@ -152,11 +152,17 @@ void BinaryDeviceOperation::validate_on_program_cache_hit( TT_FATAL(width_a == width_b || width_a == 1 || width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch"); } -BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes( - const operation_attributes_t&, const tensor_args_t& tensor_args) { - const auto input_shape_a = tensor_args.input_tensor_a.shape(); +BinaryDeviceOperation::spec_return_value_t BinaryDeviceOperation::compute_output_specs( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + const auto& output_tensor = tensor_args.output_tensor; + if (output_tensor.has_value()) { + return output_tensor->get_tensor_spec(); + } + + const auto& input_tensor_a = tensor_args.input_tensor_a; + const auto input_shape_a = input_tensor_a.logical_shape(); const auto& tensor_b = tensor_args.input_tensor_b; - const auto input_shape_b = tensor_b.has_value() ? tensor_b->shape() : ttnn::Shape{1, 1}; + const auto input_shape_b = tensor_b.has_value() ? tensor_b->logical_shape() : ttnn::SimpleShape{}; const int rank_a = input_shape_a.rank(); const int rank_b = input_shape_b.rank(); @@ -181,24 +187,9 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu output_shape[i + larger_rank] = dim_a + dim_b - 1; } } - return output_shape; + return ttnn::SimpleShape(output_shape); }; - - const auto logical_shape_a = input_shape_a.logical_shape(); - const auto logical_shape_b = input_shape_b.logical_shape(); - return ttnn::SimpleShape(compute_broadcasted_output(logical_shape_a, logical_shape_b)); -} - -BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors( - const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - using namespace tt::constants; - auto output_shape = compute_output_shapes(operation_attributes, tensor_args); - const auto& input_tensor_a = tensor_args.input_tensor_a; - const auto& output_tensor = tensor_args.output_tensor; - - if (output_tensor.has_value()) { - return output_tensor.value(); - } + auto output_shape = compute_broadcasted_output(input_shape_a, input_shape_b); auto program_factory = select_program_factory(operation_attributes, tensor_args); if (std::holds_alternative(program_factory)) { @@ -214,8 +205,7 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu } auto memory_config = operation_attributes.memory_config; memory_config.shard_spec = shard_spec; - return create_device_tensor( - output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config); + return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config)); } } else { if (operation_attributes.memory_config.is_sharded()) { @@ -226,16 +216,18 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu } auto memory_config = operation_attributes.memory_config; memory_config.shard_spec = shard_spec; - return create_device_tensor( - output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config); + return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config)); } } - return create_device_tensor( - output_shape, - operation_attributes.dtype, - Layout::TILE, - input_tensor_a.device(), - operation_attributes.memory_config); + return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config)); +} + +BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors( + const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { + if (tensor_args.output_tensor.has_value()) { + return *tensor_args.output_tensor; + } + return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input_tensor_a.device()); } tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash( diff --git a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp index 0e37500c43f..d4c42ca9810 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/binary/device/binary_device_operation.hpp @@ -47,7 +47,7 @@ struct BinaryDeviceOperation { std::optional input_tensor_b; std::optional output_tensor; }; - using shape_return_value_t = ttnn::SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct ElementWiseMultiCore { @@ -203,7 +203,7 @@ struct BinaryDeviceOperation { static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors( const operation_attributes_t& operation_attributes, const tensor_args_t&); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp index 79995d76d33..56ba7717bcf 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.cpp @@ -138,7 +138,7 @@ void UnaryDeviceOperation::validate_on_program_cache_miss( } if (preallocated_output_tensor.has_value()) { - const auto computed_output_shape = compute_output_shapes(args, tensor_args); + const auto computed_output_shape = compute_output_specs(args, tensor_args).logical_shape(); const auto preallocated_output_shape = preallocated_output_tensor.value().get_logical_shape(); TT_FATAL( preallocated_output_shape == computed_output_shape, @@ -155,15 +155,10 @@ void UnaryDeviceOperation::validate_on_program_cache_miss( } } -shape_return_value_t UnaryDeviceOperation::compute_output_shapes( - const operation_attributes_t&, const tensor_args_t& tensor_args) { - return {tensor_args.input.get_logical_shape()}; -} - -tensor_return_value_t UnaryDeviceOperation::create_output_tensors( +spec_return_value_t UnaryDeviceOperation::compute_output_specs( const operation_attributes_t& args, const tensor_args_t& tensor_args) { if (tensor_args.preallocated_output.has_value()) { - return tensor_args.preallocated_output.value(); + return tensor_args.preallocated_output->get_tensor_spec(); } auto output_layout = Layout::TILE; @@ -171,9 +166,16 @@ tensor_return_value_t UnaryDeviceOperation::create_output_tensors( output_layout = tensor_args.input.get_layout(); } - const auto output_shape = tensor_args.input.shape(); - return create_device_tensor( - output_shape, args.output_dtype, output_layout, tensor_args.input.device(), args.output_memory_config); + const auto output_shape = tensor_args.input.logical_shape(); + return TensorSpec(output_shape, TensorLayout(args.output_dtype, output_layout, args.output_memory_config)); +} + +tensor_return_value_t UnaryDeviceOperation::create_output_tensors( + const operation_attributes_t& args, const tensor_args_t& tensor_args) { + if (tensor_args.preallocated_output.has_value()) { + return *tensor_args.preallocated_output; + } + return create_device_tensor(compute_output_specs(args, tensor_args), tensor_args.input.device()); } tt::stl::hash::hash_t UnaryDeviceOperation::compute_program_hash( diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp index a8bdafcf64b..cf303a8aa39 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation.hpp @@ -24,7 +24,7 @@ struct UnaryDeviceOperation { using operation_attributes_t = unary::operation_attributes_t; using tensor_args_t = unary::tensor_args_t; - using shape_return_value_t = unary::shape_return_value_t; + using spec_return_value_t = unary::spec_return_value_t; using tensor_return_value_t = unary::tensor_return_value_t; using program_factory_t = std::variant; @@ -33,7 +33,7 @@ struct UnaryDeviceOperation { static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t& operation_attributes, const tensor_args_t&); diff --git a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp index 3c9ce09fb75..e29cd4d728a 100644 --- a/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp +++ b/ttnn/cpp/ttnn/operations/eltwise/unary/device/unary_device_operation_types.hpp @@ -28,6 +28,6 @@ struct tensor_args_t { using tensor_return_value_t = Tensor; -using shape_return_value_t = ttnn::SimpleShape; +using spec_return_value_t = TensorSpec; } // namespace ttnn::operations::unary diff --git a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp index e71c798f886..96bac8fbbc2 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.cpp @@ -45,26 +45,16 @@ void Embeddings::validate(const std::vector &input_tensors) const { } } -std::vector Embeddings::compute_output_shapes(const std::vector &input_tensors) const { +std::vector Embeddings::compute_output_specs(const std::vector &input_tensors) const { const auto &input_tensor = input_tensors.at(0); const auto &weight_tensor = input_tensors.at(1); - auto num_output_embeddings = input_tensor.get_legacy_shape()[3]; - auto batch_num = input_tensor.get_legacy_shape()[0]; - auto num_embedding_dims = weight_tensor.get_legacy_shape()[3]; + auto num_output_embeddings = input_tensor.logical_shape()[3]; + auto batch_num = input_tensor.logical_shape()[0]; + auto num_embedding_dims = weight_tensor.logical_shape()[3]; - tt::tt_metal::LegacyShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims}); - return {output_shape}; -} - -std::vector Embeddings::create_output_tensors(const std::vector &input_tensors) const { - const auto &weight_tensor = input_tensors.at(1); - if (!tilized) { - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, Layout::ROW_MAJOR, this->output_mem_config); - } else { - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config); - } + ttnn::SimpleShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims}); + auto output_layout = tilized ? Layout::TILE : Layout::ROW_MAJOR; + return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(output_layout), output_mem_config))}; } operation::ProgramWithCallbacks Embeddings::create_program( diff --git a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp index 9bacd2e4501..69e14c39165 100644 --- a/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/embedding/device/embedding_device_operation.hpp @@ -24,8 +24,7 @@ struct Embeddings { const DataType output_dtype; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; + std::vector compute_output_specs(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; }; diff --git a/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp b/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp index 7f610c90fc2..f824df4dd5c 100644 --- a/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.cpp @@ -68,18 +68,13 @@ void EmbeddingBackward::validate(const std::vector &input_tensors) const "Number of rows in gradient tensor must be equal to number of indices in index tensor"); } -std::vector EmbeddingBackward::compute_output_shapes( +std::vector EmbeddingBackward::compute_output_specs( const std::vector &input_tensors) const { const auto &grad_tensor = input_tensors.at(1); - auto embedding_dim = grad_tensor.get_legacy_shape()[-1]; + auto embedding_dim = grad_tensor.get_logical_shape()[-1]; - tt::tt_metal::LegacyShape output_shape({1, 1, this->num_embeddings, embedding_dim}); - return {output_shape}; -} - -std::vector EmbeddingBackward::create_output_tensors(const std::vector &input_tensors) const { - return operation::generic_create_output_tensors( - *this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config); + ttnn::SimpleShape output_shape({1, 1, this->num_embeddings, embedding_dim}); + return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(Layout::TILE), output_mem_config))}; } operation::ProgramWithCallbacks EmbeddingBackward::create_program( diff --git a/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.hpp b/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.hpp index 22dea1a9172..27a18bc9b0b 100644 --- a/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/embedding_backward/device/embedding_backward_device_operation.hpp @@ -24,8 +24,7 @@ struct EmbeddingBackward { uint32_t num_embeddings; void validate(const std::vector &input_tensors) const; - std::vector compute_output_shapes(const std::vector &input_tensors) const; - std::vector create_output_tensors(const std::vector &input_tensors) const; + std::vector compute_output_specs(const std::vector &input_tensors) const; operation::ProgramWithCallbacks create_program( const std::vector &input_tensors, std::vector &output_tensors) const; tt::stl::reflection::Attributes attributes() const; diff --git a/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp index 031af15dbb1..34f93caaaeb 100644 --- a/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp +++ b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.cpp @@ -32,9 +32,9 @@ void UniformDeviceOperation::validate_on_program_cache_hit( validate_inputs(operation_attributes, tensor_args); } -UniformDeviceOperation::shape_return_value_t UniformDeviceOperation::compute_output_shapes( +TensorSpec UniformDeviceOperation::compute_output_specs( const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { - return tensor_args.input.get_logical_shape(); + return tensor_args.input.get_tensor_spec(); } UniformDeviceOperation::tensor_return_value_t UniformDeviceOperation::create_output_tensors( diff --git a/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp index 72bf4f0b7b3..115727fb6cd 100644 --- a/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp +++ b/ttnn/cpp/ttnn/operations/uniform/device/uniform_device_operation.hpp @@ -21,7 +21,7 @@ struct UniformDeviceOperation { const Tensor& input; }; - using shape_return_value_t = SimpleShape; + using spec_return_value_t = TensorSpec; using tensor_return_value_t = Tensor; struct ProgramFactory { @@ -51,7 +51,7 @@ struct UniformDeviceOperation { static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args); static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&); static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&); - static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&); + static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&); static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&); static std::tuple invoke(