Skip to content

Commit

Permalink
Merge pull request #4 from evaline-ju/immutable
Browse files Browse the repository at this point in the history
Add ImmutableConfig
  • Loading branch information
gabe-l-hart authored Apr 25, 2023
2 parents 2a60398 + b0af872 commit ea6566a
Show file tree
Hide file tree
Showing 4 changed files with 114 additions and 14 deletions.
2 changes: 1 addition & 1 deletion aconfig/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@
'''Enable calls to desired components of package.
'''

from .aconfig import AttributeAccessDict, Config
from .aconfig import AttributeAccessDict, Config, ImmutableAttributeAccessDict, ImmutableConfig
60 changes: 50 additions & 10 deletions aconfig/aconfig.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,28 +32,31 @@ def __init__(self, input_map):
overrides dict's methods to enable this. Can be modified later on and keep the same
behavior.
'''
assert isinstance(input_map, dict), \
'`input_map` argument should be of type dict, but found type: <{0}>'.format(
type(input_map))
if not isinstance(input_map, dict):
raise TypeError('`input_map` argument should be of type dict, but found type: <{0}>'.format(
type(input_map)))

# copy so as not to modify passed in dictionary
copied_map = copy.deepcopy(input_map)

# recursively instantiate sub-dicts
for key, value in copied_map.items():
copied_map[key] = AttributeAccessDict._make_attribute_access_dict(value)
copied_map[key] = self.__class__._make_attribute_access_dict(value)

# make it accessible like native Python dict
super().__init__(**copied_map)

@staticmethod
def _make_attribute_access_dict(value):
if isinstance(value, AttributeAccessDict):
@classmethod
def _make_attribute_access_dict(cls, value):
"""Recursively walk down any `dict`s or `list`s and build attribute access dicts
🌶️🌶️🌶️: This is a classmethod so that inheritance is respected.
"""
if isinstance(value, cls):
return value
elif isinstance(value, dict):
return AttributeAccessDict(value)
return cls(value)
elif isinstance(value, list):
return [AttributeAccessDict._make_attribute_access_dict(v) for v in value]
return [cls._make_attribute_access_dict(v) for v in value]
else:
return value

Expand Down Expand Up @@ -83,7 +86,35 @@ def __deepcopy__(self, memo):
'''This enables deepcopy to successfully copy a Config object, despite
the default value semantics
'''
return AttributeAccessDict(copy.deepcopy(dict(self)))
return self.__class__(copy.deepcopy(dict(self)))


class ImmutableAttributeAccessDict(AttributeAccessDict):
"""This class subclasses AttributeAccessDict and removes the setters,
to allow the creation of immutable dicts.
Using inheritance this way allows the dicts to be recursively created via
AttributeAccessDict, while maintaining nested immutability.
"""

def __init__(self, input_map, *_):
"""See :func:`~aconfig.aconfig.AttributeAccessDict.__init__`"""
if not isinstance(input_map, dict):
raise TypeError('`input_map` argument should be of type dict, but found type: <{0}>'.format(
type(input_map)))
# 🌶️🌶️🌶️: we explicitly cast back down to `dict` for the immutable case
# If we were to build an immutable dict from the top-down, that would
# obviously fail.
input_map = dict(input_map)
# Invoke the AttributeAccessDict initializer
super().__init__(input_map)

def __setitem__(self, key, value):
raise TypeError("ImmutableAttributeAccessDict does not support item assignment")

def __setattr__(self, key, value):
raise AttributeError("ImmutableAttributeAccessDict does not support attribute assignment")


class Config(AttributeAccessDict):
'''Config which holds the configurations at the given config location.
Expand Down Expand Up @@ -293,6 +324,15 @@ def _env_var_from_key(self, config_key):
return re.sub(self._search_pattern, '_', config_key.upper())


class ImmutableConfig(ImmutableAttributeAccessDict, Config):
"""This class is the Immutable version of Config"""
def __init__(self, config, override_env_vars=True):
"""See :func:`~aconfig.aconfig.Config.__init__`"""
if not isinstance(config, dict):
raise TypeError("config must be a dict")
super().__init__(config, override_env_vars)


## yaml representation safe ########################################################################

yaml.add_representer(AttributeAccessDict, SafeRepresenter.represent_dict)
Expand Down
31 changes: 29 additions & 2 deletions test/test_attribute_access_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,11 +33,11 @@ def test__init__pass(self):
def test__init__fail(self):
'''Test that initialization will fail with invalid dict's.
'''
with self.assertRaises(AssertionError) as ex:
with self.assertRaises(TypeError) as ex:
bad_dict = aconfig.AttributeAccessDict(['a', 'list', 'of', 'strings'])

raised_exception = ex.exception
self.assertIsInstance(raised_exception, AssertionError)
self.assertIsInstance(raised_exception, TypeError)

# make sure bad_dict was NOT initialized
self.assertEqual(getattr(locals(), 'bad_dict', None), None)
Expand Down Expand Up @@ -213,3 +213,30 @@ def test_builtin_method(self):

self.assertNotEqual(aad.update, None)
self.assertNotIn('update', aad)

def test_immutable_flat_access_dict(self):
'''Test that immutable flat dict cannot be changed
'''
flat_dict = aconfig.ImmutableAttributeAccessDict(fixtures.GOOD_FLAT_DICT)
self.assertIsInstance(flat_dict, aconfig.AttributeAccessDict)

with self.assertRaises(TypeError):
flat_dict['str_key'] = 'new_key'

def test_immutable_nested_access_dict(self):
'''Test that immutable nested dict cannot be changed
'''
flat_dict = aconfig.ImmutableAttributeAccessDict(fixtures.GOOD_NESTED_DICT)
self.assertIsInstance(flat_dict, aconfig.AttributeAccessDict)

with self.assertRaises(TypeError):
flat_dict['key2']['key4'] = 'new_key'

def test_immutable_dict_attr(self):
'''Test that immutable dict cannot be changed via attribute
'''
flat_dict = aconfig.ImmutableAttributeAccessDict(fixtures.GOOD_FLAT_DICT)
self.assertIsInstance(flat_dict, aconfig.AttributeAccessDict)

with self.assertRaises(AttributeError):
flat_dict.str_key = 'new_key'
35 changes: 34 additions & 1 deletion test/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -221,6 +221,12 @@ def test_deepcopy(self):
self.assertEqual(cfg.a.b[0].c, 1)
self.assertEqual(cfg_copy.a.b[0].c, 2)

immutable_cfg = aconfig.ImmutableConfig(cfg)
immutable_cfg_copy = copy.deepcopy(immutable_cfg)
self.assertIsInstance(immutable_cfg_copy, aconfig.ImmutableConfig)
self.assertEqual(immutable_cfg_copy, immutable_cfg)
self.assertIsNot(immutable_cfg_copy, immutable_cfg)

def test_yaml_dump(self):
'''Test yaml.dump(config) works'''
loaded_yaml = aconfig.Config.from_yaml(fixtures.GOOD_CONFIG_LOCATION)
Expand All @@ -231,6 +237,33 @@ def test_yaml_dump(self):
assert yaml_dump == yaml_safe_dump

# Load both ways
yaml_loaded = yaml.load(yaml_dump)
yaml_loaded = yaml.full_load(yaml_dump)
yaml_safe_loaded = yaml.safe_load(yaml_dump)
assert yaml_loaded == yaml_safe_loaded

def test_immutable_config(self):
cfg = aconfig.ImmutableConfig({'a': {'b': [{'c': 1}]}})
self.assertEqual(cfg.a.b[0].c, 1)

with self.assertRaises(AttributeError):
cfg.a.b[0].c = 2
with self.assertRaises(AttributeError):
cfg.a.b = [1, 2, 3]
with self.assertRaises(AttributeError):
cfg.a = 1

def test_immutable_config_with_env_overrides(self):
# set an environment
os.environ['KEY1'] = '12345678'
cfg = aconfig.ImmutableConfig({"key1": 1, "key2": 2}, override_env_vars=True)

assert cfg.key2 == 2
assert cfg.key1 == 12345678
with self.assertRaises(AttributeError):
cfg.key1 = 1

def test_immutable_config_from_mutable_config(self):
cfg = aconfig.Config({'a': {'b': [{'c': 1}]}})
immutable_config = aconfig.ImmutableConfig(cfg)

assert cfg == immutable_config

0 comments on commit ea6566a

Please sign in to comment.