From 39a4ab8cebf68e765e826a8e6db2eeb3490b7693 Mon Sep 17 00:00:00 2001 From: Steph Merritt <97111051+astronomerritt@users.noreply.github.com> Date: Fri, 20 Sep 2024 15:03:21 +0100 Subject: [PATCH] Adding functionality to grab lightcurves from Cassandra. (#167) * Adding functionality to grab lightcurves from Cassandra. * Changing class name for PEP8 compliance * [pre-commit.ci lite] apply automatic fixes * Moving import statement * Moving import statement * [pre-commit.ci lite] apply automatic fixes * Adding docstrings. --------- Co-authored-by: pre-commit-ci-lite[bot] <117423508+pre-commit-ci-lite[bot]@users.noreply.github.com> --- src/adler/__init__.py | 1 + src/adler/dataclasses/AdlerPlanetoid.py | 85 ++++++++-- src/adler/dataclasses/MPCORB.py | 20 ++- src/adler/dataclasses/Observations.py | 28 +++- src/adler/dataclasses/SSObject.py | 52 +++++- src/adler/dataclasses/dataclass_utilities.py | 28 +++- src/adler/lasair/cassandra_fetcher.py | 160 +++++++++++++++++++ src/adler/utilities/AdlerCLIArguments.py | 18 +++ 8 files changed, 367 insertions(+), 25 deletions(-) create mode 100644 src/adler/lasair/cassandra_fetcher.py diff --git a/src/adler/__init__.py b/src/adler/__init__.py index 443fa66..6d46c05 100644 --- a/src/adler/__init__.py +++ b/src/adler/__init__.py @@ -1,3 +1,4 @@ from . import dataclasses from . import science from . import utilities +from . import lasair diff --git a/src/adler/dataclasses/AdlerPlanetoid.py b/src/adler/dataclasses/AdlerPlanetoid.py index 56d4679..a350b8e 100644 --- a/src/adler/dataclasses/AdlerPlanetoid.py +++ b/src/adler/dataclasses/AdlerPlanetoid.py @@ -117,30 +117,87 @@ def construct_from_SQL( return cls(ssObjectId, filter_list, date_range, observations_by_filter, mpcorb, ssobject, adler_data) @classmethod - def construct_from_JSON(cls, json_filename): - with open(json_filename) as f: - json_dict = json.load(f) + def construct_from_cassandra( + cls, + ssObjectId, + filter_list=["u", "g", "r", "i", "z", "y"], + date_range=[60000.0, 67300.0], + cassandra_hosts=["10.21.3.123"], + ): # pragma: no cover + """Custom constructor which builds the AdlerPlanetoid object and the associated Observations, MPCORB and SSObject objects from + a Cassandra database. Used only for Lasair integration. + + TODO: move method to its own class which inherits from AdlerPlanetoid and move to adler-lasair repo? + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. - observations_dict = {**json_dict["SSSource"], **json_dict["DiaSource"]} + filter_list : list of str + A comma-separated list of the filters of interest. + + date_range : list of float + The minimum and maximum dates of the desired observations. + + cassandra_hosts : list of str + Location of the Cassandra database - usually an IP address. Default is ["10.21.3.123"]. - filter_list = [observations_dict["band"]] + """ + # do not move this import! CassandraFetcher requires the non-mandatory + # cassandra-driver library - if not installed, and this import is at the top, + # test collection will break. + from adler.lasair.cassandra_fetcher import CassandraFetcher + + fetcher = CassandraFetcher(cassandra_hosts=cassandra_hosts) + + MPCORB_dict = fetcher.fetch_MPCORB(ssObjectId) + SSObject_dict = fetcher.fetch_SSObject(ssObjectId, filter_list) + observations_dict = fetcher.fetch_observations(ssObjectId) - MPCORB_dict = json_dict["MPCORB"] - SSObject_dict = json_dict["SSObject"] + # note that Cassandra doesn't allow filters/joins + # instead we pull all observations for this ID, then filter with Pandas later + observations_table = pd.DataFrame(observations_dict) + observations_table.rename(columns={"decl": "dec"}, inplace=True) - ssObjectId = observations_dict["ssObjectId"] + observations_by_filter = [] + for filter_name in filter_list: + obs_slice = observations_table[ + (observations_table["band"] == filter_name) + & (observations_table["midpointmjdtai"].between(date_range[0], date_range[1])) + ] + + if len(obs_slice) == 0: + logger.warning( + "No observations found in {} filter for this object. Skipping this filter.".format( + filter_name + ) + ) + else: + observations = Observations.construct_from_data_table(ssObjectId, filter_name, obs_slice) + observations_by_filter.append(observations) + + if len(observations_by_filter) == 0: + logger.error( + "No observations found for this object in the given filter(s). Check SSOID and try again." + ) + raise Exception( + "No observations found for this object in the given filter(s). Check SSOID and try again." + ) + + if len(filter_list) > len(observations_by_filter): + logger.info( + "Not all specified filters have observations. Recalculating filter list based on past observations." + ) + filter_list = [obs_object.filter_name for obs_object in observations_by_filter] + logger.info("New filter list is: {}".format(filter_list)) - observations_by_filter = [ - Observations.construct_from_dictionary(ssObjectId, filter_list[0], observations_dict) - ] mpcorb = MPCORB.construct_from_dictionary(ssObjectId, MPCORB_dict) ssobject = SSObject.construct_from_dictionary(ssObjectId, filter_list, SSObject_dict) adler_data = AdlerData(ssObjectId, filter_list) - return cls( - ssObjectId, filter_list, [np.nan, np.nan], observations_by_filter, mpcorb, ssobject, adler_data - ) + return cls(ssObjectId, filter_list, date_range, observations_by_filter, mpcorb, ssobject, adler_data) @classmethod def construct_from_RSP( diff --git a/src/adler/dataclasses/MPCORB.py b/src/adler/dataclasses/MPCORB.py index 844cef9..237df1d 100644 --- a/src/adler/dataclasses/MPCORB.py +++ b/src/adler/dataclasses/MPCORB.py @@ -122,9 +122,27 @@ def construct_from_data_table(cls, ssObjectId, data_table): @classmethod def construct_from_dictionary(cls, ssObjectId, data_dict): + """Initialises the MPCORB object from a dictionary of data. + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. + + data_dict : dict or dict-like object + Dictionary of data from which attributes shoud be populated. + + Returns + ----------- + MPCORB object + MPCORB object with class attributes populated from data_table. + + """ mpcorb_dict = {"ssObjectId": ssObjectId} for mpcorb_key, mpcorb_type in MPCORB_KEYS.items(): - mpcorb_dict[mpcorb_key] = get_from_dictionary(data_dict, mpcorb_key, mpcorb_type, "MPCORB") + mpcorb_dict[mpcorb_key] = get_from_dictionary( + data_dict, mpcorb_key.casefold(), mpcorb_type, "MPCORB" + ) return cls(**mpcorb_dict) diff --git a/src/adler/dataclasses/Observations.py b/src/adler/dataclasses/Observations.py index eb92af2..9403330 100644 --- a/src/adler/dataclasses/Observations.py +++ b/src/adler/dataclasses/Observations.py @@ -115,7 +115,7 @@ class Observations: num_obs: int = 0 @classmethod - def construct_from_data_table(cls, ssObjectId, filter_name, data_table): + def construct_from_data_table(cls, ssObjectId, filter_name, data_table, cassandra=False): """Initialises the Observations object from a table of data. Parameters @@ -139,7 +139,12 @@ def construct_from_data_table(cls, ssObjectId, filter_name, data_table): obs_dict = {"ssObjectId": ssObjectId, "filter_name": filter_name, "num_obs": len(data_table)} for obs_key, obs_type in OBSERVATIONS_KEYS.items(): - obs_dict[obs_key] = get_from_table(data_table, obs_key, obs_type, "SSSource/DIASource") + try: + obs_dict[obs_key] = get_from_table(data_table, obs_key, obs_type, "SSSource/DIASource") + except KeyError: # sometimes we have case issues... + obs_dict[obs_key] = get_from_table( + data_table, obs_key.casefold(), obs_type, "SSSource/DIASource" + ) obs_dict["reduced_mag"] = cls.calculate_reduced_mag( cls, obs_dict["mag"], obs_dict["topocentricDist"], obs_dict["heliocentricDist"] @@ -149,6 +154,25 @@ def construct_from_data_table(cls, ssObjectId, filter_name, data_table): @classmethod def construct_from_dictionary(cls, ssObjectId, filter_name, data_dict): + """Initialises the Observations object from a dictionary of data. + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. + + filter_name : str + String of the filter the observations are taken in, + + data_dict : dict or dict-like object + Dictionary of data from which attributes shoud be populated. + + Returns + ----------- + Observations object + Observations object with class attributes populated from data_dict. + + """ obs_dict = {"ssObjectId": ssObjectId, "filter_name": filter_name, "num_obs": 1} for obs_key, obs_type in OBSERVATIONS_KEYS.items(): diff --git a/src/adler/dataclasses/SSObject.py b/src/adler/dataclasses/SSObject.py index eedc724..c0ba468 100644 --- a/src/adler/dataclasses/SSObject.py +++ b/src/adler/dataclasses/SSObject.py @@ -67,6 +67,25 @@ class SSObject: @classmethod def construct_from_data_table(cls, ssObjectId, filter_list, data_table): + """Initialises the SSObject object from a table of data. + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. + + filter_list : list of str + A comma-separated list of the filters of interest. + + data_table : table-like object + Table of data from which attributes shoud be populated. + + Returns + ----------- + SSObject object + SSObject object with class attributes populated from data_table. + + """ sso_dict = {"ssObjectId": ssObjectId, "filter_list": filter_list, "filter_dependent_values": []} for sso_key, sso_type in SSO_KEYS.items(): @@ -88,19 +107,40 @@ def construct_from_data_table(cls, ssObjectId, filter_list, data_table): @classmethod def construct_from_dictionary(cls, ssObjectId, filter_list, data_dict): + """Initialises the SSObject object from a dictionary of data. + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. + + filter_list : list of str + A comma-separated list of the filters of interest. + + data_dict : dict or dict-like object + Ditcionary of data from which attributes shoud be populated. + + Returns + ----------- + SSObject object + SSObject object with class attributes populated from data_dict. + + """ sso_dict = {"ssObjectId": ssObjectId, "filter_list": filter_list, "filter_dependent_values": []} for sso_key, sso_type in SSO_KEYS.items(): - sso_dict[sso_key] = get_from_dictionary(data_dict, sso_key, sso_type, "SSObject") + sso_dict[sso_key] = get_from_dictionary(data_dict, sso_key.casefold(), sso_type, "SSObject") for i, filter_name in enumerate(filter_list): filter_dept_object = FilterDependentSSO( filter_name=filter_name, - H=get_from_dictionary(data_dict, filter_name + "_H", float, "SSObject"), - G12=get_from_dictionary(data_dict, filter_name + "_G12", float, "SSObject"), - Herr=get_from_dictionary(data_dict, filter_name + "_HErr", float, "SSObject"), - G12err=get_from_dictionary(data_dict, filter_name + "_G12Err", float, "SSObject"), - nData=get_from_dictionary(data_dict, filter_name + "_Ndata", float, "SSObject"), + H=get_from_dictionary(data_dict, (filter_name + "_H").casefold(), float, "SSObject"), + G12=get_from_dictionary(data_dict, (filter_name + "_G12").casefold(), float, "SSObject"), + Herr=get_from_dictionary(data_dict, (filter_name + "_HErr").casefold(), float, "SSObject"), + G12err=get_from_dictionary( + data_dict, (filter_name + "_G12Err").casefold(), float, "SSObject" + ), + nData=get_from_dictionary(data_dict, (filter_name + "_Ndata").casefold(), float, "SSObject"), ) sso_dict["filter_dependent_values"].append(filter_dept_object) diff --git a/src/adler/dataclasses/dataclass_utilities.py b/src/adler/dataclasses/dataclass_utilities.py index d82c031..d6dd61e 100644 --- a/src/adler/dataclasses/dataclass_utilities.py +++ b/src/adler/dataclasses/dataclass_utilities.py @@ -109,6 +109,29 @@ def get_from_table(data_table, column_name, data_type, table_name="default"): def get_from_dictionary(data_dict, key_name, data_type, table_name="default"): + """Retrieves information from a dictionary and forces it to be a specified type. + + Parameters + ----------- + data_dict : dict or dict-like object + Dictionary containing columns of interest. + + key_name : str + Key name under which the data of interest is stored. + + data_type : type + Data type. Should be int, float, str or np.ndarray. + + table_name : str + Name of the table or dictionary. This is mostly for more informative error messages. Default="default". + + Returns + ----------- + data_val : str, float, int or nd.array + The data requested from the dictionary cast to the type required. + + """ + try: if data_type == str: data_val = str(data_dict[key_name]) @@ -124,14 +147,15 @@ def get_from_dictionary(data_dict, key_name, data_type, table_name="default"): except ValueError: print("error message") - data_val = check_value_populated(data_val, data_type, key_name, "JSON") + data_val = check_value_populated(data_val, data_type, key_name, "dictionary") return data_val def check_value_populated(data_val, data_type, column_name, table_name): """Checks to see if data_val populated properly and prints a helpful warning if it didn't. - Usually this will trigger because the RSP hasn't populated that field for this particular object. + Usually this will trigger because the RSP or Cassandra database hasn't populated that + field for this particular object. Parameters ----------- diff --git a/src/adler/lasair/cassandra_fetcher.py b/src/adler/lasair/cassandra_fetcher.py new file mode 100644 index 0000000..ba6e157 --- /dev/null +++ b/src/adler/lasair/cassandra_fetcher.py @@ -0,0 +1,160 @@ +import json +import sys +from cassandra.cluster import Cluster, ConsistencyLevel +from cassandra.query import dict_factory, SimpleStatement + + +class CassandraFetcher: # pragma: no cover + """Class to fetch data from a Cassandra database, used for Lasair integration. + + TODO: move to the lasair-adler repo. + + Attributes + ----------- + cassandra_hosts : list of str + Location of the Cassandra database - usually an IP address. Default is ["10.21.3.123"]. + + """ + + def __init__(self, cassandra_hosts): + self.cluster = Cluster(cassandra_hosts) + self.session = self.cluster.connect() + # Set the row_factory to dict_factory, otherwise + # the data returned will be in the form of object properties. + self.session.row_factory = dict_factory + self.session.set_keyspace("adler") + + def fetch_SSObject(self, ssObjectId, filter_list): + """Fetches the metadata from the SSObject table of a Cassandra database as a dictionary. + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. + + filter_list : list of str + A comma-separated list of the filters of interest. + + Returns + ----------- + dict + A dictionary of metadata for the object of interest in the filters + of interest. + + """ + + filter_dependent_columns = "" + for filter_name in filter_list: + filter_string = "{}_H, {}_G12, {}_HErr, {}_G12Err, {}_Ndata, ".format( + filter_name, filter_name, filter_name, filter_name, filter_name + ) + + filter_dependent_columns += filter_string + + obj = {} + + SSObject_sql_query = f""" + SELECT + discoverySubmissionDate, firstObservationDate, arc, numObs, + {filter_dependent_columns} + maxExtendedness, minExtendedness, medianExtendedness + FROM + ssobjects + WHERE + ssObjectId = {ssObjectId} + """ + + ret = self.session.execute(SSObject_sql_query) + + for ssObject in ret: + obj = ssObject + + return obj + + def fetch_MPCORB(self, ssObjectId): + """Fetches the metadata from the MPCORB table of a Cassandra database as a dictionary. + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. + + Returns + ----------- + dict + A dictionary of metadata for the object of interest. + + """ + + obj = {} + + MPCORB_sql_query = f""" + SELECT + ssObjectId, mpcDesignation, fullDesignation, mpcNumber, mpcH, mpcG, epoch, tperi, peri, node, incl, e, n, q, + uncertaintyParameter, flags + FROM + mpcorbs + WHERE + ssObjectId = {ssObjectId} + """ + + ret = self.session.execute(MPCORB_sql_query) + + for MPCORB in ret: + obj = MPCORB + + return obj + + def fetch_observations(self, ssObjectId): + """Fetches the source observations from the DIASource and SSSource tables as a dictionary. + Note that it will retrieve ALL observations for the object regardless of filter and data range, + so any filtering must be performed later. This is due to restrictions on queries to Cassandra. + + Parameters + ----------- + ssObjectId : str + ssObjectId of the object of interest. + + Returns + ----------- + dict + A dictionary of metadata for the object of interest in the filters + of interest. + + """ + + sourceDict = {} + + dia_query = f""" + SELECT + diasourceid, band, mag, magErr, midPointMjdTai, ra, decl + FROM + diasources + WHERE + ssObjectId = {ssObjectId} + """ + ret = self.session.execute(dia_query) + + n = 0 + for diaSource in ret: + sourceDict[diaSource["diasourceid"]] = diaSource + n += 1 + + ss_query = f"""SELECT diasourceid, phaseAngle, topocentricDist, heliocentricDist, heliocentricX, heliocentricY, heliocentricZ, + topocentricX, topocentricY, topocentricZ, eclipticLambda, eclipticBeta + FROM sssources + WHERE + ssObjectId = {ssObjectId} + """ + ret = self.session.execute(ss_query) + + n = 0 + for ssSource in ret: + n += 1 + sourceDict[ssSource["diasourceid"]].update(ssSource) + + sources = [] + for k, v in sourceDict.items(): + sources.append(v) + + return sources diff --git a/src/adler/utilities/AdlerCLIArguments.py b/src/adler/utilities/AdlerCLIArguments.py index bbcb33c..5f3e476 100644 --- a/src/adler/utilities/AdlerCLIArguments.py +++ b/src/adler/utilities/AdlerCLIArguments.py @@ -27,6 +27,8 @@ def __init__(self, args): self.validate_arguments() def validate_arguments(self): + """Checks and validates the command-line arguments.""" + self._validate_filter_list() self._validate_date_range() self._validate_outpath() @@ -41,6 +43,7 @@ def validate_arguments(self): self._validate_sql_filename() def _validate_filter_list(self): + """Validation checks for the filter_list command-line argument.""" expected_filters = ["u", "g", "r", "i", "z", "y"] if not set(self.filter_list).issubset(expected_filters): @@ -52,6 +55,9 @@ def _validate_filter_list(self): ) def _validate_ssObjectId(self): + """ + Validation checks for the ssObjectId command-line argument. + """ try: int(self.ssObjectId) except ValueError: @@ -59,6 +65,9 @@ def _validate_ssObjectId(self): raise ValueError("--ssObjectId command-line argument does not appear to be a valid ssObjectId.") def _validate_date_range(self): + """ + Validation checks for the date_range command-line argument. + """ for d in self.date_range: try: float(d) @@ -79,6 +88,9 @@ def _validate_date_range(self): ) def _validate_outpath(self): + """ + Validation checks for the outpath command-line argument. + """ # make it an absolute path if it's relative! self.outpath = os.path.abspath(self.outpath) @@ -87,6 +99,9 @@ def _validate_outpath(self): raise ValueError("The output path for the command-line argument --outpath cannot be found.") def _validate_ssObjectId_list(self): + """ + Validation checks for the ssObjectId_list command-line argument. + """ self.ssObjectId_list = os.path.abspath(self.ssObjectId_list) if not os.path.exists(self.ssObjectId_list): @@ -98,6 +113,9 @@ def _validate_ssObjectId_list(self): ) def _validate_sql_filename(self): + """ + Validation checks for the sel_filename command-line argument. + """ self.sql_filename = os.path.abspath(self.sql_filename) if not os.path.exists(self.sql_filename):