1
1
from __future__ import annotations
2
2
3
+ from typing import Generic , TypeVar
3
4
4
- class DisjointSetTreeNode :
5
+ T = TypeVar ("T" )
6
+
7
+
8
+ class DisjointSetTreeNode (Generic [T ]):
5
9
# Disjoint Set Node to store the parent and rank
6
- def __init__ (self , key : int ) -> None :
7
- self .key = key
10
+ def __init__ (self , data : T ) -> None :
11
+ self .data = data
8
12
self .parent = self
9
13
self .rank = 0
10
14
11
15
12
- class DisjointSetTree :
16
+ class DisjointSetTree ( Generic [ T ]) :
13
17
# Disjoint Set DataStructure
14
- def __init__ (self ):
18
+ def __init__ (self ) -> None :
15
19
# map from node name to the node object
16
- self .map = {}
20
+ self .map : dict [ T , DisjointSetTreeNode [ T ]] = {}
17
21
18
- def make_set (self , x : int ) -> None :
22
+ def make_set (self , data : T ) -> None :
19
23
# create a new set with x as its member
20
- self .map [x ] = DisjointSetTreeNode (x )
24
+ self .map [data ] = DisjointSetTreeNode (data )
21
25
22
- def find_set (self , x : int ) -> DisjointSetTreeNode :
26
+ def find_set (self , data : T ) -> DisjointSetTreeNode [ T ] :
23
27
# find the set x belongs to (with path-compression)
24
- elem_ref = self .map [x ]
28
+ elem_ref = self .map [data ]
25
29
if elem_ref != elem_ref .parent :
26
- elem_ref .parent = self .find_set (elem_ref .parent .key )
30
+ elem_ref .parent = self .find_set (elem_ref .parent .data )
27
31
return elem_ref .parent
28
32
29
- def link (self , x : int , y : int ) -> None :
33
+ def link (
34
+ self , node1 : DisjointSetTreeNode [T ], node2 : DisjointSetTreeNode [T ]
35
+ ) -> None :
30
36
# helper function for union operation
31
- if x .rank > y .rank :
32
- y .parent = x
37
+ if node1 .rank > node2 .rank :
38
+ node2 .parent = node1
33
39
else :
34
- x .parent = y
35
- if x .rank == y .rank :
36
- y .rank += 1
40
+ node1 .parent = node2
41
+ if node1 .rank == node2 .rank :
42
+ node2 .rank += 1
37
43
38
- def union (self , x : int , y : int ) -> None :
44
+ def union (self , data1 : T , data2 : T ) -> None :
39
45
# merge 2 disjoint sets
40
- self .link (self .find_set (x ), self .find_set (y ))
46
+ self .link (self .find_set (data1 ), self .find_set (data2 ))
41
47
42
48
43
- class GraphUndirectedWeighted :
44
- def __init__ (self ):
49
+ class GraphUndirectedWeighted ( Generic [ T ]) :
50
+ def __init__ (self ) -> None :
45
51
# connections: map from the node to the neighbouring nodes (with weights)
46
- self .connections = {}
52
+ self .connections : dict [ T , dict [ T , int ]] = {}
47
53
48
- def add_node (self , node : int ) -> None :
54
+ def add_node (self , node : T ) -> None :
49
55
# add a node ONLY if its not present in the graph
50
56
if node not in self .connections :
51
57
self .connections [node ] = {}
52
58
53
- def add_edge (self , node1 : int , node2 : int , weight : int ) -> None :
59
+ def add_edge (self , node1 : T , node2 : T , weight : int ) -> None :
54
60
# add an edge with the given weight
55
61
self .add_node (node1 )
56
62
self .add_node (node2 )
57
63
self .connections [node1 ][node2 ] = weight
58
64
self .connections [node2 ][node1 ] = weight
59
65
60
- def kruskal (self ) -> GraphUndirectedWeighted :
66
+ def kruskal (self ) -> GraphUndirectedWeighted [ T ] :
61
67
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
62
68
"""
63
69
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
64
70
65
71
Example:
66
-
67
- >>> graph = GraphUndirectedWeighted()
68
- >>> graph.add_edge(1, 2, 1)
69
- >>> graph.add_edge(2, 3, 2)
70
- >>> graph.add_edge(3, 4, 1)
71
- >>> graph.add_edge(3, 5, 100) # Removed in MST
72
- >>> graph.add_edge(4, 5, 5)
73
- >>> assert 5 in graph.connections[3]
74
- >>> mst = graph.kruskal()
72
+ >>> g1 = GraphUndirectedWeighted[int]()
73
+ >>> g1.add_edge(1, 2, 1)
74
+ >>> g1.add_edge(2, 3, 2)
75
+ >>> g1.add_edge(3, 4, 1)
76
+ >>> g1.add_edge(3, 5, 100) # Removed in MST
77
+ >>> g1.add_edge(4, 5, 5)
78
+ >>> assert 5 in g1.connections[3]
79
+ >>> mst = g1.kruskal()
75
80
>>> assert 5 not in mst.connections[3]
81
+
82
+ >>> g2 = GraphUndirectedWeighted[str]()
83
+ >>> g2.add_edge('A', 'B', 1)
84
+ >>> g2.add_edge('B', 'C', 2)
85
+ >>> g2.add_edge('C', 'D', 1)
86
+ >>> g2.add_edge('C', 'E', 100) # Removed in MST
87
+ >>> g2.add_edge('D', 'E', 5)
88
+ >>> assert 'E' in g2.connections["C"]
89
+ >>> mst = g2.kruskal()
90
+ >>> assert 'E' not in mst.connections['C']
76
91
"""
77
92
78
93
# getting the edges in ascending order of weights
@@ -84,26 +99,23 @@ def kruskal(self) -> GraphUndirectedWeighted:
84
99
seen .add ((end , start ))
85
100
edges .append ((start , end , self .connections [start ][end ]))
86
101
edges .sort (key = lambda x : x [2 ])
102
+
87
103
# creating the disjoint set
88
- disjoint_set = DisjointSetTree ()
89
- [disjoint_set .make_set (node ) for node in self .connections ]
104
+ disjoint_set = DisjointSetTree [T ]()
105
+ for node in self .connections :
106
+ disjoint_set .make_set (node )
107
+
90
108
# MST generation
91
109
num_edges = 0
92
110
index = 0
93
- graph = GraphUndirectedWeighted ()
111
+ graph = GraphUndirectedWeighted [ T ] ()
94
112
while num_edges < len (self .connections ) - 1 :
95
113
u , v , w = edges [index ]
96
114
index += 1
97
- parentu = disjoint_set .find_set (u )
98
- parentv = disjoint_set .find_set (v )
99
- if parentu != parentv :
115
+ parent_u = disjoint_set .find_set (u )
116
+ parent_v = disjoint_set .find_set (v )
117
+ if parent_u != parent_v :
100
118
num_edges += 1
101
119
graph .add_edge (u , v , w )
102
120
disjoint_set .union (u , v )
103
121
return graph
104
-
105
-
106
- if __name__ == "__main__" :
107
- import doctest
108
-
109
- doctest .testmod ()
0 commit comments