Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix(search_family): Remove the output of extra fields in the FT.AGGREGATE command #4231

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
47 changes: 33 additions & 14 deletions src/server/search/aggregator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ namespace dfly::aggregate {
namespace {

struct GroupStep {
PipelineResult operator()(std::vector<DocValues> values) {
PipelineResult operator()(PipelineResult result) {
// Separate items into groups
absl::flat_hash_map<absl::FixedArray<Value>, std::vector<DocValues>> groups;
for (auto& value : values) {
for (auto& value : result.values) {
groups[Extract(value)].push_back(std::move(value));
}

Expand All @@ -28,7 +28,18 @@ struct GroupStep {
}
out.push_back(std::move(doc));
}
return out;

absl::flat_hash_set<std::string> fields_to_print;
fields_to_print.reserve(fields_.size() + reducers_.size());

for (auto& field : fields_) {
fields_to_print.insert(std::move(field));
}
for (auto& reducer : reducers_) {
fields_to_print.insert(std::move(reducer.result_field));
}

return {std::move(out), std::move(fields_to_print)};
}

absl::FixedArray<Value> Extract(const DocValues& dv) {
Expand Down Expand Up @@ -104,34 +115,42 @@ PipelineStep MakeGroupStep(absl::Span<const std::string_view> fields,
}

PipelineStep MakeSortStep(std::string_view field, bool descending) {
return [field = std::string(field), descending](std::vector<DocValues> values) -> PipelineResult {
return [field = std::string(field), descending](PipelineResult result) -> PipelineResult {
auto& values = result.values;

std::sort(values.begin(), values.end(), [field](const DocValues& l, const DocValues& r) {
auto it1 = l.find(field);
auto it2 = r.find(field);
return it1 == l.end() || (it2 != r.end() && it1->second < it2->second);
});
if (descending)

if (descending) {
std::reverse(values.begin(), values.end());
return values;
}
Comment on lines +127 to +129
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

instead of reverse, you can do correct sorting


result.fields_to_print.insert(field);
return result;
};
}

PipelineStep MakeLimitStep(size_t offset, size_t num) {
return [offset, num](std::vector<DocValues> values) -> PipelineResult {
return [offset, num](PipelineResult result) {
auto& values = result.values;
values.erase(values.begin(), values.begin() + std::min(offset, values.size()));
values.resize(std::min(num, values.size()));
return values;
return result;
};
}

PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps) {
PipelineResult Process(std::vector<DocValues> values,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please put your prod-duty hat, and think about all the VLOG statements you would like to have, if we have problems in production with this code.

absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps) {
PipelineResult result{std::move(values), {fields_to_print.begin(), fields_to_print.end()}};
for (auto& step : steps) {
auto result = step(std::move(values));
if (!result.has_value())
return result;
values = std::move(result.value());
PipelineResult step_result = step(std::move(result));
result = std::move(step_result);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I would prefer to see step method modifying the result instead of all these moves

}
return values;
return result;
}

} // namespace dfly::aggregate
17 changes: 13 additions & 4 deletions src/server/search/aggregator.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include <absl/container/flat_hash_map.h>
#include <absl/container/flat_hash_set.h>
#include <absl/types/span.h>

#include <string>
Expand All @@ -19,10 +20,16 @@ namespace dfly::aggregate {
using Value = ::dfly::search::SortableValue;
using DocValues = absl::flat_hash_map<std::string, Value>; // documents sent through the pipeline

// TODO: Replace DocValues with compact linear search map instead of hash map
struct PipelineResult {
// Values to be passed to the next step
// TODO: Replace DocValues with compact linear search map instead of hash map
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

is this TODO relevant? I do not see a hash map here

std::vector<DocValues> values;

using PipelineResult = io::Result<std::vector<DocValues>, facade::ErrorReply>;
using PipelineStep = std::function<PipelineResult(std::vector<DocValues>)>; // Group, Sort, etc.
// Fields from values to be printed
absl::flat_hash_set<std::string> fields_to_print;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why flat_hash_set

};

using PipelineStep = std::function<PipelineResult(PipelineResult)>; // Group, Sort, etc.

// Iterator over Span<DocValues> that yields doc[field] or monostate if not present.
// Extra clumsy for STL compatibility!
Expand Down Expand Up @@ -82,6 +89,8 @@ PipelineStep MakeSortStep(std::string_view field, bool descending = false);
PipelineStep MakeLimitStep(size_t offset, size_t num);

// Process values with given steps
PipelineResult Process(std::vector<DocValues> values, absl::Span<const PipelineStep> steps);
PipelineResult Process(std::vector<DocValues> values,
absl::Span<const std::string_view> fields_to_print,
absl::Span<const PipelineStep> steps);

} // namespace dfly::aggregate
52 changes: 24 additions & 28 deletions src/server/search/aggregator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,11 @@ TEST(AggregatorTest, Sort) {
};
PipelineStep steps[] = {MakeSortStep("a", false)};

auto result = Process(values, steps);
auto result = Process(values, {"a"}, steps);

EXPECT_TRUE(result);
EXPECT_EQ(result->at(0)["a"], Value(0.5));
EXPECT_EQ(result->at(1)["a"], Value(1.0));
EXPECT_EQ(result->at(2)["a"], Value(1.5));
EXPECT_EQ(result.values[0]["a"], Value(0.5));
EXPECT_EQ(result.values[1]["a"], Value(1.0));
EXPECT_EQ(result.values[2]["a"], Value(1.5));
}

TEST(AggregatorTest, Limit) {
Expand All @@ -35,12 +34,11 @@ TEST(AggregatorTest, Limit) {
};
PipelineStep steps[] = {MakeLimitStep(1, 2)};

auto result = Process(values, steps);
auto result = Process(values, {"i"}, steps);

EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
EXPECT_EQ(result->at(0)["i"], Value(2.0));
EXPECT_EQ(result->at(1)["i"], Value(3.0));
EXPECT_EQ(result.values.size(), 2);
EXPECT_EQ(result.values[0]["i"], Value(2.0));
EXPECT_EQ(result.values[1]["i"], Value(3.0));
}

TEST(AggregatorTest, SimpleGroup) {
Expand All @@ -54,12 +52,11 @@ TEST(AggregatorTest, SimpleGroup) {
std::string_view fields[] = {"tag"};
PipelineStep steps[] = {MakeGroupStep(fields, {})};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

EXPECT_EQ(result->at(0).size(), 1);
std::set<Value> groups{result->at(0)["tag"], result->at(1)["tag"]};
EXPECT_EQ(result.values[0].size(), 1);
std::set<Value> groups{result.values[0]["tag"], result.values[1]["tag"]};
std::set<Value> expected{"even", "odd"};
EXPECT_EQ(groups, expected);
}
Expand All @@ -83,25 +80,24 @@ TEST(AggregatorTest, GroupWithReduce) {
Reducer{"null-field", "distinct-null", FindReducerFunc(ReducerFunc::COUNT_DISTINCT)}};
PipelineStep steps[] = {MakeGroupStep(fields, std::move(reducers))};

auto result = Process(values, steps);
EXPECT_TRUE(result);
EXPECT_EQ(result->size(), 2);
auto result = Process(values, {"i", "half-i", "tag"}, steps);
EXPECT_EQ(result.values.size(), 2);

// Reorder even first
if (result->at(0).at("tag") == Value("odd"))
std::swap(result->at(0), result->at(1));
if (result.values[0].at("tag") == Value("odd"))
std::swap(result.values[0], result.values[1]);

// Even
EXPECT_EQ(result->at(0).at("count"), Value{(double)5});
EXPECT_EQ(result->at(0).at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result->at(0).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(0).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[0].at("count"), Value{(double)5});
EXPECT_EQ(result.values[0].at("sum-i"), Value{(double)2 + 4 + 6 + 8});
EXPECT_EQ(result.values[0].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[0].at("distinct-null"), Value{(double)1});

// Odd
EXPECT_EQ(result->at(1).at("count"), Value{(double)5});
EXPECT_EQ(result->at(1).at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result->at(1).at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result->at(1).at("distinct-null"), Value{(double)1});
EXPECT_EQ(result.values[1].at("count"), Value{(double)5});
EXPECT_EQ(result.values[1].at("sum-i"), Value{(double)1 + 3 + 5 + 7 + 9});
EXPECT_EQ(result.values[1].at("distinct-hi"), Value{(double)3});
EXPECT_EQ(result.values[1].at("distinct-null"), Value{(double)1});
}

} // namespace dfly::aggregate
31 changes: 22 additions & 9 deletions src/server/search/search_family.cc
Original file line number Diff line number Diff line change
Expand Up @@ -980,22 +980,35 @@ void SearchFamily::FtAggregate(CmdArgList args, Transaction* tx, SinkReplyBuilde
make_move_iterator(sub_results.end()));
}

auto agg_results = aggregate::Process(std::move(values), params->steps);
if (!agg_results.has_value())
return builder->SendError(agg_results.error());
std::vector<std::string_view> load_fields;
if (params->load_fields) {
load_fields.reserve(params->load_fields->size());
for (const auto& field : params->load_fields.value()) {
load_fields.push_back(field.GetShortName());
}
}

auto agg_results = aggregate::Process(std::move(values), load_fields, params->steps);

size_t result_size = agg_results->size();
auto* rb = static_cast<RedisReplyBuilder*>(builder);
auto sortable_value_sender = SortableValueSender(rb);

const size_t result_size = agg_results.values.size();
rb->StartArray(result_size + 1);
rb->SendLong(result_size);

for (const auto& result : agg_results.value()) {
rb->StartArray(result.size() * 2);
for (const auto& [k, v] : result) {
rb->SendBulkString(k);
std::visit(sortable_value_sender, v);
const size_t field_count = agg_results.fields_to_print.size();
for (const auto& value : agg_results.values) {
rb->StartArray(field_count * 2);
for (const auto& field : agg_results.fields_to_print) {
rb->SendBulkString(field);

auto it = value.find(field);
if (it != value.end()) {
Comment on lines +1006 to +1007
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: if (auto it = value.find(field); it != value.end())

std::visit(sortable_value_sender, it->second);
} else {
rb->SendNull();
}
}
}
}
Expand Down
41 changes: 33 additions & 8 deletions src/server/search/search_family_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -962,15 +962,12 @@ TEST_F(SearchFamilyTest, AggregateGroupBy) {
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("foo_total", "20", "word", "item2"),
IsMap("foo_total", "50", "word", "item1")));

/*
Temporary not supported
resp = Run({"ft.aggregate", "i1", "*", "LOAD", "2", "foo", "text", "GROUPBY", "2", "@word",
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"}); EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("foo_total", "20", "word", ArgType(RespExpr::NIL), "text", "\"second
key\""), IsMap("foo_total", "40", "word", ArgType(RespExpr::NIL), "text", "\"third key\""),
IsMap({"foo_total", "10", "word", ArgType(RespExpr::NIL), "text", "\"first key"})));
*/
"@text", "REDUCE", "SUM", "1", "@foo", "AS", "foo_total"});
EXPECT_THAT(resp, IsUnordArrayWithSize(
IsMap("foo_total", "40", "word", "item1", "text", "\"third key\""),
IsMap("foo_total", "20", "word", "item2", "text", "\"second key\""),
IsMap("foo_total", "10", "word", "item1", "text", "\"first key\"")));
}

TEST_F(SearchFamilyTest, JsonAggregateGroupBy) {
Expand Down Expand Up @@ -1632,4 +1629,32 @@ TEST_F(SearchFamilyTest, SearchLoadReturnHash) {
EXPECT_THAT(resp, IsMapWithSize("h2", IsMap("a", "two"), "h1", IsMap("a", "one")));
}

// Test that FT.AGGREGATE prints only needed fields
TEST_F(SearchFamilyTest, AggregateResultFields) {
Run({"JSON.SET", "j1", ".", R"({"a":"1","b":"2","c":"3"})"});
Run({"JSON.SET", "j2", ".", R"({"a":"4","b":"5","c":"6"})"});
Run({"JSON.SET", "j3", ".", R"({"a":"7","b":"8","c":"9"})"});

auto resp = Run({"FT.CREATE", "index", "ON", "JSON", "SCHEMA", "$.a", "AS", "a", "TEXT",
"SORTABLE", "$.b", "AS", "b", "TEXT", "$.c", "AS", "c", "TEXT"});
EXPECT_EQ(resp, "OK");

resp = Run({"FT.AGGREGATE", "index", "*"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap(), IsMap(), IsMap()));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("a", "1"), IsMap("a", "4"), IsMap("a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "LOAD", "1", "@b", "SORTBY", "1", "a"});
EXPECT_THAT(resp,
IsUnordArrayWithSize(IsMap("b", "\"2\"", "a", "1"), IsMap("b", "\"5\"", "a", "4"),
IsMap("b", "\"8\"", "a", "7")));

resp = Run({"FT.AGGREGATE", "index", "*", "SORTBY", "1", "a", "GROUPBY", "2", "@b", "@a",
"REDUCE", "COUNT", "0", "AS", "count"});
EXPECT_THAT(resp, IsUnordArrayWithSize(IsMap("b", "\"8\"", "a", "7", "count", "1"),
IsMap("b", "\"2\"", "a", "1", "count", "1"),
IsMap("b", "\"5\"", "a", "4", "count", "1")));
}

} // namespace dfly
Loading