-
Notifications
You must be signed in to change notification settings - Fork 97
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
#0: Port eltwise and some misc ops to use TensorSpec #15471
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -150,11 +150,17 @@ void BinaryDeviceOperation::validate_on_program_cache_hit( | |
TT_FATAL(width_a == width_b || width_a == 1 || width_b == 1, "ttnn::operations::binary::BinaryDeviceOperation: width mismatch"); | ||
} | ||
|
||
BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_output_shapes( | ||
const operation_attributes_t&, const tensor_args_t& tensor_args) { | ||
const auto input_shape_a = tensor_args.input_tensor_a.shape(); | ||
BinaryDeviceOperation::spec_return_value_t BinaryDeviceOperation::compute_output_specs( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
const auto& output_tensor = tensor_args.output_tensor; | ||
if (output_tensor.has_value()) { | ||
return output_tensor->get_tensor_spec(); | ||
} | ||
|
||
const auto& input_tensor_a = tensor_args.input_tensor_a; | ||
const auto input_shape_a = input_tensor_a.logical_shape(); | ||
const auto& tensor_b = tensor_args.input_tensor_b; | ||
const auto input_shape_b = tensor_b.has_value() ? tensor_b->shape() : ttnn::Shape{1, 1}; | ||
const auto input_shape_b = tensor_b.has_value() ? tensor_b->logical_shape() : ttnn::SimpleShape{}; | ||
|
||
const int rank_a = input_shape_a.rank(); | ||
const int rank_b = input_shape_b.rank(); | ||
|
@@ -179,24 +185,9 @@ BinaryDeviceOperation::shape_return_value_t BinaryDeviceOperation::compute_outpu | |
output_shape[i + larger_rank] = dim_a + dim_b - 1; | ||
} | ||
} | ||
return output_shape; | ||
return ttnn::SimpleShape(output_shape); | ||
}; | ||
|
||
const auto logical_shape_a = input_shape_a.logical_shape(); | ||
const auto logical_shape_b = input_shape_b.logical_shape(); | ||
return ttnn::SimpleShape(compute_broadcasted_output(logical_shape_a, logical_shape_b)); | ||
} | ||
|
||
BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
using namespace tt::constants; | ||
auto output_shape = compute_output_shapes(operation_attributes, tensor_args); | ||
const auto& input_tensor_a = tensor_args.input_tensor_a; | ||
const auto& output_tensor = tensor_args.output_tensor; | ||
|
||
if (output_tensor.has_value()) { | ||
return output_tensor.value(); | ||
} | ||
auto output_shape = compute_broadcasted_output(input_shape_a, input_shape_b); | ||
|
||
auto program_factory = select_program_factory(operation_attributes, tensor_args); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Got a bit surprised that we call this method here. It means we call it at least two times per operation call. Not that it’s a problem but it’s a surprise. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yeah. I would expect that |
||
if (std::holds_alternative<ElementWiseMultiCore>(program_factory)) { | ||
|
@@ -212,8 +203,7 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu | |
} | ||
auto memory_config = operation_attributes.memory_config; | ||
memory_config.shard_spec = shard_spec; | ||
return create_device_tensor( | ||
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config); | ||
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config)); | ||
} | ||
} else { | ||
if (operation_attributes.memory_config.is_sharded()) { | ||
|
@@ -224,16 +214,18 @@ BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_outpu | |
} | ||
auto memory_config = operation_attributes.memory_config; | ||
memory_config.shard_spec = shard_spec; | ||
return create_device_tensor( | ||
output_shape, operation_attributes.dtype, Layout::TILE, input_tensor_a.device(), memory_config); | ||
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), memory_config)); | ||
} | ||
} | ||
return create_device_tensor( | ||
output_shape, | ||
operation_attributes.dtype, | ||
Layout::TILE, | ||
input_tensor_a.device(), | ||
operation_attributes.memory_config); | ||
return TensorSpec(output_shape, TensorLayout(operation_attributes.dtype, PageConfig(Layout::TILE), operation_attributes.memory_config)); | ||
} | ||
|
||
BinaryDeviceOperation::tensor_return_value_t BinaryDeviceOperation::create_output_tensors( | ||
const operation_attributes_t& operation_attributes, const tensor_args_t& tensor_args) { | ||
if (tensor_args.output_tensor.has_value()) { | ||
return *tensor_args.output_tensor; | ||
} | ||
return create_device_tensor(compute_output_specs(operation_attributes, tensor_args), tensor_args.input_tensor_a.device()); | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. So interesting. Tensor might have multiple devices today. I see this part of code did not change, but I wonder.. On this level multidevice basically does not exist today if I get it right. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Yes. Multi-device seems to be problematic overall at the moment. With having single MeshDevice it should work as expected though. |
||
} | ||
|
||
tt::stl::hash::hash_t BinaryDeviceOperation::compute_program_hash( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Has value
what does it mean?There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
tensor_b
is optional here. I'm unsure what it means for a binary op to have a second argument as optional, but preserving the behaviorThere was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I remember now. This is how float scalar is handled. Think of it like there is a union. It’s or float or tensor. Just expressed in this way