Skip to content

Commit c359768

Browse files
Update bitonic_sort with type hints, doctest, snake_case names (#4016)
* Updated input * Fix pre-commit error * Add type hints, doctests, black, snake_case Co-authored-by: Dhruv Manilawala <[email protected]>
1 parent 860d4f5 commit c359768

File tree

1 file changed

+91
-53
lines changed

1 file changed

+91
-53
lines changed

Diff for: sorts/bitonic_sort.py

+91-53
Original file line numberDiff line numberDiff line change
@@ -1,58 +1,96 @@
1-
# Python program for Bitonic Sort. Note that this program
2-
# works only when size of input is a power of 2.
3-
4-
5-
# The parameter dir indicates the sorting direction, ASCENDING
6-
# or DESCENDING; if (a[i] > a[j]) agrees with the direction,
7-
# then a[i] and a[j] are interchanged.
8-
def compAndSwap(a, i, j, dire):
9-
if (dire == 1 and a[i] > a[j]) or (dire == 0 and a[i] < a[j]):
10-
a[i], a[j] = a[j], a[i]
11-
12-
# It recursively sorts a bitonic sequence in ascending order,
13-
14-
15-
# if dir = 1, and in descending order otherwise (means dir=0).
16-
# The sequence to be sorted starts at index position low,
17-
# the parameter cnt is the number of elements to be sorted.
18-
def bitonic_merge(a, low, cnt, dire):
19-
if cnt > 1:
20-
k = int(cnt / 2)
21-
for i in range(low, low + k):
22-
compAndSwap(a, i, i + k, dire)
23-
bitonic_merge(a, low, k, dire)
24-
bitonic_merge(a, low + k, k, dire)
25-
26-
# This function first produces a bitonic sequence by recursively
27-
28-
29-
# sorting its two halves in opposite sorting orders, and then
30-
# calls bitonic_merge to make them in the same order
31-
def bitonic_sort(a, low, cnt, dire):
32-
if cnt > 1:
33-
k = int(cnt / 2)
34-
bitonic_sort(a, low, k, 1)
35-
bitonic_sort(a, low + k, k, 0)
36-
bitonic_merge(a, low, cnt, dire)
37-
38-
# Caller of bitonic_sort for sorting the entire array of length N
39-
40-
41-
# in ASCENDING order
42-
def sort(a, N, up):
43-
bitonic_sort(a, 0, N, up)
1+
"""
2+
Python program for Bitonic Sort.
3+
4+
Note that this program works only when size of input is a power of 2.
5+
"""
6+
from typing import List
7+
8+
9+
def comp_and_swap(array: List[int], index1: int, index2: int, direction: int) -> None:
10+
"""Compare the value at given index1 and index2 of the array and swap them as per
11+
the given direction.
12+
13+
The parameter direction indicates the sorting direction, ASCENDING(1) or
14+
DESCENDING(0); if (a[i] > a[j]) agrees with the direction, then a[i] and a[j] are
15+
interchanged.
16+
17+
>>> arr = [12, 42, -21, 1]
18+
>>> comp_and_swap(arr, 1, 2, 1)
19+
>>> print(arr)
20+
[12, -21, 42, 1]
21+
22+
>>> comp_and_swap(arr, 1, 2, 0)
23+
>>> print(arr)
24+
[12, 42, -21, 1]
25+
26+
>>> comp_and_swap(arr, 0, 3, 1)
27+
>>> print(arr)
28+
[1, 42, -21, 12]
29+
30+
>>> comp_and_swap(arr, 0, 3, 0)
31+
>>> print(arr)
32+
[12, 42, -21, 1]
33+
"""
34+
if (direction == 1 and array[index1] > array[index2]) or (
35+
direction == 0 and array[index1] < array[index2]
36+
):
37+
array[index1], array[index2] = array[index2], array[index1]
38+
39+
40+
def bitonic_merge(array: List[int], low: int, length: int, direction: int) -> None:
41+
"""
42+
It recursively sorts a bitonic sequence in ascending order, if direction = 1, and in
43+
descending if direction = 0.
44+
The sequence to be sorted starts at index position low, the parameter length is the
45+
number of elements to be sorted.
46+
47+
>>> arr = [12, 42, -21, 1]
48+
>>> bitonic_merge(arr, 0, 4, 1)
49+
>>> print(arr)
50+
[-21, 1, 12, 42]
51+
52+
>>> bitonic_merge(arr, 0, 4, 0)
53+
>>> print(arr)
54+
[42, 12, 1, -21]
55+
"""
56+
if length > 1:
57+
middle = int(length / 2)
58+
for i in range(low, low + middle):
59+
comp_and_swap(array, i, i + middle, direction)
60+
bitonic_merge(array, low, middle, direction)
61+
bitonic_merge(array, low + middle, middle, direction)
62+
63+
64+
def bitonic_sort(array: List[int], low: int, length: int, direction: int) -> None:
65+
"""
66+
This function first produces a bitonic sequence by recursively sorting its two
67+
halves in opposite sorting orders, and then calls bitonic_merge to make them in the
68+
same order.
69+
70+
>>> arr = [12, 34, 92, -23, 0, -121, -167, 145]
71+
>>> bitonic_sort(arr, 0, 8, 1)
72+
>>> arr
73+
[-167, -121, -23, 0, 12, 34, 92, 145]
74+
75+
>>> bitonic_sort(arr, 0, 8, 0)
76+
>>> arr
77+
[145, 92, 34, 12, 0, -23, -121, -167]
78+
"""
79+
if length > 1:
80+
middle = int(length / 2)
81+
bitonic_sort(array, low, middle, 1)
82+
bitonic_sort(array, low + middle, middle, 0)
83+
bitonic_merge(array, low, length, direction)
4484

4585

4686
if __name__ == "__main__":
87+
user_input = input("Enter numbers separated by a comma:\n").strip()
88+
unsorted = [int(item.strip()) for item in user_input.split(",")]
4789

48-
a = []
49-
50-
n = int(input().strip())
51-
for i in range(n):
52-
a.append(int(input().strip()))
53-
up = 1
90+
bitonic_sort(unsorted, 0, len(unsorted), 1)
91+
print("\nSorted array in ascending order is: ", end="")
92+
print(*unsorted, sep=", ")
5493

55-
sort(a, n, up)
56-
print("\n\nSorted array is")
57-
for i in range(n):
58-
print("%d" % a[i])
94+
bitonic_merge(unsorted, 0, len(unsorted), 0)
95+
print("Sorted array in descending order is: ", end="")
96+
print(*unsorted, sep=", ")

0 commit comments

Comments
 (0)