From c108b1281b3e77ce23bb06dae0151149e4d97611 Mon Sep 17 00:00:00 2001 From: Stepan Bagritsevich Date: Fri, 29 Nov 2024 19:25:16 +0400 Subject: [PATCH] fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command fixes dragonflydb#4230 Signed-off-by: Stepan Bagritsevich --- src/server/search/aggregator.cc | 47 +++++++++++++++------- src/server/search/aggregator.h | 17 ++++++-- src/server/search/aggregator_test.cc | 52 ++++++++++++------------- src/server/search/search_family.cc | 31 ++++++++++----- src/server/search/search_family_test.cc | 41 +++++++++++++++---- 5 files changed, 125 insertions(+), 63 deletions(-) diff --git a/src/server/search/aggregator.cc b/src/server/search/aggregator.cc index 255d82e10857..4b6b4a5620cf 100644 --- a/src/server/search/aggregator.cc +++ b/src/server/search/aggregator.cc @@ -11,10 +11,10 @@ namespace dfly::aggregate { namespace { struct GroupStep { - PipelineResult operator()(std::vector values) { + PipelineResult operator()(PipelineResult result) { // Separate items into groups absl::flat_hash_map, std::vector> groups; - for (auto& value : values) { + for (auto& value : result.values) { groups[Extract(value)].push_back(std::move(value)); } @@ -28,7 +28,18 @@ struct GroupStep { } out.push_back(std::move(doc)); } - return out; + + absl::flat_hash_set fields_to_print; + fields_to_print.reserve(fields_.size() + reducers_.size()); + + for (auto& field : fields_) { + fields_to_print.insert(std::move(field)); + } + for (auto& reducer : reducers_) { + fields_to_print.insert(std::move(reducer.result_field)); + } + + return {std::move(out), std::move(fields_to_print)}; } absl::FixedArray Extract(const DocValues& dv) { @@ -104,34 +115,42 @@ PipelineStep MakeGroupStep(absl::Span fields, } PipelineStep MakeSortStep(std::string_view field, bool descending) { - return [field = std::string(field), descending](std::vector values) -> PipelineResult { + return [field = std::string(field), descending](PipelineResult result) -> PipelineResult { + auto& values = result.values; + std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) { auto it1 = l.find(field); auto it2 = r.find(field); return it1 == l.end() || (it2 != r.end() && it1->second < it2->second); }); - if (descending) + + if (descending) { std::reverse(values.begin(), values.end()); - return values; + } + + result.fields_to_print.insert(field); + return result; }; } PipelineStep MakeLimitStep(size_t offset, size_t num) { - return [offset, num](std::vector values) -> PipelineResult { + return [offset, num](PipelineResult result) { + auto& values = result.values; values.erase(values.begin(), values.begin() + std::min(offset, values.size())); values.resize(std::min(num, values.size())); - return values; + return result; }; } -PipelineResult Process(std::vector values, absl::Span steps) { +PipelineResult Process(std::vector values, + absl::Span fields_to_print, + absl::Span steps) { + PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}}; for (auto& step : steps) { - auto result = step(std::move(values)); - if (!result.has_value()) - return result; - values = std::move(result.value()); + PipelineResult step_result = step(std::move(result)); + result = std::move(step_result); } - return values; + return result; } } // namespace dfly::aggregate diff --git a/src/server/search/aggregator.h b/src/server/search/aggregator.h index 727c0ba96ed0..4f4008bce238 100644 --- a/src/server/search/aggregator.h +++ b/src/server/search/aggregator.h @@ -5,6 +5,7 @@ #pragma once #include +#include #include #include @@ -19,10 +20,16 @@ namespace dfly::aggregate { using Value = ::dfly::search::SortableValue; using DocValues = absl::flat_hash_map; // documents sent through the pipeline -// TODO: Replace DocValues with compact linear search map instead of hash map +struct PipelineResult { + // Values to be passed to the next step + // TODO: Replace DocValues with compact linear search map instead of hash map + std::vector values; -using PipelineResult = io::Result, facade::ErrorReply>; -using PipelineStep = std::function)>; // Group, Sort, etc. + // Fields from values to be printed + absl::flat_hash_set fields_to_print; +}; + +using PipelineStep = std::function; // Group, Sort, etc. // Iterator over Span that yields doc[field] or monostate if not present. // Extra clumsy for STL compatibility! @@ -82,6 +89,8 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false); PipelineStep MakeLimitStep(size_t offset, size_t num); // Process values with given steps -PipelineResult Process(std::vector values, absl::Span steps); +PipelineResult Process(std::vector values, + absl::Span fields_to_print, + absl::Span steps); } // namespace dfly::aggregate diff --git a/src/server/search/aggregator_test.cc b/src/server/search/aggregator_test.cc index d7dcc8d6061c..3ee8b58e1f5a 100644 --- a/src/server/search/aggregator_test.cc +++ b/src/server/search/aggregator_test.cc @@ -18,12 +18,11 @@ TEST(AggregatorTest, Sort) { }; PipelineStep steps[] = {MakeSortStep("a", false)}; - auto result = Process(values, steps); + auto result = Process(values, {"a"}, steps); - EXPECT_TRUE(result); - EXPECT_EQ(result->at(0)["a"], Value(0.5)); - EXPECT_EQ(result->at(1)["a"], Value(1.0)); - EXPECT_EQ(result->at(2)["a"], Value(1.5)); + EXPECT_EQ(result.values[0]["a"], Value(0.5)); + EXPECT_EQ(result.values[1]["a"], Value(1.0)); + EXPECT_EQ(result.values[2]["a"], Value(1.5)); } TEST(AggregatorTest, Limit) { @@ -35,12 +34,11 @@ TEST(AggregatorTest, Limit) { }; PipelineStep steps[] = {MakeLimitStep(1, 2)}; - auto result = Process(values, steps); + auto result = Process(values, {"i"}, steps); - EXPECT_TRUE(result); - EXPECT_EQ(result->size(), 2); - EXPECT_EQ(result->at(0)["i"], Value(2.0)); - EXPECT_EQ(result->at(1)["i"], Value(3.0)); + EXPECT_EQ(result.values.size(), 2); + EXPECT_EQ(result.values[0]["i"], Value(2.0)); + EXPECT_EQ(result.values[1]["i"], Value(3.0)); } TEST(AggregatorTest, SimpleGroup) { @@ -54,12 +52,11 @@ TEST(AggregatorTest, SimpleGroup) { std::string_view fields[] = {"tag"}; PipelineStep steps[] = {MakeGroupStep(fields, {})}; - auto result = Process(values, steps); - EXPECT_TRUE(result); - EXPECT_EQ(result->size(), 2); + auto result = Process(values, {"i", "tag"}, steps); + EXPECT_EQ(result.values.size(), 2); - EXPECT_EQ(result->at(0).size(), 1); - std::set groups{result->at(0)["tag"], result->at(1)["tag"]}; + EXPECT_EQ(result.values[0].size(), 1); + std::set groups{result.values[0]["tag"], result.values[1]["tag"]}; std::set expected{"even", "odd"}; EXPECT_EQ(groups, expected); } @@ -83,25 +80,24 @@ TEST(AggregatorTest, GroupWithReduce) { Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}}; PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))}; - auto result = Process(values, steps); - EXPECT_TRUE(result); - EXPECT_EQ(result->size(), 2); + auto result = Process(values, {"i", "half-i", "tag"}, steps); + EXPECT_EQ(result.values.size(), 2); // Reorder even first - if (result->at(0).at("tag") == Value("odd")) - std::swap(result->at(0), result->at(1)); + if (result.values[0].at("tag") == Value("odd")) + std::swap(result.values[0], result.values[1]); // Even - EXPECT_EQ(result->at(0).at("count"), Value{(double)5}); - EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8}); - EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3}); - EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1}); + EXPECT_EQ(result.values[0].at("count"), Value{(double)5}); + EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8}); + EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3}); + EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1}); // Odd - EXPECT_EQ(result->at(1).at("count"), Value{(double)5}); - EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9}); - EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3}); - EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1}); + EXPECT_EQ(result.values[1].at("count"), Value{(double)5}); + EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9}); + EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3}); + EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1}); } } // namespace dfly::aggregate diff --git a/src/server/search/search_family.cc b/src/server/search/search_family.cc index f1151dc60a67..1b37d1943de2 100644 --- a/src/server/search/search_family.cc +++ b/src/server/search/search_family.cc @@ -980,22 +980,35 @@ void SearchFamily::FtAggregate(CmdArgList args, Transaction* tx, SinkReplyBuilde make_move_iterator(sub_results.end())); } - auto agg_results = aggregate::Process(std::move(values), params->steps); - if (!agg_results.has_value()) - return builder->SendError(agg_results.error()); + std::vector load_fields; + if (params->load_fields) { + load_fields.reserve(params->load_fields->size()); + for (const auto& field : params->load_fields.value()) { + load_fields.push_back(field.GetShortName()); + } + } + + auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps); - size_t result_size = agg_results->size(); auto* rb = static_cast(builder); auto sortable_value_sender = SortableValueSender(rb); + const size_t result_size = agg_results.values.size(); rb->StartArray(result_size + 1); rb->SendLong(result_size); - for (const auto& result : agg_results.value()) { - rb->StartArray(result.size() * 2); - for (const auto& [k, v] : result) { - rb->SendBulkString(k); - std::visit(sortable_value_sender, v); + const size_t field_count = agg_results.fields_to_print.size(); + for (const auto& value : agg_results.values) { + rb->StartArray(field_count * 2); + for (const auto& field : agg_results.fields_to_print) { + rb->SendBulkString(field); + + auto it = value.find(field); + if (it != value.end()) { + std::visit(sortable_value_sender, it->second); + } else { + rb->SendNull(); + } } } } diff --git a/src/server/search/search_family_test.cc b/src/server/search/search_family_test.cc index 9fa68bd66757..fe89c412f27a 100644 --- a/src/server/search/search_family_test.cc +++ b/src/server/search/search_family_test.cc @@ -962,15 +962,12 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) { EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"), IsMap("foo_total", "50", "word", "item1"))); - /* - Temporary not supported - resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word", - "@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp, - IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second - key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""), - IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"}))); - */ + "@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); + EXPECT_THAT(resp, IsUnordArrayWithSize( + IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""), + IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""), + IsMap("foo_total", "10", "word", "item1", "text", "\"first key\""))); } TEST_F(SearchFamilyTest, JsonAggregateGroupBy) { @@ -1632,4 +1629,32 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) { EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one"))); } +// Test that FT.AGGREGATE prints only needed fields +TEST_F(SearchFamilyTest, AggregateResultFields) { + Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"}); + Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"}); + Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"}); + + auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT", + "SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"}); + EXPECT_EQ(resp, "OK"); + + resp = Run({"FT.AGGREGATE", "index", "*"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap())); + + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7"))); + + resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"}); + EXPECT_THAT(resp, + IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"), + IsMap("b", "\"8\"", "a", "7"))); + + resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a", + "REDUCE", "COUNT", "0", "AS", "count"}); + EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"), + IsMap("b", "\"2\"", "a", "1", "count", "1"), + IsMap("b", "\"5\"", "a", "4", "count", "1"))); +} + } // namespace dfly