From b0349370b5c75aef72a3051ed2cb6a3f45b552d9 Mon Sep 17 00:00:00 2001 From: Lawrence Mitchell Date: Fri, 4 Oct 2024 17:02:47 +0000 Subject: [PATCH] WIP: Implement inequality joins by translating to cross + filter Before working through the plumbing in pylibcudf for mixed and conditional joins and the ast evaluator, let's just support inequality joins by doing the basic thing. --- .../cudf_polars/cudf_polars/dsl/translate.py | 22 +++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/python/cudf_polars/cudf_polars/dsl/translate.py b/python/cudf_polars/cudf_polars/dsl/translate.py index a0291037f01..6037f179ea0 100644 --- a/python/cudf_polars/cudf_polars/dsl/translate.py +++ b/python/cudf_polars/cudf_polars/dsl/translate.py @@ -5,6 +5,7 @@ from __future__ import annotations +import functools import json from contextlib import AbstractContextManager, nullcontext from functools import singledispatch @@ -182,6 +183,27 @@ def _( with set_node(visitor, node.input_right): inp_right = translate_ir(visitor, n=None) right_on = [translate_named_expr(visitor, n=e) for e in node.right_on] + if node.options[0] == "inequality": + # No exposure of mixed/conditional joins in pylibcudf yet, so in + # the first instance, implement by doing a cross join followed by + # a filter. + _, *options, op1, op2 = node.options + cross = ir.Join(schema, inp_left, inp_right, [], [], ("cross", *options)) + dtype = plc.DataType(plc.TypeId.BOOL8) + if op2 is None: + ops = [op1] + else: + ops = [op1, op2] + mask = functools.reduce( + functools.partial( + expr.BinOp, dtype, plc.binaryop.BinaryOperator.LOGICAL_AND + ), + ( + expr.BinOp(dtype, expr.BinOp._MAPPING[op], left.value, right.value) + for op, left, right in zip(ops, left_on, right_on, strict=True) + ), + ) + return ir.Filter(schema, cross, expr.NamedExpr("mask", mask)) return ir.Join(schema, inp_left, inp_right, left_on, right_on, node.options)