From 7210be5fc08dd9433f206bb0352b56c2b6119ac2 Mon Sep 17 00:00:00 2001 From: Austin Ho Date: Fri, 6 Dec 2024 21:09:52 +0000 Subject: [PATCH] #0: Add reflection for std::map and std::unordered_map --- tt_metal/tt_stl/reflection.hpp | 120 +++++++++++++++++++++++++++++++-- 1 file changed, 114 insertions(+), 6 deletions(-) diff --git a/tt_metal/tt_stl/reflection.hpp b/tt_metal/tt_stl/reflection.hpp index 42c3aecb6a4b..2a76a8a57750 100644 --- a/tt_metal/tt_stl/reflection.hpp +++ b/tt_metal/tt_stl/reflection.hpp @@ -448,6 +448,32 @@ std::ostream& operator<<(std::ostream& os, const std::set& set) { return os; } +template +std::ostream& operator<<(std::ostream& os, const std::map& map) { + os << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + os << it->first << ": " << it->second; + if (it != map.end()) { + os << ", "; + } + } + os << "}"; + return os; +} + +template +std::ostream& operator<<(std::ostream& os, const std::unordered_map& map) { + os << "{"; + for (auto it = map.begin(); it != map.end(); ++it) { + os << it->first << ": " << it->second; + if (it != map.end()) { + os << ", "; + } + } + os << "}"; + return os; +} + template requires(tt::stl::concepts::Reflectable and not(std::integral or std::is_array::value)) std::ostream& operator<<(std::ostream& os, const T& object) { @@ -978,6 +1004,30 @@ struct fmt::formatter> { } }; +template +struct fmt::formatter> { + constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } + + auto format(const std::map& map, format_context& ctx) const -> format_context::iterator { + using tt::stl::reflection::operator<<; + std::stringstream ss; + ss << map; + return fmt::format_to(ctx.out(), "{}", ss.str()); + } +}; + +template +struct fmt::formatter> { + constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); } + + auto format(const std::unordered_map& map, format_context& ctx) const -> format_context::iterator { + using tt::stl::reflection::operator<<; + std::stringstream ss; + ss << map; + return fmt::format_to(ctx.out(), "{}", ss.str()); + } +}; + template requires( tt::stl::concepts::Reflectable and not(std::integral or std::is_array::value or @@ -1063,7 +1113,7 @@ inline hash_t hash_object(const T& object) noexcept { fmt::print("Hashing struct {} using compile-time attributes: {}\n", get_type_name(), object); } constexpr auto num_attributes = reflection::detail::get_num_attributes(); - std::size_t hash = 0; + hash_t hash = 0; const auto attribute_values = object.attribute_values(); [&object, &hash, &attribute_values](std::index_sequence) { ( @@ -1074,11 +1124,26 @@ inline hash_t hash_object(const T& object) noexcept { ...); }(std::make_index_sequence{}); return hash; + } else if constexpr (is_specialization_v) { + if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { + fmt::print("Hashing std::tuple of type {}: {}\n", get_type_name(), object); + } + constexpr auto num_elements = std::tuple_size_v; + hash_t hash = 0; + [&object, &hash](std::index_sequence) { + ( + [&object, &hash] { + const auto& element = std::get(object); + hash = hash_objects(hash, element); + }(), + ...); + }(std::make_index_sequence{}); + return hash; } else if constexpr (is_specialization_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing std::vector of type {}: {}\n", get_type_name(), object); } - auto hash = 0; + hash_t hash = 0; for (const auto& element : object) { hash = hash_objects(hash, element); } @@ -1087,11 +1152,32 @@ inline hash_t hash_object(const T& object) noexcept { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing std::set of type {}: {}\n", get_type_name(), object); } - auto hash = 0; + hash_t hash = 0; for (const auto& element : object) { hash = hash_objects(hash, element); } return hash; + } else if constexpr (is_specialization_v) { + if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { + fmt::print("Hashing std::map of type {}: {}\n", get_type_name(), object); + } + hash_t hash = 0; + for (const auto& [key, value] : object) { + hash = hash_objects(hash, key, value); + } + return hash; + } else if constexpr (is_specialization_v) { + if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { + fmt::print("Hashing std::unordered_map of type {}: {}\n", get_type_name(), object); + } + constexpr hash_t seed = 0x9e3779b9; + hash_t hash = 0; + // Combine using xor so that the hash is order invariant + // Alternative could be to sort the keys before hashing + for (const auto& [key, value] : object) { + hash ^= hash_objects(seed, key, value); + } + return hash; } else if constexpr (is_specialization_v) { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing std::optional of type {}: {}\n", get_type_name(), object); @@ -1105,7 +1191,7 @@ inline hash_t hash_object(const T& object) noexcept { if constexpr (DEBUG_HASH_OBJECT_FUNCTION) { fmt::print("Hashing struct {} using reflect library: {}\n", get_type_name(), object); } - std::size_t hash = 0; + hash_t hash = 0; reflect::for_each([&hash, &object](auto I) { hash = hash_objects(hash, reflect::get(object)); }, object); return hash; } else { @@ -1335,7 +1421,7 @@ struct to_json_t> { nlohmann::json operator()(const std::map& object) { nlohmann::json json_object = nlohmann::json::object(); for (const auto& [key, value] : object) { - json_object[to_json(key)] = to_json(value); + json_object[to_json(key).dump()] = to_json(value); } return json_object; } @@ -1346,7 +1432,29 @@ struct from_json_t> { std::map operator()(const nlohmann::json& json_object) { std::map object; for (const auto& [key, value] : json_object.items()) { - object[from_json(key)] = from_json(value); + object[from_json(nlohmann::json::parse(key))] = from_json(value); + } + return object; + } +}; + +template +struct to_json_t> { + nlohmann::json operator()(const std::unordered_map& object) { + nlohmann::json json_object = nlohmann::json::object(); + for (const auto& [key, value] : object) { + json_object[to_json(key).dump()] = to_json(value); + } + return json_object; + } +}; + +template +struct from_json_t> { + std::map operator()(const nlohmann::json& json_object) { + std::unordered_map object; + for (const auto& [key, value] : json_object.items()) { + object[from_json(nlohmann::json::parse(key))] = from_json(value); } return object; }