@@ -2519,7 +2519,7 @@ def perform(self, node, inputs, output_storage):
2519
2519
)
2520
2520
2521
2521
def c_code_cache_version (self ):
2522
- return (6 ,)
2522
+ return (7 ,)
2523
2523
2524
2524
def c_code (self , node , name , inputs , outputs , sub ):
2525
2525
axis , * arrays = inputs
@@ -2558,16 +2558,86 @@ def c_code(self, node, name, inputs, outputs, sub):
2558
2558
code = f"""
2559
2559
int axis = { axis_def }
2560
2560
PyArrayObject* arrays[{ n } ] = {{{ ',' .join (arrays )} }};
2561
- PyObject* arrays_tuple = PyTuple_New( { n } ) ;
2561
+ int out_is_valid = { out } != NULL ;
2562
2562
2563
2563
{ axis_check }
2564
2564
2565
- Py_XDECREF({ out } );
2566
- { copy_arrays_to_tuple }
2567
- { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2568
- Py_DECREF(arrays_tuple);
2569
- if(!{ out } ){{
2570
- { fail }
2565
+ if (out_is_valid) {{
2566
+ // Check if we can reuse output
2567
+ npy_intp join_size = 0;
2568
+ npy_intp out_shape[{ ndim } ];
2569
+ npy_intp *shape = PyArray_SHAPE(arrays[0]);
2570
+
2571
+ for (int i = 0; i < { n } ; i++) {{
2572
+ if (PyArray_NDIM(arrays[i]) != { ndim } ) {{
2573
+ PyErr_SetString(PyExc_ValueError, "Input to join has wrong ndim");
2574
+ { fail }
2575
+ }}
2576
+
2577
+ join_size += PyArray_SHAPE(arrays[i])[axis];
2578
+
2579
+ if (i > 0){{
2580
+ for (int j = 0; j < { ndim } ; j++) {{
2581
+ if ((j != axis) && (PyArray_SHAPE(arrays[i])[j] != shape[j])) {{
2582
+ PyErr_SetString(PyExc_ValueError, "Arrays shape must match along non join axis");
2583
+ { fail }
2584
+ }}
2585
+ }}
2586
+ }}
2587
+ }}
2588
+
2589
+ memcpy(out_shape, shape, { ndim } * sizeof(npy_intp));
2590
+ out_shape[axis] = join_size;
2591
+
2592
+ for (int i = 0; i < { ndim } ; i++) {{
2593
+ out_is_valid &= (PyArray_SHAPE({ out } )[i] == out_shape[i]);
2594
+ }}
2595
+ }}
2596
+
2597
+ if (!out_is_valid) {{
2598
+ // Use PyArray_Concatenate
2599
+ Py_XDECREF({ out } );
2600
+ PyObject* arrays_tuple = PyTuple_New({ n } );
2601
+ { copy_arrays_to_tuple }
2602
+ { out } = (PyArrayObject *)PyArray_Concatenate(arrays_tuple, axis);
2603
+ Py_DECREF(arrays_tuple);
2604
+ if(!{ out } ){{
2605
+ { fail }
2606
+ }}
2607
+ }}
2608
+ else {{
2609
+ // Copy the data to the pre-allocated output buffer
2610
+
2611
+ // Create view into output buffer
2612
+ PyArrayObject_fields *view;
2613
+
2614
+ // PyArray_NewFromDescr steals a reference to descr, so we need to increase it
2615
+ Py_INCREF(PyArray_DESCR({ out } ));
2616
+ view = (PyArrayObject_fields *)PyArray_NewFromDescr(&PyArray_Type,
2617
+ PyArray_DESCR({ out } ),
2618
+ { ndim } ,
2619
+ PyArray_SHAPE(arrays[0]),
2620
+ PyArray_STRIDES({ out } ),
2621
+ PyArray_DATA({ out } ),
2622
+ NPY_ARRAY_WRITEABLE,
2623
+ NULL);
2624
+ if (view == NULL) {{
2625
+ { fail }
2626
+ }}
2627
+
2628
+ // Copy data into output buffer
2629
+ for (int i = 0; i < { n } ; i++) {{
2630
+ view->dimensions[axis] = PyArray_SHAPE(arrays[i])[axis];
2631
+
2632
+ if (PyArray_CopyInto((PyArrayObject*)view, arrays[i]) != 0) {{
2633
+ Py_DECREF(view);
2634
+ { fail }
2635
+ }}
2636
+
2637
+ view->data += (view->dimensions[axis] * view->strides[axis]);
2638
+ }}
2639
+
2640
+ Py_DECREF(view);
2571
2641
}}
2572
2642
"""
2573
2643
return code
0 commit comments