Skip to content

Commit 472f63e

Browse files
authored
Adding type hints to RedBlackTree (TheAlgorithms#2371)
* redblacktree type hints * fixed type hints to pass flake8
1 parent 8c191f1 commit 472f63e

File tree

1 file changed

+47
-39
lines changed

1 file changed

+47
-39
lines changed

data_structures/binary_tree/red_black_tree.py

+47-39
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
python/black : true
33
flake8 : passed
44
"""
5+
from typing import Iterator, Optional
56

67

78
class RedBlackTree:
@@ -18,7 +19,14 @@ class RedBlackTree:
1819
terms of the size of the tree.
1920
"""
2021

21-
def __init__(self, label=None, color=0, parent=None, left=None, right=None):
22+
def __init__(
23+
self,
24+
label: Optional[int] = None,
25+
color: int = 0,
26+
parent: Optional["RedBlackTree"] = None,
27+
left: Optional["RedBlackTree"] = None,
28+
right: Optional["RedBlackTree"] = None,
29+
) -> None:
2230
"""Initialize a new Red-Black Tree node with the given values:
2331
label: The value associated with this node
2432
color: 0 if black, 1 if red
@@ -34,7 +42,7 @@ def __init__(self, label=None, color=0, parent=None, left=None, right=None):
3442

3543
# Here are functions which are specific to red-black trees
3644

37-
def rotate_left(self):
45+
def rotate_left(self) -> "RedBlackTree":
3846
"""Rotate the subtree rooted at this node to the left and
3947
returns the new root to this subtree.
4048
Performing one rotation can be done in O(1).
@@ -54,7 +62,7 @@ def rotate_left(self):
5462
right.parent = parent
5563
return right
5664

57-
def rotate_right(self):
65+
def rotate_right(self) -> "RedBlackTree":
5866
"""Rotate the subtree rooted at this node to the right and
5967
returns the new root to this subtree.
6068
Performing one rotation can be done in O(1).
@@ -74,7 +82,7 @@ def rotate_right(self):
7482
left.parent = parent
7583
return left
7684

77-
def insert(self, label):
85+
def insert(self, label: int) -> "RedBlackTree":
7886
"""Inserts label into the subtree rooted at self, performs any
7987
rotations necessary to maintain balance, and then returns the
8088
new root to this subtree (likely self).
@@ -100,7 +108,7 @@ def insert(self, label):
100108
self.right._insert_repair()
101109
return self.parent or self
102110

103-
def _insert_repair(self):
111+
def _insert_repair(self) -> None:
104112
"""Repair the coloring from inserting into a tree."""
105113
if self.parent is None:
106114
# This node is the root, so it just needs to be black
@@ -131,7 +139,7 @@ def _insert_repair(self):
131139
self.grandparent.color = 1
132140
self.grandparent._insert_repair()
133141

134-
def remove(self, label):
142+
def remove(self, label: int) -> "RedBlackTree":
135143
"""Remove label from this tree."""
136144
if self.label == label:
137145
if self.left and self.right:
@@ -186,7 +194,7 @@ def remove(self, label):
186194
self.right.remove(label)
187195
return self.parent or self
188196

189-
def _remove_repair(self):
197+
def _remove_repair(self) -> None:
190198
"""Repair the coloring of the tree that may have been messed up."""
191199
if color(self.sibling) == 1:
192200
self.sibling.color = 0
@@ -250,7 +258,7 @@ def _remove_repair(self):
250258
self.parent.color = 0
251259
self.parent.sibling.color = 0
252260

253-
def check_color_properties(self):
261+
def check_color_properties(self) -> bool:
254262
"""Check the coloring of the tree, and return True iff the tree
255263
is colored in a way which matches these five properties:
256264
(wording stolen from wikipedia article)
@@ -287,7 +295,7 @@ def check_color_properties(self):
287295
# All properties were met
288296
return True
289297

290-
def check_coloring(self):
298+
def check_coloring(self) -> None:
291299
"""A helper function to recursively check Property 4 of a
292300
Red-Black Tree. See check_color_properties for more info.
293301
"""
@@ -300,7 +308,7 @@ def check_coloring(self):
300308
return False
301309
return True
302310

303-
def black_height(self):
311+
def black_height(self) -> int:
304312
"""Returns the number of black nodes from this node to the
305313
leaves of the tree, or None if there isn't one such value (the
306314
tree is color incorrectly).
@@ -322,14 +330,14 @@ def black_height(self):
322330

323331
# Here are functions which are general to all binary search trees
324332

325-
def __contains__(self, label):
333+
def __contains__(self, label) -> bool:
326334
"""Search through the tree for label, returning True iff it is
327335
found somewhere in the tree.
328336
Guaranteed to run in O(log(n)) time.
329337
"""
330338
return self.search(label) is not None
331339

332-
def search(self, label):
340+
def search(self, label: int) -> "RedBlackTree":
333341
"""Search through the tree for label, returning its node if
334342
it's found, and None otherwise.
335343
This method is guaranteed to run in O(log(n)) time.
@@ -347,7 +355,7 @@ def search(self, label):
347355
else:
348356
return self.left.search(label)
349357

350-
def floor(self, label):
358+
def floor(self, label: int) -> int:
351359
"""Returns the largest element in this tree which is at most label.
352360
This method is guaranteed to run in O(log(n)) time."""
353361
if self.label == label:
@@ -364,7 +372,7 @@ def floor(self, label):
364372
return attempt
365373
return self.label
366374

367-
def ceil(self, label):
375+
def ceil(self, label: int) -> int:
368376
"""Returns the smallest element in this tree which is at least label.
369377
This method is guaranteed to run in O(log(n)) time.
370378
"""
@@ -382,7 +390,7 @@ def ceil(self, label):
382390
return attempt
383391
return self.label
384392

385-
def get_max(self):
393+
def get_max(self) -> int:
386394
"""Returns the largest element in this tree.
387395
This method is guaranteed to run in O(log(n)) time.
388396
"""
@@ -392,7 +400,7 @@ def get_max(self):
392400
else:
393401
return self.label
394402

395-
def get_min(self):
403+
def get_min(self) -> int:
396404
"""Returns the smallest element in this tree.
397405
This method is guaranteed to run in O(log(n)) time.
398406
"""
@@ -403,15 +411,15 @@ def get_min(self):
403411
return self.label
404412

405413
@property
406-
def grandparent(self):
414+
def grandparent(self) -> "RedBlackTree":
407415
"""Get the current node's grandparent, or None if it doesn't exist."""
408416
if self.parent is None:
409417
return None
410418
else:
411419
return self.parent.parent
412420

413421
@property
414-
def sibling(self):
422+
def sibling(self) -> "RedBlackTree":
415423
"""Get the current node's sibling, or None if it doesn't exist."""
416424
if self.parent is None:
417425
return None
@@ -420,18 +428,18 @@ def sibling(self):
420428
else:
421429
return self.parent.left
422430

423-
def is_left(self):
431+
def is_left(self) -> bool:
424432
"""Returns true iff this node is the left child of its parent."""
425433
return self.parent and self.parent.left is self
426434

427-
def is_right(self):
435+
def is_right(self) -> bool:
428436
"""Returns true iff this node is the right child of its parent."""
429437
return self.parent and self.parent.right is self
430438

431-
def __bool__(self):
439+
def __bool__(self) -> bool:
432440
return True
433441

434-
def __len__(self):
442+
def __len__(self) -> int:
435443
"""
436444
Return the number of nodes in this tree.
437445
"""
@@ -442,28 +450,28 @@ def __len__(self):
442450
ln += len(self.right)
443451
return ln
444452

445-
def preorder_traverse(self):
453+
def preorder_traverse(self) -> Iterator[int]:
446454
yield self.label
447455
if self.left:
448456
yield from self.left.preorder_traverse()
449457
if self.right:
450458
yield from self.right.preorder_traverse()
451459

452-
def inorder_traverse(self):
460+
def inorder_traverse(self) -> Iterator[int]:
453461
if self.left:
454462
yield from self.left.inorder_traverse()
455463
yield self.label
456464
if self.right:
457465
yield from self.right.inorder_traverse()
458466

459-
def postorder_traverse(self):
467+
def postorder_traverse(self) -> Iterator[int]:
460468
if self.left:
461469
yield from self.left.postorder_traverse()
462470
if self.right:
463471
yield from self.right.postorder_traverse()
464472
yield self.label
465473

466-
def __repr__(self):
474+
def __repr__(self) -> str:
467475
from pprint import pformat
468476

469477
if self.left is None and self.right is None:
@@ -476,15 +484,15 @@ def __repr__(self):
476484
indent=1,
477485
)
478486

479-
def __eq__(self, other):
487+
def __eq__(self, other) -> bool:
480488
"""Test if two trees are equal."""
481489
if self.label == other.label:
482490
return self.left == other.left and self.right == other.right
483491
else:
484492
return False
485493

486494

487-
def color(node):
495+
def color(node) -> int:
488496
"""Returns the color of a node, allowing for None leaves."""
489497
if node is None:
490498
return 0
@@ -498,7 +506,7 @@ def color(node):
498506
"""
499507

500508

501-
def test_rotations():
509+
def test_rotations() -> bool:
502510
"""Test that the rotate_left and rotate_right functions work."""
503511
# Make a tree to test on
504512
tree = RedBlackTree(0)
@@ -534,7 +542,7 @@ def test_rotations():
534542
return True
535543

536544

537-
def test_insertion_speed():
545+
def test_insertion_speed() -> bool:
538546
"""Test that the tree balances inserts to O(log(n)) by doing a lot
539547
of them.
540548
"""
@@ -544,7 +552,7 @@ def test_insertion_speed():
544552
return True
545553

546554

547-
def test_insert():
555+
def test_insert() -> bool:
548556
"""Test the insert() method of the tree correctly balances, colors,
549557
and inserts.
550558
"""
@@ -565,7 +573,7 @@ def test_insert():
565573
return tree == ans
566574

567575

568-
def test_insert_and_search():
576+
def test_insert_and_search() -> bool:
569577
"""Tests searching through the tree for values."""
570578
tree = RedBlackTree(0)
571579
tree.insert(8)
@@ -583,7 +591,7 @@ def test_insert_and_search():
583591
return True
584592

585593

586-
def test_insert_delete():
594+
def test_insert_delete() -> bool:
587595
"""Test the insert() and delete() method of the tree, verifying the
588596
insertion and removal of elements, and the balancing of the tree.
589597
"""
@@ -607,7 +615,7 @@ def test_insert_delete():
607615
return True
608616

609617

610-
def test_floor_ceil():
618+
def test_floor_ceil() -> bool:
611619
"""Tests the floor and ceiling functions in the tree."""
612620
tree = RedBlackTree(0)
613621
tree.insert(-16)
@@ -623,7 +631,7 @@ def test_floor_ceil():
623631
return True
624632

625633

626-
def test_min_max():
634+
def test_min_max() -> bool:
627635
"""Tests the min and max functions in the tree."""
628636
tree = RedBlackTree(0)
629637
tree.insert(-16)
@@ -637,7 +645,7 @@ def test_min_max():
637645
return True
638646

639647

640-
def test_tree_traversal():
648+
def test_tree_traversal() -> bool:
641649
"""Tests the three different tree traversal functions."""
642650
tree = RedBlackTree(0)
643651
tree = tree.insert(-16)
@@ -655,7 +663,7 @@ def test_tree_traversal():
655663
return True
656664

657665

658-
def test_tree_chaining():
666+
def test_tree_chaining() -> bool:
659667
"""Tests the three different tree chaining functions."""
660668
tree = RedBlackTree(0)
661669
tree = tree.insert(-16).insert(16).insert(8).insert(24).insert(20).insert(22)
@@ -672,7 +680,7 @@ def print_results(msg: str, passes: bool) -> None:
672680
print(str(msg), "works!" if passes else "doesn't work :(")
673681

674682

675-
def pytests():
683+
def pytests() -> None:
676684
assert test_rotations()
677685
assert test_insert()
678686
assert test_insert_and_search()
@@ -682,7 +690,7 @@ def pytests():
682690
assert test_tree_chaining()
683691

684692

685-
def main():
693+
def main() -> None:
686694
"""
687695
>>> pytests()
688696
"""

0 commit comments

Comments
 (0)