diff --git a/include/pisa/codec/block_codec_registry.hpp b/include/pisa/codec/block_codec_registry.hpp new file mode 100644 index 00000000..9fe4af0f --- /dev/null +++ b/include/pisa/codec/block_codec_registry.hpp @@ -0,0 +1,12 @@ +#pragma once + +#include +#include + +#include "codec/block_codec.hpp" + +namespace pisa { + +[[nodiscard]] auto get_block_codec(std::string_view name) -> std::unique_ptr; + +} // namespace pisa diff --git a/include/pisa/codec/interpolative.hpp b/include/pisa/codec/interpolative.hpp index b4f6ff10..c59002df 100644 --- a/include/pisa/codec/interpolative.hpp +++ b/include/pisa/codec/interpolative.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include "codec/block_codec.hpp" @@ -10,6 +11,8 @@ class InterpolativeBlockCodec: public BlockCodec { static constexpr std::uint64_t m_block_size = 128; public: + constexpr static std::string_view name = "block_interpolative"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/maskedvbyte.hpp b/include/pisa/codec/maskedvbyte.hpp index 13adf3a8..86f0686a 100644 --- a/include/pisa/codec/maskedvbyte.hpp +++ b/include/pisa/codec/maskedvbyte.hpp @@ -37,6 +37,8 @@ class MaskedVByteBlockCodec: public BlockCodec { static constexpr std::uint64_t m_overflow = 512; public: + constexpr static std::string_view name = "block_maskedvbyte"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/optpfor.hpp b/include/pisa/codec/optpfor.hpp index 7d209539..f534ed5d 100644 --- a/include/pisa/codec/optpfor.hpp +++ b/include/pisa/codec/optpfor.hpp @@ -46,6 +46,8 @@ class OptPForBlockCodec: public BlockCodec { static const uint64_t m_block_size = Codec::BlockSize; public: + constexpr static std::string_view name = "block_optpfor"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/qmx.hpp b/include/pisa/codec/qmx.hpp index 393a7ff9..f6c9d412 100644 --- a/include/pisa/codec/qmx.hpp +++ b/include/pisa/codec/qmx.hpp @@ -54,6 +54,8 @@ class QmxBlockCodec: public BlockCodec { static constexpr std::uint64_t m_overflow = 512; public: + constexpr static std::string_view name = "block_qmx"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/simdbp.hpp b/include/pisa/codec/simdbp.hpp index 1c35b854..468f2bbf 100644 --- a/include/pisa/codec/simdbp.hpp +++ b/include/pisa/codec/simdbp.hpp @@ -41,6 +41,8 @@ class SimdBpBlockCodec: public BlockCodec { static constexpr std::uint64_t m_block_size = 128; public: + constexpr static std::string_view name = "block_simdbp"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/simple16.hpp b/include/pisa/codec/simple16.hpp index 3d73537a..f145a0f2 100644 --- a/include/pisa/codec/simple16.hpp +++ b/include/pisa/codec/simple16.hpp @@ -40,6 +40,8 @@ class Simple16BlockCodec: public BlockCodec { static constexpr std::uint64_t m_block_size = 128; public: + constexpr static std::string_view name = "block_simple16"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/simple8b.hpp b/include/pisa/codec/simple8b.hpp index 859a1da4..17dc1deb 100644 --- a/include/pisa/codec/simple8b.hpp +++ b/include/pisa/codec/simple8b.hpp @@ -35,6 +35,8 @@ class Simple8bBlockCodec: public BlockCodec { static constexpr std::uint64_t m_block_size = 128; public: + constexpr static std::string_view name = "block_simple8b"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/streamvbyte.hpp b/include/pisa/codec/streamvbyte.hpp index 7fd1b0b5..e9c4817b 100644 --- a/include/pisa/codec/streamvbyte.hpp +++ b/include/pisa/codec/streamvbyte.hpp @@ -3,6 +3,7 @@ #include #include #include +#include #include #include "codec/block_codec.hpp" @@ -43,6 +44,8 @@ class StreamVByteBlockCodec: public BlockCodec { pisa::streamvbyte_max_compressedbytes(m_block_size); public: + constexpr static std::string_view name = "block_streamvbyte"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/varint_g8iu.hpp b/include/pisa/codec/varint_g8iu.hpp index acc62732..f2e48b24 100644 --- a/include/pisa/codec/varint_g8iu.hpp +++ b/include/pisa/codec/varint_g8iu.hpp @@ -1,5 +1,6 @@ #pragma once +#include #include #include "codec/block_codec.hpp" @@ -10,6 +11,8 @@ class VarintG8IUBlockCodec: public BlockCodec { static const uint64_t m_block_size = 128; public: + constexpr static std::string_view name = "block_varintg8iu"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/include/pisa/codec/varintgb.hpp b/include/pisa/codec/varintgb.hpp index 84b370ad..b501821b 100644 --- a/include/pisa/codec/varintgb.hpp +++ b/include/pisa/codec/varintgb.hpp @@ -261,6 +261,8 @@ class VarintGbBlockCodec: public BlockCodec { static constexpr std::uint64_t m_block_size = 128; public: + constexpr static std::string_view name = "block_varintgb"; + void encode(uint32_t const* in, uint32_t sum_of_values, size_t n, std::vector& out) const; uint8_t const* decode(uint8_t const* in, uint32_t* out, uint32_t sum_of_values, size_t n) const; auto block_size() const noexcept -> std::size_t { return m_block_size; } diff --git a/src/codec/block_codec_registry.cpp b/src/codec/block_codec_registry.cpp new file mode 100644 index 00000000..40a7ae0c --- /dev/null +++ b/src/codec/block_codec_registry.cpp @@ -0,0 +1,64 @@ +#include "codec/block_codec_registry.hpp" + +#include +#include +#include +#include + +#include + +#include "codec/block_codec.hpp" +#include "codec/interpolative.hpp" +#include "codec/maskedvbyte.hpp" +#include "codec/optpfor.hpp" +#include "codec/qmx.hpp" +#include "codec/simdbp.hpp" +#include "codec/simple16.hpp" +#include "codec/simple8b.hpp" +#include "codec/streamvbyte.hpp" +#include "codec/varint_g8iu.hpp" +#include "codec/varintgb.hpp" + +namespace pisa { + +template +class BlockCodecRegistry { + using BlockCodecConstructor = std::unique_ptr (*)(); + + constexpr static std::array m_names = + std::array{C::name...}; + + constexpr static std::array m_constructors = + std::array{[]() -> std::unique_ptr { + return std::make_unique(); + }...}; + + public: + constexpr static auto count() -> std::size_t { return sizeof...(C); } + static auto get(std::string_view name) -> std::unique_ptr { + auto pos = std::find(m_names.begin(), m_names.end(), name); + if (pos == m_names.end()) { + throw std::domain_error(fmt::format("invalid codec: {}", name)); + } + auto constructor = m_constructors[std::distance(m_names.begin(), pos)]; + return constructor(); + } +}; + +using BlockCodecs = BlockCodecRegistry< + InterpolativeBlockCodec, + MaskedVByteBlockCodec, + OptPForBlockCodec, + QmxBlockCodec, + SimdBpBlockCodec, + Simple16BlockCodec, + Simple8bBlockCodec, + StreamVByteBlockCodec, + VarintG8IUBlockCodec, + VarintGbBlockCodec>; + +auto get_block_codec(std::string_view name) -> std::unique_ptr { + return BlockCodecs::get(name); +} + +} // namespace pisa diff --git a/tools/queries_dynamic.cpp b/tools/queries_dynamic.cpp index 9f335033..eaae51b7 100644 --- a/tools/queries_dynamic.cpp +++ b/tools/queries_dynamic.cpp @@ -1,4 +1,5 @@ #include +#include #include #include #include @@ -16,17 +17,7 @@ #include "accumulator/simple_accumulator.hpp" #include "app.hpp" #include "block_inverted_index.hpp" -#include "codec/block_codec.hpp" -#include "codec/interpolative.hpp" -#include "codec/maskedvbyte.hpp" -#include "codec/optpfor.hpp" -#include "codec/qmx.hpp" -#include "codec/simdbp.hpp" -#include "codec/simple16.hpp" -#include "codec/simple8b.hpp" -#include "codec/streamvbyte.hpp" -#include "codec/varint_g8iu.hpp" -#include "codec/varintgb.hpp" +#include "codec/block_codec_registry.hpp" #include "cursor/block_max_scored_cursor.hpp" #include "cursor/cursor.hpp" #include "cursor/max_scored_cursor.hpp" @@ -319,40 +310,6 @@ using wand_raw_index = wand_data; using wand_uniform_index = wand_data>; using wand_uniform_index_quantized = wand_data>; -auto resolve_codec(std::string_view encoding) -> std::unique_ptr { - if (encoding == "block_interpolative") { - return std::make_unique(); - } - if (encoding == "block_maskedvbyte") { - return std::make_unique(); - } - if (encoding == "block_optpfor") { - return std::make_unique(); - } - if (encoding == "block_qmx") { - return std::make_unique(); - } - if (encoding == "block_simdbp") { - return std::make_unique(); - } - if (encoding == "block_simple16") { - return std::make_unique(); - } - if (encoding == "block_simple8b") { - return std::make_unique(); - } - if (encoding == "block_streamvbyte") { - return std::make_unique(); - } - if (encoding == "block_varintg8iu") { - return std::make_unique(); - } - if (encoding == "block_varintgb") { - return std::make_unique(); - } - throw std::domain_error("invalid encoding type"); -} - int main(int argc, const char** argv) { bool extract = false; bool safe = false; @@ -379,7 +336,7 @@ int main(int argc, const char** argv) { } BlockInvertedIndex index( - MemorySource::mapped_file(app.index_filename()), resolve_codec(app.index_encoding()) + MemorySource::mapped_file(app.index_filename()), get_block_codec(app.index_encoding()) ); auto params = std::make_tuple(