Skip to content

Commit 9586230

Browse files
authored
Fix mypy at prims_algo_2 (#4527)
1 parent 86baec0 commit 9586230

File tree

1 file changed

+27
-29
lines changed

1 file changed

+27
-29
lines changed

graphs/minimum_spanning_tree_prims2.py

+27-29
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,9 @@
88
"""
99

1010
from sys import maxsize
11-
from typing import Dict, Optional, Tuple, Union
11+
from typing import Generic, Optional, TypeVar
12+
13+
T = TypeVar("T")
1214

1315

1416
def get_parent_position(position: int) -> int:
@@ -43,7 +45,7 @@ def get_child_right_position(position: int) -> int:
4345
return (2 * position) + 2
4446

4547

46-
class MinPriorityQueue:
48+
class MinPriorityQueue(Generic[T]):
4749
"""
4850
Minimum Priority Queue Class
4951
@@ -80,9 +82,9 @@ class MinPriorityQueue:
8082
"""
8183

8284
def __init__(self) -> None:
83-
self.heap = []
84-
self.position_map = {}
85-
self.elements = 0
85+
self.heap: list[tuple[T, int]] = []
86+
self.position_map: dict[T, int] = {}
87+
self.elements: int = 0
8688

8789
def __len__(self) -> int:
8890
return self.elements
@@ -94,14 +96,14 @@ def is_empty(self) -> bool:
9496
# Check if the priority queue is empty
9597
return self.elements == 0
9698

97-
def push(self, elem: Union[int, str], weight: int) -> None:
99+
def push(self, elem: T, weight: int) -> None:
98100
# Add an element with given priority to the queue
99101
self.heap.append((elem, weight))
100102
self.position_map[elem] = self.elements
101103
self.elements += 1
102104
self._bubble_up(elem)
103105

104-
def extract_min(self) -> Union[int, str]:
106+
def extract_min(self) -> T:
105107
# Remove and return the element with lowest weight (highest priority)
106108
if self.elements > 1:
107109
self._swap_nodes(0, self.elements - 1)
@@ -113,7 +115,7 @@ def extract_min(self) -> Union[int, str]:
113115
self._bubble_down(bubble_down_elem)
114116
return elem
115117

116-
def update_key(self, elem: Union[int, str], weight: int) -> None:
118+
def update_key(self, elem: T, weight: int) -> None:
117119
# Update the weight of the given key
118120
position = self.position_map[elem]
119121
self.heap[position] = (elem, weight)
@@ -127,7 +129,7 @@ def update_key(self, elem: Union[int, str], weight: int) -> None:
127129
else:
128130
self._bubble_down(elem)
129131

130-
def _bubble_up(self, elem: Union[int, str]) -> None:
132+
def _bubble_up(self, elem: T) -> None:
131133
# Place a node at the proper position (upward movement) [to be used internally
132134
# only]
133135
curr_pos = self.position_map[elem]
@@ -141,7 +143,7 @@ def _bubble_up(self, elem: Union[int, str]) -> None:
141143
return self._bubble_up(elem)
142144
return
143145

144-
def _bubble_down(self, elem: Union[int, str]) -> None:
146+
def _bubble_down(self, elem: T) -> None:
145147
# Place a node at the proper position (downward movement) [to be used
146148
# internally only]
147149
curr_pos = self.position_map[elem]
@@ -182,7 +184,7 @@ def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
182184
self.position_map[node2_elem] = node1_pos
183185

184186

185-
class GraphUndirectedWeighted:
187+
class GraphUndirectedWeighted(Generic[T]):
186188
"""
187189
Graph Undirected Weighted Class
188190
@@ -192,24 +194,22 @@ class GraphUndirectedWeighted:
192194
"""
193195

194196
def __init__(self) -> None:
195-
self.connections = {}
196-
self.nodes = 0
197+
self.connections: dict[T, dict[T, int]] = {}
198+
self.nodes: int = 0
197199

198200
def __repr__(self) -> str:
199201
return str(self.connections)
200202

201203
def __len__(self) -> int:
202204
return self.nodes
203205

204-
def add_node(self, node: Union[int, str]) -> None:
206+
def add_node(self, node: T) -> None:
205207
# Add a node in the graph if it is not in the graph
206208
if node not in self.connections:
207209
self.connections[node] = {}
208210
self.nodes += 1
209211

210-
def add_edge(
211-
self, node1: Union[int, str], node2: Union[int, str], weight: int
212-
) -> None:
212+
def add_edge(self, node1: T, node2: T, weight: int) -> None:
213213
# Add an edge between 2 nodes in the graph
214214
self.add_node(node1)
215215
self.add_node(node2)
@@ -218,8 +218,8 @@ def add_edge(
218218

219219

220220
def prims_algo(
221-
graph: GraphUndirectedWeighted,
222-
) -> Tuple[Dict[str, int], Dict[str, Optional[str]]]:
221+
graph: GraphUndirectedWeighted[T],
222+
) -> tuple[dict[T, int], dict[T, Optional[T]]]:
223223
"""
224224
>>> graph = GraphUndirectedWeighted()
225225
@@ -239,10 +239,13 @@ def prims_algo(
239239
13
240240
"""
241241
# prim's algorithm for minimum spanning tree
242-
dist = {node: maxsize for node in graph.connections}
243-
parent = {node: None for node in graph.connections}
244-
priority_queue = MinPriorityQueue()
245-
[priority_queue.push(node, weight) for node, weight in dist.items()]
242+
dist: dict[T, int] = {node: maxsize for node in graph.connections}
243+
parent: dict[T, Optional[T]] = {node: None for node in graph.connections}
244+
245+
priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
246+
for node, weight in dist.items():
247+
priority_queue.push(node, weight)
248+
246249
if priority_queue.is_empty():
247250
return dist, parent
248251

@@ -254,6 +257,7 @@ def prims_algo(
254257
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
255258
priority_queue.update_key(neighbour, dist[neighbour])
256259
parent[neighbour] = node
260+
257261
# running prim's algorithm
258262
while not priority_queue.is_empty():
259263
node = priority_queue.extract_min()
@@ -263,9 +267,3 @@ def prims_algo(
263267
priority_queue.update_key(neighbour, dist[neighbour])
264268
parent[neighbour] = node
265269
return dist, parent
266-
267-
268-
if __name__ == "__main__":
269-
from doctest import testmod
270-
271-
testmod()

0 commit comments

Comments
 (0)