diff --git a/metadock/yaml_utils.py b/metadock/yaml_utils.py index 906ceed..9c853eb 100644 --- a/metadock/yaml_utils.py +++ b/metadock/yaml_utils.py @@ -1,3 +1,5 @@ +import operator +from functools import reduce from typing import Any @@ -18,7 +20,10 @@ def flatten_merge_keys(yaml_dict: Any) -> dict: for key in yaml_dict.keys(): flat_yaml_value = flatten_merge_keys(yaml_dict[key]) if key == "<<": - flattened_yaml_dict |= flat_yaml_value + if isinstance(flat_yaml_value, list): + flattened_yaml_dict |= reduce(operator.or_, flat_yaml_value, {}) + elif isinstance(flat_yaml_value, dict): + flattened_yaml_dict |= flat_yaml_value else: flattened_yaml_dict[key] = flat_yaml_value diff --git a/tests/test_yaml_utils.py b/tests/test_yaml_utils.py index 127c99e..963302b 100644 --- a/tests/test_yaml_utils.py +++ b/tests/test_yaml_utils.py @@ -36,6 +36,16 @@ {"key": "value", "inner": "item", "first": "level"}, id="simple 2 level to flatten with adjacent key", ), + pytest.param( + {"<<": [{"key_a": "value"}, {"key_b": "value again"}], "first": "level"}, + {"key_a": "value", "key_b": "value again", "first": "level"}, + id="simple merge list", + ), + pytest.param( + {"<<": [{"k": "value"}, {"k": "value again"}], "first": "level"}, + {"k": "value again", "first": "level"}, + id="simple merge list with collision", + ), ], ) def test_yaml_utils__flatten_merge_keys(yml_dict, flat_dict):