@@ -282,6 +282,7 @@ quad_binary_op_resolve_descriptors(PyObject *self, PyArray_DTypeMeta *const dtyp
282
282
// Determine target backend and if casting is needed
283
283
NPY_CASTING casting = NPY_NO_CASTING;
284
284
if (descr_in1->backend != descr_in2->backend ) {
285
+
285
286
target_backend = BACKEND_LONGDOUBLE;
286
287
casting = NPY_SAFE_CASTING;
287
288
}
@@ -397,12 +398,12 @@ static int
397
398
quad_ufunc_promoter (PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
398
399
PyArray_DTypeMeta *signature[], PyArray_DTypeMeta *new_op_dtypes[])
399
400
{
400
- printf ( " called comparison promoter \n " );
401
+
401
402
int nin = ufunc->nin ;
402
403
int nargs = ufunc->nargs ;
403
404
PyArray_DTypeMeta *common = NULL ;
404
405
bool has_quad = false ;
405
- printf ( " dtyp1: %s dtype2: %s \n " , get_dtype_name (op_dtypes[ 0 ]), get_dtype_name (op_dtypes[ 1 ]));
406
+
406
407
// Handle the special case for reductions
407
408
if (op_dtypes[0 ] == NULL ) {
408
409
assert (nin == 2 && ufunc->nout == 1 ); /* must be reduction */
@@ -416,7 +417,7 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
416
417
// Check if any input or signature is QuadPrecision
417
418
for (int i = 0 ; i < nin; i++) {
418
419
if (op_dtypes[i] == &QuadPrecDType) {
419
- printf ( " Quaddtype found at index: %d \n " , i);
420
+
420
421
has_quad = true ;
421
422
}
422
423
}
@@ -460,7 +461,7 @@ quad_ufunc_promoter(PyUFuncObject *ufunc, PyArray_DTypeMeta *op_dtypes[],
460
461
else {
461
462
// Otherwise, use the common dtype
462
463
Py_INCREF (common);
463
- printf ( " setting output to %s dtype \n " , get_dtype_name (common));
464
+
464
465
new_op_dtypes[i] = common;
465
466
}
466
467
}
@@ -560,6 +561,47 @@ init_quad_binary_ops(PyObject *numpy)
560
561
561
562
// comparison functions
562
563
564
+ static NPY_CASTING
565
+ quad_comparison_op_resolve_descriptors (PyObject *self, PyArray_DTypeMeta *const dtypes[],
566
+ PyArray_Descr *const given_descrs[],
567
+ PyArray_Descr *loop_descrs[], npy_intp *NPY_UNUSED (view_offset))
568
+ {
569
+ QuadPrecDTypeObject *descr_in1 = (QuadPrecDTypeObject *)given_descrs[0 ];
570
+ QuadPrecDTypeObject *descr_in2 = (QuadPrecDTypeObject *)given_descrs[1 ];
571
+ QuadBackendType target_backend;
572
+
573
+ // As dealing with different backends then cast to boolean
574
+ NPY_CASTING casting = NPY_NO_CASTING;
575
+ if (descr_in1->backend != descr_in2->backend ) {
576
+ target_backend = BACKEND_LONGDOUBLE;
577
+ casting = NPY_SAFE_CASTING;
578
+ }
579
+ else {
580
+ target_backend = descr_in1->backend ;
581
+ }
582
+
583
+ // Set up input descriptors, casting if necessary
584
+ for (int i = 0 ; i < 2 ; i++) {
585
+ if (((QuadPrecDTypeObject *)given_descrs[i])->backend != target_backend) {
586
+ loop_descrs[i] = (PyArray_Descr *)new_quaddtype_instance (target_backend);
587
+ if (!loop_descrs[i]) {
588
+ return (NPY_CASTING)-1 ;
589
+ }
590
+ }
591
+ else {
592
+ Py_INCREF (given_descrs[i]);
593
+ loop_descrs[i] = given_descrs[i];
594
+ }
595
+ }
596
+
597
+ // Set up output descriptor
598
+ loop_descrs[2 ] = PyArray_DescrFromType (NPY_BOOL);
599
+ if (!loop_descrs[2 ]) {
600
+ return (NPY_CASTING)-1 ;
601
+ }
602
+ return casting;
603
+ }
604
+
563
605
template <cmp_quad_def sleef_comp, cmp_londouble_def ld_comp>
564
606
int
565
607
quad_generic_comp_strided_loop (PyArrayMethod_Context *context, char *const data[],
@@ -581,15 +623,18 @@ quad_generic_comp_strided_loop(PyArrayMethod_Context *context, char *const data[
581
623
while (N--) {
582
624
memcpy (&in1, in1_ptr, elem_size);
583
625
memcpy (&in2, in2_ptr, elem_size);
626
+ npy_bool result;
584
627
585
628
if (backend == BACKEND_SLEEF) {
586
- *((npy_bool *)out_ptr) = sleef_comp (&in1.sleef_value , &in2.sleef_value );
629
+ result = sleef_comp (&in1.sleef_value , &in2.sleef_value );
587
630
}
588
631
else {
589
- printf ( " %Lf % Lf \n " , in1. longdouble_value , in2. longdouble_value );
590
- *((npy_bool *)out_ptr) = ld_comp (&in1.longdouble_value , &in2.longdouble_value );
632
+
633
+ result = ld_comp (&in1.longdouble_value , &in2.longdouble_value );
591
634
}
592
635
636
+ *((npy_bool *)out_ptr) = result;
637
+
593
638
in1_ptr += in1_stride;
594
639
in2_ptr += in2_stride;
595
640
out_ptr += out_stride;
@@ -624,6 +669,7 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
624
669
PyArray_DTypeMeta *dtypes[3 ] = {&QuadPrecDType, &QuadPrecDType, &PyArray_BoolDType};
625
670
626
671
PyType_Slot slots[] = {
672
+ {NPY_METH_resolve_descriptors, (void *)&quad_comparison_op_resolve_descriptors},
627
673
{NPY_METH_strided_loop, (void *)&quad_generic_comp_strided_loop<sleef_comp, ld_comp>},
628
674
{NPY_METH_unaligned_strided_loop,
629
675
(void *)&quad_generic_comp_strided_loop<sleef_comp, ld_comp>},
@@ -633,7 +679,7 @@ create_quad_comparison_ufunc(PyObject *numpy, const char *ufunc_name)
633
679
.name = " quad_comp" ,
634
680
.nin = 2 ,
635
681
.nout = 1 ,
636
- .casting = NPY_NO_CASTING ,
682
+ .casting = NPY_SAFE_CASTING ,
637
683
.flags = NPY_METH_SUPPORTS_UNALIGNED,
638
684
.dtypes = dtypes,
639
685
.slots = slots,
0 commit comments