diff --git a/data_structures/binary_tree/non_recursive_segment_tree.py b/data_structures/binary_tree/non_recursive_segment_tree.py index c914079e0a8d..b04a6e5cacb7 100644 --- a/data_structures/binary_tree/non_recursive_segment_tree.py +++ b/data_structures/binary_tree/non_recursive_segment_tree.py @@ -37,12 +37,12 @@ """ from __future__ import annotations -from typing import Callable, TypeVar +from typing import Any, Callable, Generic, TypeVar T = TypeVar("T") -class SegmentTree: +class SegmentTree(Generic[T]): def __init__(self, arr: list[T], fnc: Callable[[T, T], T]) -> None: """ Segment Tree constructor, it works just with commutative combiner. @@ -55,8 +55,10 @@ def __init__(self, arr: list[T], fnc: Callable[[T, T], T]) -> None: ... lambda a, b: (a[0] + b[0], a[1] + b[1])).query(0, 2) (6, 9) """ - self.N = len(arr) - self.st = [None for _ in range(len(arr))] + arr + any_type: Any | T = None + + self.N: int = len(arr) + self.st: list[T] = [any_type for _ in range(self.N)] + arr self.fn = fnc self.build() @@ -83,7 +85,7 @@ def update(self, p: int, v: T) -> None: p = p // 2 self.st[p] = self.fn(self.st[p * 2], self.st[p * 2 + 1]) - def query(self, l: int, r: int) -> T: # noqa: E741 + def query(self, l: int, r: int) -> T | None: # noqa: E741 """ Get range query value in log(N) time :param l: left element index @@ -101,7 +103,8 @@ def query(self, l: int, r: int) -> T: # noqa: E741 7 """ l, r = l + self.N, r + self.N # noqa: E741 - res = None + + res: T | None = None while l <= r: # noqa: E741 if l % 2 == 1: res = self.st[l] if res is None else self.fn(res, self.st[l]) @@ -135,7 +138,7 @@ def query(self, l: int, r: int) -> T: # noqa: E741 max_segment_tree = SegmentTree(test_array, max) sum_segment_tree = SegmentTree(test_array, lambda a, b: a + b) - def test_all_segments(): + def test_all_segments() -> None: """ Test all possible segments """