Skip to content

Fix mypy errors at prims_algo_2 #4527

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 27 additions & 29 deletions graphs/minimum_spanning_tree_prims2.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,9 @@
"""

from sys import maxsize
from typing import Dict, Optional, Tuple, Union
from typing import Generic, Optional, TypeVar

T = TypeVar("T")


def get_parent_position(position: int) -> int:
Expand Down Expand Up @@ -43,7 +45,7 @@ def get_child_right_position(position: int) -> int:
return (2 * position) + 2


class MinPriorityQueue:
class MinPriorityQueue(Generic[T]):
"""
Minimum Priority Queue Class

Expand Down Expand Up @@ -80,9 +82,9 @@ class MinPriorityQueue:
"""

def __init__(self) -> None:
self.heap = []
self.position_map = {}
self.elements = 0
self.heap: list[tuple[T, int]] = []
self.position_map: dict[T, int] = {}
self.elements: int = 0

def __len__(self) -> int:
return self.elements
Expand All @@ -94,14 +96,14 @@ def is_empty(self) -> bool:
# Check if the priority queue is empty
return self.elements == 0

def push(self, elem: Union[int, str], weight: int) -> None:
def push(self, elem: T, weight: int) -> None:
# Add an element with given priority to the queue
self.heap.append((elem, weight))
self.position_map[elem] = self.elements
self.elements += 1
self._bubble_up(elem)

def extract_min(self) -> Union[int, str]:
def extract_min(self) -> T:
# Remove and return the element with lowest weight (highest priority)
if self.elements > 1:
self._swap_nodes(0, self.elements - 1)
Expand All @@ -113,7 +115,7 @@ def extract_min(self) -> Union[int, str]:
self._bubble_down(bubble_down_elem)
return elem

def update_key(self, elem: Union[int, str], weight: int) -> None:
def update_key(self, elem: T, weight: int) -> None:
# Update the weight of the given key
position = self.position_map[elem]
self.heap[position] = (elem, weight)
Expand All @@ -127,7 +129,7 @@ def update_key(self, elem: Union[int, str], weight: int) -> None:
else:
self._bubble_down(elem)

def _bubble_up(self, elem: Union[int, str]) -> None:
def _bubble_up(self, elem: T) -> None:
# Place a node at the proper position (upward movement) [to be used internally
# only]
curr_pos = self.position_map[elem]
Expand All @@ -141,7 +143,7 @@ def _bubble_up(self, elem: Union[int, str]) -> None:
return self._bubble_up(elem)
return

def _bubble_down(self, elem: Union[int, str]) -> None:
def _bubble_down(self, elem: T) -> None:
# Place a node at the proper position (downward movement) [to be used
# internally only]
curr_pos = self.position_map[elem]
Expand Down Expand Up @@ -182,7 +184,7 @@ def _swap_nodes(self, node1_pos: int, node2_pos: int) -> None:
self.position_map[node2_elem] = node1_pos


class GraphUndirectedWeighted:
class GraphUndirectedWeighted(Generic[T]):
"""
Graph Undirected Weighted Class

Expand All @@ -192,24 +194,22 @@ class GraphUndirectedWeighted:
"""

def __init__(self) -> None:
self.connections = {}
self.nodes = 0
self.connections: dict[T, dict[T, int]] = {}
self.nodes: int = 0

def __repr__(self) -> str:
return str(self.connections)

def __len__(self) -> int:
return self.nodes

def add_node(self, node: Union[int, str]) -> None:
def add_node(self, node: T) -> None:
# Add a node in the graph if it is not in the graph
if node not in self.connections:
self.connections[node] = {}
self.nodes += 1

def add_edge(
self, node1: Union[int, str], node2: Union[int, str], weight: int
) -> None:
def add_edge(self, node1: T, node2: T, weight: int) -> None:
# Add an edge between 2 nodes in the graph
self.add_node(node1)
self.add_node(node2)
Expand All @@ -218,8 +218,8 @@ def add_edge(


def prims_algo(
graph: GraphUndirectedWeighted,
) -> Tuple[Dict[str, int], Dict[str, Optional[str]]]:
graph: GraphUndirectedWeighted[T],
) -> tuple[dict[T, int], dict[T, Optional[T]]]:
"""
>>> graph = GraphUndirectedWeighted()

Expand All @@ -239,10 +239,13 @@ def prims_algo(
13
"""
# prim's algorithm for minimum spanning tree
dist = {node: maxsize for node in graph.connections}
parent = {node: None for node in graph.connections}
priority_queue = MinPriorityQueue()
[priority_queue.push(node, weight) for node, weight in dist.items()]
dist: dict[T, int] = {node: maxsize for node in graph.connections}
parent: dict[T, Optional[T]] = {node: None for node in graph.connections}

priority_queue: MinPriorityQueue[T] = MinPriorityQueue()
for node, weight in dist.items():
priority_queue.push(node, weight)

if priority_queue.is_empty():
return dist, parent

Expand All @@ -254,6 +257,7 @@ def prims_algo(
dist[neighbour] = dist[node] + graph.connections[node][neighbour]
priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node

# running prim's algorithm
while not priority_queue.is_empty():
node = priority_queue.extract_min()
Expand All @@ -263,9 +267,3 @@ def prims_algo(
priority_queue.update_key(neighbour, dist[neighbour])
parent[neighbour] = node
return dist, parent


if __name__ == "__main__":
from doctest import testmod

testmod()