Skip to content

Commit

Permalink
#0: Port eltwise and some misc ops to use TensorSpec (#15471)
Browse files Browse the repository at this point in the history
### Ticket

### Problem description
We need to migrate all ops to use `compute_output_specs` with
TensorSpec, instead of older `compute_output_shapes`

### What's changed
Migrated ops to TensorSpec:
- binary
- unary
- bernoulli
- embedding
- embedding_backward
- unifrorm

Fixes in the infrastructure to support the migration.

### Checklist
- [x] [Post commit CI
passes](https://github.com/tenstorrent/tt-metal/actions/runs/12040667994)
- [ ] Blackhole Post commit (if applicable)
- [ ] Model regression CI testing passes (if applicable)
- [ ] Device performance regression CI testing passes (if applicable)
- [x] New/Existing tests provide coverage for changes
  • Loading branch information
sminakov-tt authored Nov 27, 2024
1 parent 9adbc6d commit d9cc84c
Showing 14 changed files with 83 additions and 98 deletions.
21 changes: 15 additions & 6 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
@@ -32,18 +32,27 @@ concept ProgramFactoryConcept = requires {
};
};

template <typename device_operation_t>
concept HasComputeOutputShapes = requires(device_operation_t op,
const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
{op.compute_output_shapes(operation_attributes, tensor_args)} -> std::same_as<typename device_operation_t::shape_return_value_t>;
};

template <typename device_operation_t>
concept HasComputeOutputSpecs = requires(device_operation_t op,
const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
{op.compute_output_specs(operation_attributes, tensor_args)} -> std::same_as<typename device_operation_t::spec_return_value_t>;
};

template <typename device_operation_t>
concept DeviceOperationConcept = requires {
[](const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
device_operation_t::validate_on_program_cache_hit(operation_attributes, tensor_args);
device_operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args);

using shape_return_value_t = typename device_operation_t::shape_return_value_t;
static_assert(std::same_as<
decltype(device_operation_t::compute_output_shapes(operation_attributes, tensor_args)),
shape_return_value_t>);

using tensor_return_value_t = typename device_operation_t::tensor_return_value_t;
static_assert(std::same_as<
decltype(device_operation_t::create_output_tensors(operation_attributes, tensor_args)),
@@ -57,7 +66,7 @@ concept DeviceOperationConcept = requires {
},
program_factory);
};
};
} && (HasComputeOutputSpecs<device_operation_t> || HasComputeOutputShapes<device_operation_t>);

template <typename device_operation_t>
concept DeviceOperationWithCustomProgramCacheConcept =
Original file line number Diff line number Diff line change
@@ -47,9 +47,14 @@ void BernoulliDeviceOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
}

BernoulliDeviceOperation::shape_return_value_t BernoulliDeviceOperation::compute_output_shapes(
BernoulliDeviceOperation::spec_return_value_t BernoulliDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
if (tensor_args.output.has_value()) {
return tensor_args.output->get_tensor_spec();
}

auto output_shape = tensor_args.input.get_logical_shape();
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config));
}

BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create_output_tensors(
@@ -58,13 +63,7 @@ BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create
return tensor_args.output.value();
}

auto output_shapes = compute_output_shapes(operation_attributes, tensor_args);
return create_device_tensor(
output_shapes,
operation_attributes.dtype,
Layout::TILE,
tensor_args.input.device(),
operation_attributes.memory_config);
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device());
}

std::tuple<BernoulliDeviceOperation::operation_attributes_t, BernoulliDeviceOperation::tensor_args_t>
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ struct BernoulliDeviceOperation {
const std::optional<Tensor>& output;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
@@ -52,7 +52,7 @@ struct BernoulliDeviceOperation {
static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Original file line number Diff line number Diff line change
@@ -152,11 +152,17 @@ void BinaryDeviceOperation::validate_on_program_cache_hit(
TT_FATAL(width_a == width_b || width_a == 1 || width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch");
}

BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
const auto input_shape_a = tensor_args.input_tensor_a.shape();
BinaryDeviceOperation::spec_return_value_t BinaryDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& output_tensor = tensor_args.output_tensor;
if (output_tensor.has_value()) {
return output_tensor->get_tensor_spec();
}

const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto input_shape_a = input_tensor_a.logical_shape();
const auto& tensor_b = tensor_args.input_tensor_b;
const auto input_shape_b = tensor_b.has_value() ? tensor_b->shape() : ttnn::Shape{1, 1};
const auto input_shape_b = tensor_b.has_value() ? tensor_b->logical_shape() : ttnn::SimpleShape{};

const int rank_a = input_shape_a.rank();
const int rank_b = input_shape_b.rank();
@@ -181,24 +187,9 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu
output_shape[i + larger_rank] = dim_a + dim_b - 1;
}
}
return output_shape;
return ttnn::SimpleShape(output_shape);
};

const auto logical_shape_a = input_shape_a.logical_shape();
const auto logical_shape_b = input_shape_b.logical_shape();
return ttnn::SimpleShape(compute_broadcasted_output(logical_shape_a, logical_shape_b));
}

BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
using namespace tt::constants;
auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto& output_tensor = tensor_args.output_tensor;

if (output_tensor.has_value()) {
return output_tensor.value();
}
auto output_shape = compute_broadcasted_output(input_shape_a, input_shape_b);

auto program_factory = select_program_factory(operation_attributes, tensor_args);
if (std::holds_alternative<ElementWiseMultiCore>(program_factory)) {
@@ -214,8 +205,7 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu
}
auto memory_config = operation_attributes.memory_config;
memory_config.shard_spec = shard_spec;
return create_device_tensor(
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config));
}
} else {
if (operation_attributes.memory_config.is_sharded()) {
@@ -226,16 +216,18 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu
}
auto memory_config = operation_attributes.memory_config;
memory_config.shard_spec = shard_spec;
return create_device_tensor(
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config));
}
}
return create_device_tensor(
output_shape,
operation_attributes.dtype,
Layout::TILE,
input_tensor_a.device(),
operation_attributes.memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config));
}

BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
if (tensor_args.output_tensor.has_value()) {
return *tensor_args.output_tensor;
}
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input_tensor_a.device());
}

tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash(
Original file line number Diff line number Diff line change
@@ -47,7 +47,7 @@ struct BinaryDeviceOperation {
std::optional<Tensor> input_tensor_b;
std::optional<Tensor> output_tensor;
};
using shape_return_value_t = ttnn::SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ElementWiseMultiCore {
@@ -203,7 +203,7 @@ struct BinaryDeviceOperation {
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t&);
Original file line number Diff line number Diff line change
@@ -138,7 +138,7 @@ void UnaryDeviceOperation::validate_on_program_cache_miss(
}

if (preallocated_output_tensor.has_value()) {
const auto computed_output_shape = compute_output_shapes(args, tensor_args);
const auto computed_output_shape = compute_output_specs(args, tensor_args).logical_shape();
const auto preallocated_output_shape = preallocated_output_tensor.value().get_logical_shape();
TT_FATAL(
preallocated_output_shape == computed_output_shape,
@@ -155,25 +155,27 @@ void UnaryDeviceOperation::validate_on_program_cache_miss(
}
}

shape_return_value_t UnaryDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
return {tensor_args.input.get_logical_shape()};
}

tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
spec_return_value_t UnaryDeviceOperation::compute_output_specs(
const operation_attributes_t& args, const tensor_args_t& tensor_args) {
if (tensor_args.preallocated_output.has_value()) {
return tensor_args.preallocated_output.value();
return tensor_args.preallocated_output->get_tensor_spec();
}

auto output_layout = Layout::TILE;
if (args.output_memory_config.is_sharded()) {
output_layout = tensor_args.input.get_layout();
}

const auto output_shape = tensor_args.input.shape();
return create_device_tensor(
output_shape, args.output_dtype, output_layout, tensor_args.input.device(), args.output_memory_config);
const auto output_shape = tensor_args.input.logical_shape();
return TensorSpec(output_shape, TensorLayout(args.output_dtype, output_layout, args.output_memory_config));
}

tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
const operation_attributes_t& args, const tensor_args_t& tensor_args) {
if (tensor_args.preallocated_output.has_value()) {
return *tensor_args.preallocated_output;
}
return create_device_tensor(compute_output_specs(args, tensor_args), tensor_args.input.device());
}

tt::stl::hash::hash_t UnaryDeviceOperation::compute_program_hash(
Original file line number Diff line number Diff line change
@@ -24,7 +24,7 @@ struct UnaryDeviceOperation {

using operation_attributes_t = unary::operation_attributes_t;
using tensor_args_t = unary::tensor_args_t;
using shape_return_value_t = unary::shape_return_value_t;
using spec_return_value_t = unary::spec_return_value_t;
using tensor_return_value_t = unary::tensor_return_value_t;
using program_factory_t = std::variant<program::UnaryProgramFactory, program::UnaryShardedProgramFactory>;

@@ -33,7 +33,7 @@ struct UnaryDeviceOperation {
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(const operation_attributes_t& operation_attributes, const tensor_args_t&);

Original file line number Diff line number Diff line change
@@ -28,6 +28,6 @@ struct tensor_args_t {

using tensor_return_value_t = Tensor;

using shape_return_value_t = ttnn::SimpleShape;
using spec_return_value_t = TensorSpec;

} // namespace ttnn::operations::unary
Original file line number Diff line number Diff line change
@@ -45,26 +45,16 @@ void Embeddings::validate(const std::vector<Tensor> &input_tensors) const {
}
}

std::vector<tt::tt_metal::LegacyShape> Embeddings::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
std::vector<TensorSpec> Embeddings::compute_output_specs(const std::vector<Tensor> &input_tensors) const {
const auto &input_tensor = input_tensors.at(0);
const auto &weight_tensor = input_tensors.at(1);
auto num_output_embeddings = input_tensor.get_legacy_shape()[3];
auto batch_num = input_tensor.get_legacy_shape()[0];
auto num_embedding_dims = weight_tensor.get_legacy_shape()[3];
auto num_output_embeddings = input_tensor.logical_shape()[3];
auto batch_num = input_tensor.logical_shape()[0];
auto num_embedding_dims = weight_tensor.logical_shape()[3];

tt::tt_metal::LegacyShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims});
return {output_shape};
}

std::vector<Tensor> Embeddings::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto &weight_tensor = input_tensors.at(1);
if (!tilized) {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::ROW_MAJOR, this->output_mem_config);
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
}
ttnn::SimpleShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims});
auto output_layout = tilized ? Layout::TILE : Layout::ROW_MAJOR;
return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(output_layout), output_mem_config))};
}

operation::ProgramWithCallbacks Embeddings::create_program(
Original file line number Diff line number Diff line change
@@ -24,8 +24,7 @@ struct Embeddings {
const DataType output_dtype;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
};
Original file line number Diff line number Diff line change
@@ -68,18 +68,13 @@ void EmbeddingBackward::validate(const std::vector<Tensor> &input_tensors) const
"Number of rows in gradient tensor must be equal to number of indices in index tensor");
}

std::vector<tt::tt_metal::LegacyShape> EmbeddingBackward::compute_output_shapes(
std::vector<TensorSpec> EmbeddingBackward::compute_output_specs(
const std::vector<Tensor> &input_tensors) const {
const auto &grad_tensor = input_tensors.at(1);
auto embedding_dim = grad_tensor.get_legacy_shape()[-1];
auto embedding_dim = grad_tensor.get_logical_shape()[-1];

tt::tt_metal::LegacyShape output_shape({1, 1, this->num_embeddings, embedding_dim});
return {output_shape};
}

std::vector<Tensor> EmbeddingBackward::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
ttnn::SimpleShape output_shape({1, 1, this->num_embeddings, embedding_dim});
return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(Layout::TILE), output_mem_config))};
}

operation::ProgramWithCallbacks EmbeddingBackward::create_program(
Original file line number Diff line number Diff line change
@@ -24,8 +24,7 @@ struct EmbeddingBackward {
uint32_t num_embeddings;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
tt::stl::reflection::Attributes attributes() const;
Original file line number Diff line number Diff line change
@@ -32,9 +32,9 @@ void UniformDeviceOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
}

UniformDeviceOperation::shape_return_value_t UniformDeviceOperation::compute_output_shapes(
TensorSpec UniformDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
return tensor_args.input.get_tensor_spec();
}

UniformDeviceOperation::tensor_return_value_t UniformDeviceOperation::create_output_tensors(
Original file line number Diff line number Diff line change
@@ -21,7 +21,7 @@ struct UniformDeviceOperation {
const Tensor& input;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
@@ -51,7 +51,7 @@ struct UniformDeviceOperation {
static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(

0 comments on commit d9cc84c

Please sign in to comment.