Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use generic propagator, new adam-core, tests with ASSIST #167

Open
wants to merge 10 commits into
base: main
Choose a base branch
from
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -165,3 +165,7 @@ cython_debug/

.volumes/*
.docker_bash_history.txt

# pdm stuff
.pdm-build/
.pdm-python
37 changes: 19 additions & 18 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ authors = [
description = "Tracklet-less Heliocentric Orbit Recovery"
readme = "README.md"
license = { file = "LICENSE.md" }
requires-python = ">=3.10"
requires-python = "<3.13,>=3.10"
classifiers = [
"Development Status :: 3 - Alpha",
"Intended Audience :: Science/Research",
Expand All @@ -25,8 +25,7 @@ classifiers = [
keywords = ["astronomy", "astrophysics", "space", "science", "asteroids", "comets", "solar system"]

dependencies = [
"adam-core>=0.2.5",
"adam-pyoorb@git+https://github.com/B612-Asteroid-Institute/adam-pyoorb.git@main#egg=adam-pyoorb",
"adam-core>=0.3.4",
"astropy>=5.3.1",
"astroquery",
"difi",
Expand Down Expand Up @@ -71,7 +70,7 @@ typecheck = "mypy --strict ./src/thor"
test = "pytest --benchmark-disable {args}"
doctest = "pytest --doctest-plus --doctest-only"
benchmark = "pytest --benchmark-only"
coverage = "pytest --cov=thor --cov-report=xml"
coverage = "pytest --cov=thor --cov-report=xml --benchmark-disable"

[project.urls]
"Documentation" = "https://github.com/moeyensj/thor#README.md"
Expand All @@ -80,19 +79,21 @@ coverage = "pytest --cov=thor --cov-report=xml"

[project.optional-dependencies]
dev = [
"black",
"ipython",
"matplotlib",
"isort",
"mypy",
"pdm",
"pytest-benchmark",
"pytest-cov",
"pytest-doctestplus",
"pytest-mock",
"pytest-memray",
"pytest",
"ruff",
"black",
"ipython",
"matplotlib",
"isort",
"mypy",
"pdm",
"pytest-benchmark",
"pytest-cov",
"pytest-doctestplus",
"pytest-mock",
"pytest-memray",
"pytest",
"ruff",
"adam-assist>=0.2.0",
"adam-pyoorb @ git+https://github.com/B612-Asteroid-Institute/adam-pyoorb@0697eeb871f8d2f8577bf545f5da3966c473662e",
]

[tool.black]
Expand All @@ -112,7 +113,7 @@ ignore_missing_imports = true

[tool.pytest.ini_options]
python_functions = "test_*"
addopts = "-m 'not (integration or memory)'"
addopts = "-m 'not (memory)' --ignore=__pypackages__"
markers = [
"integration: Mark a test as an integration test.",
"memory: Mark a test as a memory test."
Expand Down
2 changes: 1 addition & 1 deletion src/thor/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
class Config(BaseModel):
max_processes: Optional[int] = None
ray_memory_bytes: int = 0
propagator: Literal["PYOORB"] = "PYOORB"
propagator_namespace: str = "adam_assist.ASSISTPropagator"
cell_radius: float = 10
vx_min: float = -0.1
vx_max: float = 0.1
Expand Down
24 changes: 11 additions & 13 deletions src/thor/main.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import importlib
import logging
import os
import pathlib
Expand All @@ -7,7 +8,6 @@

import quivr as qv
import ray
from adam_core.propagator.adam_pyoorb import PYOORBPropagator
from adam_core.ray_cluster import initialize_use_ray

from .checkpointing import create_checkpoint_data, load_initial_checkpoint_values
Expand Down Expand Up @@ -113,10 +113,9 @@ def link_test_orbit(

initialize_config(config, working_dir)

if config.propagator == "PYOORB":
propagator = PYOORBPropagator
else:
raise ValueError(f"Unknown propagator: {config.propagator}")
module_path, class_name = config.propagator_namespace.rsplit(".", 1)
propagator_module = importlib.import_module(module_path)
propagator_class = getattr(propagator_module, class_name)

use_ray = initialize_use_ray(
num_cpus=config.max_processes,
Expand Down Expand Up @@ -182,7 +181,7 @@ def link_test_orbit(
transformed_detections = range_and_transform(
test_orbit,
filtered_observations,
propagator=propagator,
propagator_class=propagator_class,
max_processes=config.max_processes,
)

Expand Down Expand Up @@ -278,19 +277,18 @@ def link_test_orbit(
iod_orbits, iod_orbit_members = initial_orbit_determination(
filtered_observations,
cluster_members,
propagator_class=propagator_class,
min_obs=config.iod_min_obs,
min_arc_length=config.iod_min_arc_length,
contamination_percentage=config.iod_contamination_percentage,
rchi2_threshold=config.iod_rchi2_threshold,
observation_selection_method=config.iod_observation_selection_method,
propagator=propagator,
propagator_kwargs={},
chunk_size=config.iod_chunk_size,
max_processes=config.max_processes,
# TODO: investigate whether these should be configurable
iterate=False,
light_time=True,
linkage_id_col="cluster_id",
propagator_kwargs={},
chunk_size=config.iod_chunk_size,
max_processes=config.max_processes,
)

iod_orbits_path = None
Expand Down Expand Up @@ -345,7 +343,7 @@ def link_test_orbit(
rchi2_threshold=config.od_rchi2_threshold,
delta=config.od_delta,
max_iter=config.od_max_iter,
propagator=propagator,
propagator_class=propagator_class,
propagator_kwargs={},
chunk_size=config.od_chunk_size,
max_processes=config.max_processes,
Expand Down Expand Up @@ -406,7 +404,7 @@ def link_test_orbit(
radius=config.arc_extension_radius,
delta=config.od_delta,
max_iter=config.od_max_iter,
propagator=propagator,
propagator_class=propagator_class,
propagator_kwargs={},
orbits_chunk_size=config.arc_extension_chunk_size,
max_processes=config.max_processes,
Expand Down
19 changes: 14 additions & 5 deletions src/thor/observations/filters.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
import abc
import importlib
import logging
import multiprocessing as mp
import time
from typing import TYPE_CHECKING, List, Optional, Union
from typing import TYPE_CHECKING, List, Optional, Type, Union

import numpy as np
import pyarrow.parquet as pq
import quivr as qv
import ray
from adam_core.coordinates import SphericalCoordinates
from adam_core.propagator import Propagator
from adam_core.ray_cluster import initialize_use_ray

from thor.config import Config
Expand All @@ -34,6 +36,7 @@ def apply(
self,
observations: Observations,
test_orbit: TestOrbits,
propagator_class: Type[Propagator],
) -> "Observations":
"""
Apply the filter to a collection of observations.
Expand Down Expand Up @@ -77,6 +80,7 @@ def apply(
self,
observations: Union["Observations", ray.ObjectRef],
test_orbit: TestOrbits,
propagator_class: Type[Propagator],
) -> "Observations":
"""
Apply the filter to a collection of observations.
Expand All @@ -103,7 +107,7 @@ def apply(
logger.info(f"Using radius = {self.radius:.5f} deg")

# Generate an ephemeris for every observer time/location in the dataset
ephemeris = test_orbit.generate_ephemeris_from_observations(observations)
ephemeris = test_orbit.generate_ephemeris_from_observations(observations, propagator_class)

filtered_observations = Observations.empty()
state_ids = observations.state_id.unique()
Expand Down Expand Up @@ -198,6 +202,7 @@ def filter_observations_worker(
observations: Observations,
test_orbit: TestOrbits,
filters: List[ObservationFilter],
propagator_class: Type[Propagator],
) -> Observations:
"""
Apply a list of filters to the observations.
Expand All @@ -222,6 +227,7 @@ def filter_observations_worker(
observations = filter_i.apply(
observations,
test_orbit,
propagator_class,
)

# Defragment the observations
Expand Down Expand Up @@ -271,6 +277,10 @@ def filter_observations(
time_start = time.perf_counter()
logger.info("Running observation filters...")

module_path, class_name = config.propagator_namespace.rsplit(".", 1)
propagator_module = importlib.import_module(module_path)
propagator_class = getattr(propagator_module, class_name)

if len(test_orbit) != 1:
raise ValueError(f"filter_observations received {len(test_orbit)} orbits but expected 1.")

Expand Down Expand Up @@ -303,9 +313,7 @@ def filter_observations(
for observations_chunk in observations_iterator(observations, chunk_size=chunk_size):
futures.append(
filter_observations_worker_remote.remote(
observations_chunk,
test_orbit,
filters,
observations_chunk, test_orbit, filters, propagator_class
)
)
if len(futures) > max_processes * 1.5:
Expand All @@ -330,6 +338,7 @@ def filter_observations(
observations_chunk,
test_orbit,
filters,
propagator_class,
)
filtered_observations = qv.concatenate([filtered_observations, filtered_observations_chunk])
if filtered_observations.fragmented():
Expand Down
3 changes: 2 additions & 1 deletion src/thor/observations/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import pyarrow as pa
import pytest
import quivr as qv
from adam_assist import ASSISTPropagator
from adam_core.coordinates import CartesianCoordinates, Origin
from adam_core.observations import Exposures, PointSourceDetections
from adam_core.observers import Observers
Expand Down Expand Up @@ -49,7 +50,7 @@ def fixed_observers() -> Observers:

@pytest.fixture
def fixed_ephems(fixed_test_orbit: TestOrbits, fixed_observers: Observers) -> Ephemeris:
return fixed_test_orbit.generate_ephemeris(fixed_observers)
return fixed_test_orbit.generate_ephemeris(fixed_observers, ASSISTPropagator)


@pytest.fixture
Expand Down
3 changes: 2 additions & 1 deletion src/thor/observations/tests/test_filters.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from unittest import mock

import pyarrow.compute as pc
from adam_assist import ASSISTPropagator

from ...config import Config
from ..filters import TestOrbitRadiusObservationFilter, filter_observations
Expand All @@ -16,7 +17,7 @@ def test_orbit_radius_observation_filter(fixed_test_orbit, fixed_observations):
fos = TestOrbitRadiusObservationFilter(
radius=0.5,
)
have = fos.apply(fixed_observations, fixed_test_orbit)
have = fos.apply(fixed_observations, fixed_test_orbit, ASSISTPropagator)
assert len(pc.unique(have.exposure_id)) == 5
assert pc.all(
pc.equal(
Expand Down
Loading
Loading