37
37
"""
38
38
from __future__ import annotations
39
39
40
- from typing import Callable , TypeVar
40
+ from typing import Any , Callable , Generic , TypeVar
41
41
42
42
T = TypeVar ("T" )
43
43
44
44
45
- class SegmentTree :
45
+ class SegmentTree ( Generic [ T ]) :
46
46
def __init__ (self , arr : list [T ], fnc : Callable [[T , T ], T ]) -> None :
47
47
"""
48
48
Segment Tree constructor, it works just with commutative combiner.
@@ -55,8 +55,10 @@ def __init__(self, arr: list[T], fnc: Callable[[T, T], T]) -> None:
55
55
... lambda a, b: (a[0] + b[0], a[1] + b[1])).query(0, 2)
56
56
(6, 9)
57
57
"""
58
- self .N = len (arr )
59
- self .st = [None for _ in range (len (arr ))] + arr
58
+ any_type : Any | T = None
59
+
60
+ self .N : int = len (arr )
61
+ self .st : list [T ] = [any_type for _ in range (self .N )] + arr
60
62
self .fn = fnc
61
63
self .build ()
62
64
@@ -83,7 +85,7 @@ def update(self, p: int, v: T) -> None:
83
85
p = p // 2
84
86
self .st [p ] = self .fn (self .st [p * 2 ], self .st [p * 2 + 1 ])
85
87
86
- def query (self , l : int , r : int ) -> T : # noqa: E741
88
+ def query (self , l : int , r : int ) -> T | None : # noqa: E741
87
89
"""
88
90
Get range query value in log(N) time
89
91
:param l: left element index
@@ -101,7 +103,8 @@ def query(self, l: int, r: int) -> T: # noqa: E741
101
103
7
102
104
"""
103
105
l , r = l + self .N , r + self .N # noqa: E741
104
- res = None
106
+
107
+ res : T | None = None
105
108
while l <= r : # noqa: E741
106
109
if l % 2 == 1 :
107
110
res = self .st [l ] if res is None else self .fn (res , self .st [l ])
@@ -135,7 +138,7 @@ def query(self, l: int, r: int) -> T: # noqa: E741
135
138
max_segment_tree = SegmentTree (test_array , max )
136
139
sum_segment_tree = SegmentTree (test_array , lambda a , b : a + b )
137
140
138
- def test_all_segments ():
141
+ def test_all_segments () -> None :
139
142
"""
140
143
Test all possible segments
141
144
"""
0 commit comments