Skip to content

Commit

Permalink
some simplificatoin
Browse files Browse the repository at this point in the history
  • Loading branch information
awni committed Dec 12, 2024
1 parent d6c12bd commit 990aa66
Showing 1 changed file with 12 additions and 24 deletions.
36 changes: 12 additions & 24 deletions mlx/export.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -55,16 +55,16 @@ constexpr bool is_iterable<
decltype(std::declval<T>().end())>> = true;

template <template <typename...> class T, typename U>
struct is_specialization_of : std::false_type {};
constexpr bool is_specialization_of = false;

template <template <typename...> class T, typename... Us>
struct is_specialization_of<T, T<Us...>> : std::true_type {};
constexpr bool is_specialization_of<T, T<Us...>> = true;

template <typename T>
constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>{};
constexpr bool is_pair = is_specialization_of<std::pair, std::decay_t<T>>;

template <typename T>
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>{};
constexpr bool is_tuple = is_specialization_of<std::tuple, std::decay_t<T>>;

template <typename>
constexpr bool dependent_false = false;
Expand Down Expand Up @@ -168,25 +168,6 @@ array deserialize(Reader& is) {
return array(std::move(shape), type, nullptr, std::vector<array>{});
}

template <typename T, typename U>
std::shared_ptr<T> construct_primitive(Stream s, U&& arg) {
return std::make_shared<T>(s, std::forward<U>(arg));
}
template <typename T, typename... Args>
std::shared_ptr<T> construct_primitive(Stream s, std::tuple<Args...> args) {
auto fn = [s](auto&&... args) {
return std::make_shared<T>(s, std::forward<decltype(args)>(args)...);
};
return std::apply(fn, args);
}
template <typename T, typename... Args>
std::shared_ptr<T> construct_primitive(Stream s, std::pair<Args...> args) {
auto fn = [s](auto&&... args) {
return std::make_shared<T>(s, std::forward<decltype(args)>(args)...);
};
return std::apply(fn, args);
}

template <typename, typename = void>
constexpr bool has_state = false;

Expand All @@ -205,7 +186,14 @@ template <typename T>
std::shared_ptr<T> deserialize_primitive(Reader& is, Stream s) {
if constexpr (has_state<T>) {
auto args = deserialize<decltype(std::declval<T>().state())>(is);
return construct_primitive<T>(s, args);
if constexpr (is_pair<decltype(args)> || is_tuple<decltype(args)>) {
auto fn = [s](auto&&... args) {
return std::make_shared<T>(s, std::move(args)...);
};
return std::apply(fn, std::move(args));
} else {
return std::make_shared<T>(s, std::move(args));
}
} else {
return std::make_shared<T>(s);
}
Expand Down

0 comments on commit 990aa66

Please sign in to comment.