Skip to content

Commit

Permalink
Flatten all level for MongoDB data source (getredash#6844)
Browse files Browse the repository at this point in the history
  • Loading branch information
KimBioInfoStudio authored and harveyrendell committed Jan 8, 2025
1 parent 42a5d30 commit 199cbeb
Show file tree
Hide file tree
Showing 2 changed files with 112 additions and 29 deletions.
85 changes: 57 additions & 28 deletions redash/query_runner/mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
BaseQueryRunner,
register,
)
from redash.utils import JSONEncoder, json_dumps, json_loads, parse_human_time
from redash.utils import json_loads, parse_human_time

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -42,17 +42,6 @@
}


class MongoDBJSONEncoder(JSONEncoder):
def default(self, o):
if isinstance(o, ObjectId):
return str(o)
elif isinstance(o, Timestamp):
return super(MongoDBJSONEncoder, self).default(o.as_datetime())
elif isinstance(o, Decimal128):
return o.to_decimal()
return super(MongoDBJSONEncoder, self).default(o)


date_regex = re.compile(r'ISODate\("(.*)"\)', re.IGNORECASE)


Expand Down Expand Up @@ -80,7 +69,7 @@ def datetime_parser(dct):
return bson_object_hook(dct, json_options=opts)


def parse_query_json(query):
def parse_query_json(query: str):
query_data = json_loads(query, object_hook=datetime_parser)
return query_data

Expand All @@ -93,26 +82,40 @@ def _get_column_by_name(columns, column_name):
return None


def _parse_dict(dic):
def _parse_dict(dic: dict, flatten: bool = False) -> dict:
res = {}
for key, value in dic.items():
if isinstance(value, dict):
for tmp_key, tmp_value in _parse_dict(value).items():
new_key = "{}.{}".format(key, tmp_key)
res[new_key] = tmp_value

def _flatten(x, name=""):
if isinstance(x, dict):
for k, v in x.items():
_flatten(v, "{}.{}".format(name, k))
elif isinstance(x, list):
for idx, item in enumerate(x):
_flatten(item, "{}.{}".format(name, idx))
else:
res[key] = value
res[name[1:]] = x

if flatten:
_flatten(dic)
else:
for key, value in dic.items():
if isinstance(value, dict):
for tmp_key, tmp_value in _parse_dict(value).items():
new_key = "{}.{}".format(key, tmp_key)
res[new_key] = tmp_value
else:
res[key] = value
return res


def parse_results(results):
def parse_results(results: list, flatten: bool = False) -> list:
rows = []
columns = []

for row in results:
parsed_row = {}

parsed_row = _parse_dict(row)
parsed_row = _parse_dict(row, flatten)
for column_name, value in parsed_row.items():
columns.append(
{
Expand Down Expand Up @@ -151,6 +154,14 @@ def configuration_schema(cls):
],
"title": "Replica Set Read Preference",
},
"flatten": {
"type": "string",
"extendedEnum": [
{"value": "False", "name": "False"},
{"value": "True", "name": "True"},
],
"title": "Flatten Results",
},
},
"secret": ["password"],
"required": ["connectionString", "dbName"],
Expand All @@ -171,6 +182,19 @@ def __init__(self, configuration):
True if "replicaSetName" in self.configuration and self.configuration["replicaSetName"] else False
)

self.flatten = self.configuration.get("flatten", "False").upper() in ["TRUE", "YES", "ON", "1", "Y", "T"]
logger.debug("flatten: {}".format(self.flatten))

@classmethod
def custom_json_encoder(cls, dec, o):
if isinstance(o, ObjectId):
return str(o)
elif isinstance(o, Timestamp):
return dec.default(o.as_datetime())
elif isinstance(o, Decimal128):
return o.to_decimal()
return None

def _get_db(self):
kwargs = {}
if self.is_replica_set:
Expand Down Expand Up @@ -279,8 +303,10 @@ def run_query(self, query, user): # noqa: C901
if "$sort" in step:
sort_list = []
for sort_item in step["$sort"]:
sort_list.append((sort_item["name"], sort_item["direction"]))

if isinstance(sort_item, dict):
sort_list.append((sort_item["name"], sort_item.get("direction", 1)))
elif isinstance(sort_item, list):
sort_list.append(tuple(sort_item))
step["$sort"] = SON(sort_list)

if "fields" in query_data:
Expand All @@ -290,7 +316,10 @@ def run_query(self, query, user): # noqa: C901
if "sort" in query_data and query_data["sort"]:
s = []
for field_data in query_data["sort"]:
s.append((field_data["name"], field_data["direction"]))
if isinstance(field_data, dict):
s.append((field_data["name"], field_data.get("direction", 1)))
elif isinstance(field_data, list):
s.append(tuple(field_data))

columns = []
rows = []
Expand Down Expand Up @@ -331,7 +360,7 @@ def run_query(self, query, user): # noqa: C901

rows.append({"count": cursor})
else:
rows, columns = parse_results(cursor)
rows, columns = parse_results(cursor, flatten=self.flatten)

if f:
ordered_columns = []
Expand All @@ -341,16 +370,16 @@ def run_query(self, query, user): # noqa: C901
ordered_columns.append(column)

columns = ordered_columns
logger.debug("columns: {}".format(columns))

if query_data.get("sortColumns"):
reverse = query_data["sortColumns"] == "desc"
columns = sorted(columns, key=lambda col: col["name"], reverse=reverse)

data = {"columns": columns, "rows": rows}
error = None
json_data = json_dumps(data, cls=MongoDBJSONEncoder)

return json_data, error
return data, error


register(MongoDB)
56 changes: 55 additions & 1 deletion tests/query_runner/test_mongodb.py
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,13 @@ def test_parses_nested_results(self):
"column": 2,
"column2": "test",
"column3": "hello",
"nested": {"a": 2, "b": "str2", "c": "c", "d": {"e": 3}},
"nested": {
"a": 2,
"b": "str2",
"c": "c",
"d": {"e": 3},
"f": {"h": {"i": ["j", "k", "l"]}},
},
},
]

Expand All @@ -158,6 +164,7 @@ def test_parses_nested_results(self):
"nested.b": "str2",
"nested.c": "c",
"nested.d.e": 3,
"nested.f.h.i": ["j", "k", "l"],
},
)

Expand All @@ -167,3 +174,50 @@ def test_parses_nested_results(self):
self.assertIsNotNone(_get_column_by_name(columns, "nested.a"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.b"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.c"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.d.e"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.f.h.i"))

def test_parses_flatten_nested_results(self):
raw_results = [
{
"column": 2,
"column2": "test",
"column3": "hello",
"nested": {
"a": 2,
"b": "str2",
"c": "c",
"d": {"e": 3},
"f": {"h": {"i": ["j", "k", "l"]}},
},
}
]

rows, columns = parse_results(raw_results, flatten=True)
print(rows)
self.assertDictEqual(
rows[0],
{
"column": 2,
"column2": "test",
"column3": "hello",
"nested.a": 2,
"nested.b": "str2",
"nested.c": "c",
"nested.d.e": 3,
"nested.f.h.i.0": "j",
"nested.f.h.i.1": "k",
"nested.f.h.i.2": "l",
},
)

self.assertIsNotNone(_get_column_by_name(columns, "column"))
self.assertIsNotNone(_get_column_by_name(columns, "column2"))
self.assertIsNotNone(_get_column_by_name(columns, "column3"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.a"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.b"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.c"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.d.e"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.f.h.i.0"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.f.h.i.1"))
self.assertIsNotNone(_get_column_by_name(columns, "nested.f.h.i.2"))

0 comments on commit 199cbeb

Please sign in to comment.