8
8
"""
9
9
10
10
from sys import maxsize
11
- from typing import Dict , Optional , Tuple , Union
11
+ from typing import Generic , Optional , TypeVar
12
+
13
+ T = TypeVar ("T" )
12
14
13
15
14
16
def get_parent_position (position : int ) -> int :
@@ -43,7 +45,7 @@ def get_child_right_position(position: int) -> int:
43
45
return (2 * position ) + 2
44
46
45
47
46
- class MinPriorityQueue :
48
+ class MinPriorityQueue ( Generic [ T ]) :
47
49
"""
48
50
Minimum Priority Queue Class
49
51
@@ -80,9 +82,9 @@ class MinPriorityQueue:
80
82
"""
81
83
82
84
def __init__ (self ) -> None :
83
- self .heap = []
84
- self .position_map = {}
85
- self .elements = 0
85
+ self .heap : list [ tuple [ T , int ]] = []
86
+ self .position_map : dict [ T , int ] = {}
87
+ self .elements : int = 0
86
88
87
89
def __len__ (self ) -> int :
88
90
return self .elements
@@ -94,14 +96,14 @@ def is_empty(self) -> bool:
94
96
# Check if the priority queue is empty
95
97
return self .elements == 0
96
98
97
- def push (self , elem : Union [ int , str ] , weight : int ) -> None :
99
+ def push (self , elem : T , weight : int ) -> None :
98
100
# Add an element with given priority to the queue
99
101
self .heap .append ((elem , weight ))
100
102
self .position_map [elem ] = self .elements
101
103
self .elements += 1
102
104
self ._bubble_up (elem )
103
105
104
- def extract_min (self ) -> Union [ int , str ] :
106
+ def extract_min (self ) -> T :
105
107
# Remove and return the element with lowest weight (highest priority)
106
108
if self .elements > 1 :
107
109
self ._swap_nodes (0 , self .elements - 1 )
@@ -113,7 +115,7 @@ def extract_min(self) -> Union[int, str]:
113
115
self ._bubble_down (bubble_down_elem )
114
116
return elem
115
117
116
- def update_key (self , elem : Union [ int , str ] , weight : int ) -> None :
118
+ def update_key (self , elem : T , weight : int ) -> None :
117
119
# Update the weight of the given key
118
120
position = self .position_map [elem ]
119
121
self .heap [position ] = (elem , weight )
@@ -127,7 +129,7 @@ def update_key(self, elem: Union[int, str], weight: int) -> None:
127
129
else :
128
130
self ._bubble_down (elem )
129
131
130
- def _bubble_up (self , elem : Union [ int , str ] ) -> None :
132
+ def _bubble_up (self , elem : T ) -> None :
131
133
# Place a node at the proper position (upward movement) [to be used internally
132
134
# only]
133
135
curr_pos = self .position_map [elem ]
@@ -141,7 +143,7 @@ def _bubble_up(self, elem: Union[int, str]) -> None:
141
143
return self ._bubble_up (elem )
142
144
return
143
145
144
- def _bubble_down (self , elem : Union [ int , str ] ) -> None :
146
+ def _bubble_down (self , elem : T ) -> None :
145
147
# Place a node at the proper position (downward movement) [to be used
146
148
# internally only]
147
149
curr_pos = self .position_map [elem ]
@@ -182,7 +184,7 @@ def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
182
184
self .position_map [node2_elem ] = node1_pos
183
185
184
186
185
- class GraphUndirectedWeighted :
187
+ class GraphUndirectedWeighted ( Generic [ T ]) :
186
188
"""
187
189
Graph Undirected Weighted Class
188
190
@@ -192,24 +194,22 @@ class GraphUndirectedWeighted:
192
194
"""
193
195
194
196
def __init__ (self ) -> None :
195
- self .connections = {}
196
- self .nodes = 0
197
+ self .connections : dict [ T , dict [ T , int ]] = {}
198
+ self .nodes : int = 0
197
199
198
200
def __repr__ (self ) -> str :
199
201
return str (self .connections )
200
202
201
203
def __len__ (self ) -> int :
202
204
return self .nodes
203
205
204
- def add_node (self , node : Union [ int , str ] ) -> None :
206
+ def add_node (self , node : T ) -> None :
205
207
# Add a node in the graph if it is not in the graph
206
208
if node not in self .connections :
207
209
self .connections [node ] = {}
208
210
self .nodes += 1
209
211
210
- def add_edge (
211
- self , node1 : Union [int , str ], node2 : Union [int , str ], weight : int
212
- ) -> None :
212
+ def add_edge (self , node1 : T , node2 : T , weight : int ) -> None :
213
213
# Add an edge between 2 nodes in the graph
214
214
self .add_node (node1 )
215
215
self .add_node (node2 )
@@ -218,8 +218,8 @@ def add_edge(
218
218
219
219
220
220
def prims_algo (
221
- graph : GraphUndirectedWeighted ,
222
- ) -> Tuple [ Dict [ str , int ], Dict [ str , Optional [str ]]]:
221
+ graph : GraphUndirectedWeighted [ T ] ,
222
+ ) -> tuple [ dict [ T , int ], dict [ T , Optional [T ]]]:
223
223
"""
224
224
>>> graph = GraphUndirectedWeighted()
225
225
@@ -239,10 +239,13 @@ def prims_algo(
239
239
13
240
240
"""
241
241
# prim's algorithm for minimum spanning tree
242
- dist = {node : maxsize for node in graph .connections }
243
- parent = {node : None for node in graph .connections }
244
- priority_queue = MinPriorityQueue ()
245
- [priority_queue .push (node , weight ) for node , weight in dist .items ()]
242
+ dist : dict [T , int ] = {node : maxsize for node in graph .connections }
243
+ parent : dict [T , Optional [T ]] = {node : None for node in graph .connections }
244
+
245
+ priority_queue : MinPriorityQueue [T ] = MinPriorityQueue ()
246
+ for node , weight in dist .items ():
247
+ priority_queue .push (node , weight )
248
+
246
249
if priority_queue .is_empty ():
247
250
return dist , parent
248
251
@@ -254,6 +257,7 @@ def prims_algo(
254
257
dist [neighbour ] = dist [node ] + graph .connections [node ][neighbour ]
255
258
priority_queue .update_key (neighbour , dist [neighbour ])
256
259
parent [neighbour ] = node
260
+
257
261
# running prim's algorithm
258
262
while not priority_queue .is_empty ():
259
263
node = priority_queue .extract_min ()
@@ -263,9 +267,3 @@ def prims_algo(
263
267
priority_queue .update_key (neighbour , dist [neighbour ])
264
268
parent [neighbour ] = node
265
269
return dist , parent
266
-
267
-
268
- if __name__ == "__main__" :
269
- from doctest import testmod
270
-
271
- testmod ()
0 commit comments