Skip to content

Commit 678535b

Browse files
authored
[mypy] Fix type annotations in non_recursive_segment_tree (TheAlgorithms#5652)
1 parent e7565f8 commit 678535b

File tree

1 file changed

+10
-7
lines changed

1 file changed

+10
-7
lines changed

data_structures/binary_tree/non_recursive_segment_tree.py

+10-7
Original file line numberDiff line numberDiff line change
@@ -37,12 +37,12 @@
3737
"""
3838
from __future__ import annotations
3939

40-
from typing import Callable, TypeVar
40+
from typing import Any, Callable, Generic, TypeVar
4141

4242
T = TypeVar("T")
4343

4444

45-
class SegmentTree:
45+
class SegmentTree(Generic[T]):
4646
def __init__(self, arr: list[T], fnc: Callable[[T, T], T]) -> None:
4747
"""
4848
Segment Tree constructor, it works just with commutative combiner.
@@ -55,8 +55,10 @@ def __init__(self, arr: list[T], fnc: Callable[[T, T], T]) -> None:
5555
... lambda a, b: (a[0] + b[0], a[1] + b[1])).query(0, 2)
5656
(6, 9)
5757
"""
58-
self.N = len(arr)
59-
self.st = [None for _ in range(len(arr))] + arr
58+
any_type: Any | T = None
59+
60+
self.N: int = len(arr)
61+
self.st: list[T] = [any_type for _ in range(self.N)] + arr
6062
self.fn = fnc
6163
self.build()
6264

@@ -83,7 +85,7 @@ def update(self, p: int, v: T) -> None:
8385
p = p // 2
8486
self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1])
8587

86-
def query(self, l: int, r: int) -> T: # noqa: E741
88+
def query(self, l: int, r: int) -> T | None: # noqa: E741
8789
"""
8890
Get range query value in log(N) time
8991
:param l: left element index
@@ -101,7 +103,8 @@ def query(self, l: int, r: int) -> T: # noqa: E741
101103
7
102104
"""
103105
l, r = l + self.N, r + self.N # noqa: E741
104-
res = None
106+
107+
res: T | None = None
105108
while l <= r: # noqa: E741
106109
if l % 2 == 1:
107110
res = self.st[l] if res is None else self.fn(res, self.st[l])
@@ -135,7 +138,7 @@ def query(self, l: int, r: int) -> T: # noqa: E741
135138
max_segment_tree = SegmentTree(test_array, max)
136139
sum_segment_tree = SegmentTree(test_array, lambda a, b: a + b)
137140

138-
def test_all_segments():
141+
def test_all_segments() -> None:
139142
"""
140143
Test all possible segments
141144
"""

0 commit comments

Comments
 (0)