Skip to content

Commit 707d5d6

Browse files
committed
fixed longdouble comparison casting issue
1 parent 4f0a604 commit 707d5d6

File tree

2 files changed

+54
-9
lines changed

2 files changed

+54
-9
lines changed

quaddtype/numpy_quaddtype/src/casts.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,6 @@ numpy_to_quad_resolve_descriptors(PyObject *NPY_UNUSED(self), PyArray_DTypeMeta
274274
{
275275
// todo: here it is converting this to SLEEF, losing data and getting 0
276276
if (given_descrs[1] == NULL) {
277-
printf("called\n");
278277
loop_descrs[1] = (PyArray_Descr *)new_quaddtype_instance(BACKEND_SLEEF);
279278
if (loop_descrs[1] == nullptr) {
280279
return (NPY_CASTING)-1;

quaddtype/numpy_quaddtype/src/umath.cpp

+54-8
Original file line numberDiff line numberDiff line change
@@ -282,6 +282,7 @@ quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtyp
282282
// Determine target backend and if casting is needed
283283
NPY_CASTING casting = NPY_NO_CASTING;
284284
if (descr_in1->backend != descr_in2->backend) {
285+
285286
target_backend = BACKEND_LONGDOUBLE;
286287
casting = NPY_SAFE_CASTING;
287288
}
@@ -397,12 +398,12 @@ static int
397398
quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
398399
PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
399400
{
400-
printf("called comparison promoter\n");
401+
401402
int nin = ufunc->nin;
402403
int nargs = ufunc->nargs;
403404
PyArray_DTypeMeta *common = NULL;
404405
bool has_quad = false;
405-
printf("dtyp1: %s dtype2: %s\n", get_dtype_name(op_dtypes[0]), get_dtype_name(op_dtypes[1]));
406+
406407
// Handle the special case for reductions
407408
if (op_dtypes[0] == NULL) {
408409
assert(nin == 2 && ufunc->nout == 1); /* must be reduction */
@@ -416,7 +417,7 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
416417
// Check if any input or signature is QuadPrecision
417418
for (int i = 0; i < nin; i++) {
418419
if (op_dtypes[i] == &QuadPrecDType) {
419-
printf("Quaddtype found at index: %d\n", i);
420+
420421
has_quad = true;
421422
}
422423
}
@@ -460,7 +461,7 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
460461
else {
461462
// Otherwise, use the common dtype
462463
Py_INCREF(common);
463-
printf("setting output to %s dtype\n", get_dtype_name(common));
464+
464465
new_op_dtypes[i] = common;
465466
}
466467
}
@@ -560,6 +561,47 @@ init_quad_binary_ops(PyObject *numpy)
560561

561562
// comparison functions
562563

564+
static NPY_CASTING
565+
quad_comparison_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtypes[],
566+
PyArray_Descr *const given_descrs[],
567+
PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED(view_offset))
568+
{
569+
QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0];
570+
QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1];
571+
QuadBackendType target_backend;
572+
573+
// As dealing with different backends then cast to boolean
574+
NPY_CASTING casting = NPY_NO_CASTING;
575+
if (descr_in1->backend != descr_in2->backend) {
576+
target_backend = BACKEND_LONGDOUBLE;
577+
casting = NPY_SAFE_CASTING;
578+
}
579+
else {
580+
target_backend = descr_in1->backend;
581+
}
582+
583+
// Set up input descriptors, casting if necessary
584+
for (int i = 0; i < 2; i++) {
585+
if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
586+
loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance(target_backend);
587+
if (!loop_descrs[i]) {
588+
return (NPY_CASTING)-1;
589+
}
590+
}
591+
else {
592+
Py_INCREF(given_descrs[i]);
593+
loop_descrs[i] = given_descrs[i];
594+
}
595+
}
596+
597+
// Set up output descriptor
598+
loop_descrs[2] = PyArray_DescrFromType(NPY_BOOL);
599+
if (!loop_descrs[2]) {
600+
return (NPY_CASTING)-1;
601+
}
602+
return casting;
603+
}
604+
563605
template <cmp_quad_def sleef_comp, cmp_londouble_def ld_comp>
564606
int
565607
quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[],
@@ -581,15 +623,18 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
581623
while (N--) {
582624
memcpy(&in1, in1_ptr, elem_size);
583625
memcpy(&in2, in2_ptr, elem_size);
626+
npy_bool result;
584627

585628
if (backend == BACKEND_SLEEF) {
586-
*((npy_bool *)out_ptr) = sleef_comp(&in1.sleef_value, &in2.sleef_value);
629+
result = sleef_comp(&in1.sleef_value, &in2.sleef_value);
587630
}
588631
else {
589-
printf("%Lf % Lf\n", in1.longdouble_value, in2.longdouble_value);
590-
*((npy_bool *)out_ptr) = ld_comp(&in1.longdouble_value, &in2.longdouble_value);
632+
633+
result = ld_comp(&in1.longdouble_value, &in2.longdouble_value);
591634
}
592635

636+
*((npy_bool *)out_ptr) = result;
637+
593638
in1_ptr += in1_stride;
594639
in2_ptr += in2_stride;
595640
out_ptr += out_stride;
@@ -624,6 +669,7 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
624669
PyArray_DTypeMeta *dtypes[3] = {&QuadPrecDType, &QuadPrecDType, &PyArray_BoolDType};
625670

626671
PyType_Slot slots[] = {
672+
{NPY_METH_resolve_descriptors, (void *)&quad_comparison_op_resolve_descriptors},
627673
{NPY_METH_strided_loop, (void *)&quad_generic_comp_strided_loop<sleef_comp, ld_comp>},
628674
{NPY_METH_unaligned_strided_loop,
629675
(void *)&quad_generic_comp_strided_loop<sleef_comp, ld_comp>},
@@ -633,7 +679,7 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
633679
.name = "quad_comp",
634680
.nin = 2,
635681
.nout = 1,
636-
.casting = NPY_NO_CASTING,
682+
.casting = NPY_SAFE_CASTING,
637683
.flags = NPY_METH_SUPPORTS_UNALIGNED,
638684
.dtypes = dtypes,
639685
.slots = slots,

0 commit comments

Comments
 (0)