|
| 1 | +//=== sorting.hpp - Implementation of sorting kernels ---*-C++-*--/===// |
| 2 | +// |
| 3 | +// Data Parallel Control (dpctl) |
| 4 | +// |
| 5 | +// Copyright 2020-2023 Intel Corporation |
| 6 | +// |
| 7 | +// Licensed under the Apache License, Version 2.0 (the "License"); |
| 8 | +// you may not use this file except in compliance with the License. |
| 9 | +// You may obtain a copy of the License at |
| 10 | +// |
| 11 | +// http://www.apache.org/licenses/LICENSE-2.0 |
| 12 | +// |
| 13 | +// Unless required by applicable law or agreed to in writing, software |
| 14 | +// distributed under the License is distributed on an "AS IS" BASIS, |
| 15 | +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 16 | +// See the License for the specific language governing permissions and |
| 17 | +// limitations under the License. |
| 18 | +// |
| 19 | +//===----------------------------------------------------------------------===// |
| 20 | +/// |
| 21 | +/// \file |
| 22 | +/// This file defines kernels for tensor sort/argsort operations. |
| 23 | +//===----------------------------------------------------------------------===// |
| 24 | + |
1 | 25 | #pragma once
|
2 | 26 |
|
3 | 27 | #include "pybind11/pybind11.h"
|
4 | 28 |
|
| 29 | +#include <cassert> |
5 | 30 | #include <functional>
|
6 | 31 | #include <iterator>
|
7 | 32 | #include <sycl/sycl.hpp>
|
@@ -537,115 +562,6 @@ sort_base_step_contig_impl(sycl::queue &q,
|
537 | 562 | return base_sort;
|
538 | 563 | }
|
539 | 564 |
|
540 |
| -template <typename T1, typename T2, typename Comp> |
541 |
| -class exp_sort_over_work_group_contig_krn; |
542 |
| - |
543 |
| -template <typename InpAcc, typename OutAcc, typename Comp> |
544 |
| -sycl::event exp_sort_over_work_group_contig_impl( |
545 |
| - sycl::queue &q, |
546 |
| - size_t iter_nelems, |
547 |
| - size_t sort_nelems, |
548 |
| - const InpAcc input, |
549 |
| - OutAcc output, |
550 |
| - const Comp &comp, |
551 |
| - size_t &conseq_nelems_sorted, |
552 |
| - const std::vector<sycl::event> &depends = {}) |
553 |
| -{ |
554 |
| - |
555 |
| - using inpT = typename GetValueType<InpAcc>::value_type; |
556 |
| - using outT = typename GetValueType<InpAcc>::value_type; |
557 |
| - using KernelName = exp_sort_over_work_group_contig_krn<inpT, outT, Comp>; |
558 |
| - |
559 |
| - const auto &kernel_id = sycl::get_kernel_id<KernelName>(); |
560 |
| - |
561 |
| - auto const &ctx = q.get_context(); |
562 |
| - auto const &dev = q.get_device(); |
563 |
| - auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>( |
564 |
| - ctx, {dev}, {kernel_id}); |
565 |
| - |
566 |
| - auto krn = kb.template get_kernel(kernel_id); |
567 |
| - |
568 |
| - std::uint32_t max_sg_size = krn.template get_info< |
569 |
| - sycl::info::kernel_device_specific::max_sub_group_size>(dev); |
570 |
| - |
571 |
| - const size_t lws = 4 * max_sg_size; |
572 |
| - const std::uint32_t chunk_size = dev.has(sycl::aspect::cpu) ? 8 : 2; |
573 |
| - conseq_nelems_sorted = chunk_size * lws; |
574 |
| - |
575 |
| - // This assumption permits doing away with using a loop |
576 |
| - assert(nelems_wg_sorts % lws == 0); |
577 |
| - |
578 |
| - sycl::event exp_default_sort_ev = q.submit([&](sycl::handler &cgh) { |
579 |
| - cgh.depends_on(depends); |
580 |
| - |
581 |
| - const size_t n_chunks = |
582 |
| - quotient_ceil<size_t>(sort_nelems, conseq_nelems_sorted); |
583 |
| - |
584 |
| - sycl::range<1> global_range{iter_nelems * n_chunks * lws}; |
585 |
| - sycl::range<1> local_range{lws}; |
586 |
| - |
587 |
| - sycl::range<1> slm_size_range{conseq_nelems_sorted}; |
588 |
| - |
589 |
| - using Sorter = sycl::ext::oneapi::experimental::default_sorter<Comp>; |
590 |
| - |
591 |
| - // calculate required local memory size |
592 |
| - // MUST pass local_range, not integer. |
593 |
| - // Have different meanings |
594 |
| - const size_t temp_memory_size = Sorter::template memory_required<outT>( |
595 |
| - sycl::memory_scope::work_group, conseq_nelems_sorted); |
596 |
| - |
597 |
| - sycl::local_accessor<outT, 1> workspace(slm_size_range, cgh); |
598 |
| - sycl::local_accessor<std::byte, 1> scratch({temp_memory_size}, cgh); |
599 |
| - |
600 |
| - cgh.parallel_for<KernelName>( |
601 |
| - sycl::nd_range<1>(global_range, local_range), |
602 |
| - [=](sycl::nd_item<1> it) { |
603 |
| - auto sorter_op = Sorter(sycl::span<std::byte>{ |
604 |
| - scratch |
605 |
| - .template get_multi_ptr<sycl::access::decorated::no>() |
606 |
| - .get(), |
607 |
| - temp_memory_size}); |
608 |
| - |
609 |
| - const size_t gr_id = it.get_group_linear_id(); |
610 |
| - const size_t iter_id = gr_id / n_chunks; |
611 |
| - const size_t sort_chunk_id = gr_id - iter_id * n_chunks; |
612 |
| - |
613 |
| - const std::uint32_t lid = it.get_local_linear_id(); |
614 |
| - |
615 |
| - const size_t iter_offset = iter_id * sort_nelems; |
616 |
| - const size_t chunk_offset = |
617 |
| - sort_chunk_id * conseq_nelems_sorted; |
618 |
| - const size_t global_start_offset = iter_offset + chunk_offset; |
619 |
| - const size_t workspace_size = |
620 |
| - std::min<size_t>(sort_nelems, |
621 |
| - chunk_offset + conseq_nelems_sorted) - |
622 |
| - chunk_offset; |
623 |
| - for (std::uint32_t i = lid; i < workspace_size; i += lws) { |
624 |
| - workspace[i] = input[global_start_offset + i]; |
625 |
| - } |
626 |
| - sycl::group_barrier(it.get_group()); |
627 |
| - |
628 |
| - sycl::ext::oneapi::experimental::joint_sort( |
629 |
| - it.get_group(), |
630 |
| - workspace |
631 |
| - .template get_multi_ptr<sycl::access::decorated::no>() |
632 |
| - .get(), |
633 |
| - workspace |
634 |
| - .template get_multi_ptr< |
635 |
| - sycl::access::decorated::no>() |
636 |
| - .get() + |
637 |
| - workspace_size, |
638 |
| - sorter_op); |
639 |
| - |
640 |
| - for (std::uint32_t i = lid; i < workspace_size; i += lws) { |
641 |
| - output[global_start_offset + i] = workspace[i]; |
642 |
| - } |
643 |
| - }); |
644 |
| - }); |
645 |
| - |
646 |
| - return exp_default_sort_ev; |
647 |
| -} |
648 |
| - |
649 | 565 | class vacuous_krn;
|
650 | 566 |
|
651 | 567 | inline sycl::event tie_events(sycl::queue &q,
|
@@ -847,20 +763,11 @@ sycl::event stable_sort_axis1_contig_impl(
|
847 | 763 | (sort_nelems >= 512) ? 512 : determine_automatically;
|
848 | 764 |
|
849 | 765 | // Sort segments of the array
|
850 |
| -#if 1 |
851 | 766 | sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl<
|
852 | 767 | const argTy *, argTy *, Comp>(
|
853 | 768 | exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
|
854 | 769 | sorted_block_size, // modified in place with size of sorted block size
|
855 | 770 | depends);
|
856 |
| -#else |
857 |
| - sycl::event base_sort_ev = |
858 |
| - sort_detail::sort_base_step_contig_impl<const argTy *, argTy *, Comp>( |
859 |
| - exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp, |
860 |
| - sorted_block_size, // modified in place with size of sorted block |
861 |
| - // size |
862 |
| - depends); |
863 |
| -#endif |
864 | 771 |
|
865 | 772 | // Merge segments in parallel until all elements are sorted
|
866 | 773 | sycl::event merges_ev =
|
|
0 commit comments