Skip to content

Commit 256c319

Browse files
authored
Fix mypy errors at kruskal_2 (TheAlgorithms#4528)
1 parent 4412eaf commit 256c319

File tree

1 file changed

+58
-46
lines changed

1 file changed

+58
-46
lines changed

graphs/minimum_spanning_tree_kruskal2.py

+58-46
Original file line numberDiff line numberDiff line change
@@ -1,78 +1,93 @@
11
from __future__ import annotations
22

3+
from typing import Generic, TypeVar
34

4-
class DisjointSetTreeNode:
5+
T = TypeVar("T")
6+
7+
8+
class DisjointSetTreeNode(Generic[T]):
59
# Disjoint Set Node to store the parent and rank
6-
def __init__(self, key: int) -> None:
7-
self.key = key
10+
def __init__(self, data: T) -> None:
11+
self.data = data
812
self.parent = self
913
self.rank = 0
1014

1115

12-
class DisjointSetTree:
16+
class DisjointSetTree(Generic[T]):
1317
# Disjoint Set DataStructure
14-
def __init__(self):
18+
def __init__(self) -> None:
1519
# map from node name to the node object
16-
self.map = {}
20+
self.map: dict[T, DisjointSetTreeNode[T]] = {}
1721

18-
def make_set(self, x: int) -> None:
22+
def make_set(self, data: T) -> None:
1923
# create a new set with x as its member
20-
self.map[x] = DisjointSetTreeNode(x)
24+
self.map[data] = DisjointSetTreeNode(data)
2125

22-
def find_set(self, x: int) -> DisjointSetTreeNode:
26+
def find_set(self, data: T) -> DisjointSetTreeNode[T]:
2327
# find the set x belongs to (with path-compression)
24-
elem_ref = self.map[x]
28+
elem_ref = self.map[data]
2529
if elem_ref != elem_ref.parent:
26-
elem_ref.parent = self.find_set(elem_ref.parent.key)
30+
elem_ref.parent = self.find_set(elem_ref.parent.data)
2731
return elem_ref.parent
2832

29-
def link(self, x: int, y: int) -> None:
33+
def link(
34+
self, node1: DisjointSetTreeNode[T], node2: DisjointSetTreeNode[T]
35+
) -> None:
3036
# helper function for union operation
31-
if x.rank > y.rank:
32-
y.parent = x
37+
if node1.rank > node2.rank:
38+
node2.parent = node1
3339
else:
34-
x.parent = y
35-
if x.rank == y.rank:
36-
y.rank += 1
40+
node1.parent = node2
41+
if node1.rank == node2.rank:
42+
node2.rank += 1
3743

38-
def union(self, x: int, y: int) -> None:
44+
def union(self, data1: T, data2: T) -> None:
3945
# merge 2 disjoint sets
40-
self.link(self.find_set(x), self.find_set(y))
46+
self.link(self.find_set(data1), self.find_set(data2))
4147

4248

43-
class GraphUndirectedWeighted:
44-
def __init__(self):
49+
class GraphUndirectedWeighted(Generic[T]):
50+
def __init__(self) -> None:
4551
# connections: map from the node to the neighbouring nodes (with weights)
46-
self.connections = {}
52+
self.connections: dict[T, dict[T, int]] = {}
4753

48-
def add_node(self, node: int) -> None:
54+
def add_node(self, node: T) -> None:
4955
# add a node ONLY if its not present in the graph
5056
if node not in self.connections:
5157
self.connections[node] = {}
5258

53-
def add_edge(self, node1: int, node2: int, weight: int) -> None:
59+
def add_edge(self, node1: T, node2: T, weight: int) -> None:
5460
# add an edge with the given weight
5561
self.add_node(node1)
5662
self.add_node(node2)
5763
self.connections[node1][node2] = weight
5864
self.connections[node2][node1] = weight
5965

60-
def kruskal(self) -> GraphUndirectedWeighted:
66+
def kruskal(self) -> GraphUndirectedWeighted[T]:
6167
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
6268
"""
6369
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
6470
6571
Example:
66-
67-
>>> graph = GraphUndirectedWeighted()
68-
>>> graph.add_edge(1, 2, 1)
69-
>>> graph.add_edge(2, 3, 2)
70-
>>> graph.add_edge(3, 4, 1)
71-
>>> graph.add_edge(3, 5, 100) # Removed in MST
72-
>>> graph.add_edge(4, 5, 5)
73-
>>> assert 5 in graph.connections[3]
74-
>>> mst = graph.kruskal()
72+
>>> g1 = GraphUndirectedWeighted[int]()
73+
>>> g1.add_edge(1, 2, 1)
74+
>>> g1.add_edge(2, 3, 2)
75+
>>> g1.add_edge(3, 4, 1)
76+
>>> g1.add_edge(3, 5, 100) # Removed in MST
77+
>>> g1.add_edge(4, 5, 5)
78+
>>> assert 5 in g1.connections[3]
79+
>>> mst = g1.kruskal()
7580
>>> assert 5 not in mst.connections[3]
81+
82+
>>> g2 = GraphUndirectedWeighted[str]()
83+
>>> g2.add_edge('A', 'B', 1)
84+
>>> g2.add_edge('B', 'C', 2)
85+
>>> g2.add_edge('C', 'D', 1)
86+
>>> g2.add_edge('C', 'E', 100) # Removed in MST
87+
>>> g2.add_edge('D', 'E', 5)
88+
>>> assert 'E' in g2.connections["C"]
89+
>>> mst = g2.kruskal()
90+
>>> assert 'E' not in mst.connections['C']
7691
"""
7792

7893
# getting the edges in ascending order of weights
@@ -84,26 +99,23 @@ def kruskal(self) -> GraphUndirectedWeighted:
8499
seen.add((end, start))
85100
edges.append((start, end, self.connections[start][end]))
86101
edges.sort(key=lambda x: x[2])
102+
87103
# creating the disjoint set
88-
disjoint_set = DisjointSetTree()
89-
[disjoint_set.make_set(node) for node in self.connections]
104+
disjoint_set = DisjointSetTree[T]()
105+
for node in self.connections:
106+
disjoint_set.make_set(node)
107+
90108
# MST generation
91109
num_edges = 0
92110
index = 0
93-
graph = GraphUndirectedWeighted()
111+
graph = GraphUndirectedWeighted[T]()
94112
while num_edges < len(self.connections) - 1:
95113
u, v, w = edges[index]
96114
index += 1
97-
parentu = disjoint_set.find_set(u)
98-
parentv = disjoint_set.find_set(v)
99-
if parentu != parentv:
115+
parent_u = disjoint_set.find_set(u)
116+
parent_v = disjoint_set.find_set(v)
117+
if parent_u != parent_v:
100118
num_edges += 1
101119
graph.add_edge(u, v, w)
102120
disjoint_set.union(u, v)
103121
return graph
104-
105-
106-
if __name__ == "__main__":
107-
import doctest
108-
109-
doctest.testmod()

0 commit comments

Comments
 (0)