From b2ada30f59f90eb5571047860eaff9d15363630a Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:03:30 -0700 Subject: [PATCH 01/25] Initial version working --- slsim/ParamDistributions/params.py | 8 +++++ slsim/Util/params.py | 57 ++++++++++++++++++++++++++++++ 2 files changed, 65 insertions(+) create mode 100644 slsim/ParamDistributions/params.py create mode 100644 slsim/Util/params.py diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py new file mode 100644 index 000000000..051b38cd9 --- /dev/null +++ b/slsim/ParamDistributions/params.py @@ -0,0 +1,8 @@ +from pydantic import BaseModel +import numpy as np + +class GaussianMixtureModel(BaseModel): + means: list[float] = [0.00330796, -0.07635054, 0.11829008] + stds: list[float] = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)] + weights: list[float] = [0.62703102, 0.23732313, 0.13564585] + \ No newline at end of file diff --git a/slsim/Util/params.py b/slsim/Util/params.py new file mode 100644 index 000000000..7cf2781b4 --- /dev/null +++ b/slsim/Util/params.py @@ -0,0 +1,57 @@ +""" +Utilities for managing parameter defaults and validation in the slsim package. +Desgined to be unobtrusive to use. +""" +from functools import wraps +from inspect import getsourcefile, getargspec +from pathlib import Path +from importlib import import_module +from typing import Callable + +class SlSimParameterException(Exception): + pass + +_defaults = {} + +def check_params(init_fn): + if not init_fn.__name__.startswith('__init__'): + raise SlSimParameterException('pcheck decorator can currently only be used'\ + ' with__init__ methods') + + @wraps(init_fn) + def new_init_fn(obj, *args, **kwargs): + # Get function argument names + all_args = {} + if args: + largs = getargspec(init_fn).args + for i in range(len(args)): + all_args[largs[i+1]] = args[i] + all_args.update(kwargs) + parsed_args = get_defaults(init_fn)(**all_args) + return init_fn(obj, **dict(parsed_args)) + return new_init_fn + + +def get_defaults(init_fn): + path = getsourcefile(init_fn) + obj_name = init_fn.__qualname__.split('.')[0] + start = path.rfind("slsim") + modpath = path[start:].split('/') + modpath = modpath[1:-1] + ["params"] + modpath = ".".join(["slsim"] + modpath) + if modpath not in _defaults: + _defaults[modpath] = load_defaults(modpath, obj_name) + return _defaults[modpath] + +def load_defaults(modpath, obj_name): + try: + defaults = import_module(modpath) + except ModuleNotFoundError: + raise SlSimParameterException('No default parameters found for module '\ + f'\"{modpath[-2]}\"') + try: + obj_defaults = getattr(defaults, obj_name) + except AttributeError: + raise SlSimParameterException(f'No default parameters found for class '\ + f'\"{obj_name}\"') + return obj_defaults From 3db5362c968a69ce2dfc58514990a206f05859f5 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:10:55 -0700 Subject: [PATCH 02/25] add docstrings --- slsim/Util/params.py | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 7cf2781b4..5e89f3b92 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -13,7 +13,21 @@ class SlSimParameterException(Exception): _defaults = {} -def check_params(init_fn): +def check_params(init_fn: Callable) -> Callable: + """ + A decorator for enforcing checking of params in __init__ methods. This + decorator will automatically load the default parameters for the class + and check that the passed parameters are valid. It expeects a "params.py" + file in the same folder as the class definition. Uses pydantic models + to enforce types, sanity checks, and defaults. + + From and end user perspective, there is no difference between this and a normal + __init__ fn. Developers only need to add @check_params above their __init__ + method definition to enable this feature, then add their default parameters + to the "params.py" file. + """ + + if not init_fn.__name__.startswith('__init__'): raise SlSimParameterException('pcheck decorator can currently only be used'\ ' with__init__ methods') @@ -39,15 +53,23 @@ def get_defaults(init_fn): modpath = path[start:].split('/') modpath = modpath[1:-1] + ["params"] modpath = ".".join(["slsim"] + modpath) + # Unfortunately, there doesn't seem to be a better way of doing this. + if modpath not in _defaults: - _defaults[modpath] = load_defaults(modpath, obj_name) + #Little optimization. We cache defaults so we don't have to reload them + # every time we construct a new object. + _defaults[modpath] = load_parameters(modpath, obj_name) return _defaults[modpath] -def load_defaults(modpath, obj_name): +def load_parameters(modpath, obj_name): + """ + Loads parameters from the "params.py" file which should be in the same folder + as the class definition. + """ try: defaults = import_module(modpath) except ModuleNotFoundError: - raise SlSimParameterException('No default parameters found for module '\ + raise SlSimParameterException('No default parameters found in module '\ f'\"{modpath[-2]}\"') try: obj_defaults = getattr(defaults, obj_name) From f2652262d3782986877891285fbfd37bc0eb5b87 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:16:52 -0700 Subject: [PATCH 03/25] Finish documentation --- slsim/ParamDistributions/params.py | 1 + slsim/Util/params.py | 20 ++++++++++++-------- 2 files changed, 13 insertions(+), 8 deletions(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index 051b38cd9..8962e56cd 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -1,4 +1,5 @@ from pydantic import BaseModel +from pydantic import Field import numpy as np class GaussianMixtureModel(BaseModel): diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 5e89f3b92..5f5c1d4c6 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -6,7 +6,8 @@ from inspect import getsourcefile, getargspec from pathlib import Path from importlib import import_module -from typing import Callable +from typing import Callable, Any +import pydantic class SlSimParameterException(Exception): pass @@ -33,20 +34,20 @@ def check_params(init_fn: Callable) -> Callable: ' with__init__ methods') @wraps(init_fn) - def new_init_fn(obj, *args, **kwargs): + def new_init_fn(obj: Any, *args, **kwargs) -> Any: # Get function argument names - all_args = {} + pargs = {} if args: largs = getargspec(init_fn).args for i in range(len(args)): - all_args[largs[i+1]] = args[i] - all_args.update(kwargs) - parsed_args = get_defaults(init_fn)(**all_args) + pargs[largs[i+1]] = args[i] + #Doing it this way ensures we still catch duplicate arguments + parsed_args = get_defaults(init_fn)(**pargs, **kwargs) return init_fn(obj, **dict(parsed_args)) return new_init_fn -def get_defaults(init_fn): +def get_defaults(init_fn: Callable) -> pydantic.BaseModel: path = getsourcefile(init_fn) obj_name = init_fn.__qualname__.split('.')[0] start = path.rfind("slsim") @@ -61,7 +62,7 @@ def get_defaults(init_fn): _defaults[modpath] = load_parameters(modpath, obj_name) return _defaults[modpath] -def load_parameters(modpath, obj_name): +def load_parameters(modpath: str, obj_name: str) -> pydantic.BaseModel: """ Loads parameters from the "params.py" file which should be in the same folder as the class definition. @@ -76,4 +77,7 @@ def load_parameters(modpath, obj_name): except AttributeError: raise SlSimParameterException(f'No default parameters found for class '\ f'\"{obj_name}\"') + if not issubclass(obj_defaults, pydantic.BaseModel): + raise SlSimParameterException(f'Defaults for \"{obj_name}\" are not in a '\ + 'pydantic model!') return obj_defaults From 366de8c28914b65ae90a41d65ca0b027a073d7bc Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:22:48 -0700 Subject: [PATCH 04/25] Complete GaussianMixtureModel model --- slsim/ParamDistributions/params.py | 13 ++++++++++--- 1 file changed, 10 insertions(+), 3 deletions(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index 8962e56cd..d73e5a466 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -1,9 +1,16 @@ -from pydantic import BaseModel -from pydantic import Field +from pydantic import BaseModel, Field, model_validator import numpy as np class GaussianMixtureModel(BaseModel): means: list[float] = [0.00330796, -0.07635054, 0.11829008] - stds: list[float] = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)] + stds: list[float] = Field(gt=0, + default=[np.sqrt(0.00283885), np.sqrt(0.01066668), + np.sqrt(0.0097978)]) weights: list[float] = [0.62703102, 0.23732313, 0.13564585] + + @model_validator(mode="before") + def check_lenghts(self): + if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): + raise ValueError("The lenghts of means, stds and weights must be equal") + return self \ No newline at end of file From dd03ca223d88b63c3521ad56d098817f68ff1592 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:28:12 -0700 Subject: [PATCH 05/25] Add decorator as top-level import in Util module --- slsim/ParamDistributions/params.py | 2 +- slsim/Util/__init__.py | 3 +++ 2 files changed, 4 insertions(+), 1 deletion(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index d73e5a466..3df61666a 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -8,7 +8,7 @@ class GaussianMixtureModel(BaseModel): np.sqrt(0.0097978)]) weights: list[float] = [0.62703102, 0.23732313, 0.13564585] - @model_validator(mode="before") + @model_validator(mode="after") def check_lenghts(self): if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): raise ValueError("The lenghts of means, stds and weights must be equal") diff --git a/slsim/Util/__init__.py b/slsim/Util/__init__.py index e69de29bb..7973418cf 100644 --- a/slsim/Util/__init__.py +++ b/slsim/Util/__init__.py @@ -0,0 +1,3 @@ +from .params import check_params + +__all__ = ['check_params'] \ No newline at end of file From 84ea259a49ad7cd844b9065e161dc4622a6f5f2a Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:29:26 -0700 Subject: [PATCH 06/25] Add to GaussianMixtureModel --- .../gaussian_mixture_model.py | 18 +++++------------- 1 file changed, 5 insertions(+), 13 deletions(-) diff --git a/slsim/ParamDistributions/gaussian_mixture_model.py b/slsim/ParamDistributions/gaussian_mixture_model.py index 0d23e6a5b..d4db8997a 100644 --- a/slsim/ParamDistributions/gaussian_mixture_model.py +++ b/slsim/ParamDistributions/gaussian_mixture_model.py @@ -1,5 +1,5 @@ import numpy as np - +from slsim.Util import check_params class GaussianMixtureModel: """A Gaussian Mixture Model (GMM) class. @@ -7,12 +7,13 @@ class GaussianMixtureModel: This class is used to represent a mixture of Gaussian distributions, each of which is defined by its mean, standard deviation and weight. """ - - def __init__(self, means=None, stds=None, weights=None): + @check_params + def __init__(self, means, stds, weights): """ The constructor for GaussianMixtureModel class. The default values are the means, standard deviations, and weights of the fits to the data in the table - 2 of https://doi.org/10.1093/mnras/stac2235 and others. + 2 of https://doi.org/10.1093/mnras/stac2235 and others. See "params.py" for + defaults and validation logic. :param means: the mean values of the Gaussian components. :type means: list of float @@ -21,15 +22,6 @@ def __init__(self, means=None, stds=None, weights=None): :param weights: The weights of the Gaussian components in the mixture. :type weights: list of float """ - if means is None: - means = [0.00330796, -0.07635054, 0.11829008] - if stds is None: - stds = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)] - if weights is None: - weights = [0.62703102, 0.23732313, 0.13564585] - assert ( - len(means) == len(stds) == len(weights) - ), "Lengths of means, standard deviations, and weights must be equal." self.means = means self.stds = stds self.weights = weights From 405e19b5a9728ca0f28afe6ac42618f961f1189d Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:32:16 -0700 Subject: [PATCH 07/25] Move to using pydantic.PositiveFloat --- slsim/ParamDistributions/params.py | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index 3df61666a..ff0cfdde6 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -1,12 +1,10 @@ -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, PositiveFloat, model_validator import numpy as np class GaussianMixtureModel(BaseModel): means: list[float] = [0.00330796, -0.07635054, 0.11829008] - stds: list[float] = Field(gt=0, - default=[np.sqrt(0.00283885), np.sqrt(0.01066668), - np.sqrt(0.0097978)]) - weights: list[float] = [0.62703102, 0.23732313, 0.13564585] + stds: list[PositiveFloat] = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)] + weights: list[PositiveFloat] = [0.62703102, 0.23732313, 0.13564585] @model_validator(mode="after") def check_lenghts(self): From dc2202dfce0077629aca3739079907064cbb36e8 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:54:46 -0700 Subject: [PATCH 08/25] GMM: Enforce sum(weights) == 1 --- slsim/ParamDistributions/params.py | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index ff0cfdde6..eb2b58700 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, PositiveFloat, model_validator +from pydantic import BaseModel, Field, PositiveFloat, model_validator, field_validator import numpy as np class GaussianMixtureModel(BaseModel): @@ -6,6 +6,13 @@ class GaussianMixtureModel(BaseModel): stds: list[PositiveFloat] = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)] weights: list[PositiveFloat] = [0.62703102, 0.23732313, 0.13564585] + @field_validator("weights") + @classmethod + def check_weights(cls, weight_values): + if sum(weight_values) != 1: + raise ValueError("The sum of the weights must be 1") + return weight_values + @model_validator(mode="after") def check_lenghts(self): if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): From 4f4361d41346afe2ce2d47b1e9607294ac89e799 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 10:56:42 -0700 Subject: [PATCH 09/25] Spelling --- slsim/ParamDistributions/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index eb2b58700..faa33ae15 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -14,7 +14,7 @@ def check_weights(cls, weight_values): return weight_values @model_validator(mode="after") - def check_lenghts(self): + def check_lengths(self): if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): raise ValueError("The lenghts of means, stds and weights must be equal") return self From fa6d791347d979c95dc7539de066ff2f628f577e Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 11:00:12 -0700 Subject: [PATCH 10/25] Remove unnecessary import --- slsim/ParamDistributions/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index faa33ae15..0bae2df6b 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, Field, PositiveFloat, model_validator, field_validator +from pydantic import BaseModel, PositiveFloat, model_validator, field_validator import numpy as np class GaussianMixtureModel(BaseModel): From cb6983d185fbe7c5cb798f63c88f11225f36d78d Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 11:20:21 -0700 Subject: [PATCH 11/25] One more spelling mistake --- slsim/ParamDistributions/params.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index 0bae2df6b..dc58223b6 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -16,6 +16,6 @@ def check_weights(cls, weight_values): @model_validator(mode="after") def check_lengths(self): if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): - raise ValueError("The lenghts of means, stds and weights must be equal") + raise ValueError("The lengths of means, stds and weights must be equal") return self \ No newline at end of file From c6d902f6ae70a519d2739c4fae6ea5f21aeb599c Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 27 Oct 2023 18:42:38 +0000 Subject: [PATCH 12/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../gaussian_mixture_model.py | 2 + slsim/ParamDistributions/params.py | 10 ++- slsim/Util/__init__.py | 2 +- slsim/Util/params.py | 65 ++++++++++--------- 4 files changed, 45 insertions(+), 34 deletions(-) diff --git a/slsim/ParamDistributions/gaussian_mixture_model.py b/slsim/ParamDistributions/gaussian_mixture_model.py index d4db8997a..6dada3151 100644 --- a/slsim/ParamDistributions/gaussian_mixture_model.py +++ b/slsim/ParamDistributions/gaussian_mixture_model.py @@ -1,12 +1,14 @@ import numpy as np from slsim.Util import check_params + class GaussianMixtureModel: """A Gaussian Mixture Model (GMM) class. This class is used to represent a mixture of Gaussian distributions, each of which is defined by its mean, standard deviation and weight. """ + @check_params def __init__(self, means, stds, weights): """ diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/params.py index dc58223b6..d476ea161 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/params.py @@ -1,11 +1,16 @@ from pydantic import BaseModel, PositiveFloat, model_validator, field_validator import numpy as np + class GaussianMixtureModel(BaseModel): means: list[float] = [0.00330796, -0.07635054, 0.11829008] - stds: list[PositiveFloat] = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)] + stds: list[PositiveFloat] = [ + np.sqrt(0.00283885), + np.sqrt(0.01066668), + np.sqrt(0.0097978), + ] weights: list[PositiveFloat] = [0.62703102, 0.23732313, 0.13564585] - + @field_validator("weights") @classmethod def check_weights(cls, weight_values): @@ -18,4 +23,3 @@ def check_lengths(self): if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): raise ValueError("The lengths of means, stds and weights must be equal") return self - \ No newline at end of file diff --git a/slsim/Util/__init__.py b/slsim/Util/__init__.py index 7973418cf..daaebfade 100644 --- a/slsim/Util/__init__.py +++ b/slsim/Util/__init__.py @@ -1,3 +1,3 @@ from .params import check_params -__all__ = ['check_params'] \ No newline at end of file +__all__ = ["check_params"] diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 5f5c1d4c6..4bdc820d3 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -1,5 +1,5 @@ -""" -Utilities for managing parameter defaults and validation in the slsim package. +"""Utilities for managing parameter defaults and validation in the slsim package. + Desgined to be unobtrusive to use. """ from functools import wraps @@ -9,29 +9,31 @@ from typing import Callable, Any import pydantic + class SlSimParameterException(Exception): pass + _defaults = {} + def check_params(init_fn: Callable) -> Callable: - """ - A decorator for enforcing checking of params in __init__ methods. This - decorator will automatically load the default parameters for the class - and check that the passed parameters are valid. It expeects a "params.py" - file in the same folder as the class definition. Uses pydantic models - to enforce types, sanity checks, and defaults. + """A decorator for enforcing checking of params in __init__ methods. This decorator + will automatically load the default parameters for the class and check that the + passed parameters are valid. It expeects a "params.py" file in the same folder as + the class definition. Uses pydantic models to enforce types, sanity checks, and + defaults. From and end user perspective, there is no difference between this and a normal - __init__ fn. Developers only need to add @check_params above their __init__ - method definition to enable this feature, then add their default parameters - to the "params.py" file. + __init__ fn. Developers only need to add @check_params above their __init__ method + definition to enable this feature, then add their default parameters to the + "params.py" file. """ - - if not init_fn.__name__.startswith('__init__'): - raise SlSimParameterException('pcheck decorator can currently only be used'\ - ' with__init__ methods') + if not init_fn.__name__.startswith("__init__"): + raise SlSimParameterException( + "pcheck decorator can currently only be used" " with__init__ methods" + ) @wraps(init_fn) def new_init_fn(obj: Any, *args, **kwargs) -> Any: @@ -40,44 +42,47 @@ def new_init_fn(obj: Any, *args, **kwargs) -> Any: if args: largs = getargspec(init_fn).args for i in range(len(args)): - pargs[largs[i+1]] = args[i] - #Doing it this way ensures we still catch duplicate arguments + pargs[largs[i + 1]] = args[i] + # Doing it this way ensures we still catch duplicate arguments parsed_args = get_defaults(init_fn)(**pargs, **kwargs) return init_fn(obj, **dict(parsed_args)) + return new_init_fn def get_defaults(init_fn: Callable) -> pydantic.BaseModel: path = getsourcefile(init_fn) - obj_name = init_fn.__qualname__.split('.')[0] + obj_name = init_fn.__qualname__.split(".")[0] start = path.rfind("slsim") - modpath = path[start:].split('/') + modpath = path[start:].split("/") modpath = modpath[1:-1] + ["params"] modpath = ".".join(["slsim"] + modpath) # Unfortunately, there doesn't seem to be a better way of doing this. if modpath not in _defaults: - #Little optimization. We cache defaults so we don't have to reload them + # Little optimization. We cache defaults so we don't have to reload them # every time we construct a new object. _defaults[modpath] = load_parameters(modpath, obj_name) return _defaults[modpath] + def load_parameters(modpath: str, obj_name: str) -> pydantic.BaseModel: - """ - Loads parameters from the "params.py" file which should be in the same folder - as the class definition. - """ + """Loads parameters from the "params.py" file which should be in the same folder as + the class definition.""" try: defaults = import_module(modpath) except ModuleNotFoundError: - raise SlSimParameterException('No default parameters found in module '\ - f'\"{modpath[-2]}\"') + raise SlSimParameterException( + "No default parameters found in module " f'"{modpath[-2]}"' + ) try: obj_defaults = getattr(defaults, obj_name) except AttributeError: - raise SlSimParameterException(f'No default parameters found for class '\ - f'\"{obj_name}\"') + raise SlSimParameterException( + f"No default parameters found for class " f'"{obj_name}"' + ) if not issubclass(obj_defaults, pydantic.BaseModel): - raise SlSimParameterException(f'Defaults for \"{obj_name}\" are not in a '\ - 'pydantic model!') + raise SlSimParameterException( + f'Defaults for "{obj_name}" are not in a ' "pydantic model!" + ) return obj_defaults From 5e5f7cc87ada82806f3be2a7cefcb5dc89ce8c9a Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 11:44:54 -0700 Subject: [PATCH 13/25] Add pydantic to requirements --- requirements.txt | 1 + 1 file changed, 1 insertion(+) diff --git a/requirements.txt b/requirements.txt index 283273cf9..3b2f41356 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ scipy speclite pyyaml matplotlib +pydantic>=2 \ No newline at end of file From 8cb752106ca61009c66016f512089cf27c851520 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Fri, 27 Oct 2023 11:48:24 -0700 Subject: [PATCH 14/25] Fix linting issue --- slsim/Util/params.py | 1 - 1 file changed, 1 deletion(-) diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 4bdc820d3..fafacb6ff 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -4,7 +4,6 @@ """ from functools import wraps from inspect import getsourcefile, getargspec -from pathlib import Path from importlib import import_module from typing import Callable, Any import pydantic From 8903241996444416240533c16ab067f18f037589 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 11:00:23 -0800 Subject: [PATCH 15/25] Several changes - We now check what kind of function is being inputted - Introduce basic tests --- slsim/Util/params.py | 64 +++++++++++++++++++++++++++++--- tests/test_Params/__init__.py | 0 tests/test_Params/test_params.py | 3 ++ 3 files changed, 62 insertions(+), 5 deletions(-) create mode 100644 tests/test_Params/__init__.py create mode 100644 tests/test_Params/test_params.py diff --git a/slsim/Util/params.py b/slsim/Util/params.py index fafacb6ff..f0f38cacd 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -6,8 +6,15 @@ from inspect import getsourcefile, getargspec from importlib import import_module from typing import Callable, Any +from enum import Enum +import inspect import pydantic +""" +Set of routines for validating inputs to functions and classes. The elements of this +module should never be imported directly. Instead, @check_params can be imported +directly from the Util module. +""" class SlSimParameterException(Exception): pass @@ -15,8 +22,56 @@ class SlSimParameterException(Exception): _defaults = {} +class _FnType(Enum): + """ + + Enum for the different types of functions we can have. This is used to determine + how to parse the arguments to the function. + + There are three possible cases: + 1. The function is a standard function, defined outside a class + 2. The function is a standard object method, + taking "self" as the first parameter + 3. The funtion is a class method (or staticmethod), not taking + "self" as the first parameter + + """ + + STANDARD = 0 + METHOD = 1 + CLASSMETHOD = 2 + +def determine_fn_type(fn: Callable) -> _FnType: + """ + Determine which of the three possible cases a function falls into. Cases + 0 and 2 are actually functionally identical. Things only get spicy when we + have a "self" argument. + + However the tricky thing is that decorators operate on functions and methods when + they are imported, not when they are used. This means "inspect.ismethod" will + always return False, even if the function is a method. -def check_params(init_fn: Callable) -> Callable: + We can get around this by checking if the parent of the function is a class. Then, + we check if the first argument of the function is "self". If both of these are true, + then the function is a method. + """ + if not inspect.isfunction(fn): + raise TypeError("decorator @check_params can only be used on functions!") + qualified_obj_name = fn.__qualname__ + qualified_obj_path = qualified_obj_name.split(".") + if len(qualified_obj_path) == 1: + # If the qualified name isn't split, this is a standard function not + # attached to a class + return _FnType.STANDARD + + spec = inspect.getfullargspec(fn) + if spec.args[0] == "self": + return _FnType.METHOD + else: + return _FnType.CLASSMETHOD + + +def check_params(fn: Callable) -> Callable: """A decorator for enforcing checking of params in __init__ methods. This decorator will automatically load the default parameters for the class and check that the passed parameters are valid. It expeects a "params.py" file in the same folder as @@ -27,12 +82,11 @@ def check_params(init_fn: Callable) -> Callable: __init__ fn. Developers only need to add @check_params above their __init__ method definition to enable this feature, then add their default parameters to the "params.py" file. + """ + fn_type = determine_fn_type(fn) + - if not init_fn.__name__.startswith("__init__"): - raise SlSimParameterException( - "pcheck decorator can currently only be used" " with__init__ methods" - ) @wraps(init_fn) def new_init_fn(obj: Any, *args, **kwargs) -> Any: diff --git a/tests/test_Params/__init__.py b/tests/test_Params/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_Params/test_params.py b/tests/test_Params/test_params.py new file mode 100644 index 000000000..5d0daf217 --- /dev/null +++ b/tests/test_Params/test_params.py @@ -0,0 +1,3 @@ +from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel + +GaussianMixtureModel() \ No newline at end of file From 09051087b55e531a8eef0bac75ccf22e8aa675f6 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 19:01:10 +0000 Subject: [PATCH 16/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- slsim/Util/params.py | 30 +++++++++++++----------------- tests/test_Params/test_params.py | 2 +- 2 files changed, 14 insertions(+), 18 deletions(-) diff --git a/slsim/Util/params.py b/slsim/Util/params.py index f0f38cacd..40af4e46a 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -16,40 +16,39 @@ directly from the Util module. """ + class SlSimParameterException(Exception): pass _defaults = {} + class _FnType(Enum): - """ - - Enum for the different types of functions we can have. This is used to determine + """Enum for the different types of functions we can have. This is used to determine how to parse the arguments to the function. - + There are three possible cases: 1. The function is a standard function, defined outside a class - 2. The function is a standard object method, + 2. The function is a standard object method, taking "self" as the first parameter - 3. The funtion is a class method (or staticmethod), not taking + 3. The funtion is a class method (or staticmethod), not taking "self" as the first parameter - """ STANDARD = 0 METHOD = 1 CLASSMETHOD = 2 + def determine_fn_type(fn: Callable) -> _FnType: - """ - Determine which of the three possible cases a function falls into. Cases - 0 and 2 are actually functionally identical. Things only get spicy when we - have a "self" argument. + """Determine which of the three possible cases a function falls into. Cases 0 and 2 + are actually functionally identical. Things only get spicy when we have a "self" + argument. However the tricky thing is that decorators operate on functions and methods when - they are imported, not when they are used. This means "inspect.ismethod" will - always return False, even if the function is a method. + they are imported, not when they are used. This means "inspect.ismethod" will always + return False, even if the function is a method. We can get around this by checking if the parent of the function is a class. Then, we check if the first argument of the function is "self". If both of these are true, @@ -63,7 +62,7 @@ def determine_fn_type(fn: Callable) -> _FnType: # If the qualified name isn't split, this is a standard function not # attached to a class return _FnType.STANDARD - + spec = inspect.getfullargspec(fn) if spec.args[0] == "self": return _FnType.METHOD @@ -82,12 +81,9 @@ def check_params(fn: Callable) -> Callable: __init__ fn. Developers only need to add @check_params above their __init__ method definition to enable this feature, then add their default parameters to the "params.py" file. - """ fn_type = determine_fn_type(fn) - - @wraps(init_fn) def new_init_fn(obj: Any, *args, **kwargs) -> Any: # Get function argument names diff --git a/tests/test_Params/test_params.py b/tests/test_Params/test_params.py index 5d0daf217..aab39932b 100644 --- a/tests/test_Params/test_params.py +++ b/tests/test_Params/test_params.py @@ -1,3 +1,3 @@ from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel -GaussianMixtureModel() \ No newline at end of file +GaussianMixtureModel() From a730c7d54d8d94fb68dd923c12e6a6977095af78 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 11:46:26 -0800 Subject: [PATCH 17/25] Working with all object methods --- slsim/ParamDistributions/_params/__init__.py | 0 .../gaussian_mixture_model.py} | 5 +- .../gaussian_mixture_model.py | 7 +- slsim/Util/params.py | 105 ++++++++++++++---- tests/test_Params/test_params.py | 3 +- 5 files changed, 92 insertions(+), 28 deletions(-) create mode 100644 slsim/ParamDistributions/_params/__init__.py rename slsim/ParamDistributions/{params.py => _params/gaussian_mixture_model.py} (83%) diff --git a/slsim/ParamDistributions/_params/__init__.py b/slsim/ParamDistributions/_params/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slsim/ParamDistributions/params.py b/slsim/ParamDistributions/_params/gaussian_mixture_model.py similarity index 83% rename from slsim/ParamDistributions/params.py rename to slsim/ParamDistributions/_params/gaussian_mixture_model.py index d476ea161..8f717d047 100644 --- a/slsim/ParamDistributions/params.py +++ b/slsim/ParamDistributions/_params/gaussian_mixture_model.py @@ -1,4 +1,4 @@ -from pydantic import BaseModel, PositiveFloat, model_validator, field_validator +from pydantic import BaseModel, PositiveFloat, PositiveInt, model_validator, field_validator import numpy as np @@ -23,3 +23,6 @@ def check_lengths(self): if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): raise ValueError("The lengths of means, stds and weights must be equal") return self + +class GaussianMixtureModel_rvs(BaseModel): + size: PositiveInt \ No newline at end of file diff --git a/slsim/ParamDistributions/gaussian_mixture_model.py b/slsim/ParamDistributions/gaussian_mixture_model.py index 6dada3151..c8143f9b6 100644 --- a/slsim/ParamDistributions/gaussian_mixture_model.py +++ b/slsim/ParamDistributions/gaussian_mixture_model.py @@ -10,11 +10,11 @@ class GaussianMixtureModel: """ @check_params - def __init__(self, means, stds, weights): + def __init__(self, means: list[float], stds: list[float], weights:list[float]): """ The constructor for GaussianMixtureModel class. The default values are the means, standard deviations, and weights of the fits to the data in the table - 2 of https://doi.org/10.1093/mnras/stac2235 and others. See "params.py" for + 2 of https://doi.org/10.1093/mnras/stac2235 and others. See "_params" for defaults and validation logic. :param means: the mean values of the Gaussian components. @@ -28,7 +28,8 @@ def __init__(self, means, stds, weights): self.stds = stds self.weights = weights - def rvs(self, size): + @check_params + def rvs(self, size: int): """Generate random variables from the GMM distribution. :param size: The number of random variables to generate. diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 40af4e46a..8a9801cdb 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -83,36 +83,95 @@ def check_params(fn: Callable) -> Callable: "params.py" file. """ fn_type = determine_fn_type(fn) + if fn_type == _FnType.STANDARD: + new_fn = standard_fn_wrapper(fn) + elif fn_type == _FnType.METHOD: + new_fn = method_fn_wrapper(fn) + elif fn_type == _FnType.CLASSMETHOD: + new_fn = standard_fn_wrapper(fn) + + return new_fn + +def standard_fn_wrapper(fn: Callable) -> Callable: + """A wrapper for standard functions. This is used to parse the arguments to the + function and check that they are valid. + """ + spec = inspect.getfullargspec(fn) + @wraps(fn) + def new_fn(*args, **kwargs) -> Any: + # Get function argument names + pargs = {} + if args: + largs = getargspec(fn).args + for i in range(len(args)): + pargs[largs[i + 1]] = args[i] + # Doing it this way ensures we still catch duplicate arguments + defaults = get_defaults(fn) + parsed_args = defaults(**pargs, **kwargs) + return fn(**dict(parsed_args)) + + return new_fn - @wraps(init_fn) - def new_init_fn(obj: Any, *args, **kwargs) -> Any: +def method_fn_wrapper(fn: Callable) -> Callable: + @wraps(fn) + def new_fn(obj: Any, *args, **kwargs) -> Any: # Get function argument names pargs = {} if args: - largs = getargspec(init_fn).args + largs = getargspec(fn).args for i in range(len(args)): pargs[largs[i + 1]] = args[i] # Doing it this way ensures we still catch duplicate arguments - parsed_args = get_defaults(init_fn)(**pargs, **kwargs) - return init_fn(obj, **dict(parsed_args)) - - return new_init_fn - - -def get_defaults(init_fn: Callable) -> pydantic.BaseModel: - path = getsourcefile(init_fn) - obj_name = init_fn.__qualname__.split(".")[0] - start = path.rfind("slsim") - modpath = path[start:].split("/") - modpath = modpath[1:-1] + ["params"] - modpath = ".".join(["slsim"] + modpath) - # Unfortunately, there doesn't seem to be a better way of doing this. - - if modpath not in _defaults: - # Little optimization. We cache defaults so we don't have to reload them - # every time we construct a new object. - _defaults[modpath] = load_parameters(modpath, obj_name) - return _defaults[modpath] + defaults = get_defaults(fn) + parsed_args = defaults(**pargs, **kwargs) + return fn(obj, **dict(parsed_args)) + return new_fn + + +def get_defaults(fn: Callable) -> pydantic.BaseModel: + module_trace = inspect.getmodule(fn).__name__.split(".") + file_name = module_trace[-1] + parent_trace = module_trace[:-1] + parent_path = ".".join(parent_trace) + param_path = ".".join([parent_path, "_params"]) + fn_qualname = fn.__qualname__ + cache_name = parent_path + "." + fn_qualname + if cache_name in _defaults: + return _defaults[cache_name] + + try: + param_module = import_module(param_path) + except ModuleNotFoundError: + raise SlSimParameterException( + f'No default parameters found in module {".".join(parent_trace)},'\ + ' but something in that module is trying to use the @check_params decorator' + ) + try: + param_model_file = import_module(f'{param_path}.{file_name}') + except AttributeError: + raise SlSimParameterException( + f'No default parameters found for file "{file_name}" in module '\ + f'{".".join(parent_trace)}, but something in that module is trying to use '\ + 'the @check_params decorator' + ) + + if fn.__name__ == "__init__": + expected_model_name = "_".join(fn_qualname.split(".")[:-1]) + else: + expected_model_name = "_".join(fn_qualname.split(".")) + + try: + + parameter_model = getattr(param_model_file, expected_model_name) + except AttributeError: + raise SlSimParameterException("No default parameters found for function "\ + f'"{fn_qualname}"') + if not issubclass(parameter_model, pydantic.BaseModel): + raise SlSimParameterException( + f'Defaults for "{fn_qualname}" are not in a pydantic model!' + ) + _defaults[cache_name] = parameter_model + return _defaults[cache_name] def load_parameters(modpath: str, obj_name: str) -> pydantic.BaseModel: diff --git a/tests/test_Params/test_params.py b/tests/test_Params/test_params.py index aab39932b..d28adde1c 100644 --- a/tests/test_Params/test_params.py +++ b/tests/test_Params/test_params.py @@ -1,3 +1,4 @@ from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel -GaussianMixtureModel() +a = GaussianMixtureModel() +a.rvs(156) \ No newline at end of file From 5fea8ec41de77495c3cf139b0a2c135a2d35c2ef Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 14:47:51 -0800 Subject: [PATCH 18/25] Working with tests --- .github/workflows/tests.yml | 2 +- slsim/Util/params.py | 19 ++++++++++++------- tests/test_Params/test_params.py | 4 ++-- 3 files changed, 15 insertions(+), 10 deletions(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4fd37eff5..c3030d2fe 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: [3.9, 3.10, 3.11] steps: - uses: actions/checkout@v3 diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 8a9801cdb..dc55af8f5 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -3,7 +3,6 @@ Desgined to be unobtrusive to use. """ from functools import wraps -from inspect import getsourcefile, getargspec from importlib import import_module from typing import Callable, Any from enum import Enum @@ -102,9 +101,11 @@ def new_fn(*args, **kwargs) -> Any: # Get function argument names pargs = {} if args: - largs = getargspec(fn).args + largs = inspect.signature(fn).parameters for i in range(len(args)): - pargs[largs[i + 1]] = args[i] + arg_value = args[i] + if arg_value is not None: + pargs[largs[i + 1]] = args[i] # Doing it this way ensures we still catch duplicate arguments defaults = get_defaults(fn) parsed_args = defaults(**pargs, **kwargs) @@ -116,14 +117,18 @@ def method_fn_wrapper(fn: Callable) -> Callable: @wraps(fn) def new_fn(obj: Any, *args, **kwargs) -> Any: # Get function argument names - pargs = {} + parsed_args = {} if args: - largs = getargspec(fn).args + largs = list(inspect.signature(fn).parameters.keys()) + for i in range(len(args)): - pargs[largs[i + 1]] = args[i] + arg_value = args[i] + if arg_value is not None: + parsed_args[largs[i + 1]] = arg_value # Doing it this way ensures we still catch duplicate arguments + parsed_kwargs = {k: v for k, v in kwargs.items() if v is not None} defaults = get_defaults(fn) - parsed_args = defaults(**pargs, **kwargs) + parsed_args = defaults(**parsed_args, **parsed_kwargs) return fn(obj, **dict(parsed_args)) return new_fn diff --git a/tests/test_Params/test_params.py b/tests/test_Params/test_params.py index d28adde1c..2f8437c33 100644 --- a/tests/test_Params/test_params.py +++ b/tests/test_Params/test_params.py @@ -1,4 +1,4 @@ from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel -a = GaussianMixtureModel() -a.rvs(156) \ No newline at end of file +a = GaussianMixtureModel(means=[1,2,3], stds=[1,2,3], weights=[0.4,0.4,0.2]) +print(a.stds) \ No newline at end of file From eadf85ed2b7e7a11839d3ff39d1f0f2f89b3558b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 22:48:06 +0000 Subject: [PATCH 19/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- .../_params/gaussian_mixture_model.py | 11 +++++-- .../gaussian_mixture_model.py | 2 +- slsim/Util/params.py | 33 +++++++++++-------- tests/test_Params/test_params.py | 4 +-- 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/slsim/ParamDistributions/_params/gaussian_mixture_model.py b/slsim/ParamDistributions/_params/gaussian_mixture_model.py index 8f717d047..b00e880aa 100644 --- a/slsim/ParamDistributions/_params/gaussian_mixture_model.py +++ b/slsim/ParamDistributions/_params/gaussian_mixture_model.py @@ -1,4 +1,10 @@ -from pydantic import BaseModel, PositiveFloat, PositiveInt, model_validator, field_validator +from pydantic import ( + BaseModel, + PositiveFloat, + PositiveInt, + model_validator, + field_validator, +) import numpy as np @@ -24,5 +30,6 @@ def check_lengths(self): raise ValueError("The lengths of means, stds and weights must be equal") return self + class GaussianMixtureModel_rvs(BaseModel): - size: PositiveInt \ No newline at end of file + size: PositiveInt diff --git a/slsim/ParamDistributions/gaussian_mixture_model.py b/slsim/ParamDistributions/gaussian_mixture_model.py index c8143f9b6..ecbf86ed4 100644 --- a/slsim/ParamDistributions/gaussian_mixture_model.py +++ b/slsim/ParamDistributions/gaussian_mixture_model.py @@ -10,7 +10,7 @@ class GaussianMixtureModel: """ @check_params - def __init__(self, means: list[float], stds: list[float], weights:list[float]): + def __init__(self, means: list[float], stds: list[float], weights: list[float]): """ The constructor for GaussianMixtureModel class. The default values are the means, standard deviations, and weights of the fits to the data in the table diff --git a/slsim/Util/params.py b/slsim/Util/params.py index dc55af8f5..fed8593ed 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -88,14 +88,17 @@ def check_params(fn: Callable) -> Callable: new_fn = method_fn_wrapper(fn) elif fn_type == _FnType.CLASSMETHOD: new_fn = standard_fn_wrapper(fn) - + return new_fn + def standard_fn_wrapper(fn: Callable) -> Callable: - """A wrapper for standard functions. This is used to parse the arguments to the - function and check that they are valid. + """A wrapper for standard functions. + + This is used to parse the arguments to the function and check that they are valid. """ spec = inspect.getfullargspec(fn) + @wraps(fn) def new_fn(*args, **kwargs) -> Any: # Get function argument names @@ -113,6 +116,7 @@ def new_fn(*args, **kwargs) -> Any: return new_fn + def method_fn_wrapper(fn: Callable) -> Callable: @wraps(fn) def new_fn(obj: Any, *args, **kwargs) -> Any: @@ -130,6 +134,7 @@ def new_fn(obj: Any, *args, **kwargs) -> Any: defaults = get_defaults(fn) parsed_args = defaults(**parsed_args, **parsed_kwargs) return fn(obj, **dict(parsed_args)) + return new_fn @@ -148,17 +153,17 @@ def get_defaults(fn: Callable) -> pydantic.BaseModel: param_module = import_module(param_path) except ModuleNotFoundError: raise SlSimParameterException( - f'No default parameters found in module {".".join(parent_trace)},'\ - ' but something in that module is trying to use the @check_params decorator' - ) + f'No default parameters found in module {".".join(parent_trace)},' + " but something in that module is trying to use the @check_params decorator" + ) try: - param_model_file = import_module(f'{param_path}.{file_name}') + param_model_file = import_module(f"{param_path}.{file_name}") except AttributeError: raise SlSimParameterException( - f'No default parameters found for file "{file_name}" in module '\ - f'{".".join(parent_trace)}, but something in that module is trying to use '\ - 'the @check_params decorator' - ) + f'No default parameters found for file "{file_name}" in module ' + f'{".".join(parent_trace)}, but something in that module is trying to use ' + "the @check_params decorator" + ) if fn.__name__ == "__init__": expected_model_name = "_".join(fn_qualname.split(".")[:-1]) @@ -166,11 +171,11 @@ def get_defaults(fn: Callable) -> pydantic.BaseModel: expected_model_name = "_".join(fn_qualname.split(".")) try: - parameter_model = getattr(param_model_file, expected_model_name) except AttributeError: - raise SlSimParameterException("No default parameters found for function "\ - f'"{fn_qualname}"') + raise SlSimParameterException( + "No default parameters found for function " f'"{fn_qualname}"' + ) if not issubclass(parameter_model, pydantic.BaseModel): raise SlSimParameterException( f'Defaults for "{fn_qualname}" are not in a pydantic model!' diff --git a/tests/test_Params/test_params.py b/tests/test_Params/test_params.py index 2f8437c33..5bcbeb2f2 100644 --- a/tests/test_Params/test_params.py +++ b/tests/test_Params/test_params.py @@ -1,4 +1,4 @@ from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel -a = GaussianMixtureModel(means=[1,2,3], stds=[1,2,3], weights=[0.4,0.4,0.2]) -print(a.stds) \ No newline at end of file +a = GaussianMixtureModel(means=[1, 2, 3], stds=[1, 2, 3], weights=[0.4, 0.4, 0.2]) +print(a.stds) From 2e37905e17df118270d1666413f88ef435dc41fd Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 14:55:56 -0800 Subject: [PATCH 20/25] Passing linters --- slsim/Util/params.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/slsim/Util/params.py b/slsim/Util/params.py index fed8593ed..bf734781e 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -88,7 +88,7 @@ def check_params(fn: Callable) -> Callable: new_fn = method_fn_wrapper(fn) elif fn_type == _FnType.CLASSMETHOD: new_fn = standard_fn_wrapper(fn) - + return new_fn @@ -97,8 +97,6 @@ def standard_fn_wrapper(fn: Callable) -> Callable: This is used to parse the arguments to the function and check that they are valid. """ - spec = inspect.getfullargspec(fn) - @wraps(fn) def new_fn(*args, **kwargs) -> Any: # Get function argument names @@ -150,7 +148,7 @@ def get_defaults(fn: Callable) -> pydantic.BaseModel: return _defaults[cache_name] try: - param_module = import_module(param_path) + _ = import_module(param_path) except ModuleNotFoundError: raise SlSimParameterException( f'No default parameters found in module {".".join(parent_trace)},' From 3ab59d8180cf60bb7c27e354a74995f810eb3e65 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 6 Nov 2023 22:57:38 +0000 Subject: [PATCH 21/25] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- slsim/Util/params.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/slsim/Util/params.py b/slsim/Util/params.py index bf734781e..8e6a33dac 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -88,7 +88,7 @@ def check_params(fn: Callable) -> Callable: new_fn = method_fn_wrapper(fn) elif fn_type == _FnType.CLASSMETHOD: new_fn = standard_fn_wrapper(fn) - + return new_fn @@ -97,6 +97,7 @@ def standard_fn_wrapper(fn: Callable) -> Callable: This is used to parse the arguments to the function and check that they are valid. """ + @wraps(fn) def new_fn(*args, **kwargs) -> Any: # Get function argument names From 5287a3d8438c4ae23ff60613d93c827d77781409 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 16:14:13 -0800 Subject: [PATCH 22/25] Tests written --- .../_params/elliptical_lens_galaxies.py | 5 + .../Deflectors/_params/velocity_dispersion.py | 11 ++ slsim/Deflectors/elliptical_lens_galaxies.py | 2 + slsim/Deflectors/velocity_dispersion.py | 2 + slsim/Util/params.py | 4 +- tests/test_Params/test_params.py | 4 - tests/test_Params/test_params_check.py | 140 ++++++++++++++++++ 7 files changed, 162 insertions(+), 6 deletions(-) create mode 100644 slsim/Deflectors/_params/elliptical_lens_galaxies.py create mode 100644 slsim/Deflectors/_params/velocity_dispersion.py delete mode 100644 tests/test_Params/test_params.py create mode 100644 tests/test_Params/test_params_check.py diff --git a/slsim/Deflectors/_params/elliptical_lens_galaxies.py b/slsim/Deflectors/_params/elliptical_lens_galaxies.py new file mode 100644 index 000000000..0ec1deff6 --- /dev/null +++ b/slsim/Deflectors/_params/elliptical_lens_galaxies.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel, PositiveFloat + + +class vel_disp_from_m_star(BaseModel): + m_star: PositiveFloat diff --git a/slsim/Deflectors/_params/velocity_dispersion.py b/slsim/Deflectors/_params/velocity_dispersion.py new file mode 100644 index 000000000..475aae9b1 --- /dev/null +++ b/slsim/Deflectors/_params/velocity_dispersion.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, PositiveFloat +from astropy.cosmology import Cosmology + + +class vel_disp_composite_model(BaseModel, arbitrary_types_allowed=True): + r: PositiveFloat + m_star: PositiveFloat + rs_star: PositiveFloat + m_halo: PositiveFloat + c_halo: float + cosmo: Cosmology diff --git a/slsim/Deflectors/elliptical_lens_galaxies.py b/slsim/Deflectors/elliptical_lens_galaxies.py index c37788827..4ca7458ff 100644 --- a/slsim/Deflectors/elliptical_lens_galaxies.py +++ b/slsim/Deflectors/elliptical_lens_galaxies.py @@ -4,6 +4,7 @@ from slsim.Deflectors.velocity_dispersion import vel_disp_sdss from slsim.Util import param_util from slsim.Deflectors.deflector_base import DeflectorBase +from slsim.Util import check_params class EllipticalLensGalaxies(DeflectorBase): @@ -120,6 +121,7 @@ def elliptical_projected_eccentricity(ellipticity, **kwargs): return e1_light, e2_light, e1_mass, e2_mass +@check_params def vel_disp_from_m_star(m_star): """Function to calculate the velocity dispersion from the staller mass using empirical relation for elliptical galaxies. diff --git a/slsim/Deflectors/velocity_dispersion.py b/slsim/Deflectors/velocity_dispersion.py index abe9d3078..08ba313f7 100644 --- a/slsim/Deflectors/velocity_dispersion.py +++ b/slsim/Deflectors/velocity_dispersion.py @@ -2,6 +2,7 @@ import scipy from skypy.galaxies.redshift import redshifts_from_comoving_density from skypy.utils.random import schechter +from slsim.Util import check_params """ This module provides functions to compute velocity dispersion using schechter function. @@ -11,6 +12,7 @@ # from skypy.galaxies.velocity_dispersion import schechter_vdf +@check_params def vel_disp_composite_model(r, m_star, rs_star, m_halo, c_halo, cosmo): """Computes the luminosity weighted velocity dispersion for a deflector with a stellar Hernquist profile and a NFW halo profile, assuming isotropic anisotropy. diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 8e6a33dac..0e120d13a 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -103,11 +103,11 @@ def new_fn(*args, **kwargs) -> Any: # Get function argument names pargs = {} if args: - largs = inspect.signature(fn).parameters + largs = list(inspect.signature(fn).parameters.keys()) for i in range(len(args)): arg_value = args[i] if arg_value is not None: - pargs[largs[i + 1]] = args[i] + pargs[largs[i]] = args[i] # Doing it this way ensures we still catch duplicate arguments defaults = get_defaults(fn) parsed_args = defaults(**pargs, **kwargs) diff --git a/tests/test_Params/test_params.py b/tests/test_Params/test_params.py deleted file mode 100644 index 5bcbeb2f2..000000000 --- a/tests/test_Params/test_params.py +++ /dev/null @@ -1,4 +0,0 @@ -from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel - -a = GaussianMixtureModel(means=[1, 2, 3], stds=[1, 2, 3], weights=[0.4, 0.4, 0.2]) -print(a.stds) diff --git a/tests/test_Params/test_params_check.py b/tests/test_Params/test_params_check.py new file mode 100644 index 000000000..b0d2e3927 --- /dev/null +++ b/tests/test_Params/test_params_check.py @@ -0,0 +1,140 @@ +from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel +from slsim.Deflectors.velocity_dispersion import vel_disp_composite_model +from astropy.cosmology import FlatLambdaCDM +from pydantic import ValidationError +import pytest +import random + +""" +Test for the parameter checking rountines in slsim. We want to make sure we're using +functions/classes that are actually in slsim, rather than ones in the /tests folder. +This is because the parameter checking routine discovers defaults by importing from +a standard location inside the slsim package. If we import from the /tests folder, +this will fail. +""" + + +def test_all_kwargs_init(good_inputs: dict): + gmm = GaussianMixtureModel(**good_inputs) + assert ( + gmm.means == good_inputs["means"] + and gmm.stds == good_inputs["stds"] + and gmm.weights == good_inputs["weights"] + ) + + +def test_all_args_init(good_inputs: dict): + inputs = list(good_inputs.values()) + gmm = GaussianMixtureModel(*inputs) + assert gmm.means == good_inputs["means"] + + +def test_all_args_method(good_model: GaussianMixtureModel): + output = good_model.rvs(100) + assert len(output) == 100 + + +def test_all_kwargs_method(good_model: GaussianMixtureModel): + output = good_model.rvs(size=100) + assert len(output) == 100 + + +def test_all_args_function(vel_disp_inputs: dict): + inputs = list(vel_disp_inputs.values()) + output = vel_disp_composite_model(*inputs) + assert isinstance(output, float) + + +def test_all_kwargs_function(vel_disp_inputs: dict): + output = vel_disp_composite_model(**vel_disp_inputs) + assert isinstance(output, float) + + +def test_mixture_init(good_inputs: dict): + good_inputs_keys = list(good_inputs.keys()) + input_args = [good_inputs[k] for k in good_inputs_keys[:2]] + input_kwargs = {k: good_inputs[k] for k in good_inputs_keys[2:]} + + gmm = GaussianMixtureModel(*input_args, **input_kwargs) + assert ( + gmm.means == good_inputs["means"] + and gmm.stds == good_inputs["stds"] + and gmm.weights == good_inputs["weights"] + ) + + +def test_mixture_function(vel_disp_inputs: dict): + vel_disp_inputs_keys = list(vel_disp_inputs.keys()) + input_args = [vel_disp_inputs[k] for k in vel_disp_inputs_keys[:3]] + input_kwargs = {k: vel_disp_inputs[k] for k in vel_disp_inputs_keys[3:]} + + output = vel_disp_composite_model(*input_args, **input_kwargs) + assert isinstance(output, float) + + +def test_shuffle_init(good_inputs: dict): + keys = list(good_inputs.keys()) + random.shuffle(keys) + input = {k: good_inputs[k] for k in keys} + gmm = GaussianMixtureModel(**input) + assert ( + gmm.means == good_inputs["means"] + and gmm.stds == good_inputs["stds"] + and gmm.weights == good_inputs["weights"] + ) + + +def test_shuffle_function(vel_disp_inputs: dict): + keys = list(vel_disp_inputs.keys()) + random.shuffle(keys) + input = {k: vel_disp_inputs[k] for k in keys} + output = vel_disp_composite_model(**input) + assert isinstance(output, float) + + +def test_failure_init(good_inputs: dict): + good_inputs_keys = list(good_inputs.keys()) + input_args = [good_inputs[k] for k in good_inputs_keys[:2]] + input_kwargs = {k: good_inputs[k] for k in good_inputs_keys[2:]} + input_kwargs["weights"] = [0.2, 0.3, 0.4] + with pytest.raises(ValueError): + _ = GaussianMixtureModel(*input_args, **input_kwargs) + + +def test_failure_method(good_model: GaussianMixtureModel): + with pytest.raises(ValidationError): + _ = good_model.rvs(-50) + + +def test_failure_function(vel_disp_inputs: dict): + input_kwargs = vel_disp_inputs + input_kwargs["m_star"] = -50 + with pytest.raises(ValidationError): + _ = vel_disp_composite_model(**input_kwargs) + + +@pytest.fixture +def cosmology(): + return FlatLambdaCDM(H0=70, Om0=0.3) + + +@pytest.fixture +def vel_disp_inputs(cosmology): + return { + "r": 5, + "m_star": 10**10, + "rs_star": 30, + "m_halo": 10**14, + "c_halo": 3, + "cosmo": cosmology, + } + + +@pytest.fixture +def good_inputs(): + return {"means": [1, 2, 3], "stds": [1, 2, 3], "weights": [0.2, 0.3, 0.5]} + + +@pytest.fixture +def good_model(good_inputs: dict): + return GaussianMixtureModel(**good_inputs) From 067e64269c5b43319a76c53cd2aedb185753b266 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 16:39:41 -0800 Subject: [PATCH 23/25] Bump minimum version to 3.9 --- setup.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/setup.py b/setup.py index 6c61cea6f..e57f125a3 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( author="DESC/SLSC", author_email="sibirrer@gmail.com", - python_requires=">=3.6", + python_requires=">=3.9", classifiers=[ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", From b5d21dd14699137c8796eabaf16efe29a9fbd2f7 Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 17:34:16 -0800 Subject: [PATCH 24/25] partial fix for typing --- slsim/Util/params.py | 15 +++++++++------ 1 file changed, 9 insertions(+), 6 deletions(-) diff --git a/slsim/Util/params.py b/slsim/Util/params.py index 0e120d13a..dd19bbe03 100644 --- a/slsim/Util/params.py +++ b/slsim/Util/params.py @@ -4,7 +4,7 @@ """ from functools import wraps from importlib import import_module -from typing import Callable, Any +from typing import Callable, Any, TypeVar from enum import Enum import inspect import pydantic @@ -69,7 +69,10 @@ def determine_fn_type(fn: Callable) -> _FnType: return _FnType.CLASSMETHOD -def check_params(fn: Callable) -> Callable: +T_ = TypeVar("T_") + + +def check_params(fn: Callable[..., T_]) -> Callable[..., T_]: """A decorator for enforcing checking of params in __init__ methods. This decorator will automatically load the default parameters for the class and check that the passed parameters are valid. It expeects a "params.py" file in the same folder as @@ -92,14 +95,14 @@ def check_params(fn: Callable) -> Callable: return new_fn -def standard_fn_wrapper(fn: Callable) -> Callable: +def standard_fn_wrapper(fn: Callable[..., T_]) -> Callable[..., T_]: """A wrapper for standard functions. This is used to parse the arguments to the function and check that they are valid. """ @wraps(fn) - def new_fn(*args, **kwargs) -> Any: + def new_fn(*args, **kwargs) -> T_: # Get function argument names pargs = {} if args: @@ -116,9 +119,9 @@ def new_fn(*args, **kwargs) -> Any: return new_fn -def method_fn_wrapper(fn: Callable) -> Callable: +def method_fn_wrapper(fn: Callable[..., T_]) -> Callable[..., T_]: @wraps(fn) - def new_fn(obj: Any, *args, **kwargs) -> Any: + def new_fn(obj: Any, *args, **kwargs) -> T_: # Get function argument names parsed_args = {} if args: From cfec7a6f363caf8019941a07e0b252f5eedd8a2d Mon Sep 17 00:00:00 2001 From: Patrick Wells Date: Mon, 6 Nov 2023 17:49:04 -0800 Subject: [PATCH 25/25] Fix to run in 3.10 --- .github/workflows/tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index c3030d2fe..5abb8e2f2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.9, 3.10, 3.11] + python-version: ['3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v3