Skip to content

Fix mypy errors at kruskal_2 #4528

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 58 additions & 46 deletions graphs/minimum_spanning_tree_kruskal2.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,93 @@
from __future__ import annotations

from typing import Generic, TypeVar

class DisjointSetTreeNode:
T = TypeVar("T")


class DisjointSetTreeNode(Generic[T]):
# Disjoint Set Node to store the parent and rank
def __init__(self, key: int) -> None:
self.key = key
def __init__(self, data: T) -> None:
self.data = data
self.parent = self
self.rank = 0


class DisjointSetTree:
class DisjointSetTree(Generic[T]):
# Disjoint Set DataStructure
def __init__(self):
def __init__(self) -> None:
# map from node name to the node object
self.map = {}
self.map: dict[T, DisjointSetTreeNode[T]] = {}

def make_set(self, x: int) -> None:
def make_set(self, data: T) -> None:
# create a new set with x as its member
self.map[x] = DisjointSetTreeNode(x)
self.map[data] = DisjointSetTreeNode(data)

def find_set(self, x: int) -> DisjointSetTreeNode:
def find_set(self, data: T) -> DisjointSetTreeNode[T]:
# find the set x belongs to (with path-compression)
elem_ref = self.map[x]
elem_ref = self.map[data]
if elem_ref != elem_ref.parent:
elem_ref.parent = self.find_set(elem_ref.parent.key)
elem_ref.parent = self.find_set(elem_ref.parent.data)
return elem_ref.parent

def link(self, x: int, y: int) -> None:
def link(
self, node1: DisjointSetTreeNode[T], node2: DisjointSetTreeNode[T]
) -> None:
# helper function for union operation
if x.rank > y.rank:
y.parent = x
if node1.rank > node2.rank:
node2.parent = node1
else:
x.parent = y
if x.rank == y.rank:
y.rank += 1
node1.parent = node2
if node1.rank == node2.rank:
node2.rank += 1

def union(self, x: int, y: int) -> None:
def union(self, data1: T, data2: T) -> None:
# merge 2 disjoint sets
self.link(self.find_set(x), self.find_set(y))
self.link(self.find_set(data1), self.find_set(data2))


class GraphUndirectedWeighted:
def __init__(self):
class GraphUndirectedWeighted(Generic[T]):
def __init__(self) -> None:
# connections: map from the node to the neighbouring nodes (with weights)
self.connections = {}
self.connections: dict[T, dict[T, int]] = {}

def add_node(self, node: int) -> None:
def add_node(self, node: T) -> None:
# add a node ONLY if its not present in the graph
if node not in self.connections:
self.connections[node] = {}

def add_edge(self, node1: int, node2: int, weight: int) -> None:
def add_edge(self, node1: T, node2: T, weight: int) -> None:
# add an edge with the given weight
self.add_node(node1)
self.add_node(node2)
self.connections[node1][node2] = weight
self.connections[node2][node1] = weight

def kruskal(self) -> GraphUndirectedWeighted:
def kruskal(self) -> GraphUndirectedWeighted[T]:
# Kruskal's Algorithm to generate a Minimum Spanning Tree (MST) of a graph
"""
Details: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm

Example:

>>> graph = GraphUndirectedWeighted()
>>> graph.add_edge(1, 2, 1)
>>> graph.add_edge(2, 3, 2)
>>> graph.add_edge(3, 4, 1)
>>> graph.add_edge(3, 5, 100) # Removed in MST
>>> graph.add_edge(4, 5, 5)
>>> assert 5 in graph.connections[3]
>>> mst = graph.kruskal()
>>> g1 = GraphUndirectedWeighted[int]()
>>> g1.add_edge(1, 2, 1)
>>> g1.add_edge(2, 3, 2)
>>> g1.add_edge(3, 4, 1)
>>> g1.add_edge(3, 5, 100) # Removed in MST
>>> g1.add_edge(4, 5, 5)
>>> assert 5 in g1.connections[3]
>>> mst = g1.kruskal()
>>> assert 5 not in mst.connections[3]

>>> g2 = GraphUndirectedWeighted[str]()
>>> g2.add_edge('A', 'B', 1)
>>> g2.add_edge('B', 'C', 2)
>>> g2.add_edge('C', 'D', 1)
>>> g2.add_edge('C', 'E', 100) # Removed in MST
>>> g2.add_edge('D', 'E', 5)
>>> assert 'E' in g2.connections["C"]
>>> mst = g2.kruskal()
>>> assert 'E' not in mst.connections['C']
"""

# getting the edges in ascending order of weights
Expand All @@ -84,26 +99,23 @@ def kruskal(self) -> GraphUndirectedWeighted:
seen.add((end, start))
edges.append((start, end, self.connections[start][end]))
edges.sort(key=lambda x: x[2])

# creating the disjoint set
disjoint_set = DisjointSetTree()
[disjoint_set.make_set(node) for node in self.connections]
disjoint_set = DisjointSetTree[T]()
for node in self.connections:
disjoint_set.make_set(node)

# MST generation
num_edges = 0
index = 0
graph = GraphUndirectedWeighted()
graph = GraphUndirectedWeighted[T]()
while num_edges < len(self.connections) - 1:
u, v, w = edges[index]
index += 1
parentu = disjoint_set.find_set(u)
parentv = disjoint_set.find_set(v)
if parentu != parentv:
parent_u = disjoint_set.find_set(u)
parent_v = disjoint_set.find_set(v)
if parent_u != parent_v:
num_edges += 1
graph.add_edge(u, v, w)
disjoint_set.union(u, v)
return graph


if __name__ == "__main__":
import doctest

doctest.testmod()