Skip to content

Commit a6118bd

Browse files
Corrected variable name used in assertion
Removed dead code, added include of cassert, added lincese header comment.
1 parent 2df8f6a commit a6118bd

File tree

1 file changed

+25
-118
lines changed

1 file changed

+25
-118
lines changed

dpctl/tensor/libtensor/include/kernels/sorting.hpp

Lines changed: 25 additions & 118 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,32 @@
1+
//=== sorting.hpp - Implementation of sorting kernels ---*-C++-*--/===//
2+
//
3+
// Data Parallel Control (dpctl)
4+
//
5+
// Copyright 2020-2023 Intel Corporation
6+
//
7+
// Licensed under the Apache License, Version 2.0 (the "License");
8+
// you may not use this file except in compliance with the License.
9+
// You may obtain a copy of the License at
10+
//
11+
// http://www.apache.org/licenses/LICENSE-2.0
12+
//
13+
// Unless required by applicable law or agreed to in writing, software
14+
// distributed under the License is distributed on an "AS IS" BASIS,
15+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
16+
// See the License for the specific language governing permissions and
17+
// limitations under the License.
18+
//
19+
//===----------------------------------------------------------------------===//
20+
///
21+
/// \file
22+
/// This file defines kernels for tensor sort/argsort operations.
23+
//===----------------------------------------------------------------------===//
24+
125
#pragma once
226

327
#include "pybind11/pybind11.h"
428

29+
#include <cassert>
530
#include <functional>
631
#include <iterator>
732
#include <sycl/sycl.hpp>
@@ -537,115 +562,6 @@ sort_base_step_contig_impl(sycl::queue &q,
537562
return base_sort;
538563
}
539564

540-
template <typename T1, typename T2, typename Comp>
541-
class exp_sort_over_work_group_contig_krn;
542-
543-
template <typename InpAcc, typename OutAcc, typename Comp>
544-
sycl::event exp_sort_over_work_group_contig_impl(
545-
sycl::queue &q,
546-
size_t iter_nelems,
547-
size_t sort_nelems,
548-
const InpAcc input,
549-
OutAcc output,
550-
const Comp &comp,
551-
size_t &conseq_nelems_sorted,
552-
const std::vector<sycl::event> &depends = {})
553-
{
554-
555-
using inpT = typename GetValueType<InpAcc>::value_type;
556-
using outT = typename GetValueType<InpAcc>::value_type;
557-
using KernelName = exp_sort_over_work_group_contig_krn<inpT, outT, Comp>;
558-
559-
const auto &kernel_id = sycl::get_kernel_id<KernelName>();
560-
561-
auto const &ctx = q.get_context();
562-
auto const &dev = q.get_device();
563-
auto kb = sycl::get_kernel_bundle<sycl::bundle_state::executable>(
564-
ctx, {dev}, {kernel_id});
565-
566-
auto krn = kb.template get_kernel(kernel_id);
567-
568-
std::uint32_t max_sg_size = krn.template get_info<
569-
sycl::info::kernel_device_specific::max_sub_group_size>(dev);
570-
571-
const size_t lws = 4 * max_sg_size;
572-
const std::uint32_t chunk_size = dev.has(sycl::aspect::cpu) ? 8 : 2;
573-
conseq_nelems_sorted = chunk_size * lws;
574-
575-
// This assumption permits doing away with using a loop
576-
assert(nelems_wg_sorts % lws == 0);
577-
578-
sycl::event exp_default_sort_ev = q.submit([&](sycl::handler &cgh) {
579-
cgh.depends_on(depends);
580-
581-
const size_t n_chunks =
582-
quotient_ceil<size_t>(sort_nelems, conseq_nelems_sorted);
583-
584-
sycl::range<1> global_range{iter_nelems * n_chunks * lws};
585-
sycl::range<1> local_range{lws};
586-
587-
sycl::range<1> slm_size_range{conseq_nelems_sorted};
588-
589-
using Sorter = sycl::ext::oneapi::experimental::default_sorter<Comp>;
590-
591-
// calculate required local memory size
592-
// MUST pass local_range, not integer.
593-
// Have different meanings
594-
const size_t temp_memory_size = Sorter::template memory_required<outT>(
595-
sycl::memory_scope::work_group, conseq_nelems_sorted);
596-
597-
sycl::local_accessor<outT, 1> workspace(slm_size_range, cgh);
598-
sycl::local_accessor<std::byte, 1> scratch({temp_memory_size}, cgh);
599-
600-
cgh.parallel_for<KernelName>(
601-
sycl::nd_range<1>(global_range, local_range),
602-
[=](sycl::nd_item<1> it) {
603-
auto sorter_op = Sorter(sycl::span<std::byte>{
604-
scratch
605-
.template get_multi_ptr<sycl::access::decorated::no>()
606-
.get(),
607-
temp_memory_size});
608-
609-
const size_t gr_id = it.get_group_linear_id();
610-
const size_t iter_id = gr_id / n_chunks;
611-
const size_t sort_chunk_id = gr_id - iter_id * n_chunks;
612-
613-
const std::uint32_t lid = it.get_local_linear_id();
614-
615-
const size_t iter_offset = iter_id * sort_nelems;
616-
const size_t chunk_offset =
617-
sort_chunk_id * conseq_nelems_sorted;
618-
const size_t global_start_offset = iter_offset + chunk_offset;
619-
const size_t workspace_size =
620-
std::min<size_t>(sort_nelems,
621-
chunk_offset + conseq_nelems_sorted) -
622-
chunk_offset;
623-
for (std::uint32_t i = lid; i < workspace_size; i += lws) {
624-
workspace[i] = input[global_start_offset + i];
625-
}
626-
sycl::group_barrier(it.get_group());
627-
628-
sycl::ext::oneapi::experimental::joint_sort(
629-
it.get_group(),
630-
workspace
631-
.template get_multi_ptr<sycl::access::decorated::no>()
632-
.get(),
633-
workspace
634-
.template get_multi_ptr<
635-
sycl::access::decorated::no>()
636-
.get() +
637-
workspace_size,
638-
sorter_op);
639-
640-
for (std::uint32_t i = lid; i < workspace_size; i += lws) {
641-
output[global_start_offset + i] = workspace[i];
642-
}
643-
});
644-
});
645-
646-
return exp_default_sort_ev;
647-
}
648-
649565
class vacuous_krn;
650566

651567
inline sycl::event tie_events(sycl::queue &q,
@@ -847,20 +763,11 @@ sycl::event stable_sort_axis1_contig_impl(
847763
(sort_nelems >= 512) ? 512 : determine_automatically;
848764

849765
// Sort segments of the array
850-
#if 1
851766
sycl::event base_sort_ev = sort_detail::sort_over_work_group_contig_impl<
852767
const argTy *, argTy *, Comp>(
853768
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
854769
sorted_block_size, // modified in place with size of sorted block size
855770
depends);
856-
#else
857-
sycl::event base_sort_ev =
858-
sort_detail::sort_base_step_contig_impl<const argTy *, argTy *, Comp>(
859-
exec_q, iter_nelems, sort_nelems, arg_tp, res_tp, comp,
860-
sorted_block_size, // modified in place with size of sorted block
861-
// size
862-
depends);
863-
#endif
864771

865772
// Merge segments in parallel until all elements are sorted
866773
sycl::event merges_ev =

0 commit comments

Comments
 (0)