diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 4fd37eff5..5abb8e2f2 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -13,7 +13,7 @@ jobs: runs-on: ubuntu-latest strategy: matrix: - python-version: [3.8] + python-version: ['3.9', '3.10', '3.11'] steps: - uses: actions/checkout@v3 diff --git a/requirements.txt b/requirements.txt index 283273cf9..3b2f41356 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,3 +6,4 @@ scipy speclite pyyaml matplotlib +pydantic>=2 \ No newline at end of file diff --git a/setup.py b/setup.py index 6c61cea6f..e57f125a3 100644 --- a/setup.py +++ b/setup.py @@ -19,7 +19,7 @@ setup( author="DESC/SLSC", author_email="sibirrer@gmail.com", - python_requires=">=3.6", + python_requires=">=3.9", classifiers=[ "Development Status :: 2 - Pre-Alpha", "Intended Audience :: Developers", diff --git a/slsim/Deflectors/_params/elliptical_lens_galaxies.py b/slsim/Deflectors/_params/elliptical_lens_galaxies.py new file mode 100644 index 000000000..0ec1deff6 --- /dev/null +++ b/slsim/Deflectors/_params/elliptical_lens_galaxies.py @@ -0,0 +1,5 @@ +from pydantic import BaseModel, PositiveFloat + + +class vel_disp_from_m_star(BaseModel): + m_star: PositiveFloat diff --git a/slsim/Deflectors/_params/velocity_dispersion.py b/slsim/Deflectors/_params/velocity_dispersion.py new file mode 100644 index 000000000..475aae9b1 --- /dev/null +++ b/slsim/Deflectors/_params/velocity_dispersion.py @@ -0,0 +1,11 @@ +from pydantic import BaseModel, PositiveFloat +from astropy.cosmology import Cosmology + + +class vel_disp_composite_model(BaseModel, arbitrary_types_allowed=True): + r: PositiveFloat + m_star: PositiveFloat + rs_star: PositiveFloat + m_halo: PositiveFloat + c_halo: float + cosmo: Cosmology diff --git a/slsim/Deflectors/elliptical_lens_galaxies.py b/slsim/Deflectors/elliptical_lens_galaxies.py index c37788827..4ca7458ff 100644 --- a/slsim/Deflectors/elliptical_lens_galaxies.py +++ b/slsim/Deflectors/elliptical_lens_galaxies.py @@ -4,6 +4,7 @@ from slsim.Deflectors.velocity_dispersion import vel_disp_sdss from slsim.Util import param_util from slsim.Deflectors.deflector_base import DeflectorBase +from slsim.Util import check_params class EllipticalLensGalaxies(DeflectorBase): @@ -120,6 +121,7 @@ def elliptical_projected_eccentricity(ellipticity, **kwargs): return e1_light, e2_light, e1_mass, e2_mass +@check_params def vel_disp_from_m_star(m_star): """Function to calculate the velocity dispersion from the staller mass using empirical relation for elliptical galaxies. diff --git a/slsim/Deflectors/velocity_dispersion.py b/slsim/Deflectors/velocity_dispersion.py index abe9d3078..08ba313f7 100644 --- a/slsim/Deflectors/velocity_dispersion.py +++ b/slsim/Deflectors/velocity_dispersion.py @@ -2,6 +2,7 @@ import scipy from skypy.galaxies.redshift import redshifts_from_comoving_density from skypy.utils.random import schechter +from slsim.Util import check_params """ This module provides functions to compute velocity dispersion using schechter function. @@ -11,6 +12,7 @@ # from skypy.galaxies.velocity_dispersion import schechter_vdf +@check_params def vel_disp_composite_model(r, m_star, rs_star, m_halo, c_halo, cosmo): """Computes the luminosity weighted velocity dispersion for a deflector with a stellar Hernquist profile and a NFW halo profile, assuming isotropic anisotropy. diff --git a/slsim/ParamDistributions/_params/__init__.py b/slsim/ParamDistributions/_params/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/slsim/ParamDistributions/_params/gaussian_mixture_model.py b/slsim/ParamDistributions/_params/gaussian_mixture_model.py new file mode 100644 index 000000000..b00e880aa --- /dev/null +++ b/slsim/ParamDistributions/_params/gaussian_mixture_model.py @@ -0,0 +1,35 @@ +from pydantic import ( + BaseModel, + PositiveFloat, + PositiveInt, + model_validator, + field_validator, +) +import numpy as np + + +class GaussianMixtureModel(BaseModel): + means: list[float] = [0.00330796, -0.07635054, 0.11829008] + stds: list[PositiveFloat] = [ + np.sqrt(0.00283885), + np.sqrt(0.01066668), + np.sqrt(0.0097978), + ] + weights: list[PositiveFloat] = [0.62703102, 0.23732313, 0.13564585] + + @field_validator("weights") + @classmethod + def check_weights(cls, weight_values): + if sum(weight_values) != 1: + raise ValueError("The sum of the weights must be 1") + return weight_values + + @model_validator(mode="after") + def check_lengths(self): + if len(self.means) != len(self.stds) or len(self.means) != len(self.weights): + raise ValueError("The lengths of means, stds and weights must be equal") + return self + + +class GaussianMixtureModel_rvs(BaseModel): + size: PositiveInt diff --git a/slsim/ParamDistributions/gaussian_mixture_model.py b/slsim/ParamDistributions/gaussian_mixture_model.py index 0d23e6a5b..ecbf86ed4 100644 --- a/slsim/ParamDistributions/gaussian_mixture_model.py +++ b/slsim/ParamDistributions/gaussian_mixture_model.py @@ -1,4 +1,5 @@ import numpy as np +from slsim.Util import check_params class GaussianMixtureModel: @@ -8,11 +9,13 @@ class GaussianMixtureModel: is defined by its mean, standard deviation and weight. """ - def __init__(self, means=None, stds=None, weights=None): + @check_params + def __init__(self, means: list[float], stds: list[float], weights: list[float]): """ The constructor for GaussianMixtureModel class. The default values are the means, standard deviations, and weights of the fits to the data in the table - 2 of https://doi.org/10.1093/mnras/stac2235 and others. + 2 of https://doi.org/10.1093/mnras/stac2235 and others. See "_params" for + defaults and validation logic. :param means: the mean values of the Gaussian components. :type means: list of float @@ -21,20 +24,12 @@ def __init__(self, means=None, stds=None, weights=None): :param weights: The weights of the Gaussian components in the mixture. :type weights: list of float """ - if means is None: - means = [0.00330796, -0.07635054, 0.11829008] - if stds is None: - stds = [np.sqrt(0.00283885), np.sqrt(0.01066668), np.sqrt(0.0097978)] - if weights is None: - weights = [0.62703102, 0.23732313, 0.13564585] - assert ( - len(means) == len(stds) == len(weights) - ), "Lengths of means, standard deviations, and weights must be equal." self.means = means self.stds = stds self.weights = weights - def rvs(self, size): + @check_params + def rvs(self, size: int): """Generate random variables from the GMM distribution. :param size: The number of random variables to generate. diff --git a/slsim/Util/__init__.py b/slsim/Util/__init__.py index e69de29bb..daaebfade 100644 --- a/slsim/Util/__init__.py +++ b/slsim/Util/__init__.py @@ -0,0 +1,3 @@ +from .params import check_params + +__all__ = ["check_params"] diff --git a/slsim/Util/params.py b/slsim/Util/params.py new file mode 100644 index 000000000..dd19bbe03 --- /dev/null +++ b/slsim/Util/params.py @@ -0,0 +1,208 @@ +"""Utilities for managing parameter defaults and validation in the slsim package. + +Desgined to be unobtrusive to use. +""" +from functools import wraps +from importlib import import_module +from typing import Callable, Any, TypeVar +from enum import Enum +import inspect +import pydantic + +""" +Set of routines for validating inputs to functions and classes. The elements of this +module should never be imported directly. Instead, @check_params can be imported +directly from the Util module. +""" + + +class SlSimParameterException(Exception): + pass + + +_defaults = {} + + +class _FnType(Enum): + """Enum for the different types of functions we can have. This is used to determine + how to parse the arguments to the function. + + There are three possible cases: + 1. The function is a standard function, defined outside a class + 2. The function is a standard object method, + taking "self" as the first parameter + 3. The funtion is a class method (or staticmethod), not taking + "self" as the first parameter + """ + + STANDARD = 0 + METHOD = 1 + CLASSMETHOD = 2 + + +def determine_fn_type(fn: Callable) -> _FnType: + """Determine which of the three possible cases a function falls into. Cases 0 and 2 + are actually functionally identical. Things only get spicy when we have a "self" + argument. + + However the tricky thing is that decorators operate on functions and methods when + they are imported, not when they are used. This means "inspect.ismethod" will always + return False, even if the function is a method. + + We can get around this by checking if the parent of the function is a class. Then, + we check if the first argument of the function is "self". If both of these are true, + then the function is a method. + """ + if not inspect.isfunction(fn): + raise TypeError("decorator @check_params can only be used on functions!") + qualified_obj_name = fn.__qualname__ + qualified_obj_path = qualified_obj_name.split(".") + if len(qualified_obj_path) == 1: + # If the qualified name isn't split, this is a standard function not + # attached to a class + return _FnType.STANDARD + + spec = inspect.getfullargspec(fn) + if spec.args[0] == "self": + return _FnType.METHOD + else: + return _FnType.CLASSMETHOD + + +T_ = TypeVar("T_") + + +def check_params(fn: Callable[..., T_]) -> Callable[..., T_]: + """A decorator for enforcing checking of params in __init__ methods. This decorator + will automatically load the default parameters for the class and check that the + passed parameters are valid. It expeects a "params.py" file in the same folder as + the class definition. Uses pydantic models to enforce types, sanity checks, and + defaults. + + From and end user perspective, there is no difference between this and a normal + __init__ fn. Developers only need to add @check_params above their __init__ method + definition to enable this feature, then add their default parameters to the + "params.py" file. + """ + fn_type = determine_fn_type(fn) + if fn_type == _FnType.STANDARD: + new_fn = standard_fn_wrapper(fn) + elif fn_type == _FnType.METHOD: + new_fn = method_fn_wrapper(fn) + elif fn_type == _FnType.CLASSMETHOD: + new_fn = standard_fn_wrapper(fn) + + return new_fn + + +def standard_fn_wrapper(fn: Callable[..., T_]) -> Callable[..., T_]: + """A wrapper for standard functions. + + This is used to parse the arguments to the function and check that they are valid. + """ + + @wraps(fn) + def new_fn(*args, **kwargs) -> T_: + # Get function argument names + pargs = {} + if args: + largs = list(inspect.signature(fn).parameters.keys()) + for i in range(len(args)): + arg_value = args[i] + if arg_value is not None: + pargs[largs[i]] = args[i] + # Doing it this way ensures we still catch duplicate arguments + defaults = get_defaults(fn) + parsed_args = defaults(**pargs, **kwargs) + return fn(**dict(parsed_args)) + + return new_fn + + +def method_fn_wrapper(fn: Callable[..., T_]) -> Callable[..., T_]: + @wraps(fn) + def new_fn(obj: Any, *args, **kwargs) -> T_: + # Get function argument names + parsed_args = {} + if args: + largs = list(inspect.signature(fn).parameters.keys()) + + for i in range(len(args)): + arg_value = args[i] + if arg_value is not None: + parsed_args[largs[i + 1]] = arg_value + # Doing it this way ensures we still catch duplicate arguments + parsed_kwargs = {k: v for k, v in kwargs.items() if v is not None} + defaults = get_defaults(fn) + parsed_args = defaults(**parsed_args, **parsed_kwargs) + return fn(obj, **dict(parsed_args)) + + return new_fn + + +def get_defaults(fn: Callable) -> pydantic.BaseModel: + module_trace = inspect.getmodule(fn).__name__.split(".") + file_name = module_trace[-1] + parent_trace = module_trace[:-1] + parent_path = ".".join(parent_trace) + param_path = ".".join([parent_path, "_params"]) + fn_qualname = fn.__qualname__ + cache_name = parent_path + "." + fn_qualname + if cache_name in _defaults: + return _defaults[cache_name] + + try: + _ = import_module(param_path) + except ModuleNotFoundError: + raise SlSimParameterException( + f'No default parameters found in module {".".join(parent_trace)},' + " but something in that module is trying to use the @check_params decorator" + ) + try: + param_model_file = import_module(f"{param_path}.{file_name}") + except AttributeError: + raise SlSimParameterException( + f'No default parameters found for file "{file_name}" in module ' + f'{".".join(parent_trace)}, but something in that module is trying to use ' + "the @check_params decorator" + ) + + if fn.__name__ == "__init__": + expected_model_name = "_".join(fn_qualname.split(".")[:-1]) + else: + expected_model_name = "_".join(fn_qualname.split(".")) + + try: + parameter_model = getattr(param_model_file, expected_model_name) + except AttributeError: + raise SlSimParameterException( + "No default parameters found for function " f'"{fn_qualname}"' + ) + if not issubclass(parameter_model, pydantic.BaseModel): + raise SlSimParameterException( + f'Defaults for "{fn_qualname}" are not in a pydantic model!' + ) + _defaults[cache_name] = parameter_model + return _defaults[cache_name] + + +def load_parameters(modpath: str, obj_name: str) -> pydantic.BaseModel: + """Loads parameters from the "params.py" file which should be in the same folder as + the class definition.""" + try: + defaults = import_module(modpath) + except ModuleNotFoundError: + raise SlSimParameterException( + "No default parameters found in module " f'"{modpath[-2]}"' + ) + try: + obj_defaults = getattr(defaults, obj_name) + except AttributeError: + raise SlSimParameterException( + f"No default parameters found for class " f'"{obj_name}"' + ) + if not issubclass(obj_defaults, pydantic.BaseModel): + raise SlSimParameterException( + f'Defaults for "{obj_name}" are not in a ' "pydantic model!" + ) + return obj_defaults diff --git a/tests/test_Params/__init__.py b/tests/test_Params/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_Params/test_params_check.py b/tests/test_Params/test_params_check.py new file mode 100644 index 000000000..b0d2e3927 --- /dev/null +++ b/tests/test_Params/test_params_check.py @@ -0,0 +1,140 @@ +from slsim.ParamDistributions.gaussian_mixture_model import GaussianMixtureModel +from slsim.Deflectors.velocity_dispersion import vel_disp_composite_model +from astropy.cosmology import FlatLambdaCDM +from pydantic import ValidationError +import pytest +import random + +""" +Test for the parameter checking rountines in slsim. We want to make sure we're using +functions/classes that are actually in slsim, rather than ones in the /tests folder. +This is because the parameter checking routine discovers defaults by importing from +a standard location inside the slsim package. If we import from the /tests folder, +this will fail. +""" + + +def test_all_kwargs_init(good_inputs: dict): + gmm = GaussianMixtureModel(**good_inputs) + assert ( + gmm.means == good_inputs["means"] + and gmm.stds == good_inputs["stds"] + and gmm.weights == good_inputs["weights"] + ) + + +def test_all_args_init(good_inputs: dict): + inputs = list(good_inputs.values()) + gmm = GaussianMixtureModel(*inputs) + assert gmm.means == good_inputs["means"] + + +def test_all_args_method(good_model: GaussianMixtureModel): + output = good_model.rvs(100) + assert len(output) == 100 + + +def test_all_kwargs_method(good_model: GaussianMixtureModel): + output = good_model.rvs(size=100) + assert len(output) == 100 + + +def test_all_args_function(vel_disp_inputs: dict): + inputs = list(vel_disp_inputs.values()) + output = vel_disp_composite_model(*inputs) + assert isinstance(output, float) + + +def test_all_kwargs_function(vel_disp_inputs: dict): + output = vel_disp_composite_model(**vel_disp_inputs) + assert isinstance(output, float) + + +def test_mixture_init(good_inputs: dict): + good_inputs_keys = list(good_inputs.keys()) + input_args = [good_inputs[k] for k in good_inputs_keys[:2]] + input_kwargs = {k: good_inputs[k] for k in good_inputs_keys[2:]} + + gmm = GaussianMixtureModel(*input_args, **input_kwargs) + assert ( + gmm.means == good_inputs["means"] + and gmm.stds == good_inputs["stds"] + and gmm.weights == good_inputs["weights"] + ) + + +def test_mixture_function(vel_disp_inputs: dict): + vel_disp_inputs_keys = list(vel_disp_inputs.keys()) + input_args = [vel_disp_inputs[k] for k in vel_disp_inputs_keys[:3]] + input_kwargs = {k: vel_disp_inputs[k] for k in vel_disp_inputs_keys[3:]} + + output = vel_disp_composite_model(*input_args, **input_kwargs) + assert isinstance(output, float) + + +def test_shuffle_init(good_inputs: dict): + keys = list(good_inputs.keys()) + random.shuffle(keys) + input = {k: good_inputs[k] for k in keys} + gmm = GaussianMixtureModel(**input) + assert ( + gmm.means == good_inputs["means"] + and gmm.stds == good_inputs["stds"] + and gmm.weights == good_inputs["weights"] + ) + + +def test_shuffle_function(vel_disp_inputs: dict): + keys = list(vel_disp_inputs.keys()) + random.shuffle(keys) + input = {k: vel_disp_inputs[k] for k in keys} + output = vel_disp_composite_model(**input) + assert isinstance(output, float) + + +def test_failure_init(good_inputs: dict): + good_inputs_keys = list(good_inputs.keys()) + input_args = [good_inputs[k] for k in good_inputs_keys[:2]] + input_kwargs = {k: good_inputs[k] for k in good_inputs_keys[2:]} + input_kwargs["weights"] = [0.2, 0.3, 0.4] + with pytest.raises(ValueError): + _ = GaussianMixtureModel(*input_args, **input_kwargs) + + +def test_failure_method(good_model: GaussianMixtureModel): + with pytest.raises(ValidationError): + _ = good_model.rvs(-50) + + +def test_failure_function(vel_disp_inputs: dict): + input_kwargs = vel_disp_inputs + input_kwargs["m_star"] = -50 + with pytest.raises(ValidationError): + _ = vel_disp_composite_model(**input_kwargs) + + +@pytest.fixture +def cosmology(): + return FlatLambdaCDM(H0=70, Om0=0.3) + + +@pytest.fixture +def vel_disp_inputs(cosmology): + return { + "r": 5, + "m_star": 10**10, + "rs_star": 30, + "m_halo": 10**14, + "c_halo": 3, + "cosmo": cosmology, + } + + +@pytest.fixture +def good_inputs(): + return {"means": [1, 2, 3], "stds": [1, 2, 3], "weights": [0.2, 0.3, 0.5]} + + +@pytest.fixture +def good_model(good_inputs: dict): + return GaussianMixtureModel(**good_inputs)