Skip to content

Commit

Permalink
#0: Add reflection for std::map and std::unordered_map
Browse files Browse the repository at this point in the history
  • Loading branch information
tt-aho committed Dec 9, 2024
1 parent 4393945 commit 7210be5
Showing 1 changed file with 114 additions and 6 deletions.
120 changes: 114 additions & 6 deletions tt_metal/tt_stl/reflection.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -448,6 +448,32 @@ std::ostream& operator<<(std::ostream& os, const std::set<T>& set) {
return os;
}

template <typename K, typename V>
std::ostream& operator<<(std::ostream& os, const std::map<K, V>& map) {
os << "{";
for (auto it = map.begin(); it != map.end(); ++it) {
os << it->first << ": " << it->second;
if (it != map.end()) {
os << ", ";
}
}
os << "}";
return os;
}

template <typename K, typename V>
std::ostream& operator<<(std::ostream& os, const std::unordered_map<K, V>& map) {
os << "{";
for (auto it = map.begin(); it != map.end(); ++it) {
os << it->first << ": " << it->second;
if (it != map.end()) {
os << ", ";
}
}
os << "}";
return os;
}

template <typename T>
requires(tt::stl::concepts::Reflectable<T> and not(std::integral<T> or std::is_array<T>::value))
std::ostream& operator<<(std::ostream& os, const T& object) {
Expand Down Expand Up @@ -978,6 +1004,30 @@ struct fmt::formatter<std::set<T>> {
}
};

template <typename K, typename V>
struct fmt::formatter<std::map<K, V>> {
constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); }

auto format(const std::map<K, V>& map, format_context& ctx) const -> format_context::iterator {
using tt::stl::reflection::operator<<;
std::stringstream ss;
ss << map;
return fmt::format_to(ctx.out(), "{}", ss.str());
}
};

template <typename K, typename V>
struct fmt::formatter<std::unordered_map<K, V>> {
constexpr auto parse(format_parse_context& ctx) -> format_parse_context::iterator { return ctx.end(); }

auto format(const std::unordered_map<K, V>& map, format_context& ctx) const -> format_context::iterator {
using tt::stl::reflection::operator<<;
std::stringstream ss;
ss << map;
return fmt::format_to(ctx.out(), "{}", ss.str());
}
};

template <typename T>
requires(
tt::stl::concepts::Reflectable<T> and not(std::integral<T> or std::is_array<T>::value or
Expand Down Expand Up @@ -1063,7 +1113,7 @@ inline hash_t hash_object(const T& object) noexcept {
fmt::print("Hashing struct {} using compile-time attributes: {}\n", get_type_name<T>(), object);
}
constexpr auto num_attributes = reflection::detail::get_num_attributes<T>();
std::size_t hash = 0;
hash_t hash = 0;
const auto attribute_values = object.attribute_values();
[&object, &hash, &attribute_values]<size_t... Ns>(std::index_sequence<Ns...>) {
(
Expand All @@ -1074,11 +1124,26 @@ inline hash_t hash_object(const T& object) noexcept {
...);
}(std::make_index_sequence<num_attributes>{});
return hash;
} else if constexpr (is_specialization_v<T, std::tuple>) {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing std::tuple of type {}: {}\n", get_type_name<T>(), object);
}
constexpr auto num_elements = std::tuple_size_v<T>;
hash_t hash = 0;
[&object, &hash]<size_t... Ns>(std::index_sequence<Ns...>) {
(
[&object, &hash] {
const auto& element = std::get<Ns>(object);
hash = hash_objects(hash, element);
}(),
...);
}(std::make_index_sequence<num_elements>{});
return hash;
} else if constexpr (is_specialization_v<T, std::vector>) {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing std::vector of type {}: {}\n", get_type_name<T>(), object);
}
auto hash = 0;
hash_t hash = 0;
for (const auto& element : object) {
hash = hash_objects(hash, element);
}
Expand All @@ -1087,11 +1152,32 @@ inline hash_t hash_object(const T& object) noexcept {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing std::set of type {}: {}\n", get_type_name<T>(), object);
}
auto hash = 0;
hash_t hash = 0;
for (const auto& element : object) {
hash = hash_objects(hash, element);
}
return hash;
} else if constexpr (is_specialization_v<T, std::map>) {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing std::map of type {}: {}\n", get_type_name<T>(), object);
}
hash_t hash = 0;
for (const auto& [key, value] : object) {
hash = hash_objects(hash, key, value);
}
return hash;
} else if constexpr (is_specialization_v<T, std::unordered_map>) {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing std::unordered_map of type {}: {}\n", get_type_name<T>(), object);
}
constexpr hash_t seed = 0x9e3779b9;
hash_t hash = 0;
// Combine using xor so that the hash is order invariant
// Alternative could be to sort the keys before hashing
for (const auto& [key, value] : object) {
hash ^= hash_objects(seed, key, value);
}
return hash;
} else if constexpr (is_specialization_v<T, std::optional>) {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing std::optional of type {}: {}\n", get_type_name<T>(), object);
Expand All @@ -1105,7 +1191,7 @@ inline hash_t hash_object(const T& object) noexcept {
if constexpr (DEBUG_HASH_OBJECT_FUNCTION) {
fmt::print("Hashing struct {} using reflect library: {}\n", get_type_name<T>(), object);
}
std::size_t hash = 0;
hash_t hash = 0;
reflect::for_each([&hash, &object](auto I) { hash = hash_objects(hash, reflect::get<I>(object)); }, object);
return hash;
} else {
Expand Down Expand Up @@ -1335,7 +1421,7 @@ struct to_json_t<std::map<K, V>> {
nlohmann::json operator()(const std::map<K, V>& object) {
nlohmann::json json_object = nlohmann::json::object();
for (const auto& [key, value] : object) {
json_object[to_json(key)] = to_json(value);
json_object[to_json(key).dump()] = to_json(value);
}
return json_object;
}
Expand All @@ -1346,7 +1432,29 @@ struct from_json_t<std::map<K, V>> {
std::map<K, V> operator()(const nlohmann::json& json_object) {
std::map<K, V> object;
for (const auto& [key, value] : json_object.items()) {
object[from_json<K>(key)] = from_json<V>(value);
object[from_json<K>(nlohmann::json::parse(key))] = from_json<V>(value);
}
return object;
}
};

template <typename K, typename V>
struct to_json_t<std::unordered_map<K, V>> {
nlohmann::json operator()(const std::unordered_map<K, V>& object) {
nlohmann::json json_object = nlohmann::json::object();
for (const auto& [key, value] : object) {
json_object[to_json(key).dump()] = to_json(value);
}
return json_object;
}
};

template <typename K, typename V>
struct from_json_t<std::unordered_map<K, V>> {
std::map<K, V> operator()(const nlohmann::json& json_object) {
std::unordered_map<K, V> object;
for (const auto& [key, value] : json_object.items()) {
object[from_json<K>(nlohmann::json::parse(key))] = from_json<V>(value);
}
return object;
}
Expand Down

0 comments on commit 7210be5

Please sign in to comment.