Skip to content

Commit

Permalink
#0: Port Eltwise and some misc ops to use TensorSpec
Browse files Browse the repository at this point in the history
  • Loading branch information
sminakov-tt committed Nov 23, 2024
1 parent f305b49 commit 1171a25
Show file tree
Hide file tree
Showing 14 changed files with 87 additions and 97 deletions.
27 changes: 22 additions & 5 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,16 +33,33 @@ concept ProgramFactoryConcept = requires {
};

template <typename device_operation_t>
concept DeviceOperationConcept = requires {
concept HasComputeOutputShapes = requires {
[](const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
device_operation_t::validate_on_program_cache_hit(operation_attributes, tensor_args);
device_operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args);

using shape_return_value_t = typename device_operation_t::shape_return_value_t;
static_assert(std::same_as<
decltype(device_operation_t::compute_output_shapes(operation_attributes, tensor_args)),
shape_return_value_t>);
};
};

template <typename device_operation_t>
concept HasComputeOutputSpecs = requires {
[](const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
using spec_return_value_t = typename device_operation_t::spec_return_value_t;
static_assert(std::same_as<
decltype(device_operation_t::compute_output_specs(operation_attributes, tensor_args)),
spec_return_value_t>);
};
};

template <typename device_operation_t>
concept DeviceOperationConcept = requires {
[](const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
device_operation_t::validate_on_program_cache_hit(operation_attributes, tensor_args);
device_operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args);

using tensor_return_value_t = typename device_operation_t::tensor_return_value_t;
static_assert(std::same_as<
Expand All @@ -57,7 +74,7 @@ concept DeviceOperationConcept = requires {
},
program_factory);
};
};
} && (HasComputeOutputSpecs<device_operation_t> || HasComputeOutputShapes<device_operation_t>);

template <typename device_operation_t>
concept DeviceOperationWithCustomProgramCacheConcept =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ void BernoulliDeviceOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
}

BernoulliDeviceOperation::shape_return_value_t BernoulliDeviceOperation::compute_output_shapes(
BernoulliDeviceOperation::spec_return_value_t BernoulliDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
if (tensor_args.output.has_value()) {
return tensor_args.output->get_tensor_spec();
}

auto output_shape = tensor_args.input.get_logical_shape();
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config));
}

BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create_output_tensors(
Expand All @@ -58,13 +63,7 @@ BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create
return tensor_args.output.value();
}

auto output_shapes = compute_output_shapes(operation_attributes, tensor_args);
return create_device_tensor(
output_shapes,
operation_attributes.dtype,
Layout::TILE,
tensor_args.input.device(),
operation_attributes.memory_config);
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device());
}

std::tuple<BernoulliDeviceOperation::operation_attributes_t, BernoulliDeviceOperation::tensor_args_t>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct BernoulliDeviceOperation {
const std::optional<Tensor>& output;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -52,7 +52,7 @@ struct BernoulliDeviceOperation {
static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,17 @@ void BinaryDeviceOperation::validate_on_program_cache_hit(
TT_FATAL(width_a == width_b || width_a == 1 || width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch");
}

BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
const auto input_shape_a = tensor_args.input_tensor_a.shape();
BinaryDeviceOperation::spec_return_value_t BinaryDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& output_tensor = tensor_args.output_tensor;
if (output_tensor.has_value()) {
return output_tensor->get_tensor_spec();
}

const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto input_shape_a = input_tensor_a.logical_shape();
const auto& tensor_b = tensor_args.input_tensor_b;
const auto input_shape_b = tensor_b.has_value() ? tensor_b->shape() : ttnn::Shape{1, 1};
const auto input_shape_b = tensor_b.has_value() ? tensor_b->logical_shape() : ttnn::SimpleShape{};

const int rank_a = input_shape_a.rank();
const int rank_b = input_shape_b.rank();
Expand All @@ -179,24 +185,9 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu
output_shape[i + larger_rank] = dim_a + dim_b - 1;
}
}
return output_shape;
return ttnn::SimpleShape(output_shape);
};

const auto logical_shape_a = input_shape_a.logical_shape();
const auto logical_shape_b = input_shape_b.logical_shape();
return ttnn::SimpleShape(compute_broadcasted_output(logical_shape_a, logical_shape_b));
}

BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
using namespace tt::constants;
auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto& output_tensor = tensor_args.output_tensor;

if (output_tensor.has_value()) {
return output_tensor.value();
}
auto output_shape = compute_broadcasted_output(input_shape_a, input_shape_b);

auto program_factory = select_program_factory(operation_attributes, tensor_args);
if (std::holds_alternative<ElementWiseMultiCore>(program_factory)) {
Expand All @@ -212,8 +203,7 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu
}
auto memory_config = operation_attributes.memory_config;
memory_config.shard_spec = shard_spec;
return create_device_tensor(
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config));
}
} else {
if (operation_attributes.memory_config.is_sharded()) {
Expand All @@ -224,16 +214,18 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu
}
auto memory_config = operation_attributes.memory_config;
memory_config.shard_spec = shard_spec;
return create_device_tensor(
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config));
}
}
return create_device_tensor(
output_shape,
operation_attributes.dtype,
Layout::TILE,
input_tensor_a.device(),
operation_attributes.memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config));
}

BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
if (tensor_args.output_tensor.has_value()) {
return *tensor_args.output_tensor;
}
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input_tensor_a.device());
}

tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct BinaryDeviceOperation {
std::optional<Tensor> input_tensor_b;
std::optional<Tensor> output_tensor;
};
using shape_return_value_t = ttnn::SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ElementWiseMultiCore {
Expand Down Expand Up @@ -203,7 +203,7 @@ struct BinaryDeviceOperation {
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t&);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void UnaryDeviceOperation::validate_on_program_cache_miss(
}

if (preallocated_output_tensor.has_value()) {
const auto computed_output_shape = compute_output_shapes(args, tensor_args);
const auto computed_output_shape = compute_output_specs(args, tensor_args).logical_shape();
const auto preallocated_output_shape = preallocated_output_tensor.value().get_logical_shape();
TT_FATAL(
preallocated_output_shape == computed_output_shape,
Expand All @@ -153,25 +153,24 @@ void UnaryDeviceOperation::validate_on_program_cache_miss(
}
}

shape_return_value_t UnaryDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
return {tensor_args.input.get_logical_shape()};
}

tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
spec_return_value_t UnaryDeviceOperation::compute_output_specs(
const operation_attributes_t& args, const tensor_args_t& tensor_args) {
if (tensor_args.preallocated_output.has_value()) {
return tensor_args.preallocated_output.value();
return tensor_args.preallocated_output->get_tensor_spec();
}

auto output_layout = Layout::TILE;
if (args.output_memory_config.is_sharded()) {
output_layout = tensor_args.input.get_layout();
}

const auto output_shape = tensor_args.input.shape();
return create_device_tensor(
output_shape, args.output_dtype, output_layout, tensor_args.input.device(), args.output_memory_config);
const auto output_shape = tensor_args.input.logical_shape();
return TensorSpec(output_shape, TensorLayout(args.output_dtype, output_layout, args.output_memory_config));
}

tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
const operation_attributes_t& args, const tensor_args_t& tensor_args) {
return create_device_tensor(compute_output_specs(args, tensor_args), tensor_args.input.device());
}

tt::stl::hash::hash_t UnaryDeviceOperation::compute_program_hash(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct UnaryDeviceOperation {

using operation_attributes_t = unary::operation_attributes_t;
using tensor_args_t = unary::tensor_args_t;
using shape_return_value_t = unary::shape_return_value_t;
using spec_return_value_t = unary::spec_return_value_t;
using tensor_return_value_t = unary::tensor_return_value_t;
using program_factory_t = std::variant<program::UnaryProgramFactory, program::UnaryShardedProgramFactory>;

Expand All @@ -33,7 +33,7 @@ struct UnaryDeviceOperation {
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(const operation_attributes_t& operation_attributes, const tensor_args_t&);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ struct tensor_args_t {

using tensor_return_value_t = Tensor;

using shape_return_value_t = ttnn::SimpleShape;
using spec_return_value_t = TensorSpec;

} // namespace ttnn::operations::unary
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,16 @@ void Embeddings::validate(const std::vector<Tensor> &input_tensors) const {
}
}

std::vector<tt::tt_metal::LegacyShape> Embeddings::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
std::vector<TensorSpec> Embeddings::compute_output_specs(const std::vector<Tensor> &input_tensors) const {
const auto &input_tensor = input_tensors.at(0);
const auto &weight_tensor = input_tensors.at(1);
auto num_output_embeddings = input_tensor.get_legacy_shape()[3];
auto batch_num = input_tensor.get_legacy_shape()[0];
auto num_embedding_dims = weight_tensor.get_legacy_shape()[3];
auto num_output_embeddings = input_tensor.logical_shape()[3];
auto batch_num = input_tensor.logical_shape()[0];
auto num_embedding_dims = weight_tensor.logical_shape()[3];

tt::tt_metal::LegacyShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims});
return {output_shape};
}

std::vector<Tensor> Embeddings::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto &weight_tensor = input_tensors.at(1);
if (!tilized) {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::ROW_MAJOR, this->output_mem_config);
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
}
ttnn::SimpleShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims});
auto output_layout = tilized ? Layout::TILE : Layout::ROW_MAJOR;
return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(output_layout), output_mem_config))};
}

operation::ProgramWithCallbacks Embeddings::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ struct Embeddings {
const DataType output_dtype;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,13 @@ void EmbeddingBackward::validate(const std::vector<Tensor> &input_tensors) const
"Number of rows in gradient tensor must be equal to number of indices in index tensor");
}

std::vector<tt::tt_metal::LegacyShape> EmbeddingBackward::compute_output_shapes(
std::vector<TensorSpec> EmbeddingBackward::compute_output_specs(
const std::vector<Tensor> &input_tensors) const {
const auto &grad_tensor = input_tensors.at(1);
auto embedding_dim = grad_tensor.get_legacy_shape()[-1];
auto embedding_dim = grad_tensor.get_logical_shape()[-1];

tt::tt_metal::LegacyShape output_shape({1, 1, this->num_embeddings, embedding_dim});
return {output_shape};
}

std::vector<Tensor> EmbeddingBackward::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
ttnn::SimpleShape output_shape({1, 1, this->num_embeddings, embedding_dim});
return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(Layout::TILE), output_mem_config))};
}

operation::ProgramWithCallbacks EmbeddingBackward::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ struct EmbeddingBackward {
uint32_t num_embeddings;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
tt::stl::reflection::Attributes attributes() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ void UniformDeviceOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
}

UniformDeviceOperation::shape_return_value_t UniformDeviceOperation::compute_output_shapes(
TensorSpec UniformDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
return tensor_args.input.get_tensor_spec();
}

UniformDeviceOperation::tensor_return_value_t UniformDeviceOperation::create_output_tensors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct UniformDeviceOperation {
const Tensor& input;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -51,7 +51,7 @@ struct UniformDeviceOperation {
static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down

0 comments on commit 1171a25

Please sign in to comment.