Skip to content

Commit a3d0d08

Browse files
Use extended comparison operators to define weak order on real/complex FP types
We use extended comparison operators compatible with NumPy's behavior: https://numpy.org/devdocs/reference/generated/numpy.sort.html Specifically, we use [R, nan] block ordering for reals, and [(R, R), (R, nan), (nan, R), (nan, nan)] for complexes.
1 parent 08e5dac commit a3d0d08

File tree

1 file changed

+68
-12
lines changed

1 file changed

+68
-12
lines changed

dpctl/tensor/libtensor/source/sorting/sorting_common.hpp

Lines changed: 68 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,8 @@
2424

2525
#pragma once
2626

27-
#include "utils/math_utils.hpp"
27+
#include "sycl/sycl.hpp"
28+
#include <type_traits>
2829

2930
namespace dpctl
3031
{
@@ -33,44 +34,99 @@ namespace tensor
3334
namespace py_internal
3435
{
3536

36-
template <typename cT> struct ComplexLess
37+
namespace
3738
{
38-
bool operator()(const cT &v1, const cT &v2) const
39+
template <typename fpT> struct ExtendedRealFPLess
40+
{
41+
/* [R, nan] */
42+
bool operator()(const fpT v1, const fpT v2) const
3943
{
40-
using dpctl::tensor::math_utils::less_complex;
44+
return (!sycl::isnan(v1) && (sycl::isnan(v2) || (v1 < v2)));
45+
}
46+
};
4147

42-
return less_complex(v1, v2);
48+
template <typename fpT> struct ExtendedRealFPGreater
49+
{
50+
bool operator()(const fpT v1, const fpT v2) const
51+
{
52+
return (!sycl::isnan(v2) && (sycl::isnan(v1) || (v2 < v1)));
4353
}
4454
};
4555

46-
template <typename cT> struct ComplexGreater
56+
template <typename cT> struct ExtendedComplexFPLess
4757
{
58+
/* [(R, R), (R, nan), (nan, R), (nan, nan)] */
59+
4860
bool operator()(const cT &v1, const cT &v2) const
4961
{
50-
using dpctl::tensor::math_utils::greater_complex;
62+
using realT = typename cT::value_type;
63+
64+
const realT real1 = std::real(v1);
65+
const realT real2 = std::real(v2);
66+
67+
const bool r1_nan = sycl::isnan(real1);
68+
const bool r2_nan = sycl::isnan(real2);
69+
70+
const realT imag1 = std::imag(v1);
71+
const realT imag2 = std::imag(v2);
72+
73+
const bool i1_nan = sycl::isnan(imag1);
74+
const bool i2_nan = sycl::isnan(imag2);
5175

52-
return greater_complex(v1, v2);
76+
const int idx1 = ((r1_nan) ? 2 : 0) + ((i1_nan) ? 1 : 0);
77+
const int idx2 = ((r2_nan) ? 2 : 0) + ((i2_nan) ? 1 : 0);
78+
79+
const bool res =
80+
!(r1_nan && i1_nan) &&
81+
((idx1 < idx2) ||
82+
((idx1 == idx2) &&
83+
((r1_nan && !i1_nan && (imag1 < imag2)) ||
84+
(!r1_nan && i1_nan && (real1 < real2)) ||
85+
(!r1_nan && !i1_nan &&
86+
((real1 < real2) || (!(real2 < real1) && (imag1 < imag2)))))));
87+
88+
return res;
89+
}
90+
};
91+
92+
template <typename cT> struct ExtendedComplexFPGreater
93+
{
94+
bool operator()(const cT &v1, const cT &v2) const
95+
{
96+
auto less_ = ExtendedComplexFPLess<cT>{};
97+
return less_(v2, v1);
5398
}
5499
};
55100

101+
template <typename T>
102+
inline constexpr bool is_fp_v = (std::is_same_v<T, sycl::half> ||
103+
std::is_same_v<T, float> ||
104+
std::is_same_v<T, double>);
105+
106+
} // end of anonymous namespace
107+
56108
template <typename argTy> struct AscendingSorter
57109
{
58-
using type = std::less<argTy>;
110+
using type = std::conditional_t<is_fp_v<argTy>,
111+
ExtendedRealFPLess<argTy>,
112+
std::less<argTy>>;
59113
};
60114

61115
template <typename T> struct AscendingSorter<std::complex<T>>
62116
{
63-
using type = ComplexLess<std::complex<T>>;
117+
using type = ExtendedComplexFPLess<std::complex<T>>;
64118
};
65119

66120
template <typename argTy> struct DescendingSorter
67121
{
68-
using type = std::greater<argTy>;
122+
using type = std::conditional_t<is_fp_v<argTy>,
123+
ExtendedRealFPGreater<argTy>,
124+
std::greater<argTy>>;
69125
};
70126

71127
template <typename T> struct DescendingSorter<std::complex<T>>
72128
{
73-
using type = ComplexGreater<std::complex<T>>;
129+
using type = ExtendedComplexFPGreater<std::complex<T>>;
74130
};
75131

76132
} // end of namespace py_internal

0 commit comments

Comments
 (0)