Skip to content

Commit

Permalink
Allow the LLVM cache machinery to cache multiple modules at a time.
Browse files Browse the repository at this point in the history
  • Loading branch information
bluescarni committed Aug 1, 2024
1 parent b2f71ef commit 2f29bdb
Show file tree
Hide file tree
Showing 3 changed files with 81 additions and 36 deletions.
9 changes: 6 additions & 3 deletions include/heyoka/llvm_state.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
#include <heyoka/config.hpp>

#include <concepts>
#include <cstddef>
#include <cstdint>
#include <memory>
#include <ostream>
Expand Down Expand Up @@ -338,12 +339,14 @@ namespace detail

// The value contained in the in-memory cache.
struct llvm_mc_value {
std::string opt_bc, opt_ir, obj;
std::vector<std::string> opt_bc, opt_ir, obj;

std::size_t total_size() const;
};

// Cache lookup and insertion.
std::optional<llvm_mc_value> llvm_state_mem_cache_lookup(const std::string &, unsigned);
void llvm_state_mem_cache_try_insert(std::string, unsigned, llvm_mc_value);
std::optional<llvm_mc_value> llvm_state_mem_cache_lookup(const std::vector<std::string> &, unsigned);
void llvm_state_mem_cache_try_insert(std::vector<std::string>, unsigned, llvm_mc_value);

} // namespace detail

Expand Down
26 changes: 17 additions & 9 deletions src/llvm_state.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1325,11 +1325,15 @@ void llvm_state::compile()
// to fix the module and re-attempt compilation without having
// altered the module and without having already added the trigger
// function.
// NOTE: this function does its own cleanup, no need to
// start the try catch block yet.
add_obj_trigger();

try {
// Fetch the bitcode *before* optimisation.
auto orig_bc = get_bc();
std::vector<std::string> obc;
obc.push_back(std::move(orig_bc));

// Combine m_opt_level, m_force_avx512, m_slp_vectorize and m_c_model into a single value,
// as they all affect codegen.
Expand All @@ -1341,22 +1345,26 @@ void llvm_state::compile()
assert(m_opt_level <= 3u);
assert(static_cast<unsigned>(m_c_model) <= 7u);
static_assert(std::numeric_limits<unsigned>::digits >= 7u);
const auto olevel = m_opt_level + (static_cast<unsigned>(m_force_avx512) << 2)
+ (static_cast<unsigned>(m_slp_vectorize) << 3) + (static_cast<unsigned>(m_c_model) << 4);
const auto comp_flag = m_opt_level + (static_cast<unsigned>(m_force_avx512) << 2)
+ (static_cast<unsigned>(m_slp_vectorize) << 3)
+ (static_cast<unsigned>(m_c_model) << 4);

if (auto cached_data = detail::llvm_state_mem_cache_lookup(orig_bc, olevel)) {
if (auto cached_data = detail::llvm_state_mem_cache_lookup(obc, comp_flag)) {
// Cache hit.

// Assign the snapshots.
m_ir_snapshot = std::move(cached_data->opt_ir);
m_bc_snapshot = std::move(cached_data->opt_bc);
assert(cached_data->opt_ir.size() == 1u);
assert(cached_data->opt_bc.size() == 1u);
assert(cached_data->obj.size() == 1u);
m_ir_snapshot = std::move(cached_data->opt_ir[0]);
m_bc_snapshot = std::move(cached_data->opt_bc[0]);

// Clear out module and builder.
m_module.reset();
m_builder.reset();

// Assign the object file.
detail::llvm_state_add_obj_to_jit(*m_jitter, std::move(cached_data->obj));
detail::llvm_state_add_obj_to_jit(*m_jitter, std::move(cached_data->obj[0]));
} else {
sw.reset();

Expand All @@ -1372,10 +1380,10 @@ void llvm_state::compile()

logger->trace("materialisation runtime: {}", sw);

// Try to insert orig_bc into the cache.
detail::llvm_state_mem_cache_try_insert(std::move(orig_bc), olevel,
// Try to insert obc into the cache.
detail::llvm_state_mem_cache_try_insert(std::move(obc), comp_flag,
// NOLINTNEXTLINE(bugprone-unchecked-optional-access)
{m_bc_snapshot, m_ir_snapshot, *m_jitter->m_object_file});
{{m_bc_snapshot}, {m_ir_snapshot}, {*m_jitter->m_object_file}});
}
// LCOV_EXCL_START
} catch (...) {
Expand Down
82 changes: 58 additions & 24 deletions src/llvm_state_mem_cache.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
#include <optional>
#include <string>
#include <utility>
#include <variant>
#include <vector>

#include <boost/container_hash/hash.hpp>
#include <boost/numeric/conversion/cast.hpp>
Expand All @@ -28,7 +28,8 @@
#include <heyoka/llvm_state.hpp>

// This in-memory cache maps the bitcode
// of an LLVM module and an optimisation level to:
// of one or more LLVM modules and an integer flag
// (representing several compilation settings) to:
//
// - the optimised version of the bitcode,
// - the textual IR corresponding
Expand All @@ -43,6 +44,26 @@ HEYOKA_BEGIN_NAMESPACE
namespace detail
{

// Helper to compute the total size in bytes
// of the data contained in an llvm_mc_value.
// Will throw on overflow.
std::size_t llvm_mc_value::total_size() const
{
assert(!opt_bc.empty());
assert(opt_bc.size() == opt_ir.size());
assert(opt_bc.size() == obj.size());

boost::safe_numerics::safe<std::size_t> ret = 0;

for (decltype(opt_bc.size()) i = 0; i < opt_bc.size(); ++i) {
ret += opt_bc[i].size();
ret += opt_ir[i].size();
ret += obj[i].size();
}

return ret;
}

namespace
{

Expand All @@ -56,16 +77,33 @@ HEYOKA_CONSTINIT
std::mutex mem_cache_mutex;

// Definition of the data structures for the cache.
using lru_queue_t = std::list<std::pair<std::string, unsigned>>;
using lru_queue_t = std::list<std::pair<std::vector<std::string>, unsigned>>;

using lru_key_t = lru_queue_t::iterator;

// Implementation of hashing for std::pair<std::vector<std::string>, unsigned> and
// its heterogeneous counterpart.
template <typename T>
auto cache_key_hasher(const T &k) noexcept
{
assert(!k.first.empty());

// Combine the bitcodes.
auto seed = std::hash<std::string>{}(k.first[0]);
for (decltype(k.first.size()) i = 1; i < k.first.size(); ++i) {
boost::hash_combine(seed, k.first[i]);
}

// Combine with the compilation flag.
boost::hash_combine(seed, static_cast<std::size_t>(k.second));

return seed;
}

struct lru_hasher {
std::size_t operator()(const lru_key_t &k) const noexcept
{
auto seed = std::hash<std::string>{}(k->first);
boost::hash_combine(seed, k->second);
return seed;
return cache_key_hasher(*k);
}
};

Expand Down Expand Up @@ -96,16 +134,16 @@ HEYOKA_CONSTINIT std::uint64_t mem_cache_limit = 2147483648ull;

// Machinery for heterogeneous lookup into the cache.
// NOTE: this function MUST be invoked while holding the global lock.
auto llvm_state_mem_cache_hl(const std::string &bc, unsigned opt_level)
auto llvm_state_mem_cache_hl(const std::vector<std::string> &bc, unsigned comp_flag)
{
using compat_key_t = std::pair<const std::string &, unsigned>;
// NOTE: the heterogeneous version of the key replaces std::vector<std::string>
// with a const reference.
using compat_key_t = std::pair<const std::vector<std::string> &, unsigned>;

struct compat_hasher {
std::size_t operator()(const compat_key_t &k) const noexcept
{
auto seed = std::hash<std::string>{}(k.first);
boost::hash_combine(seed, k.second);
return seed;
return cache_key_hasher(k);
}
};

Expand All @@ -120,7 +158,7 @@ auto llvm_state_mem_cache_hl(const std::string &bc, unsigned opt_level)
}
};

return lru_map.find(std::make_pair(std::cref(bc), opt_level), compat_hasher{}, compat_cmp{});
return lru_map.find(std::make_pair(std::cref(bc), comp_flag), compat_hasher{}, compat_cmp{});
}

// Debug function to run sanity checks on the cache.
Expand All @@ -131,23 +169,21 @@ void llvm_state_mem_cache_sanity_checks()

// Check that the computed size of the cache is consistent with mem_cache_size.
assert(std::accumulate(lru_map.begin(), lru_map.end(), boost::safe_numerics::safe<std::size_t>(0),
[](const auto &a, const auto &p) {
return a + p.second.opt_bc.size() + p.second.opt_ir.size() + p.second.obj.size();
})
[](const auto &a, const auto &p) { return a + p.second.total_size(); })
== mem_cache_size);
}

} // namespace

std::optional<llvm_mc_value> llvm_state_mem_cache_lookup(const std::string &bc, unsigned opt_level)
std::optional<llvm_mc_value> llvm_state_mem_cache_lookup(const std::vector<std::string> &bc, unsigned comp_flag)
{
// Lock down.
const std::lock_guard lock(mem_cache_mutex);

// Sanity checks.
llvm_state_mem_cache_sanity_checks();

if (const auto it = llvm_state_mem_cache_hl(bc, opt_level); it == lru_map.end()) {
if (const auto it = llvm_state_mem_cache_hl(bc, comp_flag); it == lru_map.end()) {
// Cache miss.
return {};
} else {
Expand All @@ -163,7 +199,7 @@ std::optional<llvm_mc_value> llvm_state_mem_cache_lookup(const std::string &bc,
}
}

void llvm_state_mem_cache_try_insert(std::string bc, unsigned opt_level, llvm_mc_value val)
void llvm_state_mem_cache_try_insert(std::vector<std::string> bc, unsigned comp_flag, llvm_mc_value val)
{
// Lock down.
const std::lock_guard lock(mem_cache_mutex);
Expand All @@ -174,7 +210,7 @@ void llvm_state_mem_cache_try_insert(std::string bc, unsigned opt_level, llvm_mc
// Do a first lookup to check if bc is already in the cache.
// This could happen, e.g., if two threads are compiling the same
// code concurrently.
if (const auto it = llvm_state_mem_cache_hl(bc, opt_level); it != lru_map.end()) {
if (const auto it = llvm_state_mem_cache_hl(bc, comp_flag); it != lru_map.end()) {
assert(val.opt_bc == it->second.opt_bc);
assert(val.opt_ir == it->second.opt_ir);
assert(val.obj == it->second.obj);
Expand All @@ -183,8 +219,7 @@ void llvm_state_mem_cache_try_insert(std::string bc, unsigned opt_level, llvm_mc
}

// Compute the new cache size.
auto new_cache_size = static_cast<std::size_t>(boost::safe_numerics::safe<std::size_t>(mem_cache_size)
+ val.opt_bc.size() + val.opt_ir.size() + val.obj.size());
auto new_cache_size = boost::safe_numerics::safe<std::size_t>(mem_cache_size) + val.total_size();

// Remove items from the cache if we are exceeding
// the limit.
Expand All @@ -195,8 +230,7 @@ void llvm_state_mem_cache_try_insert(std::string bc, unsigned opt_level, llvm_mc
const auto &cur_val = cur_it->second;
// NOTE: no possibility of overflow here, as cur_size is guaranteed
// not to be greater than mem_cache_size.
const auto cur_size
= static_cast<std::size_t>(cur_val.opt_bc.size()) + cur_val.opt_ir.size() + cur_val.obj.size();
const auto cur_size = cur_val.total_size();

// NOTE: the next 4 lines cannot throw, which ensures that the
// cache cannot be left in an inconsistent state.
Expand All @@ -222,7 +256,7 @@ void llvm_state_mem_cache_try_insert(std::string bc, unsigned opt_level, llvm_mc
// Add the new item to the front of the queue.
// NOTE: if this throws, we have not modified lru_map yet,
// no cleanup needed.
lru_queue.emplace_front(std::move(bc), opt_level);
lru_queue.emplace_front(std::move(bc), comp_flag);

// Add the new item to the map.
try {
Expand Down

0 comments on commit 2f29bdb

Please sign in to comment.