Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#0: Port eltwise and some misc ops to use TensorSpec #15471

Merged
merged 3 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 15 additions & 6 deletions ttnn/cpp/ttnn/device_operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,18 +32,27 @@ concept ProgramFactoryConcept = requires {
};
};

template <typename device_operation_t>
concept HasComputeOutputShapes = requires(device_operation_t op,
const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
{op.compute_output_shapes(operation_attributes, tensor_args)} -> std::same_as<typename device_operation_t::shape_return_value_t>;
};

template <typename device_operation_t>
concept HasComputeOutputSpecs = requires(device_operation_t op,
const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
{op.compute_output_specs(operation_attributes, tensor_args)} -> std::same_as<typename device_operation_t::spec_return_value_t>;
};

template <typename device_operation_t>
concept DeviceOperationConcept = requires {
[](const typename device_operation_t::operation_attributes_t& operation_attributes,
const typename device_operation_t::tensor_args_t& tensor_args) {
device_operation_t::validate_on_program_cache_hit(operation_attributes, tensor_args);
device_operation_t::validate_on_program_cache_miss(operation_attributes, tensor_args);

using shape_return_value_t = typename device_operation_t::shape_return_value_t;
static_assert(std::same_as<
decltype(device_operation_t::compute_output_shapes(operation_attributes, tensor_args)),
shape_return_value_t>);

using tensor_return_value_t = typename device_operation_t::tensor_return_value_t;
static_assert(std::same_as<
decltype(device_operation_t::create_output_tensors(operation_attributes, tensor_args)),
Expand All @@ -57,7 +66,7 @@ concept DeviceOperationConcept = requires {
},
program_factory);
};
};
} && (HasComputeOutputSpecs<device_operation_t> || HasComputeOutputShapes<device_operation_t>);

template <typename device_operation_t>
concept DeviceOperationWithCustomProgramCacheConcept =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,14 @@ void BernoulliDeviceOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
}

BernoulliDeviceOperation::shape_return_value_t BernoulliDeviceOperation::compute_output_shapes(
BernoulliDeviceOperation::spec_return_value_t BernoulliDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
if (tensor_args.output.has_value()) {
return tensor_args.output->get_tensor_spec();
}

auto output_shape = tensor_args.input.get_logical_shape();
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config));
}

BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create_output_tensors(
Expand All @@ -58,13 +63,7 @@ BernoulliDeviceOperation::tensor_return_value_t BernoulliDeviceOperation::create
return tensor_args.output.value();
}

auto output_shapes = compute_output_shapes(operation_attributes, tensor_args);
return create_device_tensor(
output_shapes,
operation_attributes.dtype,
Layout::TILE,
tensor_args.input.device(),
operation_attributes.memory_config);
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input.device());
}

std::tuple<BernoulliDeviceOperation::operation_attributes_t, BernoulliDeviceOperation::tensor_args_t>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct BernoulliDeviceOperation {
const std::optional<Tensor>& output;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -52,7 +52,7 @@ struct BernoulliDeviceOperation {
static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,11 +150,17 @@ void BinaryDeviceOperation::validate_on_program_cache_hit(
TT_FATAL(width_a == width_b || width_a == 1 || width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch");
}

BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
const auto input_shape_a = tensor_args.input_tensor_a.shape();
BinaryDeviceOperation::spec_return_value_t BinaryDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
const auto& output_tensor = tensor_args.output_tensor;
if (output_tensor.has_value()) {
return output_tensor->get_tensor_spec();
}

const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto input_shape_a = input_tensor_a.logical_shape();
const auto& tensor_b = tensor_args.input_tensor_b;
const auto input_shape_b = tensor_b.has_value() ? tensor_b->shape() : ttnn::Shape{1, 1};
const auto input_shape_b = tensor_b.has_value() ? tensor_b->logical_shape() : ttnn::SimpleShape{};
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Has value what does it mean?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

tensor_b is optional here. I'm unsure what it means for a binary op to have a second argument as optional, but preserving the behavior

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I remember now. This is how float scalar is handled. Think of it like there is a union. It’s or float or tensor. Just expressed in this way


const int rank_a = input_shape_a.rank();
const int rank_b = input_shape_b.rank();
Expand All @@ -179,24 +185,9 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu
output_shape[i + larger_rank] = dim_a + dim_b - 1;
}
}
return output_shape;
return ttnn::SimpleShape(output_shape);
};

const auto logical_shape_a = input_shape_a.logical_shape();
const auto logical_shape_b = input_shape_b.logical_shape();
return ttnn::SimpleShape(compute_broadcasted_output(logical_shape_a, logical_shape_b));
}

BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
using namespace tt::constants;
auto output_shape = compute_output_shapes(operation_attributes, tensor_args);
const auto& input_tensor_a = tensor_args.input_tensor_a;
const auto& output_tensor = tensor_args.output_tensor;

if (output_tensor.has_value()) {
return output_tensor.value();
}
auto output_shape = compute_broadcasted_output(input_shape_a, input_shape_b);

auto program_factory = select_program_factory(operation_attributes, tensor_args);
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Got a bit surprised that we call this method here. It means we call it at least two times per operation call. Not that it’s a problem but it’s a surprise.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yeah. I would expect that select_program_factory should be really quick. It makes sense, at least to some extent, that the output tensor specs depend on the actual factory being used for these input arguments.

if (std::holds_alternative<ElementWiseMultiCore>(program_factory)) {
Expand All @@ -212,8 +203,7 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu
}
auto memory_config = operation_attributes.memory_config;
memory_config.shard_spec = shard_spec;
return create_device_tensor(
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config));
}
} else {
if (operation_attributes.memory_config.is_sharded()) {
Expand All @@ -224,16 +214,18 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu
}
auto memory_config = operation_attributes.memory_config;
memory_config.shard_spec = shard_spec;
return create_device_tensor(
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config));
}
}
return create_device_tensor(
output_shape,
operation_attributes.dtype,
Layout::TILE,
input_tensor_a.device(),
operation_attributes.memory_config);
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config));
}

BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
if (tensor_args.output_tensor.has_value()) {
return *tensor_args.output_tensor;
}
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input_tensor_a.device());
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So interesting. Tensor might have multiple devices today. I see this part of code did not change, but I wonder..

On this level multidevice basically does not exist today if I get it right.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. Multi-device seems to be problematic overall at the moment. With having single MeshDevice it should work as expected though.

}

tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ struct BinaryDeviceOperation {
std::optional<Tensor> input_tensor_b;
std::optional<Tensor> output_tensor;
};
using shape_return_value_t = ttnn::SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ElementWiseMultiCore {
Expand Down Expand Up @@ -203,7 +203,7 @@ struct BinaryDeviceOperation {
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(
const operation_attributes_t& operation_attributes, const tensor_args_t&);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -136,7 +136,7 @@ void UnaryDeviceOperation::validate_on_program_cache_miss(
}

if (preallocated_output_tensor.has_value()) {
const auto computed_output_shape = compute_output_shapes(args, tensor_args);
const auto computed_output_shape = compute_output_specs(args, tensor_args).logical_shape();
const auto preallocated_output_shape = preallocated_output_tensor.value().get_logical_shape();
TT_FATAL(
preallocated_output_shape == computed_output_shape,
Expand All @@ -153,25 +153,27 @@ void UnaryDeviceOperation::validate_on_program_cache_miss(
}
}

shape_return_value_t UnaryDeviceOperation::compute_output_shapes(
const operation_attributes_t&, const tensor_args_t& tensor_args) {
return {tensor_args.input.get_logical_shape()};
}

tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
spec_return_value_t UnaryDeviceOperation::compute_output_specs(
const operation_attributes_t& args, const tensor_args_t& tensor_args) {
if (tensor_args.preallocated_output.has_value()) {
return tensor_args.preallocated_output.value();
return tensor_args.preallocated_output->get_tensor_spec();
}

auto output_layout = Layout::TILE;
if (args.output_memory_config.is_sharded()) {
output_layout = tensor_args.input.get_layout();
}

const auto output_shape = tensor_args.input.shape();
return create_device_tensor(
output_shape, args.output_dtype, output_layout, tensor_args.input.device(), args.output_memory_config);
const auto output_shape = tensor_args.input.logical_shape();
return TensorSpec(output_shape, TensorLayout(args.output_dtype, output_layout, args.output_memory_config));
}

tensor_return_value_t UnaryDeviceOperation::create_output_tensors(
const operation_attributes_t& args, const tensor_args_t& tensor_args) {
if (tensor_args.preallocated_output.has_value()) {
return *tensor_args.preallocated_output;
}
return create_device_tensor(compute_output_specs(args, tensor_args), tensor_args.input.device());
}

tt::stl::hash::hash_t UnaryDeviceOperation::compute_program_hash(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ struct UnaryDeviceOperation {

using operation_attributes_t = unary::operation_attributes_t;
using tensor_args_t = unary::tensor_args_t;
using shape_return_value_t = unary::shape_return_value_t;
using spec_return_value_t = unary::spec_return_value_t;
using tensor_return_value_t = unary::tensor_return_value_t;
using program_factory_t = std::variant<program::UnaryProgramFactory, program::UnaryShardedProgramFactory>;

Expand All @@ -33,7 +33,7 @@ struct UnaryDeviceOperation {
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);

static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);

static tensor_return_value_t create_output_tensors(const operation_attributes_t& operation_attributes, const tensor_args_t&);

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,6 @@ struct tensor_args_t {

using tensor_return_value_t = Tensor;

using shape_return_value_t = ttnn::SimpleShape;
using spec_return_value_t = TensorSpec;

} // namespace ttnn::operations::unary
Original file line number Diff line number Diff line change
Expand Up @@ -45,26 +45,16 @@ void Embeddings::validate(const std::vector<Tensor> &input_tensors) const {
}
}

std::vector<tt::tt_metal::LegacyShape> Embeddings::compute_output_shapes(const std::vector<Tensor> &input_tensors) const {
std::vector<TensorSpec> Embeddings::compute_output_specs(const std::vector<Tensor> &input_tensors) const {
const auto &input_tensor = input_tensors.at(0);
const auto &weight_tensor = input_tensors.at(1);
auto num_output_embeddings = input_tensor.get_legacy_shape()[3];
auto batch_num = input_tensor.get_legacy_shape()[0];
auto num_embedding_dims = weight_tensor.get_legacy_shape()[3];
auto num_output_embeddings = input_tensor.logical_shape()[3];
auto batch_num = input_tensor.logical_shape()[0];
auto num_embedding_dims = weight_tensor.logical_shape()[3];

tt::tt_metal::LegacyShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims});
return {output_shape};
}

std::vector<Tensor> Embeddings::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
const auto &weight_tensor = input_tensors.at(1);
if (!tilized) {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::ROW_MAJOR, this->output_mem_config);
} else {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
}
ttnn::SimpleShape output_shape({batch_num, 1, num_output_embeddings, num_embedding_dims});
auto output_layout = tilized ? Layout::TILE : Layout::ROW_MAJOR;
return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(output_layout), output_mem_config))};
}

operation::ProgramWithCallbacks Embeddings::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ struct Embeddings {
const DataType output_dtype;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
};
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -68,18 +68,13 @@ void EmbeddingBackward::validate(const std::vector<Tensor> &input_tensors) const
"Number of rows in gradient tensor must be equal to number of indices in index tensor");
}

std::vector<tt::tt_metal::LegacyShape> EmbeddingBackward::compute_output_shapes(
std::vector<TensorSpec> EmbeddingBackward::compute_output_specs(
const std::vector<Tensor> &input_tensors) const {
const auto &grad_tensor = input_tensors.at(1);
auto embedding_dim = grad_tensor.get_legacy_shape()[-1];
auto embedding_dim = grad_tensor.get_logical_shape()[-1];

tt::tt_metal::LegacyShape output_shape({1, 1, this->num_embeddings, embedding_dim});
return {output_shape};
}

std::vector<Tensor> EmbeddingBackward::create_output_tensors(const std::vector<Tensor> &input_tensors) const {
return operation::generic_create_output_tensors(
*this, input_tensors, this->output_dtype, Layout::TILE, this->output_mem_config);
ttnn::SimpleShape output_shape({1, 1, this->num_embeddings, embedding_dim});
return {TensorSpec(output_shape, TensorLayout(output_dtype, PageConfig(Layout::TILE), output_mem_config))};
}

operation::ProgramWithCallbacks EmbeddingBackward::create_program(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@ struct EmbeddingBackward {
uint32_t num_embeddings;

void validate(const std::vector<Tensor> &input_tensors) const;
std::vector<tt::tt_metal::LegacyShape> compute_output_shapes(const std::vector<Tensor> &input_tensors) const;
std::vector<Tensor> create_output_tensors(const std::vector<Tensor> &input_tensors) const;
std::vector<TensorSpec> compute_output_specs(const std::vector<Tensor> &input_tensors) const;
operation::ProgramWithCallbacks create_program(
const std::vector<Tensor> &input_tensors, std::vector<Tensor> &output_tensors) const;
tt::stl::reflection::Attributes attributes() const;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,9 @@ void UniformDeviceOperation::validate_on_program_cache_hit(
validate_inputs(operation_attributes, tensor_args);
}

UniformDeviceOperation::shape_return_value_t UniformDeviceOperation::compute_output_shapes(
TensorSpec UniformDeviceOperation::compute_output_specs(
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) {
return tensor_args.input.get_logical_shape();
return tensor_args.input.get_tensor_spec();
}

UniformDeviceOperation::tensor_return_value_t UniformDeviceOperation::create_output_tensors(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ struct UniformDeviceOperation {
const Tensor& input;
};

using shape_return_value_t = SimpleShape;
using spec_return_value_t = TensorSpec;
using tensor_return_value_t = Tensor;

struct ProgramFactory {
Expand Down Expand Up @@ -51,7 +51,7 @@ struct UniformDeviceOperation {
static void validate_inputs(const operation_attributes_t& attributes, const tensor_args_t& tensor_args);
static void validate_on_program_cache_miss(const operation_attributes_t&, const tensor_args_t&);
static void validate_on_program_cache_hit(const operation_attributes_t&, const tensor_args_t&);
static shape_return_value_t compute_output_shapes(const operation_attributes_t&, const tensor_args_t&);
static spec_return_value_t compute_output_specs(const operation_attributes_t&, const tensor_args_t&);
static tensor_return_value_t create_output_tensors(const operation_attributes_t&, const tensor_args_t&);

static std::tuple<operation_attributes_t, tensor_args_t> invoke(
Expand Down
Loading