Skip to content

Reduce the complexity of graphs/bi_directional_dijkstra.py #8165

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

95 changes: 52 additions & 43 deletions graphs/bi_directional_dijkstra.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,32 @@
import numpy as np


def pass_and_relaxation(
graph: dict,
v: str,
visited_forward: set,
visited_backward: set,
cst_fwd: dict,
cst_bwd: dict,
queue: PriorityQueue,
parent: dict,
shortest_distance: float | int,
) -> float | int:
for nxt, d in graph[v]:
if nxt in visited_forward:
continue
old_cost_f = cst_fwd.get(nxt, np.inf)
new_cost_f = cst_fwd[v] + d
if new_cost_f < old_cost_f:
queue.put((new_cost_f, nxt))
cst_fwd[nxt] = new_cost_f
parent[nxt] = v
if nxt in visited_backward:
if cst_fwd[v] + d + cst_bwd[nxt] < shortest_distance:
shortest_distance = cst_fwd[v] + d + cst_bwd[nxt]
return shortest_distance


def bidirectional_dij(
source: str, destination: str, graph_forward: dict, graph_backward: dict
) -> int:
Expand Down Expand Up @@ -51,53 +77,36 @@ def bidirectional_dij(
if source == destination:
return 0

while queue_forward and queue_backward:
while not queue_forward.empty():
_, v_fwd = queue_forward.get()

if v_fwd not in visited_forward:
break
else:
break
while not queue_forward.empty() and not queue_backward.empty():
_, v_fwd = queue_forward.get()
visited_forward.add(v_fwd)

while not queue_backward.empty():
_, v_bwd = queue_backward.get()

if v_bwd not in visited_backward:
break
else:
break
_, v_bwd = queue_backward.get()
visited_backward.add(v_bwd)

# forward pass and relaxation
for nxt_fwd, d_forward in graph_forward[v_fwd]:
if nxt_fwd in visited_forward:
continue
old_cost_f = cst_fwd.get(nxt_fwd, np.inf)
new_cost_f = cst_fwd[v_fwd] + d_forward
if new_cost_f < old_cost_f:
queue_forward.put((new_cost_f, nxt_fwd))
cst_fwd[nxt_fwd] = new_cost_f
parent_forward[nxt_fwd] = v_fwd
if nxt_fwd in visited_backward:
if cst_fwd[v_fwd] + d_forward + cst_bwd[nxt_fwd] < shortest_distance:
shortest_distance = cst_fwd[v_fwd] + d_forward + cst_bwd[nxt_fwd]

# backward pass and relaxation
for nxt_bwd, d_backward in graph_backward[v_bwd]:
if nxt_bwd in visited_backward:
continue
old_cost_b = cst_bwd.get(nxt_bwd, np.inf)
new_cost_b = cst_bwd[v_bwd] + d_backward
if new_cost_b < old_cost_b:
queue_backward.put((new_cost_b, nxt_bwd))
cst_bwd[nxt_bwd] = new_cost_b
parent_backward[nxt_bwd] = v_bwd

if nxt_bwd in visited_forward:
if cst_bwd[v_bwd] + d_backward + cst_fwd[nxt_bwd] < shortest_distance:
shortest_distance = cst_bwd[v_bwd] + d_backward + cst_fwd[nxt_bwd]
shortest_distance = pass_and_relaxation(
graph_forward,
v_fwd,
visited_forward,
visited_backward,
cst_fwd,
cst_bwd,
queue_forward,
parent_forward,
shortest_distance,
)

shortest_distance = pass_and_relaxation(
graph_backward,
v_bwd,
visited_backward,
visited_forward,
cst_bwd,
cst_fwd,
queue_backward,
parent_backward,
shortest_distance,
)

if cst_fwd[v_fwd] + cst_bwd[v_bwd] >= shortest_distance:
break
Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ show-source = true
target-version = "py311"

[tool.ruff.mccabe] # DO NOT INCREASE THIS VALUE
max-complexity = 20 # default: 10
max-complexity = 17 # default: 10

[tool.ruff.pylint] # DO NOT INCREASE THESE VALUES
max-args = 10 # default: 5
Expand Down