From 0ec7d7e93b21c5d5687f3e7212a9f97c6a42b90e Mon Sep 17 00:00:00 2001 From: Jeremy Kubica <104161096+jeremykubica@users.noreply.github.com> Date: Fri, 12 Jul 2024 09:27:54 -0400 Subject: [PATCH] Change setters list to dict --- src/tdastro/base_models.py | 28 ++++++++++++++-------------- 1 file changed, 14 insertions(+), 14 deletions(-) diff --git a/src/tdastro/base_models.py b/src/tdastro/base_models.py index 77f40995..e2fb3e13 100644 --- a/src/tdastro/base_models.py +++ b/src/tdastro/base_models.py @@ -21,9 +21,9 @@ class ParameterizedModel: Attributes ---------- - setters : `list` of `tuple` + setters : `dict` of `tuple` A dictionary to information about the setters for the parameters in the form: - (name, ParameterSource, setter information, required). The attributes are + (ParameterSource, setter information, required). The attributes are stored in the order in which they need to be set. sample_iteration : `int` A counter used to syncronize sampling runs. Tracks how many times this @@ -31,7 +31,7 @@ class ParameterizedModel: """ def __init__(self, **kwargs): - self.setters = [] + self.setters = {} self.sample_iteration = 0 def __str__(self): @@ -63,11 +63,9 @@ def set_parameter(self, name, value=None, **kwargs): Raise a ``ValueError`` if the parameter is required, but set to None. """ # Check for parameter has been added and if so, find the index. - try: - ind = next(ind for ind, entry in enumerate(self.setters) if entry[0] == name) - except StopIteration: + if name not in self.setters: raise KeyError(f"Tried to set parameter {name} that has not been added.") from None - required = self.setters[ind][3] + required = self.setters[name][2] if value is None and name in kwargs: # The value wasn't set, but the name is in kwargs. @@ -76,26 +74,26 @@ def set_parameter(self, name, value=None, **kwargs): if value is not None: if isinstance(value, types.FunctionType): # Case 1: If we are getting from a static function, sample it. - self.setters[ind] = (name, ParameterSource.FUNCTION, value, required) + self.setters[name] = (ParameterSource.FUNCTION, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, types.MethodType) and isinstance(value.__self__, ParameterizedModel): # Case 2: We are trying to use the method from a ParameterizedModel. # Note that this will (correctly) fail if we are adding a model method from the current # object that requires an unset attribute. - self.setters[ind] = (name, ParameterSource.MODEL_METHOD, value, required) + self.setters[name] = (ParameterSource.MODEL_METHOD, value, required) setattr(self, name, value(**kwargs)) elif isinstance(value, ParameterizedModel): # Case 3: We are trying to access an attribute from a parameterized model. if not hasattr(value, name): raise ValueError(f"Attribute {name} missing from parent.") - self.setters[ind] = (name, ParameterSource.MODEL_ATTRIBUTE, value, required) + self.setters[name] = (ParameterSource.MODEL_ATTRIBUTE, value, required) setattr(self, name, getattr(value, name)) else: # Case 4: The value is constant. - self.setters[ind] = (name, ParameterSource.CONSTANT, value, required) + self.setters[name] = (ParameterSource.CONSTANT, value, required) setattr(self, name, value) elif not required: - self.setters[ind] = (name, ParameterSource.CONSTANT, None, required) + self.setters[name] = (ParameterSource.CONSTANT, None, required) setattr(self, name, None) else: raise ValueError(f"Missing required parameter {name}") @@ -134,7 +132,7 @@ def add_parameter(self, name, value=None, required=False, **kwargs): # Add an entry for the setter function and fill in the remaining # information using set_parameter(). - self.setters.append((name, None, None, required)) + self.setters[name] = (None, None, required) self.set_parameter(name, value, **kwargs) def sample_parameters(self, max_depth=50, **kwargs): @@ -159,7 +157,9 @@ def sample_parameters(self, max_depth=50, **kwargs): raise ValueError(f"Maximum sampling depth exceeded at {self}. Potential infinite loop.") # Run through each parameter and sample it based on the given recipe. - for name, source_type, setter, _ in self.setters: + # As of Python 3.7 dictionaries are guaranteed to preserve insertion ordering, + # so this will iterate through attributes in the order they were inserted. + for name, (source_type, setter, _) in self.setters.items(): sampled_value = None if source_type == ParameterSource.CONSTANT: sampled_value = setter