Skip to content

Commit 4a2216b

Browse files
authored
Fix mypy errors at bidirectional_a_star (TheAlgorithms#4556)
1 parent 72aa4cc commit 4a2216b

File tree

1 file changed

+23
-19
lines changed

1 file changed

+23
-19
lines changed

graphs/bidirectional_a_star.py

+23-19
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,8 @@
88
from math import sqrt
99

1010
# 1 for manhattan, 0 for euclidean
11+
from typing import Optional
12+
1113
HEURISTIC = 0
1214

1315
grid = [
@@ -22,6 +24,8 @@
2224

2325
delta = [[-1, 0], [0, -1], [1, 0], [0, 1]] # up, left, down, right
2426

27+
TPosition = tuple[int, int]
28+
2529

2630
class Node:
2731
"""
@@ -39,7 +43,15 @@ class Node:
3943
True
4044
"""
4145

42-
def __init__(self, pos_x, pos_y, goal_x, goal_y, g_cost, parent):
46+
def __init__(
47+
self,
48+
pos_x: int,
49+
pos_y: int,
50+
goal_x: int,
51+
goal_y: int,
52+
g_cost: int,
53+
parent: Optional[Node],
54+
) -> None:
4355
self.pos_x = pos_x
4456
self.pos_y = pos_y
4557
self.pos = (pos_y, pos_x)
@@ -61,7 +73,7 @@ def calculate_heuristic(self) -> float:
6173
else:
6274
return sqrt(dy ** 2 + dx ** 2)
6375

64-
def __lt__(self, other) -> bool:
76+
def __lt__(self, other: Node) -> bool:
6577
return self.f_cost < other.f_cost
6678

6779

@@ -81,23 +93,22 @@ class AStar:
8193
(4, 3), (4, 4), (5, 4), (5, 5), (6, 5), (6, 6)]
8294
"""
8395

84-
def __init__(self, start, goal):
96+
def __init__(self, start: TPosition, goal: TPosition):
8597
self.start = Node(start[1], start[0], goal[1], goal[0], 0, None)
8698
self.target = Node(goal[1], goal[0], goal[1], goal[0], 99999, None)
8799

88100
self.open_nodes = [self.start]
89-
self.closed_nodes = []
101+
self.closed_nodes: list[Node] = []
90102

91103
self.reached = False
92104

93-
def search(self) -> list[tuple[int]]:
105+
def search(self) -> list[TPosition]:
94106
while self.open_nodes:
95107
# Open Nodes are sorted using __lt__
96108
self.open_nodes.sort()
97109
current_node = self.open_nodes.pop(0)
98110

99111
if current_node.pos == self.target.pos:
100-
self.reached = True
101112
return self.retrace_path(current_node)
102113

103114
self.closed_nodes.append(current_node)
@@ -118,8 +129,7 @@ def search(self) -> list[tuple[int]]:
118129
else:
119130
self.open_nodes.append(better_node)
120131

121-
if not (self.reached):
122-
return [(self.start.pos)]
132+
return [self.start.pos]
123133

124134
def get_successors(self, parent: Node) -> list[Node]:
125135
"""
@@ -147,7 +157,7 @@ def get_successors(self, parent: Node) -> list[Node]:
147157
)
148158
return successors
149159

150-
def retrace_path(self, node: Node) -> list[tuple[int]]:
160+
def retrace_path(self, node: Optional[Node]) -> list[TPosition]:
151161
"""
152162
Retrace the path from parents to parents until start node
153163
"""
@@ -173,20 +183,19 @@ class BidirectionalAStar:
173183
(2, 5), (3, 5), (4, 5), (5, 5), (5, 6), (6, 6)]
174184
"""
175185

176-
def __init__(self, start, goal):
186+
def __init__(self, start: TPosition, goal: TPosition) -> None:
177187
self.fwd_astar = AStar(start, goal)
178188
self.bwd_astar = AStar(goal, start)
179189
self.reached = False
180190

181-
def search(self) -> list[tuple[int]]:
191+
def search(self) -> list[TPosition]:
182192
while self.fwd_astar.open_nodes or self.bwd_astar.open_nodes:
183193
self.fwd_astar.open_nodes.sort()
184194
self.bwd_astar.open_nodes.sort()
185195
current_fwd_node = self.fwd_astar.open_nodes.pop(0)
186196
current_bwd_node = self.bwd_astar.open_nodes.pop(0)
187197

188198
if current_bwd_node.pos == current_fwd_node.pos:
189-
self.reached = True
190199
return self.retrace_bidirectional_path(
191200
current_fwd_node, current_bwd_node
192201
)
@@ -220,12 +229,11 @@ def search(self) -> list[tuple[int]]:
220229
else:
221230
astar.open_nodes.append(better_node)
222231

223-
if not self.reached:
224-
return [self.fwd_astar.start.pos]
232+
return [self.fwd_astar.start.pos]
225233

226234
def retrace_bidirectional_path(
227235
self, fwd_node: Node, bwd_node: Node
228-
) -> list[tuple[int]]:
236+
) -> list[TPosition]:
229237
fwd_path = self.fwd_astar.retrace_path(fwd_node)
230238
bwd_path = self.bwd_astar.retrace_path(bwd_node)
231239
bwd_path.pop()
@@ -236,9 +244,6 @@ def retrace_bidirectional_path(
236244

237245
if __name__ == "__main__":
238246
# all coordinates are given in format [y,x]
239-
import doctest
240-
241-
doctest.testmod()
242247
init = (0, 0)
243248
goal = (len(grid) - 1, len(grid[0]) - 1)
244249
for elem in grid:
@@ -252,6 +257,5 @@ def retrace_bidirectional_path(
252257

253258
bd_start_time = time.time()
254259
bidir_astar = BidirectionalAStar(init, goal)
255-
path = bidir_astar.search()
256260
bd_end_time = time.time() - bd_start_time
257261
print(f"BidirectionalAStar execution time = {bd_end_time:f} seconds")

0 commit comments

Comments
 (0)