diff --git a/src/adler/adler.py b/src/adler/adler.py index bc24b29..7d29f86 100644 --- a/src/adler/adler.py +++ b/src/adler/adler.py @@ -3,17 +3,23 @@ from adler.dataclasses.AdlerPlanetoid import AdlerPlanetoid from adler.science.PhaseCurve import PhaseCurve +from adler.utilities.AdlerCLIArguments import AdlerCLIArguments -def runAdler(args): - planetoid = AdlerPlanetoid.construct_from_RSP(args.ssoid, args.filter_list, args.date_range) +def runAdler(cli_args): + planetoid = AdlerPlanetoid.construct_from_RSP( + cli_args.ssObjectId, cli_args.filter_list, cli_args.date_range + ) # now let's do some phase curves! + # get the r filter SSObject metadata + sso_r = planetoid.SSObject_in_filter("r") + # get the RSP r filter model pc = PhaseCurve( - abs_mag=planetoid.SSObject.H[2] * u.mag, - phase_param=planetoid.SSObject.G12[2], + abs_mag=sso_r.H * u.mag, + phase_param=sso_r.G12, model_name="HG12_Pen16", ) print(pc) @@ -33,9 +39,14 @@ def runAdler(args): def main(): parser = argparse.ArgumentParser(description="Runs Adler for a select planetoid and given user input.") - parser.add_argument("-s", "--ssoid", help="SSObject ID of planetoid.", type=str, required=True) + parser.add_argument("-s", "--ssObjectId", help="SSObject ID of planetoid.", type=str, required=True) parser.add_argument( - "-f", "--filters", help="Comma-separated list of filters required.", type=str, default="u,g,r,i,z,y" + "-f", + "--filter_list", + help="Filters required.", + nargs="*", + type=str, + default=["u", "g", "r", "i", "z", "y"], ) parser.add_argument( "-d", @@ -48,9 +59,9 @@ def main(): args = parser.parse_args() - args.filter_list = args.filters.split(",") + cli_args = AdlerCLIArguments(args) - runAdler(args) + runAdler(cli_args) if __name__ == "__main__": diff --git a/src/adler/dataclasses/MPCORB.py b/src/adler/dataclasses/MPCORB.py index 483acd7..a17bd18 100644 --- a/src/adler/dataclasses/MPCORB.py +++ b/src/adler/dataclasses/MPCORB.py @@ -2,6 +2,22 @@ from adler.dataclasses.dataclass_utilities import get_from_table +MPCORB_KEYS = { + "mpcDesignation": str, + "mpcNumber": int, + "mpcH": float, + "mpcG": float, + "epoch": float, + "peri": float, + "node": float, + "incl": float, + "e": float, + "n": float, + "q": float, + "uncertaintyParameter": str, + "flags": str, +} + @dataclass class MPCORB: @@ -87,33 +103,9 @@ def construct_from_data_table(cls, ssObjectId, data_table): """ - mpcDesignation = get_from_table(data_table, "mpcDesignation", "str") - mpcNumber = get_from_table(data_table, "mpcNumber", "int") - mpcH = get_from_table(data_table, "mpcH", "float") - mpcG = get_from_table(data_table, "mpcG", "float") - epoch = get_from_table(data_table, "epoch", "float") - peri = get_from_table(data_table, "peri", "float") - node = get_from_table(data_table, "node", "float") - incl = get_from_table(data_table, "incl", "float") - e = get_from_table(data_table, "e", "float") - n = get_from_table(data_table, "n", "float") - q = get_from_table(data_table, "q", "float") - uncertaintyParameter = get_from_table(data_table, "uncertaintyParameter", "str") - flags = get_from_table(data_table, "flags", "str") - - return cls( - ssObjectId, - mpcDesignation, - mpcNumber, - mpcH, - mpcG, - epoch, - peri, - node, - incl, - e, - n, - q, - uncertaintyParameter, - flags, - ) + mpcorb_dict = {"ssObjectId": ssObjectId} + + for mpcorb_key, mpcorb_type in MPCORB_KEYS.items(): + mpcorb_dict[mpcorb_key] = get_from_table(data_table, mpcorb_key, mpcorb_type, "MPCORB") + + return cls(**mpcorb_dict) diff --git a/src/adler/dataclasses/Observations.py b/src/adler/dataclasses/Observations.py index 77768f6..64b3900 100644 --- a/src/adler/dataclasses/Observations.py +++ b/src/adler/dataclasses/Observations.py @@ -3,6 +3,17 @@ from adler.dataclasses.dataclass_utilities import get_from_table +OBSERVATIONS_KEYS = { + "mag": np.ndarray, + "magErr": np.ndarray, + "midPointMjdTai": np.ndarray, + "ra": np.ndarray, + "dec": np.ndarray, + "phaseAngle": np.ndarray, + "topocentricDist": np.ndarray, + "heliocentricDist": np.ndarray, +} + @dataclass class Observations: @@ -24,7 +35,7 @@ class Observations: magErr: array_like of floats Magnitude error. This is a placeholder and will be replaced by flux error. - midpointMjdTai: array_like of floats + midPointMjdTai: array_like of floats Effective mid-visit time for this diaSource, expressed as Modified Julian Date, International Atomic Time. ra: array_like of floats @@ -54,7 +65,7 @@ class Observations: filter_name: str = "" mag: np.ndarray = field(default_factory=lambda: np.zeros(0)) magErr: np.ndarray = field(default_factory=lambda: np.zeros(0)) - midpointMjdTai: np.ndarray = field(default_factory=lambda: np.zeros(0)) + midPointMjdTai: np.ndarray = field(default_factory=lambda: np.zeros(0)) ra: np.ndarray = field(default_factory=lambda: np.zeros(0)) dec: np.ndarray = field(default_factory=lambda: np.zeros(0)) phaseAngle: np.ndarray = field(default_factory=lambda: np.zeros(0)) @@ -85,32 +96,17 @@ def construct_from_data_table(cls, ssObjectId, filter_name, data_table): """ - mag = get_from_table(data_table, "mag", "array") - magErr = get_from_table(data_table, "magErr", "array") - midpointMjdTai = get_from_table(data_table, "midPointMjdTai", "array") - ra = get_from_table(data_table, "ra", "array") - dec = get_from_table(data_table, "dec", "array") - phaseAngle = get_from_table(data_table, "phaseAngle", "array") - topocentricDist = get_from_table(data_table, "topocentricDist", "array") - heliocentricDist = get_from_table(data_table, "heliocentricDist", "array") - - reduced_mag = cls.calculate_reduced_mag(cls, mag, topocentricDist, heliocentricDist) - - return cls( - ssObjectId, - filter_name, - mag, - magErr, - midpointMjdTai, - ra, - dec, - phaseAngle, - topocentricDist, - heliocentricDist, - reduced_mag, - len(data_table), + obs_dict = {"ssObjectId": ssObjectId, "filter_name": filter_name, "num_obs": len(data_table)} + + for obs_key, obs_type in OBSERVATIONS_KEYS.items(): + obs_dict[obs_key] = get_from_table(data_table, obs_key, obs_type, "SSSource/DIASource") + + obs_dict["reduced_mag"] = cls.calculate_reduced_mag( + cls, obs_dict["mag"], obs_dict["topocentricDist"], obs_dict["heliocentricDist"] ) + return cls(**obs_dict) + def calculate_reduced_mag(self, mag, topocentric_dist, heliocentric_dist): """ Calculates the reduced magnitude column. diff --git a/src/adler/dataclasses/SSObject.py b/src/adler/dataclasses/SSObject.py index 2207c51..9ec0443 100644 --- a/src/adler/dataclasses/SSObject.py +++ b/src/adler/dataclasses/SSObject.py @@ -3,6 +3,16 @@ from adler.dataclasses.dataclass_utilities import get_from_table +SSO_KEYS = { + "discoverySubmissionDate": float, + "firstObservationDate": float, + "arc": float, + "numObs": int, + "maxExtendedness": float, + "minExtendedness": float, + "medianExtendedness": float, +} + @dataclass class SSObject: @@ -57,47 +67,24 @@ class SSObject: @classmethod def construct_from_data_table(cls, ssObjectId, filter_list, data_table): - discoverySubmissionDate = get_from_table(data_table, "discoverySubmissionDate", "float") - firstObservationDate = get_from_table(data_table, "firstObservationDate", "float") - arc = get_from_table(data_table, "arc", "float") - numObs = get_from_table(data_table, "numObs", "int") - - H = np.zeros(len(filter_list)) - G12 = np.zeros(len(filter_list)) - Herr = np.zeros(len(filter_list)) - G12err = np.zeros(len(filter_list)) - nData = np.zeros(len(filter_list)) + sso_dict = {"ssObjectId": ssObjectId, "filter_list": filter_list, "filter_dependent_values": []} - filter_dependent_values = [] + for sso_key, sso_type in SSO_KEYS.items(): + sso_dict[sso_key] = get_from_table(data_table, sso_key, sso_type, "SSObject") for i, filter_name in enumerate(filter_list): filter_dept_object = FilterDependentSSO( filter_name=filter_name, - H=get_from_table(data_table, filter_name + "_H", "float"), - G12=get_from_table(data_table, filter_name + "_G12", "float"), - Herr=get_from_table(data_table, filter_name + "_HErr", "float"), - G12err=get_from_table(data_table, filter_name + "_G12Err", "float"), - nData=get_from_table(data_table, filter_name + "_Ndata", "int"), + H=get_from_table(data_table, filter_name + "_H", float, "SSObject"), + G12=get_from_table(data_table, filter_name + "_G12", float, "SSObject"), + Herr=get_from_table(data_table, filter_name + "_HErr", float, "SSObject"), + G12err=get_from_table(data_table, filter_name + "_G12Err", float, "SSObject"), + nData=get_from_table(data_table, filter_name + "_Ndata", float, "SSObject"), ) - filter_dependent_values.append(filter_dept_object) - - maxExtendedness = get_from_table(data_table, "maxExtendedness", "float") - minExtendedness = get_from_table(data_table, "minExtendedness", "float") - medianExtendedness = get_from_table(data_table, "medianExtendedness", "float") - - return cls( - ssObjectId, - filter_list, - discoverySubmissionDate, - firstObservationDate, - arc, - numObs, - filter_dependent_values, - maxExtendedness, - minExtendedness, - medianExtendedness, - ) + sso_dict["filter_dependent_values"].append(filter_dept_object) + + return cls(**sso_dict) @dataclass diff --git a/src/adler/dataclasses/dataclass_utilities.py b/src/adler/dataclasses/dataclass_utilities.py index ebba3d8..172a264 100644 --- a/src/adler/dataclasses/dataclass_utilities.py +++ b/src/adler/dataclasses/dataclass_utilities.py @@ -48,34 +48,92 @@ def get_data_table(sql_query, service=None, sql_filename=None): return data_table -def get_from_table(data_table, column_name, data_type): - """Retrieves information from the data_table class variable and forces it to be a specified type. +def get_from_table(data_table, column_name, data_type, table_name="default"): + """Retrieves information from the data_table and forces it to be a specified type. Parameters ----------- + data_table : DALResultsTable or Pandas dataframe + Data table containing columns of interest. + column_name : str Column name under which the data of interest is stored. - type : str - String delineating data type. Should be "str", "float", "int" or "array". + + data_type : type + Data type. Should be int, float, str or np.ndarray. + + table_name : str + Name of the table. This is mostly for more informative error messages. Default="default". Returns ----------- - data : any type + data_val : str, float, int or nd.array The data requested from the table cast to the type required. """ - try: - if data_type == "str": - return str(data_table[column_name][0]) - elif data_type == "float": - return float(data_table[column_name][0]) - elif data_type == "int": - return int(data_table[column_name][0]) - elif data_type == "array": - return np.array(data_table[column_name]) - else: - raise TypeError( - "Type for argument data_type not recognised: must be one of 'str', 'float', 'int', 'array'." + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=UserWarning + ) # RSP tables mask unpopulated elements, which get converted to NaN here and trigger a warning we don't care about. + try: + if data_type == str: + data_val = str(data_table[column_name][0]) + elif data_type == float: + data_val = float(data_table[column_name][0]) + elif data_type == int: + data_val = int(data_table[column_name][0]) + elif data_type == np.ndarray: + data_val = np.array(data_table[column_name]) + else: + raise TypeError( + "Type for argument data_type not recognised for column {} in table {}: must be str, float, int or np.ndarray.".format( + column_name, table_name + ) + ) + except ValueError: + raise ValueError("Could not cast column name to type.") + + # here we alert the user if one of the values is unpopulated and change it to a NaN + data_val = check_value_populated(data_val, data_type, column_name, table_name) + + return data_val + + +def check_value_populated(data_val, data_type, column_name, table_name): + """Checks to see if data_val populated properly and prints a helpful warning if it didn't. + Usually this will trigger because the RSP hasn't populated that field for this particular object. + + Parameters + ----------- + data_val : str, float, int or nd.array + The value to check. + + data_type: type + Data type. Should be int, float, str or np.ndarray. + + column_name: str + Column name under which the data of interest is stored. + + table_name : str + Name of the table. This is mostly for more informative error messages. Default="default". + + Returns + ----------- + data_val : str, float, int, nd.array or np.nan + Either returns the original data_val or an np.nan if it detected that the value was not populated. + + """ + + array_length_zero = data_type == np.ndarray and len(data_val) == 0 + number_is_nan = data_type in [float, int] and np.isnan(data_val) + str_is_empty = data_type == str and len(data_val) == 0 + + if array_length_zero or number_is_nan or str_is_empty: + print( + "WARNING: {} unpopulated in {} table for this object. Storing NaN instead.".format( + column_name, table_name ) - except ValueError: - raise ValueError("Could not cast column name to type.") + ) + data_val = np.nan + + return data_val diff --git a/src/adler/utilities/AdlerCLIArguments.py b/src/adler/utilities/AdlerCLIArguments.py new file mode 100644 index 0000000..03cf41c --- /dev/null +++ b/src/adler/utilities/AdlerCLIArguments.py @@ -0,0 +1,50 @@ +class AdlerCLIArguments: + """ + Class for storing abd validating Adler command-line arguments. + + Attributes: + ----------- + args : argparse.Namespace object + argparse.Namespace object created by calling parse_args(). + + """ + + def __init__(self, args): + self.ssObjectId = args.ssObjectId + self.filter_list = args.filter_list + self.date_range = args.date_range + + self.validate_arguments() + + def validate_arguments(self): + self._validate_filter_list() + self._validate_ssObjectId() + self._validate_date_range() + + def _validate_filter_list(self): + expected_filters = ["u", "g", "r", "i", "z", "y"] + + if not set(self.filter_list).issubset(expected_filters): + raise ValueError( + "Unexpected filters found in filter_list command-line argument. filter_list must be a list of LSST filters." + ) + + def _validate_ssObjectId(self): + try: + int(self.ssObjectId) + except ValueError: + raise ValueError("ssObjectId command-line argument does not appear to be a valid ssObjectId.") + + def _validate_date_range(self): + for d in self.date_range: + try: + float(d) + except ValueError: + raise ValueError( + "One or both of the values for the date_range command-line argument do not seem to be valid numbers." + ) + + if any(d > 250000 for d in self.date_range): + raise ValueError( + "Dates for date_range command-line argument seem rather large. Did you input JD instead of MJD?" + ) diff --git a/tests/adler/dataclasses/test_AdlerPlanetoid.py b/tests/adler/dataclasses/test_AdlerPlanetoid.py index 288c5cc..a011540 100644 --- a/tests/adler/dataclasses/test_AdlerPlanetoid.py +++ b/tests/adler/dataclasses/test_AdlerPlanetoid.py @@ -86,7 +86,7 @@ def test_construct_with_date_range(): ] ) - assert_almost_equal(test_planetoid.observations_by_filter[0].midpointMjdTai, expected_dates) + assert_almost_equal(test_planetoid.observations_by_filter[0].midPointMjdTai, expected_dates) with pytest.raises(ValueError) as error_info_1: test_planetoid = AdlerPlanetoid.construct_from_SQL( diff --git a/tests/adler/dataclasses/test_MPCORB.py b/tests/adler/dataclasses/test_MPCORB.py index f94db75..d139d00 100644 --- a/tests/adler/dataclasses/test_MPCORB.py +++ b/tests/adler/dataclasses/test_MPCORB.py @@ -36,5 +36,5 @@ def test_construct_MPCORB_from_data_table(): assert_almost_equal(test_MPCORB.e, 0.7168805704972735, decimal=6) assert np.isnan(test_MPCORB.n) assert_almost_equal(test_MPCORB.q, 0.5898291078470536, decimal=6) - assert test_MPCORB.uncertaintyParameter == "" + assert np.isnan(test_MPCORB.uncertaintyParameter) assert test_MPCORB.flags == "0" diff --git a/tests/adler/dataclasses/test_Observations.py b/tests/adler/dataclasses/test_Observations.py index 24c2b25..44a3717 100644 --- a/tests/adler/dataclasses/test_Observations.py +++ b/tests/adler/dataclasses/test_Observations.py @@ -161,7 +161,7 @@ def test_construct_observations_from_data_table(): assert_almost_equal(test_observations.mag, expected_mag) assert_almost_equal(test_observations.magErr, expected_magerr) - assert_almost_equal(test_observations.midpointMjdTai, expected_mjd) + assert_almost_equal(test_observations.midPointMjdTai, expected_mjd) assert_almost_equal(test_observations.ra, expected_ra) assert_almost_equal(test_observations.dec, expected_dec) assert_almost_equal(test_observations.phaseAngle, expected_phaseangle) diff --git a/tests/adler/dataclasses/test_dataclass_utilities.py b/tests/adler/dataclasses/test_dataclass_utilities.py index dfae4b7..b087e0e 100644 --- a/tests/adler/dataclasses/test_dataclass_utilities.py +++ b/tests/adler/dataclasses/test_dataclass_utilities.py @@ -1,10 +1,12 @@ import pytest import pandas as pd +import numpy as np from pandas.testing import assert_frame_equal from numpy.testing import assert_equal from adler.dataclasses.dataclass_utilities import get_data_table from adler.dataclasses.dataclass_utilities import get_from_table +from adler.dataclasses.dataclass_utilities import check_value_populated from adler.utilities.tests_utilities import get_test_data_filepath @@ -37,13 +39,13 @@ def test_get_from_table(): {"string_col": "a test string", "int_col": 4, "float_col": 4.5, "array_col": [5, 6]} ) - assert get_from_table(test_table, "string_col", "str") == "a test string" - assert get_from_table(test_table, "int_col", "int") == 4 - assert get_from_table(test_table, "float_col", "float") == 4.5 - assert_equal(get_from_table(test_table, "array_col", "array"), [5, 6]) + assert get_from_table(test_table, "string_col", str) == "a test string" + assert get_from_table(test_table, "int_col", int) == 4 + assert get_from_table(test_table, "float_col", float) == 4.5 + assert_equal(get_from_table(test_table, "array_col", np.ndarray), [5, 6]) with pytest.raises(ValueError) as error_info_1: - get_from_table(test_table, "string_col", "int") + get_from_table(test_table, "string_col", int) assert error_info_1.value.args[0] == "Could not cast column name to type." @@ -52,5 +54,18 @@ def test_get_from_table(): assert ( error_info_2.value.args[0] - == "Type for argument data_type not recognised: must be one of 'str', 'float', 'int', 'array'." + == "Type for argument data_type not recognised for column string_col in table default: must be str, float, int or np.ndarray." ) + + +def test_check_value_populated(): + populated_value = check_value_populated(3, int, "column", "table") + assert populated_value == 3 + + array_length_zero = check_value_populated(np.array([]), np.ndarray, "column", "table") + number_is_nan = check_value_populated(np.nan, float, "column", "table") + str_is_empty = check_value_populated("", str, "column", "table") + + assert np.isnan(array_length_zero) + assert np.isnan(number_is_nan) + assert np.isnan(str_is_empty) diff --git a/tests/adler/utilities/test_AdlerCLIArguments.py b/tests/adler/utilities/test_AdlerCLIArguments.py new file mode 100644 index 0000000..e99d3b5 --- /dev/null +++ b/tests/adler/utilities/test_AdlerCLIArguments.py @@ -0,0 +1,73 @@ +import pytest +from adler.utilities.AdlerCLIArguments import AdlerCLIArguments + + +# AdlerCLIArguments object takes an object as input, so we define a quick one here +class args: + def __init__(self, ssObjectId, filter_list, date_range): + self.ssObjectId = ssObjectId + self.filter_list = filter_list + self.date_range = date_range + + +def test_AdlerCLIArguments(): + # test correct population + good_input_dict = {"ssObjectId": "666", "filter_list": ["g", "r", "i"], "date_range": [60000.0, 67300.0]} + good_arguments = args(**good_input_dict) + good_arguments_object = AdlerCLIArguments(good_arguments) + + assert good_arguments_object.__dict__ == good_input_dict + + # test that a bad ssObjectId triggers the right error + bad_ssoid_arguments = args("hello!", ["g", "r", "i"], [60000.0, 67300.0]) + + with pytest.raises(ValueError) as bad_ssoid_error: + bad_ssoid_object = AdlerCLIArguments(bad_ssoid_arguments) + + assert ( + bad_ssoid_error.value.args[0] + == "ssObjectId command-line argument does not appear to be a valid ssObjectId." + ) + + # test that non-LSST or unexpected filters trigger the right error + bad_filter_arguments = args("666", ["g", "r", "i", "m"], [60000.0, 67300.0]) + + with pytest.raises(ValueError) as bad_filter_error: + bad_filter_object = AdlerCLIArguments(bad_filter_arguments) + + assert ( + bad_filter_error.value.args[0] + == "Unexpected filters found in filter_list command-line argument. filter_list must be a list of LSST filters." + ) + + bad_filter_arguments_2 = args("666", ["pony"], [60000.0, 67300.0]) + + with pytest.raises(ValueError) as bad_filter_error_2: + bad_filter_object = AdlerCLIArguments(bad_filter_arguments_2) + + assert ( + bad_filter_error_2.value.args[0] + == "Unexpected filters found in filter_list command-line argument. filter_list must be a list of LSST filters." + ) + + # test that overly-large dates trigger the right error + big_date_arguments = args("666", ["g", "r", "i"], [260000.0, 267300.0]) + + with pytest.raises(ValueError) as big_date_error: + big_date_object = AdlerCLIArguments(big_date_arguments) + + assert ( + big_date_error.value.args[0] + == "Dates for date_range command-line argument seem rather large. Did you input JD instead of MJD?" + ) + + # test that unexpected date values trigger the right error + bad_date_arguments = args("666", ["g", "r", "i"], [260000.0, "cheese"]) + + with pytest.raises(ValueError) as bad_date_error: + bad_date_object = AdlerCLIArguments(bad_date_arguments) + + assert ( + bad_date_error.value.args[0] + == "One or both of the values for the date_range command-line argument do not seem to be valid numbers." + )