Skip to content

Commit

Permalink
#5389: minor fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
arakhmati committed May 10, 2024
1 parent 13e23a6 commit e570cd5
Show file tree
Hide file tree
Showing 6 changed files with 8 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,6 @@ const operation::Hash EltwiseBinary::compute_program_hash(
const auto& input_tensor_b = input_tensors.at(1);

operation::Hash hash = tt::stl::hash::hash_objects(
0,
typeid(*this).hash_code(),
this->op_type,
parallelization_strategy,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ struct EltwiseBinary {
return result;
}
static constexpr auto attribute_names =
std::make_tuple("op_type", "fused_activations", "output_mem_config", "output_dtype");
std::make_tuple("op_type", "fused_activations", "output_mem_config", "output_dtype", "in_place");
const auto attribute_values() const {
return std::make_tuple(
std::cref(this->op_type),
Expand Down
2 changes: 0 additions & 2 deletions tt_eager/tt_dnn/op_library/eltwise_unary/eltwise_unary_op.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -324,12 +324,10 @@ const operation::Hash EltwiseUnary::compute_program_hash(const std::vector<Tenso
const auto& input_shape = input_tensor.get_legacy_shape();

operation::Hash hash = tt::stl::hash::hash_objects(
0,
typeid(*this).hash_code(),
compute_volume(input_shape),
input_tensor.get_dtype(),
input_tensor.memory_config(),
input_tensor.device()->id(),
this->output_mem_config);

for (const auto& unary_with_param_op : this->op_chain) {
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/operation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ using Hash = tt::stl::hash::hash_t;
template <typename OperationType, typename... Types>
static Hash hash_operation(const Types&... objects) {
auto operation_type_hash = typeid(OperationType).hash_code();
return stl::hash::hash_objects(0, operation_type_hash, objects...);
return stl::hash::hash_objects(operation_type_hash, objects...);
}

using OverrideAddressesCallback =
Expand Down
2 changes: 1 addition & 1 deletion tt_eager/tt_dnn/op_library/run_operation.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,7 @@ OutputTensors run_device_operation(
};
}

operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
auto output_tensors = operation.create_output_tensors(input_tensors, optional_output_tensors);
auto program = get_or_create_program(operation, input_tensors, optional_input_tensors, output_tensors);
uint32_t device_id = detail::get_device(input_tensors, optional_input_tensors)->id();
Expand Down Expand Up @@ -418,7 +419,6 @@ OutputTensors run(
// }
// }

operation.validate(input_tensors, optional_input_tensors, optional_output_tensors);
if (detail::any_tensor_on_multi_device(input_tensors)) {
return detail::decorate_device_operation(detail::run_multi_device_operation<OutputTensors>)(
std::nullopt, operation, input_tensors, optional_input_tensors, optional_output_tensors);
Expand Down
6 changes: 5 additions & 1 deletion tt_metal/tt_stl/reflection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -187,6 +187,10 @@ constexpr bool supports_runtime_time_attributes_v = std::experimental::is_detect

template <typename T>
inline constexpr std::size_t get_num_attributes() {
static_assert(
std::tuple_size_v<decltype(T::attribute_names)> ==
std::tuple_size_v<decltype(std::declval<T>().attribute_values())>,
"Number of attribute_names must match number of attribute_values");
return std::tuple_size_v<decltype(T::attribute_names)>;
}
template <typename T>
Expand Down Expand Up @@ -566,7 +570,7 @@ inline hash_t hash_object(const T& object) noexcept {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing struct {} using run-time attributes: {}\n", get_type_name<T>(), object);
}
return hash_objects(0, typeid(T).hash_code(), object.attributes());
return hash_objects(typeid(T).hash_code(), object.attributes());
} else if constexpr (detail::is_specialization_v<T, std::vector>) {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing std::vector of type {}: {}\n", get_type_name<T>(), object);
Expand Down

0 comments on commit e570cd5

Please sign in to comment.