Skip to content

Commit b82d1b4

Browse files
authored
Update travelling_salesman_problem.py
1 parent f178fa9 commit b82d1b4

File tree

1 file changed

+11
-10
lines changed

1 file changed

+11
-10
lines changed

dynamic_programming/travelling_salesman_problem.py

+11-10
Original file line numberDiff line numberDiff line change
@@ -5,47 +5,48 @@
55

66
def tsp(distances: list[list[int]]) -> int:
77
"""
8-
Solves the Travelling Salesman Problem (TSP) using
8+
Solves the Travelling Salesman Problem (TSP) using
99
dynamic programming and bitmasking.
1010
Args:
11-
distances: A 2D list where distances[i][j] represents the
11+
distances: 2D list where distances[i][j] is the
1212
distance between city i and city j.
1313
Returns:
14-
The minimum cost to complete the tour visiting all cities.
14+
Minimum cost to complete the
15+
tour visiting all cities.
1516
Raises:
1617
ValueError: If any distance is negative.
1718
1819
>>> tsp([[0, 10, 15, 20], [10, 0, 35, 25], [15, 35, 0, 30], [20, 25, 30, 0]])
1920
80
2021
>>> tsp([[0, 29, 20, 21], [29, 0, 15, 17], [20, 15, 0, 28], [21, 17, 28, 0]])
2122
69
22-
>>> tsp([[0, 10, -15, 20], [10, 0, 35, 25], [15, 35, 0, 30], [20, 25, 30, 0]])
23+
>>> tsp([[0, 10, -15, 20], [10, 0, 35, 25], [15, 35, 0, 30], [20, 25, 30, 0]])
2324
ValueError: Distance cannot be negative
2425
"""
2526
n = len(distances)
2627
if any(distances[i][j] < 0 for i in range(n) for j in range(n)):
2728
raise ValueError("Distance cannot be negative")
29+
2830
visited_all = (1 << n) - 1
2931

3032
@lru_cache(None)
3133
def visit(city: int, mask: int) -> int:
32-
"""Recursively calculates the minimum cost of visiting all cities."""
34+
"""Recursively calculates the minimum cost to visit all cities."""
3335
if mask == visited_all:
3436
return distances[city][0] # Return to start
3537

36-
min_cost = float("inf")
38+
min_cost = float('inf') # Large value to compare against
3739
for next_city in range(n):
38-
if not mask & (1 << next_city): # If next_city is unvisited
40+
if not mask & (1 << next_city): # If unvisited
3941
new_cost = distances[city][next_city] + visit(
4042
next_city, mask | (1 << next_city)
4143
)
4244
min_cost = min(min_cost, new_cost)
43-
return min_cost
45+
return int(min_cost) # Ensure returning an integer
4446

45-
return visit(0, 1) # Start from city 0 with only city 0 visited
47+
return visit(0, 1) # Start from city 0 with city 0 visited
4648

4749

4850
if __name__ == "__main__":
4951
import doctest
50-
5152
doctest.testmod()

0 commit comments

Comments
 (0)