Skip to content

Commit

Permalink
Added more tests and cleaned up some small syntax things.
Browse files Browse the repository at this point in the history
  • Loading branch information
mmcdermott committed Jun 12, 2024
1 parent a22c493 commit 7e72c8c
Show file tree
Hide file tree
Showing 3 changed files with 57 additions and 20 deletions.
8 changes: 4 additions & 4 deletions src/aces/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,10 +4,8 @@
from importlib.resources import files

import hydra
import hydra.core
import hydra.core.hydra_config
from loguru import logger
from omegaconf import DictConfig, OmegaConf
from omegaconf import DictConfig

config_yaml = files("aces").joinpath("configs/aces.yaml")
if not config_yaml.is_file():
Expand All @@ -26,9 +24,11 @@ def main(cfg: DictConfig):
from datetime import datetime
from pathlib import Path

from omegaconf import OmegaConf

from . import config, predicates, query, utils

utils.hydra_loguru_init(f"{hydra.core.hydra_config.HydraConfig.get().job.name}.log")
utils.hydra_loguru_init(f"{cfg.hydra.job.name}.log")

st = datetime.now()

Expand Down
40 changes: 39 additions & 1 deletion src/aces/predicates.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,10 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
... end="start + 24h",
... start_inclusive=False,
... end_inclusive=True,
... has={"death_or_dis": "(None, 0)", "adm": "(None, 0)"},
... has={
... "death_or_dis": "(None, 0)",
... "adm": "(None, 0)",
... },
... ),
... "target": WindowConfig(
... start="gap.end",
Expand Down Expand Up @@ -513,6 +516,41 @@ def get_predicates_df(cfg: TaskExtractorConfig, data_config: DictConfig) -> pl.D
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 0 ┆ 0 ┆ 0 ┆ 1 │
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 0 ┆ 1 ┆ 1 ┆ 1 │
└────────────┴─────────────────────┴─────┴─────┴───────┴──────────────┴────────────┘
>>> any_event_trigger = EventConfig("_ANY_EVENT")
>>> adm_only_predicates = {"adm": PlainPredicateConfig("adm")}
>>> st_end_windows = {
... "input": WindowConfig(
... start="end - 365d",
... end="trigger + 24h",
... start_inclusive=True,
... end_inclusive=True,
... has={
... "_RECORD_END": "(None, 0)", # These are added just to show start/end predicates
... "_RECORD_START": "(None, 0)", # These are added just to show start/end predicates
... },
... ),
... }
>>> st_end_config = TaskExtractorConfig(
... predicates=adm_only_predicates, trigger=any_event_trigger, windows=st_end_windows
... )
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f:
... data_path = Path(f.name)
... data.write_csv(data_path)
... data_config = DictConfig({
... "path": str(data_path), "standard": "direct", "ts_format": "%m/%d/%Y %H:%M"
... })
... get_predicates_df(st_end_config, data_config)
shape: (4, 6)
┌────────────┬─────────────────────┬─────┬────────────┬───────────────┬─────────────┐
│ subject_id ┆ timestamp ┆ adm ┆ _ANY_EVENT ┆ _RECORD_START ┆ _RECORD_END │
│ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
│ i64 ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 ┆ i64 │
╞════════════╪═════════════════════╪═════╪════════════╪═══════════════╪═════════════╡
│ 1 ┆ 2021-01-01 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 │
│ 1 ┆ 2021-01-01 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │
│ 2 ┆ 2021-01-02 00:00:00 ┆ 1 ┆ 1 ┆ 1 ┆ 0 │
│ 2 ┆ 2021-01-02 12:00:00 ┆ 0 ┆ 1 ┆ 0 ┆ 1 │
└────────────┴─────────────────────┴─────┴────────────┴───────────────┴─────────────┘
>>> with tempfile.NamedTemporaryFile(mode="w", suffix=".csv") as f:
... data_path = Path(f.name)
... data.write_csv(data_path)
Expand Down
29 changes: 14 additions & 15 deletions src/aces/query.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,20 +23,22 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame
Returns:
polars.DataFrame: The result of the task query, containing subjects who satisfy the conditions
defined in cfg. Timestamps for the start/end boundaries of each window specified in the task
configuration, as well as predicate counts for each window, are provided.
defined in cfg. Timestamps for the start/end boundaries of each window specified in the task
configuration, as well as predicate counts for each window, are provided.
Raises:
TypeError: If predicates_df is not a polars.DataFrame.
ValueError: If the (subject_id, timestamp) columns are not unique.
"""
if not isinstance(predicates_df, pl.DataFrame):
raise TypeError(f"Predicates dataframe type must be a polars.DataFrame. Got: {type(predicates_df)}.")

logger.info("Checking if '(subject_id, timestamp)' columns are unique...")
try:
assert (
predicates_df.n_unique(subset=["subject_id", "timestamp"]) == predicates_df.shape[0]
), "The (subject_id, timestamp) columns must be unique."
except AssertionError as e:
logger.error(str(e))
return pl.DataFrame()

is_unique = predicates_df.n_unique(subset=["subject_id", "timestamp"]) == predicates_df.shape[0]

if not is_unique:
raise ValueError("The (subject_id, timestamp) columns must be unique.")

log_tree(cfg.window_tree)

Expand All @@ -45,12 +47,9 @@ def query(cfg: TaskExtractorConfig, predicates_df: pl.DataFrame) -> pl.DataFrame
prospective_root_anchors = check_constraints({cfg.trigger.predicate: (1, None)}, predicates_df).select(
"subject_id", pl.col("timestamp").alias("subtree_anchor_timestamp")
)
try:
assert (
not prospective_root_anchors.is_empty()
), f"No valid rows found for the trigger event '{cfg.trigger.predicate}'. Exiting."
except AssertionError as e:
logger.error(str(e))

if prospective_root_anchors.is_empty():
logger.warning(f"No valid rows found for the trigger event '{cfg.trigger.predicate}'. Exiting.")
return pl.DataFrame()

result = extract_subtree(cfg.window_tree, prospective_root_anchors, predicates_df)
Expand Down

0 comments on commit 7e72c8c

Please sign in to comment.