Skip to content

Commit

Permalink
Merge projection selects too many columns (#283)
Browse files Browse the repository at this point in the history
  • Loading branch information
phofl authored Aug 30, 2023
1 parent 0c193ad commit ec30959
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 3 deletions.
14 changes: 11 additions & 3 deletions dask_expr/_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
from dask_expr._expr import Blockwise, Expr, Index, PartitionsFiltered, Projection
from dask_expr._repartition import Repartition
from dask_expr._shuffle import AssignPartitioningIndex, Shuffle, _contains_index_name
from dask_expr._util import _convert_to_list

_HASH_COLUMN_NAME = "__hash_partition"

Expand Down Expand Up @@ -203,13 +204,20 @@ def _simplify_up(self, parent):
projection = [projection]

left, right = self.left, self.right
left_on, right_on = self.left_on, self.right_on
left_on = _convert_to_list(self.left_on)
if left_on is None:
left_on = []

right_on = _convert_to_list(self.right_on)
if right_on is None:
right_on = []

left_suffix, right_suffix = self.suffixes[0], self.suffixes[1]
project_left, project_right = [], []

# Find columns to project on the left
for col in left.columns:
if left_on is not None and col in left_on or col in projection:
if col in left_on or col in projection:
project_left.append(col)
elif f"{col}{left_suffix}" in projection:
project_left.append(col)
Expand All @@ -220,7 +228,7 @@ def _simplify_up(self, parent):

# Find columns to project on the right
for col in right.columns:
if right_on is not None and col in right_on or col in projection:
if col in right_on or col in projection:
project_right.append(col)
elif f"{col}{right_suffix}" in projection:
project_right.append(col)
Expand Down
12 changes: 12 additions & 0 deletions dask_expr/tests/test_merge.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,3 +164,15 @@ def test_merge_len():
query = df.merge(df2).index.optimize(fuse=False)
expected = df[["x"]].merge(df2[["x"]]).index.optimize(fuse=False)
assert query._name == expected._name


def test_merge_optimize_subset_strings():
pdf = lib.DataFrame({"a": [1, 2], "aaa": 1})
pdf2 = lib.DataFrame({"b": [1, 2], "aaa": 1})
df = from_pandas(pdf)
df2 = from_pandas(pdf2)

query = df.merge(df2, on="aaa")[["aaa"]].optimize(fuse=False)
exp = df[["aaa"]].merge(df2[["aaa"]], on="aaa").optimize(fuse=False)
assert query._name == exp._name
assert_eq(query, pdf.merge(pdf2, on="aaa")[["aaa"]])

0 comments on commit ec30959

Please sign in to comment.