Skip to content

Commit

Permalink
modified: src/forcedphot/ephemeris/data_loader.py
Browse files Browse the repository at this point in the history
	modified:   src/forcedphot/ephemeris/ephemeris_client.py
	modified:   src/forcedphot/ephemeris/horizons_interface.py
	modified:   src/forcedphot/ephemeris/miriade_interface.py
	modified:   tests/forcedphot/ephemeris/test_data_loader.py
	modified:   tests/forcedphot/ephemeris/test_ephemeris_client.py
	modified:   tests/forcedphot/ephemeris/test_horizons_interface.py
	modified:   tests/forcedphot/ephemeris/test_miriade_interface.py
  • Loading branch information
szilac committed Jul 23, 2024
1 parent 8b32d83 commit 849e6cd
Show file tree
Hide file tree
Showing 8 changed files with 250 additions and 167 deletions.
52 changes: 32 additions & 20 deletions src/forcedphot/ephemeris/data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,39 +54,49 @@ def load_ephemeris_from_ecsv(file_path: str) -> EphemerisData:
"""
try:
# Read the ECSV file
table = Table.read(file_path, format='ascii.ecsv')
table = Table.read(file_path, format="ascii.ecsv")

# Check if all required columns are present
required_columns = [
'datetime_jd', 'RA_deg', 'DEC_deg', 'RA_rate_arcsec_per_h',
'DEC_rate_arcsec_per_h', 'AZ_deg', 'EL_deg', 'r_au', 'delta_au',
'V_mag', 'alpha_deg', 'RSS_3sigma_arcsec'
"datetime_jd",
"RA_deg",
"DEC_deg",
"RA_rate_arcsec_per_h",
"DEC_rate_arcsec_per_h",
"AZ_deg",
"EL_deg",
"r_au",
"delta_au",
"V_mag",
"alpha_deg",
"RSS_3sigma_arcsec",
]
missing_columns = [col for col in required_columns if col not in table.colnames]
if missing_columns:
raise ValueError(f"Missing columns in ECSV file: {', '.join(missing_columns)}")

# Create and populate the EphemerisData object
ephemeris_data = EphemerisData(
datetime_jd=Time(table['datetime_jd'], format='jd'),
RA_deg=np.array(table['RA_deg']),
DEC_deg=np.array(table['DEC_deg']),
RA_rate_arcsec_per_h=np.array(table['RA_rate_arcsec_per_h']),
DEC_rate_arcsec_per_h=np.array(table['DEC_rate_arcsec_per_h']),
AZ_deg=np.array(table['AZ_deg']),
EL_deg=np.array(table['EL_deg']),
r_au=np.array(table['r_au']),
delta_au=np.array(table['delta_au']),
V_mag=np.array(table['V_mag']),
alpha_deg=np.array(table['alpha_deg']),
RSS_3sigma_arcsec=np.array(table['RSS_3sigma_arcsec'])
datetime_jd=Time(table["datetime_jd"], format="jd"),
RA_deg=np.array(table["RA_deg"]),
DEC_deg=np.array(table["DEC_deg"]),
RA_rate_arcsec_per_h=np.array(table["RA_rate_arcsec_per_h"]),
DEC_rate_arcsec_per_h=np.array(table["DEC_rate_arcsec_per_h"]),
AZ_deg=np.array(table["AZ_deg"]),
EL_deg=np.array(table["EL_deg"]),
r_au=np.array(table["r_au"]),
delta_au=np.array(table["delta_au"]),
V_mag=np.array(table["V_mag"]),
alpha_deg=np.array(table["alpha_deg"]),
RSS_3sigma_arcsec=np.array(table["RSS_3sigma_arcsec"])
)

DataLoader.logger.info(f"Loaded ephemeris data with {len(ephemeris_data.datetime_jd)} points from {file_path}.")
DataLoader.logger.info(
f"Loaded ephemeris data with {len(ephemeris_data.datetime_jd)} points from {file_path}."
)

return ephemeris_data


except FileNotFoundError:
DataLoader.logger.error(f"The file {file_path} was not found.")
raise
Expand Down Expand Up @@ -133,8 +143,10 @@ def load_multiple_ephemeris_files(file_paths: list[str]) -> list[EphemerisData]:
# print(f"Error: {str(e)}")

# Example of loading multiple files
file_paths = ["./Ceres_2024-01-01_00-00-00.000_2025-12-31_23-59-00.000.ecsv",
"./Encke_2024-01-01_00-00-00.000_2024-06-30_23-59-00.000.ecsv"]
file_paths = [
"./Ceres_2024-01-01_00-00-00.000_2025-12-31_23-59-00.000.ecsv",
"./Encke_2024-01-01_00-00-00.000_2024-06-30_23-59-00.000.ecsv"
]
try:
ephemeris_list = DataLoader.load_multiple_ephemeris_files(file_paths)
print(f"Loaded {len(ephemeris_list)} ephemeris files.")
Expand Down
47 changes: 32 additions & 15 deletions src/forcedphot/ephemeris/ephemeris_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,17 @@ def __init__(self):
self.logger = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

def query_single(self, service: str, target: str, target_type: str, start: str, end: str, step: str,
observer_location: str, save_data: bool = DEFAUT_SAVE_DATA) -> Union[QueryInput, None]:
def query_single(
self,
service: str,
target: str,
target_type: str,
start: str,
end: str,
step: str,
observer_location: str,
save_data: bool = DEFAUT_SAVE_DATA,
) -> Union[QueryInput, None]:
"""
Query ephemeris for a single target using the specified service.
Expand Down Expand Up @@ -125,6 +134,7 @@ def load_ephemeris_from_multi_ecsv(self, ecsv_files: list[str]) -> EphemerisData
"""
return DataLoader.load_multiple_ephemeris_files(ecsv_files)


def main():
"""
Main function to handle command-line arguments and execute ephemeris queries.
Expand Down Expand Up @@ -163,20 +173,24 @@ def main():
Returns:
result (list[EphemerisData]): List of ephemeris data as a dataclass.
"""
parser = argparse.ArgumentParser(description=
"Query ephemeris data using Horizons or Miriade services or"
" load ephemeris data from existing ECSV.")
parser = argparse.ArgumentParser(
description="Query ephemeris data using Horizons or Miriade services or"
" load ephemeris data from existing ECSV."
)
parser.add_argument('service', choices=['horizons', 'miriade'], help="Service to use for querying")
parser.add_argument('--ecsv',
help= "Path to ECSV file (or a list separated with ,) containing ephemeris data")
parser.add_argument(
'--ecsv', help= "Path to ECSV file (or a list separated with ,) containing ephemeris data"
)
parser.add_argument('--csv', help="Path to CSV file for batch processing")
parser.add_argument('--target', help="Target object for single query")
parser.add_argument('--target_type', help="Target object type for single query")
parser.add_argument('--start', help="Start time for single query")
parser.add_argument('--end', help="End time for single query")
parser.add_argument('--step', help="Time step for single query")
parser.add_argument('--location', default=EphemerisClient.DEFAULT_OBSERVER_LOCATION,
help="Observer location code, default: Rubin(X05)")
parser.add_argument(
'--location', default=EphemerisClient.DEFAULT_OBSERVER_LOCATION,
help="Observer location code, default: Rubin(X05)",
)
parser.add_argument('--save_data', action='store_true', help="Save query results as ECSV files")

args = parser.parse_args()
Expand All @@ -186,19 +200,22 @@ def main():
if args.csv:
results = client.query_from_csv(args.service, args.csv, args.location)
elif all([args.target, args.target_type, args.start, args.end, args.step]):
result = client.query_single(args.service, args.target, args.target_type, args.start, args.end,
args.step, args.location)
result = client.query_single(
args.service, args.target, args.target_type, args.start, args.end, args.step, args.location
)
results = [result] if result else []
elif args.ecsv:
ecsv_files = args.ecsv.split(',') # Assume multiple files are comma-separated
ecsv_files = args.ecsv.split(",") # Assume multiple files are comma-separated
if len(ecsv_files) > 1:
results = client.load_ephemeris_from_multi_ecsv(ecsv_files)
else:
results = client.load_ephemeris_from_ecsv(args.ecsv)
else:
parser.error("Either provide a CSV file or all single query parameters"
" like target, target_type,start, end, step"
" or ECSV file containing ephemeris data")
parser.error(
"Either provide a CSV file or all single query parameters"
" like target, target_type,start, end, step"
" or ECSV file containing ephemeris data"
)

if results:
print(f"Successfully queried {len(results)} object(s)")
Expand Down
10 changes: 6 additions & 4 deletions src/forcedphot/ephemeris/horizons_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -187,7 +187,7 @@ def save_horizons_data_to_ecsv(self, query_input, ephemeris_data):
self.logger.info(f"Ephemeris data successfully saved to {output_filename}")


def query_single_range(self, query: QueryInput, save_data: bool = False) -> QueryResult:
def query_single_range(self, query: QueryInput, save_data: bool = False) -> QueryResult:
"""
Query ephemeris for a single time range.
Expand Down Expand Up @@ -231,8 +231,10 @@ def query_single_range(self, query: QueryInput, save_data: bool = False) -> Que

if query.target_type == "comet_name":
mag_type = "Tmag"
ephemeris = obj.ephemerides(closest_apparition=True, no_fragments=True,
skip_daylight=True)
ephemeris = obj.ephemerides(
closest_apparition=True, no_fragments=True,
skip_daylight=True
)

else:
mag_type = "V"
Expand Down Expand Up @@ -281,7 +283,7 @@ def query_single_range(self, query: QueryInput, save_data: bool = False) -> Que

@classmethod
def query_ephemeris_from_csv(
cls, csv_filename: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False
cls, csv_filename: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False
):
"""
Query ephemeris for multiple celestial objects from JPL Horizons based on a CSV file and save
Expand Down
72 changes: 45 additions & 27 deletions src/forcedphot/ephemeris/miriade_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,10 @@ def set_target_type(self, target_type):
elif target_type == "designation":
return "comet"
else:
raise ValueError(f"Unsupported target type: {target_type}. Please chose"
f"from 'smallbody', 'comet_name', 'asteroid_name', or 'designation'.")
raise ValueError(
f"Unsupported target type: {target_type}. Please chose"
f"from 'smallbody', 'comet_name', 'asteroid_name', or 'designation'."
)


def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade:
Expand Down Expand Up @@ -123,21 +125,21 @@ def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade:
Raises:
-------
ValueError
If the step unit in the input query is not recognized (valid units are 's', 'm', 'h', 'd').
If the step unit in the input query is not recognized (valid units are "s", "m", "h", "d").
Notes:
------
The method supports time steps in seconds ('s'), minutes ('m'), hours ('h'), or days ('d').
The method supports time steps in seconds ("s"), minutes ("m"), hours ("h"), or days ("d").
"""

value, unit = int(query.step[:-1]), query.step[-1]
if unit == 's':
if unit == "s":
step_freqency = value * u.s
elif unit == 'm':
elif unit == "m":
step_freqency = value * u.min
elif unit == 'h':
elif unit == "h":
step_freqency = value * u.hour
elif unit == 'd':
elif unit == "d":
step_freqency = value * u.day
else:
raise ValueError("Error in the input field.")
Expand All @@ -148,7 +150,7 @@ def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade:

if query.target_type == "comet_name":
resolved_target = ESASky.find_sso(sso_name=query.target, sso_type="COMET")
sso_name=resolved_target[0].get("sso_name")
sso_name = resolved_target[0].get("sso_name")
else:
sso_name = query.target

Expand All @@ -157,11 +159,11 @@ def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade:
objtype=self.set_target_type(query.target_type),
start=query.start,
step=query.step,
nsteps=nsteps
nsteps=nsteps,
)

return query_miriade

def save_miriade_data_to_ecsv(self, query_input, ephemeris_data):
"""
Save queried ephemeris data to an ECSV file.
Expand Down Expand Up @@ -254,26 +256,41 @@ def query_single_range(self, query: QueryInput, save_data: bool = False):
query_miriade = self.calc_nsteps_for_miriade_query(query)

# Query Miriade
ephemeris = Miriade.get_ephemerides(targetname = query_miriade.target,
objtype = query_miriade.objtype,
location = self.observer_location,
epoch = query_miriade.start,
epoch_step = query_miriade.step,
epoch_nsteps = query_miriade.nsteps, coordtype = 5)
ephemeris = Miriade.get_ephemerides(
targetname = query_miriade.target,
objtype = query_miriade.objtype,
location = self.observer_location,
epoch = query_miriade.start,
epoch_step = query_miriade.step,
epoch_nsteps = query_miriade.nsteps,
coordtype = 5,
)

end_time = time.time()
self.logger.info(f"Query for range {query_miriade.start} with {query_miriade.nsteps}"
f" completed in {end_time - start_time} seconds.")
self.logger.info(
f"Query for range {query_miriade.start} with {query_miriade.nsteps}"
f" completed in {end_time - start_time} seconds.")

# Selecting relevant columns
relevant_columns = ['epoch','RAJ2000', 'DECJ2000', 'RAcosD_rate', 'DEC_rate',
'AZ', 'EL', 'heldist', 'delta', 'V', 'alpha', 'posunc']
relevant_columns = [
"epoch",
"RAJ2000",
"DECJ2000",
"RAcosD_rate",
"DEC_rate",
"AZ",
"EL",
"heldist",
"delta",
"V",
"alpha",
"posunc"
]
relevant_data = ephemeris[relevant_columns]

ephemeris_data = EphemerisData()

if ephemeris_data is not None:

ephemeris_data.datetime_jd = relevant_data["epoch"]
ephemeris_data.RA_deg = relevant_data["RAJ2000"]
ephemeris_data.DEC_deg = relevant_data["DECJ2000"]
Expand All @@ -294,16 +311,17 @@ def query_single_range(self, query: QueryInput, save_data: bool = False):
return QueryResult(query.target, query.start, query.end, ephemeris_data)

except Exception as e:
self.logger.error(f"An error occurred during query for range {query_miriade.start}"
f"with {query_miriade.nsteps} for target {query_miriade.target}")
self.logger.error(
f"An error occurred during query for range {query_miriade.start}"
f"with {query_miriade.nsteps} for target {query_miriade.target}"
)
self.logger.error(f"Error details: {str(e)}")

return None


@classmethod
def query_ephemeris_from_csv(
cls, csv_file: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False
cls, csv_file: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False
):
"""
Process multiple ephemeris queries from a CSV file and save results as ECSV files.
Expand Down Expand Up @@ -370,7 +388,7 @@ def query_ephemeris_from_csv(

if save_data:
# Save the queried ephemeris data to ECSV file
miriade_interface.save_miriade_data_to_ecsv(query, query_result.ephemeris_data)
miriade_interface.save_miriade_data_to_ecsv(query, query_result.ephemeris)

total_end_time = time.time()
cls.logger.info(
Expand Down
Loading

0 comments on commit 849e6cd

Please sign in to comment.