Skip to content

Commit d19cb40

Browse files
authored
Remove floating point types from radix sort fast-path (#7215)
Closes #7212 Reference #7167 (comment) Using radix sort for all fixed-width types causes an [error in Spark when floating point columns contain NaN elements](NVIDIA/spark-rapids#1585). This PR removes floating-point column types from the radix fast-path. This means the original `relational_compare` row operator is used to handle sorting floating point columns since they could possibly contain NaN elements. The `NANSorting` gtest included null elements so it did not catch the fast-path output discrepancy. This PR adds a `NANSortingNonNull` gtest to check for the desired NaN sorting behavior. Authors: - David (@davidwendt) Approvers: - Jake Hemstad (@jrhemstad) - Conor Hoekstra (@codereport) URL: #7215
1 parent ccf4ffa commit d19cb40

File tree

2 files changed

+40
-7
lines changed

2 files changed

+40
-7
lines changed

cpp/src/sort/sort_column.cu

Lines changed: 18 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,17 @@ namespace {
2424
* @brief Type-dispatched functor for sorting a single column.
2525
*/
2626
struct column_sorted_order_fn {
27+
/**
28+
* @brief Compile time check for allowing radix sort for column type.
29+
*
30+
* Floating point is removed here for special handling of NaNs.
31+
*/
32+
template <typename T>
33+
static constexpr bool is_radix_sort_supported()
34+
{
35+
return cudf::is_fixed_width<T>() && !cudf::is_floating_point<T>();
36+
}
37+
2738
/**
2839
* @brief Sorts fixed-width columns using faster thrust sort.
2940
*
@@ -32,15 +43,15 @@ struct column_sorted_order_fn {
3243
* @param ascending True if sort order is ascending
3344
* @param stream CUDA stream used for device memory operations and kernel launches
3445
*/
35-
template <typename T, typename std::enable_if_t<cudf::is_fixed_width<T>()>* = nullptr>
46+
template <typename T, typename std::enable_if_t<is_radix_sort_supported<T>()>* = nullptr>
3647
void radix_sort(column_view const& input,
3748
mutable_column_view& indices,
3849
bool ascending,
3950
rmm::cuda_stream_view stream)
4051
{
41-
// A non-stable sort on a fixed-width column with no nulls will use a radix sort
42-
// if using only the thrust::less or thrust::greater comparators but also
43-
// requires making a copy of the input data.
52+
// A non-stable sort on a column of arithmetic type with no nulls will use a radix sort
53+
// if specifying only the `thrust::less` or `thrust::greater` comparators.
54+
// But this also requires making a copy of the input data.
4455
auto temp_col = column(input, stream);
4556
auto d_col = temp_col.mutable_view();
4657
using DeviceT = device_storage_type_t<T>;
@@ -58,7 +69,7 @@ struct column_sorted_order_fn {
5869
thrust::greater<DeviceT>());
5970
}
6071
}
61-
template <typename T, typename std::enable_if_t<!cudf::is_fixed_width<T>()>* = nullptr>
72+
template <typename T, typename std::enable_if_t<!is_radix_sort_supported<T>()>* = nullptr>
6273
void radix_sort(column_view const&, mutable_column_view&, bool, rmm::cuda_stream_view)
6374
{
6475
CUDF_FAIL("Only fixed-width types are suitable for faster sorting");
@@ -83,8 +94,8 @@ struct column_sorted_order_fn {
8394
null_order null_precedence,
8495
rmm::cuda_stream_view stream)
8596
{
86-
// column with nulls or non-fixed-width column will also use a comparator
87-
if (input.has_nulls() || !cudf::is_fixed_width<T>()) {
97+
// column with nulls or non-supported types will also use a comparator
98+
if (input.has_nulls() || !is_radix_sort_supported<T>()) {
8899
auto keys = column_device_view::create(input, stream);
89100
thrust::sort(rmm::exec_policy(stream),
90101
indices.begin<size_type>(),

cpp/tests/table/row_operators_tests.cpp

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,3 +65,25 @@ TEST_F(RowOperatorTestForNAN, NANSorting)
6565

6666
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected2, got2->view());
6767
}
68+
69+
TEST_F(RowOperatorTestForNAN, NANSortingNonNull)
70+
{
71+
cudf::test::fixed_width_column_wrapper<double> input{
72+
{0.,
73+
double(NAN),
74+
-1.,
75+
7.,
76+
std::numeric_limits<double>::infinity(),
77+
1.,
78+
-1 * std::numeric_limits<double>::infinity()}};
79+
80+
cudf::table_view input_table{{input}};
81+
82+
auto result = cudf::sorted_order(input_table, {cudf::order::ASCENDING});
83+
cudf::test::fixed_width_column_wrapper<int32_t> expected_asc{{6, 2, 0, 5, 3, 4, 1}};
84+
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_asc, result->view());
85+
86+
result = cudf::sorted_order(input_table, {cudf::order::DESCENDING});
87+
cudf::test::fixed_width_column_wrapper<int32_t> expected_desc{{1, 4, 3, 5, 0, 2, 6}};
88+
CUDF_TEST_EXPECT_COLUMNS_EQUAL(expected_desc, result->view());
89+
}

0 commit comments

Comments
 (0)