Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Sample hypothesis testing. #64

Merged
merged 1 commit into from
Jul 28, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ build-backend = "setuptools.build_meta"

[project.optional-dependencies]
dev = [
"pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils"
"pre-commit", "pytest", "pytest-cov", "pytest-subtests", "rootutils", "hypothesis"
]
profiling = ["psutil"]

Expand Down
117 changes: 117 additions & 0 deletions tests/test_aggregate_hypothesis.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import rootutils

root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)

from datetime import datetime, timedelta

import polars as pl
import polars.selectors as cs
from hypothesis import given, settings
from hypothesis import strategies as st
from polars.testing import assert_series_equal
from polars.testing.parametric import column, dataframes

from aces.aggregate import aggregate_temporal_window
from aces.types import TemporalWindowBounds

Comment on lines +1 to +16
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Reorder imports to adhere to Python's convention.

The imports should be placed at the top of the file before any other code, including the setup of root utilities. This follows Python's convention for better readability and to avoid potential issues with import shadowing.

- import rootutils
- root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)
+ import rootutils
+ from datetime import datetime, timedelta
+ import polars as pl
+ import polars.selectors as cs
+ from hypothesis import given, settings
+ from hypothesis import strategies as st
+ from polars.testing import assert_series_equal
+ from polars.testing.parametric import column, dataframes
+ from aces.aggregate import aggregate_temporal_window
+ from aces.types import TemporalWindowBounds
+ root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)
Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
import rootutils
root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)
from datetime import datetime, timedelta
import polars as pl
import polars.selectors as cs
from hypothesis import given, settings
from hypothesis import strategies as st
from polars.testing import assert_series_equal
from polars.testing.parametric import column, dataframes
from aces.aggregate import aggregate_temporal_window
from aces.types import TemporalWindowBounds
import rootutils
from datetime import datetime, timedelta
import polars as pl
import polars.selectors as cs
from hypothesis import given, settings
from hypothesis import strategies as st
from polars.testing import assert_series_equal
from polars.testing.parametric import column, dataframes
from aces.aggregate import aggregate_temporal_window
from aces.types import TemporalWindowBounds
root = rootutils.setup_root(__file__, dotenv=True, pythonpath=True, cwd=True)
Tools
Ruff

5-5: Module level import not at top of file (E402)


7-7: Module level import not at top of file (E402)


8-8: Module level import not at top of file (E402)


9-9: Module level import not at top of file (E402)


10-10: Module level import not at top of file (E402)


11-11: Module level import not at top of file (E402)


12-12: Module level import not at top of file (E402)


14-14: Module level import not at top of file (E402)


15-15: Module level import not at top of file (E402)

datetime_st = st.datetimes(min_value=datetime(1989, 12, 1), max_value=datetime(1999, 12, 31))

N_PREDICATES = 5
PREDICATE_DATAFRAMES = dataframes(
cols=[
column("subject_id", allow_null=False, dtype=pl.UInt32),
column("timestamp", allow_null=False, dtype=pl.Datetime("ms"), strategy=datetime_st),
*[column(f"predicate_{i}", allow_null=False, dtype=pl.UInt8) for i in range(1, N_PREDICATES + 1)],
],
min_size=1,
max_size=50,
)


@given(
df=PREDICATE_DATAFRAMES,
left_inclusive=st.booleans(),
right_inclusive=st.booleans(),
window_size=st.timedeltas(min_value=timedelta(days=1), max_value=timedelta(days=365 * 5)),
offset=st.timedeltas(min_value=timedelta(days=0), max_value=timedelta(days=365)),
)
@settings(max_examples=50)
def test_aggregate_temporal_window(
df: pl.DataFrame, left_inclusive: bool, right_inclusive: bool, window_size: timedelta, offset: timedelta
):
"""Tests whether calling the `aggregate_temporal_window` function works produces a consistent output."""

max_N_subjects = 3
df = df.with_columns(
(pl.col("subject_id") % max_N_subjects).alias("subject_id"),
cs.starts_with("predicate_").cast(pl.Int32).name.keep(),
).sort("subject_id", "timestamp")

endpoint_expr = TemporalWindowBounds(
left_inclusive=left_inclusive, right_inclusive=right_inclusive, window_size=window_size, offset=offset
)

# Should run:
agg_df = aggregate_temporal_window(df.lazy(), endpoint_expr)
assert agg_df is not None
agg_df = agg_df.collect()

# This will return something of the below form:
#
# shape: (6, 7)
# ┌────────────┬─────────────────────┬─────────────────────┬─────────────────────┬──────┬──────┬──────┐
# │ subject_id ┆ timestamp ┆ timestamp_at_start ┆ timestamp_at_end ┆ is_A ┆ is_B ┆ is_C │
# │ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- ┆ --- │
# │ i64 ┆ datetime[μs] ┆ datetime[μs] ┆ datetime[μs] ┆ i64 ┆ i64 ┆ i64 │
# ╞════════════╪═════════════════════╪═════════════════════╪═════════════════════╪══════╪══════╪══════╡
# │ 1 ┆ 1989-12-01 12:03:00 ┆ 1989-12-02 12:03:00 ┆ 1989-12-01 12:03:00 ┆ 1 ┆ 1 ┆ 2 │
# │ 1 ┆ 1989-12-02 05:17:00 ┆ 1989-12-03 05:17:00 ┆ 1989-12-02 05:17:00 ┆ 1 ┆ 1 ┆ 1 │
# │ 1 ┆ 1989-12-02 12:03:00 ┆ 1989-12-03 12:03:00 ┆ 1989-12-02 12:03:00 ┆ 1 ┆ 0 ┆ 0 │
# │ 1 ┆ 1989-12-06 11:00:00 ┆ 1989-12-07 11:00:00 ┆ 1989-12-06 11:00:00 ┆ 0 ┆ 1 ┆ 0 │
# │ 2 ┆ 1989-12-01 13:14:00 ┆ 1989-12-02 13:14:00 ┆ 1989-12-01 13:14:00 ┆ 0 ┆ 1 ┆ 1 │
# │ 2 ┆ 1989-12-03 15:17:00 ┆ 1989-12-04 15:17:00 ┆ 1989-12-03 15:17:00 ┆ 0 ┆ 0 ┆ 0 │
# └────────────┴─────────────────────┴─────────────────────┴─────────────────────┴──────┴──────┴──────┘
#
# We're going to validate this by asserting that the sums of the predicate columns between the rows
# for a given subject are consistent.

assert set(df.columns).issubset(set(agg_df.columns))
assert len(agg_df.columns) == len(df.columns) + 2
assert "timestamp_at_start" in agg_df.columns
assert "timestamp_at_end" in agg_df.columns
assert_series_equal(agg_df["subject_id"], df["subject_id"])
assert_series_equal(agg_df["timestamp"], df["timestamp"])

# Now we're going to validate the sums of the predicate columns between the rows for a given subject are
# consistent.
for subject_id in range(max_N_subjects):
if subject_id not in df["subject_id"]:
assert subject_id not in agg_df["subject_id"]
continue

raw_subj = df.filter(pl.col("subject_id") == subject_id)
agg_subj = agg_df.filter(pl.col("subject_id") == subject_id)

for row in agg_subj.iter_rows(named=True):
start = row["timestamp_at_start"]
end = row["timestamp_at_end"]

if left_inclusive:
st_filter = pl.col("timestamp") >= start
else:
st_filter = pl.col("timestamp") > start

if right_inclusive:
et_filter = pl.col("timestamp") <= end
else:
et_filter = pl.col("timestamp") < end

raw_filtered = raw_subj.filter(st_filter & et_filter)
if len(raw_filtered) == 0:
for i in range(1, N_PREDICATES + 1):
# TODO: Is this right? Or should it always be one or the other?
assert (row[f"predicate_{i}"] is None) or (row[f"predicate_{i}"] == 0)
else:
raw_sums = raw_filtered.select(cs.starts_with("predicate_")).sum()
for i in range(1, N_PREDICATES + 1):
assert raw_sums[f"predicate_{i}"].item() == row[f"predicate_{i}"]
mmcdermott marked this conversation as resolved.
Show resolved Hide resolved
Loading