Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding output CL arguments, setting up logging. #118

Merged
merged 7 commits into from
May 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 29 additions & 1 deletion src/adler/adler.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,26 @@
import logging
import argparse
import astropy.units as u

from adler.dataclasses.AdlerPlanetoid import AdlerPlanetoid
from adler.science.PhaseCurve import PhaseCurve
from adler.utilities.AdlerCLIArguments import AdlerCLIArguments
from adler.utilities.adler_logging import setup_adler_logging

logger = logging.getLogger(__name__)


def runAdler(cli_args):
logger.info("Beginning Adler.")
logger.info("Ingesting all data for object {} from RSP...".format(cli_args.ssObjectId))

planetoid = AdlerPlanetoid.construct_from_RSP(
cli_args.ssObjectId, cli_args.filter_list, cli_args.date_range
)

logger.info("Data successfully ingested.")
logger.info("Calculating phase curves...")

# now let's do some phase curves!

# get the r filter SSObject metadata
Expand All @@ -37,7 +47,7 @@ def runAdler(cli_args):


def main():
parser = argparse.ArgumentParser(description="Runs Adler for a select planetoid and given user input.")
parser = argparse.ArgumentParser(description="Runs Adler for select planetoid(s) and given user input.")

parser.add_argument("-s", "--ssObjectId", help="SSObject ID of planetoid.", type=str, required=True)
parser.add_argument(
Expand All @@ -56,11 +66,29 @@ def main():
type=float,
default=[60000.0, 67300.0],
)
parser.add_argument(
"-o",
"--outpath",
help="Output path location. Default is current working directory.",
type=str,
default="./",
)
parser.add_argument(
"-n",
"--db_name",
help="Stem filename of output database. If this doesn't exist, it will be created. Default: adler_out.",
type=str,
default="adler_out",
)

args = parser.parse_args()

cli_args = AdlerCLIArguments(args)

adler_logger = setup_adler_logging(cli_args.outpath)

cli_args.logger = adler_logger

runAdler(cli_args)


Expand Down
12 changes: 12 additions & 0 deletions src/adler/dataclasses/AdlerData.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import os
import sqlite3
import logging
import numpy as np
from dataclasses import dataclass, field
from datetime import datetime, timezone
Expand All @@ -15,6 +16,8 @@
"phase_parameter_2_err",
]

logger = logging.getLogger(__name__)


@dataclass
class AdlerData:
Expand Down Expand Up @@ -70,10 +73,12 @@ def populate_phase_parameters(self, filter_name, **kwargs):
try:
filter_index = self.filter_list.index(filter_name)
except ValueError:
logger.error("ValueError: Filter {} does not exist in AdlerData.filter_list.".format(filter_name))
raise ValueError("Filter {} does not exist in AdlerData.filter_list.".format(filter_name))

# if model-dependent parameters exist without a model name, return an error
if not kwargs.get("model_name") and any(name in kwargs for name in MODEL_DEPENDENT_KEYS):
logger.error("NameError: No model name given. Cannot update model-specific phase parameters.")
raise NameError("No model name given. Cannot update model-specific phase parameters.")

# update the value if it's in **kwargs
Expand Down Expand Up @@ -163,6 +168,7 @@ def get_phase_parameters_in_filter(self, filter_name, model_name=None):
try:
filter_index = self.filter_list.index(filter_name)
except ValueError:
logger.error("ValueError: Filter {} does not exist in AdlerData.filter_list.".format(filter_name))
raise ValueError("Filter {} does not exist in AdlerData.filter_list.".format(filter_name))

output_obj = PhaseParameterOutput()
Expand All @@ -173,11 +179,17 @@ def get_phase_parameters_in_filter(self, filter_name, model_name=None):
output_obj.arc = self.filter_dependent_values[filter_index].arc

if not model_name:
logger.warn("No model name was specified. Returning non-model-dependent phase parameters.")
print("No model name specified. Returning non-model-dependent phase parameters.")
else:
try:
model_index = self.filter_dependent_values[filter_index].model_list.index(model_name)
except ValueError:
logger.error(
"ValueError: Model {} does not exist for filter {} in AdlerData.model_lists.".format(
model_name, filter_name
)
)
raise ValueError(
"Model {} does not exist for filter {} in AdlerData.model_lists.".format(
model_name, filter_name
Expand Down
38 changes: 34 additions & 4 deletions src/adler/dataclasses/AdlerPlanetoid.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from lsst.rsp import get_tap_service
import pandas as pd
import logging

from adler.dataclasses.Observations import Observations
from adler.dataclasses.MPCORB import MPCORB
from adler.dataclasses.SSObject import SSObject
from adler.dataclasses.AdlerData import AdlerData
from adler.dataclasses.dataclass_utilities import get_data_table

logger = logging.getLogger(__name__)


class AdlerPlanetoid:
"""AdlerPlanetoid class. Contains the Observations, MPCORB and SSObject dataclass objects."""
Expand Down Expand Up @@ -80,19 +83,27 @@ def construct_from_SQL(
"""

if len(date_range) != 2:
logger.error("ValueError: date_range attribute must be of length 2.")
raise ValueError("date_range attribute must be of length 2.")

observations_by_filter = cls.populate_observations(
cls, ssObjectId, filter_list, date_range, sql_filename=sql_filename, schema=schema
)

if len(observations_by_filter) == 0:
logger.error(
"No observations found for this object in the given filter(s). Check SSOID and try again."
)
raise Exception(
"No observations found for this object in the given filter(s). Check SSOID and try again."
)

# redo the filter list based on the available filters in observations_by_filter
filter_list = [obs_object.filter_name for obs_object in observations_by_filter]
if len(filter_list) > len(observations_by_filter):
logger.info(
"Not all specified filters have observations. Recalculating filter list based on past observations."
)
filter_list = [obs_object.filter_name for obs_object in observations_by_filter]
logger.info("New filter list is: {}".format(filter_list))

mpcorb = cls.populate_MPCORB(cls, ssObjectId, sql_filename=sql_filename, schema=schema)
ssobject = cls.populate_SSObject(
Expand Down Expand Up @@ -127,19 +138,29 @@ def construct_from_RSP(
raise Exception("date_range argument must be of length 2.")

service = get_tap_service("ssotap")
logger.info("Getting past observations from DIASource/SSSource...")
observations_by_filter = cls.populate_observations(
cls, ssObjectId, filter_list, date_range, service=service
)

if len(observations_by_filter) == 0:
logger.error(
"No observations found for this object in the given filter(s). Check SSOID and try again."
)
raise Exception(
"No observations found for this object in the given filter(s). Check SSOID and try again."
)

# redo the filter list based on the available filters in observations_by_filter
filter_list = [obs_object.filter_name for obs_object in observations_by_filter]
if len(filter_list) > len(observations_by_filter):
logger.info(
"Not all specified filters have observations. Recalculating filter list based on past observations."
)
filter_list = [obs_object.filter_name for obs_object in observations_by_filter]
logger.info("New filter list is: {}".format(filter_list))

logger.info("Populating MPCORB metadata...")
mpcorb = cls.populate_MPCORB(cls, ssObjectId, service=service)
logger.info("Populating SSObject metadata...")
ssobject = cls.populate_SSObject(cls, ssObjectId, filter_list, service=service)

adler_data = AdlerData(ssObjectId, filter_list)
Expand Down Expand Up @@ -203,6 +224,11 @@ def populate_observations(
data_table = get_data_table(observations_sql_query, service=service, sql_filename=sql_filename)

if len(data_table) == 0:
logger.warning(
"No observations found in {} filter for this object. Skipping this filter.".format(
filter_name
)
)
print(
"WARNING: No observations found in {} filter for this object. Skipping this filter.".format(
filter_name
Expand Down Expand Up @@ -253,6 +279,7 @@ def populate_MPCORB(self, ssObjectId, service=None, sql_filename=None, schema="d
data_table = get_data_table(MPCORB_sql_query, service=service, sql_filename=sql_filename)

if len(data_table) == 0:
logger.error("No MPCORB data for this object could be found for this SSObjectId.")
raise Exception("No MPCORB data for this object could be found for this SSObjectId.")

return MPCORB.construct_from_data_table(ssObjectId, data_table)
Expand Down Expand Up @@ -310,6 +337,7 @@ def populate_SSObject(
data_table = get_data_table(SSObject_sql_query, service=service, sql_filename=sql_filename)

if len(data_table) == 0:
logger.error("No SSObject data for this object could be found for this SSObjectId.")
raise Exception("No SSObject data for this object could be found for this SSObjectId.")

return SSObject.construct_from_data_table(ssObjectId, filter_list, data_table)
Expand All @@ -332,6 +360,7 @@ def observations_in_filter(self, filter_name):
try:
filter_index = self.filter_list.index(filter_name)
except ValueError:
logger.error("ValueError: Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name))
raise ValueError("Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name))

return self.observations_by_filter[filter_index]
Expand All @@ -354,6 +383,7 @@ def SSObject_in_filter(self, filter_name):
try:
filter_index = self.filter_list.index(filter_name)
except ValueError:
logger.error("ValueError: Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name))
raise ValueError("Filter {} is not in AdlerPlanetoid.filter_list.".format(filter_name))

return self.SSObject.filter_dependent_values[filter_index]
12 changes: 12 additions & 0 deletions src/adler/dataclasses/dataclass_utilities.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@
import pandas as pd
import sqlite3
import warnings
import logging

logger = logging.getLogger(__name__)


def get_data_table(sql_query, service=None, sql_filename=None):
Expand Down Expand Up @@ -85,12 +88,18 @@ def get_from_table(data_table, column_name, data_type, table_name="default"):
elif data_type == np.ndarray:
data_val = np.array(data_table[column_name])
else:
logger.error(
"TypeError: Type for argument data_type not recognised for column {} in table {}: must be str, float, int or np.ndarray.".format(
column_name, table_name
)
)
raise TypeError(
"Type for argument data_type not recognised for column {} in table {}: must be str, float, int or np.ndarray.".format(
column_name, table_name
)
)
except ValueError:
logger.error("ValueError: Could not cast column name to type.")
raise ValueError("Could not cast column name to type.")

# here we alert the user if one of the values is unpopulated and change it to a NaN
Expand Down Expand Up @@ -129,6 +138,9 @@ def check_value_populated(data_val, data_type, column_name, table_name):
str_is_empty = data_type == str and len(data_val) == 0

if array_length_zero or number_is_nan or str_is_empty:
logger.warning(
"{} unpopulated in {} table for this object. Storing NaN instead.".format(column_name, table_name)
)
print(
"WARNING: {} unpopulated in {} table for this object. Storing NaN instead.".format(
column_name, table_name
Expand Down
21 changes: 17 additions & 4 deletions src/adler/utilities/AdlerCLIArguments.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
import os


class AdlerCLIArguments:
"""
Class for storing abd validating Adler command-line arguments.
Expand All @@ -13,38 +16,48 @@ def __init__(self, args):
self.ssObjectId = args.ssObjectId
self.filter_list = args.filter_list
self.date_range = args.date_range
self.outpath = args.outpath
self.db_name = args.db_name

self.validate_arguments()

def validate_arguments(self):
self._validate_filter_list()
self._validate_ssObjectId()
self._validate_date_range()
self._validate_outpath()

def _validate_filter_list(self):
expected_filters = ["u", "g", "r", "i", "z", "y"]

if not set(self.filter_list).issubset(expected_filters):
raise ValueError(
"Unexpected filters found in filter_list command-line argument. filter_list must be a list of LSST filters."
"Unexpected filters found in --filter_list command-line argument. --filter_list must be a list of LSST filters."
)

def _validate_ssObjectId(self):
try:
int(self.ssObjectId)
except ValueError:
raise ValueError("ssObjectId command-line argument does not appear to be a valid ssObjectId.")
raise ValueError("--ssObjectId command-line argument does not appear to be a valid ssObjectId.")

def _validate_date_range(self):
for d in self.date_range:
try:
float(d)
except ValueError:
raise ValueError(
"One or both of the values for the date_range command-line argument do not seem to be valid numbers."
"One or both of the values for the --date_range command-line argument do not seem to be valid numbers."
)

if any(d > 250000 for d in self.date_range):
raise ValueError(
"Dates for date_range command-line argument seem rather large. Did you input JD instead of MJD?"
"Dates for --date_range command-line argument seem rather large. Did you input JD instead of MJD?"
)

def _validate_outpath(self):
# make it an absolute path if it's relative!
self.outpath = os.path.abspath(self.outpath)

if not os.path.isdir(self.outpath):
raise ValueError("The output path for the command-line argument --outpath cannot be found.")
42 changes: 42 additions & 0 deletions src/adler/utilities/adler_logging.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import logging
import os
from datetime import datetime


def setup_adler_logging(
log_location,
log_format="%(asctime)s %(name)-12s %(levelname)-8s %(message)s ",
log_name="",
log_file_info="adler.log",
log_file_error="adler.err",
):
log = logging.getLogger(log_name)
log_formatter = logging.Formatter(log_format)

# comment this to suppress console output
# stream_handler = logging.StreamHandler()
# stream_handler.setFormatter(log_formatter)
# log.addHandler(stream_handler)

dstr = datetime.now().strftime("%Y-%m-%d-%H-%M-%S")
cpid = os.getpid()

log_file_info = os.path.join(log_location, dstr + "-p" + str(cpid) + "-" + log_file_info)
log_file_error = os.path.join(log_location, dstr + "-p" + str(cpid) + "-" + log_file_error)

# this log will log pretty much everything: basic info, but also warnings and errors
file_handler_info = logging.FileHandler(log_file_info, mode="w")
file_handler_info.setFormatter(log_formatter)
file_handler_info.setLevel(logging.INFO)
log.addHandler(file_handler_info)

# this log only logs warnings and errors, so they can be looked at quickly without a lot of scrolling
file_handler_error = logging.FileHandler(log_file_error, mode="w")
file_handler_error.setFormatter(log_formatter)
file_handler_error.setLevel(logging.WARN)
log.addHandler(file_handler_error)

# I don't know why we need this line but info logging doesn't work without it, upsettingly
log.setLevel(logging.INFO)

return log
Loading
Loading