From 12d5fb62edf71b43a8d6ed2e80345fe019b8e9f7 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 26 Apr 2020 19:45:59 +0000 Subject: [PATCH 01/21] Query contaier --- include/pisa/query.hpp | 99 +++++++++++++++++++++++++ src/query.cpp | 161 +++++++++++++++++++++++++++++++++++++++++ test/test_query.cpp | 84 +++++++++++++++++++++ 3 files changed, 344 insertions(+) create mode 100644 include/pisa/query.hpp create mode 100644 src/query.cpp create mode 100644 test/test_query.cpp diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp new file mode 100644 index 000000000..7f481ee08 --- /dev/null +++ b/include/pisa/query.hpp @@ -0,0 +1,99 @@ +#pragma once + +#include +#include +#include +#include +#include + +#include + +namespace pisa { + +struct QueryContainerInner; + +struct ParsedTerm { + std::uint32_t id; + std::string term; +}; + +using TermProcessorFn = std::function(std::string)>; +using ParseFn = std::function(std::string const&)>; + +class QueryContainer; + +/// Query is a special container that maintains important invariants, such as sorted term IDs, +/// and also has some additional data, like term weights, etc. +class Query { + public: + explicit Query(QueryContainer const& data); + + [[nodiscard]] auto term_ids() const -> gsl::span; + [[nodiscard]] auto threshold() const -> std::optional; + + private: + std::optional m_threshold{}; + std::vector m_term_ids{}; +}; + +class QueryContainer { + public: + QueryContainer(QueryContainer const&); + QueryContainer(QueryContainer&&) noexcept; + QueryContainer& operator=(QueryContainer const&); + QueryContainer& operator=(QueryContainer&&) noexcept; + ~QueryContainer(); + + /// Constructs a query from a raw string. + [[nodiscard]] static auto raw(std::string query_string) -> QueryContainer; + + /// Constructs a query from a list of terms. + /// + /// \param terms List of terms + /// \param term_processor Function executed for each term before stroring them, + /// e.g., stemming or filtering. This function returns + /// `std::optional`, and all `std::nullopt` values + /// will be filtered out. + [[nodiscard]] static auto + from_terms(std::vector terms, std::optional term_processor) + -> QueryContainer; + + /// Constructs a query from a list of term IDs. + [[nodiscard]] static auto from_term_ids(std::vector term_ids) -> QueryContainer; + + // Accessors + + [[nodiscard]] auto string() const noexcept -> std::optional const&; + [[nodiscard]] auto terms() const noexcept -> std::optional> const&; + [[nodiscard]] auto term_ids() const noexcept -> std::optional> const&; + [[nodiscard]] auto threshold() const noexcept -> std::optional const&; + + /// Sets the raw string. + [[nodiscard]] auto string(std::string) -> QueryContainer&; + + /// Sets processed terms. + /// + /// NOTE: If the intent is to parse the query, use `parse` method instead. + /// This method is intended to be used when loading a query from JSON or another + /// external representation. + /// + /// \throws std::domain_error when term IDs are set but the lengths don't match + auto processed_terms(std::vector terms) -> QueryContainer&; + + /// Parses the raw query with the given parser. + /// + /// \throws std::domain_error when raw string is not set + auto parse(ParseFn parse_fn) -> QueryContainer&; + + /// Sets the query score threshold. + auto threshold(float score) -> QueryContainer&; + + /// Returns a query ready to be used for retrieval. + [[nodiscard]] auto query() const -> Query; + + private: + QueryContainer(); + std::unique_ptr m_data; +}; + +} // namespace pisa diff --git a/src/query.cpp b/src/query.cpp new file mode 100644 index 000000000..2838d9dcb --- /dev/null +++ b/src/query.cpp @@ -0,0 +1,161 @@ +#include "query.hpp" + +#include + +#include + +namespace pisa { + +Query::Query(QueryContainer const& data) : m_threshold(data.threshold()) +{ + if (auto term_ids = data.term_ids(); term_ids) { + m_term_ids = *term_ids; + std::sort(m_term_ids.begin(), m_term_ids.end()); + auto last = std::unique(m_term_ids.begin(), m_term_ids.end()); + m_term_ids.erase(last, m_term_ids.end()); + } + throw std::domain_error("Query not parsed."); +} + +auto Query::term_ids() const -> gsl::span +{ + return gsl::span(m_term_ids); +} + +auto Query::threshold() const -> std::optional +{ + return m_threshold; +} + +struct QueryContainerInner { + std::optional query_string; + std::optional> processed_terms; + std::optional> term_ids; + std::optional threshold; +}; + +QueryContainer::QueryContainer() : m_data(std::make_unique()) {} + +QueryContainer::QueryContainer(QueryContainer const& other) + : m_data(std::make_unique(*other.m_data)) +{} +QueryContainer::QueryContainer(QueryContainer&&) noexcept = default; +QueryContainer& QueryContainer::operator=(QueryContainer const& other) +{ + this->m_data = std::make_unique(*other.m_data); + return *this; +} +QueryContainer& QueryContainer::operator=(QueryContainer&&) noexcept = default; +QueryContainer::~QueryContainer() = default; + +auto QueryContainer::raw(std::string query_string) -> QueryContainer +{ + QueryContainer query; + query.m_data->query_string = std::move(query_string); + return query; +} + +auto QueryContainer::from_terms( + std::vector terms, std::optional term_processor) -> QueryContainer +{ + QueryContainer query; + query.m_data->processed_terms = std::vector{}; + auto& processed_terms = *query.m_data->processed_terms; + for (auto&& term: terms) { + if (term_processor) { + auto fn = *term_processor; + if (auto processed = fn(std::move(term)); processed) { + processed_terms.push_back(std::move(*processed)); + } + } else { + processed_terms.push_back(std::move(term)); + } + } + return query; +} + +auto QueryContainer::from_term_ids(std::vector term_ids) -> QueryContainer +{ + QueryContainer query; + query.m_data->term_ids = std::move(term_ids); + return query; +} + +auto QueryContainer::string() const noexcept -> std::optional const& +{ + return m_data->query_string; +} +auto QueryContainer::terms() const noexcept -> std::optional> const& +{ + return m_data->processed_terms; +} + +auto QueryContainer::term_ids() const noexcept -> std::optional> const& +{ + return m_data->term_ids; +} + +auto QueryContainer::threshold() const noexcept -> std::optional const& +{ + return m_data->threshold; +} + +auto QueryContainer::string(std::string raw_query) -> QueryContainer& +{ + m_data->query_string = std::move(raw_query); + return *this; +} + +auto QueryContainer::processed_terms(std::vector terms) -> QueryContainer& +{ + if (auto&& term_ids = m_data->term_ids; term_ids.has_value() && term_ids->size() != terms.size()) { + throw std::domain_error(fmt::format( + "Number of terms ({}) must match number of term IDs ({})", + fmt::join(terms, ", "), + fmt::join(*term_ids, ", "))); + } + m_data->processed_terms = std::move(terms); + return *this; +} + +// auto QueryContainer::term_ids(std::vector term_ids) -> Query& +//{ +// if (auto&& terms = m_data->processed_terms; +// terms.has_value() && terms->size() != term_ids.size()) { +// throw std::domain_error(fmt::format( +// "Number of terms ({}) must match number of term IDs ({})", +// fmt::join(*terms, ", "), +// fmt::join(term_ids, ", "))); +// } +// m_data->term_ids = std::move(term_ids); +// return *this; +//} + +auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& +{ + if (not m_data->query_string) { + throw std::domain_error("Cannot parse, query string not set"); + } + auto parsed_terms = parse_fn(*m_data->query_string); + std::vector processed_terms; + std::vector term_ids; + for (auto&& term: parsed_terms) { + processed_terms.push_back(std::move(term.term)); + term_ids.push_back(term.id); + } + m_data->term_ids = std::move(term_ids); + return *this; +} + +auto QueryContainer::threshold(float score) -> QueryContainer& +{ + m_data->threshold = score; + return *this; +} + +auto QueryContainer::query() const -> Query +{ + return Query(*this); +} + +} // namespace pisa diff --git a/test/test_query.cpp b/test/test_query.cpp new file mode 100644 index 000000000..e65ff1283 --- /dev/null +++ b/test/test_query.cpp @@ -0,0 +1,84 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "query.hpp" + +using pisa::QueryContainer; + +TEST_CASE("Construct from raw string") +{ + auto raw_query = "brooklyn tea house"; + auto query = QueryContainer::raw(raw_query); + REQUIRE(*query.string() == raw_query); +} + +TEST_CASE("Construct from terms") +{ + std::vector terms{"brooklyn", "tea", "house"}; + auto query = QueryContainer::from_terms(terms, std::nullopt); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); +} + +TEST_CASE("Construct from terms with processor") +{ + std::vector terms{"brooklyn", "tea", "house"}; + auto proc = [](std::string term) -> std::optional { + if (term.size() > 3) { + return term.substr(0, 4); + } + return std::nullopt; + }; + auto query = QueryContainer::from_terms(terms, proc); + REQUIRE(*query.terms() == std::vector{"broo", "hous"}); +} + +TEST_CASE("Construct from term IDs") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); +} + +TEST_CASE("Set processed terms") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + query.processed_terms(std::vector{"brooklyn", "tea", "house"}); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); + REQUIRE_THROWS_AS( + query.processed_terms(std::vector{"tea", "house"}), std::domain_error); +} + +TEST_CASE("Parse query") +{ + auto raw_query = "brooklyn tea house brooklyn"; + auto query = QueryContainer::raw(raw_query); + std::vector lexicon{"house", "brooklyn"}; + auto term_proc = [](std::string term) -> std::optional { return term; }; + query.parse([&](auto&& q) { + std::istringstream is(q); + std::string term; + std::vector parsed_terms; + while (is >> term) { + if (auto t = term_proc(term); t) { + if (auto pos = std::find(lexicon.begin(), lexicon.end(), *t); pos != lexicon.end()) { + auto id = static_cast(std::distance(lexicon.begin(), pos)); + parsed_terms.push_back(pisa::ParsedTerm{id, *t}); + } + } + } + return parsed_terms; + }); + REQUIRE(*query.term_ids() == std::vector{1, 0, 1}); +} + +TEST_CASE("Parsing throws without raw query") +{ + std::vector term_ids{1, 0, 3}; + auto query = QueryContainer::from_term_ids(term_ids); + REQUIRE_THROWS_AS( + query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); +} From 1d3aa8851d77eeef4f4512082e3725878e9eba14 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 26 Apr 2020 23:12:42 +0000 Subject: [PATCH 02/21] Query container parsing --- CMakeLists.txt | 1 + include/pisa/query.hpp | 32 +++++---- src/query.cpp | 125 ++++++++++++++++++++++++++--------- test/test_query.cpp | 79 +++++++++++++++++++--- tools/CMakeLists.txt | 6 ++ tools/app.hpp | 15 +++++ tools/filter_queries.cpp | 138 +++++++++++++++++++++++++++++++++++++++ 7 files changed, 344 insertions(+), 52 deletions(-) create mode 100644 tools/filter_queries.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index 86e0f2460..769fee467 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -105,6 +105,7 @@ target_link_libraries(pisa PUBLIC # TODO(michal): are there any of these we can spdlog fmt::fmt range-v3 + nlohmann_json::nlohmann_json ) target_include_directories(pisa PUBLIC external) diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index 7f481ee08..1b39e3d51 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -24,9 +24,9 @@ class QueryContainer; /// Query is a special container that maintains important invariants, such as sorted term IDs, /// and also has some additional data, like term weights, etc. -class Query { +class QueryRequest { public: - explicit Query(QueryContainer const& data); + explicit QueryRequest(QueryContainer const& data); [[nodiscard]] auto term_ids() const -> gsl::span; [[nodiscard]] auto threshold() const -> std::optional; @@ -61,8 +61,25 @@ class QueryContainer { /// Constructs a query from a list of term IDs. [[nodiscard]] static auto from_term_ids(std::vector term_ids) -> QueryContainer; + /// Constructs a query from a JSON object. + [[nodiscard]] static auto from_json(std::string_view json_string) -> QueryContainer; + + [[nodiscard]] auto to_json() const -> std::string; + + /// Constructs a query from a colon-separated format: + /// + /// ``` + /// id:raw query string + /// ``` + /// or + /// ``` + /// raw query string + /// ``` + [[nodiscard]] static auto from_colon_format(std::string_view line) -> QueryContainer; + // Accessors + [[nodiscard]] auto id() const noexcept -> std::optional const&; [[nodiscard]] auto string() const noexcept -> std::optional const&; [[nodiscard]] auto terms() const noexcept -> std::optional> const&; [[nodiscard]] auto term_ids() const noexcept -> std::optional> const&; @@ -71,15 +88,6 @@ class QueryContainer { /// Sets the raw string. [[nodiscard]] auto string(std::string) -> QueryContainer&; - /// Sets processed terms. - /// - /// NOTE: If the intent is to parse the query, use `parse` method instead. - /// This method is intended to be used when loading a query from JSON or another - /// external representation. - /// - /// \throws std::domain_error when term IDs are set but the lengths don't match - auto processed_terms(std::vector terms) -> QueryContainer&; - /// Parses the raw query with the given parser. /// /// \throws std::domain_error when raw string is not set @@ -89,7 +97,7 @@ class QueryContainer { auto threshold(float score) -> QueryContainer&; /// Returns a query ready to be used for retrieval. - [[nodiscard]] auto query() const -> Query; + [[nodiscard]] auto query() const -> QueryRequest; private: QueryContainer(); diff --git a/src/query.cpp b/src/query.cpp index 2838d9dcb..39dedc846 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -3,10 +3,11 @@ #include #include +#include namespace pisa { -Query::Query(QueryContainer const& data) : m_threshold(data.threshold()) +QueryRequest::QueryRequest(QueryContainer const& data) : m_threshold(data.threshold()) { if (auto term_ids = data.term_ids(); term_ids) { m_term_ids = *term_ids; @@ -17,17 +18,18 @@ Query::Query(QueryContainer const& data) : m_threshold(data.threshold()) throw std::domain_error("Query not parsed."); } -auto Query::term_ids() const -> gsl::span +auto QueryRequest::term_ids() const -> gsl::span { return gsl::span(m_term_ids); } -auto Query::threshold() const -> std::optional +auto QueryRequest::threshold() const -> std::optional { return m_threshold; } struct QueryContainerInner { + std::optional id; std::optional query_string; std::optional> processed_terms; std::optional> term_ids; @@ -81,6 +83,10 @@ auto QueryContainer::from_term_ids(std::vector term_ids) -> Query return query; } +auto QueryContainer::id() const noexcept -> std::optional const& +{ + return m_data->id; +} auto QueryContainer::string() const noexcept -> std::optional const& { return m_data->query_string; @@ -106,31 +112,6 @@ auto QueryContainer::string(std::string raw_query) -> QueryContainer& return *this; } -auto QueryContainer::processed_terms(std::vector terms) -> QueryContainer& -{ - if (auto&& term_ids = m_data->term_ids; term_ids.has_value() && term_ids->size() != terms.size()) { - throw std::domain_error(fmt::format( - "Number of terms ({}) must match number of term IDs ({})", - fmt::join(terms, ", "), - fmt::join(*term_ids, ", "))); - } - m_data->processed_terms = std::move(terms); - return *this; -} - -// auto QueryContainer::term_ids(std::vector term_ids) -> Query& -//{ -// if (auto&& terms = m_data->processed_terms; -// terms.has_value() && terms->size() != term_ids.size()) { -// throw std::domain_error(fmt::format( -// "Number of terms ({}) must match number of term IDs ({})", -// fmt::join(*terms, ", "), -// fmt::join(term_ids, ", "))); -// } -// m_data->term_ids = std::move(term_ids); -// return *this; -//} - auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& { if (not m_data->query_string) { @@ -153,9 +134,93 @@ auto QueryContainer::threshold(float score) -> QueryContainer& return *this; } -auto QueryContainer::query() const -> Query +auto QueryContainer::query() const -> QueryRequest { - return Query(*this); + return QueryRequest(*this); +} + +template +[[nodiscard]] auto get(nlohmann::json const& node, std::string_view field) -> std::optional +{ + if (auto pos = node.find(field); pos != node.end()) { + try { + return std::make_optional(pos->get()); + } catch (nlohmann::detail::exception const& err) { + throw std::runtime_error(fmt::format("Requested field {} is of wrong type", field)); + } + } + return std::optional{}; +} + +auto QueryContainer::from_json(std::string_view json_string) -> QueryContainer +{ + try { + auto json = nlohmann::json::parse(json_string); + QueryContainer query; + QueryContainerInner& data = *query.m_data; + bool at_least_one_required = false; + if (auto id = get(json, "id"); id) { + data.id = std::move(id); + } + if (auto raw = get(json, "query"); raw) { + data.query_string = std::move(raw); + at_least_one_required = true; + } + if (auto terms = get>(json, "terms"); terms) { + data.processed_terms = std::move(terms); + at_least_one_required = true; + } + if (auto term_ids = get>(json, "term_ids"); term_ids) { + data.term_ids = std::move(term_ids); + at_least_one_required = true; + } + if (auto threshold = get(json, "threshold"); threshold) { + data.threshold = threshold; + } + if (not at_least_one_required) { + throw std::invalid_argument(fmt::format( + "JSON must have either raw query, terms, or term IDs: {}", json_string)); + } + return query; + } catch (nlohmann::detail::exception const& err) { + throw std::runtime_error( + fmt::format("Failed to parse JSON: `{}`: {}", json_string, err.what())); + } +} + +auto QueryContainer::to_json() const -> std::string +{ + nlohmann::json json; + if (auto id = m_data->id; id) { + json["id"] = *id; + } + if (auto raw = m_data->query_string; raw) { + json["query"] = *raw; + } + if (auto terms = m_data->processed_terms; terms) { + json["terms"] = *terms; + } + if (auto term_ids = m_data->term_ids; term_ids) { + json["term_ids"] = *term_ids; + } + if (auto threshold = m_data->threshold; threshold) { + json["threshold"] = *threshold; + } + return json.dump(); +} + +auto QueryContainer::from_colon_format(std::string_view line) -> QueryContainer +{ + auto pos = std::find(line.begin(), line.end(), ':'); + QueryContainer query; + QueryContainerInner& data = *query.m_data; + if (pos == line.end()) { + data.query_string = std::string(line); + } else { + data.id = std::string(line.begin(), pos); + data.query_string = std::string(std::next(pos), line.end()); + } + return query; } } // namespace pisa diff --git a/test/test_query.cpp b/test/test_query.cpp index e65ff1283..e3e6128e7 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -42,16 +42,6 @@ TEST_CASE("Construct from term IDs") REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); } -TEST_CASE("Set processed terms") -{ - std::vector term_ids{1, 0, 3}; - auto query = QueryContainer::from_term_ids(term_ids); - query.processed_terms(std::vector{"brooklyn", "tea", "house"}); - REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); - REQUIRE_THROWS_AS( - query.processed_terms(std::vector{"tea", "house"}), std::domain_error); -} - TEST_CASE("Parse query") { auto raw_query = "brooklyn tea house brooklyn"; @@ -82,3 +72,72 @@ TEST_CASE("Parsing throws without raw query") REQUIRE_THROWS_AS( query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); } + +TEST_CASE("Parse query container from colon-delimited format") +{ + auto query = QueryContainer::from_colon_format(""); + REQUIRE(query.string()->empty()); + REQUIRE_FALSE(query.id()); + + query = QueryContainer::from_colon_format("brooklyn tea house"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE_FALSE(query.id()); + + query = QueryContainer::from_colon_format("BTH:brooklyn tea house"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE(*query.id() == "BTH"); + + query = QueryContainer::from_colon_format("BTH:"); + REQUIRE(query.string()->empty()); + REQUIRE(*query.id() == "BTH"); +} + +TEST_CASE("Parse query container from JSON") +{ + REQUIRE_THROWS_AS(QueryContainer::from_json(""), std::runtime_error); + REQUIRE_THROWS_AS(QueryContainer::from_json(R"({"id":"ID"})"), std::invalid_argument); + + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house" + } + )"); + REQUIRE(*query.id() == "ID"); + REQUIRE(*query.string() == "brooklyn tea house"); + REQUIRE_FALSE(query.terms()); + REQUIRE_FALSE(query.term_ids()); + REQUIRE_FALSE(query.threshold()); + + query = QueryContainer::from_json(R"( + { + "term_ids": [1, 0, 3], + "terms": ["brooklyn", "tea", "house"], + "threshold": 10.8 + } + )"); + REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); + REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); + REQUIRE(*query.threshold() == Approx(10.8)); + REQUIRE_FALSE(query.id()); + REQUIRE_FALSE(query.string()); + + REQUIRE_THROWS_AS(QueryContainer::from_json(R"({"terms":[1, 2]})"), std::runtime_error); +} + +TEST_CASE("Serialize query container to JSON") +{ + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house", + "terms": ["brooklyn", "tea", "house"], + "term_ids": [1, 0, 3], + "threshold": 10.0 + } + )"); + auto serialized = query.to_json(); + REQUIRE( + serialized + == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"threshold":10.0})"); +} diff --git a/tools/CMakeLists.txt b/tools/CMakeLists.txt index d011d56ea..0bc424e99 100644 --- a/tools/CMakeLists.txt +++ b/tools/CMakeLists.txt @@ -136,3 +136,9 @@ target_link_libraries(reorder-docids pisa CLI11 ) + +add_executable(filter-queries filter_queries.cpp) +target_link_libraries(filter-queries + pisa + CLI11 +) diff --git a/tools/app.hpp b/tools/app.hpp index 3e723e9d8..16a4ac314 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -101,6 +101,21 @@ namespace arg { return q; } + [[nodiscard]] auto term_lexicon() const -> std::optional const& + { + return m_term_lexicon; + } + + [[nodiscard]] auto stemmer() const -> std::optional const& + { + return m_stemmer; + } + + [[nodiscard]] auto stop_words() const -> std::optional const& + { + return m_stop_words; + } + [[nodiscard]] auto k() const -> int { return m_k; } private: diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp new file mode 100644 index 000000000..0b637d979 --- /dev/null +++ b/tools/filter_queries.cpp @@ -0,0 +1,138 @@ +#include + +#include +#include +#include + +#include "app.hpp" +#include "query.hpp" +#include "tokenizer.hpp" + +namespace arg = pisa::arg; +using pisa::QueryContainer; +using pisa::io::for_each_line; + +class TermProcessor { + private: + std::unordered_set stopwords; + + std::function(std::string const&)> m_to_id; + pisa::Stemmer_t m_stemmer; + + public: + TermProcessor( + std::optional const& terms_file, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) + { + auto source = std::make_shared(terms_file->c_str()); + auto terms = pisa::Payload_Vector<>::from(*source); + + m_to_id = [source = std::move(source), terms](auto str) -> std::optional { + // Note: the lexicographical order of the terms matters. + auto pos = std::lower_bound(terms.begin(), terms.end(), std::string_view(str)); + if (*pos == std::string_view(str)) { + return std::distance(terms.begin(), pos); + } + return std::nullopt; + }; + + m_stemmer = pisa::term_processor(stemmer_type); + + if (stopwords_filename) { + std::ifstream is(*stopwords_filename); + pisa::io::for_each_line(is, [&](auto&& word) { + if (auto processed_term = m_to_id(std::move(word)); processed_term.has_value()) { + stopwords.insert(*processed_term); + } + }); + } + } + + [[nodiscard]] std::optional operator()(std::string token) + { + token = m_stemmer(token); + auto id = m_to_id(token); + if (not id) { + return std::nullopt; + } + if (is_stopword(*id)) { + return std::nullopt; + } + return pisa::ParsedTerm{*id, token}; + } + + [[nodiscard]] auto is_stopword(std::uint32_t const term) const -> bool + { + return stopwords.find(term) != stopwords.end(); + } +}; + +enum class Format { Json, Colon }; + +void filter_queries( + std::optional const& query_file, + std::optional const& term_lexicon, + std::optional const& stemmer, + std::size_t min_query_len, + std::size_t max_query_len) +{ + std::optional fmt{}; + auto parser = [term_processor = TermProcessor(term_lexicon, {}, stemmer)](auto query) mutable { + std::vector parsed_terms; + pisa::TermTokenizer tokenizer(query); + for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { + auto term = term_processor(*term_iter); + if (term) { + parsed_terms.push_back(std::move(*term)); + } + } + return parsed_terms; + }; + auto filter = [&](auto&& line) { + auto query = [&] { + if (fmt) { + if (*fmt == Format::Json) { + return QueryContainer::from_json(line); + } + return QueryContainer::from_colon_format(line); + } + try { + auto query = QueryContainer::from_json(line); + fmt = Format::Json; + return query; + } catch (std::exception const& err) { + fmt = Format::Colon; + return QueryContainer::from_colon_format(line); + } + }(); + query.parse(parser); + if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { + std::cout << query.to_json() << '\n'; + } + }; + if (query_file) { + std::ifstream is(*query_file); + for_each_line(is, filter); + } else { + for_each_line(std::cin, filter); + } +} + +int main(int argc, char** argv) +{ + spdlog::drop(""); + spdlog::set_default_logger(spdlog::stderr_color_mt("")); + + std::size_t min_query_len = 1; + std::size_t max_query_len = std::numeric_limits::max(); + + pisa::App> app( + "Filters out empty queries against a v1 index."); + app.add_option("--min", min_query_len, "Minimum query legth to consider"); + app.add_option("--max", max_query_len, "Maximum query legth to consider"); + CLI11_PARSE(app, argc, argv); + + filter_queries(app.query_file(), app.term_lexicon(), app.stemmer(), min_query_len, max_query_len); + return 0; +} From cdc17f3ae858dd72418bfb5163540a74e077b96a Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 14:08:05 +0000 Subject: [PATCH 03/21] CLI test --- .travis.yml | 10 +++++- include/pisa/query.hpp | 4 +-- src/query.cpp | 1 + test/cli/run.sh | 5 +++ test/cli/setup.sh | 25 ++++++++++++++ test/cli/test_filter_queries.sh | 58 +++++++++++++++++++++++++++++++++ tools/filter_queries.cpp | 14 ++++++-- 7 files changed, 111 insertions(+), 6 deletions(-) create mode 100755 test/cli/run.sh create mode 100755 test/cli/setup.sh create mode 100644 test/cli/test_filter_queries.sh diff --git a/.travis.yml b/.travis.yml index 58c62d718..596f5a76b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -56,7 +56,7 @@ matrix: apt: sources: *all_sources packages: ['g++-9'] - env: MATRIX_EVAL="CC=gcc-9 && CXX=g++-9 && COVERAGE=Off && DOCKER=Off" + env: MATRIX_EVAL="CC=gcc-9 && CXX=g++-9 && COVERAGE=Off && DOCKER=Off && TEST_CLI=On" - os: linux dist: xenial compiler: clang @@ -112,6 +112,11 @@ before_install: brew install ccache; export PATH="/usr/local/opt/ccache/libexec:$PATH"; fi + - if [[ "$TEST_CLI" == "On" ]]; then + git clone https://github.com/sstephenson/bats.git + cd bats + sudo ./install.sh /usr/local + fi - eval "${MATRIX_EVAL}" script: @@ -121,6 +126,9 @@ script: make -j2; if [[ "$TIDY" != "On" ]]; then CTEST_OUTPUT_ON_FAILURE=TRUE ctest -j2; + if [[ "$TEST_CLI" != "On" ]]; then + bash ../test/cli/run.sh + fi fi fi - if [[ "$CLANG_FORMAT" == "On" ]]; then diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index 1b39e3d51..6bb66be5c 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -22,8 +22,8 @@ using ParseFn = std::function(std::string const&)>; class QueryContainer; -/// Query is a special container that maintains important invariants, such as sorted term IDs, -/// and also has some additional data, like term weights, etc. +/// QueryRequest is a special container that maintains important invariants, such as sorted term +/// IDs, and also has some additional data, like term weights, etc. class QueryRequest { public: explicit QueryRequest(QueryContainer const& data); diff --git a/src/query.cpp b/src/query.cpp index 39dedc846..b2aacfca6 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -125,6 +125,7 @@ auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& term_ids.push_back(term.id); } m_data->term_ids = std::move(term_ids); + m_data->processed_terms = std::move(processed_terms); return *this; } diff --git a/test/cli/run.sh b/test/cli/run.sh new file mode 100755 index 000000000..51486a077 --- /dev/null +++ b/test/cli/run.sh @@ -0,0 +1,5 @@ +#!/usr/bin/env bash + +DIR=$(dirname "$0") +$DIR/setup.sh +bats $DIR/test_filter_queries.sh diff --git a/test/cli/setup.sh b/test/cli/setup.sh new file mode 100755 index 000000000..c55753f1a --- /dev/null +++ b/test/cli/setup.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash + +# This script should be executed within the build directory that is directly +# in the project directory, e.g., /path/to/pisa/build + +PISA_BIN="./bin" +export PATH="$PISA_BIN:$PATH" + +cat "../test/test_data/clueweb1k.plaintext" | parse_collection \ + --stemmer porter2 \ + --output "./fwd" \ + --format plaintext + +invert --input "./fwd" --output "./inv" + +compress_inverted_index --check \ + --encoding block_simdbp \ + --collection "./inv" \ + --output "./simdbp" + +create_wand_data \ + --scorer bm25 \ + --collection "./inv" \ + --output "./bm25.bmw" \ + --block-size 32 diff --git a/test/cli/test_filter_queries.sh b/test/cli/test_filter_queries.sh new file mode 100644 index 000000000..44e8f0f86 --- /dev/null +++ b/test/cli/test_filter_queries.sh @@ -0,0 +1,58 @@ +#!/usr/bin/env bats + +PISA_BIN="bin" +export PATH="$PISA_BIN:$PATH" + +function write_lines { + file=$1 + rm -f "$file" + shift + for line in "$@" + do + echo "$line" >> "$file" + done +} + + +function setup { + write_lines "$BATS_TMPDIR/queries.txt" "brooklyn tea house" "labradoodle" 'Tell your dog I said "hi"' + write_lines "$BATS_TMPDIR/stopwords.txt" "i" "your" +} + +@test "Filter from plain" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with minimum length" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --min 4 --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with maximum length" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --max 4 --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter with stopwords" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stopwords "$BATS_TMPDIR/stopwords.txt" --stemmer porter2 --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,10396,26032,15114],"terms":["tell","dog","said","hi"]}' + [[ "$result" = "$expected" ]] +} + +@test "Filter without stemmer" { + result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --terms ./fwd.termlex) + echo "$result" > "$BATS_TMPDIR/result" + expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]} +{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' + [[ "$result" = "$expected" ]] +} diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index 0b637d979..17748126c 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -74,11 +74,13 @@ void filter_queries( std::optional const& query_file, std::optional const& term_lexicon, std::optional const& stemmer, + std::optional const& stopwords_filename, std::size_t min_query_len, std::size_t max_query_len) { std::optional fmt{}; - auto parser = [term_processor = TermProcessor(term_lexicon, {}, stemmer)](auto query) mutable { + auto parser = [term_processor = TermProcessor(term_lexicon, stopwords_filename, stemmer)]( + auto query) mutable { std::vector parsed_terms; pisa::TermTokenizer tokenizer(query); for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { @@ -127,12 +129,18 @@ int main(int argc, char** argv) std::size_t min_query_len = 1; std::size_t max_query_len = std::numeric_limits::max(); - pisa::App> app( + pisa::App> app( "Filters out empty queries against a v1 index."); app.add_option("--min", min_query_len, "Minimum query legth to consider"); app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); - filter_queries(app.query_file(), app.term_lexicon(), app.stemmer(), min_query_len, max_query_len); + filter_queries( + app.query_file(), + app.term_lexicon(), + app.stemmer(), + app.stop_words(), + min_query_len, + max_query_len); return 0; } From b0e5d1a7a84f05292aa92fec4772ae06b9da18fa Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 15:57:42 +0000 Subject: [PATCH 04/21] Fix .travis.yml syntax --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 596f5a76b..685313f01 100644 --- a/.travis.yml +++ b/.travis.yml @@ -127,7 +127,7 @@ script: if [[ "$TIDY" != "On" ]]; then CTEST_OUTPUT_ON_FAILURE=TRUE ctest -j2; if [[ "$TEST_CLI" != "On" ]]; then - bash ../test/cli/run.sh + bash ../test/cli/run.sh; fi fi fi From 6e2ab62bf185afb0cce0ddf8b37d269f561c80d0 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 16:08:55 +0000 Subject: [PATCH 05/21] Fix .travis.yml syntax --- .travis.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.travis.yml b/.travis.yml index 685313f01..34604e9d5 100644 --- a/.travis.yml +++ b/.travis.yml @@ -113,9 +113,9 @@ before_install: export PATH="/usr/local/opt/ccache/libexec:$PATH"; fi - if [[ "$TEST_CLI" == "On" ]]; then - git clone https://github.com/sstephenson/bats.git - cd bats - sudo ./install.sh /usr/local + git clone https://github.com/sstephenson/bats.git; + cd bats; + sudo ./install.sh /usr/local; fi - eval "${MATRIX_EVAL}" From 2cce2cdaea040abc7e11b2e0982c0c135177f5a1 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Mon, 27 Apr 2020 16:42:21 +0000 Subject: [PATCH 06/21] Fix when cli test are executed --- .travis.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.travis.yml b/.travis.yml index 34604e9d5..241028838 100644 --- a/.travis.yml +++ b/.travis.yml @@ -126,7 +126,7 @@ script: make -j2; if [[ "$TIDY" != "On" ]]; then CTEST_OUTPUT_ON_FAILURE=TRUE ctest -j2; - if [[ "$TEST_CLI" != "On" ]]; then + if [[ "$TEST_CLI" == "On" ]]; then bash ../test/cli/run.sh; fi fi From 1838258a1d0c69c0edf2a43dccd9c81ea348077d Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Tue, 28 Apr 2020 23:47:18 +0000 Subject: [PATCH 07/21] Refactor out common code from tool --- include/pisa/query.hpp | 37 ++++++++- include/pisa/query/parser.hpp | 52 +++++++++++++ src/query.cpp | 39 ++++++++++ src/query/parser.cpp | 94 +++++++++++++++++++++++ test/cli/test_filter_queries.sh | 26 +++++-- test/test_query.cpp | 6 +- test/test_query_parser.cpp | 28 +++++++ tools/filter_queries.cpp | 130 ++++++++------------------------ 8 files changed, 302 insertions(+), 110 deletions(-) create mode 100644 include/pisa/query/parser.hpp create mode 100644 src/query/parser.cpp create mode 100644 test/test_query_parser.cpp diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index 6bb66be5c..f40055e06 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -1,6 +1,7 @@ #pragma once #include +#include #include #include #include @@ -12,13 +13,13 @@ namespace pisa { struct QueryContainerInner; -struct ParsedTerm { +struct ResolvedTerm { std::uint32_t id; std::string term; }; using TermProcessorFn = std::function(std::string)>; -using ParseFn = std::function(std::string const&)>; +using ParseFn = std::function(std::string const&)>; class QueryContainer; @@ -104,4 +105,36 @@ class QueryContainer { std::unique_ptr m_data; }; +enum class Format { Json, Colon }; + +class QueryReader { + public: + /// Open reader from file. + static auto from_file(std::string const& file) -> QueryReader; + /// Open reader from stdin. + static auto from_stdin() -> QueryReader; + + /// Read next query or return `nullopt` if stream has ended. + [[nodiscard]] auto next() -> std::optional; + + /// Execute `fn(q)` for each query `q`. + template + void for_each(Fn&& fn) + { + auto query = next(); + while (query) { + fn(std::move(*query)); + query = next(); + } + } + + private: + explicit QueryReader(std::unique_ptr stream, std::istream& stream_ref); + + std::unique_ptr m_stream; + std::istream& m_stream_ref; + std::string m_line_buf{}; + std::optional m_format{}; +}; + } // namespace pisa diff --git a/include/pisa/query/parser.hpp b/include/pisa/query/parser.hpp new file mode 100644 index 000000000..589c5c52b --- /dev/null +++ b/include/pisa/query/parser.hpp @@ -0,0 +1,52 @@ +#pragma once + +#include +#include +#include + +#include "query.hpp" + +namespace pisa { + +using TermResolver = std::function(std::string)>; + +struct StandardTermResolverParams; + +/// Provides a standard implementation of `TermResolver`. +class StandardTermResolver { + public: + StandardTermResolver( + std::string const& term_lexicon_path, + std::optional const& stopwords_filename, + std::optional const& stemmer_type); + StandardTermResolver(StandardTermResolver const&); + StandardTermResolver(StandardTermResolver&&) noexcept; + StandardTermResolver& operator=(StandardTermResolver const&); + StandardTermResolver& operator=(StandardTermResolver&&) noexcept; + ~StandardTermResolver(); + + [[nodiscard]] auto operator()(std::string token) const -> std::optional; + + private: + [[nodiscard]] auto is_stopword(std::uint32_t const term) const -> bool; + + std::unique_ptr m_self; +}; + +/// Parses a query string to processed terms. +class QueryParser { + public: + explicit QueryParser(TermResolver term_processor); + /// Given a query string, it returns a list of (possibly processed) terms. + /// + /// Possible transformations of terms include lower-casing and stemming. + /// Some terms could be also removed, e.g., because they are on a list of + /// stop words. The exact implementation depends on the term processor + /// passed to the constructor. + auto operator()(std::string const&) -> std::vector; + + private: + TermResolver m_term_resolver; +}; + +} // namespace pisa diff --git a/src/query.cpp b/src/query.cpp index b2aacfca6..b202ce172 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -1,6 +1,8 @@ #include "query.hpp" #include +#include +#include #include #include @@ -224,4 +226,41 @@ auto QueryContainer::from_colon_format(std::string_view line) -> QueryContainer return query; } +auto QueryReader::from_file(std::string const& file) -> QueryReader +{ + auto input = std::make_unique(file); + auto& ref = *input; + return QueryReader(std::move(input), ref); +} + +auto QueryReader::from_stdin() -> QueryReader +{ + return QueryReader(nullptr, std::cin); +} + +QueryReader::QueryReader(std::unique_ptr input, std::istream& stream_ref) + : m_stream(std::move(input)), m_stream_ref(stream_ref) +{} + +auto QueryReader::next() -> std::optional +{ + if (std::getline(m_stream_ref, m_line_buf)) { + if (m_format) { + if (*m_format == Format::Json) { + return QueryContainer::from_json(m_line_buf); + } + return QueryContainer::from_colon_format(m_line_buf); + } + try { + auto query = QueryContainer::from_json(m_line_buf); + m_format = Format::Json; + return query; + } catch (std::exception const& err) { + m_format = Format::Colon; + return QueryContainer::from_colon_format(m_line_buf); + } + } + return std::nullopt; +} + } // namespace pisa diff --git a/src/query/parser.cpp b/src/query/parser.cpp new file mode 100644 index 000000000..178a425bf --- /dev/null +++ b/src/query/parser.cpp @@ -0,0 +1,94 @@ +#include + +#include "io.hpp" +#include "payload_vector.hpp" +#include "query.hpp" +#include "query/parser.hpp" +#include "query/term_processor.hpp" +#include "tokenizer.hpp" + +namespace pisa { + +StandardTermResolver::StandardTermResolver(StandardTermResolver const& other) + : m_self(std::make_unique(*other.m_self)) +{} +StandardTermResolver::StandardTermResolver(StandardTermResolver&&) noexcept = default; +StandardTermResolver& StandardTermResolver::operator=(StandardTermResolver const& other) +{ + m_self = std::make_unique(*other.m_self); + return *this; +} +StandardTermResolver& StandardTermResolver::operator=(StandardTermResolver&&) noexcept = default; +StandardTermResolver::~StandardTermResolver() = default; + +struct StandardTermResolverParams { + std::vector stopwords; + std::function(std::string const&)> to_id; + std::function transform; +}; + +StandardTermResolver::StandardTermResolver( + std::string const& term_lexicon_path, + std::optional const& stopwords_filename, + std::optional const& stemmer_type) + : m_self(std::make_unique()) +{ + auto source = std::make_shared(term_lexicon_path.c_str()); + auto terms = pisa::Payload_Vector<>::from(*source); + + m_self->to_id = [source = std::move(source), terms](auto str) -> std::optional { + auto pos = std::lower_bound(terms.begin(), terms.end(), std::string_view(str)); + if (*pos == std::string_view(str)) { + return std::distance(terms.begin(), pos); + } + return std::nullopt; + }; + + m_self->transform = pisa::term_processor(stemmer_type); + + if (stopwords_filename) { + std::ifstream is(*stopwords_filename); + pisa::io::for_each_line(is, [&](auto&& word) { + if (auto term_id = m_self->to_id(std::move(word)); term_id.has_value()) { + m_self->stopwords.push_back(*term_id); + } + }); + std::sort(m_self->stopwords.begin(), m_self->stopwords.end()); + } +} + +auto StandardTermResolver::operator()(std::string token) const -> std::optional +{ + token = m_self->transform(token); + auto id = m_self->to_id(token); + if (not id) { + return std::nullopt; + } + if (is_stopword(*id)) { + return std::nullopt; + } + return pisa::ResolvedTerm{*id, token}; +} + +auto StandardTermResolver::is_stopword(std::uint32_t const term) const -> bool +{ + auto pos = std::lower_bound(m_self->stopwords.begin(), m_self->stopwords.end(), term); + return pos != m_self->stopwords.end() && *pos == term; +} + +QueryParser::QueryParser(TermResolver term_resolver) : m_term_resolver(std::move(term_resolver)) {} + +auto QueryParser::operator()(std::string const& query) -> std::vector +{ + TermTokenizer tokenizer(query); + std::vector terms; + for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { + auto term = m_term_resolver(*term_iter); + if (term) { + terms.push_back(std::move(*term)); + } + } + return terms; +} + +} // namespace pisa diff --git a/test/cli/test_filter_queries.sh b/test/cli/test_filter_queries.sh index 44e8f0f86..2fa486074 100644 --- a/test/cli/test_filter_queries.sh +++ b/test/cli/test_filter_queries.sh @@ -21,7 +21,6 @@ function setup { @test "Filter from plain" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} {"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' [[ "$result" = "$expected" ]] @@ -29,21 +28,18 @@ function setup { @test "Filter with minimum length" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --min 4 --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' [[ "$result" = "$expected" ]] } @test "Filter with maximum length" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --max 4 --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}' [[ "$result" = "$expected" ]] } @test "Filter with stopwords" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --stopwords "$BATS_TMPDIR/stopwords.txt" --stemmer porter2 --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]} {"query":"Tell your dog I said \"hi\"","term_ids":[29287,10396,26032,15114],"terms":["tell","dog","said","hi"]}' [[ "$result" = "$expected" ]] @@ -51,8 +47,28 @@ function setup { @test "Filter without stemmer" { result=$(cat $BATS_TMPDIR/queries.txt | filter-queries --terms ./fwd.termlex) - echo "$result" > "$BATS_TMPDIR/result" expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]} {"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}' [[ "$result" = "$expected" ]] } + +@test "Accept JSON" { + echo '{"query":"brooklyn tea house"}' > "$BATS_TMPDIR/queries.json" + result=$(cat $BATS_TMPDIR/queries.json | filter-queries --terms ./fwd.termlex) + expected='{"query":"brooklyn tea house","term_ids":[6535,29194],"terms":["brooklyn","tea"]}' + [[ "$result" = "$expected" ]] +} + +@test "Accept JSON without --terms if already parsed" { + echo '{"term_ids":[6535,29194]}' > "$BATS_TMPDIR/queries.json" + result=$(cat $BATS_TMPDIR/queries.json | filter-queries) + expected='{"term_ids":[6535,29194]}' + [[ "$result" = "$expected" ]] +} + +@test "Fail when no --terms and not parsed" { + echo '{"query":"brooklyn tea house"}' > "$BATS_TMPDIR/queries.json" + run filter-queries < $BATS_TMPDIR/queries.json + [[ "$status" -eq 1 ]] + [[ "$output" = *"[error] Unresoved queries (without IDs) require term lexicon." ]] +} diff --git a/test/test_query.cpp b/test/test_query.cpp index e3e6128e7..00b3aea0e 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -51,12 +51,12 @@ TEST_CASE("Parse query") query.parse([&](auto&& q) { std::istringstream is(q); std::string term; - std::vector parsed_terms; + std::vector parsed_terms; while (is >> term) { if (auto t = term_proc(term); t) { if (auto pos = std::find(lexicon.begin(), lexicon.end(), *t); pos != lexicon.end()) { auto id = static_cast(std::distance(lexicon.begin(), pos)); - parsed_terms.push_back(pisa::ParsedTerm{id, *t}); + parsed_terms.push_back(pisa::ResolvedTerm{id, *t}); } } } @@ -70,7 +70,7 @@ TEST_CASE("Parsing throws without raw query") std::vector term_ids{1, 0, 3}; auto query = QueryContainer::from_term_ids(term_ids); REQUIRE_THROWS_AS( - query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); + query.parse([](auto&& str) { return std::vector{}; }), std::domain_error); } TEST_CASE("Parse query container from colon-delimited format") diff --git a/test/test_query_parser.cpp b/test/test_query_parser.cpp new file mode 100644 index 000000000..32ba2b32c --- /dev/null +++ b/test/test_query_parser.cpp @@ -0,0 +1,28 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "query/parser.hpp" + +using pisa::QueryContainer; +using pisa::QueryParser; + +TEST_CASE("Parse with lower-case processor and stop word") +{ + std::uint32_t init_id = 0; + auto term_proc = [id = init_id](auto&& term) mutable { + std::transform( + term.begin(), term.end(), term.begin(), [](unsigned char c) { return std::tolower(c); }); + if (term == "house") { + return std::optional{}; + } + return std::optional{pisa::ResolvedTerm{id++, term}}; + }; + QueryParser parser(term_proc); + auto terms = parser("Brooklyn tea house"); + REQUIRE(terms.size() == 2); + REQUIRE(terms[0].term == "brooklyn"); + REQUIRE(terms[1].term == "tea"); +} diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index 17748126c..e29a66b44 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -6,68 +6,17 @@ #include "app.hpp" #include "query.hpp" +#include "query/parser.hpp" #include "tokenizer.hpp" namespace arg = pisa::arg; + using pisa::QueryContainer; +using pisa::QueryParser; +using pisa::QueryReader; +using pisa::StandardTermResolver; using pisa::io::for_each_line; -class TermProcessor { - private: - std::unordered_set stopwords; - - std::function(std::string const&)> m_to_id; - pisa::Stemmer_t m_stemmer; - - public: - TermProcessor( - std::optional const& terms_file, - std::optional const& stopwords_filename, - std::optional const& stemmer_type) - { - auto source = std::make_shared(terms_file->c_str()); - auto terms = pisa::Payload_Vector<>::from(*source); - - m_to_id = [source = std::move(source), terms](auto str) -> std::optional { - // Note: the lexicographical order of the terms matters. - auto pos = std::lower_bound(terms.begin(), terms.end(), std::string_view(str)); - if (*pos == std::string_view(str)) { - return std::distance(terms.begin(), pos); - } - return std::nullopt; - }; - - m_stemmer = pisa::term_processor(stemmer_type); - - if (stopwords_filename) { - std::ifstream is(*stopwords_filename); - pisa::io::for_each_line(is, [&](auto&& word) { - if (auto processed_term = m_to_id(std::move(word)); processed_term.has_value()) { - stopwords.insert(*processed_term); - } - }); - } - } - - [[nodiscard]] std::optional operator()(std::string token) - { - token = m_stemmer(token); - auto id = m_to_id(token); - if (not id) { - return std::nullopt; - } - if (is_stopword(*id)) { - return std::nullopt; - } - return pisa::ParsedTerm{*id, token}; - } - - [[nodiscard]] auto is_stopword(std::uint32_t const term) const -> bool - { - return stopwords.find(term) != stopwords.end(); - } -}; - enum class Format { Json, Colon }; void filter_queries( @@ -78,47 +27,23 @@ void filter_queries( std::size_t min_query_len, std::size_t max_query_len) { - std::optional fmt{}; - auto parser = [term_processor = TermProcessor(term_lexicon, stopwords_filename, stemmer)]( - auto query) mutable { - std::vector parsed_terms; - pisa::TermTokenizer tokenizer(query); - for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { - auto term = term_processor(*term_iter); - if (term) { - parsed_terms.push_back(std::move(*term)); - } + auto reader = [&] { + if (query_file) { + return QueryReader::from_file(*query_file); } - return parsed_terms; - }; - auto filter = [&](auto&& line) { - auto query = [&] { - if (fmt) { - if (*fmt == Format::Json) { - return QueryContainer::from_json(line); - } - return QueryContainer::from_colon_format(line); - } - try { - auto query = QueryContainer::from_json(line); - fmt = Format::Json; - return query; - } catch (std::exception const& err) { - fmt = Format::Colon; - return QueryContainer::from_colon_format(line); + return QueryReader::from_stdin(); + }(); + reader.for_each([&](auto query) { + if (not query.term_ids()) { + if (not term_lexicon) { + throw std::runtime_error("Unresoved queries (without IDs) require term lexicon."); } - }(); - query.parse(parser); + query.parse(QueryParser(StandardTermResolver(*term_lexicon, stopwords_filename, stemmer))); + } if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { std::cout << query.to_json() << '\n'; } - }; - if (query_file) { - std::ifstream is(*query_file); - for_each_line(is, filter); - } else { - for_each_line(std::cin, filter); - } + }); } int main(int argc, char** argv) @@ -135,12 +60,17 @@ int main(int argc, char** argv) app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); - filter_queries( - app.query_file(), - app.term_lexicon(), - app.stemmer(), - app.stop_words(), - min_query_len, - max_query_len); - return 0; + try { + filter_queries( + app.query_file(), + app.term_lexicon(), + app.stemmer(), + app.stop_words(), + min_query_len, + max_query_len); + return 0; + } catch (std::runtime_error const& err) { + spdlog::error(err.what()); + return 1; + } } From 7107f65efa82bc611ef7d0ed83703a15b22ee5ce Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 1 May 2020 14:53:28 +0000 Subject: [PATCH 08/21] Small refactoring and term resolver tests --- include/pisa/query.hpp | 2 + include/pisa/query/query_parser.hpp | 26 +++++++ .../query/{parser.hpp => term_resolver.hpp} | 31 +++++---- src/query.cpp | 12 ++++ src/query/query_parser.cpp | 27 ++++++++ src/query/{parser.cpp => term_resolver.cpp} | 42 ++++++----- test/test_query_parser.cpp | 2 +- test/test_term_resolver.cpp | 69 +++++++++++++++++++ tools/filter_queries.cpp | 48 ++++--------- 9 files changed, 189 insertions(+), 70 deletions(-) create mode 100644 include/pisa/query/query_parser.hpp rename include/pisa/query/{parser.hpp => term_resolver.hpp} (64%) create mode 100644 src/query/query_parser.cpp rename src/query/{parser.cpp => term_resolver.cpp} (76%) create mode 100644 test/test_term_resolver.cpp diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index f40055e06..eb6488c0c 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -45,6 +45,8 @@ class QueryContainer { QueryContainer& operator=(QueryContainer&&) noexcept; ~QueryContainer(); + [[nodiscard]] auto operator==(QueryContainer const& other) const noexcept -> bool; + /// Constructs a query from a raw string. [[nodiscard]] static auto raw(std::string query_string) -> QueryContainer; diff --git a/include/pisa/query/query_parser.hpp b/include/pisa/query/query_parser.hpp new file mode 100644 index 000000000..a3500962a --- /dev/null +++ b/include/pisa/query/query_parser.hpp @@ -0,0 +1,26 @@ +#pragma once + +#include + +#include "query.hpp" +#include "term_resolver.hpp" + +namespace pisa { + +/// Parses a query string to processed terms. +class QueryParser { + public: + explicit QueryParser(TermResolver term_processor); + /// Given a query string, it returns a list of (possibly processed) terms. + /// + /// Possible transformations of terms include lower-casing and stemming. + /// Some terms could be also removed, e.g., because they are on a list of + /// stop words. The exact implementation depends on the term processor + /// passed to the constructor. + auto operator()(std::string const&) -> std::vector; + + private: + TermResolver m_term_resolver; +}; + +} // namespace pisa diff --git a/include/pisa/query/parser.hpp b/include/pisa/query/term_resolver.hpp similarity index 64% rename from include/pisa/query/parser.hpp rename to include/pisa/query/term_resolver.hpp index 589c5c52b..fc9088a48 100644 --- a/include/pisa/query/parser.hpp +++ b/include/pisa/query/term_resolver.hpp @@ -8,6 +8,10 @@ namespace pisa { +/// Thrown if expected resolver but none found. +struct MissingResolverError { +}; + using TermResolver = std::function(std::string)>; struct StandardTermResolverParams; @@ -33,20 +37,17 @@ class StandardTermResolver { std::unique_ptr m_self; }; -/// Parses a query string to processed terms. -class QueryParser { - public: - explicit QueryParser(TermResolver term_processor); - /// Given a query string, it returns a list of (possibly processed) terms. - /// - /// Possible transformations of terms include lower-casing and stemming. - /// Some terms could be also removed, e.g., because they are on a list of - /// stop words. The exact implementation depends on the term processor - /// passed to the constructor. - auto operator()(std::string const&) -> std::vector; - - private: - TermResolver m_term_resolver; -}; +/// Reads queries from `query_file`, resolves them with `term_resolver`, filters by +/// query length (number of resolved terms in the query), and prints the selected +/// queries to `out`. +/// +/// \throws MissingResolverError When no resolver passed but queries don't have IDs resolved. +// +void filter_queries( + std::optional const& query_file, + std::optional term_resolver, + std::size_t min_query_len, + std::size_t max_query_len, + std::ostream& out); } // namespace pisa diff --git a/src/query.cpp b/src/query.cpp index b202ce172..1b54efe66 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -36,6 +36,13 @@ struct QueryContainerInner { std::optional> processed_terms; std::optional> term_ids; std::optional threshold; + + [[nodiscard]] auto operator==(QueryContainerInner const& other) const noexcept -> bool + { + return id == other.id && query_string == other.query_string + && processed_terms == other.processed_terms && term_ids == other.term_ids + && threshold == other.threshold; + } }; QueryContainer::QueryContainer() : m_data(std::make_unique()) {} @@ -52,6 +59,11 @@ QueryContainer& QueryContainer::operator=(QueryContainer const& other) QueryContainer& QueryContainer::operator=(QueryContainer&&) noexcept = default; QueryContainer::~QueryContainer() = default; +auto QueryContainer::operator==(QueryContainer const& other) const noexcept -> bool +{ + return *m_data == *other.m_data; +} + auto QueryContainer::raw(std::string query_string) -> QueryContainer { QueryContainer query; diff --git a/src/query/query_parser.cpp b/src/query/query_parser.cpp new file mode 100644 index 000000000..fc64a7891 --- /dev/null +++ b/src/query/query_parser.cpp @@ -0,0 +1,27 @@ +#include + +#include "io.hpp" +#include "payload_vector.hpp" +#include "query.hpp" +#include "query/query_parser.hpp" +#include "query/term_resolver.hpp" +#include "tokenizer.hpp" + +namespace pisa { + +QueryParser::QueryParser(TermResolver term_resolver) : m_term_resolver(std::move(term_resolver)) {} + +auto QueryParser::operator()(std::string const& query) -> std::vector +{ + TermTokenizer tokenizer(query); + std::vector terms; + for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { + auto term = m_term_resolver(*term_iter); + if (term) { + terms.push_back(std::move(*term)); + } + } + return terms; +} + +} // namespace pisa diff --git a/src/query/parser.cpp b/src/query/term_resolver.cpp similarity index 76% rename from src/query/parser.cpp rename to src/query/term_resolver.cpp index 178a425bf..7d00f59ad 100644 --- a/src/query/parser.cpp +++ b/src/query/term_resolver.cpp @@ -1,11 +1,6 @@ -#include - -#include "io.hpp" -#include "payload_vector.hpp" -#include "query.hpp" -#include "query/parser.hpp" +#include "query/term_resolver.hpp" +#include "query/query_parser.hpp" #include "query/term_processor.hpp" -#include "tokenizer.hpp" namespace pisa { @@ -76,19 +71,30 @@ auto StandardTermResolver::is_stopword(std::uint32_t const term) const -> bool return pos != m_self->stopwords.end() && *pos == term; } -QueryParser::QueryParser(TermResolver term_resolver) : m_term_resolver(std::move(term_resolver)) {} - -auto QueryParser::operator()(std::string const& query) -> std::vector +void filter_queries( + std::optional const& query_file, + std::optional term_resolver, + std::size_t min_query_len, + std::size_t max_query_len, + std::ostream& out) { - TermTokenizer tokenizer(query); - std::vector terms; - for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { - auto term = m_term_resolver(*term_iter); - if (term) { - terms.push_back(std::move(*term)); + auto reader = [&] { + if (query_file) { + return QueryReader::from_file(*query_file); } - } - return terms; + return QueryReader::from_stdin(); + }(); + reader.for_each([&](auto query) { + if (not query.term_ids()) { + if (not term_resolver) { + throw MissingResolverError{}; + } + query.parse(QueryParser(*term_resolver)); + } + if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { + out << query.to_json() << '\n'; + } + }); } } // namespace pisa diff --git a/test/test_query_parser.cpp b/test/test_query_parser.cpp index 32ba2b32c..95f028124 100644 --- a/test/test_query_parser.cpp +++ b/test/test_query_parser.cpp @@ -4,7 +4,7 @@ #include -#include "query/parser.hpp" +#include "query/query_parser.hpp" using pisa::QueryContainer; using pisa::QueryParser; diff --git a/test/test_term_resolver.cpp b/test/test_term_resolver.cpp new file mode 100644 index 000000000..913060b85 --- /dev/null +++ b/test/test_term_resolver.cpp @@ -0,0 +1,69 @@ +#define CATCH_CONFIG_MAIN + +#include + +#include + +#include "io.hpp" +#include "query/term_resolver.hpp" +#include "temporary_directory.hpp" + +using pisa::QueryContainer; +using pisa::StandardTermResolver; + +TEST_CASE("Filter queries") +{ + std::uint32_t id = 0; + auto term_resolver = [&id](auto&& term) mutable { + return std::optional{pisa::ResolvedTerm{id++, term}}; + }; + Temporary_Directory tmp; + auto input = (tmp.path() / "input.txt"); + { + std::ofstream os(input.c_str()); + os << "a b c d\n"; + os << "e\n"; + os << "f g h i j\n"; + os << "k l m\n"; + os << "n o\n"; + } + + SECTION("Between 2 and 4") + { + std::ostringstream os; + pisa::filter_queries( + std::make_optional(input.string()), std::make_optional(term_resolver), 2, 4, os); + std::vector queries; + std::istringstream is(os.str()); + pisa::io::for_each_line( + is, [&queries](auto&& line) { queries.push_back(QueryContainer::from_json(line)); }); + REQUIRE(queries.size() == 3); + REQUIRE(*queries[0].terms() == std::vector{"a", "b", "c", "d"}); + REQUIRE(*queries[0].term_ids() == std::vector{0, 1, 2, 3}); + REQUIRE(*queries[1].terms() == std::vector{"k", "l", "m"}); + REQUIRE(*queries[1].term_ids() == std::vector{10, 11, 12}); + REQUIRE(*queries[2].terms() == std::vector{"n", "o"}); + REQUIRE(*queries[2].term_ids() == std::vector{13, 14}); + + SECTION("Don't fail if no resolver but IDs already resolved") + { + auto json_input = (tmp.path() / "input.json"); + { + std::ofstream json_out(json_input.c_str()); + for (auto&& query: queries) { + json_out << query.to_json() << '\n'; + } + } + std::ostringstream output; + pisa::filter_queries(std::make_optional(json_input.string()), std::nullopt, 2, 4, output); + REQUIRE(output.str() == os.str()); + } + } + + SECTION("Fail without IDs and resolver") + { + REQUIRE_THROWS_AS( + pisa::filter_queries(std::make_optional(input.string()), std::nullopt, 2, 4, std::cerr), + pisa::MissingResolverError); + } +} diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index e29a66b44..1b1507c81 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -6,46 +6,20 @@ #include "app.hpp" #include "query.hpp" -#include "query/parser.hpp" +#include "query/query_parser.hpp" +#include "query/term_resolver.hpp" #include "tokenizer.hpp" namespace arg = pisa::arg; +using pisa::filter_queries; using pisa::QueryContainer; using pisa::QueryParser; using pisa::QueryReader; using pisa::StandardTermResolver; +using pisa::TermResolver; using pisa::io::for_each_line; -enum class Format { Json, Colon }; - -void filter_queries( - std::optional const& query_file, - std::optional const& term_lexicon, - std::optional const& stemmer, - std::optional const& stopwords_filename, - std::size_t min_query_len, - std::size_t max_query_len) -{ - auto reader = [&] { - if (query_file) { - return QueryReader::from_file(*query_file); - } - return QueryReader::from_stdin(); - }(); - reader.for_each([&](auto query) { - if (not query.term_ids()) { - if (not term_lexicon) { - throw std::runtime_error("Unresoved queries (without IDs) require term lexicon."); - } - query.parse(QueryParser(StandardTermResolver(*term_lexicon, stopwords_filename, stemmer))); - } - if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { - std::cout << query.to_json() << '\n'; - } - }); -} - int main(int argc, char** argv) { spdlog::drop(""); @@ -60,15 +34,17 @@ int main(int argc, char** argv) app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); + std::optional term_resolver{}; + if (app.term_lexicon()) { + term_resolver = StandardTermResolver(*app.term_lexicon(), app.stop_words(), app.stemmer()); + } + try { filter_queries( - app.query_file(), - app.term_lexicon(), - app.stemmer(), - app.stop_words(), - min_query_len, - max_query_len); + app.query_file(), std::move(term_resolver), min_query_len, max_query_len, std::cout); return 0; + } catch (pisa::MissingResolverError err) { + spdlog::error("Unresoved queries(without IDs) require term lexicon."); } catch (std::runtime_error const& err) { spdlog::error(err.what()); return 1; From ede9c984a745fae61f2120be4dfa23892a4bdad6 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 1 May 2020 14:55:13 +0000 Subject: [PATCH 09/21] Fix tool description --- tools/filter_queries.cpp | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index 1b1507c81..dbc89469e 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -28,8 +28,7 @@ int main(int argc, char** argv) std::size_t min_query_len = 1; std::size_t max_query_len = std::numeric_limits::max(); - pisa::App> app( - "Filters out empty queries against a v1 index."); + pisa::App> app("Filters queries by their length"); app.add_option("--min", min_query_len, "Minimum query legth to consider"); app.add_option("--max", max_query_len, "Maximum query legth to consider"); CLI11_PARSE(app, argc, argv); From b8f625cd5fd08f53297a151f6d391e275ab97ebd Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 3 May 2020 01:42:17 +0000 Subject: [PATCH 10/21] Multiple thresholds per query --- include/pisa/query.hpp | 17 +++++++--- src/query.cpp | 73 ++++++++++++++++++++++++++++++++++-------- test/test_query.cpp | 10 +++--- 3 files changed, 76 insertions(+), 24 deletions(-) diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index eb6488c0c..dbec6ba51 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -27,12 +27,14 @@ class QueryContainer; /// IDs, and also has some additional data, like term weights, etc. class QueryRequest { public: - explicit QueryRequest(QueryContainer const& data); + explicit QueryRequest(QueryContainer const& data, std::size_t k); [[nodiscard]] auto term_ids() const -> gsl::span; [[nodiscard]] auto threshold() const -> std::optional; + [[nodiscard]] auto k() const -> std::optional; private: + std::size_t m_k; std::optional m_threshold{}; std::vector m_term_ids{}; }; @@ -86,7 +88,9 @@ class QueryContainer { [[nodiscard]] auto string() const noexcept -> std::optional const&; [[nodiscard]] auto terms() const noexcept -> std::optional> const&; [[nodiscard]] auto term_ids() const noexcept -> std::optional> const&; - [[nodiscard]] auto threshold() const noexcept -> std::optional const&; + [[nodiscard]] auto threshold(std::size_t k) const noexcept -> std::optional; + [[nodiscard]] auto thresholds() const noexcept + -> std::vector> const&; /// Sets the raw string. [[nodiscard]] auto string(std::string) -> QueryContainer&; @@ -96,11 +100,14 @@ class QueryContainer { /// \throws std::domain_error when raw string is not set auto parse(ParseFn parse_fn) -> QueryContainer&; - /// Sets the query score threshold. - auto threshold(float score) -> QueryContainer&; + /// Sets the query score threshold for `k`. + /// + /// If another threshold for the same `k` exists, it will be replaced, + /// and `true` will be returned. Otherwise, `false` will be returned. + auto add_threshold(std::size_t k, float score) -> bool; /// Returns a query ready to be used for retrieval. - [[nodiscard]] auto query() const -> QueryRequest; + [[nodiscard]] auto query(std::size_t k) const -> QueryRequest; private: QueryContainer(); diff --git a/src/query.cpp b/src/query.cpp index 1b54efe66..da943bd24 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -9,7 +9,13 @@ namespace pisa { -QueryRequest::QueryRequest(QueryContainer const& data) : m_threshold(data.threshold()) +[[nodiscard]] auto first_equal_to(std::size_t k) +{ + return [k](auto&& pair) { return pair.first == k; }; +} + +QueryRequest::QueryRequest(QueryContainer const& data, std::size_t k) + : m_k(k), m_threshold(data.threshold(k)) { if (auto term_ids = data.term_ids(); term_ids) { m_term_ids = *term_ids; @@ -35,13 +41,13 @@ struct QueryContainerInner { std::optional query_string; std::optional> processed_terms; std::optional> term_ids; - std::optional threshold; + std::vector> thresholds; [[nodiscard]] auto operator==(QueryContainerInner const& other) const noexcept -> bool { return id == other.id && query_string == other.query_string && processed_terms == other.processed_terms && term_ids == other.term_ids - && threshold == other.threshold; + && thresholds == other.thresholds; } }; @@ -115,9 +121,18 @@ auto QueryContainer::term_ids() const noexcept -> std::optionalterm_ids; } -auto QueryContainer::threshold() const noexcept -> std::optional const& +auto QueryContainer::threshold(std::size_t k) const noexcept -> std::optional +{ + auto pos = std::find_if(m_data->thresholds.begin(), m_data->thresholds.end(), first_equal_to(k)); + if (pos == m_data->thresholds.end()) { + return std::nullopt; + } + return std::make_optional(pos->second); +} + +auto QueryContainer::thresholds() const noexcept -> std::vector> const& { - return m_data->threshold; + return m_data->thresholds; } auto QueryContainer::string(std::string raw_query) -> QueryContainer& @@ -143,15 +158,21 @@ auto QueryContainer::parse(ParseFn parse_fn) -> QueryContainer& return *this; } -auto QueryContainer::threshold(float score) -> QueryContainer& +auto QueryContainer::add_threshold(std::size_t k, float score) -> bool { - m_data->threshold = score; - return *this; + if (auto pos = + std::find_if(m_data->thresholds.begin(), m_data->thresholds.end(), first_equal_to(k)); + pos != m_data->thresholds.end()) { + pos->second = score; + return true; + } + m_data->thresholds.emplace_back(k, score); + return false; } -auto QueryContainer::query() const -> QueryRequest +auto QueryContainer::query(std::size_t k) const -> QueryRequest { - return QueryRequest(*this); + return QueryRequest(*this, k); } template @@ -189,8 +210,25 @@ auto QueryContainer::from_json(std::string_view json_string) -> QueryContainer data.term_ids = std::move(term_ids); at_least_one_required = true; } - if (auto threshold = get(json, "threshold"); threshold) { - data.threshold = threshold; + if (auto thresholds = json.find("thresholds"); thresholds != json.end()) { + auto raise_error = [&]() { + throw std::runtime_error( + fmt::format("Field \"thresholds\" is invalid: {}", thresholds->dump())); + }; + if (not thresholds->is_array()) { + raise_error(); + } + for (auto&& threshold_entry: *thresholds) { + if (not threshold_entry.is_object()) { + raise_error(); + } + auto k = get(threshold_entry, "k"); + auto score = get(threshold_entry, "score"); + if (not k or not score) { + raise_error(); + } + data.thresholds.emplace_back(*k, *score); + } } if (not at_least_one_required) { throw std::invalid_argument(fmt::format( @@ -218,8 +256,15 @@ auto QueryContainer::to_json() const -> std::string if (auto term_ids = m_data->term_ids; term_ids) { json["term_ids"] = *term_ids; } - if (auto threshold = m_data->threshold; threshold) { - json["threshold"] = *threshold; + if (not m_data->thresholds.empty()) { + auto thresholds = nlohmann::json::array(); + for (auto&& [k, score]: m_data->thresholds) { + auto entry = nlohmann::json::object(); + entry["k"] = k; + entry["score"] = score; + thresholds.push_back(std::move(entry)); + } + json["thresholds"] = thresholds; } return json.dump(); } diff --git a/test/test_query.cpp b/test/test_query.cpp index 00b3aea0e..69c7406d8 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -107,18 +107,18 @@ TEST_CASE("Parse query container from JSON") REQUIRE(*query.string() == "brooklyn tea house"); REQUIRE_FALSE(query.terms()); REQUIRE_FALSE(query.term_ids()); - REQUIRE_FALSE(query.threshold()); + REQUIRE(query.thresholds().empty()); query = QueryContainer::from_json(R"( { "term_ids": [1, 0, 3], "terms": ["brooklyn", "tea", "house"], - "threshold": 10.8 + "thresholds": [{"k": 10, "score": 10.8}] } )"); REQUIRE(*query.terms() == std::vector{"brooklyn", "tea", "house"}); REQUIRE(*query.term_ids() == std::vector{1, 0, 3}); - REQUIRE(*query.threshold() == Approx(10.8)); + REQUIRE(*query.threshold(10) == Approx(10.8)); REQUIRE_FALSE(query.id()); REQUIRE_FALSE(query.string()); @@ -133,11 +133,11 @@ TEST_CASE("Serialize query container to JSON") "query": "brooklyn tea house", "terms": ["brooklyn", "tea", "house"], "term_ids": [1, 0, 3], - "threshold": 10.0 + "thresholds": [{"k": 10, "score": 10.0}] } )"); auto serialized = query.to_json(); REQUIRE( serialized - == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"threshold":10.0})"); + == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"thresholds":[{"k":10,"score":10.0}]})"); } From 78cf15c6a788ea043b7b59d5687d533cf2b9736f Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 3 May 2020 16:01:36 +0000 Subject: [PATCH 11/21] Return program with 1 if fails --- tools/filter_queries.cpp | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tools/filter_queries.cpp b/tools/filter_queries.cpp index dbc89469e..99a609d5d 100644 --- a/tools/filter_queries.cpp +++ b/tools/filter_queries.cpp @@ -43,9 +43,9 @@ int main(int argc, char** argv) app.query_file(), std::move(term_resolver), min_query_len, max_query_len, std::cout); return 0; } catch (pisa::MissingResolverError err) { - spdlog::error("Unresoved queries(without IDs) require term lexicon."); + spdlog::error("Unresoved queries (without IDs) require term lexicon."); } catch (std::runtime_error const& err) { spdlog::error(err.what()); - return 1; } + return 1; } From bee2fc3cea76eedb2ae02bfd12491f1bb81b0298 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sun, 3 May 2020 17:45:07 +0000 Subject: [PATCH 12/21] Partial query container usage --- include/pisa/cursor/cursor.hpp | 5 +- include/pisa/cursor/max_scored_cursor.hpp | 29 +-- include/pisa/query.hpp | 29 +++ include/pisa/query/queries.hpp | 51 +++--- src/query.cpp | 14 +- src/query/queries.cpp | 210 +++++++++++----------- test/test_query.cpp | 22 +++ tools/app.hpp | 28 ++- tools/profile_queries.cpp | 3 +- tools/thresholds.cpp | 11 +- 10 files changed, 240 insertions(+), 162 deletions(-) diff --git a/include/pisa/cursor/cursor.hpp b/include/pisa/cursor/cursor.hpp index c27c67726..6b6014be0 100644 --- a/include/pisa/cursor/cursor.hpp +++ b/include/pisa/cursor/cursor.hpp @@ -1,12 +1,13 @@ #pragma once -#include "query/queries.hpp" #include +#include "query.hpp" + namespace pisa { template -[[nodiscard]] auto make_cursors(Index const& index, Query query) +[[nodiscard]] auto make_cursors(Index const& index, QueryRequest query) { auto terms = query.terms; remove_duplicate_terms(terms); diff --git a/include/pisa/cursor/max_scored_cursor.hpp b/include/pisa/cursor/max_scored_cursor.hpp index a7d0d475b..f076c3cb1 100644 --- a/include/pisa/cursor/max_scored_cursor.hpp +++ b/include/pisa/cursor/max_scored_cursor.hpp @@ -1,9 +1,13 @@ #pragma once +#include + +#include + +#include "query.hpp" #include "query/queries.hpp" #include "scorer/index_scorer.hpp" #include "wand_data.hpp" -#include namespace pisa { @@ -18,22 +22,19 @@ struct max_scored_cursor { }; template -[[nodiscard]] auto -make_max_scored_cursors(Index const& index, WandType const& wdata, Scorer const& scorer, Query query) +[[nodiscard]] auto make_max_scored_cursors( + Index const& index, WandType const& wdata, Scorer const& scorer, QueryRequest query) { - auto terms = query.terms; - auto query_term_freqs = query_freqs(terms); + auto term_ids = query.term_ids(); + auto term_weights = query.term_ids(); std::vector> cursors; - cursors.reserve(query_term_freqs.size()); - std::transform( - query_term_freqs.begin(), query_term_freqs.end(), std::back_inserter(cursors), [&](auto&& term) { - auto list = index[term.first]; - float q_weight = term.second; - auto max_weight = q_weight * wdata.max_term_weight(term.first); - return max_scored_cursor{ - std::move(list), q_weight, scorer.term_scorer(term.first), max_weight}; - }); + cursors.reserve(term_ids.size()); + + for (auto [term_id, term_weight]: ranges::views::zip(term_ids, term_weights)) { + cursors.push_back( + max_scored_cursor{index[term_id], term_weight * wdata.max_term_weight(term_id)}); + } return cursors; } diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index dbec6ba51..eb61e1d12 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -11,6 +11,11 @@ namespace pisa { +// using DocId = std::uint32_t; +// using Frequency = std::uint32_t; +// using Score = float; +using TermId = std::uint32_t; + struct QueryContainerInner; struct ResolvedTerm { @@ -30,6 +35,7 @@ class QueryRequest { explicit QueryRequest(QueryContainer const& data, std::size_t k); [[nodiscard]] auto term_ids() const -> gsl::span; + [[nodiscard]] auto term_weights() const -> gsl::span; [[nodiscard]] auto threshold() const -> std::optional; [[nodiscard]] auto k() const -> std::optional; @@ -37,6 +43,7 @@ class QueryRequest { std::size_t m_k; std::optional m_threshold{}; std::vector m_term_ids{}; + std::vector m_term_weights{}; }; class QueryContainer { @@ -146,4 +153,26 @@ class QueryReader { std::optional m_format{}; }; +/// Eliminates duplicates in a sorted sequence, and returns a vector of counts. +template +[[nodiscard]] auto unique_with_counts(ForwardIt first, ForwardIt last) -> std::vector +{ + std::vector counts; + + if (first == last) { + return counts; + } + + ForwardIt result = first; + while (++first != last) { + if (!(*result == *first) && ++result != first) { + *result = std::move(*first); + counts.back() += 1; + } else { + counts.push_back(1); + } + } + return counts; +} + } // namespace pisa diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index e0d83ea4a..9501b94d7 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -16,30 +16,31 @@ using term_id_vec = std::vector; using term_freq_pair = std::pair; using term_freq_vec = std::vector; -struct Query { - std::optional id; - std::vector terms; - std::vector term_weights; -}; - -[[nodiscard]] auto split_query_at_colon(std::string const& query_string) - -> std::pair, std::string_view>; - -[[nodiscard]] auto parse_query_terms(std::string const& query_string, TermProcessor term_processor) - -> Query; - -[[nodiscard]] auto parse_query_ids(std::string const& query_string) -> Query; - -[[nodiscard]] std::function resolve_query_parser( - std::vector& queries, - std::optional const& terms_file, - std::optional const& stopwords_filename, - std::optional const& stemmer_type); - -bool read_query(term_id_vec& ret, std::istream& is = std::cin); - -void remove_duplicate_terms(term_id_vec& terms); - -term_freq_vec query_freqs(term_id_vec terms); +// struct Query { +// std::optional id; +// std::vector terms; +// std::vector term_weights; +// }; +// +// [[nodiscard]] auto split_query_at_colon(std::string const& query_string) +// -> std::pair, std::string_view>; +// +// [[nodiscard]] auto parse_query_terms(std::string const& query_string, TermProcessor +// term_processor) +// -> Query; +// +// [[nodiscard]] auto parse_query_ids(std::string const& query_string) -> Query; +// +// [[nodiscard]] std::function resolve_query_parser( +// std::vector& queries, +// std::optional const& terms_file, +// std::optional const& stopwords_filename, +// std::optional const& stemmer_type); +// +// bool read_query(term_id_vec& ret, std::istream& is = std::cin); +// +// void remove_duplicate_terms(term_id_vec& terms); +// +// term_freq_vec query_freqs(term_id_vec terms); } // namespace pisa diff --git a/src/query.cpp b/src/query.cpp index da943bd24..5979e143f 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -19,9 +19,12 @@ QueryRequest::QueryRequest(QueryContainer const& data, std::size_t k) { if (auto term_ids = data.term_ids(); term_ids) { m_term_ids = *term_ids; - std::sort(m_term_ids.begin(), m_term_ids.end()); - auto last = std::unique(m_term_ids.begin(), m_term_ids.end()); - m_term_ids.erase(last, m_term_ids.end()); + auto counts = unique_with_counts(m_term_ids.begin(), m_term_ids.end()); + m_term_weights.resize(counts.size()); + m_term_ids.resize(counts.size()); + std::transform(counts.begin(), counts.end(), m_term_weights.begin(), [](auto count) { + return static_cast(count); + }); } throw std::domain_error("Query not parsed."); } @@ -31,6 +34,11 @@ auto QueryRequest::term_ids() const -> gsl::span return gsl::span(m_term_ids); } +auto QueryRequest::term_weights() const -> gsl::span +{ + return gsl::span(m_term_weights); +} + auto QueryRequest::threshold() const -> std::optional { return m_threshold; diff --git a/src/query/queries.cpp b/src/query/queries.cpp index de4a0674d..add25db90 100644 --- a/src/query/queries.cpp +++ b/src/query/queries.cpp @@ -11,109 +11,111 @@ namespace pisa { -auto split_query_at_colon(std::string const& query_string) - -> std::pair, std::string_view> -{ - // query id : terms (or ids) - auto colon = std::find(query_string.begin(), query_string.end(), ':'); - std::optional id; - if (colon != query_string.end()) { - id = std::string(query_string.begin(), colon); - } - auto pos = colon == query_string.end() ? query_string.begin() : std::next(colon); - auto raw_query = std::string_view(&*pos, std::distance(pos, query_string.end())); - return {std::move(id), raw_query}; -} - -auto parse_query_terms(std::string const& query_string, TermProcessor term_processor) -> Query -{ - auto [id, raw_query] = split_query_at_colon(query_string); - TermTokenizer tokenizer(raw_query); - std::vector parsed_query; - for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { - auto raw_term = *term_iter; - auto term = term_processor(raw_term); - if (term) { - if (!term_processor.is_stopword(*term)) { - parsed_query.push_back(*term); - } else { - spdlog::warn("Term `{}` is a stopword and will be ignored", raw_term); - } - } else { - spdlog::warn("Term `{}` not found and will be ignored", raw_term); - } - } - return {std::move(id), std::move(parsed_query), {}}; -} - -auto parse_query_ids(std::string const& query_string) -> Query -{ - auto [id, raw_query] = split_query_at_colon(query_string); - std::vector parsed_query; - std::vector term_ids; - boost::split(term_ids, raw_query, boost::is_any_of("\t, ,\v,\f,\r,\n")); - - auto is_empty = [](const std::string& val) { return val.empty(); }; - // remove_if move matching elements to the end, preparing them for erase. - term_ids.erase(std::remove_if(term_ids.begin(), term_ids.end(), is_empty), term_ids.end()); - - try { - auto to_int = [](const std::string& val) { return std::stoi(val); }; - std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(parsed_query), to_int); - } catch (std::invalid_argument& err) { - spdlog::error("Could not parse term identifiers of query `{}`", raw_query); - exit(1); - } - return {std::move(id), std::move(parsed_query), {}}; -} - -std::function resolve_query_parser( - std::vector& queries, - std::optional const& terms_file, - std::optional const& stopwords_filename, - std::optional const& stemmer_type) -{ - if (terms_file) { - auto term_processor = TermProcessor(terms_file, stopwords_filename, stemmer_type); - return [&queries, term_processor = std::move(term_processor)](std::string const& query_line) { - queries.push_back(parse_query_terms(query_line, term_processor)); - }; - } - return [&queries](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); - }; -} - -bool read_query(term_id_vec& ret, std::istream& is) -{ - ret.clear(); - std::string line; - if (!std::getline(is, line)) { - return false; - } - ret = parse_query_ids(line).terms; - return true; -} - -void remove_duplicate_terms(term_id_vec& terms) -{ - std::sort(terms.begin(), terms.end()); - terms.erase(std::unique(terms.begin(), terms.end()), terms.end()); -} - -term_freq_vec query_freqs(term_id_vec terms) -{ - term_freq_vec query_term_freqs; - std::sort(terms.begin(), terms.end()); - // count query term frequencies - for (size_t i = 0; i < terms.size(); ++i) { - if (i == 0 || terms[i] != terms[i - 1]) { - query_term_freqs.emplace_back(terms[i], 1); - } else { - query_term_freqs.back().second += 1; - } - } - return query_term_freqs; -} +// auto split_query_at_colon(std::string const& query_string) +// -> std::pair, std::string_view> +//{ +// // query id : terms (or ids) +// auto colon = std::find(query_string.begin(), query_string.end(), ':'); +// std::optional id; +// if (colon != query_string.end()) { +// id = std::string(query_string.begin(), colon); +// } +// auto pos = colon == query_string.end() ? query_string.begin() : std::next(colon); +// auto raw_query = std::string_view(&*pos, std::distance(pos, query_string.end())); +// return {std::move(id), raw_query}; +//} +// +// auto parse_query_terms(std::string const& query_string, TermProcessor term_processor) -> Query +//{ +// auto [id, raw_query] = split_query_at_colon(query_string); +// TermTokenizer tokenizer(raw_query); +// std::vector parsed_query; +// for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { +// auto raw_term = *term_iter; +// auto term = term_processor(raw_term); +// if (term) { +// if (!term_processor.is_stopword(*term)) { +// parsed_query.push_back(*term); +// } else { +// spdlog::warn("Term `{}` is a stopword and will be ignored", raw_term); +// } +// } else { +// spdlog::warn("Term `{}` not found and will be ignored", raw_term); +// } +// } +// return {std::move(id), std::move(parsed_query), {}}; +//} +// +// auto parse_query_ids(std::string const& query_string) -> Query +//{ +// auto [id, raw_query] = split_query_at_colon(query_string); +// std::vector parsed_query; +// std::vector term_ids; +// boost::split(term_ids, raw_query, boost::is_any_of("\t, ,\v,\f,\r,\n")); +// +// auto is_empty = [](const std::string& val) { return val.empty(); }; +// // remove_if move matching elements to the end, preparing them for erase. +// term_ids.erase(std::remove_if(term_ids.begin(), term_ids.end(), is_empty), term_ids.end()); +// +// try { +// auto to_int = [](const std::string& val) { return std::stoi(val); }; +// std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(parsed_query), +// to_int); +// } catch (std::invalid_argument& err) { +// spdlog::error("Could not parse term identifiers of query `{}`", raw_query); +// exit(1); +// } +// return {std::move(id), std::move(parsed_query), {}}; +//} +// +// std::function resolve_query_parser( +// std::vector& queries, +// std::optional const& terms_file, +// std::optional const& stopwords_filename, +// std::optional const& stemmer_type) +//{ +// if (terms_file) { +// auto term_processor = TermProcessor(terms_file, stopwords_filename, stemmer_type); +// return [&queries, term_processor = std::move(term_processor)](std::string const& +// query_line) { +// queries.push_back(parse_query_terms(query_line, term_processor)); +// }; +// } +// return [&queries](std::string const& query_line) { +// queries.push_back(parse_query_ids(query_line)); +// }; +//} +// +// bool read_query(term_id_vec& ret, std::istream& is) +//{ +// ret.clear(); +// std::string line; +// if (!std::getline(is, line)) { +// return false; +// } +// ret = parse_query_ids(line).terms; +// return true; +//} +// +// void remove_duplicate_terms(term_id_vec& terms) +//{ +// std::sort(terms.begin(), terms.end()); +// terms.erase(std::unique(terms.begin(), terms.end()), terms.end()); +//} +// +// term_freq_vec query_freqs(term_id_vec terms) +//{ +// term_freq_vec query_term_freqs; +// std::sort(terms.begin(), terms.end()); +// // count query term frequencies +// for (size_t i = 0; i < terms.size(); ++i) { +// if (i == 0 || terms[i] != terms[i - 1]) { +// query_term_freqs.emplace_back(terms[i], 1); +// } else { +// query_term_freqs.back().second += 1; +// } +// } +// return query_term_freqs; +//} } // namespace pisa diff --git a/test/test_query.cpp b/test/test_query.cpp index 69c7406d8..00e12cb3f 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -7,6 +7,8 @@ #include "query.hpp" using pisa::QueryContainer; +using pisa::TermId; +using pisa::unique_with_counts; TEST_CASE("Construct from raw string") { @@ -141,3 +143,23 @@ TEST_CASE("Serialize query container to JSON") serialized == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"thresholds":[{"k":10,"score":10.0}]})"); } + +TEST_CASE("Test dedup terms.") +{ + SECTION("Double in front") + { + std::vector terms{0, 0, 1, 2, 2, 2, 3}; + auto counts = unique_with_counts(terms.begin(), terms.end()); + REQUIRE(counts == std::vector{2, 1, 3, 1}); + terms.resize(counts.size()); + REQUIRE(terms == std::vector{0, 1, 2, 3}); + } + SECTION("Double at the end") + { + std::vector terms{1, 2, 2, 2, 4, 4}; + auto counts = unique_with_counts(terms.begin(), terms.end()); + REQUIRE(counts == std::vector{1, 3, 2}); + terms.resize(counts.size()); + REQUIRE(terms == std::vector{1, 2, 4}); + } +} diff --git a/tools/app.hpp b/tools/app.hpp index 15158b3ff..e5d83b018 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -12,7 +12,9 @@ #include #include "io.hpp" +#include "query.hpp" #include "query/queries.hpp" +#include "query/term_resolver.hpp" #include "scorer/scorer.hpp" #include "sharding.hpp" #include "type_safe.hpp" @@ -89,17 +91,27 @@ namespace arg { return std::nullopt; } - [[nodiscard]] auto queries() const -> std::vector<::pisa::Query> + [[nodiscard]] auto term_resolver() -> std::optional + { + if (term_lexicon()) { + return StandardTermResolver(*term_lexicon(), stop_words(), stemmer()); + } + return std::nullopt; + } + + [[nodiscard]] auto queries() const -> std::vector<::pisa::QueryContainer> + { + std::vector<::pisa::QueryContainer> queries; + query_reader().for_each([&](auto&& query) { queries.push_back(std::move(query)); }); + return queries; + } + + [[nodiscard]] auto query_reader() const -> QueryReader { - std::vector<::pisa::Query> q; - auto parse_query = resolve_query_parser(q, m_term_lexicon, m_stop_words, m_stemmer); if (m_query_file) { - std::ifstream is(*m_query_file); - io::for_each_line(is, parse_query); - } else { - io::for_each_line(std::cin, parse_query); + return QueryReader::from_file(*m_query_file); } - return q; + return QueryReader::from_stdin(); } [[nodiscard]] auto term_lexicon() const -> std::optional const& diff --git a/tools/profile_queries.cpp b/tools/profile_queries.cpp index da87b8361..64d17140c 100644 --- a/tools/profile_queries.cpp +++ b/tools/profile_queries.cpp @@ -14,6 +14,7 @@ #include "cursor/scored_cursor.hpp" #include "index_types.hpp" #include "mappable/mapper.hpp" +#include "query.hpp" #include "query/algorithm.hpp" #include "scorer/scorer.hpp" #include "util/util.hpp" @@ -22,7 +23,7 @@ using namespace pisa; template -void op_profile(QueryOperator const& query_op, std::vector const& queries) +void op_profile(QueryOperator const& query_op, std::vector const& queries) { using namespace pisa; diff --git a/tools/thresholds.cpp b/tools/thresholds.cpp index 52c09c4a7..fbf82bfa1 100644 --- a/tools/thresholds.cpp +++ b/tools/thresholds.cpp @@ -14,6 +14,7 @@ #include "index_types.hpp" #include "io.hpp" #include "mappable/mapper.hpp" +#include "query.hpp" #include "query/algorithm.hpp" #include "scorer/scorer.hpp" #include "util/util.hpp" @@ -26,7 +27,7 @@ template void thresholds( const std::string& index_filename, const std::optional& wand_data_filename, - const std::vector& queries, + QueryReader queries, std::string const& type, ScorerParams const& scorer_params, uint64_t k, @@ -52,8 +53,8 @@ void thresholds( } topk_queue topk(k); wand_query wand_q(topk); - for (auto const& query: queries) { - wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + queries.for_each([](auto&& query) { + wand_q(make_max_scored_cursors(index, wdata, *scorer, query.query(k)), index.num_docs()); topk.finalize(); auto results = topk.topk(); topk.clear(); @@ -62,7 +63,7 @@ void thresholds( threshold = results.back().first; } std::cout << threshold << '\n'; - } + }); } using wand_raw_index = wand_data; @@ -88,7 +89,7 @@ int main(int argc, const char** argv) auto params = std::make_tuple( app.index_filename(), app.wand_data_path(), - app.queries(), + app.query_reader(), app.index_encoding(), app.scorer_params(), app.k(), From 0a94dcae788b8520a8d75b23cf96f7dc1ed35b77 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 6 May 2020 13:21:37 +0000 Subject: [PATCH 13/21] Replace Query with QueryContainer --- .../pisa/cursor/block_max_scored_cursor.hpp | 32 +- include/pisa/cursor/cursor.hpp | 13 +- include/pisa/cursor/max_scored_cursor.hpp | 20 +- include/pisa/cursor/scored_cursor.hpp | 21 +- include/pisa/intersection.hpp | 46 +- include/pisa/query.hpp | 40 +- include/pisa/query/queries.hpp | 29 - include/pisa/query/query_stemmer.hpp | 21 +- include/pisa/query/term_resolver.hpp | 1 + src/query.cpp | 60 ++- src/query/term_resolver.cpp | 2 +- test/cli/common.sh | 11 + test/cli/run.sh | 1 + test/cli/test_compute_intersection.sh | 40 ++ test/cli/test_filter_queries.sh | 11 +- test/test_bmw_queries.cpp | 12 +- test/test_data/queries.jl | 500 ++++++++++++++++++ test/test_intersection.cpp | 125 +++-- test/test_queries.cpp | 136 ----- test/test_query.cpp | 99 +++- test/test_ranked_queries.cpp | 18 +- test/test_term_resolver.cpp | 2 +- test/test_tokenizer.cpp | 16 +- tools/app.hpp | 19 +- tools/compute_intersection.cpp | 61 +-- tools/evaluate_queries.cpp | 70 ++- tools/map_queries.cpp | 11 +- tools/profile_queries.cpp | 33 +- tools/queries.cpp | 81 +-- tools/selective_queries.cpp | 19 +- tools/thresholds.cpp | 30 +- 31 files changed, 1069 insertions(+), 511 deletions(-) create mode 100644 test/cli/common.sh create mode 100644 test/cli/test_compute_intersection.sh create mode 100644 test/test_data/queries.jl delete mode 100644 test/test_queries.cpp diff --git a/include/pisa/cursor/block_max_scored_cursor.hpp b/include/pisa/cursor/block_max_scored_cursor.hpp index d5d1ad2a5..175bded4c 100644 --- a/include/pisa/cursor/block_max_scored_cursor.hpp +++ b/include/pisa/cursor/block_max_scored_cursor.hpp @@ -1,9 +1,10 @@ #pragma once -#include "query/queries.hpp" +#include + +#include "query.hpp" #include "scorer/index_scorer.hpp" #include "wand_data.hpp" -#include namespace pisa { @@ -23,21 +24,22 @@ struct block_max_scored_cursor { template [[nodiscard]] auto make_block_max_scored_cursors( - Index const& index, WandType const& wdata, Scorer const& scorer, Query query) + Index const& index, WandType const& wdata, Scorer const& scorer, QueryRequest query) { - auto terms = query.terms; - auto query_term_freqs = query_freqs(terms); - - std::vector> cursors; - cursors.reserve(query_term_freqs.size()); + using cursor_type = block_max_scored_cursor; + auto term_ids = query.term_ids(); + auto term_weights = query.term_weights(); + std::vector cursors; + cursors.reserve(term_ids.size()); std::transform( - query_term_freqs.begin(), query_term_freqs.end(), std::back_inserter(cursors), [&](auto&& term) { - auto list = index[term.first]; - auto w_enum = wdata.getenum(term.first); - float q_weight = term.second; - auto max_weight = q_weight * wdata.max_term_weight(term.first); - return block_max_scored_cursor{ - std::move(list), w_enum, q_weight, scorer.term_scorer(term.first), max_weight}; + term_ids.begin(), + term_ids.end(), + term_weights.begin(), + std::back_inserter(cursors), + [&](auto term_id, auto weight) { + auto max_weight = weight * wdata.max_term_weight(term_id); + return cursor_type{ + index[term_id], wdata.getenum(term_id), weight, scorer.term_scorer(term_id), max_weight}; }); return cursors; } diff --git a/include/pisa/cursor/cursor.hpp b/include/pisa/cursor/cursor.hpp index 6b6014be0..8df868939 100644 --- a/include/pisa/cursor/cursor.hpp +++ b/include/pisa/cursor/cursor.hpp @@ -9,14 +9,11 @@ namespace pisa { template [[nodiscard]] auto make_cursors(Index const& index, QueryRequest query) { - auto terms = query.terms; - remove_duplicate_terms(terms); - using cursor = typename Index::document_enumerator; - - std::vector cursors; - cursors.reserve(terms.size()); - std::transform(terms.begin(), terms.end(), std::back_inserter(cursors), [&](auto&& term) { - return index[term]; + auto term_ids = query.term_ids(); + std::vector cursors; + cursors.reserve(term_ids.size()); + std::transform(term_ids.begin(), term_ids.end(), std::back_inserter(cursors), [&](auto&& term_id) { + return index[term_id]; }); return cursors; diff --git a/include/pisa/cursor/max_scored_cursor.hpp b/include/pisa/cursor/max_scored_cursor.hpp index f076c3cb1..b46837148 100644 --- a/include/pisa/cursor/max_scored_cursor.hpp +++ b/include/pisa/cursor/max_scored_cursor.hpp @@ -25,16 +25,20 @@ template [[nodiscard]] auto make_max_scored_cursors( Index const& index, WandType const& wdata, Scorer const& scorer, QueryRequest query) { + using cursor_type = max_scored_cursor; auto term_ids = query.term_ids(); - auto term_weights = query.term_ids(); - - std::vector> cursors; + auto term_weights = query.term_weights(); + std::vector cursors; cursors.reserve(term_ids.size()); - - for (auto [term_id, term_weight]: ranges::views::zip(term_ids, term_weights)) { - cursors.push_back( - max_scored_cursor{index[term_id], term_weight * wdata.max_term_weight(term_id)}); - } + std::transform( + term_ids.begin(), + term_ids.end(), + term_weights.begin(), + std::back_inserter(cursors), + [&](auto term_id, auto weight) { + auto max_weight = weight * wdata.max_term_weight(term_id); + return cursor_type{index[term_id], weight, scorer.term_scorer(term_id), max_weight}; + }); return cursors; } diff --git a/include/pisa/cursor/scored_cursor.hpp b/include/pisa/cursor/scored_cursor.hpp index 1d6d344ca..1711c2155 100644 --- a/include/pisa/cursor/scored_cursor.hpp +++ b/include/pisa/cursor/scored_cursor.hpp @@ -1,6 +1,6 @@ #pragma once -#include "query/queries.hpp" +#include "query.hpp" #include "scorer/index_scorer.hpp" #include "wand_data.hpp" #include @@ -17,18 +17,19 @@ struct scored_cursor { }; template -[[nodiscard]] auto make_scored_cursors(Index const& index, Scorer const& scorer, Query query) +[[nodiscard]] auto make_scored_cursors(Index const& index, Scorer const& scorer, QueryRequest query) { - auto terms = query.terms; - auto query_term_freqs = query_freqs(terms); - + auto term_ids = query.term_ids(); + auto term_weights = query.term_weights(); std::vector> cursors; - cursors.reserve(query_term_freqs.size()); + cursors.reserve(term_ids.size()); std::transform( - query_term_freqs.begin(), query_term_freqs.end(), std::back_inserter(cursors), [&](auto&& term) { - auto list = index[term.first]; - float q_weight = term.second; - return scored_cursor{std::move(list), q_weight, scorer.term_scorer(term.first)}; + term_ids.begin(), + term_ids.end(), + term_weights.begin(), + std::back_inserter(cursors), + [&](auto term_id, auto weight) { + return scored_cursor{index[term_id], weight, scorer.term_scorer(term_id)}; }); return cursors; } diff --git a/include/pisa/intersection.hpp b/include/pisa/intersection.hpp index 29ecc776c..d315450c2 100644 --- a/include/pisa/intersection.hpp +++ b/include/pisa/intersection.hpp @@ -4,8 +4,10 @@ #include #include +#include + +#include "query.hpp" #include "query/algorithm/and_query.hpp" -#include "query/queries.hpp" #include "scorer/scorer.hpp" namespace pisa { @@ -23,22 +25,18 @@ namespace intersection { using Mask = std::bitset; /// Returns a filtered copy of `query` containing only terms indicated by ones in the bit mask. - [[nodiscard]] inline auto filter(Query const& query, Mask mask) -> Query + [[nodiscard]] inline auto filter(QueryContainer const& query, Mask mask) -> QueryContainer { - if (query.terms.size() > MAX_QUERY_LEN) { - throw std::invalid_argument("Queries can be at most 2^32 terms long"); - } - std::vector terms; - std::vector weights; - for (std::size_t bitpos = 0; bitpos < query.terms.size(); ++bitpos) { - if (((1U << bitpos) & mask.to_ulong()) > 0) { - terms.push_back(query.terms.at(bitpos)); - if (bitpos < query.term_weights.size()) { - weights.push_back(query.term_weights[bitpos]); - } + std::vector positions; + for (std::size_t bitpos = 0; mask.any(); ++bitpos) { + if (mask.test(bitpos)) { + positions.push_back(bitpos); + mask.reset(bitpos); } } - return Query{query.id, terms, weights}; + QueryContainer filtered_query(query); + filtered_query.filter_terms(positions); + return filtered_query; } } // namespace intersection @@ -53,19 +51,23 @@ struct Intersection { inline static auto compute( Index const& index, Wand const& wand, - Query const& query, + QueryContainer const& query, std::optional term_mask = std::nullopt) -> Intersection; }; template inline auto Intersection::compute( - Index const& index, Wand const& wand, Query const& query, std::optional term_mask) - -> Intersection + Index const& index, + Wand const& wand, + QueryContainer const& query, + std::optional term_mask) -> Intersection { auto filtered_query = term_mask ? intersection::filter(query, *term_mask) : query; scored_and_query retrieve{}; auto scorer = scorer::from_params(ScorerParams("bm25"), wand); - auto results = retrieve(make_scored_cursors(index, *scorer, filtered_query), index.num_docs()); + auto results = retrieve( + make_scored_cursors(index, *scorer, filtered_query.query(query::unlimited)), + index.num_docs()); auto max_element = [&](auto const& vec) -> float { auto order = [](auto const& lhs, auto const& rhs) { return lhs.second < rhs.second; }; if (auto pos = std::max_element(results.begin(), results.end(), order); pos != results.end()) { @@ -78,14 +80,14 @@ inline auto Intersection::compute( } /// Do `func` for all intersections in a query that have a given maximum number of terms. -/// `Fn` takes `Query` and `Mask`. +/// `Fn` takes `QueryContainer` and `Mask`. template -auto for_all_subsets(Query const& query, std::optional max_term_count, Fn func) +auto for_all_subsets(QueryContainer const& query, std::optional max_term_count, Fn func) { - auto subset_count = 1U << query.terms.size(); + auto subset_count = 1U << query.term_ids()->size(); for (auto subset = 1U; subset < subset_count; ++subset) { auto mask = intersection::Mask(subset); - if (!max_term_count || mask.count() <= *max_term_count) { + if (!max_term_count || (mask.count() <= *max_term_count)) { func(query, mask); } } diff --git a/include/pisa/query.hpp b/include/pisa/query.hpp index eb61e1d12..4758cc3d1 100644 --- a/include/pisa/query.hpp +++ b/include/pisa/query.hpp @@ -8,9 +8,14 @@ #include #include +#include namespace pisa { +namespace query { + constexpr std::size_t unlimited = std::numeric_limits::max(); +} + // using DocId = std::uint32_t; // using Frequency = std::uint32_t; // using Score = float; @@ -32,12 +37,12 @@ class QueryContainer; /// IDs, and also has some additional data, like term weights, etc. class QueryRequest { public: - explicit QueryRequest(QueryContainer const& data, std::size_t k); + explicit QueryRequest(QueryContainer const& data, std::size_t k = query::unlimited); [[nodiscard]] auto term_ids() const -> gsl::span; [[nodiscard]] auto term_weights() const -> gsl::span; [[nodiscard]] auto threshold() const -> std::optional; - [[nodiscard]] auto k() const -> std::optional; + [[nodiscard]] auto k() const -> std::size_t; private: std::size_t m_k; @@ -76,7 +81,8 @@ class QueryContainer { /// Constructs a query from a JSON object. [[nodiscard]] static auto from_json(std::string_view json_string) -> QueryContainer; - [[nodiscard]] auto to_json() const -> std::string; + [[nodiscard]] auto to_json_string() const -> std::string; + [[nodiscard]] auto to_json() const -> nlohmann::json; /// Constructs a query from a colon-separated format: /// @@ -113,7 +119,13 @@ class QueryContainer { /// and `true` will be returned. Otherwise, `false` will be returned. auto add_threshold(std::size_t k, float score) -> bool; + /// Preserve only terms at given positions. + void filter_terms(gsl::span term_positions); + /// Returns a query ready to be used for retrieval. + /// + /// This function takes `k` and resolves the associated threshold if exists. + /// For unranked queries, pass `pisa::query::unlimited` explicitly to avoidi mistakes. [[nodiscard]] auto query(std::size_t k) const -> QueryRequest; private: @@ -153,26 +165,4 @@ class QueryReader { std::optional m_format{}; }; -/// Eliminates duplicates in a sorted sequence, and returns a vector of counts. -template -[[nodiscard]] auto unique_with_counts(ForwardIt first, ForwardIt last) -> std::vector -{ - std::vector counts; - - if (first == last) { - return counts; - } - - ForwardIt result = first; - while (++first != last) { - if (!(*result == *first) && ++result != first) { - *result = std::move(*first); - counts.back() += 1; - } else { - counts.push_back(1); - } - } - return counts; -} - } // namespace pisa diff --git a/include/pisa/query/queries.hpp b/include/pisa/query/queries.hpp index 9501b94d7..f9ac54d21 100644 --- a/include/pisa/query/queries.hpp +++ b/include/pisa/query/queries.hpp @@ -7,8 +7,6 @@ #include #include -#include "query/term_processor.hpp" - namespace pisa { using term_id_type = uint32_t; @@ -16,31 +14,4 @@ using term_id_vec = std::vector; using term_freq_pair = std::pair; using term_freq_vec = std::vector; -// struct Query { -// std::optional id; -// std::vector terms; -// std::vector term_weights; -// }; -// -// [[nodiscard]] auto split_query_at_colon(std::string const& query_string) -// -> std::pair, std::string_view>; -// -// [[nodiscard]] auto parse_query_terms(std::string const& query_string, TermProcessor -// term_processor) -// -> Query; -// -// [[nodiscard]] auto parse_query_ids(std::string const& query_string) -> Query; -// -// [[nodiscard]] std::function resolve_query_parser( -// std::vector& queries, -// std::optional const& terms_file, -// std::optional const& stopwords_filename, -// std::optional const& stemmer_type); -// -// bool read_query(term_id_vec& ret, std::istream& is = std::cin); -// -// void remove_duplicate_terms(term_id_vec& terms); -// -// term_freq_vec query_freqs(term_id_vec terms); - } // namespace pisa diff --git a/include/pisa/query/query_stemmer.hpp b/include/pisa/query/query_stemmer.hpp index 11df9c8c0..3f7d2a885 100644 --- a/include/pisa/query/query_stemmer.hpp +++ b/include/pisa/query/query_stemmer.hpp @@ -1,13 +1,17 @@ #pragma once + #include #include #include -#include "query/queries.hpp" +#include + +#include "query.hpp" #include "query/term_processor.hpp" #include "tokenizer.hpp" -#include + namespace pisa { + class QueryStemmer { public: explicit QueryStemmer(std::optional const& stemmer_name) @@ -15,15 +19,15 @@ class QueryStemmer { {} std::string operator()(std::string const& query_string) { + auto query = QueryContainer::from_colon_format(query_string); std::stringstream tokenized_query; - auto [id, raw_query] = split_query_at_colon(query_string); std::vector stemmed_terms; - TermTokenizer tokenizer(raw_query); + TermTokenizer tokenizer(*query.string()); for (auto term_iter = tokenizer.begin(); term_iter != tokenizer.end(); ++term_iter) { - stemmed_terms.push_back(std::move(m_stemmer(*term_iter))); + stemmed_terms.push_back(m_stemmer(*term_iter)); } - if (id) { - tokenized_query << *(id) << ":"; + if (auto id = query.id(); id) { + tokenized_query << *id << ":"; } using boost::algorithm::join; tokenized_query << join(stemmed_terms, " "); @@ -32,4 +36,5 @@ class QueryStemmer { Stemmer_t m_stemmer; }; -} // namespace pisa \ No newline at end of file + +} // namespace pisa diff --git a/include/pisa/query/term_resolver.hpp b/include/pisa/query/term_resolver.hpp index fc9088a48..402b88133 100644 --- a/include/pisa/query/term_resolver.hpp +++ b/include/pisa/query/term_resolver.hpp @@ -4,6 +4,7 @@ #include #include +#include "payload_vector.hpp" #include "query.hpp" namespace pisa { diff --git a/src/query.cpp b/src/query.cpp index 5979e143f..9d746634b 100644 --- a/src/query.cpp +++ b/src/query.cpp @@ -18,15 +18,17 @@ QueryRequest::QueryRequest(QueryContainer const& data, std::size_t k) : m_k(k), m_threshold(data.threshold(k)) { if (auto term_ids = data.term_ids(); term_ids) { - m_term_ids = *term_ids; - auto counts = unique_with_counts(m_term_ids.begin(), m_term_ids.end()); - m_term_weights.resize(counts.size()); - m_term_ids.resize(counts.size()); - std::transform(counts.begin(), counts.end(), m_term_weights.begin(), [](auto count) { - return static_cast(count); - }); + std::map counts; + for (auto term_id: *term_ids) { + counts[term_id] += 1; + } + for (auto [term_id, count]: counts) { + m_term_ids.push_back(term_id); + m_term_weights.push_back(static_cast(count)); + } + } else { + throw std::domain_error("Query not parsed."); } - throw std::domain_error("Query not parsed."); } auto QueryRequest::term_ids() const -> gsl::span @@ -248,8 +250,12 @@ auto QueryContainer::from_json(std::string_view json_string) -> QueryContainer fmt::format("Failed to parse JSON: `{}`: {}", json_string, err.what())); } } +auto QueryContainer::to_json_string() const -> std::string +{ + return to_json().dump(); +} -auto QueryContainer::to_json() const -> std::string +auto QueryContainer::to_json() const -> nlohmann::json { nlohmann::json json; if (auto id = m_data->id; id) { @@ -274,7 +280,7 @@ auto QueryContainer::to_json() const -> std::string } json["thresholds"] = thresholds; } - return json.dump(); + return json; } auto QueryContainer::from_colon_format(std::string_view line) -> QueryContainer @@ -291,6 +297,40 @@ auto QueryContainer::from_colon_format(std::string_view line) -> QueryContainer return query; } +void QueryContainer::filter_terms(gsl::span term_positions) +{ + auto const& processed_terms = m_data->processed_terms; + auto const& term_ids = m_data->term_ids; + if (not processed_terms && not term_ids) { + return; + } + auto query_length = 0; + if (processed_terms) { + query_length = processed_terms->size(); + } else if (term_ids) { + query_length = term_ids->size(); + } + std::vector filtered_terms; + std::vector filtered_ids; + for (auto position: term_positions) { + if (position >= query_length) { + throw std::out_of_range("Passed term position out of range"); + } + if (processed_terms) { + filtered_terms.push_back(std::move((*m_data->processed_terms)[position])); + } + if (term_ids) { + filtered_ids.push_back((*m_data->term_ids)[position]); + } + } + if (processed_terms) { + m_data->processed_terms = filtered_terms; + } + if (term_ids) { + m_data->term_ids = filtered_ids; + } +} + auto QueryReader::from_file(std::string const& file) -> QueryReader { auto input = std::make_unique(file); diff --git a/src/query/term_resolver.cpp b/src/query/term_resolver.cpp index 7d00f59ad..be6dc8239 100644 --- a/src/query/term_resolver.cpp +++ b/src/query/term_resolver.cpp @@ -92,7 +92,7 @@ void filter_queries( query.parse(QueryParser(*term_resolver)); } if (auto len = query.term_ids()->size(); len >= min_query_len && len <= max_query_len) { - out << query.to_json() << '\n'; + out << query.to_json_string() << '\n'; } }); } diff --git a/test/cli/common.sh b/test/cli/common.sh new file mode 100644 index 000000000..e7e729a06 --- /dev/null +++ b/test/cli/common.sh @@ -0,0 +1,11 @@ +#!/usr/bin/env bats + +function write_lines { + file=$1 + rm -f "$file" + shift + for line in "$@" + do + echo "$line" >> "$file" + done +} diff --git a/test/cli/run.sh b/test/cli/run.sh index 51486a077..652e1c9f5 100755 --- a/test/cli/run.sh +++ b/test/cli/run.sh @@ -3,3 +3,4 @@ DIR=$(dirname "$0") $DIR/setup.sh bats $DIR/test_filter_queries.sh +bats $DIR/test_compute_intersection.sh diff --git a/test/cli/test_compute_intersection.sh b/test/cli/test_compute_intersection.sh new file mode 100644 index 000000000..efe2e79a2 --- /dev/null +++ b/test/cli/test_compute_intersection.sh @@ -0,0 +1,40 @@ +#!/usr/bin/env bats + +PISA_BIN="bin" +export PATH="$PISA_BIN:$PATH" + +. "$BATS_TEST_DIRNAME/common.sh" + +function setup { + write_lines "$BATS_TMPDIR/queries.txt" "brooklyn tea house" "labradoodle" 'Tell your dog I said "hi"' +} + +@test "Compute single intersection" { + result=$(compute_intersection --stemmer porter2 --terms ./fwd.termlex -q $BATS_TMPDIR/queries.txt \ + -e block_simdbp -i ./simdbp -w ./bm25.bmw) + expected='{"intersections":[{"length":2,"mask":7,"max_score":11.045351028442383}],"query":{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}} +{"intersections":[{"length":0,"mask":0,"max_score":0.0}],"query":{"query":"labradoodle","term_ids":[],"terms":[]}} +{"intersections":[{"length":1,"mask":63,"max_score":10.947325706481934}],"query":{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}}' + echo $result > "$BATS_TMPDIR/test.log" + [[ "$result" = "$expected" ]] +} + +@test "Compute combinations with --mtc 1 (effectively single terms)" { + result=$(compute_intersection --stemmer porter2 --terms ./fwd.termlex -q $BATS_TMPDIR/queries.txt \ + -e block_simdbp -i ./simdbp -w ./bm25.bmw --combinations --mtc 1) + expected='{"intersections":[{"length":10,"mask":1,"max_score":6.536393642425537},{"length":20,"mask":2,"max_score":6.352736949920654},{"length":82,"mask":4,"max_score":3.8619942665100098}],"query":{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}} +{"intersections":[],"query":{"query":"labradoodle","term_ids":[],"terms":[]}} +{"intersections":[{"length":156,"mask":1,"max_score":2.819181442260742},{"length":493,"mask":2,"max_score":0.05130492523312569},{"length":168,"mask":4,"max_score":2.7635655403137207},{"length":408,"mask":8,"max_score":0.6954182386398315},{"length":103,"mask":16,"max_score":3.5857112407684326},{"length":33,"mask":32,"max_score":6.272759914398193}],"query":{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}}' + echo $result > "$BATS_TMPDIR/test.log" + [[ "$result" = "$expected" ]] +} + +@test "Compute combinations with --mtc 1 (single terms and pairs)" { + result=$(compute_intersection --stemmer porter2 --terms ./fwd.termlex -q $BATS_TMPDIR/queries.txt \ + -e block_simdbp -i ./simdbp -w ./bm25.bmw --combinations --mtc 2) + expected='{"intersections":[{"length":10,"mask":1,"max_score":6.536393642425537},{"length":20,"mask":2,"max_score":6.352736949920654},{"length":2,"mask":3,"max_score":8.58621883392334},{"length":82,"mask":4,"max_score":3.8619942665100098},{"length":2,"mask":5,"max_score":7.098772048950195},{"length":5,"mask":6,"max_score":8.58519458770752}],"query":{"query":"brooklyn tea house","term_ids":[6535,29194,15462],"terms":["brooklyn","tea","hous"]}} +{"intersections":[],"query":{"query":"labradoodle","term_ids":[],"terms":[]}} +{"intersections":[{"length":156,"mask":1,"max_score":2.819181442260742},{"length":493,"mask":2,"max_score":0.05130492523312569},{"length":82,"mask":3,"max_score":2.850811243057251},{"length":168,"mask":4,"max_score":2.7635655403137207},{"length":13,"mask":5,"max_score":4.625458240509033},{"length":114,"mask":6,"max_score":2.7950499057769775},{"length":408,"mask":8,"max_score":0.6954182386398315},{"length":64,"mask":9,"max_score":3.0511868000030518},{"length":219,"mask":10,"max_score":0.7440643310546875},{"length":64,"mask":12,"max_score":2.8497495651245117},{"length":103,"mask":16,"max_score":3.5857112407684326},{"length":32,"mask":17,"max_score":5.245534420013428},{"length":87,"mask":18,"max_score":3.626940965652466},{"length":26,"mask":20,"max_score":4.820082187652588},{"length":89,"mask":24,"max_score":4.2715864181518555},{"length":33,"mask":32,"max_score":6.272759914398193},{"length":16,"mask":33,"max_score":6.646618843078613},{"length":26,"mask":34,"max_score":6.29996919631958},{"length":3,"mask":36,"max_score":6.239107131958008},{"length":28,"mask":40,"max_score":6.949945449829102},{"length":15,"mask":48,"max_score":8.103652000427246}],"query":{"query":"Tell your dog I said \"hi\"","term_ids":[29287,32766,10396,15670,26032,15114],"terms":["tell","your","dog","i","said","hi"]}}' + echo $result > "$BATS_TMPDIR/test.log" + [[ "$result" = "$expected" ]] +} diff --git a/test/cli/test_filter_queries.sh b/test/cli/test_filter_queries.sh index 2fa486074..52cf7274d 100644 --- a/test/cli/test_filter_queries.sh +++ b/test/cli/test_filter_queries.sh @@ -3,16 +3,7 @@ PISA_BIN="bin" export PATH="$PISA_BIN:$PATH" -function write_lines { - file=$1 - rm -f "$file" - shift - for line in "$@" - do - echo "$line" >> "$file" - done -} - +. "$BATS_TEST_DIRNAME/common.sh" function setup { write_lines "$BATS_TMPDIR/queries.txt" "brooklyn tea house" "labradoodle" 'Tell your dog I said "hi"' diff --git a/test/test_bmw_queries.cpp b/test/test_bmw_queries.cpp index 6bf2f54a7..4b540a9bf 100644 --- a/test/test_bmw_queries.cpp +++ b/test/test_bmw_queries.cpp @@ -46,7 +46,7 @@ struct IndexData { term_id_vec q; std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); auto push_query = [&](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); + queries.push_back(QueryContainer::from_json(query_line)); }; io::for_each_line(qfile, push_query); } @@ -55,7 +55,7 @@ struct IndexData { binary_freq_collection collection; binary_collection document_sizes; Index index; - std::vector queries; + std::vector queries; WandTypePlain wdata; [[nodiscard]] static auto @@ -83,8 +83,12 @@ auto test(Wand& wdata, std::string const& s_name) auto scorer = scorer::from_params(ScorerParams(s_name), data->wdata); for (auto const& q: data->queries) { - wand_q(make_max_scored_cursors(data->index, data->wdata, *scorer, q), data->index.num_docs()); - op_q(make_block_max_scored_cursors(data->index, wdata, *scorer, q), data->index.num_docs()); + wand_q( + make_max_scored_cursors(data->index, data->wdata, *scorer, q.query(10)), + data->index.num_docs()); + op_q( + make_block_max_scored_cursors(data->index, wdata, *scorer, q.query(10)), + data->index.num_docs()); topk_1.finalize(); topk_2.finalize(); REQUIRE(topk_2.topk().size() == topk_1.topk().size()); diff --git a/test/test_data/queries.jl b/test/test_data/queries.jl new file mode 100644 index 000000000..f0256760a --- /dev/null +++ b/test/test_data/queries.jl @@ -0,0 +1,500 @@ +{"term_ids":[101587,61936]} +{"term_ids":[40429,86328]} +{"term_ids":[13975,94987,102912,75488,86157]} +{"term_ids":[80811,110278,90269,96541]} +{"term_ids":[33726]} +{"term_ids":[78401,68238]} +{"term_ids":[59451,82510]} +{"term_ids":[110622,102912,53265,66945,43418,101818,99022,54523,54209]} +{"term_ids":[67842,54513,67848]} +{"term_ids":[55900,91909]} +{"term_ids":[51079,89883]} +{"term_ids":[38616,96982]} +{"term_ids":[97986,43403]} +{"term_ids":[106967,75552,59184]} +{"term_ids":[86328,82481,95555,80147]} +{"term_ids":[101785,47930]} +{"term_ids":[44232,103219]} +{"term_ids":[90882,72383]} +{"term_ids":[48145,68857]} +{"term_ids":[73102,55872,68283]} +{"term_ids":[43460,110362]} +{"term_ids":[46586]} +{"term_ids":[47320,33596]} +{"term_ids":[101682,72197]} +{"term_ids":[62885,43748]} +{"term_ids":[110278,44879]} +{"term_ids":[62574,93388,40150,68583]} +{"term_ids":[102046,74112]} +{"term_ids":[65953,111200]} +{"term_ids":[101365,17496,110642,53842]} +{"term_ids":[82777,83431,41152,44915]} +{"term_ids":[60341,49248,34323,95878,67486,75119]} +{"term_ids":[102133,112621,65989]} +{"term_ids":[60740,78250,62198]} +{"term_ids":[60392,75877,86281]} +{"term_ids":[67574]} +{"term_ids":[33856,88404]} +{"term_ids":[40975]} +{"term_ids":[97369,110949]} +{"term_ids":[110717,76695,110770,74156,102912,54599,42353,111450]} +{"term_ids":[73411,82481,72583,79520,46235]} +{"term_ids":[40013,42353,42958,106267]} +{"term_ids":[51571,51834,82481,91489]} +{"term_ids":[46410,47753]} +{"term_ids":[81496,33252,59377]} +{"term_ids":[80219,72531,82632]} +{"term_ids":[61559,110479,71821]} +{"term_ids":[46352,86758,75773]} +{"term_ids":[105328,44427,5924,86157]} +{"term_ids":[82607,103402,98558]} +{"term_ids":[59519,47436,39332]} +{"term_ids":[105329,61936]} +{"term_ids":[74447,49248,49285]} +{"term_ids":[67262,62044,105677,67262,62044,105677,96886]} +{"term_ids":[113086,52033]} +{"term_ids":[69774,67486,50806]} +{"term_ids":[43974,96023,91015]} +{"term_ids":[62557]} +{"term_ids":[86738,96807,40429,59978,57905]} +{"term_ids":[99001,78599]} +{"term_ids":[65446,91071,50240,93962,111030]} +{"term_ids":[55612,111457]} +{"term_ids":[82620,79303,111530,102324,97353,68820,34390,112715,66631,71126,69016]} +{"term_ids":[97366,85132]} +{"term_ids":[92457,91889]} +{"term_ids":[111200,67486,84677]} +{"term_ids":[86157,93388,82481,61684,41505,70086,8468,48343]} +{"term_ids":[111450,93388,91851,67486,94022,38961]} +{"term_ids":[102133,97188,47852]} +{"term_ids":[60392,44792,47436,39332]} +{"term_ids":[109782,78596,68754,42738]} +{"term_ids":[71780,82481,102046,91015,65989]} +{"term_ids":[102503,62083]} +{"term_ids":[34247,44390]} +{"term_ids":[42771,63843]} +{"term_ids":[93479]} +{"term_ids":[110622,68820,102133,47977,82481,102133,43302]} +{"term_ids":[45777,102533,61690]} +{"term_ids":[58684,42983]} +{"term_ids":[102133,39983,82481,77197,34202,76695]} +{"term_ids":[91753,40749]} +{"term_ids":[47487,57873,62029]} +{"term_ids":[33229,44941]} +{"term_ids":[69805,31550,42004]} +{"term_ids":[79610,98398]} +{"term_ids":[30300,32436,71869,59978,93645,94610,106016]} +{"term_ids":[90013]} +{"term_ids":[86938,74830,44915]} +{"term_ids":[69359,100705,58774,78596,102889]} +{"term_ids":[58625,48720,82481,78623]} +{"term_ids":[78500,49248,80811,90144,56796]} +{"term_ids":[33708,42738]} +{"term_ids":[93788,70008,93879,102339,96015]} +{"term_ids":[68581,58195,59978,53338,34202,88081]} +{"term_ids":[69805,96470,93944]} +{"term_ids":[102133,72383,48169,67558]} +{"term_ids":[110622,68820,53187]} +{"term_ids":[100035,68289,45194,79365]} +{"term_ids":[56694]} +{"term_ids":[92489,84496,47977]} +{"term_ids":[47675,44915]} +{"term_ids":[74156,53113]} +{"term_ids":[48792,82481,44782,44145,82481,104965,88209]} +{"term_ids":[50265,53793,95978]} +{"term_ids":[54599,34202,76462]} +{"term_ids":[52857]} +{"term_ids":[57681,47478]} +{"term_ids":[66014,102912,48819,58131,68462,98077,59953]} +{"term_ids":[60120,34194]} +{"term_ids":[64274,69016,83392,74156,69016]} +{"term_ids":[80432,102046,80422]} +{"term_ids":[51590,94716,79520]} +{"term_ids":[60317,75609,79072]} +{"term_ids":[76897,73806]} +{"term_ids":[101585,95555]} +{"term_ids":[99306,68335,68551]} +{"term_ids":[91214,95878,59978,96921,53338,93388,67486,77217]} +{"term_ids":[102133,57803,93670,44596]} +{"term_ids":[69571,74156,65063]} +{"term_ids":[57729,47034,101846,45930]} +{"term_ids":[70609,65356]} +{"term_ids":[71712,89029,105677,40967,99737]} +{"term_ids":[97478,32942,90144,56796]} +{"term_ids":[99856,42059,111730]} +{"term_ids":[101268,80644]} +{"term_ids":[75039,101681]} +{"term_ids":[81398,55245,84949,104433]} +{"term_ids":[68820,102133,40683,80689,51060,31550,46819,40683]} +{"term_ids":[93959,97650,33229,95458]} +{"term_ids":[61690,62697]} +{"term_ids":[110278,44713]} +{"term_ids":[93788,69773,49248,49285,94399]} +{"term_ids":[52165,78514,89883]} +{"term_ids":[47089,82481,75567]} +{"term_ids":[58663,58634,69640]} +{"term_ids":[53889,105983,96013,105677,67486,59951]} +{"term_ids":[68645,58676,95458,103402,44145]} +{"term_ids":[59451,97116]} +{"term_ids":[90435,56089,88388,47753]} +{"term_ids":[66631,72410]} +{"term_ids":[65814,63815,75496]} +{"term_ids":[104388,44850,33229,47302]} +{"term_ids":[102555,87083,95997,91738]} +{"term_ids":[78567,65741,59978,93645]} +{"term_ids":[84890,98474,56035]} +{"term_ids":[66945,112832,74156,97319,95496,102176,60392]} +{"term_ids":[106967,33286]} +{"term_ids":[38616,76506,86773]} +{"term_ids":[98388,63026]} +{"term_ids":[69800,76231]} +{"term_ids":[60392,82179,82481,71537]} +{"term_ids":[41996,110299]} +{"term_ids":[42585,78960,46337]} +{"term_ids":[104317,72842,83942,57392]} +{"term_ids":[49314,47382]} +{"term_ids":[49251,59940,67558]} +{"term_ids":[102133,41029,47521,64342]} +{"term_ids":[44297,71101]} +{"term_ids":[96035,63790,99413]} +{"term_ids":[80377,83553,71627]} +{"term_ids":[33385,103552,51209]} +{"term_ids":[32556,34202,57081]} +{"term_ids":[104322,103029]} +{"term_ids":[79660,103590,98779,87320]} +{"term_ids":[62029,30298,4807,96598,26877,6386,46406,47487]} +{"term_ids":[87869,111161,80913,68238]} +{"term_ids":[98289,85861,98077]} +{"term_ids":[88154,110278,68583,60392]} +{"term_ids":[97600,96472,96062]} +{"term_ids":[80377,112825,47089,89876,88225]} +{"term_ids":[86000,101610,67910]} +{"term_ids":[54191,58195]} +{"term_ids":[106830,82481,104506,76023,58520]} +{"term_ids":[100072]} +{"term_ids":[64131,51040,92214,101985]} +{"term_ids":[86537,60870]} +{"term_ids":[88435,110278,32606]} +{"term_ids":[81950,47436,39332]} +{"term_ids":[100437,87304,100018]} +{"term_ids":[72377,87092]} +{"term_ids":[42250]} +{"term_ids":[44241,59978,93645,67486,59451]} +{"term_ids":[104801,98449]} +{"term_ids":[65447,82481,95754,92013,80811]} +{"term_ids":[31550,109770,82984,102133,78408,78623]} +{"term_ids":[54550,67486,67203,8802]} +{"term_ids":[110770,84205,43628,75415,38658,61157,100705]} +{"term_ids":[74433,49248,101444,38817]} +{"term_ids":[69571,61327]} +{"term_ids":[56809,41152]} +{"term_ids":[41911,103874]} +{"term_ids":[50855,82984,45058,47750]} +{"term_ids":[72231,54729]} +{"term_ids":[41173]} +{"term_ids":[105871,62567,62697,59978,54935]} +{"term_ids":[49806,65959]} +{"term_ids":[46444,47487]} +{"term_ids":[60845,91919]} +{"term_ids":[32007]} +{"term_ids":[102095,112839]} +{"term_ids":[43302,44961,73912]} +{"term_ids":[110349,45930]} +{"term_ids":[57771,64563,110245,96541]} +{"term_ids":[97919,48164,102749]} +{"term_ids":[94508,59978,63248]} +{"term_ids":[91705,73102,50393]} +{"term_ids":[60392,98612,101985,47427,67203]} +{"term_ids":[110786,86769,39667,109901,103219,2671]} +{"term_ids":[63254,47673,97604]} +{"term_ids":[82697,75944,103402]} +{"term_ids":[97593,34202,62207,47753,59978,46369]} +{"term_ids":[71089,34175]} +{"term_ids":[33300,76282]} +{"term_ids":[85795,33745]} +{"term_ids":[65959,80377,112825]} +{"term_ids":[48754,89457,50481,97213]} +{"term_ids":[45286]} +{"term_ids":[77016,65807]} +{"term_ids":[93959,89635]} +{"term_ids":[4542,84803]} +{"term_ids":[65543]} +{"term_ids":[105922,80724,60551,86294,105677,103960]} +{"term_ids":[69628,42585,33229,73293]} +{"term_ids":[106928,47521,67701,110389]} +{"term_ids":[96585,51814]} +{"term_ids":[109945,94508,82481,88549]} +{"term_ids":[59995,89564,49516,55913]} +{"term_ids":[98449,87992,69227,40277,85111]} +{"term_ids":[49254,65741,73764]} +{"term_ids":[46248,12608,60458]} +{"term_ids":[102133,51198,113242]} +{"term_ids":[45612,76695]} +{"term_ids":[43422,64630]} +{"term_ids":[86157,71692,75182]} +{"term_ids":[68689,62558]} +{"term_ids":[85853,60484,91015,67486,80377,69613]} +{"term_ids":[66333,96160]} +{"term_ids":[111542,39667,65741]} +{"term_ids":[65741,57563,73126]} +{"term_ids":[90595,97823,53778,45773,93388,79216]} +{"term_ids":[8859]} +{"term_ids":[66309,97919]} +{"term_ids":[84734,94508,60458,44782]} +{"term_ids":[34281,67486,33941]} +{"term_ids":[80377,69613,53287]} +{"term_ids":[46556,86987]} +{"term_ids":[69032,55929,52484]} +{"term_ids":[95458,60836,65741,88572]} +{"term_ids":[80913,46579,72575,41346,32477]} +{"term_ids":[70461]} +{"term_ids":[63837,93388,49024,78067]} +{"term_ids":[42353,51339]} +{"term_ids":[93176]} +{"term_ids":[97489,84852]} +{"term_ids":[84672,89564]} +{"term_ids":[67558]} +{"term_ids":[27281]} +{"term_ids":[47647,80143,78250]} +{"term_ids":[97863,32177]} +{"term_ids":[112916,61891,82620]} +{"term_ids":[39717,40712,64889]} +{"term_ids":[39749,80410]} +{"term_ids":[111497,49248]} +{"term_ids":[105871,51834,82481,54935,83069,85130]} +{"term_ids":[61508,109936,102679]} +{"term_ids":[111457,85054]} +{"term_ids":[33883]} +{"term_ids":[72160,95997]} +{"term_ids":[87439,82519]} +{"term_ids":[39363,100394,84617]} +{"term_ids":[57929,105065,68394]} +{"term_ids":[34267]} +{"term_ids":[75721,98492,42738,82481,49134]} +{"term_ids":[77175,38658]} +{"term_ids":[98955,97248,96035]} +{"term_ids":[48062,99737,93880,47223]} +{"term_ids":[65741,86670]} +{"term_ids":[85319,71012]} +{"term_ids":[79365,104515,40277,66631,71573,54383,93388,79365,54383,100705]} +{"term_ids":[85853,64590]} +{"term_ids":[67567]} +{"term_ids":[94508,113242,44381,32606]} +{"term_ids":[59451,98492,53114,82481,49053]} +{"term_ids":[110952,80377,69613,82179,58076,9365]} +{"term_ids":[103393,84803,67486,47979,97986]} +{"term_ids":[32942,33944,87059,96541]} +{"term_ids":[111530,102912,77013,97353]} +{"term_ids":[44941,79216,45804]} +{"term_ids":[64169,51039]} +{"term_ids":[46607,100605,59978,79216]} +{"term_ids":[98705,45882,34202,111635,64585]} +{"term_ids":[32747,31550,85853]} +{"term_ids":[103368,63248]} +{"term_ids":[52853,112626]} +{"term_ids":[86783,72044,59439]} +{"term_ids":[49066,95458]} +{"term_ids":[44596,98492,48757,101985]} +{"term_ids":[43653,65886,96216,93536]} +{"term_ids":[63864,9072,103744]} +{"term_ids":[69032,84983,95868,99381]} +{"term_ids":[90640,102142,106822,80377,112825]} +{"term_ids":[47521,76492]} +{"term_ids":[80377,69613,80410]} +{"term_ids":[59951,99019]} +{"term_ids":[99373,67486,78960]} +{"term_ids":[110684,111455,79303,69453,53612,73754]} +{"term_ids":[43987,86092]} +{"term_ids":[34522,84496,49472]} +{"term_ids":[70624]} +{"term_ids":[102339,34202,74890,39919,48343]} +{"term_ids":[83993,48669,91087]} +{"term_ids":[51400,49583]} +{"term_ids":[106403,47089,99045]} +{"term_ids":[65959,55753,71627,75361]} +{"term_ids":[85065,89402,47930]} +{"term_ids":[42490]} +{"term_ids":[91013,102912,44347,60870]} +{"term_ids":[60392,88156,53847]} +{"term_ids":[76756,41520,104515,93388,31550]} +{"term_ids":[112799,41183,68820,96935,102181,102133]} +{"term_ids":[98819,49251,79216]} +{"term_ids":[95793,96987]} +{"term_ids":[80811,74112]} +{"term_ids":[40157,76848,43843,79303,101688]} +{"term_ids":[68271,101635]} +{"term_ids":[87885,64601]} +{"term_ids":[45967,104367,83015,60120,79315]} +{"term_ids":[9174,59995,57368]} +{"term_ids":[69553,106830]} +{"term_ids":[63974]} +{"term_ids":[84852,51834,78464,106255]} +{"term_ids":[54264,96107]} +{"term_ids":[104405,96293,48186,59978,110677,83392,81520,102265]} +{"term_ids":[99045,110663,102265,95217,78960,67486,69297,77095]} +{"term_ids":[47977,67266]} +{"term_ids":[44782,80410,84689]} +{"term_ids":[74964,64286,39332]} +{"term_ids":[79610,64619,9174,46410,47753]} +{"term_ids":[52853,65959]} +{"term_ids":[61566,70878]} +{"term_ids":[112601,79303,92489]} +{"term_ids":[94679,41646,32241]} +{"term_ids":[41650,83906,78567]} +{"term_ids":[55987,74044,63248]} +{"term_ids":[44878,92539,93143]} +{"term_ids":[56498,77200]} +{"term_ids":[39750,97650]} +{"term_ids":[87869,111450]} +{"term_ids":[60392,85801,94916]} +{"term_ids":[60830]} +{"term_ids":[60392,57206,82481,88464,111542,90847]} +{"term_ids":[93959,78586]} +{"term_ids":[69805,13974,6756]} +{"term_ids":[71860]} +{"term_ids":[86110,45512]} +{"term_ids":[56640,49248,94508,53047]} +{"term_ids":[32804,34202,75808]} +{"term_ids":[86610,19275]} +{"term_ids":[89012]} +{"term_ids":[40240,74112]} +{"term_ids":[66014,102912,68183,31550,45860,57755]} +{"term_ids":[97578,111530,59978,60484]} +{"term_ids":[61293,82481,87731]} +{"term_ids":[51656,97353]} +{"term_ids":[60612]} +{"term_ids":[90144,56796,67486,63326,82713]} +{"term_ids":[77811,60870]} +{"term_ids":[93925,76278]} +{"term_ids":[94904,85497]} +{"term_ids":[102168]} +{"term_ids":[95831,88277]} +{"term_ids":[86113,96015]} +{"term_ids":[80811,65959]} +{"term_ids":[75799,65907,86157]} +{"term_ids":[42395,84494]} +{"term_ids":[8682,102476]} +{"term_ids":[64756,70537]} +{"term_ids":[68910]} +{"term_ids":[97356,48943]} +{"term_ids":[94679]} +{"term_ids":[89613,63041,67486,42857,66839]} +{"term_ids":[105922,82549,88153,87992]} +{"term_ids":[41650,105871,41152]} +{"term_ids":[69291,54520,63814]} +{"term_ids":[91754,74719,110639]} +{"term_ids":[71730,49248,58828]} +{"term_ids":[48928,43179,63334]} +{"term_ids":[89621,98558,32804,87398,83459]} +{"term_ids":[59451,71899,101813,47753]} +{"term_ids":[73414,102912,50745,93388,78250,44381,91787]} +{"term_ids":[69227,78542]} +{"term_ids":[102133,104677,78266,80410,84689]} +{"term_ids":[89283,65959,110507,103834,32807]} +{"term_ids":[112590,46758,53831,48169]} +{"term_ids":[99008,89437,60535,78623]} +{"term_ids":[80377,112825,95668,80643,47521]} +{"term_ids":[92739]} +{"term_ids":[102339,91795,82984,103402]} +{"term_ids":[66204,49248,49285]} +{"term_ids":[105922,102133,61091,52558,13828,91223,42958]} +{"term_ids":[93190,50806]} +{"term_ids":[103189,101119,85189]} +{"term_ids":[99378,63218,59978,31550,58181,87083,63176]} +{"term_ids":[71101]} +{"term_ids":[59451,91355,83446]} +{"term_ids":[85003,82481,48062,96921,86799,59978,85314]} +{"term_ids":[70911]} +{"term_ids":[82834,111457]} +{"term_ids":[40467,46414]} +{"term_ids":[53778,76438]} +{"term_ids":[82070,69904,80410]} +{"term_ids":[66246,81952]} +{"term_ids":[66014,102912,48510,31550,97640,95173]} +{"term_ids":[86157,82481,62521,63041,102133,110949,96530]} +{"term_ids":[103446]} +{"term_ids":[64855,60146]} +{"term_ids":[73885,44611,39332]} +{"term_ids":[34393,49569,91087]} +{"term_ids":[92007,59954]} +{"term_ids":[68789,95458]} +{"term_ids":[55364,75285,72096,33432]} +{"term_ids":[48731,33252]} +{"term_ids":[71102,58520,53718,86328]} +{"term_ids":[61647,34202,47415,56096,67486,102133,54523,74719]} +{"term_ids":[46630,106255,94508]} +{"term_ids":[82620,79303,75285,66631,74156,65357,97142]} +{"term_ids":[34281,67486,33941,89437]} +{"term_ids":[52672,104474,70970]} +{"term_ids":[54577,62065]} +{"term_ids":[53573]} +{"term_ids":[52309,44879]} +{"term_ids":[103552,67486,43388]} +{"term_ids":[91754,33634]} +{"term_ids":[59978,75428,82915,80081]} +{"term_ids":[88154,75471,102912,47977,82481,33353]} +{"term_ids":[61625,103347]} +{"term_ids":[44893,67486,102133,32674]} +{"term_ids":[73783,66358]} +{"term_ids":[81507,75476]} +{"term_ids":[10647,42254,66853]} +{"term_ids":[110622,66945,112850,53338,61784,42284,44381,31550,102533]} +{"term_ids":[48669,88081,67695]} +{"term_ids":[48343,67486,41597,67702]} +{"term_ids":[52079,59451,65664,68070,41158]} +{"term_ids":[90953,109804]} +{"term_ids":[89575,60535,102133,90144,111591]} +{"term_ids":[32556,84936]} +{"term_ids":[96392,46410]} +{"term_ids":[102133,50714,106909,47753,78623]} +{"term_ids":[55245,53484,49285]} +{"term_ids":[110291]} +{"term_ids":[94680,44882,44056,57457,66113,103219]} +{"term_ids":[98492,40732,82481,102046]} +{"term_ids":[60392,102116,112832,45169]} +{"term_ids":[46535,86587,39212,81926]} +{"term_ids":[111389]} +{"term_ids":[41389,49248,71537,61559]} +{"term_ids":[110230,48030,75739,74830]} +{"term_ids":[69876,51553,106251,90144,98492,39842]} +{"term_ids":[90144,39012]} +{"term_ids":[51039,43703]} +{"term_ids":[85497,98558,91767]} +{"term_ids":[97425,51021,87059]} +{"term_ids":[47089,82481,83588,33353]} +{"term_ids":[65959,34202,78553,59978,93645,67486,97478,102046]} +{"term_ids":[57916,49248,85515,88846,90374]} +{"term_ids":[88374,53793]} +{"term_ids":[65938,79568,58828]} +{"term_ids":[80926,85619]} +{"term_ids":[85975,102535]} +{"term_ids":[106830,61241,40467]} +{"term_ids":[83918,40702,91015,88438]} +{"term_ids":[80443]} +{"term_ids":[44189,44824]} +{"term_ids":[46556,48087,88438]} +{"term_ids":[33972,80525]} +{"term_ids":[65768]} +{"term_ids":[46328,34202,102133,46758,57613,78623]} +{"term_ids":[61684,45612]} +{"term_ids":[65357]} +{"term_ids":[46999,96987,90144,56796,33003]} +{"term_ids":[96267,32199]} +{"term_ids":[49569,45169,40150,88323]} +{"term_ids":[107372,40601,23611]} +{"term_ids":[46406,59377]} +{"term_ids":[56006,39992,67486]} +{"term_ids":[65356,86281]} +{"term_ids":[43616,95458]} +{"term_ids":[68565,94045,40702]} +{"term_ids":[86066,68381,33262]} +{"term_ids":[106170,65745]} +{"term_ids":[53576,55403]} +{"term_ids":[79075]} +{"term_ids":[85577,43189]} +{"term_ids":[46414,93766]} +{"term_ids":[96392,61241]} +{"term_ids":[101688,46344]} +{"term_ids":[45147,58429,96216,96676]} +{"term_ids":[39485,49251]} +{"term_ids":[43537]} diff --git a/test/test_intersection.cpp b/test/test_intersection.cpp index cf419f281..86917a6f4 100644 --- a/test/test_intersection.cpp +++ b/test/test_intersection.cpp @@ -12,28 +12,42 @@ using namespace pisa::intersection; TEST_CASE("filter query", "[intersection][unit]") { - GIVEN("Four-term query") + GIVEN("With term IDs") { - Query query{ - "Q1", // query ID - {6, 1, 5}, // terms - {0.1, 0.4, 1.0} // weights - }; - auto [mask, expected] = GENERATE(table({ - {0b001, Query{"Q1", {6}, {0.1}}}, - {0b010, Query{"Q1", {1}, {0.4}}}, - {0b100, Query{"Q1", {5}, {1.0}}}, - {0b011, Query{"Q1", {6, 1}, {0.1, 0.4}}}, - {0b101, Query{"Q1", {6, 5}, {0.1, 1.0}}}, - {0b110, Query{"Q1", {1, 5}, {0.4, 1.0}}}, - {0b111, Query{"Q1", {6, 1, 5}, {0.1, 0.4, 1.0}}}, + auto query = QueryContainer::from_term_ids({6, 1, 5}); + auto [mask, expected] = GENERATE(table({ + {0b001, QueryContainer::from_term_ids({6})}, + {0b010, QueryContainer::from_term_ids({1})}, + {0b100, QueryContainer::from_term_ids({5})}, + {0b011, QueryContainer::from_term_ids({6, 1})}, + {0b101, QueryContainer::from_term_ids({6, 5})}, + {0b110, QueryContainer::from_term_ids({1, 5})}, + {0b111, QueryContainer::from_term_ids({6, 1, 5})}, })); WHEN("Filtered with mask " << mask) { auto actual = filter(query, mask); - CHECK(actual.id == expected.id); - CHECK(actual.terms == expected.terms); - CHECK(actual.term_weights == expected.term_weights); + CHECK(actual.term_ids() == expected.term_ids()); + CHECK(actual.terms() == expected.terms()); + } + } + GIVEN("With terms") + { + auto query = QueryContainer::from_terms({"a", "b", "c"}, std::nullopt); + auto [mask, expected] = GENERATE(table({ + {0b001, QueryContainer::from_terms({"a"}, std::nullopt)}, + {0b010, QueryContainer::from_terms({"b"}, std::nullopt)}, + {0b100, QueryContainer::from_terms({"c"}, std::nullopt)}, + {0b011, QueryContainer::from_terms({"a", "b"}, std::nullopt)}, + {0b101, QueryContainer::from_terms({"a", "c"}, std::nullopt)}, + {0b110, QueryContainer::from_terms({"b", "c"}, std::nullopt)}, + {0b111, QueryContainer::from_terms({"a", "b", "c"}, std::nullopt)}, + })); + WHEN("Filtered with mask " << mask) + { + auto actual = filter(query, mask); + REQUIRE(actual.term_ids() == expected.term_ids()); + REQUIRE(*actual.terms() == *expected.terms()); } } } @@ -89,10 +103,11 @@ struct InMemoryIndex { throw std::out_of_range( fmt::format("Term {} is out of range; index contains {} terms", term_id, size())); } - return {gsl::make_span(documents[term_id]), - gsl::make_span(frequencies[term_id]), - num_documents, - {num_documents}}; + return { + gsl::make_span(documents[term_id]), + gsl::make_span(frequencies[term_id]), + num_documents, + {num_documents}}; } [[nodiscard]] auto size() const noexcept -> std::size_t { return documents.size(); } @@ -177,32 +192,29 @@ TEST_CASE("compute intersection", "[intersection][unit]") { GIVEN("Four-term query, index, and wand data object") { - InMemoryIndex index{{ - {0}, // 0 - {0, 1, 2}, // 1 - {0}, // 2 - {0}, // 3 - {0}, // 4 - {0, 1, 4}, // 5 - {1, 4, 8}, // 6 - }, - { - {1}, // 0 - {1, 1, 1}, // 1 - {1}, // 2 - {1}, // 3 - {1}, // 4 - {1, 1, 1}, // 5 - {1, 1, 1}, // 6 - }, - 10}; + InMemoryIndex index{ + { + {0}, // 0 + {0, 1, 2}, // 1 + {0}, // 2 + {0}, // 3 + {0}, // 4 + {0, 1, 4}, // 5 + {1, 4, 8}, // 6 + }, + { + {1}, // 0 + {1, 1, 1}, // 1 + {1}, // 2 + {1}, // 3 + {1}, // 4 + {1, 1, 1}, // 5 + {1, 1, 1}, // 6 + }, + 10}; InMemoryWand wand{{0.0, 1.0, 0.0, 0.0, 0.0, 5.0, 6.0}, 10}; - Query query{ - "Q1", // query ID - {6, 1, 5}, // terms - {0.1, 0.4, 1.0} // weights - }; + auto query = QueryContainer::from_term_ids({6, 1, 5}); auto [mask, len, max] = GENERATE(table({ {0b001, 3, 1.84583f}, {0b010, 3, 1.84583f}, @@ -226,12 +238,8 @@ TEST_CASE("for_all_subsets", "[intersection][unit]") GIVEN("A query and a mock function that accumulates arguments") { std::vector masks; - auto accumulate = [&](Query const&, Mask const& mask) { masks.push_back(mask); }; - Query query{ - "Q1", // query ID - {6, 1, 5}, // terms - {0.1, 0.4, 1.0} // weights - }; + auto accumulate = [&](QueryContainer const&, Mask const& mask) { masks.push_back(mask); }; + auto query = QueryContainer::from_term_ids({6, 1, 5}); WHEN("Executed with limit 0") { for_all_subsets(query, 0, accumulate); @@ -263,13 +271,14 @@ TEST_CASE("for_all_subsets", "[intersection][unit]") { CHECK( masks - == std::vector{Mask(0b001), - Mask(0b010), - Mask(0b011), - Mask(0b100), - Mask(0b101), - Mask(0b110), - Mask(0b111)}); + == std::vector{ + Mask(0b001), + Mask(0b010), + Mask(0b011), + Mask(0b100), + Mask(0b101), + Mask(0b110), + Mask(0b111)}); } } } diff --git a/test/test_queries.cpp b/test/test_queries.cpp deleted file mode 100644 index a97aa1b7e..000000000 --- a/test/test_queries.cpp +++ /dev/null @@ -1,136 +0,0 @@ -#define CATCH_CONFIG_MAIN - -#include - -#include "query/algorithm.hpp" -#include "temporary_directory.hpp" - -using namespace pisa; - -TEST_CASE("Parse query term ids without query id") -{ - auto raw_query = "1 2\t3 4"; - auto q = parse_query_ids(raw_query); - REQUIRE(q.id.has_value() == false); - REQUIRE(q.terms == std::vector{1, 2, 3, 4}); -} - -TEST_CASE("Parse query term ids with query id") -{ - auto raw_query = "1: 1\t2 3\t4"; - auto q = parse_query_ids(raw_query); - REQUIRE(q.id == "1"); - REQUIRE(q.terms == std::vector{1, 2, 3, 4}); -} - -TEST_CASE("Compute parsing function") -{ - Temporary_Directory tmpdir; - - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"a", "account", "he", "she", "usa", "world"})) - .to_file(lexfile.string()); - auto stopwords_filename = tmpdir.path() / "stop"; - { - std::ofstream os(stopwords_filename.string()); - os << "a\nthe\n"; - } - - std::vector queries; - - WHEN("No stopwords, terms, or stemmer") - { - auto parse = resolve_query_parser(queries, std::nullopt, std::nullopt, std::nullopt); - THEN("Parse query IDs") - { - parse("1:0 2 4"); - REQUIRE(queries[0].id == std::optional("1")); - REQUIRE(queries[0].terms == std::vector{0, 2, 4}); - REQUIRE(queries[0].term_weights.empty()); - } - } - WHEN("With terms and stopwords. No stemmer") - { - auto parse = resolve_query_parser( - queries, lexfile.string(), stopwords_filename.string(), std::nullopt); - THEN("Parse query IDs") - { - parse("1:a he usa"); - REQUIRE(queries[0].id == std::optional("1")); - REQUIRE(queries[0].terms == std::vector{2, 4}); - REQUIRE(queries[0].term_weights.empty()); - } - } - WHEN("With terms, stopwords, and stemmer") - { - auto parse = - resolve_query_parser(queries, lexfile.string(), stopwords_filename.string(), "porter2"); - THEN("Parse query IDs") - { - parse("1:a he usa"); - REQUIRE(queries[0].id == std::optional("1")); - REQUIRE(queries[0].terms == std::vector{2, 4}); - REQUIRE(queries[0].term_weights.empty()); - } - } -} - -TEST_CASE("Load stopwords in term processor with all stopwords present in the lexicon") -{ - Temporary_Directory tmpdir; - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"a", "account", "he", "she", "usa", "world"})) - .to_file(lexfile.string()); - - auto stopwords_filename = (tmpdir.path() / "stopwords").string(); - std::ofstream is(stopwords_filename); - is << "a\nshe\nhe"; - is.close(); - - TermProcessor tprocessor( - std::make_optional(lexfile.string()), std::make_optional(stopwords_filename), std::nullopt); - REQUIRE(tprocessor.get_stopwords() == std::vector{0, 2, 3}); -} - -TEST_CASE("Load stopwords in term processor with some stopwords not present in the lexicon") -{ - Temporary_Directory tmpdir; - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"account", "coffee", "he", "she", "usa", "world"})) - .to_file(lexfile.string()); - - auto stopwords_filename = (tmpdir.path() / "stopwords").string(); - std::ofstream is(stopwords_filename); - is << "\nis\nto\na\nshe\nhe"; - is.close(); - - TermProcessor tprocessor( - std::make_optional(lexfile.string()), std::make_optional(stopwords_filename), std::nullopt); - REQUIRE(tprocessor.get_stopwords() == std::vector{2, 3}); -} - -TEST_CASE("Check if term is stopword") -{ - Temporary_Directory tmpdir; - auto lexfile = tmpdir.path() / "lex"; - encode_payload_vector( - gsl::make_span(std::vector{"account", "coffee", "he", "she", "usa", "world"})) - .to_file(lexfile.string()); - - auto stopwords_filename = (tmpdir.path() / "stopwords").string(); - std::ofstream is(stopwords_filename); - is << "\nis\nto\na\nshe\nhe"; - is.close(); - - TermProcessor tprocessor( - std::make_optional(lexfile.string()), std::make_optional(stopwords_filename), std::nullopt); - REQUIRE(!tprocessor.is_stopword(0)); - REQUIRE(!tprocessor.is_stopword(1)); - REQUIRE(tprocessor.is_stopword(2)); - REQUIRE(tprocessor.is_stopword(3)); - REQUIRE(!tprocessor.is_stopword(4)); - REQUIRE(!tprocessor.is_stopword(5)); -} diff --git a/test/test_query.cpp b/test/test_query.cpp index 00e12cb3f..ec4b298a9 100644 --- a/test/test_query.cpp +++ b/test/test_query.cpp @@ -8,7 +8,6 @@ using pisa::QueryContainer; using pisa::TermId; -using pisa::unique_with_counts; TEST_CASE("Construct from raw string") { @@ -138,28 +137,98 @@ TEST_CASE("Serialize query container to JSON") "thresholds": [{"k": 10, "score": 10.0}] } )"); - auto serialized = query.to_json(); + auto serialized = query.to_json_string(); REQUIRE( serialized == R"({"id":"ID","query":"brooklyn tea house","term_ids":[1,0,3],"terms":["brooklyn","tea","house"],"thresholds":[{"k":10,"score":10.0}]})"); } -TEST_CASE("Test dedup terms.") +TEST_CASE("Copy constructor and assignment") { - SECTION("Double in front") + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house", + "terms": ["brooklyn", "tea", "house"], + "term_ids": [1, 0, 3], + "thresholds": [{"k": 10, "score": 10.0}] + } + )"); + { + QueryContainer copy(query); + REQUIRE(query.string() == copy.string()); + REQUIRE(*query.id() == copy.id()); + REQUIRE(*query.terms() == *copy.terms()); + REQUIRE(*query.term_ids() == *copy.term_ids()); + REQUIRE(query.thresholds() == copy.thresholds()); + } { - std::vector terms{0, 0, 1, 2, 2, 2, 3}; - auto counts = unique_with_counts(terms.begin(), terms.end()); - REQUIRE(counts == std::vector{2, 1, 3, 1}); - terms.resize(counts.size()); - REQUIRE(terms == std::vector{0, 1, 2, 3}); + auto copy = QueryContainer::raw(""); + copy = query; + REQUIRE(query.string() == copy.string()); + REQUIRE(*query.id() == copy.id()); + REQUIRE(*query.terms() == *copy.terms()); + REQUIRE(*query.term_ids() == *copy.term_ids()); + REQUIRE(query.thresholds() == copy.thresholds()); } - SECTION("Double at the end") +} + +TEST_CASE("Filter terms") +{ + SECTION("Both terms and IDs") { - std::vector terms{1, 2, 2, 2, 4, 4}; - auto counts = unique_with_counts(terms.begin(), terms.end()); - REQUIRE(counts == std::vector{1, 3, 2}); - terms.resize(counts.size()); - REQUIRE(terms == std::vector{1, 2, 4}); + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house", + "terms": ["brooklyn", "tea", "house"], + "term_ids": [1, 0, 3], + "thresholds": [{"k": 10, "score": 10.0}] + } + )"); + SECTION("First") + { + query.filter_terms(std::vector{0}); + REQUIRE(*query.terms() == std::vector{"brooklyn"}); + REQUIRE(*query.term_ids() == std::vector{1}); + } + SECTION("Second") + { + query.filter_terms(std::vector{1}); + REQUIRE(*query.terms() == std::vector{"tea"}); + REQUIRE(*query.term_ids() == std::vector{0}); + } + SECTION("Third") + { + query.filter_terms(std::vector{2}); + REQUIRE(*query.terms() == std::vector{"house"}); + REQUIRE(*query.term_ids() == std::vector{3}); + } + } + SECTION("Only terms") + { + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house", + "terms": ["brooklyn", "tea", "house"], + "thresholds": [{"k": 10, "score": 10.0}] + } + )"); + query.filter_terms(std::vector{1}); + REQUIRE(*query.terms() == std::vector{"tea"}); + } + SECTION("Only IDs") + { + auto query = QueryContainer::from_json(R"( + { + "id": "ID", + "query": "brooklyn tea house", + "term_ids": [1, 0, 3], + "thresholds": [{"k": 10, "score": 10.0}] + } + )"); + query.filter_terms(std::vector{1}); + REQUIRE(*query.term_ids() == std::vector{0}); } } diff --git a/test/test_ranked_queries.cpp b/test/test_ranked_queries.cpp index 9db4c6acb..be3e78f20 100644 --- a/test/test_ranked_queries.cpp +++ b/test/test_ranked_queries.cpp @@ -40,9 +40,9 @@ struct IndexData { builder.build(index); term_id_vec q; - std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); + std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries.jl"); auto push_query = [&](std::string const& query_line) { - queries.push_back(parse_query_ids(query_line)); + queries.push_back(QueryContainer::from_json(query_line)); }; io::for_each_line(qfile, push_query); @@ -63,7 +63,7 @@ struct IndexData { binary_freq_collection collection; binary_collection document_sizes; Index index; - std::vector queries; + std::vector queries; wand_data wdata; }; @@ -123,9 +123,9 @@ TEMPLATE_TEST_CASE( auto scorer = scorer::from_params(ScorerParams(s_name), data->wdata); for (auto const& q: data->queries) { - or_q(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); + or_q(make_scored_cursors(data->index, *scorer, q.query(10)), data->index.num_docs()); op_q( - make_block_max_scored_cursors(data->index, data->wdata, *scorer, q), + make_block_max_scored_cursors(data->index, data->wdata, *scorer, q.query(10)), data->index.num_docs()); topk_1.finalize(); topk_2.finalize(); @@ -158,9 +158,9 @@ TEMPLATE_TEST_CASE("Ranked AND query test", "[query][ranked][integration]", bloc auto scorer = scorer::from_params(ScorerParams(s_name), data->wdata); for (auto const& q: data->queries) { - and_q(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); + and_q(make_scored_cursors(data->index, *scorer, q.query(10)), data->index.num_docs()); op_q( - make_block_max_scored_cursors(data->index, data->wdata, *scorer, q), + make_block_max_scored_cursors(data->index, data->wdata, *scorer, q.query(10)), data->index.num_docs()); topk_1.finalize(); topk_2.finalize(); @@ -191,8 +191,8 @@ TEST_CASE("Top k") auto scorer = scorer::from_params(ScorerParams(s_name), data->wdata); for (auto const& q: data->queries) { - or_10(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); - or_1(make_scored_cursors(data->index, *scorer, q), data->index.num_docs()); + or_10(make_scored_cursors(data->index, *scorer, q.query(10)), data->index.num_docs()); + or_1(make_scored_cursors(data->index, *scorer, q.query(1)), data->index.num_docs()); topk_1.finalize(); topk_2.finalize(); if (not or_10.topk().empty()) { diff --git a/test/test_term_resolver.cpp b/test/test_term_resolver.cpp index 913060b85..f1d5ecd24 100644 --- a/test/test_term_resolver.cpp +++ b/test/test_term_resolver.cpp @@ -51,7 +51,7 @@ TEST_CASE("Filter queries") { std::ofstream json_out(json_input.c_str()); for (auto&& query: queries) { - json_out << query.to_json() << '\n'; + json_out << query.to_json_string() << '\n'; } } std::ostringstream output; diff --git a/test/test_tokenizer.cpp b/test/test_tokenizer.cpp index ca670909b..ba9560857 100644 --- a/test/test_tokenizer.cpp +++ b/test/test_tokenizer.cpp @@ -8,7 +8,9 @@ #include #include "payload_vector.hpp" -#include "query/queries.hpp" +#include "query.hpp" +#include "query/query_parser.hpp" +#include "query/term_resolver.hpp" #include "temporary_directory.hpp" #include "tokenizer.hpp" @@ -33,7 +35,7 @@ TEST_CASE("Parse query terms to ids") .to_file(lexfile.string()); auto [query, id, parsed] = - GENERATE(table, std::vector>( + GENERATE(table, std::vector>( {{"17:obama family tree", "17", {1, 3}}, {"obama family tree", std::nullopt, {1, 3}}, {"obama, family, trees", std::nullopt, {1, 3}}, @@ -41,8 +43,10 @@ TEST_CASE("Parse query terms to ids") {"lol's", std::nullopt, {0}}, {"U.S.A.!?", std::nullopt, {4}}})); CAPTURE(query); - TermProcessor term_processor(std::make_optional(lexfile.string()), std::nullopt, "krovetz"); - auto q = parse_query_terms(query, term_processor); - REQUIRE(q.id == id); - REQUIRE(q.terms == parsed); + QueryParser parser( + StandardTermResolver(lexfile.string(), std::nullopt, std::make_optional("krovetz"))); + auto query_container = QueryContainer::from_colon_format(query); + query_container.parse(parser); + REQUIRE(query_container.id() == id); + REQUIRE(*query_container.term_ids() == parsed); } diff --git a/tools/app.hpp b/tools/app.hpp index e5d83b018..3ad42b92a 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -14,6 +14,7 @@ #include "io.hpp" #include "query.hpp" #include "query/queries.hpp" +#include "query/query_parser.hpp" #include "query/term_resolver.hpp" #include "scorer/scorer.hpp" #include "sharding.hpp" @@ -91,7 +92,7 @@ namespace arg { return std::nullopt; } - [[nodiscard]] auto term_resolver() -> std::optional + [[nodiscard]] auto term_resolver() const -> std::optional { if (term_lexicon()) { return StandardTermResolver(*term_lexicon(), stop_words(), stemmer()); @@ -99,6 +100,22 @@ namespace arg { return std::nullopt; } + [[nodiscard]] auto resolved_queries() const -> std::vector<::pisa::QueryContainer> + { + auto term_resolver = this->term_resolver(); + std::vector<::pisa::QueryContainer> queries; + query_reader().for_each([&](auto query) { + if (not query.term_ids()) { + if (not term_resolver) { + throw MissingResolverError{}; + } + query.parse(QueryParser(*term_resolver)); + } + queries.push_back(std::move(query)); + }); + return queries; + } + [[nodiscard]] auto queries() const -> std::vector<::pisa::QueryContainer> { std::vector<::pisa::QueryContainer> queries; diff --git a/tools/compute_intersection.cpp b/tools/compute_intersection.cpp index 9d4338470..ce29ce116 100644 --- a/tools/compute_intersection.cpp +++ b/tools/compute_intersection.cpp @@ -6,6 +6,7 @@ #include "mappable/mapper.hpp" #include #include +#include #include #include #include @@ -27,7 +28,7 @@ void intersect( std::optional const& wand_data_filename, QueryRange&& queries, IntersectionType intersection_type, - std::optional max_term_count = std::nullopt) + std::optional max_term_count) { IndexType index; mio::mmap_source m(index_filename.c_str()); @@ -46,30 +47,31 @@ void intersect( mapper::map(wdata, md, mapper::map_flags::warmup); } - std::size_t qid = 0U; - - auto print_intersection = [&](auto const& query, auto const& mask) { - auto intersection = Intersection::compute(index, wdata, query, mask); - std::cout << fmt::format( - "{}\t{}\t{}\t{}\n", - query.id ? *query.id : std::to_string(qid), - mask.to_ulong(), - intersection.length, - intersection.max_score); - }; - for (auto&& query: queries) { if (intersection_type == IntersectionType::Combinations) { - for_all_subsets(query, max_term_count, print_intersection); + auto intersections = nlohmann::json::array(); + auto process_intersection = [&](auto const& query, auto const& mask) { + auto intersection = Intersection::compute(index, wdata, query, mask); + intersections.push_back(nlohmann::json{ + {"length", intersection.length}, + {"max_score", intersection.max_score}, + {"mask", mask.to_ulong()}}); + }; + for_all_subsets(query, max_term_count, process_intersection); + auto output = + nlohmann::json{{"query", query.to_json()}, {"intersections", intersections}}; + std::cout << output.dump() << '\n'; } else { auto intersection = Intersection::compute(index, wdata, query); - std::cout << fmt::format( - "{}\t{}\t{}\n", - query.id ? *query.id : std::to_string(qid), - intersection.length, - intersection.max_score); + auto query_json = query.to_json(); + auto intersection_json = nlohmann::json::object(); + intersection_json["length"] = intersection.length; + intersection_json["max_score"] = intersection.max_score; + intersection_json["mask"] = (1U << query.term_ids()->size()) - 1; + auto output = nlohmann::json{ + {"query", query_json}, {"intersections", nlohmann::json::array({intersection_json})}}; + std::cout << output.dump() << '\n'; } - qid += 1; } } @@ -80,11 +82,11 @@ int main(int argc, const char** argv) spdlog::drop(""); spdlog::set_default_logger(spdlog::stderr_color_mt("")); - std::optional max_term_count; + std::optional max_term_count; std::size_t min_query_len = 0; std::size_t max_query_len = std::numeric_limits::max(); bool combinations = false; - bool header = false; + // bool header = false; App> app{ "Computes intersections of posting lists."}; @@ -97,23 +99,14 @@ int main(int argc, const char** argv) ->needs(combinations_flag); app.add_option("--min-query-len", min_query_len, "Minimum query length"); app.add_option("--max-query-len", max_query_len, "Maximum query length"); - app.add_flag("--header", header, "Write TSV header"); CLI11_PARSE(app, argc, argv); - auto queries = app.queries(); + auto queries = app.resolved_queries(); auto filtered_queries = ranges::views::filter(queries, [&](auto&& query) { - auto size = query.terms.size(); - return size < min_query_len || size > max_query_len; + auto size = query.term_ids()->size(); + return size >= min_query_len || size <= max_query_len; }); - if (header) { - if (combinations) { - std::cout << "qid\tterm_mask\tlength\tmax_score\n"; - } else { - std::cout << "qid\tlength\tmax_score\n"; - } - } - IntersectionType intersection_type = combinations ? IntersectionType::Combinations : IntersectionType::Query; diff --git a/tools/evaluate_queries.cpp b/tools/evaluate_queries.cpp index c8c9e2a99..dd2d8f8ff 100644 --- a/tools/evaluate_queries.cpp +++ b/tools/evaluate_queries.cpp @@ -21,6 +21,7 @@ #include "cursor/scored_cursor.hpp" #include "index_types.hpp" #include "io.hpp" +#include "query.hpp" #include "query/algorithm.hpp" #include "scorer/scorer.hpp" #include "util/util.hpp" @@ -34,7 +35,7 @@ template void evaluate_queries( const std::string& index_filename, const std::optional& wand_data_filename, - const std::vector& queries, + const std::vector& queries, const std::optional& thresholds_filename, std::string const& type, std::string const& query_type, @@ -63,10 +64,10 @@ void evaluate_queries( mapper::map(wdata, md, mapper::map_flags::warmup); } - std::function>(Query)> query_fun; + std::function>(QueryRequest)> query_fun; if (query_type == "wand" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryRequest query) { topk_queue topk(k); wand_query wand_q(topk); wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); @@ -74,7 +75,7 @@ void evaluate_queries( return topk.topk(); }; } else if (query_type == "block_max_wand" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryRequest query) { topk_queue topk(k); block_max_wand_query block_max_wand_q(topk); block_max_wand_q( @@ -83,7 +84,7 @@ void evaluate_queries( return topk.topk(); }; } else if (query_type == "block_max_maxscore" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryRequest query) { topk_queue topk(k); block_max_maxscore_query block_max_maxscore_q(topk); block_max_maxscore_q( @@ -92,7 +93,7 @@ void evaluate_queries( return topk.topk(); }; } else if (query_type == "block_max_ranked_and" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryRequest query) { topk_queue topk(k); block_max_ranked_and_query block_max_ranked_and_q(topk); block_max_ranked_and_q( @@ -101,7 +102,7 @@ void evaluate_queries( return topk.topk(); }; } else if (query_type == "ranked_and" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryRequest query) { topk_queue topk(k); ranked_and_query ranked_and_q(topk); ranked_and_q(make_scored_cursors(index, *scorer, query), index.num_docs()); @@ -109,7 +110,7 @@ void evaluate_queries( return topk.topk(); }; } else if (query_type == "ranked_or" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryRequest query) { topk_queue topk(k); ranked_or_query ranked_or_q(topk); ranked_or_q(make_scored_cursors(index, *scorer, query), index.num_docs()); @@ -117,7 +118,7 @@ void evaluate_queries( return topk.topk(); }; } else if (query_type == "maxscore" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryRequest query) { topk_queue topk(k); maxscore_query maxscore_q(topk); maxscore_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); @@ -125,23 +126,25 @@ void evaluate_queries( return topk.topk(); }; } else if (query_type == "ranked_or_taat" && wand_data_filename) { - query_fun = [&, accumulator = Simple_Accumulator(index.num_docs())](Query query) mutable { - topk_queue topk(k); - ranked_or_taat_query ranked_or_taat_q(topk); - ranked_or_taat_q( - make_scored_cursors(index, *scorer, query), index.num_docs(), accumulator); - topk.finalize(); - return topk.topk(); - }; + query_fun = + [&, accumulator = Simple_Accumulator(index.num_docs())](QueryRequest query) mutable { + topk_queue topk(k); + ranked_or_taat_query ranked_or_taat_q(topk); + ranked_or_taat_q( + make_scored_cursors(index, *scorer, query), index.num_docs(), accumulator); + topk.finalize(); + return topk.topk(); + }; } else if (query_type == "ranked_or_taat_lazy" && wand_data_filename) { - query_fun = [&, accumulator = Lazy_Accumulator<4>(index.num_docs())](Query query) mutable { - topk_queue topk(k); - ranked_or_taat_query ranked_or_taat_q(topk); - ranked_or_taat_q( - make_scored_cursors(index, *scorer, query), index.num_docs(), accumulator); - topk.finalize(); - return topk.topk(); - }; + query_fun = + [&, accumulator = Lazy_Accumulator<4>(index.num_docs())](QueryRequest query) mutable { + topk_queue topk(k); + ranked_or_taat_query ranked_or_taat_q(topk); + ranked_or_taat_q( + make_scored_cursors(index, *scorer, query), index.num_docs(), accumulator); + topk.finalize(); + return topk.topk(); + }; } else { spdlog::error("Unsupported query type: {}", query_type); } @@ -152,13 +155,13 @@ void evaluate_queries( std::vector>> raw_results(queries.size()); auto start_batch = std::chrono::steady_clock::now(); tbb::parallel_for(size_t(0), queries.size(), [&, query_fun](size_t query_idx) { - raw_results[query_idx] = query_fun(queries[query_idx]); + raw_results[query_idx] = query_fun(queries[query_idx].query(k)); }); auto end_batch = std::chrono::steady_clock::now(); for (size_t query_idx = 0; query_idx < raw_results.size(); ++query_idx) { auto results = raw_results[query_idx]; - auto qid = queries[query_idx].id; + auto qid = queries[query_idx].id(); for (auto&& [rank, result]: enumerate(results)) { std::cout << fmt::format( "{}\t{}\t{}\t{}\t{}\t{}\n", @@ -208,10 +211,21 @@ int main(int argc, const char** argv) auto iteration = "Q0"; + std::vector queries; + try { + queries = app.resolved_queries(); + } catch (pisa::MissingResolverError err) { + spdlog::error("Unresoved queries (without IDs) require term lexicon."); + std::exit(1); + } catch (std::runtime_error const& err) { + spdlog::error(err.what()); + std::exit(1); + } + auto params = std::make_tuple( app.index_filename(), app.wand_data_path(), - app.queries(), + queries, app.thresholds_file(), app.index_encoding(), app.algorithm(), diff --git a/tools/map_queries.cpp b/tools/map_queries.cpp index 244b53ea5..4d70ab702 100644 --- a/tools/map_queries.cpp +++ b/tools/map_queries.cpp @@ -25,11 +25,12 @@ int main(int argc, const char** argv) using boost::adaptors::transformed; using boost::algorithm::join; - for (auto&& q: app.queries()) { - if (query_id and q.id) { - std::cout << *(q.id) << ":"; + for (auto&& q: app.resolved_queries()) { + if (query_id and q.id()) { + std::cout << *q.id() << ":"; } - std::cout << join(q.terms | transformed([](auto d) { return std::to_string(d); }), separator) - << '\n'; + std::cout + << join(*q.term_ids() | transformed([](auto d) { return std::to_string(d); }), separator) + << '\n'; } } diff --git a/tools/profile_queries.cpp b/tools/profile_queries.cpp index 64d17140c..7327fa093 100644 --- a/tools/profile_queries.cpp +++ b/tools/profile_queries.cpp @@ -65,7 +65,7 @@ void profile( const std::string index_filename, const std::optional& wand_data_filename, - std::vector const& queries, + std::vector const& queries, std::string const& type, std::string const& query_type) { @@ -98,44 +98,45 @@ void profile( for (auto const& t: query_types) { spdlog::info("Query type: {}", t); - std::function query_fun; + std::function query_fun; if (t == "and") { - query_fun = [&](Query query) { + query_fun = [&](QueryContainer query) { and_query and_q; return and_q( - make_cursors::type>(index, query), + make_cursors::type>( + index, query.query(query::unlimited)), index.num_docs()) .size(); }; } else if (t == "ranked_and" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryContainer query) { topk_queue topk(10); ranked_and_query ranked_and_q(topk); ranked_and_q( make_scored_cursors::type>( - index, *scorer, query), + index, *scorer, query.query(10)), index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "wand" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryContainer query) { topk_queue topk(10); wand_query wand_q(topk); wand_q( make_max_scored_cursors::type, WandType>( - index, wdata, *scorer, query), + index, wdata, *scorer, query.query(10)), index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "maxscore" && wand_data_filename) { - query_fun = [&](Query query) { + query_fun = [&](QueryContainer query) { topk_queue topk(10); maxscore_query maxscore_q(topk); maxscore_q( make_max_scored_cursors::type, WandType>( - index, wdata, *scorer, query), + index, wdata, *scorer, query.query(10)), index.num_docs()); topk.finalize(); return topk.topk().size(); @@ -163,7 +164,7 @@ int main(int argc, const char** argv) args++; } - std::vector queries; + std::vector queries; term_id_vec q; if (std::string(argv[args]) == "--file") { args++; @@ -171,13 +172,15 @@ int main(int argc, const char** argv) std::filebuf fb; if (fb.open(argv[args], std::ios::in) != nullptr) { std::istream is(&fb); - while (read_query(q, is)) { - queries.push_back({std::nullopt, q, {}}); + std::string query_line; + while (std::getline(is, query_line)) { + queries.push_back(QueryContainer::from_colon_format(query_line)); } } } else { - while (read_query(q)) { - queries.push_back({std::nullopt, q, {}}); + std::string query_line; + while (std::getline(std::cin, query_line)) { + queries.push_back(QueryContainer::from_colon_format(query_line)); } } diff --git a/tools/queries.cpp b/tools/queries.cpp index cd37f3e0d..ea37aba2f 100644 --- a/tools/queries.cpp +++ b/tools/queries.cpp @@ -35,7 +35,7 @@ using ranges::views::enumerate; template void extract_times( Fn fn, - std::vector const& queries, + std::vector const& queries, std::vector const& thresholds, std::string const& index_type, std::string const& query_type, @@ -51,14 +51,14 @@ void extract_times( .count(); }); auto mean = std::accumulate(times.begin(), times.end(), std::size_t{0}, std::plus<>()) / runs; - os << fmt::format("{}\t{}\n", query.id.value_or(std::to_string(qid)), mean); + os << fmt::format("{}\t{}\n", query.id().value_or(std::to_string(qid)), mean); } } template void op_perftest( Functor query_func, - std::vector const& queries, + std::vector const& queries, std::vector const& thresholds, std::string const& index_type, std::string const& query_type, @@ -118,7 +118,7 @@ template void perftest( const std::string& index_filename, const std::optional& wand_data_filename, - const std::vector& queries, + const std::vector& queries, const std::optional& thresholds_filename, std::string const& type, std::string const& query_type, @@ -134,8 +134,8 @@ void perftest( spdlog::info("Warming up posting lists"); std::unordered_set warmed_up; - for (auto const& q: queries) { - for (auto t: q.terms) { + for (auto&& query: queries) { + for (auto t: *query.term_ids()) { if (!warmed_up.count(t)) { index.warmup(t); warmed_up.insert(t); @@ -179,85 +179,90 @@ void perftest( for (auto&& t: query_types) { spdlog::info("Query type: {}", t); - std::function query_fun; + std::function query_fun; if (t == "and") { - query_fun = [&](Query query, Threshold) { + query_fun = [&](QueryContainer const& query, Threshold) { and_query and_q; - return and_q(make_cursors(index, query), index.num_docs()).size(); + return and_q(make_cursors(index, query.query(k)), index.num_docs()).size(); }; } else if (t == "or") { - query_fun = [&](Query query, Threshold) { + query_fun = [&](QueryContainer const& query, Threshold) { or_query or_q; - return or_q(make_cursors(index, query), index.num_docs()); + return or_q(make_cursors(index, query.query(k)), index.num_docs()); }; } else if (t == "or_freq") { - query_fun = [&](Query query, Threshold) { + query_fun = [&](QueryContainer const& query, Threshold) { or_query or_q; - return or_q(make_cursors(index, query), index.num_docs()); + return or_q(make_cursors(index, query.query(k)), index.num_docs()); }; } else if (t == "wand" && wand_data_filename) { - query_fun = [&](Query query, Threshold t) { + query_fun = [&](QueryContainer const& query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); wand_query wand_q(topk); - wand_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + wand_q( + make_max_scored_cursors(index, wdata, *scorer, query.query(k)), index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "block_max_wand" && wand_data_filename) { - query_fun = [&](Query query, Threshold t) { + query_fun = [&](QueryContainer const& query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); block_max_wand_query block_max_wand_q(topk); block_max_wand_q( - make_block_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + make_block_max_scored_cursors(index, wdata, *scorer, query.query(k)), + index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "block_max_maxscore" && wand_data_filename) { - query_fun = [&](Query query, Threshold t) { + query_fun = [&](QueryContainer const& query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); block_max_maxscore_query block_max_maxscore_q(topk); block_max_maxscore_q( - make_block_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + make_block_max_scored_cursors(index, wdata, *scorer, query.query(k)), + index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "ranked_and" && wand_data_filename) { - query_fun = [&](Query query, Threshold t) { + query_fun = [&](QueryContainer const& query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); ranked_and_query ranked_and_q(topk); - ranked_and_q(make_scored_cursors(index, *scorer, query), index.num_docs()); + ranked_and_q(make_scored_cursors(index, *scorer, query.query(k)), index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "block_max_ranked_and" && wand_data_filename) { - query_fun = [&](Query query, Threshold t) { + query_fun = [&](QueryContainer const& query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); block_max_ranked_and_query block_max_ranked_and_q(topk); block_max_ranked_and_q( - make_block_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + make_block_max_scored_cursors(index, wdata, *scorer, query.query(k)), + index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "ranked_or" && wand_data_filename) { - query_fun = [&](Query query, Threshold t) { + query_fun = [&](QueryContainer const& query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); ranked_or_query ranked_or_q(topk); - ranked_or_q(make_scored_cursors(index, *scorer, query), index.num_docs()); + ranked_or_q(make_scored_cursors(index, *scorer, query.query(k)), index.num_docs()); topk.finalize(); return topk.topk().size(); }; } else if (t == "maxscore" && wand_data_filename) { - query_fun = [&](Query query, Threshold t) { + query_fun = [&](QueryContainer const& query, Threshold t) { topk_queue topk(k); topk.set_threshold(t); maxscore_query maxscore_q(topk); - maxscore_q(make_max_scored_cursors(index, wdata, *scorer, query), index.num_docs()); + maxscore_q( + make_max_scored_cursors(index, wdata, *scorer, query.query(k)), index.num_docs()); topk.finalize(); return topk.topk().size(); }; @@ -265,10 +270,11 @@ void perftest( Simple_Accumulator accumulator(index.num_docs()); topk_queue topk(k); ranked_or_taat_query ranked_or_taat_q(topk); - query_fun = [&, ranked_or_taat_q, accumulator](Query query, Threshold t) mutable { + query_fun = [&, ranked_or_taat_q, accumulator]( + QueryContainer const& query, Threshold t) mutable { topk.set_threshold(t); ranked_or_taat_q( - make_scored_cursors(index, *scorer, query), index.num_docs(), accumulator); + make_scored_cursors(index, *scorer, query.query(k)), index.num_docs(), accumulator); topk.finalize(); return topk.topk().size(); }; @@ -276,10 +282,11 @@ void perftest( Lazy_Accumulator<4> accumulator(index.num_docs()); topk_queue topk(k); ranked_or_taat_query ranked_or_taat_q(topk); - query_fun = [&, ranked_or_taat_q, accumulator](Query query, Threshold t) mutable { + query_fun = [&, ranked_or_taat_q, accumulator]( + QueryContainer const& query, Threshold t) mutable { topk.set_threshold(t); ranked_or_taat_q( - make_scored_cursors(index, *scorer, query), index.num_docs(), accumulator); + make_scored_cursors(index, *scorer, query.query(k)), index.num_docs(), accumulator); topk.finalize(); return topk.topk().size(); }; @@ -324,10 +331,20 @@ int main(int argc, const char** argv) std::cout << "qid\tusec\n"; } + std::vector queries; + try { + queries = app.resolved_queries(); + } catch (pisa::MissingResolverError err) { + spdlog::error("Unresoved queries (without IDs) require term lexicon."); + std::exit(1); + } catch (std::runtime_error const& err) { + spdlog::error(err.what()); + std::exit(1); + } auto params = std::make_tuple( app.index_filename(), app.wand_data_path(), - app.queries(), + queries, app.thresholds_file(), app.index_encoding(), app.algorithm(), diff --git a/tools/selective_queries.cpp b/tools/selective_queries.cpp index 112f9e5a2..20a912bf4 100644 --- a/tools/selective_queries.cpp +++ b/tools/selective_queries.cpp @@ -16,7 +16,9 @@ using namespace pisa; template void selective_queries( - const std::string& index_filename, std::string const& encoding, std::vector const& queries) + const std::string& index_filename, + std::string const& encoding, + std::vector const& queries) { IndexType index; spdlog::info("Loading index from {}", index_filename); @@ -28,14 +30,17 @@ void selective_queries( using boost::adaptors::transformed; using boost::algorithm::join; for (auto const& query: queries) { - size_t and_results = and_query()(make_cursors(index, query), index.num_docs()).size(); - size_t or_results = or_query()(make_cursors(index, query), index.num_docs()); + size_t and_results = + and_query()(make_cursors(index, query.query(pisa::query::unlimited)), index.num_docs()) + .size(); + size_t or_results = or_query()( + make_cursors(index, query.query(pisa::query::unlimited)), index.num_docs()); double selectiveness = double(and_results) / double(or_results); if (selectiveness < 0.005) { - std::cout - << join(query.terms | transformed([](auto d) { return std::to_string(d); }), " ") - << '\n'; + std::cout << join( + *query.term_ids() | transformed([](auto d) { return std::to_string(d); }), " ") + << '\n'; } } } @@ -52,7 +57,7 @@ int main(int argc, const char** argv) else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ { \ selective_queries( \ - app.index_filename(), app.index_encoding(), app.queries()); + app.index_filename(), app.index_encoding(), app.resolved_queries()); /**/ BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); diff --git a/tools/thresholds.cpp b/tools/thresholds.cpp index fbf82bfa1..019220215 100644 --- a/tools/thresholds.cpp +++ b/tools/thresholds.cpp @@ -53,7 +53,7 @@ void thresholds( } topk_queue topk(k); wand_query wand_q(topk); - queries.for_each([](auto&& query) { + queries.for_each([&](auto&& query) { wand_q(make_max_scored_cursors(index, wdata, *scorer, query.query(k)), index.num_docs()); topk.finalize(); auto results = topk.topk(); @@ -97,19 +97,21 @@ int main(int argc, const char** argv) /**/ if (false) { -#define LOOP_BODY(R, DATA, T) \ - } \ - else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ - { \ - if (app.is_wand_compressed()) { \ - if (quantized) { \ - std::apply( \ - thresholds, params); \ - } else { \ - std::apply(thresholds, params); \ - } \ - } else { \ - std::apply(thresholds, params); \ +#define LOOP_BODY(R, DATA, T) \ + } \ + else if (app.index_encoding() == BOOST_PP_STRINGIZE(T)) \ + { \ + if (app.is_wand_compressed()) { \ + if (quantized) { \ + std::apply( \ + thresholds, \ + std::move(params)); \ + } else { \ + std::apply( \ + thresholds, std::move(params)); \ + } \ + } else { \ + std::apply(thresholds, std::move(params)); \ } /**/ BOOST_PP_SEQ_FOR_EACH(LOOP_BODY, _, PISA_INDEX_TYPES); From 3e4a58c701f2d45a9d8ea508703f1a6bbdf95ed8 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 6 May 2020 13:42:57 +0000 Subject: [PATCH 14/21] Add evaluate_queries CLI test --- test/cli/test_evaluate_queries.sh | 57 +++++++++++++++++++++++++++++++ 1 file changed, 57 insertions(+) create mode 100644 test/cli/test_evaluate_queries.sh diff --git a/test/cli/test_evaluate_queries.sh b/test/cli/test_evaluate_queries.sh new file mode 100644 index 000000000..7e39fc10c --- /dev/null +++ b/test/cli/test_evaluate_queries.sh @@ -0,0 +1,57 @@ +#!/usr/bin/env bats + +PISA_BIN="bin" +export PATH="$PISA_BIN:$PATH" + +. "$BATS_TEST_DIRNAME/common.sh" + +function setup { + write_lines "$BATS_TMPDIR/queries.txt" "brooklyn tea house" "the" 'Tell your dog I said "hi"' + write_lines "$BATS_TMPDIR/queries_with_ids.txt" "2:brooklyn tea house" "0:the" '1:Tell your dog I said "hi"' + jq '{"query": .}' "$BATS_TMPDIR/queries.txt" -Rc > "$BATS_TMPDIR/queries.jl" + jq '{"id": .|split(":")[0], "query": .|split(":")[1] }' "$BATS_TMPDIR/queries_with_ids.txt" -Rc \ + > "$BATS_TMPDIR/queries_with_ids.jl" +} + +@test "From echo - ID is 0" { + result=$(echo "brooklyn tea house" | ./bin/evaluate_queries -e block_simdbp -i ./simdbp --stemmer porter2 \ + -k 3 -a wand --scorer bm25 --documents ./fwd.doclex --terms ./fwd.termlex -w ./bm25.bmw \ + | cut -f1) + expected=$(printf "0\n0\n0") + [[ "$result" = "$expected" ]] +} + +@test "From plan text - consecutive IDs" { + result=$(cat $BATS_TMPDIR/queries.txt | ./bin/evaluate_queries -e block_simdbp -i ./simdbp --stemmer porter2 \ + -k 3 -a wand --scorer bm25 --documents ./fwd.doclex --terms ./fwd.termlex -w ./bm25.bmw \ + | cut -f1) + expected=$(printf "0\n0\n0\n1\n1\n1\n2\n2\n2") + [[ "$result" = "$expected" ]] +} + +@test "From plan text - predefined IDs" { + result=$(cat $BATS_TMPDIR/queries_with_ids.txt | ./bin/evaluate_queries -e block_simdbp \ + -i ./simdbp --stemmer porter2 -k 3 -a wand --scorer bm25 \ + --documents ./fwd.doclex --terms ./fwd.termlex -w ./bm25.bmw \ + | cut -f1) + expected=$(printf "2\n2\n2\n0\n0\n0\n1\n1\n1") + [[ "$result" = "$expected" ]] +} + +@test "From JSON without IDs" { + result=$(cat $BATS_TMPDIR/queries.jl | ./bin/evaluate_queries -e block_simdbp \ + -i ./simdbp --stemmer porter2 -k 3 -a wand --scorer bm25 \ + --documents ./fwd.doclex --terms ./fwd.termlex -w ./bm25.bmw \ + | cut -f1) + expected=$(printf "0\n0\n0\n1\n1\n1\n2\n2\n2") + [[ "$result" = "$expected" ]] +} + +@test "From JSON with IDs" { + result=$(cat $BATS_TMPDIR/queries_with_ids.jl | ./bin/evaluate_queries -e block_simdbp \ + -i ./simdbp --stemmer porter2 -k 3 -a wand --scorer bm25 \ + --documents ./fwd.doclex --terms ./fwd.termlex -w ./bm25.bmw \ + | cut -f1) + expected=$(printf "2\n2\n2\n0\n0\n0\n1\n1\n1") + [[ "$result" = "$expected" ]] +} From 5a8fa5e0cb8ee583c419227dcac8a18fca157b54 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Wed, 6 May 2020 16:34:06 +0000 Subject: [PATCH 15/21] Add missing include --- test/test_bmw_queries.cpp | 1 + test/test_invert.cpp | 62 +++++++++++++++++++----------------- test/test_ranked_queries.cpp | 1 + tools/app.hpp | 1 + 4 files changed, 36 insertions(+), 29 deletions(-) diff --git a/test/test_bmw_queries.cpp b/test/test_bmw_queries.cpp index 4b540a9bf..e32b70d34 100644 --- a/test/test_bmw_queries.cpp +++ b/test/test_bmw_queries.cpp @@ -8,6 +8,7 @@ #include "cursor/block_max_scored_cursor.hpp" #include "cursor/max_scored_cursor.hpp" #include "index_types.hpp" +#include "io.hpp" #include "pisa_config.hpp" #include "query/algorithm.hpp" #include "wand_data.hpp" diff --git a/test/test_invert.cpp b/test/test_invert.cpp index 458ae20f6..9ac65fb25 100644 --- a/test/test_invert.cpp +++ b/test/test_invert.cpp @@ -11,6 +11,7 @@ #include "binary_collection.hpp" #include "filesystem.hpp" #include "invert.hpp" +#include "io.hpp" #include "pisa_config.hpp" #include "temporary_directory.hpp" @@ -61,15 +62,16 @@ TEST_CASE("Join term from one index to the same term from another", "[invert][un TEST_CASE("Accumulate postings to Inverted_Index", "[invert][unit]") { - std::vector> postings = {{0_t, 0_d}, - {0_t, 1_d}, - {0_t, 2_d}, - {1_t, 0_d}, - {1_t, 0_d}, - {1_t, 0_d}, - {1_t, 0_d}, - {1_t, 1_d}, - {2_t, 5_d}}; + std::vector> postings = { + {0_t, 0_d}, + {0_t, 1_d}, + {0_t, 2_d}, + {1_t, 0_d}, + {1_t, 0_d}, + {1_t, 0_d}, + {1_t, 0_d}, + {1_t, 1_d}, + {2_t, 5_d}}; using iterator_type = decltype(postings.begin()); invert::Inverted_Index index; index(tbb::blocked_range(postings.begin(), postings.end())); @@ -98,28 +100,30 @@ TEST_CASE("Accumulate postings to Inverted_Index one by one", "[invert][unit]") } REQUIRE( index.documents - == std::unordered_map>{{0_t, {0_d, 1_d, 4_d}}, - {1_t, {2_d, 4_d}}, - {2_t, {0_d, 1_d}}, - {3_t, {0_d, 1_d, 4_d}}, - {4_t, {1_d, 4_d}}, - {5_t, {1_d, 2_d, 3_d, 4_d}}, - {6_t, {1_d, 4_d}}, - {7_t, {1_d}}, - {8_t, {2_d, 3_d, 4_d}}, - {9_t, {0_d, 2_d, 3_d, 4_d}}}); + == std::unordered_map>{ + {0_t, {0_d, 1_d, 4_d}}, + {1_t, {2_d, 4_d}}, + {2_t, {0_d, 1_d}}, + {3_t, {0_d, 1_d, 4_d}}, + {4_t, {1_d, 4_d}}, + {5_t, {1_d, 2_d, 3_d, 4_d}}, + {6_t, {1_d, 4_d}}, + {7_t, {1_d}}, + {8_t, {2_d, 3_d, 4_d}}, + {9_t, {0_d, 2_d, 3_d, 4_d}}}); REQUIRE( index.frequencies - == std::unordered_map>{{0_t, {2_f, 1_f, 1_f}}, - {1_t, {1_f, 1_f}}, - {2_t, {1_f, 1_f}}, - {3_t, {1_f, 1_f, 1_f}}, - {4_t, {2_f, 1_f}}, - {5_t, {2_f, 1_f, 1_f, 1_f}}, - {6_t, {1_f, 4_f}}, - {7_t, {1_f}}, - {8_t, {3_f, 1_f, 1_f}}, - {9_t, {1_f, 1_f, 1_f, 1_f}}}); + == std::unordered_map>{ + {0_t, {2_f, 1_f, 1_f}}, + {1_t, {1_f, 1_f}}, + {2_t, {1_f, 1_f}}, + {3_t, {1_f, 1_f, 1_f}}, + {4_t, {2_f, 1_f}}, + {5_t, {2_f, 1_f, 1_f, 1_f}}, + {6_t, {1_f, 4_f}}, + {7_t, {1_f}}, + {8_t, {3_f, 1_f, 1_f}}, + {9_t, {1_f, 1_f, 1_f, 1_f}}}); } TEST_CASE("Join Inverted_Index to another", "[invert][unit]") diff --git a/test/test_ranked_queries.cpp b/test/test_ranked_queries.cpp index be3e78f20..802b35769 100644 --- a/test/test_ranked_queries.cpp +++ b/test/test_ranked_queries.cpp @@ -8,6 +8,7 @@ #include "cursor/max_scored_cursor.hpp" #include "cursor/scored_cursor.hpp" #include "index_types.hpp" +#include "io.hpp" #include "pisa_config.hpp" #include "query/algorithm.hpp" #include "test_common.hpp" diff --git a/tools/app.hpp b/tools/app.hpp index 3ad42b92a..3854f093f 100644 --- a/tools/app.hpp +++ b/tools/app.hpp @@ -5,6 +5,7 @@ #include #include #include +#include #include #include From e1c000048b8d79c4dd6dab04e127254022c259cc Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 5 Jun 2020 16:25:19 +0000 Subject: [PATCH 16/21] Fix merge issues --- include/pisa/cursor/block_max_scored_cursor.hpp | 2 +- include/pisa/cursor/max_scored_cursor.hpp | 2 +- include/pisa/cursor/scored_cursor.hpp | 2 +- src/query/term_resolver.cpp | 3 ++- test/test_bmw_queries.cpp | 2 +- 5 files changed, 6 insertions(+), 5 deletions(-) diff --git a/include/pisa/cursor/block_max_scored_cursor.hpp b/include/pisa/cursor/block_max_scored_cursor.hpp index dbb7a3f8b..e54c11d51 100644 --- a/include/pisa/cursor/block_max_scored_cursor.hpp +++ b/include/pisa/cursor/block_max_scored_cursor.hpp @@ -57,7 +57,7 @@ template [&](auto term_id, auto weight) { auto max_weight = weight * wdata.max_term_weight(term_id); return cursor_type{ - index[term_id], wdata.getenum(term_id), weight, scorer.term_scorer(term_id), max_weight}; + index[term_id], scorer.term_scorer(term_id), weight, max_weight, wdata.getenum(term_id)}; }); return cursors; } diff --git a/include/pisa/cursor/max_scored_cursor.hpp b/include/pisa/cursor/max_scored_cursor.hpp index 1fbbf13cb..3e18b6725 100644 --- a/include/pisa/cursor/max_scored_cursor.hpp +++ b/include/pisa/cursor/max_scored_cursor.hpp @@ -47,7 +47,7 @@ template std::back_inserter(cursors), [&](auto term_id, auto weight) { auto max_weight = weight * wdata.max_term_weight(term_id); - return cursor_type{index[term_id], weight, scorer.term_scorer(term_id), max_weight}; + return cursor_type{index[term_id], scorer.term_scorer(term_id), weight, max_weight}; }); return cursors; } diff --git a/include/pisa/cursor/scored_cursor.hpp b/include/pisa/cursor/scored_cursor.hpp index 5fbfc2ddd..d133c8116 100644 --- a/include/pisa/cursor/scored_cursor.hpp +++ b/include/pisa/cursor/scored_cursor.hpp @@ -58,7 +58,7 @@ template term_weights.begin(), std::back_inserter(cursors), [&](auto term_id, auto weight) { - return cursor_type{index[term_id], weight, scorer.term_scorer(term_id)}; + return cursor_type{index[term_id], scorer.term_scorer(term_id), weight}; }); return cursors; } diff --git a/src/query/term_resolver.cpp b/src/query/term_resolver.cpp index be6dc8239..8b8e78ec4 100644 --- a/src/query/term_resolver.cpp +++ b/src/query/term_resolver.cpp @@ -1,5 +1,6 @@ #include "query/term_resolver.hpp" #include "query/query_parser.hpp" +#include "query/query_stemmer.hpp" #include "query/term_processor.hpp" namespace pisa { @@ -39,7 +40,7 @@ StandardTermResolver::StandardTermResolver( return std::nullopt; }; - m_self->transform = pisa::term_processor(stemmer_type); + m_self->transform = pisa::QueryStemmer(stemmer_type); if (stopwords_filename) { std::ifstream is(*stopwords_filename); diff --git a/test/test_bmw_queries.cpp b/test/test_bmw_queries.cpp index 74a2272e6..74f915629 100644 --- a/test/test_bmw_queries.cpp +++ b/test/test_bmw_queries.cpp @@ -46,7 +46,7 @@ struct IndexData { } builder.build(index); term_id_vec q; - std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries"); + std::ifstream qfile(PISA_SOURCE_DIR "/test/test_data/queries.jl"); auto push_query = [&](std::string const& query_line) { queries.push_back(QueryContainer::from_json(query_line)); }; From 65ae025bab5c78086eabf94398766bbf5586fc68 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 5 Jun 2020 16:29:39 +0000 Subject: [PATCH 17/21] Fix formatting --- include/pisa/cursor/block_max_scored_cursor.hpp | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/include/pisa/cursor/block_max_scored_cursor.hpp b/include/pisa/cursor/block_max_scored_cursor.hpp index e54c11d51..61ba070aa 100644 --- a/include/pisa/cursor/block_max_scored_cursor.hpp +++ b/include/pisa/cursor/block_max_scored_cursor.hpp @@ -56,8 +56,11 @@ template std::back_inserter(cursors), [&](auto term_id, auto weight) { auto max_weight = weight * wdata.max_term_weight(term_id); - return cursor_type{ - index[term_id], scorer.term_scorer(term_id), weight, max_weight, wdata.getenum(term_id)}; + return cursor_type{index[term_id], + scorer.term_scorer(term_id), + weight, + max_weight, + wdata.getenum(term_id)}; }); return cursors; } From a9a235900d160260a16c39caa092b127577abc5b Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 5 Jun 2020 16:33:59 +0000 Subject: [PATCH 18/21] Fix formatting --- test/test_intersection.cpp | 63 ++++++++++++++++++-------------------- test/test_invert.cpp | 61 ++++++++++++++++++------------------ 2 files changed, 59 insertions(+), 65 deletions(-) diff --git a/test/test_intersection.cpp b/test/test_intersection.cpp index 86917a6f4..7c482a59a 100644 --- a/test/test_intersection.cpp +++ b/test/test_intersection.cpp @@ -103,11 +103,10 @@ struct InMemoryIndex { throw std::out_of_range( fmt::format("Term {} is out of range; index contains {} terms", term_id, size())); } - return { - gsl::make_span(documents[term_id]), - gsl::make_span(frequencies[term_id]), - num_documents, - {num_documents}}; + return {gsl::make_span(documents[term_id]), + gsl::make_span(frequencies[term_id]), + num_documents, + {num_documents}}; } [[nodiscard]] auto size() const noexcept -> std::size_t { return documents.size(); } @@ -192,26 +191,25 @@ TEST_CASE("compute intersection", "[intersection][unit]") { GIVEN("Four-term query, index, and wand data object") { - InMemoryIndex index{ - { - {0}, // 0 - {0, 1, 2}, // 1 - {0}, // 2 - {0}, // 3 - {0}, // 4 - {0, 1, 4}, // 5 - {1, 4, 8}, // 6 - }, - { - {1}, // 0 - {1, 1, 1}, // 1 - {1}, // 2 - {1}, // 3 - {1}, // 4 - {1, 1, 1}, // 5 - {1, 1, 1}, // 6 - }, - 10}; + InMemoryIndex index{{ + {0}, // 0 + {0, 1, 2}, // 1 + {0}, // 2 + {0}, // 3 + {0}, // 4 + {0, 1, 4}, // 5 + {1, 4, 8}, // 6 + }, + { + {1}, // 0 + {1, 1, 1}, // 1 + {1}, // 2 + {1}, // 3 + {1}, // 4 + {1, 1, 1}, // 5 + {1, 1, 1}, // 6 + }, + 10}; InMemoryWand wand{{0.0, 1.0, 0.0, 0.0, 0.0, 5.0, 6.0}, 10}; auto query = QueryContainer::from_term_ids({6, 1, 5}); @@ -271,14 +269,13 @@ TEST_CASE("for_all_subsets", "[intersection][unit]") { CHECK( masks - == std::vector{ - Mask(0b001), - Mask(0b010), - Mask(0b011), - Mask(0b100), - Mask(0b101), - Mask(0b110), - Mask(0b111)}); + == std::vector{Mask(0b001), + Mask(0b010), + Mask(0b011), + Mask(0b100), + Mask(0b101), + Mask(0b110), + Mask(0b111)}); } } } diff --git a/test/test_invert.cpp b/test/test_invert.cpp index 9ac65fb25..94f5d1f3d 100644 --- a/test/test_invert.cpp +++ b/test/test_invert.cpp @@ -62,16 +62,15 @@ TEST_CASE("Join term from one index to the same term from another", "[invert][un TEST_CASE("Accumulate postings to Inverted_Index", "[invert][unit]") { - std::vector> postings = { - {0_t, 0_d}, - {0_t, 1_d}, - {0_t, 2_d}, - {1_t, 0_d}, - {1_t, 0_d}, - {1_t, 0_d}, - {1_t, 0_d}, - {1_t, 1_d}, - {2_t, 5_d}}; + std::vector> postings = {{0_t, 0_d}, + {0_t, 1_d}, + {0_t, 2_d}, + {1_t, 0_d}, + {1_t, 0_d}, + {1_t, 0_d}, + {1_t, 0_d}, + {1_t, 1_d}, + {2_t, 5_d}}; using iterator_type = decltype(postings.begin()); invert::Inverted_Index index; index(tbb::blocked_range(postings.begin(), postings.end())); @@ -100,30 +99,28 @@ TEST_CASE("Accumulate postings to Inverted_Index one by one", "[invert][unit]") } REQUIRE( index.documents - == std::unordered_map>{ - {0_t, {0_d, 1_d, 4_d}}, - {1_t, {2_d, 4_d}}, - {2_t, {0_d, 1_d}}, - {3_t, {0_d, 1_d, 4_d}}, - {4_t, {1_d, 4_d}}, - {5_t, {1_d, 2_d, 3_d, 4_d}}, - {6_t, {1_d, 4_d}}, - {7_t, {1_d}}, - {8_t, {2_d, 3_d, 4_d}}, - {9_t, {0_d, 2_d, 3_d, 4_d}}}); + == std::unordered_map>{{0_t, {0_d, 1_d, 4_d}}, + {1_t, {2_d, 4_d}}, + {2_t, {0_d, 1_d}}, + {3_t, {0_d, 1_d, 4_d}}, + {4_t, {1_d, 4_d}}, + {5_t, {1_d, 2_d, 3_d, 4_d}}, + {6_t, {1_d, 4_d}}, + {7_t, {1_d}}, + {8_t, {2_d, 3_d, 4_d}}, + {9_t, {0_d, 2_d, 3_d, 4_d}}}); REQUIRE( index.frequencies - == std::unordered_map>{ - {0_t, {2_f, 1_f, 1_f}}, - {1_t, {1_f, 1_f}}, - {2_t, {1_f, 1_f}}, - {3_t, {1_f, 1_f, 1_f}}, - {4_t, {2_f, 1_f}}, - {5_t, {2_f, 1_f, 1_f, 1_f}}, - {6_t, {1_f, 4_f}}, - {7_t, {1_f}}, - {8_t, {3_f, 1_f, 1_f}}, - {9_t, {1_f, 1_f, 1_f, 1_f}}}); + == std::unordered_map>{{0_t, {2_f, 1_f, 1_f}}, + {1_t, {1_f, 1_f}}, + {2_t, {1_f, 1_f}}, + {3_t, {1_f, 1_f, 1_f}}, + {4_t, {2_f, 1_f}}, + {5_t, {2_f, 1_f, 1_f, 1_f}}, + {6_t, {1_f, 4_f}}, + {7_t, {1_f}}, + {8_t, {3_f, 1_f, 1_f}}, + {9_t, {1_f, 1_f, 1_f, 1_f}}}); } TEST_CASE("Join Inverted_Index to another", "[invert][unit]") From 4189626be90e99eb90bc58eb770a983211abee72 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 5 Jun 2020 16:34:31 +0000 Subject: [PATCH 19/21] Fix formatting --- tools/compute_intersection.cpp | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/tools/compute_intersection.cpp b/tools/compute_intersection.cpp index 78a758784..79b30ea59 100644 --- a/tools/compute_intersection.cpp +++ b/tools/compute_intersection.cpp @@ -52,10 +52,9 @@ void intersect( auto intersections = nlohmann::json::array(); auto process_intersection = [&](auto const& query, auto const& mask) { auto intersection = Intersection::compute(index, wdata, query, mask); - intersections.push_back(nlohmann::json{ - {"length", intersection.length}, - {"max_score", intersection.max_score}, - {"mask", mask.to_ulong()}}); + intersections.push_back(nlohmann::json{{"length", intersection.length}, + {"max_score", intersection.max_score}, + {"mask", mask.to_ulong()}}); }; for_all_subsets(query, max_term_count, process_intersection); auto output = From c35c5a04c26ed9014037bbd81db8b4bf0d3e2ec9 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Fri, 5 Jun 2020 16:34:56 +0000 Subject: [PATCH 20/21] Fix formatting --- include/pisa/binary_collection.hpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/include/pisa/binary_collection.hpp b/include/pisa/binary_collection.hpp index 98867d3b1..b78452234 100644 --- a/include/pisa/binary_collection.hpp +++ b/include/pisa/binary_collection.hpp @@ -92,7 +92,7 @@ class base_binary_collection { auto const& operator*() const { return m_cur_seq; } - auto const* operator->() const { return &m_cur_seq; } + auto const* operator-> () const { return &m_cur_seq; } base_iterator& operator++() { From d83abe7932ed8c9d41b19cbe17879b173b003574 Mon Sep 17 00:00:00 2001 From: Michal Siedlaczek Date: Sat, 6 Jun 2020 01:26:22 +0000 Subject: [PATCH 21/21] Add missing header --- include/pisa/query/algorithm/or_query.hpp | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/include/pisa/query/algorithm/or_query.hpp b/include/pisa/query/algorithm/or_query.hpp index 75e9be4ce..647cb34da 100644 --- a/include/pisa/query/algorithm/or_query.hpp +++ b/include/pisa/query/algorithm/or_query.hpp @@ -1,8 +1,10 @@ #pragma once -#include "query/queries.hpp" +#include #include +#include "query/queries.hpp" + namespace pisa { template