diff --git a/.github/workflows/smoke-test.yml b/.github/workflows/smoke-test.yml index 9bd50e2..128ed0f 100644 --- a/.github/workflows/smoke-test.yml +++ b/.github/workflows/smoke-test.yml @@ -20,7 +20,7 @@ jobs: strategy: matrix: os: ['macos-latest','ubuntu-latest'] - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11', '3.12'] runs-on: ${{ matrix.os }} steps: diff --git a/.github/workflows/testing-and-coverage.yml b/.github/workflows/testing-and-coverage.yml index c39f18d..de5d7d4 100644 --- a/.github/workflows/testing-and-coverage.yml +++ b/.github/workflows/testing-and-coverage.yml @@ -15,7 +15,7 @@ jobs: strategy: matrix: os: ['macos-latest','ubuntu-latest'] - python-version: ['3.9', '3.10', '3.11'] + python-version: ['3.10', '3.11', '3.12'] runs-on: ${{ matrix.os }} steps: - uses: actions/checkout@v4 diff --git a/src/adler/adler.py b/src/adler/adler.py index bc24b29..cdf1d09 100644 --- a/src/adler/adler.py +++ b/src/adler/adler.py @@ -1,19 +1,35 @@ +import logging import argparse import astropy.units as u from adler.dataclasses.AdlerPlanetoid import AdlerPlanetoid from adler.science.PhaseCurve import PhaseCurve +from adler.utilities.AdlerCLIArguments import AdlerCLIArguments +from adler.utilities.adler_logging import setup_adler_logging +logger = logging.getLogger(__name__) -def runAdler(args): - planetoid = AdlerPlanetoid.construct_from_RSP(args.ssoid, args.filter_list, args.date_range) + +def runAdler(cli_args): + logger.info("Beginning Adler.") + logger.info("Ingesting all data for object {} from RSP...".format(cli_args.ssObjectId)) + + planetoid = AdlerPlanetoid.construct_from_RSP( + cli_args.ssObjectId, cli_args.filter_list, cli_args.date_range + ) + + logger.info("Data successfully ingested.") + logger.info("Calculating phase curves...") # now let's do some phase curves! + # get the r filter SSObject metadata + sso_r = planetoid.SSObject_in_filter("r") + # get the RSP r filter model pc = PhaseCurve( - abs_mag=planetoid.SSObject.H[2] * u.mag, - phase_param=planetoid.SSObject.G12[2], + abs_mag=sso_r.H * u.mag, + phase_param=sso_r.G12, model_name="HG12_Pen16", ) print(pc) @@ -31,11 +47,16 @@ def runAdler(args): def main(): - parser = argparse.ArgumentParser(description="Runs Adler for a select planetoid and given user input.") + parser = argparse.ArgumentParser(description="Runs Adler for select planetoid(s) and given user input.") - parser.add_argument("-s", "--ssoid", help="SSObject ID of planetoid.", type=str, required=True) + parser.add_argument("-s", "--ssObjectId", help="SSObject ID of planetoid.", type=str, required=True) parser.add_argument( - "-f", "--filters", help="Comma-separated list of filters required.", type=str, default="u,g,r,i,z,y" + "-f", + "--filter_list", + help="Filters required.", + nargs="*", + type=str, + default=["u", "g", "r", "i", "z", "y"], ) parser.add_argument( "-d", @@ -45,12 +66,30 @@ def main(): type=float, default=[60000.0, 67300.0], ) + parser.add_argument( + "-o", + "--outpath", + help="Output path location. Default is current working directory.", + type=str, + default="./", + ) + parser.add_argument( + "-n", + "--db_name", + help="Stem filename of output database. If this doesn't exist, it will be created. Default: adler_out.", + type=str, + default="adler_out", + ) args = parser.parse_args() - args.filter_list = args.filters.split(",") + cli_args = AdlerCLIArguments(args) + + adler_logger = setup_adler_logging(cli_args.outpath) + + cli_args.logger = adler_logger - runAdler(args) + runAdler(cli_args) if __name__ == "__main__": diff --git a/src/adler/dataclasses/AdlerData.py b/src/adler/dataclasses/AdlerData.py index 1453cda..86c2e2a 100644 --- a/src/adler/dataclasses/AdlerData.py +++ b/src/adler/dataclasses/AdlerData.py @@ -1,5 +1,6 @@ import os import sqlite3 +import logging import numpy as np from dataclasses import dataclass, field from datetime import datetime, timezone @@ -15,6 +16,8 @@ "phase_parameter_2_err", ] +logger = logging.getLogger(__name__) + @dataclass class AdlerData: @@ -70,10 +73,12 @@ def populate_phase_parameters(self, filter_name, **kwargs): try: filter_index = self.filter_list.index(filter_name) except ValueError: + logger.error("ValueError: Filter {} does not exist in AdlerData.filter_list.".format(filter_name)) raise ValueError("Filter {} does not exist in AdlerData.filter_list.".format(filter_name)) # if model-dependent parameters exist without a model name, return an error if not kwargs.get("model_name") and any(name in kwargs for name in MODEL_DEPENDENT_KEYS): + logger.error("NameError: No model name given. Cannot update model-specific phase parameters.") raise NameError("No model name given. Cannot update model-specific phase parameters.") # update the value if it's in **kwargs @@ -163,6 +168,7 @@ def get_phase_parameters_in_filter(self, filter_name, model_name=None): try: filter_index = self.filter_list.index(filter_name) except ValueError: + logger.error("ValueError: Filter {} does not exist in AdlerData.filter_list.".format(filter_name)) raise ValueError("Filter {} does not exist in AdlerData.filter_list.".format(filter_name)) output_obj = PhaseParameterOutput() @@ -173,11 +179,17 @@ def get_phase_parameters_in_filter(self, filter_name, model_name=None): output_obj.arc = self.filter_dependent_values[filter_index].arc if not model_name: + logger.warn("No model name was specified. Returning non-model-dependent phase parameters.") print("No model name specified. Returning non-model-dependent phase parameters.") else: try: model_index = self.filter_dependent_values[filter_index].model_list.index(model_name) except ValueError: + logger.error( + "ValueError: Model {} does not exist for filter {} in AdlerData.model_lists.".format( + model_name, filter_name + ) + ) raise ValueError( "Model {} does not exist for filter {} in AdlerData.model_lists.".format( model_name, filter_name diff --git a/src/adler/dataclasses/AdlerPlanetoid.py b/src/adler/dataclasses/AdlerPlanetoid.py index 8b03781..93177e0 100644 --- a/src/adler/dataclasses/AdlerPlanetoid.py +++ b/src/adler/dataclasses/AdlerPlanetoid.py @@ -1,5 +1,6 @@ from lsst.rsp import get_tap_service import pandas as pd +import logging from adler.dataclasses.Observations import Observations from adler.dataclasses.MPCORB import MPCORB @@ -7,6 +8,8 @@ from adler.dataclasses.AdlerData import AdlerData from adler.dataclasses.dataclass_utilities import get_data_table +logger = logging.getLogger(__name__) + class AdlerPlanetoid: """AdlerPlanetoid class. Contains the Observations, MPCORB and SSObject dataclass objects.""" @@ -80,12 +83,28 @@ def construct_from_SQL( """ if len(date_range) != 2: + logger.error("ValueError: date_range attribute must be of length 2.") raise ValueError("date_range attribute must be of length 2.") observations_by_filter = cls.populate_observations( cls, ssObjectId, filter_list, date_range, sql_filename=sql_filename, schema=schema ) + if len(observations_by_filter) == 0: + logger.error( + "No observations found for this object in the given filter(s). Check SSOID and try again." + ) + raise Exception( + "No observations found for this object in the given filter(s). Check SSOID and try again." + ) + + if len(filter_list) > len(observations_by_filter): + logger.info( + "Not all specified filters have observations. Recalculating filter list based on past observations." + ) + filter_list = [obs_object.filter_name for obs_object in observations_by_filter] + logger.info("New filter list is: {}".format(filter_list)) + mpcorb = cls.populate_MPCORB(cls, ssObjectId, sql_filename=sql_filename, schema=schema) ssobject = cls.populate_SSObject( cls, ssObjectId, filter_list, sql_filename=sql_filename, schema=schema @@ -119,10 +138,29 @@ def construct_from_RSP( raise Exception("date_range argument must be of length 2.") service = get_tap_service("ssotap") + logger.info("Getting past observations from DIASource/SSSource...") observations_by_filter = cls.populate_observations( cls, ssObjectId, filter_list, date_range, service=service ) + + if len(observations_by_filter) == 0: + logger.error( + "No observations found for this object in the given filter(s). Check SSOID and try again." + ) + raise Exception( + "No observations found for this object in the given filter(s). Check SSOID and try again." + ) + + if len(filter_list) > len(observations_by_filter): + logger.info( + "Not all specified filters have observations. Recalculating filter list based on past observations." + ) + filter_list = [obs_object.filter_name for obs_object in observations_by_filter] + logger.info("New filter list is: {}".format(filter_list)) + + logger.info("Populating MPCORB metadata...") mpcorb = cls.populate_MPCORB(cls, ssObjectId, service=service) + logger.info("Populating SSObject metadata...") ssobject = cls.populate_SSObject(cls, ssObjectId, filter_list, service=service) adler_data = AdlerData(ssObjectId, filter_list) @@ -185,9 +223,21 @@ def populate_observations( data_table = get_data_table(observations_sql_query, service=service, sql_filename=sql_filename) - observations_by_filter.append( - Observations.construct_from_data_table(ssObjectId, filter_name, data_table) - ) + if len(data_table) == 0: + logger.warning( + "No observations found in {} filter for this object. Skipping this filter.".format( + filter_name + ) + ) + print( + "WARNING: No observations found in {} filter for this object. Skipping this filter.".format( + filter_name + ) + ) + else: + observations_by_filter.append( + Observations.construct_from_data_table(ssObjectId, filter_name, data_table) + ) return observations_by_filter @@ -228,6 +278,10 @@ def populate_MPCORB(self, ssObjectId, service=None, sql_filename=None, schema="d data_table = get_data_table(MPCORB_sql_query, service=service, sql_filename=sql_filename) + if len(data_table) == 0: + logger.error("No MPCORB data for this object could be found for this SSObjectId.") + raise Exception("No MPCORB data for this object could be found for this SSObjectId.") + return MPCORB.construct_from_data_table(ssObjectId, data_table) def populate_SSObject( @@ -282,6 +336,10 @@ def populate_SSObject( data_table = get_data_table(SSObject_sql_query, service=service, sql_filename=sql_filename) + if len(data_table) == 0: + logger.error("No SSObject data for this object could be found for this SSObjectId.") + raise Exception("No SSObject data for this object could be found for this SSObjectId.") + return SSObject.construct_from_data_table(ssObjectId, filter_list, data_table) def observations_in_filter(self, filter_name): @@ -302,6 +360,7 @@ def observations_in_filter(self, filter_name): try: filter_index = self.filter_list.index(filter_name) except ValueError: + logger.error("ValueError: Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name)) raise ValueError("Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name)) return self.observations_by_filter[filter_index] @@ -324,6 +383,7 @@ def SSObject_in_filter(self, filter_name): try: filter_index = self.filter_list.index(filter_name) except ValueError: + logger.error("ValueError: Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name)) raise ValueError("Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name)) return self.SSObject.filter_dependent_values[filter_index] diff --git a/src/adler/dataclasses/MPCORB.py b/src/adler/dataclasses/MPCORB.py index 483acd7..a17bd18 100644 --- a/src/adler/dataclasses/MPCORB.py +++ b/src/adler/dataclasses/MPCORB.py @@ -2,6 +2,22 @@ from adler.dataclasses.dataclass_utilities import get_from_table +MPCORB_KEYS = { + "mpcDesignation": str, + "mpcNumber": int, + "mpcH": float, + "mpcG": float, + "epoch": float, + "peri": float, + "node": float, + "incl": float, + "e": float, + "n": float, + "q": float, + "uncertaintyParameter": str, + "flags": str, +} + @dataclass class MPCORB: @@ -87,33 +103,9 @@ def construct_from_data_table(cls, ssObjectId, data_table): """ - mpcDesignation = get_from_table(data_table, "mpcDesignation", "str") - mpcNumber = get_from_table(data_table, "mpcNumber", "int") - mpcH = get_from_table(data_table, "mpcH", "float") - mpcG = get_from_table(data_table, "mpcG", "float") - epoch = get_from_table(data_table, "epoch", "float") - peri = get_from_table(data_table, "peri", "float") - node = get_from_table(data_table, "node", "float") - incl = get_from_table(data_table, "incl", "float") - e = get_from_table(data_table, "e", "float") - n = get_from_table(data_table, "n", "float") - q = get_from_table(data_table, "q", "float") - uncertaintyParameter = get_from_table(data_table, "uncertaintyParameter", "str") - flags = get_from_table(data_table, "flags", "str") - - return cls( - ssObjectId, - mpcDesignation, - mpcNumber, - mpcH, - mpcG, - epoch, - peri, - node, - incl, - e, - n, - q, - uncertaintyParameter, - flags, - ) + mpcorb_dict = {"ssObjectId": ssObjectId} + + for mpcorb_key, mpcorb_type in MPCORB_KEYS.items(): + mpcorb_dict[mpcorb_key] = get_from_table(data_table, mpcorb_key, mpcorb_type, "MPCORB") + + return cls(**mpcorb_dict) diff --git a/src/adler/dataclasses/Observations.py b/src/adler/dataclasses/Observations.py index 77768f6..64b3900 100644 --- a/src/adler/dataclasses/Observations.py +++ b/src/adler/dataclasses/Observations.py @@ -3,6 +3,17 @@ from adler.dataclasses.dataclass_utilities import get_from_table +OBSERVATIONS_KEYS = { + "mag": np.ndarray, + "magErr": np.ndarray, + "midPointMjdTai": np.ndarray, + "ra": np.ndarray, + "dec": np.ndarray, + "phaseAngle": np.ndarray, + "topocentricDist": np.ndarray, + "heliocentricDist": np.ndarray, +} + @dataclass class Observations: @@ -24,7 +35,7 @@ class Observations: magErr: array_like of floats Magnitude error. This is a placeholder and will be replaced by flux error. - midpointMjdTai: array_like of floats + midPointMjdTai: array_like of floats Effective mid-visit time for this diaSource, expressed as Modified Julian Date, International Atomic Time. ra: array_like of floats @@ -54,7 +65,7 @@ class Observations: filter_name: str = "" mag: np.ndarray = field(default_factory=lambda: np.zeros(0)) magErr: np.ndarray = field(default_factory=lambda: np.zeros(0)) - midpointMjdTai: np.ndarray = field(default_factory=lambda: np.zeros(0)) + midPointMjdTai: np.ndarray = field(default_factory=lambda: np.zeros(0)) ra: np.ndarray = field(default_factory=lambda: np.zeros(0)) dec: np.ndarray = field(default_factory=lambda: np.zeros(0)) phaseAngle: np.ndarray = field(default_factory=lambda: np.zeros(0)) @@ -85,32 +96,17 @@ def construct_from_data_table(cls, ssObjectId, filter_name, data_table): """ - mag = get_from_table(data_table, "mag", "array") - magErr = get_from_table(data_table, "magErr", "array") - midpointMjdTai = get_from_table(data_table, "midPointMjdTai", "array") - ra = get_from_table(data_table, "ra", "array") - dec = get_from_table(data_table, "dec", "array") - phaseAngle = get_from_table(data_table, "phaseAngle", "array") - topocentricDist = get_from_table(data_table, "topocentricDist", "array") - heliocentricDist = get_from_table(data_table, "heliocentricDist", "array") - - reduced_mag = cls.calculate_reduced_mag(cls, mag, topocentricDist, heliocentricDist) - - return cls( - ssObjectId, - filter_name, - mag, - magErr, - midpointMjdTai, - ra, - dec, - phaseAngle, - topocentricDist, - heliocentricDist, - reduced_mag, - len(data_table), + obs_dict = {"ssObjectId": ssObjectId, "filter_name": filter_name, "num_obs": len(data_table)} + + for obs_key, obs_type in OBSERVATIONS_KEYS.items(): + obs_dict[obs_key] = get_from_table(data_table, obs_key, obs_type, "SSSource/DIASource") + + obs_dict["reduced_mag"] = cls.calculate_reduced_mag( + cls, obs_dict["mag"], obs_dict["topocentricDist"], obs_dict["heliocentricDist"] ) + return cls(**obs_dict) + def calculate_reduced_mag(self, mag, topocentric_dist, heliocentric_dist): """ Calculates the reduced magnitude column. diff --git a/src/adler/dataclasses/SSObject.py b/src/adler/dataclasses/SSObject.py index 2207c51..9ec0443 100644 --- a/src/adler/dataclasses/SSObject.py +++ b/src/adler/dataclasses/SSObject.py @@ -3,6 +3,16 @@ from adler.dataclasses.dataclass_utilities import get_from_table +SSO_KEYS = { + "discoverySubmissionDate": float, + "firstObservationDate": float, + "arc": float, + "numObs": int, + "maxExtendedness": float, + "minExtendedness": float, + "medianExtendedness": float, +} + @dataclass class SSObject: @@ -57,47 +67,24 @@ class SSObject: @classmethod def construct_from_data_table(cls, ssObjectId, filter_list, data_table): - discoverySubmissionDate = get_from_table(data_table, "discoverySubmissionDate", "float") - firstObservationDate = get_from_table(data_table, "firstObservationDate", "float") - arc = get_from_table(data_table, "arc", "float") - numObs = get_from_table(data_table, "numObs", "int") - - H = np.zeros(len(filter_list)) - G12 = np.zeros(len(filter_list)) - Herr = np.zeros(len(filter_list)) - G12err = np.zeros(len(filter_list)) - nData = np.zeros(len(filter_list)) + sso_dict = {"ssObjectId": ssObjectId, "filter_list": filter_list, "filter_dependent_values": []} - filter_dependent_values = [] + for sso_key, sso_type in SSO_KEYS.items(): + sso_dict[sso_key] = get_from_table(data_table, sso_key, sso_type, "SSObject") for i, filter_name in enumerate(filter_list): filter_dept_object = FilterDependentSSO( filter_name=filter_name, - H=get_from_table(data_table, filter_name + "_H", "float"), - G12=get_from_table(data_table, filter_name + "_G12", "float"), - Herr=get_from_table(data_table, filter_name + "_HErr", "float"), - G12err=get_from_table(data_table, filter_name + "_G12Err", "float"), - nData=get_from_table(data_table, filter_name + "_Ndata", "int"), + H=get_from_table(data_table, filter_name + "_H", float, "SSObject"), + G12=get_from_table(data_table, filter_name + "_G12", float, "SSObject"), + Herr=get_from_table(data_table, filter_name + "_HErr", float, "SSObject"), + G12err=get_from_table(data_table, filter_name + "_G12Err", float, "SSObject"), + nData=get_from_table(data_table, filter_name + "_Ndata", float, "SSObject"), ) - filter_dependent_values.append(filter_dept_object) - - maxExtendedness = get_from_table(data_table, "maxExtendedness", "float") - minExtendedness = get_from_table(data_table, "minExtendedness", "float") - medianExtendedness = get_from_table(data_table, "medianExtendedness", "float") - - return cls( - ssObjectId, - filter_list, - discoverySubmissionDate, - firstObservationDate, - arc, - numObs, - filter_dependent_values, - maxExtendedness, - minExtendedness, - medianExtendedness, - ) + sso_dict["filter_dependent_values"].append(filter_dept_object) + + return cls(**sso_dict) @dataclass diff --git a/src/adler/dataclasses/dataclass_utilities.py b/src/adler/dataclasses/dataclass_utilities.py index ebba3d8..497e1a6 100644 --- a/src/adler/dataclasses/dataclass_utilities.py +++ b/src/adler/dataclasses/dataclass_utilities.py @@ -2,6 +2,9 @@ import pandas as pd import sqlite3 import warnings +import logging + +logger = logging.getLogger(__name__) def get_data_table(sql_query, service=None, sql_filename=None): @@ -48,34 +51,101 @@ def get_data_table(sql_query, service=None, sql_filename=None): return data_table -def get_from_table(data_table, column_name, data_type): - """Retrieves information from the data_table class variable and forces it to be a specified type. +def get_from_table(data_table, column_name, data_type, table_name="default"): + """Retrieves information from the data_table and forces it to be a specified type. Parameters ----------- + data_table : DALResultsTable or Pandas dataframe + Data table containing columns of interest. + column_name : str Column name under which the data of interest is stored. - type : str - String delineating data type. Should be "str", "float", "int" or "array". + + data_type : type + Data type. Should be int, float, str or np.ndarray. + + table_name : str + Name of the table. This is mostly for more informative error messages. Default="default". Returns ----------- - data : any type + data_val : str, float, int or nd.array The data requested from the table cast to the type required. """ - try: - if data_type == "str": - return str(data_table[column_name][0]) - elif data_type == "float": - return float(data_table[column_name][0]) - elif data_type == "int": - return int(data_table[column_name][0]) - elif data_type == "array": - return np.array(data_table[column_name]) - else: - raise TypeError( - "Type for argument data_type not recognised: must be one of 'str', 'float', 'int', 'array'." + with warnings.catch_warnings(): + warnings.filterwarnings( + "ignore", category=UserWarning + ) # RSP tables mask unpopulated elements, which get converted to NaN here and trigger a warning we don't care about. + try: + if data_type == str: + data_val = str(data_table[column_name][0]) + elif data_type == float: + data_val = float(data_table[column_name][0]) + elif data_type == int: + data_val = int(data_table[column_name][0]) + elif data_type == np.ndarray: + data_val = np.array(data_table[column_name]) + else: + logger.error( + "TypeError: Type for argument data_type not recognised for column {} in table {}: must be str, float, int or np.ndarray.".format( + column_name, table_name + ) + ) + raise TypeError( + "Type for argument data_type not recognised for column {} in table {}: must be str, float, int or np.ndarray.".format( + column_name, table_name + ) + ) + except ValueError: + logger.error("ValueError: Could not cast column name to type.") + raise ValueError("Could not cast column name to type.") + + # here we alert the user if one of the values is unpopulated and change it to a NaN + data_val = check_value_populated(data_val, data_type, column_name, table_name) + + return data_val + + +def check_value_populated(data_val, data_type, column_name, table_name): + """Checks to see if data_val populated properly and prints a helpful warning if it didn't. + Usually this will trigger because the RSP hasn't populated that field for this particular object. + + Parameters + ----------- + data_val : str, float, int or nd.array + The value to check. + + data_type: type + Data type. Should be int, float, str or np.ndarray. + + column_name: str + Column name under which the data of interest is stored. + + table_name : str + Name of the table. This is mostly for more informative error messages. Default="default". + + Returns + ----------- + data_val : str, float, int, nd.array or np.nan + Either returns the original data_val or an np.nan if it detected that the value was not populated. + + """ + + array_length_zero = data_type == np.ndarray and len(data_val) == 0 + number_is_nan = data_type in [float, int] and np.isnan(data_val) + str_is_empty = data_type == str and len(data_val) == 0 + + if array_length_zero or number_is_nan or str_is_empty: + logger.warning( + "{} unpopulated in {} table for this object. Storing NaN instead.".format(column_name, table_name) + ) + print( + "WARNING: {} unpopulated in {} table for this object. Storing NaN instead.".format( + column_name, table_name ) - except ValueError: - raise ValueError("Could not cast column name to type.") + ) + data_val = np.nan + + return data_val diff --git a/src/adler/utilities/AdlerCLIArguments.py b/src/adler/utilities/AdlerCLIArguments.py new file mode 100644 index 0000000..dea7e98 --- /dev/null +++ b/src/adler/utilities/AdlerCLIArguments.py @@ -0,0 +1,63 @@ +import os + + +class AdlerCLIArguments: + """ + Class for storing abd validating Adler command-line arguments. + + Attributes: + ----------- + args : argparse.Namespace object + argparse.Namespace object created by calling parse_args(). + + """ + + def __init__(self, args): + self.ssObjectId = args.ssObjectId + self.filter_list = args.filter_list + self.date_range = args.date_range + self.outpath = args.outpath + self.db_name = args.db_name + + self.validate_arguments() + + def validate_arguments(self): + self._validate_filter_list() + self._validate_ssObjectId() + self._validate_date_range() + self._validate_outpath() + + def _validate_filter_list(self): + expected_filters = ["u", "g", "r", "i", "z", "y"] + + if not set(self.filter_list).issubset(expected_filters): + raise ValueError( + "Unexpected filters found in --filter_list command-line argument. --filter_list must be a list of LSST filters." + ) + + def _validate_ssObjectId(self): + try: + int(self.ssObjectId) + except ValueError: + raise ValueError("--ssObjectId command-line argument does not appear to be a valid ssObjectId.") + + def _validate_date_range(self): + for d in self.date_range: + try: + float(d) + except ValueError: + raise ValueError( + "One or both of the values for the --date_range command-line argument do not seem to be valid numbers." + ) + + if any(d > 250000 for d in self.date_range): + raise ValueError( + "Dates for --date_range command-line argument seem rather large. Did you input JD instead of MJD?" + ) + + def _validate_outpath(self): + # make it an absolute path if it's relative! + self.outpath = os.path.abspath(self.outpath) + + if not os.path.isdir(self.outpath): + raise ValueError("The output path for the command-line argument --outpath cannot be found.") diff --git a/src/adler/utilities/adler_logging.py b/src/adler/utilities/adler_logging.py new file mode 100644 index 0000000..d885c79 --- /dev/null +++ b/src/adler/utilities/adler_logging.py @@ -0,0 +1,42 @@ +import logging +import os +from datetime import datetime + + +def setup_adler_logging( + log_location, + log_format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s ", + log_name="", + log_file_info="adler.log", + log_file_error="adler.err", +): + log = logging.getLogger(log_name) + log_formatter = logging.Formatter(log_format) + + # comment this to suppress console output + # stream_handler = logging.StreamHandler() + # stream_handler.setFormatter(log_formatter) + # log.addHandler(stream_handler) + + dstr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S") + cpid = os.getpid() + + log_file_info = os.path.join(log_location, dstr + "-p" + str(cpid) + "-" + log_file_info) + log_file_error = os.path.join(log_location, dstr + "-p" + str(cpid) + "-" + log_file_error) + + # this log will log pretty much everything: basic info, but also warnings and errors + file_handler_info = logging.FileHandler(log_file_info, mode="w") + file_handler_info.setFormatter(log_formatter) + file_handler_info.setLevel(logging.INFO) + log.addHandler(file_handler_info) + + # this log only logs warnings and errors, so they can be looked at quickly without a lot of scrolling + file_handler_error = logging.FileHandler(log_file_error, mode="w") + file_handler_error.setFormatter(log_formatter) + file_handler_error.setLevel(logging.WARN) + log.addHandler(file_handler_error) + + # I don't know why we need this line but info logging doesn't work without it, upsettingly + log.setLevel(logging.INFO) + + return log diff --git a/tests/adler/dataclasses/test_AdlerPlanetoid.py b/tests/adler/dataclasses/test_AdlerPlanetoid.py index 288c5cc..3ac451d 100644 --- a/tests/adler/dataclasses/test_AdlerPlanetoid.py +++ b/tests/adler/dataclasses/test_AdlerPlanetoid.py @@ -11,13 +11,15 @@ def test_construct_from_SQL(): - test_planetoid = AdlerPlanetoid.construct_from_SQL(ssoid, test_db_path) + test_planetoid = AdlerPlanetoid.construct_from_SQL( + ssoid, test_db_path, filter_list=["u", "g", "r", "i", "z", "y"] + ) # testing just a few values here to ensure correct setup: these objects have their own unit tests assert test_planetoid.MPCORB.mpcH == 19.8799991607666 assert test_planetoid.SSObject.discoverySubmissionDate == 60218.0 assert_almost_equal( - test_planetoid.observations_by_filter[1].mag, + test_planetoid.observations_by_filter[0].mag, [ 21.33099937, 22.67099953, @@ -31,10 +33,10 @@ def test_construct_from_SQL(): ], ) - # did we pick up all the filters? - assert len(test_planetoid.observations_by_filter) == 6 - assert len(test_planetoid.SSObject.filter_dependent_values) == 6 - assert test_planetoid.filter_list == ["u", "g", "r", "i", "z", "y"] + # did we pick up all the filters? note we ask for ugrizy but u and y are unpopulated in DP0.3, so the code should eliminate them + assert len(test_planetoid.observations_by_filter) == 4 + assert len(test_planetoid.SSObject.filter_dependent_values) == 4 + assert test_planetoid.filter_list == ["g", "r", "i", "z"] # checking the date range to ensure it's the default assert test_planetoid.date_range == [60000.0, 67300.0] @@ -86,7 +88,7 @@ def test_construct_with_date_range(): ] ) - assert_almost_equal(test_planetoid.observations_by_filter[0].midpointMjdTai, expected_dates) + assert_almost_equal(test_planetoid.observations_by_filter[0].midPointMjdTai, expected_dates) with pytest.raises(ValueError) as error_info_1: test_planetoid = AdlerPlanetoid.construct_from_SQL( @@ -100,12 +102,10 @@ def test_observations_in_filter(): test_planetoid = AdlerPlanetoid.construct_from_SQL(ssoid, test_db_path) # Python dataclasses create an __eq__ for you so object-to-object comparison just works, isn't that nice? - assert test_planetoid.observations_in_filter("u") == test_planetoid.observations_by_filter[0] - assert test_planetoid.observations_in_filter("g") == test_planetoid.observations_by_filter[1] - assert test_planetoid.observations_in_filter("r") == test_planetoid.observations_by_filter[2] - assert test_planetoid.observations_in_filter("i") == test_planetoid.observations_by_filter[3] - assert test_planetoid.observations_in_filter("z") == test_planetoid.observations_by_filter[4] - assert test_planetoid.observations_in_filter("y") == test_planetoid.observations_by_filter[5] + assert test_planetoid.observations_in_filter("g") == test_planetoid.observations_by_filter[0] + assert test_planetoid.observations_in_filter("r") == test_planetoid.observations_by_filter[1] + assert test_planetoid.observations_in_filter("i") == test_planetoid.observations_by_filter[2] + assert test_planetoid.observations_in_filter("z") == test_planetoid.observations_by_filter[3] with pytest.raises(ValueError) as error_info_1: test_planetoid.observations_in_filter("f") @@ -116,14 +116,55 @@ def test_observations_in_filter(): def test_SSObject_in_filter(): test_planetoid = AdlerPlanetoid.construct_from_SQL(ssoid, test_db_path) - assert test_planetoid.SSObject_in_filter("u") == test_planetoid.SSObject.filter_dependent_values[0] - assert test_planetoid.SSObject_in_filter("g") == test_planetoid.SSObject.filter_dependent_values[1] - assert test_planetoid.SSObject_in_filter("r") == test_planetoid.SSObject.filter_dependent_values[2] - assert test_planetoid.SSObject_in_filter("i") == test_planetoid.SSObject.filter_dependent_values[3] - assert test_planetoid.SSObject_in_filter("z") == test_planetoid.SSObject.filter_dependent_values[4] - assert test_planetoid.SSObject_in_filter("y") == test_planetoid.SSObject.filter_dependent_values[5] + assert test_planetoid.SSObject_in_filter("g") == test_planetoid.SSObject.filter_dependent_values[0] + assert test_planetoid.SSObject_in_filter("r") == test_planetoid.SSObject.filter_dependent_values[1] + assert test_planetoid.SSObject_in_filter("i") == test_planetoid.SSObject.filter_dependent_values[2] + assert test_planetoid.SSObject_in_filter("z") == test_planetoid.SSObject.filter_dependent_values[3] with pytest.raises(ValueError) as error_info_1: test_planetoid.SSObject_in_filter("f") assert error_info_1.value.args[0] == "Filter f is not in AdlerPlanetoid.filter_list." + + +def test_no_observations(): + with pytest.raises(Exception) as error_info: + test_planetoid = AdlerPlanetoid.construct_from_SQL(826857066833589477, test_db_path) + + assert ( + error_info.value.args[0] + == "No observations found for this object in the given filter(s). Check SSOID and try again." + ) + + +def test_for_warnings(capsys): + test_planetoid = AdlerPlanetoid.construct_from_SQL(ssoid, test_db_path, filter_list=["u", "g"]) + captured = capsys.readouterr() + + expected = ( + "WARNING: No observations found in u filter for this object. Skipping this filter.\n" + + "WARNING: n unpopulated in MPCORB table for this object. Storing NaN instead.\n" + + "WARNING: uncertaintyParameter unpopulated in MPCORB table for this object. Storing NaN instead.\n" + ) + + assert captured.out == expected + + +def test_failed_SQL_queries(): + test_planetoid = AdlerPlanetoid.construct_from_SQL( + ssoid, test_db_path, filter_list=["u", "g", "r", "i", "z", "y"] + ) + + with pytest.raises(Exception) as error_info_1: + test_planetoid.populate_MPCORB("826857066833589477", sql_filename=test_db_path, schema="") + + assert error_info_1.value.args[0] == "No MPCORB data for this object could be found for this SSObjectId." + + with pytest.raises(Exception) as error_info_2: + test_planetoid.populate_SSObject( + "826857066833589477", filter_list=["u"], sql_filename=test_db_path, schema="" + ) + + assert ( + error_info_2.value.args[0] == "No SSObject data for this object could be found for this SSObjectId." + ) diff --git a/tests/adler/dataclasses/test_MPCORB.py b/tests/adler/dataclasses/test_MPCORB.py index f94db75..d139d00 100644 --- a/tests/adler/dataclasses/test_MPCORB.py +++ b/tests/adler/dataclasses/test_MPCORB.py @@ -36,5 +36,5 @@ def test_construct_MPCORB_from_data_table(): assert_almost_equal(test_MPCORB.e, 0.7168805704972735, decimal=6) assert np.isnan(test_MPCORB.n) assert_almost_equal(test_MPCORB.q, 0.5898291078470536, decimal=6) - assert test_MPCORB.uncertaintyParameter == "" + assert np.isnan(test_MPCORB.uncertaintyParameter) assert test_MPCORB.flags == "0" diff --git a/tests/adler/dataclasses/test_Observations.py b/tests/adler/dataclasses/test_Observations.py index 24c2b25..44a3717 100644 --- a/tests/adler/dataclasses/test_Observations.py +++ b/tests/adler/dataclasses/test_Observations.py @@ -161,7 +161,7 @@ def test_construct_observations_from_data_table(): assert_almost_equal(test_observations.mag, expected_mag) assert_almost_equal(test_observations.magErr, expected_magerr) - assert_almost_equal(test_observations.midpointMjdTai, expected_mjd) + assert_almost_equal(test_observations.midPointMjdTai, expected_mjd) assert_almost_equal(test_observations.ra, expected_ra) assert_almost_equal(test_observations.dec, expected_dec) assert_almost_equal(test_observations.phaseAngle, expected_phaseangle) diff --git a/tests/adler/dataclasses/test_dataclass_utilities.py b/tests/adler/dataclasses/test_dataclass_utilities.py index dfae4b7..b087e0e 100644 --- a/tests/adler/dataclasses/test_dataclass_utilities.py +++ b/tests/adler/dataclasses/test_dataclass_utilities.py @@ -1,10 +1,12 @@ import pytest import pandas as pd +import numpy as np from pandas.testing import assert_frame_equal from numpy.testing import assert_equal from adler.dataclasses.dataclass_utilities import get_data_table from adler.dataclasses.dataclass_utilities import get_from_table +from adler.dataclasses.dataclass_utilities import check_value_populated from adler.utilities.tests_utilities import get_test_data_filepath @@ -37,13 +39,13 @@ def test_get_from_table(): {"string_col": "a test string", "int_col": 4, "float_col": 4.5, "array_col": [5, 6]} ) - assert get_from_table(test_table, "string_col", "str") == "a test string" - assert get_from_table(test_table, "int_col", "int") == 4 - assert get_from_table(test_table, "float_col", "float") == 4.5 - assert_equal(get_from_table(test_table, "array_col", "array"), [5, 6]) + assert get_from_table(test_table, "string_col", str) == "a test string" + assert get_from_table(test_table, "int_col", int) == 4 + assert get_from_table(test_table, "float_col", float) == 4.5 + assert_equal(get_from_table(test_table, "array_col", np.ndarray), [5, 6]) with pytest.raises(ValueError) as error_info_1: - get_from_table(test_table, "string_col", "int") + get_from_table(test_table, "string_col", int) assert error_info_1.value.args[0] == "Could not cast column name to type." @@ -52,5 +54,18 @@ def test_get_from_table(): assert ( error_info_2.value.args[0] - == "Type for argument data_type not recognised: must be one of 'str', 'float', 'int', 'array'." + == "Type for argument data_type not recognised for column string_col in table default: must be str, float, int or np.ndarray." ) + + +def test_check_value_populated(): + populated_value = check_value_populated(3, int, "column", "table") + assert populated_value == 3 + + array_length_zero = check_value_populated(np.array([]), np.ndarray, "column", "table") + number_is_nan = check_value_populated(np.nan, float, "column", "table") + str_is_empty = check_value_populated("", str, "column", "table") + + assert np.isnan(array_length_zero) + assert np.isnan(number_is_nan) + assert np.isnan(str_is_empty) diff --git a/tests/adler/utilities/test_AdlerCLIArguments.py b/tests/adler/utilities/test_AdlerCLIArguments.py new file mode 100644 index 0000000..5dff38c --- /dev/null +++ b/tests/adler/utilities/test_AdlerCLIArguments.py @@ -0,0 +1,104 @@ +import os +import pytest +from adler.utilities.AdlerCLIArguments import AdlerCLIArguments + + +# AdlerCLIArguments object takes an object as input, so we define a quick one here +class args: + def __init__(self, ssObjectId, filter_list, date_range, outpath, db_name): + self.ssObjectId = ssObjectId + self.filter_list = filter_list + self.date_range = date_range + self.outpath = outpath + self.db_name = db_name + + +def test_AdlerCLIArguments_population(): + # test correct population + good_input_dict = { + "ssObjectId": "666", + "filter_list": ["g", "r", "i"], + "date_range": [60000.0, 67300.0], + "outpath": "./", + "db_name": "output", + } + good_arguments = args(**good_input_dict) + good_arguments_object = AdlerCLIArguments(good_arguments) + + good_input_dict["outpath"] = os.path.abspath("./") + + assert good_arguments_object.__dict__ == good_input_dict + + +def test_AdlerCLIArguments_badSSOID(): + # test that a bad ssObjectId triggers the right error + bad_ssoid_arguments = args("hello!", ["g", "r", "i"], [60000.0, 67300.0], "./", "output") + + with pytest.raises(ValueError) as bad_ssoid_error: + bad_ssoid_object = AdlerCLIArguments(bad_ssoid_arguments) + + assert ( + bad_ssoid_error.value.args[0] + == "--ssObjectId command-line argument does not appear to be a valid ssObjectId." + ) + + +def test_AdlerCLIArguments_badfilters(): + # test that non-LSST or unexpected filters trigger the right error + bad_filter_arguments = args("666", ["g", "r", "i", "m"], [60000.0, 67300.0], "./", "output") + + with pytest.raises(ValueError) as bad_filter_error: + bad_filter_object = AdlerCLIArguments(bad_filter_arguments) + + assert ( + bad_filter_error.value.args[0] + == "Unexpected filters found in --filter_list command-line argument. --filter_list must be a list of LSST filters." + ) + + bad_filter_arguments_2 = args("666", ["pony"], [60000.0, 67300.0], "./", "output") + + with pytest.raises(ValueError) as bad_filter_error_2: + bad_filter_object = AdlerCLIArguments(bad_filter_arguments_2) + + assert ( + bad_filter_error_2.value.args[0] + == "Unexpected filters found in --filter_list command-line argument. --filter_list must be a list of LSST filters." + ) + + +def test_AdlerCLIArguments_baddates(): + # test that overly-large dates trigger the right error + big_date_arguments = args("666", ["g", "r", "i"], [260000.0, 267300.0], "./", "output") + + with pytest.raises(ValueError) as big_date_error: + big_date_object = AdlerCLIArguments(big_date_arguments) + + assert ( + big_date_error.value.args[0] + == "Dates for --date_range command-line argument seem rather large. Did you input JD instead of MJD?" + ) + + # test that unexpected date values trigger the right error + bad_date_arguments = args("666", ["g", "r", "i"], [60000.0, "cheese"], "./", "output") + + with pytest.raises(ValueError) as bad_date_error: + bad_date_object = AdlerCLIArguments(bad_date_arguments) + + assert ( + bad_date_error.value.args[0] + == "One or both of the values for the --date_range command-line argument do not seem to be valid numbers." + ) + + +def test_AdlerCLIArguments_badoutput(): + bad_output_arguments = args( + "666", ["g", "r", "i"], [60000.0, 67300.0], "./definitely_fake_folder/", "output" + ) + + with pytest.raises(ValueError) as bad_output_error: + bad_output_object = AdlerCLIArguments(bad_output_arguments) + + assert ( + bad_output_error.value.args[0] + == "The output path for the command-line argument --outpath cannot be found." + ) diff --git a/tests/adler/utilities/test_adler_logging.py b/tests/adler/utilities/test_adler_logging.py new file mode 100644 index 0000000..eefd91a --- /dev/null +++ b/tests/adler/utilities/test_adler_logging.py @@ -0,0 +1,40 @@ +import glob +import os +import pytest +import tempfile + + +def test_setup_adler_logging(): + from adler.utilities.adler_logging import setup_adler_logging + + with tempfile.TemporaryDirectory() as dir_name: + logger = setup_adler_logging(dir_name) + + # Check that the files get created. + errlog = glob.glob(os.path.join(dir_name, "*-adler.err")) + datalog = glob.glob(os.path.join(dir_name, "*-adler.log")) + + assert os.path.exists(errlog[0]) + assert os.path.exists(datalog[0]) + + # Log some information. + logger.info("Test1") + logger.info("Test2") + logger.error("Error1") + logger.info("Test3") + + # Check that all five lines exist in the INFO file. + with open(datalog[0], "r") as f_info: + log_data = f_info.read() + assert "Test1" in log_data + assert "Test2" in log_data + assert "Error1" in log_data + assert "Test3" in log_data + + # Check that only error and critical lines exist in the ERROR file. + with open(errlog[0], "r") as f_err: + log_data = f_err.read() + assert "Test1" not in log_data + assert "Test2" not in log_data + assert "Error1" in log_data + assert "Test3" not in log_data