Skip to content

Commit e321b1e

Browse files
Added TSP
1 parent 3ceccfb commit e321b1e

File tree

1 file changed

+226
-0
lines changed

1 file changed

+226
-0
lines changed

travelling_salesman_problem.py

+226
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,226 @@
1+
""" Travelling Salesman Problem (TSP) """
2+
3+
import itertools
4+
import math
5+
6+
class InvalidGraphError(ValueError):
7+
"""Custom error for invalid graph inputs."""
8+
9+
def euclidean_distance(point1: list[float], point2: list[float]) -> float:
10+
"""
11+
Calculate the Euclidean distance between two points in 2D space.
12+
13+
:param point1: Coordinates of the first point [x, y]
14+
:param point2: Coordinates of the second point [x, y]
15+
:return: The Euclidean distance between the two points
16+
17+
>>> euclidean_distance([0, 0], [3, 4])
18+
5.0
19+
>>> euclidean_distance([1, 1], [1, 1])
20+
0.0
21+
>>> euclidean_distance([1, 1], ['a', 1])
22+
Traceback (most recent call last):
23+
...
24+
ValueError: Invalid input: Points must be numerical coordinates
25+
"""
26+
try:
27+
return math.sqrt((point2[0] - point1[0]) ** 2 + (point2[1] - point1[1]) ** 2)
28+
except TypeError:
29+
raise ValueError("Invalid input: Points must be numerical coordinates")
30+
31+
def validate_graph(graph_points: dict[str, list[float]]) -> None:
32+
"""
33+
Validate the input graph to ensure it has valid nodes and coordinates.
34+
35+
:param graph_points: A dictionary where the keys are node names,
36+
and values are 2D coordinates as [x, y]
37+
:raises InvalidGraphError: If the graph points are not valid
38+
39+
>>> validate_graph({"A": [10, 20], "B": [30, 21], "C": [15, 35]}) # Valid graph
40+
>>> validate_graph({"A": [10, 20], "B": [30, "invalid"], "C": [15, 35]})
41+
Traceback (most recent call last):
42+
...
43+
InvalidGraphError: Each node must have a valid 2D coordinate [x, y]
44+
45+
>>> validate_graph([10, 20]) # Invalid input type
46+
Traceback (most recent call last):
47+
...
48+
InvalidGraphError: Graph must be a dictionary with node names and coordinates
49+
50+
>>> validate_graph({"A": [10, 20], "B": [30, 21], "C": [15]}) # Missing coordinate
51+
Traceback (most recent call last):
52+
...
53+
InvalidGraphError: Each node must have a valid 2D coordinate [x, y]
54+
"""
55+
if not isinstance(graph_points, dict):
56+
raise InvalidGraphError(
57+
"Graph must be a dictionary with node names and coordinates"
58+
)
59+
60+
for node, coordinates in graph_points.items():
61+
if (
62+
not isinstance(node, str)
63+
or not isinstance(coordinates, list)
64+
or len(coordinates) != 2
65+
or not all(isinstance(c, (int, float)) for c in coordinates)
66+
):
67+
raise InvalidGraphError("Each node must have a valid 2D coordinate [x, y]")
68+
69+
# TSP in Brute Force Approach
70+
def travelling_salesman_brute_force(
71+
graph_points: dict[str, list[float]],
72+
) -> tuple[list[str], float]:
73+
"""
74+
Solve the Travelling Salesman Problem using brute force.
75+
76+
:param graph_points: A dictionary of nodes and their coordinates {node: [x, y]}
77+
:return: The shortest path and its total distance
78+
79+
>>> graph = {"A": [10, 20], "B": [30, 21], "C": [15, 35]}
80+
>>> travelling_salesman_brute_force(graph)
81+
(['A', 'C', 'B', 'A'], 56.35465722402587)
82+
"""
83+
validate_graph(graph_points)
84+
85+
nodes = list(graph_points.keys()) # Extracting the node names (keys)
86+
87+
# There shoukd be atleast 2 nodes for a valid TSP
88+
if len(nodes) < 2:
89+
raise InvalidGraphError("Graph must have at least two nodes")
90+
91+
min_path = [] # List that stores shortest path
92+
min_distance = float("inf") # Initialize minimum distance to infinity
93+
94+
start_node = nodes[0]
95+
other_nodes = nodes[1:]
96+
97+
# Iterating over all permutations of the other nodes
98+
for perm in itertools.permutations(other_nodes):
99+
path = [start_node, *perm, start_node]
100+
101+
# Calculating the total distance
102+
total_distance = sum(
103+
euclidean_distance(graph_points[path[i]], graph_points[path[i + 1]])
104+
for i in range(len(path) - 1)
105+
)
106+
107+
# Update minimum distance if shorter path found
108+
if total_distance < min_distance:
109+
min_distance = total_distance
110+
min_path = path
111+
112+
return min_path, min_distance
113+
114+
# TSP in Dynamic Programming approach
115+
def travelling_salesman_dynamic_programming(
116+
graph_points: dict[str, list[float]],
117+
) -> tuple[list[str], float]:
118+
"""
119+
Solve the Travelling Salesman Problem using dynamic programming.
120+
121+
:param graph_points: A dictionary of nodes and their coordinates {node: [x, y]}
122+
:return: The shortest path and its total distance
123+
124+
>>> graph = {"A": [10, 20], "B": [30, 21], "C": [15, 35]}
125+
>>> travelling_salesman_dynamic_programming(graph)
126+
(['A', 'C', 'B', 'A'], 56.35465722402587)
127+
"""
128+
validate_graph(graph_points)
129+
130+
n = len(graph_points) # Extracting the node names (keys)
131+
132+
# There shoukd be atleast 2 nodes for a valid TSP
133+
if n < 2:
134+
raise InvalidGraphError("Graph must have at least two nodes")
135+
136+
nodes = list(graph_points.keys()) # Extracting the node names (keys)
137+
138+
# Initialize distance matrix with float values
139+
dist = [[euclidean_distance(graph_points[nodes[i]], graph_points[nodes[j]]) for j in range(n)] for i in range(n)]
140+
141+
# Initialize a dynamic programming table with infinity
142+
dp = [[float("inf")] * n for _ in range(1 << n)]
143+
dp[1][0] = 0 # Only visited node is the starting point at node 0
144+
145+
# Iterate through all masks of visited nodes
146+
for mask in range(1 << n):
147+
for u in range(n):
148+
# If current node 'u' is visited
149+
if mask & (1 << u):
150+
# Traverse nodes 'v' such that u->v
151+
for v in range(n):
152+
if mask & (1 << v) == 0: # If v is not visited
153+
next_mask = mask | (1 << v) # Upodate mask to include 'v'
154+
# Update dynamic programming table with minimum distance
155+
dp[next_mask][v] = min(dp[next_mask][v], dp[mask][u] + dist[u][v])
156+
157+
final_mask = (1 << n) - 1
158+
min_cost = float("inf")
159+
end_node = -1 # Track the last node in the optimal path
160+
161+
for u in range(1, n):
162+
if min_cost > dp[final_mask][u] + dist[u][0]:
163+
min_cost = dp[final_mask][u] + dist[u][0]
164+
end_node = u
165+
166+
path = []
167+
mask = final_mask
168+
while end_node != 0:
169+
path.append(nodes[end_node])
170+
for u in range(n):
171+
# If current state corresponds to optimal state before visiting end node
172+
if (
173+
mask & (1 << u)
174+
and dp[mask][end_node]
175+
== dp[mask ^ (1 << end_node)][u] + dist[u][end_node]
176+
):
177+
mask ^= 1 << end_node # Update mask to remove end node
178+
end_node = u # Set the previous node as end node
179+
break
180+
181+
path.append(nodes[0]) # Bottom-up Order
182+
path.reverse() # Top-Down Order
183+
path.append(nodes[0])
184+
185+
return path, min_cost
186+
187+
188+
# Demo Graph
189+
# C (15, 35)
190+
# |
191+
# |
192+
# |
193+
# F (5, 15) --- A (10, 20)
194+
# | |
195+
# | |
196+
# | |
197+
# | |
198+
# E (25, 5) --- B (30, 21)
199+
# |
200+
# |
201+
# |
202+
# D (40, 10)
203+
# |
204+
# |
205+
# |
206+
# G (50, 25)
207+
208+
209+
if __name__ == "__main__":
210+
demo_graph = {
211+
"A": [10.0, 20.0],
212+
"B": [30.0, 21.0],
213+
"C": [15.0, 35.0],
214+
"D": [40.0, 10.0],
215+
"E": [25.0, 5.0],
216+
"F": [5.0, 15.0],
217+
"G": [50.0, 25.0],
218+
}
219+
220+
# Brute force
221+
brute_force_result = travelling_salesman_brute_force(demo_graph)
222+
print(f"Brute force result: {brute_force_result}")
223+
224+
# Dynamic programming
225+
dp_result = travelling_salesman_dynamic_programming(demo_graph)
226+
print(f"Dynamic programming result: {dp_result}")

0 commit comments

Comments
 (0)