From 849e6cd0a5327a11ac9839e6bc4d5469e014069c Mon Sep 17 00:00:00 2001 From: szilac Date: Tue, 23 Jul 2024 18:27:36 +0200 Subject: [PATCH] modified: src/forcedphot/ephemeris/data_loader.py modified: src/forcedphot/ephemeris/ephemeris_client.py modified: src/forcedphot/ephemeris/horizons_interface.py modified: src/forcedphot/ephemeris/miriade_interface.py modified: tests/forcedphot/ephemeris/test_data_loader.py modified: tests/forcedphot/ephemeris/test_ephemeris_client.py modified: tests/forcedphot/ephemeris/test_horizons_interface.py modified: tests/forcedphot/ephemeris/test_miriade_interface.py --- src/forcedphot/ephemeris/data_loader.py | 52 +++++++----- src/forcedphot/ephemeris/ephemeris_client.py | 47 +++++++---- .../ephemeris/horizons_interface.py | 10 ++- src/forcedphot/ephemeris/miriade_interface.py | 72 ++++++++++------ .../forcedphot/ephemeris/test_data_loader.py | 84 ++++++++++--------- .../ephemeris/test_ephemeris_client.py | 71 +++++++++------- .../ephemeris/test_horizons_interface.py | 14 +++- .../ephemeris/test_miriade_interface.py | 67 ++++++++------- 8 files changed, 250 insertions(+), 167 deletions(-) diff --git a/src/forcedphot/ephemeris/data_loader.py b/src/forcedphot/ephemeris/data_loader.py index 2948684..1ebb95f 100644 --- a/src/forcedphot/ephemeris/data_loader.py +++ b/src/forcedphot/ephemeris/data_loader.py @@ -54,13 +54,22 @@ def load_ephemeris_from_ecsv(file_path: str) -> EphemerisData: """ try: # Read the ECSV file - table = Table.read(file_path, format='ascii.ecsv') + table = Table.read(file_path, format="ascii.ecsv") # Check if all required columns are present required_columns = [ - 'datetime_jd', 'RA_deg', 'DEC_deg', 'RA_rate_arcsec_per_h', - 'DEC_rate_arcsec_per_h', 'AZ_deg', 'EL_deg', 'r_au', 'delta_au', - 'V_mag', 'alpha_deg', 'RSS_3sigma_arcsec' + "datetime_jd", + "RA_deg", + "DEC_deg", + "RA_rate_arcsec_per_h", + "DEC_rate_arcsec_per_h", + "AZ_deg", + "EL_deg", + "r_au", + "delta_au", + "V_mag", + "alpha_deg", + "RSS_3sigma_arcsec", ] missing_columns = [col for col in required_columns if col not in table.colnames] if missing_columns: @@ -68,25 +77,26 @@ def load_ephemeris_from_ecsv(file_path: str) -> EphemerisData: # Create and populate the EphemerisData object ephemeris_data = EphemerisData( - datetime_jd=Time(table['datetime_jd'], format='jd'), - RA_deg=np.array(table['RA_deg']), - DEC_deg=np.array(table['DEC_deg']), - RA_rate_arcsec_per_h=np.array(table['RA_rate_arcsec_per_h']), - DEC_rate_arcsec_per_h=np.array(table['DEC_rate_arcsec_per_h']), - AZ_deg=np.array(table['AZ_deg']), - EL_deg=np.array(table['EL_deg']), - r_au=np.array(table['r_au']), - delta_au=np.array(table['delta_au']), - V_mag=np.array(table['V_mag']), - alpha_deg=np.array(table['alpha_deg']), - RSS_3sigma_arcsec=np.array(table['RSS_3sigma_arcsec']) + datetime_jd=Time(table["datetime_jd"], format="jd"), + RA_deg=np.array(table["RA_deg"]), + DEC_deg=np.array(table["DEC_deg"]), + RA_rate_arcsec_per_h=np.array(table["RA_rate_arcsec_per_h"]), + DEC_rate_arcsec_per_h=np.array(table["DEC_rate_arcsec_per_h"]), + AZ_deg=np.array(table["AZ_deg"]), + EL_deg=np.array(table["EL_deg"]), + r_au=np.array(table["r_au"]), + delta_au=np.array(table["delta_au"]), + V_mag=np.array(table["V_mag"]), + alpha_deg=np.array(table["alpha_deg"]), + RSS_3sigma_arcsec=np.array(table["RSS_3sigma_arcsec"]) ) - DataLoader.logger.info(f"Loaded ephemeris data with {len(ephemeris_data.datetime_jd)} points from {file_path}.") + DataLoader.logger.info( + f"Loaded ephemeris data with {len(ephemeris_data.datetime_jd)} points from {file_path}." + ) return ephemeris_data - except FileNotFoundError: DataLoader.logger.error(f"The file {file_path} was not found.") raise @@ -133,8 +143,10 @@ def load_multiple_ephemeris_files(file_paths: list[str]) -> list[EphemerisData]: # print(f"Error: {str(e)}") # Example of loading multiple files - file_paths = ["./Ceres_2024-01-01_00-00-00.000_2025-12-31_23-59-00.000.ecsv", - "./Encke_2024-01-01_00-00-00.000_2024-06-30_23-59-00.000.ecsv"] + file_paths = [ + "./Ceres_2024-01-01_00-00-00.000_2025-12-31_23-59-00.000.ecsv", + "./Encke_2024-01-01_00-00-00.000_2024-06-30_23-59-00.000.ecsv" + ] try: ephemeris_list = DataLoader.load_multiple_ephemeris_files(file_paths) print(f"Loaded {len(ephemeris_list)} ephemeris files.") diff --git a/src/forcedphot/ephemeris/ephemeris_client.py b/src/forcedphot/ephemeris/ephemeris_client.py index 77ec133..396b13a 100644 --- a/src/forcedphot/ephemeris/ephemeris_client.py +++ b/src/forcedphot/ephemeris/ephemeris_client.py @@ -40,8 +40,17 @@ def __init__(self): self.logger = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) - def query_single(self, service: str, target: str, target_type: str, start: str, end: str, step: str, - observer_location: str, save_data: bool = DEFAUT_SAVE_DATA) -> Union[QueryInput, None]: + def query_single( + self, + service: str, + target: str, + target_type: str, + start: str, + end: str, + step: str, + observer_location: str, + save_data: bool = DEFAUT_SAVE_DATA, + ) -> Union[QueryInput, None]: """ Query ephemeris for a single target using the specified service. @@ -125,6 +134,7 @@ def load_ephemeris_from_multi_ecsv(self, ecsv_files: list[str]) -> EphemerisData """ return DataLoader.load_multiple_ephemeris_files(ecsv_files) + def main(): """ Main function to handle command-line arguments and execute ephemeris queries. @@ -163,20 +173,24 @@ def main(): Returns: result (list[EphemerisData]): List of ephemeris data as a dataclass. """ - parser = argparse.ArgumentParser(description= - "Query ephemeris data using Horizons or Miriade services or" - " load ephemeris data from existing ECSV.") + parser = argparse.ArgumentParser( + description="Query ephemeris data using Horizons or Miriade services or" + " load ephemeris data from existing ECSV." + ) parser.add_argument('service', choices=['horizons', 'miriade'], help="Service to use for querying") - parser.add_argument('--ecsv', - help= "Path to ECSV file (or a list separated with ,) containing ephemeris data") + parser.add_argument( + '--ecsv', help= "Path to ECSV file (or a list separated with ,) containing ephemeris data" + ) parser.add_argument('--csv', help="Path to CSV file for batch processing") parser.add_argument('--target', help="Target object for single query") parser.add_argument('--target_type', help="Target object type for single query") parser.add_argument('--start', help="Start time for single query") parser.add_argument('--end', help="End time for single query") parser.add_argument('--step', help="Time step for single query") - parser.add_argument('--location', default=EphemerisClient.DEFAULT_OBSERVER_LOCATION, - help="Observer location code, default: Rubin(X05)") + parser.add_argument( + '--location', default=EphemerisClient.DEFAULT_OBSERVER_LOCATION, + help="Observer location code, default: Rubin(X05)", + ) parser.add_argument('--save_data', action='store_true', help="Save query results as ECSV files") args = parser.parse_args() @@ -186,19 +200,22 @@ def main(): if args.csv: results = client.query_from_csv(args.service, args.csv, args.location) elif all([args.target, args.target_type, args.start, args.end, args.step]): - result = client.query_single(args.service, args.target, args.target_type, args.start, args.end, - args.step, args.location) + result = client.query_single( + args.service, args.target, args.target_type, args.start, args.end, args.step, args.location + ) results = [result] if result else [] elif args.ecsv: - ecsv_files = args.ecsv.split(',') # Assume multiple files are comma-separated + ecsv_files = args.ecsv.split(",") # Assume multiple files are comma-separated if len(ecsv_files) > 1: results = client.load_ephemeris_from_multi_ecsv(ecsv_files) else: results = client.load_ephemeris_from_ecsv(args.ecsv) else: - parser.error("Either provide a CSV file or all single query parameters" - " like target, target_type,start, end, step" - " or ECSV file containing ephemeris data") + parser.error( + "Either provide a CSV file or all single query parameters" + " like target, target_type,start, end, step" + " or ECSV file containing ephemeris data" + ) if results: print(f"Successfully queried {len(results)} object(s)") diff --git a/src/forcedphot/ephemeris/horizons_interface.py b/src/forcedphot/ephemeris/horizons_interface.py index 1d9cad4..3ac9657 100644 --- a/src/forcedphot/ephemeris/horizons_interface.py +++ b/src/forcedphot/ephemeris/horizons_interface.py @@ -187,7 +187,7 @@ def save_horizons_data_to_ecsv(self, query_input, ephemeris_data): self.logger.info(f"Ephemeris data successfully saved to {output_filename}") - def query_single_range(self, query: QueryInput, save_data: bool = False) -> QueryResult: + def query_single_range(self, query: QueryInput, save_data: bool = False) -> QueryResult: """ Query ephemeris for a single time range. @@ -231,8 +231,10 @@ def query_single_range(self, query: QueryInput, save_data: bool = False) -> Que if query.target_type == "comet_name": mag_type = "Tmag" - ephemeris = obj.ephemerides(closest_apparition=True, no_fragments=True, - skip_daylight=True) + ephemeris = obj.ephemerides( + closest_apparition=True, no_fragments=True, + skip_daylight=True + ) else: mag_type = "V" @@ -281,7 +283,7 @@ def query_single_range(self, query: QueryInput, save_data: bool = False) -> Que @classmethod def query_ephemeris_from_csv( - cls, csv_filename: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False + cls, csv_filename: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False ): """ Query ephemeris for multiple celestial objects from JPL Horizons based on a CSV file and save diff --git a/src/forcedphot/ephemeris/miriade_interface.py b/src/forcedphot/ephemeris/miriade_interface.py index 238568d..c6e8d8c 100644 --- a/src/forcedphot/ephemeris/miriade_interface.py +++ b/src/forcedphot/ephemeris/miriade_interface.py @@ -92,8 +92,10 @@ def set_target_type(self, target_type): elif target_type == "designation": return "comet" else: - raise ValueError(f"Unsupported target type: {target_type}. Please chose" - f"from 'smallbody', 'comet_name', 'asteroid_name', or 'designation'.") + raise ValueError( + f"Unsupported target type: {target_type}. Please chose" + f"from 'smallbody', 'comet_name', 'asteroid_name', or 'designation'." + ) def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade: @@ -123,21 +125,21 @@ def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade: Raises: ------- ValueError - If the step unit in the input query is not recognized (valid units are 's', 'm', 'h', 'd'). + If the step unit in the input query is not recognized (valid units are "s", "m", "h", "d"). Notes: ------ - The method supports time steps in seconds ('s'), minutes ('m'), hours ('h'), or days ('d'). + The method supports time steps in seconds ("s"), minutes ("m"), hours ("h"), or days ("d"). """ value, unit = int(query.step[:-1]), query.step[-1] - if unit == 's': + if unit == "s": step_freqency = value * u.s - elif unit == 'm': + elif unit == "m": step_freqency = value * u.min - elif unit == 'h': + elif unit == "h": step_freqency = value * u.hour - elif unit == 'd': + elif unit == "d": step_freqency = value * u.day else: raise ValueError("Error in the input field.") @@ -148,7 +150,7 @@ def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade: if query.target_type == "comet_name": resolved_target = ESASky.find_sso(sso_name=query.target, sso_type="COMET") - sso_name=resolved_target[0].get("sso_name") + sso_name = resolved_target[0].get("sso_name") else: sso_name = query.target @@ -157,11 +159,11 @@ def calc_nsteps_for_miriade_query(self, query: QueryInput) -> QueryInputMiriade: objtype=self.set_target_type(query.target_type), start=query.start, step=query.step, - nsteps=nsteps + nsteps=nsteps, ) return query_miriade - + def save_miriade_data_to_ecsv(self, query_input, ephemeris_data): """ Save queried ephemeris data to an ECSV file. @@ -254,26 +256,41 @@ def query_single_range(self, query: QueryInput, save_data: bool = False): query_miriade = self.calc_nsteps_for_miriade_query(query) # Query Miriade - ephemeris = Miriade.get_ephemerides(targetname = query_miriade.target, - objtype = query_miriade.objtype, - location = self.observer_location, - epoch = query_miriade.start, - epoch_step = query_miriade.step, - epoch_nsteps = query_miriade.nsteps, coordtype = 5) + ephemeris = Miriade.get_ephemerides( + targetname = query_miriade.target, + objtype = query_miriade.objtype, + location = self.observer_location, + epoch = query_miriade.start, + epoch_step = query_miriade.step, + epoch_nsteps = query_miriade.nsteps, + coordtype = 5, + ) end_time = time.time() - self.logger.info(f"Query for range {query_miriade.start} with {query_miriade.nsteps}" - f" completed in {end_time - start_time} seconds.") + self.logger.info( + f"Query for range {query_miriade.start} with {query_miriade.nsteps}" + f" completed in {end_time - start_time} seconds.") # Selecting relevant columns - relevant_columns = ['epoch','RAJ2000', 'DECJ2000', 'RAcosD_rate', 'DEC_rate', - 'AZ', 'EL', 'heldist', 'delta', 'V', 'alpha', 'posunc'] + relevant_columns = [ + "epoch", + "RAJ2000", + "DECJ2000", + "RAcosD_rate", + "DEC_rate", + "AZ", + "EL", + "heldist", + "delta", + "V", + "alpha", + "posunc" + ] relevant_data = ephemeris[relevant_columns] ephemeris_data = EphemerisData() if ephemeris_data is not None: - ephemeris_data.datetime_jd = relevant_data["epoch"] ephemeris_data.RA_deg = relevant_data["RAJ2000"] ephemeris_data.DEC_deg = relevant_data["DECJ2000"] @@ -294,16 +311,17 @@ def query_single_range(self, query: QueryInput, save_data: bool = False): return QueryResult(query.target, query.start, query.end, ephemeris_data) except Exception as e: - self.logger.error(f"An error occurred during query for range {query_miriade.start}" - f"with {query_miriade.nsteps} for target {query_miriade.target}") + self.logger.error( + f"An error occurred during query for range {query_miriade.start}" + f"with {query_miriade.nsteps} for target {query_miriade.target}" + ) self.logger.error(f"Error details: {str(e)}") return None - @classmethod def query_ephemeris_from_csv( - cls, csv_file: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False + cls, csv_file: str, observer_location=DEFAULT_OBSERVER_LOCATION, save_data: bool = False ): """ Process multiple ephemeris queries from a CSV file and save results as ECSV files. @@ -370,7 +388,7 @@ def query_ephemeris_from_csv( if save_data: # Save the queried ephemeris data to ECSV file - miriade_interface.save_miriade_data_to_ecsv(query, query_result.ephemeris_data) + miriade_interface.save_miriade_data_to_ecsv(query, query_result.ephemeris) total_end_time = time.time() cls.logger.info( diff --git a/tests/forcedphot/ephemeris/test_data_loader.py b/tests/forcedphot/ephemeris/test_data_loader.py index ed3d7e1..e5128e2 100644 --- a/tests/forcedphot/ephemeris/test_data_loader.py +++ b/tests/forcedphot/ephemeris/test_data_loader.py @@ -1,35 +1,38 @@ -import pytest import numpy as np -from astropy.time import Time +import pytest from astropy.table import Table from forcedphot.ephemeris.data_loader import DataLoader from forcedphot.ephemeris.local_dataclasses import EphemerisData + @pytest.fixture def sample_ecsv_file(tmp_path): """Create a sample ECSV file for testing.""" file_path = tmp_path / "test_ephemeris.ecsv" - data = Table({ - 'datetime_jd': [2459000.5, 2459001.5], - 'RA_deg': [100.0, 101.0], - 'DEC_deg': [-20.0, -19.5], - 'RA_rate_arcsec_per_h': [0.1, 0.2], - 'DEC_rate_arcsec_per_h': [-0.1, -0.2], - 'AZ_deg': [180.0, 185.0], - 'EL_deg': [45.0, 46.0], - 'r_au': [1.0, 1.1], - 'delta_au': [0.5, 0.6], - 'V_mag': [15.0, 15.1], - 'alpha_deg': [30.0, 31.0], - 'RSS_3sigma_arcsec': [0.01, 0.02] - }) - data.write(file_path, format='ascii.ecsv') + data = Table( + { + "datetime_jd": [2459000.5, 2459001.5], + "RA_deg": [100.0, 101.0], + "DEC_deg": [-20.0, -19.5], + "RA_rate_arcsec_per_h": [0.1, 0.2], + "DEC_rate_arcsec_per_h": [-0.1, -0.2], + "AZ_deg": [180.0, 185.0], + "EL_deg": [45.0, 46.0], + "r_au": [1.0, 1.1], + "delta_au": [0.5, 0.6], + "V_mag": [15.0, 15.1], + "alpha_deg": [30.0, 31.0], + "RSS_3sigma_arcsec": [0.01, 0.02] + } + ) + data.write(file_path, format="ascii.ecsv") return file_path + def test_load_ephemeris_from_ecsv(sample_ecsv_file): """Test loading ephemeris data from a valid ECSV file.""" ephemeris_data = DataLoader.load_ephemeris_from_ecsv(sample_ecsv_file) - + assert isinstance(ephemeris_data, EphemerisData) assert len(ephemeris_data.datetime_jd) == 2 assert np.allclose(ephemeris_data.RA_deg, [100.0, 101.0]) @@ -52,8 +55,8 @@ def test_load_ephemeris_from_nonexistent_file(): def test_load_ephemeris_from_invalid_file(tmp_path): """Test loading ephemeris data from an invalid ECSV file (missing columns).""" invalid_file = tmp_path / "invalid_ephemeris.ecsv" - data = Table({'datetime_jd': [2459000.5], 'RA_deg': [100.0]}) # Missing columns - data.write(invalid_file, format='ascii.ecsv') + data = Table({"datetime_jd": [2459000.5], "RA_deg": [100.0]}) # Missing columns + data.write(invalid_file, format="ascii.ecsv") with pytest.raises(ValueError): DataLoader.load_ephemeris_from_ecsv(invalid_file) @@ -61,34 +64,37 @@ def test_load_ephemeris_from_invalid_file(tmp_path): def test_load_multiple_ephemeris_files(sample_ecsv_file, tmp_path): """Test loading multiple ephemeris files.""" second_file = tmp_path / "test_ephemeris2.ecsv" - data = Table({ - 'datetime_jd': [2459002.5], - 'RA_deg': [102.0], - 'DEC_deg': [-19.0], - 'RA_rate_arcsec_per_h': [0.3], - 'DEC_rate_arcsec_per_h': [-0.3], - 'AZ_deg': [190.0], - 'EL_deg': [47.0], - 'r_au': [1.2], - 'delta_au': [0.7], - 'V_mag': [15.2], - 'alpha_deg': [32.0], - 'RSS_3sigma_arcsec': [0.03] - }) - data.write(second_file, format='ascii.ecsv') - + data = Table( + { + "datetime_jd": [2459002.5], + "RA_deg": [102.0], + "DEC_deg": [-19.0], + "RA_rate_arcsec_per_h": [0.3], + "DEC_rate_arcsec_per_h": [-0.3], + "AZ_deg": [190.0], + "EL_deg": [47.0], + "r_au": [1.2], + "delta_au": [0.7], + "V_mag": [15.2], + "alpha_deg": [32.0], + "RSS_3sigma_arcsec": [0.03] + } + ) + data.write(second_file, format="ascii.ecsv") + file_paths = [sample_ecsv_file, second_file] ephemeris_list = DataLoader.load_multiple_ephemeris_files(file_paths) - + assert len(ephemeris_list) == 2 assert isinstance(ephemeris_list[0], EphemerisData) assert isinstance(ephemeris_list[1], EphemerisData) assert len(ephemeris_list[0].datetime_jd) == 2 assert len(ephemeris_list[1].datetime_jd) == 1 + def test_load_multiple_ephemeris_files_with_error(sample_ecsv_file): """Test loading multiple ephemeris files with one non-existent file.""" file_paths = [sample_ecsv_file, "nonexistent_file.ecsv"] - + with pytest.raises(FileNotFoundError): - DataLoader.load_multiple_ephemeris_files(file_paths) \ No newline at end of file + DataLoader.load_multiple_ephemeris_files(file_paths) diff --git a/tests/forcedphot/ephemeris/test_ephemeris_client.py b/tests/forcedphot/ephemeris/test_ephemeris_client.py index 4d63a5e..e74cbe3 100644 --- a/tests/forcedphot/ephemeris/test_ephemeris_client.py +++ b/tests/forcedphot/ephemeris/test_ephemeris_client.py @@ -1,8 +1,11 @@ + +from unittest.mock import Mock, patch, ANY + import pytest from astropy.time import Time -from unittest.mock import Mock, patch, ANY from forcedphot.ephemeris.ephemeris_client import EphemerisClient -from forcedphot.ephemeris.local_dataclasses import QueryInput, QueryResult, EphemerisData +from forcedphot.ephemeris.local_dataclasses import QueryResult, EphemerisData + @pytest.fixture def ephemeris_client(): @@ -11,22 +14,25 @@ def ephemeris_client(): """ return EphemerisClient() + @pytest.fixture def mock_horizons_interface(): """ Fixture to mock the HorizonsInterface class for use in tests. """ - with patch('forcedphot.ephemeris.ephemeris_client.HorizonsInterface') as mock: + with patch("forcedphot.ephemeris.ephemeris_client.HorizonsInterface") as mock: yield mock + @pytest.fixture def mock_miriade_interface(): """ Fixture to mock the MiriadeInterface class for use in tests. """ - with patch('forcedphot.ephemeris.ephemeris_client.MiriadeInterface') as mock: + with patch("forcedphot.ephemeris.ephemeris_client.MiriadeInterface") as mock: yield mock + def test_query_single_horizons(ephemeris_client, mock_horizons_interface): """ Test the query_single method of EphemerisClient when using the JPL Horizons service. @@ -34,21 +40,19 @@ def test_query_single_horizons(ephemeris_client, mock_horizons_interface): mock_horizons_instance = Mock() mock_horizons_interface.return_value = mock_horizons_instance mock_horizons_instance.query_single_range.return_value = QueryResult( - target="Ceres", - start=Time("2023-01-01"), - end=Time("2023-01-02"), - ephemeris=EphemerisData() + target="Ceres", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData() ) result = ephemeris_client.query_single( - 'horizons', 'Ceres', 'smallbody', '2023-01-01', '2023-01-02', '1h', 'X05' + "horizons", "Ceres", "smallbody", "2023-01-01", "2023-01-02", "1h", "X05" ) assert isinstance(result, QueryResult) assert result.target == "Ceres" - mock_horizons_interface.assert_called_once_with('X05') + mock_horizons_interface.assert_called_once_with("X05") mock_horizons_instance.query_single_range.assert_called_once_with(ANY, save_data=False) + def test_query_single_miriade(ephemeris_client, mock_miriade_interface): """ Test the query_single method of EphemerisClient when using the Miriade service. @@ -56,67 +60,76 @@ def test_query_single_miriade(ephemeris_client, mock_miriade_interface): mock_miriade_instance = Mock() mock_miriade_interface.return_value = mock_miriade_instance mock_miriade_instance.query_single_range.return_value = QueryResult( - target="Encke", - start=Time("2023-01-01"), - end=Time("2023-01-02"), - ephemeris=EphemerisData() + target="Encke", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData() ) result = ephemeris_client.query_single( - 'miriade', 'Encke', 'comet_name', '2023-01-01', '2023-01-02', '1h', 'X05' + "miriade", "Encke", "comet_name", "2023-01-01", "2023-01-02", "1h", "X05" ) assert isinstance(result, QueryResult) assert result.target == "Encke" - mock_miriade_interface.assert_called_once_with('X05') + mock_miriade_interface.assert_called_once_with("X05") mock_miriade_instance.query_single_range.assert_called_once_with(ANY, save_data=False) + def test_query_single_invalid_service(ephemeris_client): """ Test the query_single method of EphemerisClient with an invalid service. """ result = ephemeris_client.query_single( - 'invalid_service', 'Ceres', 'smallbody', '2023-01-01', '2023-01-02', '1h', 'X05' + "invalid_service", "Ceres", "smallbody", "2023-01-01", "2023-01-02", "1h", "X05" ) assert result is None -@patch('forcedphot.ephemeris.horizons_interface.HorizonsInterface.query_ephemeris_from_csv') + +@patch("forcedphot.ephemeris.horizons_interface.HorizonsInterface.query_ephemeris_from_csv") def test_query_from_csv_horizons(mock_query_csv, ephemeris_client): """ Test the query_from_csv method of EphemerisClient when using the JPL Horizons service. """ mock_query_csv.return_value = [ - QueryResult(target="Ceres", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData()), - QueryResult(target="Vesta", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData()) + QueryResult( + target="Ceres", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData() + ), + QueryResult( + target="Vesta", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData() + ), ] - results = ephemeris_client.query_from_csv('horizons', 'test.csv', 'X05') + results = ephemeris_client.query_from_csv("horizons", "test.csv", "X05") assert len(results) == 2 assert all(isinstance(result, QueryResult) for result in results) - mock_query_csv.assert_called_once_with('test.csv', 'X05', save_data=False) + mock_query_csv.assert_called_once_with("test.csv", "X05", save_data=False) -@patch('forcedphot.ephemeris.miriade_interface.MiriadeInterface.query_ephemeris_from_csv') + +@patch("forcedphot.ephemeris.miriade_interface.MiriadeInterface.query_ephemeris_from_csv") def test_query_from_csv_miriade(mock_query_csv, ephemeris_client): """ Test the query_from_csv method of EphemerisClient when using the Miriade service. """ mock_query_csv.return_value = [ - QueryResult(target="Encke", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData()), - QueryResult(target="Halley", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData()) + QueryResult( + target="Encke", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData() + ), + QueryResult( + target="Halley", start=Time("2023-01-01"), end=Time("2023-01-02"), ephemeris=EphemerisData() + ), ] - results = ephemeris_client.query_from_csv('miriade', 'test.csv', 'X05') + results = ephemeris_client.query_from_csv("miriade", "test.csv", "X05") assert len(results) == 2 assert all(isinstance(result, QueryResult) for result in results) - mock_query_csv.assert_called_once_with('test.csv', 'X05', save_data=False) + mock_query_csv.assert_called_once_with("test.csv", "X05", save_data=False) + def test_query_from_csv_invalid_service(ephemeris_client): """ Test the query_from_csv method of EphemerisClient with an invalid service. """ - result = ephemeris_client.query_from_csv('invalid_service', 'test.csv', 'X05') + result = ephemeris_client.query_from_csv("invalid_service", "test.csv", "X05") - assert result is None \ No newline at end of file + assert result is None diff --git a/tests/forcedphot/ephemeris/test_horizons_interface.py b/tests/forcedphot/ephemeris/test_horizons_interface.py index e6feb8f..4ce73d6 100644 --- a/tests/forcedphot/ephemeris/test_horizons_interface.py +++ b/tests/forcedphot/ephemeris/test_horizons_interface.py @@ -21,7 +21,15 @@ def mock_csv_data(): """ Fixture to provide mock CSV data for testing. """ - return pd.DataFrame({"target": ["Ceres"], "target_type": ["smallbody"], "start": ["2020-01-01"], "end": ["2020-01-02"], "step": ["1h"]}) + return pd.DataFrame( + { + "target": ["Ceres"], + "target_type": ["smallbody"], + "start": ["2020-01-01"], + "end": ["2020-01-02"], + "step": ["1h"] + } + ) def test_init(): @@ -75,7 +83,9 @@ def test_query_single_range_failure(mock_horizons): mock_horizons.side_effect = Exception("Query failed") hi = horizons_interface.HorizonsInterface() - query = local_dataclasses.QueryInput("Invalid Target", "smallbody", Time("2020-01-01"), Time("2020-01-02"), "1h") + query = local_dataclasses.QueryInput( + "Invalid Target", "smallbody", Time("2020-01-01"), Time("2020-01-02"), "1h" + ) result = hi.query_single_range(query) assert result is None diff --git a/tests/forcedphot/ephemeris/test_miriade_interface.py b/tests/forcedphot/ephemeris/test_miriade_interface.py index bfce259..1f0b1d1 100644 --- a/tests/forcedphot/ephemeris/test_miriade_interface.py +++ b/tests/forcedphot/ephemeris/test_miriade_interface.py @@ -21,7 +21,15 @@ def mock_csv_data(): """ Fixture to provide mock CSV data for testing. """ - return pd.DataFrame({"target": ["Ceres"], "target_type": ["smallbody"], "start": ["2020-01-01"], "end": ["2020-01-02"], "step": ["1h"]}) + return pd.DataFrame( + { + "target": ["Ceres"], + "target_type": ["smallbody"], + "start": ["2020-01-01"], + "end": ["2020-01-02"], + "step": ["1h"] + } + ) def test_init(): """ @@ -33,22 +41,20 @@ def test_init(): mi_custom = MiriadeInterface("500") assert mi_custom.observer_location == "500" + def test_calc_nsteps_for_miriade_query(): """ Test the calculation of the number of steps for a Miriade query. """ query = QueryInput( - target="Ceres", - target_type="smallbody", - start=Time("2023-01-01"), - end=Time("2023-01-02"), - step="1h" + target="Ceres", target_type="smallbody", start=Time("2023-01-01"), end=Time("2023-01-02"), step="1h" ) mi = MiriadeInterface() result = mi.calc_nsteps_for_miriade_query(query) assert isinstance(result, QueryInputMiriade) assert result.nsteps == 24 + def test_calc_nsteps_for_miriade_query_invalid_step(): """ Test the calculation of the number of steps for a Miriade query with an invalid step. @@ -58,39 +64,38 @@ def test_calc_nsteps_for_miriade_query_invalid_step(): target_type="smallbody", start=Time("2023-01-01"), end=Time("2023-01-02"), - step="1x" # Invalid step + step="1x" ) mi = MiriadeInterface() with pytest.raises(ValueError): mi.calc_nsteps_for_miriade_query(query) -@patch('forcedphot.ephemeris.miriade_interface.Miriade.get_ephemerides') + +@patch("forcedphot.ephemeris.miriade_interface.Miriade.get_ephemerides") def test_query_single_range(mock_get_ephemerides, miriade_interface): """ Test successful query of a single range using mocked Miriade data. """ - mock_get_ephemerides.return_value = Table({ - 'epoch': [2459580.5], - 'RAJ2000': [10.5], - 'DECJ2000': [20.5], - 'RAcosD_rate': [0.1], - 'DEC_rate': [0.2], - 'AZ': [30.0], - 'EL': [40.0], - 'heldist': [1.5], - 'delta': [1.0], - 'V': [5.0], - 'alpha': [60.0], - 'posunc': [0.01] - }) + mock_get_ephemerides.return_value = Table( + { + "epoch": [2459580.5], + "RAJ2000": [10.5], + "DECJ2000": [20.5], + "RAcosD_rate": [0.1], + "DEC_rate": [0.2], + "AZ": [30.0], + "EL": [40.0], + "heldist": [1.5], + "delta": [1.0], + "V": [5.0], + "alpha": [60.0], + "posunc": [0.01] + } + ) query = QueryInput( - target="Ceres", - target_type="smallbody", - start=Time("2023-01-01"), - end=Time("2023-01-02"), - step="1h" + target="Ceres", target_type="smallbody", start=Time("2023-01-01"), end=Time("2023-01-02"), step="1h" ) result = miriade_interface.query_single_range(query) @@ -113,6 +118,7 @@ def test_query_single_range(mock_get_ephemerides, miriade_interface): assert result.ephemeris.alpha_deg[0] == 60.0 assert result.ephemeris.RSS_3sigma_arcsec[0] == 0.01 + def test_ephemeris_data_creation(): """ Test creation of EphemerisData object and verify its attribute types. @@ -131,6 +137,7 @@ def test_ephemeris_data_creation(): assert isinstance(ephemeris.alpha_deg, np.ndarray) assert isinstance(ephemeris.RSS_3sigma_arcsec, np.ndarray) + @patch("pandas.read_csv") @patch("forcedphot.ephemeris.miriade_interface.MiriadeInterface.query_single_range") @patch("astropy.table.Table.from_pandas") @@ -157,9 +164,7 @@ def test_query_ephemeris_from_csv( alpha_deg=np.array([30.0]), RSS_3sigma_arcsec=np.array([0.1]), ) - mock_query_result = QueryResult( - "Ceres", Time("2020-01-01"), Time("2020-01-02"), mock_ephemeris - ) + mock_query_result = QueryResult("Ceres", Time("2020-01-01"), Time("2020-01-02"), mock_ephemeris) mock_query_single_range.return_value = mock_query_result mock_table = MagicMock() @@ -173,7 +178,7 @@ def test_query_ephemeris_from_csv( mock_query_single_range.assert_called_once() mock_table_from_pandas.assert_called_once() - expected_filename = "./data/Ceres_2020-01-01_00-00-00.000_2020-01-02_00-00-00.000.ecsv" + expected_filename = "./Ceres_2020-01-01_00-00-00.000_2020-01-02_00-00-00.000.ecsv" expected_call = call(expected_filename, format="ascii.ecsv", overwrite=True) print(f"Expected call: {expected_call}") print(f"Actual calls: {mock_table.write.mock_calls}")