Skip to content

Commit dc62177

Browse files
authored
Preserve order in left join for cudf-polars (#16268)
Unlike all other joins, polars provides an ordering guarantee for left joins. By default libcudf does not, so we need to order the gather maps in this case. While here, because it requires another hard-coding of `int32` for something that should be `size_type`, expose `type_to_id` in cython and plumb it through. Authors: - Lawrence Mitchell (https://github.com/wence-) Approvers: - Vyas Ramasubramani (https://github.com/vyasr) URL: #16268
1 parent d5ab48d commit dc62177

File tree

7 files changed

+78
-18
lines changed

7 files changed

+78
-18
lines changed

python/cudf/cudf/_lib/pylibcudf/join.pyx

Lines changed: 3 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -10,28 +10,19 @@ from rmm._lib.device_buffer cimport device_buffer
1010
from cudf._lib.pylibcudf.libcudf cimport join as cpp_join
1111
from cudf._lib.pylibcudf.libcudf.column.column cimport column
1212
from cudf._lib.pylibcudf.libcudf.table.table cimport table
13-
from cudf._lib.pylibcudf.libcudf.types cimport (
14-
data_type,
15-
null_equality,
16-
size_type,
17-
type_id,
18-
)
13+
from cudf._lib.pylibcudf.libcudf.types cimport null_equality
1914

2015
from .column cimport Column
2116
from .table cimport Table
2217

2318

2419
cdef Column _column_from_gather_map(cpp_join.gather_map_type gather_map):
2520
# helper to convert a gather map to a Column
26-
cdef device_buffer c_empty
27-
cdef size_type size = dereference(gather_map.get()).size()
2821
return Column.from_libcudf(
2922
move(
3023
make_unique[column](
31-
data_type(type_id.INT32),
32-
size,
33-
dereference(gather_map.get()).release(),
34-
move(c_empty),
24+
move(dereference(gather_map.get())),
25+
device_buffer(),
3526
0
3627
)
3728
)
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
# Copyright (c) 2024, NVIDIA CORPORATION.
2+
3+
from cudf._lib.pylibcudf.libcudf.types cimport type_id
4+
5+
6+
cdef extern from "cudf/utilities/type_dispatcher.hpp" namespace "cudf" nogil:
7+
cdef type_id type_to_id[T]()

python/cudf/cudf/_lib/pylibcudf/types.pyx

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,8 @@
22

33
from libc.stdint cimport int32_t
44

5-
from cudf._lib.pylibcudf.libcudf.types cimport data_type, type_id
5+
from cudf._lib.pylibcudf.libcudf.types cimport data_type, size_type, type_id
6+
from cudf._lib.pylibcudf.libcudf.utilities.type_dispatcher cimport type_to_id
67

78
from cudf._lib.pylibcudf.libcudf.types import type_id as TypeId # no-cython-lint, isort:skip
89
from cudf._lib.pylibcudf.libcudf.types import nan_policy as NanPolicy # no-cython-lint, isort:skip
@@ -67,3 +68,7 @@ cdef class DataType:
6768
cdef DataType ret = DataType.__new__(DataType, type_id.EMPTY)
6869
ret.c_obj = dt
6970
return ret
71+
72+
73+
SIZE_TYPE = DataType(type_to_id[size_type]())
74+
SIZE_TYPE_ID = SIZE_TYPE.id()

python/cudf/cudf/_lib/types.pyx

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,6 @@ from cudf._lib.types cimport (
2121
import cudf
2222
from cudf._lib import pylibcudf
2323

24-
size_type_dtype = np.dtype("int32")
25-
2624

2725
class TypeId(IntEnum):
2826
EMPTY = <underlying_type_t_type_id> libcudf_types.type_id.EMPTY
@@ -150,6 +148,8 @@ datetime_unit_map = {
150148
TypeId.TIMESTAMP_NANOSECONDS: "ns",
151149
}
152150

151+
size_type_dtype = LIBCUDF_TO_SUPPORTED_NUMPY_TYPES[pylibcudf.types.SIZE_TYPE_ID]
152+
153153

154154
class Interpolation(IntEnum):
155155
LINEAR = (

python/cudf_polars/cudf_polars/containers/column.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -185,8 +185,7 @@ def nan_count(self) -> int:
185185
plc.reduce.reduce(
186186
plc.unary.is_nan(self.obj),
187187
plc.aggregation.sum(),
188-
# TODO: pylibcudf needs to have a SizeType DataType singleton
189-
plc.DataType(plc.TypeId.INT32),
188+
plc.types.SIZE_TYPE,
190189
)
191190
).as_py()
192191
return 0

python/cudf_polars/cudf_polars/dsl/ir.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -653,6 +653,59 @@ def _joiners(
653653
else:
654654
assert_never(how)
655655

656+
def _reorder_maps(
657+
self,
658+
left_rows: int,
659+
lg: plc.Column,
660+
left_policy: plc.copying.OutOfBoundsPolicy,
661+
right_rows: int,
662+
rg: plc.Column,
663+
right_policy: plc.copying.OutOfBoundsPolicy,
664+
) -> list[plc.Column]:
665+
"""
666+
Reorder gather maps to satisfy polars join order restrictions.
667+
668+
Parameters
669+
----------
670+
left_rows
671+
Number of rows in left table
672+
lg
673+
Left gather map
674+
left_policy
675+
Nullify policy for left map
676+
right_rows
677+
Number of rows in right table
678+
rg
679+
Right gather map
680+
right_policy
681+
Nullify policy for right map
682+
683+
Returns
684+
-------
685+
list of reordered left and right gather maps.
686+
687+
Notes
688+
-----
689+
For a left join, the polars result preserves the order of the
690+
left keys, and is stable wrt the right keys. For all other
691+
joins, there is no order obligation.
692+
"""
693+
dt = plc.interop.to_arrow(plc.types.SIZE_TYPE)
694+
init = plc.interop.from_arrow(pa.scalar(0, type=dt))
695+
step = plc.interop.from_arrow(pa.scalar(1, type=dt))
696+
left_order = plc.copying.gather(
697+
plc.Table([plc.filling.sequence(left_rows, init, step)]), lg, left_policy
698+
)
699+
right_order = plc.copying.gather(
700+
plc.Table([plc.filling.sequence(right_rows, init, step)]), rg, right_policy
701+
)
702+
return plc.sorting.stable_sort_by_key(
703+
plc.Table([lg, rg]),
704+
plc.Table([*left_order.columns(), *right_order.columns()]),
705+
[plc.types.Order.ASCENDING, plc.types.Order.ASCENDING],
706+
[plc.types.NullOrder.AFTER, plc.types.NullOrder.AFTER],
707+
).columns()
708+
656709
def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
657710
"""Evaluate and return a dataframe."""
658711
left = self.left.evaluate(cache=cache)
@@ -693,6 +746,11 @@ def evaluate(self, *, cache: MutableMapping[int, DataFrame]) -> DataFrame:
693746
result = DataFrame.from_table(table, left.column_names)
694747
else:
695748
lg, rg = join_fn(left_on.table, right_on.table, null_equality)
749+
if how == "left":
750+
# Order of left table is preserved
751+
lg, rg = self._reorder_maps(
752+
left.num_rows, lg, left_policy, right.num_rows, rg, right_policy
753+
)
696754
if coalesce and how == "inner":
697755
right = right.discard_columns(right_on.column_names_set)
698756
left = DataFrame.from_table(

python/cudf_polars/tests/test_join.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def test_join(how, coalesce, join_nulls, join_expr):
5353
query = left.join(
5454
right, on=join_expr, how=how, join_nulls=join_nulls, coalesce=coalesce
5555
)
56-
assert_gpu_result_equal(query, check_row_order=False)
56+
assert_gpu_result_equal(query, check_row_order=how == "left")
5757

5858

5959
def test_cross_join():

0 commit comments

Comments
 (0)