From ac111ee463065e372ad148dbafba630045ecf94c Mon Sep 17 00:00:00 2001 From: Maxim Smolskiy Date: Wed, 29 Mar 2023 00:41:54 +0300 Subject: [PATCH] Reduce the complexity of graphs/bi_directional_dijkstra.py (#8165) * Reduce the complexity of graphs/bi_directional_dijkstra.py * Try to lower the --max-complexity threshold in the file .flake8 * Lower the --max-complexity threshold in the file .flake8 * updating DIRECTORY.md * updating DIRECTORY.md * Try to lower max-complexity * Try to lower max-complexity * Try to lower max-complexity --------- Co-authored-by: github-actions <${GITHUB_ACTOR}@users.noreply.github.com> --- graphs/bi_directional_dijkstra.py | 93 +++++++++++++++++-------------- pyproject.toml | 2 +- 2 files changed, 52 insertions(+), 43 deletions(-) diff --git a/graphs/bi_directional_dijkstra.py b/graphs/bi_directional_dijkstra.py index fc53e2f0..a4489026 100644 --- a/graphs/bi_directional_dijkstra.py +++ b/graphs/bi_directional_dijkstra.py @@ -17,6 +17,32 @@ from typing import Any import numpy as np +def pass_and_relaxation( + graph: dict, + v: str, + visited_forward: set, + visited_backward: set, + cst_fwd: dict, + cst_bwd: dict, + queue: PriorityQueue, + parent: dict, + shortest_distance: float | int, +) -> float | int: + for nxt, d in graph[v]: + if nxt in visited_forward: + continue + old_cost_f = cst_fwd.get(nxt, np.inf) + new_cost_f = cst_fwd[v] + d + if new_cost_f < old_cost_f: + queue.put((new_cost_f, nxt)) + cst_fwd[nxt] = new_cost_f + parent[nxt] = v + if nxt in visited_backward: + if cst_fwd[v] + d + cst_bwd[nxt] < shortest_distance: + shortest_distance = cst_fwd[v] + d + cst_bwd[nxt] + return shortest_distance + + def bidirectional_dij( source: str, destination: str, graph_forward: dict, graph_backward: dict ) -> int: @@ -51,53 +77,36 @@ def bidirectional_dij( if source == destination: return 0 - while queue_forward and queue_backward: - while not queue_forward.empty(): - _, v_fwd = queue_forward.get() - - if v_fwd not in visited_forward: - break - else: - break + while not queue_forward.empty() and not queue_backward.empty(): + _, v_fwd = queue_forward.get() visited_forward.add(v_fwd) - while not queue_backward.empty(): - _, v_bwd = queue_backward.get() - - if v_bwd not in visited_backward: - break - else: - break + _, v_bwd = queue_backward.get() visited_backward.add(v_bwd) - # forward pass and relaxation - for nxt_fwd, d_forward in graph_forward[v_fwd]: - if nxt_fwd in visited_forward: - continue - old_cost_f = cst_fwd.get(nxt_fwd, np.inf) - new_cost_f = cst_fwd[v_fwd] + d_forward - if new_cost_f < old_cost_f: - queue_forward.put((new_cost_f, nxt_fwd)) - cst_fwd[nxt_fwd] = new_cost_f - parent_forward[nxt_fwd] = v_fwd - if nxt_fwd in visited_backward: - if cst_fwd[v_fwd] + d_forward + cst_bwd[nxt_fwd] < shortest_distance: - shortest_distance = cst_fwd[v_fwd] + d_forward + cst_bwd[nxt_fwd] + shortest_distance = pass_and_relaxation( + graph_forward, + v_fwd, + visited_forward, + visited_backward, + cst_fwd, + cst_bwd, + queue_forward, + parent_forward, + shortest_distance, + ) - # backward pass and relaxation - for nxt_bwd, d_backward in graph_backward[v_bwd]: - if nxt_bwd in visited_backward: - continue - old_cost_b = cst_bwd.get(nxt_bwd, np.inf) - new_cost_b = cst_bwd[v_bwd] + d_backward - if new_cost_b < old_cost_b: - queue_backward.put((new_cost_b, nxt_bwd)) - cst_bwd[nxt_bwd] = new_cost_b - parent_backward[nxt_bwd] = v_bwd - - if nxt_bwd in visited_forward: - if cst_bwd[v_bwd] + d_backward + cst_fwd[nxt_bwd] < shortest_distance: - shortest_distance = cst_bwd[v_bwd] + d_backward + cst_fwd[nxt_bwd] + shortest_distance = pass_and_relaxation( + graph_backward, + v_bwd, + visited_backward, + visited_forward, + cst_bwd, + cst_fwd, + queue_backward, + parent_backward, + shortest_distance, + ) if cst_fwd[v_fwd] + cst_bwd[v_bwd] >= shortest_distance: break diff --git a/pyproject.toml b/pyproject.toml index 23fe45e9..48c3fbd4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -61,7 +61,7 @@ show-source = true target-version = "py311" [tool.ruff.mccabe] # DO NOT INCREASE THIS VALUE -max-complexity = 20 # default: 10 +max-complexity = 17 # default: 10 [tool.ruff.pylint] # DO NOT INCREASE THESE VALUES max-args = 10 # default: 5