Skip to content

Commit

Permalink
test and not-quite-right fix for assertion error in specialize()
Browse files Browse the repository at this point in the history
  • Loading branch information
rebkwok committed Jan 10, 2025
1 parent 32f477b commit 9a56cae
Show file tree
Hide file tree
Showing 2 changed files with 48 additions and 7 deletions.
14 changes: 8 additions & 6 deletions ehrql/dummy_data_nextgen/query_info.py
Original file line number Diff line number Diff line change
Expand Up @@ -328,12 +328,14 @@ def specialize(query, column) -> Node | None:
if lhs is None or rhs is None:
return None
return type(comp)(lhs=lhs, rhs=rhs)
case SelectColumn() as q:
if column == q:
assert len(columns_for_query(q)) == 1
return q
else:
return None
case SelectColumn() as q1:
# a SelectColumn() query can be a simple select from a SelectTable source,
# but if it is from an EventTable, it can will be a PickOneRowPerPatient
# with a Sort on a source SelectTable.
if set(columns_for_query(q1)) == {column}:
assert len(columns_for_query(q1)) == 1
return q1
return None
case _:
fields = query.__dataclass_fields__
specialized = {}
Expand Down
41 changes: 40 additions & 1 deletion tests/unit/dummy_data_nextgen/test_query_info.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import datetime

from ehrql import Dataset, days
from ehrql import Dataset, case, days, when
from ehrql.codes import CTV3Code
from ehrql.dummy_data_nextgen.query_info import ColumnInfo, QueryInfo, TableInfo
from ehrql.tables import (
Expand Down Expand Up @@ -32,6 +32,16 @@ class events(EventFrame):
code = Series(CTV3Code)


@table
class address(PatientFrame):
start_date = Series(datetime.date)


@table
class addresses(EventFrame):
start_date = Series(datetime.date)


def test_query_info_from_dataset():
dataset = Dataset()
dataset.define_population(events.exists_for_patient())
Expand Down Expand Up @@ -147,3 +157,32 @@ def test_query_info_ignores_complex_comparisons():
column_info = query_info.tables["patients"].columns["date_of_birth"]

assert column_info.values_used == [datetime.date(2022, 10, 5)]


def test_query_info_with_nested_case_statements():
# This test reproduces an error encountered in real-world ehrQL which was over-using
# case statements for boolean series (i.e. using case to return True/False on an
# already bool series, and then also filtering by == True/False on that case statement).
# QueryInfo.specialize turns those sorts of queries into e.g. EQ(lhs=Value(True), rhs=Value(False))
# (with no column references).
# This exposed a bug where a SelectColumn on an EventTable rather than a PatientTable
# was reduced to None, and we ended up with a resulting query with no column references
# in it.
dataset = Dataset()

has_dob = case(
when(patients.date_of_birth.is_not_null()).then(True), otherwise=False
)
first_date = events.sort_by(events.date).first_for_patient().date

query = case(
when(
(has_dob == False) | ((has_dob == True) & (first_date == "2020-01-01"))
).then(True),
otherwise=False,
)
dataset.define_population(patients.exists_for_patient() & query)

query_info = QueryInfo.from_dataset(dataset._compile())
column_info = query_info.tables["events"].columns["date"]
assert column_info.values_used == [datetime.date(2020, 1, 1)]

0 comments on commit 9a56cae

Please sign in to comment.