Skip to content

[mypy] Add/fix type annotations for binary trees in data structures #4085

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
99 changes: 62 additions & 37 deletions data_structures/binary_tree/binary_search_tree_recursive.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,21 +8,22 @@
python binary_search_tree_recursive.py
"""
import unittest
from typing import Iterator, Optional


class Node:
def __init__(self, label: int, parent):
def __init__(self, label: int, parent: Optional["Node"]) -> None:
self.label = label
self.parent = parent
self.left = None
self.right = None
self.left: Optional[Node] = None
self.right: Optional[Node] = None


class BinarySearchTree:
def __init__(self):
self.root = None
def __init__(self) -> None:
self.root: Optional[Node] = None

def empty(self):
def empty(self) -> None:
"""
Empties the tree

Expand All @@ -46,7 +47,7 @@ def is_empty(self) -> bool:
"""
return self.root is None

def put(self, label: int):
def put(self, label: int) -> None:
"""
Put a new node in the tree

Expand All @@ -65,7 +66,9 @@ def put(self, label: int):
"""
self.root = self._put(self.root, label)

def _put(self, node: Node, label: int, parent: Node = None) -> Node:
def _put(
self, node: Optional[Node], label: int, parent: Optional[Node] = None
) -> Node:
if node is None:
node = Node(label, parent)
else:
Expand Down Expand Up @@ -95,7 +98,7 @@ def search(self, label: int) -> Node:
"""
return self._search(self.root, label)

def _search(self, node: Node, label: int) -> Node:
def _search(self, node: Optional[Node], label: int) -> Node:
if node is None:
raise Exception(f"Node with label {label} does not exist")
else:
Expand All @@ -106,7 +109,7 @@ def _search(self, node: Node, label: int) -> Node:

return node

def remove(self, label: int):
def remove(self, label: int) -> None:
"""
Removes a node in the tree

Expand All @@ -122,22 +125,22 @@ def remove(self, label: int):
Exception: Node with label 3 does not exist
"""
node = self.search(label)
if not node.right and not node.left:
self._reassign_nodes(node, None)
elif not node.right and node.left:
self._reassign_nodes(node, node.left)
elif node.right and not node.left:
self._reassign_nodes(node, node.right)
else:
if node.right and node.left:
lowest_node = self._get_lowest_node(node.right)
lowest_node.left = node.left
lowest_node.right = node.right
node.left.parent = lowest_node
if node.right:
node.right.parent = lowest_node
self._reassign_nodes(node, lowest_node)
elif not node.right and node.left:
self._reassign_nodes(node, node.left)
elif node.right and not node.left:
self._reassign_nodes(node, node.right)
else:
self._reassign_nodes(node, None)

def _reassign_nodes(self, node: Node, new_children: Node):
def _reassign_nodes(self, node: Node, new_children: Optional[Node]) -> None:
if new_children:
new_children.parent = node.parent

Expand Down Expand Up @@ -192,7 +195,7 @@ def get_max_label(self) -> int:
>>> t.get_max_label()
10
"""
if self.is_empty():
if self.root is None:
raise Exception("Binary search tree is empty")

node = self.root
Expand All @@ -216,7 +219,7 @@ def get_min_label(self) -> int:
>>> t.get_min_label()
8
"""
if self.is_empty():
if self.root is None:
raise Exception("Binary search tree is empty")

node = self.root
Expand All @@ -225,7 +228,7 @@ def get_min_label(self) -> int:

return node.label

def inorder_traversal(self) -> list:
def inorder_traversal(self) -> Iterator[Node]:
"""
Return the inorder traversal of the tree

Expand All @@ -241,13 +244,13 @@ def inorder_traversal(self) -> list:
"""
return self._inorder_traversal(self.root)

def _inorder_traversal(self, node: Node) -> list:
def _inorder_traversal(self, node: Optional[Node]) -> Iterator[Node]:
if node is not None:
yield from self._inorder_traversal(node.left)
yield node
yield from self._inorder_traversal(node.right)

def preorder_traversal(self) -> list:
def preorder_traversal(self) -> Iterator[Node]:
"""
Return the preorder traversal of the tree

Expand All @@ -263,7 +266,7 @@ def preorder_traversal(self) -> list:
"""
return self._preorder_traversal(self.root)

def _preorder_traversal(self, node: Node) -> list:
def _preorder_traversal(self, node: Optional[Node]) -> Iterator[Node]:
if node is not None:
yield node
yield from self._preorder_traversal(node.left)
Expand All @@ -272,7 +275,7 @@ def _preorder_traversal(self, node: Node) -> list:

class BinarySearchTreeTest(unittest.TestCase):
@staticmethod
def _get_binary_search_tree():
def _get_binary_search_tree() -> BinarySearchTree:
r"""
8
/ \
Expand All @@ -298,14 +301,15 @@ def _get_binary_search_tree():

return t

def test_put(self):
def test_put(self) -> None:
t = BinarySearchTree()
assert t.is_empty()

t.put(8)
r"""
8
"""
assert t.root is not None
assert t.root.parent is None
assert t.root.label == 8

Expand All @@ -315,6 +319,7 @@ def test_put(self):
\
10
"""
assert t.root.right is not None
assert t.root.right.parent == t.root
assert t.root.right.label == 10

Expand All @@ -324,6 +329,7 @@ def test_put(self):
/ \
3 10
"""
assert t.root.left is not None
assert t.root.left.parent == t.root
assert t.root.left.label == 3

Expand All @@ -335,6 +341,7 @@ def test_put(self):
\
6
"""
assert t.root.left.right is not None
assert t.root.left.right.parent == t.root.left
assert t.root.left.right.label == 6

Expand All @@ -346,13 +353,14 @@ def test_put(self):
/ \
1 6
"""
assert t.root.left.left is not None
assert t.root.left.left.parent == t.root.left
assert t.root.left.left.label == 1

with self.assertRaises(Exception):
t.put(1)

def test_search(self):
def test_search(self) -> None:
t = self._get_binary_search_tree()

node = t.search(6)
Expand All @@ -364,7 +372,7 @@ def test_search(self):
with self.assertRaises(Exception):
t.search(2)

def test_remove(self):
def test_remove(self) -> None:
t = self._get_binary_search_tree()

t.remove(13)
Expand All @@ -379,6 +387,9 @@ def test_remove(self):
\
5
"""
assert t.root is not None
assert t.root.right is not None
assert t.root.right.right is not None
assert t.root.right.right.right is None
assert t.root.right.right.left is None

Expand All @@ -394,6 +405,9 @@ def test_remove(self):
\
5
"""
assert t.root.left is not None
assert t.root.left.right is not None
assert t.root.left.right.left is not None
assert t.root.left.right.right is None
assert t.root.left.right.left.label == 4

Expand All @@ -407,6 +421,8 @@ def test_remove(self):
\
5
"""
assert t.root.left.left is not None
assert t.root.left.right.right is not None
assert t.root.left.left.label == 1
assert t.root.left.right.label == 4
assert t.root.left.right.right.label == 5
Expand All @@ -422,6 +438,7 @@ def test_remove(self):
/ \ \
1 5 14
"""
assert t.root is not None
assert t.root.left.label == 4
assert t.root.left.right.label == 5
assert t.root.left.left.label == 1
Expand All @@ -437,13 +454,15 @@ def test_remove(self):
/ \
1 14
"""
assert t.root.left is not None
assert t.root.left.left is not None
assert t.root.left.label == 5
assert t.root.left.right is None
assert t.root.left.left.label == 1
assert t.root.left.parent == t.root
assert t.root.left.left.parent == t.root.left

def test_remove_2(self):
def test_remove_2(self) -> None:
t = self._get_binary_search_tree()

t.remove(3)
Expand All @@ -456,6 +475,12 @@ def test_remove_2(self):
/ \ /
5 7 13
"""
assert t.root is not None
assert t.root.left is not None
assert t.root.left.left is not None
assert t.root.left.right is not None
assert t.root.left.right.left is not None
assert t.root.left.right.right is not None
assert t.root.left.label == 4
assert t.root.left.right.label == 6
assert t.root.left.left.label == 1
Expand All @@ -466,25 +491,25 @@ def test_remove_2(self):
assert t.root.left.left.parent == t.root.left
assert t.root.left.right.left.parent == t.root.left.right

def test_empty(self):
def test_empty(self) -> None:
t = self._get_binary_search_tree()
t.empty()
assert t.root is None

def test_is_empty(self):
def test_is_empty(self) -> None:
t = self._get_binary_search_tree()
assert not t.is_empty()

t.empty()
assert t.is_empty()

def test_exists(self):
def test_exists(self) -> None:
t = self._get_binary_search_tree()

assert t.exists(6)
assert not t.exists(-1)

def test_get_max_label(self):
def test_get_max_label(self) -> None:
t = self._get_binary_search_tree()

assert t.get_max_label() == 14
Expand All @@ -493,7 +518,7 @@ def test_get_max_label(self):
with self.assertRaises(Exception):
t.get_max_label()

def test_get_min_label(self):
def test_get_min_label(self) -> None:
t = self._get_binary_search_tree()

assert t.get_min_label() == 1
Expand All @@ -502,20 +527,20 @@ def test_get_min_label(self):
with self.assertRaises(Exception):
t.get_min_label()

def test_inorder_traversal(self):
def test_inorder_traversal(self) -> None:
t = self._get_binary_search_tree()

inorder_traversal_nodes = [i.label for i in t.inorder_traversal()]
assert inorder_traversal_nodes == [1, 3, 4, 5, 6, 7, 8, 10, 13, 14]

def test_preorder_traversal(self):
def test_preorder_traversal(self) -> None:
t = self._get_binary_search_tree()

preorder_traversal_nodes = [i.label for i in t.preorder_traversal()]
assert preorder_traversal_nodes == [8, 3, 1, 6, 4, 5, 7, 10, 14, 13]


def binary_search_tree_example():
def binary_search_tree_example() -> None:
r"""
Example
8
Expand Down
9 changes: 5 additions & 4 deletions data_structures/binary_tree/lazy_segment_tree.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations

import math
from typing import List, Union


class SegmentTree:
Expand Down Expand Up @@ -37,7 +38,7 @@ def right(self, idx: int) -> int:
return idx * 2 + 1

def build(
self, idx: int, left_element: int, right_element: int, A: list[int]
self, idx: int, left_element: int, right_element: int, A: List[int]
) -> None:
if left_element == right_element:
self.segment_tree[idx] = A[left_element - 1]
Expand Down Expand Up @@ -88,7 +89,7 @@ def update(
# query with O(lg n)
def query(
self, idx: int, left_element: int, right_element: int, a: int, b: int
) -> int:
) -> Union[int, float]:
"""
query(1, 1, size, a, b) for query max of [a,b]
>>> A = [1, 2, -4, 7, 3, -5, 6, 11, -20, 9, 14, 15, 5, 2, -8]
Expand Down Expand Up @@ -118,8 +119,8 @@ def query(
q2 = self.query(self.right(idx), mid + 1, right_element, a, b)
return max(q1, q2)

def __str__(self) -> None:
return [self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)]
def __str__(self) -> str:
return str([self.query(1, 1, self.size, i, i) for i in range(1, self.size + 1)])


if __name__ == "__main__":
Expand Down
Loading