@@ -384,28 +384,26 @@ class sort_base_step_contig_krn;
384
384
template <typename InpAcc, typename OutAcc, typename Comp>
385
385
sycl::event
386
386
sort_base_step_contig_impl (sycl::queue &q,
387
- size_t iter_nelems,
388
- size_t sort_nelems,
387
+ const size_t iter_nelems,
388
+ const size_t sort_nelems,
389
389
const InpAcc input,
390
390
OutAcc output,
391
391
const Comp &comp,
392
- size_t & conseq_nelems_sorted,
392
+ const size_t conseq_nelems_sorted,
393
393
const std::vector<sycl::event> &depends = {})
394
394
{
395
395
396
396
using inpT = typename GetValueType<InpAcc>::value_type;
397
397
using outT = typename GetValueType<OutAcc>::value_type;
398
398
using KernelName = sort_base_step_contig_krn<inpT, outT, Comp>;
399
399
400
- conseq_nelems_sorted = (q.get_device ().has (sycl::aspect::cpu) ? 16 : 4 );
401
-
402
- size_t n_segments =
400
+ const size_t n_segments =
403
401
quotient_ceil<size_t >(sort_nelems, conseq_nelems_sorted);
404
402
405
403
sycl::event base_sort = q.submit ([&](sycl::handler &cgh) {
406
404
cgh.depends_on (depends);
407
405
408
- sycl::range<1 > gRange {iter_nelems * n_segments};
406
+ const sycl::range<1 > gRange {iter_nelems * n_segments};
409
407
410
408
auto input_acc = GetReadOnlyAccess<InpAcc>{}(input, cgh);
411
409
auto output_acc = GetWriteDiscardAccess<OutAcc>{}(output, cgh);
@@ -478,7 +476,8 @@ sort_over_work_group_contig_impl(sycl::queue &q,
478
476
nelems_wg_sorts = elems_per_wi * lws;
479
477
480
478
if (nelems_wg_sorts > nelems_per_slm) {
481
- nelems_wg_sorts = 0 ;
479
+ nelems_wg_sorts = (q.get_device ().has (sycl::aspect::cpu) ? 16 : 4 );
480
+
482
481
return sort_base_step_contig_impl<InpAcc, OutAcc, Comp>(
483
482
q, iter_nelems, sort_nelems, input, output, comp, nelems_wg_sorts,
484
483
depends);
@@ -781,24 +780,38 @@ sycl::event stable_sort_axis1_contig_impl(
781
780
782
781
auto comp = Comp{};
783
782
784
- static constexpr size_t determine_automatically = 0 ;
785
- size_t sorted_block_size =
786
- (sort_nelems >= 512 ) ? 512 : determine_automatically;
783
+ constexpr size_t sequential_sorting_threshold = 64 ;
787
784
788
- // Sort segments of the array
789
- sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl<
790
- const argTy *, argTy *, Comp>(
791
- exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
792
- sorted_block_size, // modified in place with size of sorted block size
793
- depends);
794
-
795
- // Merge segments in parallel until all elements are sorted
796
- sycl::event merges_ev =
797
- sort_detail::merge_sorted_block_contig_impl<argTy *, Comp>(
798
- exec_q, iter_nelems, sort_nelems, res_tp, comp, sorted_block_size,
799
- {base_sort_ev});
785
+ if (sort_nelems < sequential_sorting_threshold) {
786
+ // equal work-item sorts entire row
787
+ sycl::event sequential_sorting_ev =
788
+ sort_detail::sort_base_step_contig_impl<const argTy *, argTy *,
789
+ Comp>(
790
+ exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
791
+ sort_nelems, depends);
800
792
801
- return merges_ev;
793
+ return sequential_sorting_ev;
794
+ }
795
+ else {
796
+ size_t sorted_block_size{};
797
+
798
+ // Sort segments of the array
799
+ sycl::event base_sort_ev =
800
+ sort_detail::sort_over_work_group_contig_impl<const argTy *,
801
+ argTy *, Comp>(
802
+ exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
803
+ sorted_block_size, // modified in place with size of sorted
804
+ // block size
805
+ depends);
806
+
807
+ // Merge segments in parallel until all elements are sorted
808
+ sycl::event merges_ev =
809
+ sort_detail::merge_sorted_block_contig_impl<argTy *, Comp>(
810
+ exec_q, iter_nelems, sort_nelems, res_tp, comp,
811
+ sorted_block_size, {base_sort_ev});
812
+
813
+ return merges_ev;
814
+ }
802
815
}
803
816
804
817
template <typename T1, typename T2, typename T3>
0 commit comments