diff --git a/pandagg/response.py b/pandagg/response.py index 62075407..94507e04 100644 --- a/pandagg/response.py +++ b/pandagg/response.py @@ -155,9 +155,9 @@ def get(self, key): def _parse_group_by( self, response, + until, row=None, agg_name=None, - until=None, ancestors=None, row_as_tuple=False, with_single_bucket_groups=False, @@ -166,11 +166,14 @@ def _parse_group_by( Yields each row for which last bucket aggregation generated buckets. """ + # initialization: find ancestors once for faster computation if ancestors is None: ancestors = self._aggs.ancestors(until, id_only=True) + # remove eventual fake root + ancestors = [until] + [a for a in ancestors if a != "_"] + agg_name = ancestors[-1] if not row: row = [] if row_as_tuple else {} - agg_name = self._aggs.root if agg_name is None else agg_name if agg_name in response: agg_node = self._aggs.get(agg_name) for key, raw_bucket in agg_node.extract_buckets(response[agg_name]): diff --git a/tests/test_response.py b/tests/test_response.py index 2b504779..3228feb9 100644 --- a/tests/test_response.py +++ b/tests/test_response.py @@ -271,6 +271,16 @@ def test_parse_as_tabular_multiple_roots(self): }, ) + # with specified grouped_by + index_names, index_values = Aggregations( + data=raw_response, search=Search().aggs(my_agg) + ).to_tabular(grouped_by="classification_type") + self.assertEqual(index_names, ["classification_type"]) + self.assertEqual( + index_values, + {("multiclass",): {"doc_count": 439}, ("multilabel",): {"doc_count": 433}}, + ) + def test_parse_as_dataframe(self): my_agg = Aggs(sample.EXPECTED_AGG_QUERY, mapping=MAPPING) df = Aggregations(