Skip to content

Commit

Permalink
Merge pull request #1146 from lsst/tickets/DM-46248
Browse files Browse the repository at this point in the history
DM-46248: Few small improvements for the new query system
  • Loading branch information
andy-slac authored Jan 31, 2025
2 parents d858534 + 5903c60 commit 440f5ac
Show file tree
Hide file tree
Showing 14 changed files with 174 additions and 33 deletions.
47 changes: 35 additions & 12 deletions python/lsst/daf/butler/direct_query_driver/_driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,7 @@ def execute(self, result_spec: ResultSpec, tree: qt.QueryTree) -> Iterator[Resul
final_columns=result_spec.get_result_columns(),
order_by=result_spec.order_by,
find_first_dataset=result_spec.find_first_dataset,
allow_duplicate_overlaps=result_spec.allow_duplicate_overlaps,
)
sql_select, sql_columns = builder.finish_select()
if result_spec.order_by:
Expand Down Expand Up @@ -290,12 +291,15 @@ def materialize(
tree: qt.QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
key: qt.MaterializationKey | None = None,
) -> qt.MaterializationKey:
# Docstring inherited.
if self._exit_stack is None:
raise RuntimeError("QueryDriver context must be entered before 'materialize' is called.")
plan = self.build_query(tree, qt.ColumnSet(dimensions))
plan = self.build_query(
tree, qt.ColumnSet(dimensions), allow_duplicate_overlaps=allow_duplicate_overlaps
)
# Current implementation ignores 'datasets' aside from remembering
# them, because figuring out what to put in the temporary table for
# them is tricky, especially if calibration collections are involved.
Expand All @@ -311,7 +315,9 @@ def materialize(
#
sql_select, _ = plan.finish_select(return_columns=False)
table = self._exit_stack.enter_context(
self.db.temporary_table(make_table_spec(plan.final_columns, self.db, plan.postprocessing))
self.db.temporary_table(
make_table_spec(plan.final_columns, self.db, plan.postprocessing, make_indices=True)
)
)
self.db.insert(table, select=sql_select)
if key is None:
Expand Down Expand Up @@ -401,7 +407,7 @@ def count(

def any(self, tree: qt.QueryTree, *, execute: bool, exact: bool) -> bool:
# Docstring inherited.
builder = self.build_query(tree, qt.ColumnSet(tree.dimensions))
builder = self.build_query(tree, qt.ColumnSet(tree.dimensions), allow_duplicate_overlaps=True)
if not all(d.collection_records for d in builder.joins_analysis.datasets.values()):
return False
if not execute:
Expand Down Expand Up @@ -447,6 +453,7 @@ def build_query(
order_by: Iterable[qt.OrderExpression] = (),
find_first_dataset: str | qt.AnyDatasetType | None = None,
analyze_only: bool = False,
allow_duplicate_overlaps: bool = False,
) -> QueryBuilder:
"""Convert a query description into a nearly-complete builder object
for the SQL version of that query.
Expand All @@ -470,6 +477,9 @@ def build_query(
builder, but do not call methods that build its SQL form. This can
be useful for obtaining diagnostic information about the query that
would be generated.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Returns
-------
Expand Down Expand Up @@ -542,7 +552,7 @@ def build_query(
# SqlSelectBuilder and Postprocessing with spatial/temporal constraints
# potentially transformed by the dimensions manager (but none of the
# rest of the analysis reflected in that SqlSelectBuilder).
query_tree_analysis = self._analyze_query_tree(tree)
query_tree_analysis = self._analyze_query_tree(tree, allow_duplicate_overlaps)
# The "projection" columns differ from the final columns by not
# omitting any dimension keys (this keeps queries for different result
# types more similar during construction), including any columns needed
Expand Down Expand Up @@ -589,7 +599,7 @@ def build_query(
builder.apply_find_first(self)
return builder

def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
def _analyze_query_tree(self, tree: qt.QueryTree, allow_duplicate_overlaps: bool) -> QueryTreeAnalysis:
"""Analyze a `.queries.tree.QueryTree` as the first step in building
a SQL query.
Expand All @@ -603,6 +613,9 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
tree_analysis : `QueryTreeAnalysis`
Struct containing additional information need to build the joins
stage of a query.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Notes
-----
Expand Down Expand Up @@ -632,6 +645,7 @@ def _analyze_query_tree(self, tree: qt.QueryTree) -> QueryTreeAnalysis:
tree.predicate,
tree.get_joined_dimension_groups(),
collection_analysis.calibration_dataset_types,
allow_duplicate_overlaps,
)
# Extract the data ID implied by the predicate; we can use the governor
# dimensions in that to constrain the collections we search for
Expand Down Expand Up @@ -799,13 +813,22 @@ def apply_initial_query_joins(
select_builder.joins, materialization_key, materialization_dimensions
)
)
# Process dataset joins (not including any union dataset).
for dataset_search in joins_analysis.datasets.values():
self.join_dataset_search(
select_builder.joins,
dataset_search,
joins_analysis.columns.dataset_fields[dataset_search.name],
)
# Process dataset joins (not including any union dataset). Datasets
# searches included in materialization can be skipped unless we need
# something from their tables.
materialized_datasets = set()
for m_state in self._materializations.values():
materialized_datasets.update(m_state.datasets)
for dataset_type_name, dataset_search in joins_analysis.datasets.items():
if (
dataset_type_name not in materialized_datasets
or dataset_type_name in select_builder.columns.dataset_fields
):
self.join_dataset_search(
select_builder.joins,
dataset_search,
joins_analysis.columns.dataset_fields[dataset_search.name],
)
# Join in dimension element tables that we know we need relationships
# or columns from.
for element in joins_analysis.iter_mandatory(union_dataset_dimensions):
Expand Down
26 changes: 24 additions & 2 deletions python/lsst/daf/butler/direct_query_driver/_sql_builders.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,8 @@
import sqlalchemy

from .. import ddl
from ..dimensions import DimensionGroup
from ..dimensions._group import SortedSequenceSet
from ..nonempty_mapping import NonemptyMapping
from ..queries import tree as qt
from ._postprocessing import Postprocessing
Expand Down Expand Up @@ -638,7 +640,7 @@ def to_select_builder(


def make_table_spec(
columns: qt.ColumnSet, db: Database, postprocessing: Postprocessing | None
columns: qt.ColumnSet, db: Database, postprocessing: Postprocessing | None, *, make_indices: bool = False
) -> ddl.TableSpec:
"""Make a specification that can be used to create a table to store
this query's outputs.
Expand All @@ -652,18 +654,22 @@ def make_table_spec(
postprocessing : `Postprocessing`
Struct representing post-query processing in Python, which may
require additional columns in the query results.
make_indices : `bool`, optional
If `True` add indices for groups of columns.
Returns
-------
spec : `.ddl.TableSpec`
Table specification for this query's result columns (including
those from `postprocessing` and `SqlJoinsBuilder.special`).
"""
indices = _make_table_indices(columns.dimensions) if make_indices else []
results = ddl.TableSpec(
[
columns.get_column_spec(logical_table, field).to_sql_spec(name_shrinker=db.name_shrinker)
for logical_table, field in columns
]
],
indexes=indices,
)
if postprocessing:
for element in postprocessing.iter_missing(columns):
Expand All @@ -679,3 +685,19 @@ def make_table_spec(
ddl.FieldSpec(name=SqlSelectBuilder.EMPTY_COLUMNS_NAME, dtype=SqlSelectBuilder.EMPTY_COLUMNS_TYPE)
)
return results


def _make_table_indices(dimensions: DimensionGroup) -> list[ddl.IndexSpec]:

index_columns: list[SortedSequenceSet] = []
for dimension in dimensions.required:
minimal_group = dimensions.universe[dimension].minimal_group.required

for idx in range(len(index_columns)):
if index_columns[idx] <= minimal_group:
index_columns[idx] = minimal_group
break
else:
index_columns.append(minimal_group)

return [ddl.IndexSpec(*columns) for columns in index_columns]
24 changes: 20 additions & 4 deletions python/lsst/daf/butler/queries/_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,12 @@ def __init__(self, driver: QueryDriver, tree: QueryTree | None = None):
tree = make_identity_query_tree(driver.universe)
super().__init__(driver, tree)

# If ``_allow_duplicate_overlaps`` is set to `True` then query will be
# allowed to generate non-distinct rows for spatial overlaps. This is
# not a part of public API for now, to be used by graph builder as
# optimization.
self._allow_duplicate_overlaps: bool = False

@property
def constraint_dataset_types(self) -> Set[str]:
"""The names of all dataset types joined into the query.
Expand Down Expand Up @@ -218,7 +224,11 @@ def data_ids(
dimensions = self._driver.universe.conform(dimensions)
if not dimensions <= self._tree.dimensions:
tree = tree.join_dimensions(dimensions)
result_spec = DataCoordinateResultSpec(dimensions=dimensions, include_dimension_records=False)
result_spec = DataCoordinateResultSpec(
dimensions=dimensions,
include_dimension_records=False,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return DataCoordinateQueryResults(self._driver, tree, result_spec)

def datasets(
Expand Down Expand Up @@ -284,6 +294,7 @@ def datasets(
storage_class_name=storage_class_name,
include_dimension_records=False,
find_first=find_first,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return DatasetRefQueryResults(self._driver, tree=query._tree, spec=spec)

Expand All @@ -308,7 +319,9 @@ def dimension_records(self, element: str) -> DimensionRecordQueryResults:
tree = self._tree
if element not in tree.dimensions.elements:
tree = tree.join_dimensions(self._driver.universe[element].minimal_group)
result_spec = DimensionRecordResultSpec(element=self._driver.universe[element])
result_spec = DimensionRecordResultSpec(
element=self._driver.universe[element], allow_duplicate_overlaps=self._allow_duplicate_overlaps
)
return DimensionRecordQueryResults(self._driver, tree, result_spec)

def general(
Expand Down Expand Up @@ -445,6 +458,7 @@ def general(
dimension_fields=dimension_fields_dict,
dataset_fields=dataset_fields_dict,
find_first=find_first,
allow_duplicate_overlaps=self._allow_duplicate_overlaps,
)
return GeneralQueryResults(self._driver, tree=tree, spec=result_spec)

Expand Down Expand Up @@ -495,7 +509,9 @@ def materialize(
dimensions = self._tree.dimensions
else:
dimensions = self._driver.universe.conform(dimensions)
key = self._driver.materialize(self._tree, dimensions, datasets)
key = self._driver.materialize(
self._tree, dimensions, datasets, allow_duplicate_overlaps=self._allow_duplicate_overlaps
)
tree = make_identity_query_tree(self._driver.universe).join_materialization(
key, dimensions=dimensions
)
Expand All @@ -508,7 +524,7 @@ def materialize(
"Expand the dimensions or drop this dataset type in the arguments to materialize to "
"avoid this error."
)
tree = tree.join_dataset(dataset_type_name, self._tree.datasets[dataset_type_name])
tree = tree.join_dataset(dataset_type_name, dataset_search)
return Query(self._driver, tree)

def join_dataset_search(
Expand Down
4 changes: 4 additions & 0 deletions python/lsst/daf/butler/queries/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -209,6 +209,7 @@ def materialize(
tree: QueryTree,
dimensions: DimensionGroup,
datasets: frozenset[str],
allow_duplicate_overlaps: bool = False,
) -> MaterializationKey:
"""Execute a query tree, saving results to temporary storage for use
in later queries.
Expand All @@ -222,6 +223,9 @@ def materialize(
datasets : `frozenset` [ `str` ]
Names of dataset types whose ID columns may be materialized. It
is implementation-defined whether they actually are.
allow_duplicate_overlaps : `bool`, optional
If set to `True` then query will be allowed to generate
non-distinct rows for spatial overlaps.
Returns
-------
Expand Down
5 changes: 5 additions & 0 deletions python/lsst/daf/butler/queries/result_specs.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,11 @@ class ResultSpecBase(pydantic.BaseModel, ABC):
limit: int | None = None
"""Maximum number of rows to return, or `None` for no bound."""

allow_duplicate_overlaps: bool = False
"""If set to True the queries are allowed to returnd duplicate rows for
spatial overlaps.
"""

def validate_tree(self, tree: QueryTree) -> None:
"""Check that this result object is consistent with a query tree.
Expand Down
21 changes: 11 additions & 10 deletions python/lsst/daf/butler/registry/databases/postgresql.py
Original file line number Diff line number Diff line change
Expand Up @@ -246,20 +246,21 @@ def _transaction(
# PostgreSQL actually considers SET TRANSACTION to be a
# fundamentally different statement from SET (they have their
# own distinct doc pages, at least).
if not (self.isWriteable() or for_temp_tables):
with closing(connection.connection.cursor()) as cursor:
# PostgreSQL permits writing to temporary tables inside
# read-only transactions, but it doesn't permit creating
# them.
with closing(connection.connection.cursor()) as cursor:
if not (self.isWriteable() or for_temp_tables):
cursor.execute("SET TRANSACTION READ ONLY")
cursor.execute("SET TIME ZONE 0")
else:
with closing(connection.connection.cursor()) as cursor:
# Make timestamps UTC, because we didn't use TIMESTAMPZ
# for the column type. When we can tolerate a schema
# change, we should change that type and remove this
# line.
cursor.execute("SET TIME ZONE 0")
# Make timestamps UTC, because we didn't use TIMESTAMPZ
# for the column type. When we can tolerate a schema
# change, we should change that type and remove this
# line.
cursor.execute("SET TIME ZONE 0")
# Using server-side cursors with complex queries frequently
# generates suboptimal query plan, setting
# cursor_tuple_fraction=1 helps for those cases.
cursor.execute("SET cursor_tuple_fraction = 1")
yield is_new, connection

@contextmanager
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1635,7 +1635,7 @@ def _get_calibs_table(self, table: DynamicTables) -> sqlalchemy.Table:

def _create_case_expression_for_collections(
collections: Iterable[CollectionRecord], id_column: sqlalchemy.ColumnElement
) -> sqlalchemy.Case | sqlalchemy.Null:
) -> sqlalchemy.ColumnElement:
"""Return a SQLAlchemy Case expression that converts collection IDs to
collection names for the given set of collections.
Expand All @@ -1661,6 +1661,6 @@ def _create_case_expression_for_collections(
# cases, e.g. we start with a list of valid collections but they are
# all filtered out by higher-level code on the basis of collection
# summaries.
return sqlalchemy.null()
return sqlalchemy.cast(sqlalchemy.null(), sqlalchemy.String)

return sqlalchemy.case(mapping, value=id_column)
10 changes: 7 additions & 3 deletions python/lsst/daf/butler/registry/dimensions/static.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,9 +487,10 @@ def process_query_overlaps(
predicate: qt.Predicate,
join_operands: Iterable[DimensionGroup],
calibration_dataset_types: Set[str | qt.AnyDatasetType],
allow_duplicates: bool = False,
) -> tuple[qt.Predicate, SqlSelectBuilder, Postprocessing]:
overlaps_visitor = _CommonSkyPixMediatedOverlapsVisitor(
self._db, dimensions, calibration_dataset_types, self._overlap_tables
self._db, dimensions, calibration_dataset_types, self._overlap_tables, allow_duplicates
)
new_predicate = overlaps_visitor.run(predicate, join_operands)
return new_predicate, overlaps_visitor.builder, overlaps_visitor.postprocessing
Expand Down Expand Up @@ -1025,13 +1026,15 @@ def __init__(
dimensions: DimensionGroup,
calibration_dataset_types: Set[str | qt.AnyDatasetType],
overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]],
allow_duplicates: bool,
):
super().__init__(dimensions, calibration_dataset_types)
self.builder: SqlSelectBuilder = SqlJoinsBuilder(db=db).to_select_builder(qt.ColumnSet(dimensions))
self.postprocessing = Postprocessing()
self.common_skypix = dimensions.universe.commonSkyPix
self.overlap_tables: Mapping[str, tuple[sqlalchemy.Table, sqlalchemy.Table]] = overlap_tables
self.common_skypix_overlaps_done: set[DatabaseDimensionElement] = set()
self.allow_duplicates = allow_duplicates

def visit_spatial_constraint(
self,
Expand Down Expand Up @@ -1081,7 +1084,8 @@ def visit_spatial_constraint(
joins_builder.where(sqlalchemy.or_(*sql_where_or))
self.builder.join(
joins_builder.to_select_builder(
qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(), distinct=True
qt.ColumnSet(element.minimal_group).drop_implied_dimension_keys(),
distinct=not self.allow_duplicates,
).into_joins_builder(postprocessing=None)
)
# Short circuit here since the SQL WHERE clause has already
Expand Down Expand Up @@ -1145,7 +1149,7 @@ def visit_spatial_join(
.join(self._make_common_skypix_overlap_joins_builder(b))
.to_select_builder(
qt.ColumnSet(a.minimal_group | b.minimal_group).drop_implied_dimension_keys(),
distinct=True,
distinct=not self.allow_duplicates,
)
.into_joins_builder(postprocessing=None)
)
Expand Down
Loading

0 comments on commit 440f5ac

Please sign in to comment.