Skip to content

Commit

Permalink
Merge pull request #11 from alkemics/search
Browse files Browse the repository at this point in the history
update aggregation tabular serialization
  • Loading branch information
alk-lbinet authored May 4, 2020
2 parents cd97086 + 3ae71d6 commit 52d5a93
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 183 deletions.
50 changes: 26 additions & 24 deletions pandagg/response.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,19 +229,10 @@ def serialize_as_tabular(
:return: index, index_names, values
"""
grouping_agg = self._grouping_agg(grouped_by)
if grouping_agg is None or isinstance(grouping_agg, ShadowRoot):
index = ((None,),) if row_as_tuple else ({},)
values = (self.data,)
index_names = [""]
if grouping_agg is None:
index_values = [(tuple() if row_as_tuple else dict(), self.data)]
index_names = []
else:
index_values = list(
self._parse_group_by(
response=self.data,
row_as_tuple=row_as_tuple,
until=grouping_agg.name,
with_single_bucket_groups=with_single_bucket_groups,
)
)
index_names = [
a.name
for a in self.__aggs.ancestors(
Expand All @@ -250,20 +241,30 @@ def serialize_as_tabular(
+ [grouping_agg]
if not isinstance(a, UniqueBucketAgg) or with_single_bucket_groups
]
index_values = list(
self._parse_group_by(
response=self.data,
row_as_tuple=row_as_tuple,
until=grouping_agg.name,
with_single_bucket_groups=with_single_bucket_groups,
)
)
if not index_values:
return [], [], []
index, values = zip(*index_values)

serialized_values = [
self.serialize_columns(
v,
normalize=normalize,
total_agg=grouping_agg,
expand_columns=expand_columns,
return [], []

rows = [
(
row_index,
self.serialize_columns(
row_values,
normalize=normalize,
total_agg=grouping_agg,
expand_columns=expand_columns,
),
)
for v in values
for row_index, row_values in index_values
]
return index, index_names, serialized_values
return index_names, rows

def serialize_columns(self, row_data, normalize, expand_columns, total_agg=None):
# extract value (usually 'doc_count') of grouping agg node
Expand Down Expand Up @@ -305,12 +306,13 @@ def serialize_as_dataframe(
'Using dataframe output format requires to install pandas. Please install "pandas" or '
"use another output format."
)
index, index_names, values = self.serialize_as_tabular(
index_names, index_values = self.serialize_as_tabular(
row_as_tuple=True,
grouped_by=grouped_by,
normalize=normalize_children,
with_single_bucket_groups=with_single_bucket_groups,
)
index, values = zip(*index_values)
if not index:
return pd.DataFrame()
if len(index[0]) == 0:
Expand Down
2 changes: 1 addition & 1 deletion pandagg/search.py
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ def update_from_dict(self, d):
"""
d = d.copy()
if "query" in d:
self._query = Query(d.pop("query"))
self._query = Query(d.pop("query"))
if "post_filter" in d:
self._post_filter = Query(d.pop("post_filter"))

Expand Down
2 changes: 2 additions & 0 deletions pandagg/tree/aggs.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def _clone_init(self, deep=False):

def _is_eligible_grouping_node(self, nid):
node = self.get(nid)
if isinstance(node, ShadowRoot):
return False
if not isinstance(node, BucketAggNode):
return False
# special aggregations not returning anything
Expand Down
62 changes: 53 additions & 9 deletions tests/test_response.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from pandagg.tree.aggs import Aggs

import tests.testing_samples.data_sample as sample
from pandagg.utils import equal_queries
from pandagg.utils import equal_queries, ordered
from tests.testing_samples.mapping_example import MAPPING


Expand Down Expand Up @@ -129,25 +129,49 @@ def test_normalize_buckets(self):
client=None,
query=None,
).serialize_as_normalized()
self.assertTrue(equal_queries(response, sample.EXPECTED_NORMALIZED_RESPONSE,))
self.assertEqual(
ordered(response), ordered(sample.EXPECTED_NORMALIZED_RESPONSE)
)

def test_parse_as_tabular(self):
# with single agg at root
my_agg = Aggs(sample.EXPECTED_AGG_QUERY, mapping=MAPPING)
index, index_names, values = Aggregations(
index_names, index_values = Aggregations(
data=sample.ES_AGG_RESPONSE,
aggs=my_agg,
index=None,
client=None,
query=None,
).serialize_as_tabular()
).serialize_as_tabular(row_as_tuple=True)

self.assertEqual(
index_names, ["classification_type", "global_metrics.field.name"]
)
self.assertEqual(len(index), len(values))
self.assertEqual(len(index), 10)
self.assertEqual(index, sample.EXPECTED_TABULAR_INDEX)
self.assertEqual(values, sample.EXPECTED_TABULAR_VALUES)
self.assertEqual(
index_values,
[
(
("multilabel", "ispracticecompatible"),
{"avg_f1_micro": 0.72, "avg_nb_classes": 18.71, "doc_count": 128},
),
(
("multilabel", "gpc"),
{"avg_f1_micro": 0.95, "avg_nb_classes": 183.21, "doc_count": 119},
),
(
("multilabel", "preservationmethods"),
{"avg_f1_micro": 0.8, "avg_nb_classes": 9.97, "doc_count": 76},
),
(
("multiclass", "kind"),
{"avg_f1_micro": 0.89, "avg_nb_classes": 206.5, "doc_count": 370},
),
(
("multiclass", "gpc"),
{"avg_f1_micro": 0.93, "avg_nb_classes": 211.12, "doc_count": 198},
),
],
)

def test_parse_as_dataframe(self):
my_agg = Aggs(sample.EXPECTED_AGG_QUERY, mapping=MAPPING)
Expand All @@ -165,4 +189,24 @@ def test_parse_as_dataframe(self):
self.assertEqual(
set(df.columns), {"avg_f1_micro", "avg_nb_classes", "doc_count"}
)
self.assertEqual(df.shape, (len(sample.EXPECTED_TABULAR_INDEX), 3))
self.assertEqual(
df.index.to_list(),
[
("multilabel", "ispracticecompatible"),
("multilabel", "gpc"),
("multilabel", "preservationmethods"),
("multiclass", "kind"),
("multiclass", "gpc"),
],
)

self.assertEqual(
df.to_dict(orient="rows"),
[
{"avg_f1_micro": 0.72, "avg_nb_classes": 18.71, "doc_count": 128},
{"avg_f1_micro": 0.95, "avg_nb_classes": 183.21, "doc_count": 119},
{"avg_f1_micro": 0.8, "avg_nb_classes": 9.97, "doc_count": 76},
{"avg_f1_micro": 0.89, "avg_nb_classes": 206.5, "doc_count": 370},
{"avg_f1_micro": 0.93, "avg_nb_classes": 211.12, "doc_count": 198},
],
)
131 changes: 0 additions & 131 deletions tests/testing_samples/data_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,36 +83,6 @@ def get_node_hierarchy():
"doc_count": 1797,
"global_metrics.field.name": {
"buckets": [
{
"avg_f1_micro": {"value": 0.83},
"avg_nb_classes": {"value": 5.20},
"doc_count": 369,
"key": "hazardpictograms",
},
{
"avg_f1_micro": {"value": 0.81},
"avg_nb_classes": {"value": 88.72},
"doc_count": 369,
"key": "islabeledby",
},
{
"avg_f1_micro": {"value": 0.41},
"avg_nb_classes": {"value": 27.57},
"doc_count": 367,
"key": "flavors",
},
{
"avg_f1_micro": {"value": 0.83},
"avg_nb_classes": {"value": 107.82},
"doc_count": 239,
"key": "hasnotableingredients",
},
{
"avg_f1_micro": {"value": 0.82},
"avg_nb_classes": {"value": 65.59},
"doc_count": 130,
"key": "allergentypelist",
},
{
"avg_f1_micro": {"value": 0.72},
"avg_nb_classes": {"value": 18.71},
Expand Down Expand Up @@ -174,24 +144,9 @@ def get_node_hierarchy():
│ ├── avg_f1_micro 0.89
│ └── avg_nb_classes 206.5
└── classification_type=multilabel 1797
├── global_metrics.field.name=allergentypelist 130
│ ├── avg_f1_micro 0.82
│ └── avg_nb_classes 65.59
├── global_metrics.field.name=flavors 367
│ ├── avg_f1_micro 0.41
│ └── avg_nb_classes 27.57
├── global_metrics.field.name=gpc 119
│ ├── avg_f1_micro 0.95
│ └── avg_nb_classes 183.21
├── global_metrics.field.name=hasnotableingredients 239
│ ├── avg_f1_micro 0.83
│ └── avg_nb_classes 107.82
├── global_metrics.field.name=hazardpictograms 369
│ ├── avg_f1_micro 0.83
│ └── avg_nb_classes 5.2
├── global_metrics.field.name=islabeledby 369
│ ├── avg_f1_micro 0.81
│ └── avg_nb_classes 88.72
├── global_metrics.field.name=ispracticecompatible 128
│ ├── avg_f1_micro 0.72
│ └── avg_nb_classes 18.71
Expand All @@ -207,51 +162,6 @@ def get_node_hierarchy():
"children": [
{
"children": [
{
"children": [
{"key": None, "level": "avg_nb_classes", "value": 5.2},
{"key": None, "level": "avg_f1_micro", "value": 0.83},
],
"key": "hazardpictograms",
"level": "global_metrics.field.name",
"value": 369,
},
{
"children": [
{"key": None, "level": "avg_nb_classes", "value": 88.72},
{"key": None, "level": "avg_f1_micro", "value": 0.81},
],
"key": "islabeledby",
"level": "global_metrics.field.name",
"value": 369,
},
{
"children": [
{"key": None, "level": "avg_nb_classes", "value": 27.57},
{"key": None, "level": "avg_f1_micro", "value": 0.41},
],
"key": "flavors",
"level": "global_metrics.field.name",
"value": 367,
},
{
"children": [
{"key": None, "level": "avg_nb_classes", "value": 107.82},
{"key": None, "level": "avg_f1_micro", "value": 0.83},
],
"key": "hasnotableingredients",
"level": "global_metrics.field.name",
"value": 239,
},
{
"children": [
{"key": None, "level": "avg_nb_classes", "value": 65.59},
{"key": None, "level": "avg_f1_micro", "value": 0.82},
],
"key": "allergentypelist",
"level": "global_metrics.field.name",
"value": 130,
},
{
"children": [
{"key": None, "level": "avg_nb_classes", "value": 18.71},
Expand Down Expand Up @@ -314,44 +224,3 @@ def get_node_hierarchy():
"level": "root",
"value": None,
}

EXPECTED_TABULAR_INDEX = (
{
"classification_type": "multilabel",
"global_metrics.field.name": "hazardpictograms",
},
{"classification_type": "multilabel", "global_metrics.field.name": "islabeledby"},
{"classification_type": "multilabel", "global_metrics.field.name": "flavors"},
{
"classification_type": "multilabel",
"global_metrics.field.name": "hasnotableingredients",
},
{
"classification_type": "multilabel",
"global_metrics.field.name": "allergentypelist",
},
{
"classification_type": "multilabel",
"global_metrics.field.name": "ispracticecompatible",
},
{"classification_type": "multilabel", "global_metrics.field.name": "gpc"},
{
"classification_type": "multilabel",
"global_metrics.field.name": "preservationmethods",
},
{"classification_type": "multiclass", "global_metrics.field.name": "kind"},
{"classification_type": "multiclass", "global_metrics.field.name": "gpc"},
)

EXPECTED_TABULAR_VALUES = [
{"avg_f1_micro": 0.83, "avg_nb_classes": 5.2, u"doc_count": 369},
{"avg_f1_micro": 0.81, "avg_nb_classes": 88.72, u"doc_count": 369},
{"avg_f1_micro": 0.41, "avg_nb_classes": 27.57, u"doc_count": 367},
{"avg_f1_micro": 0.83, "avg_nb_classes": 107.82, u"doc_count": 239},
{"avg_f1_micro": 0.82, "avg_nb_classes": 65.59, u"doc_count": 130},
{"avg_f1_micro": 0.72, "avg_nb_classes": 18.71, u"doc_count": 128},
{"avg_f1_micro": 0.95, "avg_nb_classes": 183.21, u"doc_count": 119},
{"avg_f1_micro": 0.8, "avg_nb_classes": 9.97, u"doc_count": 76},
{"avg_f1_micro": 0.89, "avg_nb_classes": 206.5, u"doc_count": 370},
{"avg_f1_micro": 0.93, "avg_nb_classes": 211.12, u"doc_count": 198},
]
Loading

0 comments on commit 52d5a93

Please sign in to comment.