diff --git a/dask_expr/_merge.py b/dask_expr/_merge.py index cacc25c8..20e3fc81 100644 --- a/dask_expr/_merge.py +++ b/dask_expr/_merge.py @@ -9,6 +9,7 @@ from dask_expr._expr import Blockwise, Expr, Index, PartitionsFiltered, Projection from dask_expr._repartition import Repartition from dask_expr._shuffle import AssignPartitioningIndex, Shuffle, _contains_index_name +from dask_expr._util import _convert_to_list _HASH_COLUMN_NAME = "__hash_partition" @@ -203,13 +204,20 @@ def _simplify_up(self, parent): projection = [projection] left, right = self.left, self.right - left_on, right_on = self.left_on, self.right_on + left_on = _convert_to_list(self.left_on) + if left_on is None: + left_on = [] + + right_on = _convert_to_list(self.right_on) + if right_on is None: + right_on = [] + left_suffix, right_suffix = self.suffixes[0], self.suffixes[1] project_left, project_right = [], [] # Find columns to project on the left for col in left.columns: - if left_on is not None and col in left_on or col in projection: + if col in left_on or col in projection: project_left.append(col) elif f"{col}{left_suffix}" in projection: project_left.append(col) @@ -220,7 +228,7 @@ def _simplify_up(self, parent): # Find columns to project on the right for col in right.columns: - if right_on is not None and col in right_on or col in projection: + if col in right_on or col in projection: project_right.append(col) elif f"{col}{right_suffix}" in projection: project_right.append(col) diff --git a/dask_expr/tests/test_merge.py b/dask_expr/tests/test_merge.py index 96365473..dacb92b3 100644 --- a/dask_expr/tests/test_merge.py +++ b/dask_expr/tests/test_merge.py @@ -164,3 +164,15 @@ def test_merge_len(): query = df.merge(df2).index.optimize(fuse=False) expected = df[["x"]].merge(df2[["x"]]).index.optimize(fuse=False) assert query._name == expected._name + + +def test_merge_optimize_subset_strings(): + pdf = lib.DataFrame({"a": [1, 2], "aaa": 1}) + pdf2 = lib.DataFrame({"b": [1, 2], "aaa": 1}) + df = from_pandas(pdf) + df2 = from_pandas(pdf2) + + query = df.merge(df2, on="aaa")[["aaa"]].optimize(fuse=False) + exp = df[["aaa"]].merge(df2[["aaa"]], on="aaa").optimize(fuse=False) + assert query._name == exp._name + assert_eq(query, pdf.merge(pdf2, on="aaa")[["aaa"]])