24
24
25
25
#pragma once
26
26
27
- #include " utils/math_utils.hpp"
27
+ #include " sycl/sycl.hpp"
28
+ #include < type_traits>
28
29
29
30
namespace dpctl
30
31
{
@@ -33,44 +34,99 @@ namespace tensor
33
34
namespace py_internal
34
35
{
35
36
36
- template < typename cT> struct ComplexLess
37
+ namespace
37
38
{
38
- bool operator ()(const cT &v1, const cT &v2) const
39
+ template <typename fpT> struct ExtendedRealFPLess
40
+ {
41
+ /* [R, nan] */
42
+ bool operator ()(const fpT v1, const fpT v2) const
39
43
{
40
- using dpctl::tensor::math_utils::less_complex;
44
+ return (!sycl::isnan (v1) && (sycl::isnan (v2) || (v1 < v2)));
45
+ }
46
+ };
41
47
42
- return less_complex (v1, v2);
48
+ template <typename fpT> struct ExtendedRealFPGreater
49
+ {
50
+ bool operator ()(const fpT v1, const fpT v2) const
51
+ {
52
+ return (!sycl::isnan (v2) && (sycl::isnan (v1) || (v2 < v1)));
43
53
}
44
54
};
45
55
46
- template <typename cT> struct ComplexGreater
56
+ template <typename cT> struct ExtendedComplexFPLess
47
57
{
58
+ /* [(R, R), (R, nan), (nan, R), (nan, nan)] */
59
+
48
60
bool operator ()(const cT &v1, const cT &v2) const
49
61
{
50
- using dpctl::tensor::math_utils::greater_complex;
62
+ using realT = typename cT::value_type;
63
+
64
+ const realT real1 = std::real (v1);
65
+ const realT real2 = std::real (v2);
66
+
67
+ const bool r1_nan = sycl::isnan (real1);
68
+ const bool r2_nan = sycl::isnan (real2);
69
+
70
+ const realT imag1 = std::imag (v1);
71
+ const realT imag2 = std::imag (v2);
72
+
73
+ const bool i1_nan = sycl::isnan (imag1);
74
+ const bool i2_nan = sycl::isnan (imag2);
51
75
52
- return greater_complex (v1, v2);
76
+ const int idx1 = ((r1_nan) ? 2 : 0 ) + ((i1_nan) ? 1 : 0 );
77
+ const int idx2 = ((r2_nan) ? 2 : 0 ) + ((i2_nan) ? 1 : 0 );
78
+
79
+ const bool res =
80
+ !(r1_nan && i1_nan) &&
81
+ ((idx1 < idx2) ||
82
+ ((idx1 == idx2) &&
83
+ ((r1_nan && !i1_nan && (imag1 < imag2)) ||
84
+ (!r1_nan && i1_nan && (real1 < real2)) ||
85
+ (!r1_nan && !i1_nan &&
86
+ ((real1 < real2) || (!(real2 < real1) && (imag1 < imag2)))))));
87
+
88
+ return res;
89
+ }
90
+ };
91
+
92
+ template <typename cT> struct ExtendedComplexFPGreater
93
+ {
94
+ bool operator ()(const cT &v1, const cT &v2) const
95
+ {
96
+ auto less_ = ExtendedComplexFPLess<cT>{};
97
+ return less_ (v2, v1);
53
98
}
54
99
};
55
100
101
+ template <typename T>
102
+ inline constexpr bool is_fp_v = (std::is_same_v<T, sycl::half> ||
103
+ std::is_same_v<T, float > ||
104
+ std::is_same_v<T, double >);
105
+
106
+ } // end of anonymous namespace
107
+
56
108
template <typename argTy> struct AscendingSorter
57
109
{
58
- using type = std::less<argTy>;
110
+ using type = std::conditional_t <is_fp_v<argTy>,
111
+ ExtendedRealFPLess<argTy>,
112
+ std::less<argTy>>;
59
113
};
60
114
61
115
template <typename T> struct AscendingSorter <std::complex<T>>
62
116
{
63
- using type = ComplexLess <std::complex<T>>;
117
+ using type = ExtendedComplexFPLess <std::complex<T>>;
64
118
};
65
119
66
120
template <typename argTy> struct DescendingSorter
67
121
{
68
- using type = std::greater<argTy>;
122
+ using type = std::conditional_t <is_fp_v<argTy>,
123
+ ExtendedRealFPGreater<argTy>,
124
+ std::greater<argTy>>;
69
125
};
70
126
71
127
template <typename T> struct DescendingSorter <std::complex<T>>
72
128
{
73
- using type = ComplexGreater <std::complex<T>>;
129
+ using type = ExtendedComplexFPGreater <std::complex<T>>;
74
130
};
75
131
76
132
} // end of namespace py_internal
0 commit comments