Skip to content

Commit 6ec4768

Browse files
Ensure that we use no more SLM than is available
1 parent a6118bd commit 6ec4768

File tree

1 file changed

+80
-57
lines changed

1 file changed

+80
-57
lines changed

dpctl/tensor/libtensor/include/kernels/sorting.hpp

Lines changed: 80 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -378,6 +378,60 @@ struct GetReadWriteAccess<sycl::buffer<ElementType, Dim, AllocatorT>>
378378
}
379379
};
380380

381+
template <typename T1, typename T2, typename Comp>
382+
class sort_base_step_contig_krn;
383+
384+
template <typename InpAcc, typename OutAcc, typename Comp>
385+
sycl::event
386+
sort_base_step_contig_impl(sycl::queue &q,
387+
size_t iter_nelems,
388+
size_t sort_nelems,
389+
const InpAcc input,
390+
OutAcc output,
391+
const Comp &comp,
392+
size_t &conseq_nelems_sorted,
393+
const std::vector<sycl::event> &depends = {})
394+
{
395+
396+
using inpT = typename GetValueType<InpAcc>::value_type;
397+
using outT = typename GetValueType<OutAcc>::value_type;
398+
using KernelName = sort_base_step_contig_krn<inpT, outT, Comp>;
399+
400+
conseq_nelems_sorted = (q.get_device().has(sycl::aspect::cpu) ? 16 : 4);
401+
402+
size_t n_segments =
403+
quotient_ceil<size_t>(sort_nelems, conseq_nelems_sorted);
404+
405+
sycl::event base_sort = q.submit([&](sycl::handler &cgh) {
406+
cgh.depends_on(depends);
407+
408+
sycl::range<1> gRange{iter_nelems * n_segments};
409+
410+
auto input_acc = GetReadOnlyAccess<InpAcc>{}(input, cgh);
411+
auto output_acc = GetWriteDiscardAccess<OutAcc>{}(output, cgh);
412+
413+
cgh.parallel_for<KernelName>(gRange, [=](sycl::id<1> id) {
414+
const size_t iter_id = id[0] / n_segments;
415+
const size_t segment_id = id[0] - iter_id * n_segments;
416+
417+
const size_t iter_offset = iter_id * sort_nelems;
418+
const size_t beg_id =
419+
iter_offset + segment_id * conseq_nelems_sorted;
420+
const size_t end_id =
421+
iter_offset +
422+
std::min<size_t>((segment_id + 1) * conseq_nelems_sorted,
423+
sort_nelems);
424+
for (size_t i = beg_id; i < end_id; ++i) {
425+
output_acc[i] = input_acc[i];
426+
}
427+
428+
leaf_sort_impl(output_acc, beg_id, end_id, comp);
429+
});
430+
});
431+
432+
return base_sort;
433+
}
434+
381435
template <typename T1, typename T2, typename Comp>
382436
class sort_over_work_group_contig_krn;
383437

@@ -393,8 +447,8 @@ sort_over_work_group_contig_impl(sycl::queue &q,
393447
const std::vector<sycl::event> &depends = {})
394448
{
395449
using inpT = typename GetValueType<InpAcc>::value_type;
396-
using outT = typename GetValueType<InpAcc>::value_type;
397-
using KernelName = sort_over_work_group_contig_krn<inpT, outT, Comp>;
450+
using T = typename GetValueType<OutAcc>::value_type;
451+
using KernelName = sort_over_work_group_contig_krn<inpT, T, Comp>;
398452

399453
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
400454

@@ -405,12 +459,30 @@ sort_over_work_group_contig_impl(sycl::queue &q,
405459

406460
auto krn = kb.template get_kernel(kernel_id);
407461

408-
std::uint32_t max_sg_size = krn.template get_info<
462+
const std::uint32_t max_sg_size = krn.template get_info<
409463
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
464+
const std::uint64_t device_local_memory_size =
465+
dev.get_info<sycl::info::device::local_mem_size>();
466+
467+
// leave 512 bytes of local memory for RT
468+
const std::uint64_t safety_margin = 512;
469+
470+
const std::uint64_t nelems_per_slm =
471+
(device_local_memory_size - safety_margin) / (2 * sizeof(T));
410472

411-
const size_t lws = 4 * max_sg_size;
412-
const std::uint32_t chunk_size = dev.has(sycl::aspect::cpu) ? 8 : 2;
413-
nelems_wg_sorts = chunk_size * lws;
473+
constexpr std::uint32_t sub_groups_per_work_group = 4;
474+
const std::uint32_t elems_per_wi = dev.has(sycl::aspect::cpu) ? 8 : 2;
475+
476+
const size_t lws = sub_groups_per_work_group * max_sg_size;
477+
478+
nelems_wg_sorts = elems_per_wi * lws;
479+
480+
if (nelems_wg_sorts > nelems_per_slm) {
481+
nelems_wg_sorts = 0;
482+
return sort_base_step_contig_impl<InpAcc, OutAcc, Comp>(
483+
q, iter_nelems, sort_nelems, input, output, comp, nelems_wg_sorts,
484+
depends);
485+
}
414486

415487
// This assumption permits doing away with using a loop
416488
assert(nelems_wg_sorts % lws == 0);
@@ -421,11 +493,12 @@ sort_over_work_group_contig_impl(sycl::queue &q,
421493
sycl::event base_sort_ev = q.submit([&](sycl::handler &cgh) {
422494
cgh.depends_on(depends);
423495

496+
cgh.use_kernel_bundle(kb);
497+
424498
sycl::range<1> global_range{iter_nelems * n_segments * lws};
425499
sycl::range<1> local_range{lws};
426500

427501
sycl::range<1> slm_range{nelems_wg_sorts};
428-
using T = typename GetValueType<OutAcc>::value_type;
429502
sycl::local_accessor<T, 1> work_space(slm_range, cgh);
430503
sycl::local_accessor<T, 1> scratch_space(slm_range, cgh);
431504

@@ -512,56 +585,6 @@ sort_over_work_group_contig_impl(sycl::queue &q,
512585
return base_sort_ev;
513586
}
514587

515-
template <typename T1, typename T2, typename Comp>
516-
class sort_base_step_contig_krn;
517-
518-
template <typename InpAcc, typename OutAcc, typename Comp>
519-
sycl::event
520-
sort_base_step_contig_impl(sycl::queue &q,
521-
size_t iter_nelems,
522-
size_t sort_nelems,
523-
const InpAcc input,
524-
OutAcc output,
525-
const Comp &comp,
526-
size_t &conseq_nelems_sorted,
527-
const std::vector<sycl::event> &depends = {})
528-
{
529-
530-
using inpT = typename GetValueType<InpAcc>::value_type;
531-
using outT = typename GetValueType<InpAcc>::value_type;
532-
using KernelName = sort_base_step_contig_krn<inpT, outT, Comp>;
533-
534-
conseq_nelems_sorted = (q.get_device().has(sycl::aspect::cpu) ? 16 : 4);
535-
536-
size_t n_segments =
537-
quotient_ceil<size_t>(sort_nelems, conseq_nelems_sorted);
538-
539-
sycl::event base_sort = q.submit([&](sycl::handler &cgh) {
540-
cgh.depends_on(depends);
541-
542-
sycl::range<1> gRange{iter_nelems * n_segments};
543-
cgh.parallel_for<KernelName>(gRange, [=](sycl::id<1> id) {
544-
const size_t iter_id = id[0] / n_segments;
545-
const size_t segment_id = id[0] - iter_id * n_segments;
546-
547-
const size_t iter_offset = iter_id * sort_nelems;
548-
const size_t beg_id =
549-
iter_offset + segment_id * conseq_nelems_sorted;
550-
const size_t end_id =
551-
iter_offset +
552-
std::min<size_t>((segment_id + 1) * conseq_nelems_sorted,
553-
sort_nelems);
554-
for (size_t i = beg_id; i < end_id; ++i) {
555-
output[i] = input[i];
556-
}
557-
558-
leaf_sort_impl(output, beg_id, end_id, comp);
559-
});
560-
});
561-
562-
return base_sort;
563-
}
564-
565588
class vacuous_krn;
566589

567590
inline sycl::event tie_events(sycl::queue &q,

0 commit comments

Comments
 (0)