1
1
from __future__ import annotations
2
2
3
+ from typing import Generic , TypeVar
3
4
4
- class DisjointSetTreeNode :
5
+ T = TypeVar ("T" )
6
+
7
+
8
+ class DisjointSetTreeNode (Generic [T ]):
5
9
# Disjoint Set Node to store the parent and rank
6
- def __init__ (self , key : int ) -> None :
10
+ def __init__ (self , key : T ) -> None :
7
11
self .key = key
8
12
self .parent = self
9
13
self .rank = 0
10
14
11
15
12
- class DisjointSetTree :
16
+ class DisjointSetTree ( Generic [ T ]) :
13
17
# Disjoint Set DataStructure
14
- def __init__ (self ):
18
+ def __init__ (self ) -> None :
15
19
# map from node name to the node object
16
- self .map = {}
20
+ self .map : dict [ T , DisjointSetTreeNode [ T ]] = {}
17
21
18
- def make_set (self , x : int ) -> None :
22
+ def make_set (self , x : T ) -> None :
19
23
# create a new set with x as its member
20
24
self .map [x ] = DisjointSetTreeNode (x )
21
25
22
- def find_set (self , x : int ) -> DisjointSetTreeNode :
26
+ def find_set (self , x : T ) -> DisjointSetTreeNode [ T ] :
23
27
# find the set x belongs to (with path-compression)
24
28
elem_ref = self .map [x ]
25
29
if elem_ref != elem_ref .parent :
26
30
elem_ref .parent = self .find_set (elem_ref .parent .key )
27
31
return elem_ref .parent
28
32
29
- def link (self , x : int , y : int ) -> None :
33
+ def link (self , x : DisjointSetTreeNode [ T ] , y : DisjointSetTreeNode [ T ] ) -> None :
30
34
# helper function for union operation
31
35
if x .rank > y .rank :
32
36
y .parent = x
@@ -35,44 +39,53 @@ def link(self, x: int, y: int) -> None:
35
39
if x .rank == y .rank :
36
40
y .rank += 1
37
41
38
- def union (self , x : int , y : int ) -> None :
42
+ def union (self , x : T , y : T ) -> None :
39
43
# merge 2 disjoint sets
40
44
self .link (self .find_set (x ), self .find_set (y ))
41
45
42
46
43
- class GraphUndirectedWeighted :
44
- def __init__ (self ):
47
+ class GraphUndirectedWeighted ( Generic [ T ]) :
48
+ def __init__ (self ) -> None :
45
49
# connections: map from the node to the neighbouring nodes (with weights)
46
- self .connections = {}
50
+ self .connections : dict [ T , dict [ T , int ]] = {}
47
51
48
- def add_node (self , node : int ) -> None :
52
+ def add_node (self , node : T ) -> None :
49
53
# add a node ONLY if its not present in the graph
50
54
if node not in self .connections :
51
55
self .connections [node ] = {}
52
56
53
- def add_edge (self , node1 : int , node2 : int , weight : int ) -> None :
57
+ def add_edge (self , node1 : T , node2 : T , weight : int ) -> None :
54
58
# add an edge with the given weight
55
59
self .add_node (node1 )
56
60
self .add_node (node2 )
57
61
self .connections [node1 ][node2 ] = weight
58
62
self .connections [node2 ][node1 ] = weight
59
63
60
- def kruskal (self ) -> GraphUndirectedWeighted :
64
+ def kruskal (self ) -> GraphUndirectedWeighted [ T ] :
61
65
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
62
66
"""
63
67
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
64
68
65
69
Example:
66
-
67
- >>> graph = GraphUndirectedWeighted()
68
- >>> graph.add_edge(1, 2, 1)
69
- >>> graph.add_edge(2, 3, 2)
70
- >>> graph.add_edge(3, 4, 1)
71
- >>> graph.add_edge(3, 5, 100) # Removed in MST
72
- >>> graph.add_edge(4, 5, 5)
73
- >>> assert 5 in graph.connections[3]
74
- >>> mst = graph.kruskal()
70
+ >>> g1 = GraphUndirectedWeighted[int]()
71
+ >>> g1.add_edge(1, 2, 1)
72
+ >>> g1.add_edge(2, 3, 2)
73
+ >>> g1.add_edge(3, 4, 1)
74
+ >>> g1.add_edge(3, 5, 100) # Removed in MST
75
+ >>> g1.add_edge(4, 5, 5)
76
+ >>> assert 5 in g1.connections[3]
77
+ >>> mst = g1.kruskal()
75
78
>>> assert 5 not in mst.connections[3]
79
+
80
+ >>> g2 = GraphUndirectedWeighted[str]()
81
+ >>> g2.add_edge('A', 'B', 1)
82
+ >>> g2.add_edge('B', 'C', 2)
83
+ >>> g2.add_edge('C', 'D', 1)
84
+ >>> g2.add_edge('C', 'E', 100) # Removed in MST
85
+ >>> g2.add_edge('D', 'E', 5)
86
+ >>> assert 'E' in g2.connections["C"]
87
+ >>> mst = g2.kruskal()
88
+ >>> assert 'E' not in mst.connections['C']
76
89
"""
77
90
78
91
# getting the edges in ascending order of weights
@@ -84,26 +97,23 @@ def kruskal(self) -> GraphUndirectedWeighted:
84
97
seen .add ((end , start ))
85
98
edges .append ((start , end , self .connections [start ][end ]))
86
99
edges .sort (key = lambda x : x [2 ])
100
+
87
101
# creating the disjoint set
88
- disjoint_set = DisjointSetTree ()
89
- [disjoint_set .make_set (node ) for node in self .connections ]
102
+ disjoint_set = DisjointSetTree [T ]()
103
+ for node in self .connections :
104
+ disjoint_set .make_set (node )
105
+
90
106
# MST generation
91
107
num_edges = 0
92
108
index = 0
93
- graph = GraphUndirectedWeighted ()
109
+ graph = GraphUndirectedWeighted [ T ] ()
94
110
while num_edges < len (self .connections ) - 1 :
95
111
u , v , w = edges [index ]
96
112
index += 1
97
- parentu = disjoint_set .find_set (u )
98
- parentv = disjoint_set .find_set (v )
99
- if parentu != parentv :
113
+ parent_u = disjoint_set .find_set (u )
114
+ parent_v = disjoint_set .find_set (v )
115
+ if parent_u != parent_v :
100
116
num_edges += 1
101
117
graph .add_edge (u , v , w )
102
118
disjoint_set .union (u , v )
103
119
return graph
104
-
105
-
106
- if __name__ == "__main__" :
107
- import doctest
108
-
109
- doctest .testmod ()
0 commit comments