Skip to content

Commit ae51489

Browse files
author
Christian Bender
authored
Merge pull request TheAlgorithms#316 from AnshulMalik/segment-tree-refactor
refactor-segment-tree
2 parents ee92291 + 0d19edb commit ae51489

File tree

1 file changed

+34
-28
lines changed

1 file changed

+34
-28
lines changed

data_structures/Binary Tree/SegmentTree.py

+34-28
Original file line numberDiff line numberDiff line change
@@ -3,63 +3,69 @@
33

44
class SegmentTree:
55

6-
def __init__(self, N):
7-
self.N = N
8-
self.st = [0 for i in range(0,4*N)] # approximate the overall size of segment tree with array N
6+
def __init__(self, A):
7+
self.N = len(A)
8+
self.st = [0] * (4 * self.N) # approximate the overall size of segment tree with array N
9+
self.build(1, 0, self.N - 1)
910

1011
def left(self, idx):
11-
return idx*2
12+
return idx * 2
1213

1314
def right(self, idx):
14-
return idx*2 + 1
15+
return idx * 2 + 1
1516

16-
def build(self, idx, l, r, A):
17-
if l==r:
18-
self.st[idx] = A[l-1]
19-
else :
20-
mid = (l+r)//2
21-
self.build(self.left(idx),l,mid, A)
22-
self.build(self.right(idx),mid+1,r, A)
17+
def build(self, idx, l, r):
18+
if l == r:
19+
self.st[idx] = A[l]
20+
else:
21+
mid = (l + r) // 2
22+
self.build(self.left(idx), l, mid)
23+
self.build(self.right(idx), mid + 1, r)
2324
self.st[idx] = max(self.st[self.left(idx)] , self.st[self.right(idx)])
2425

25-
def update(self, idx, l, r, a, b, val): # update(1, 1, N, a, b, v) for update val v to [a,b]
26+
def update(self, a, b, val):
27+
return self.update_recursive(1, 0, self.N - 1, a - 1, b - 1, val)
28+
29+
def update_recursive(self, idx, l, r, a, b, val): # update(1, 1, N, a, b, v) for update val v to [a,b]
2630
if r < a or l > b:
2731
return True
2832
if l == r :
2933
self.st[idx] = val
3034
return True
3135
mid = (l+r)//2
32-
self.update(self.left(idx),l,mid,a,b,val)
33-
self.update(self.right(idx),mid+1,r,a,b,val)
36+
self.update_recursive(self.left(idx), l, mid, a, b, val)
37+
self.update_recursive(self.right(idx), mid+1, r, a, b, val)
3438
self.st[idx] = max(self.st[self.left(idx)] , self.st[self.right(idx)])
3539
return True
3640

37-
def query(self, idx, l, r, a, b): #query(1, 1, N, a, b) for query max of [a,b]
41+
def query(self, a, b):
42+
return self.query_recursive(1, 0, self.N - 1, a - 1, b - 1)
43+
44+
def query_recursive(self, idx, l, r, a, b): #query(1, 1, N, a, b) for query max of [a,b]
3845
if r < a or l > b:
3946
return -math.inf
4047
if l >= a and r <= b:
4148
return self.st[idx]
4249
mid = (l+r)//2
43-
q1 = self.query(self.left(idx),l,mid,a,b)
44-
q2 = self.query(self.right(idx),mid+1,r,a,b)
45-
return max(q1,q2)
50+
q1 = self.query_recursive(self.left(idx), l, mid, a, b)
51+
q2 = self.query_recursive(self.right(idx), mid + 1, r, a, b)
52+
return max(q1, q2)
4653

4754
def showData(self):
4855
showList = []
4956
for i in range(1,N+1):
50-
showList += [self.query(1, 1, self.N, i, i)]
57+
showList += [self.query(i, i)]
5158
print (showList)
5259

5360

5461
if __name__ == '__main__':
5562
A = [1,2,-4,7,3,-5,6,11,-20,9,14,15,5,2,-8]
5663
N = 15
57-
segt = SegmentTree(N)
58-
segt.build(1,1,N,A)
59-
print (segt.query(1,1,N,4,6))
60-
print (segt.query(1,1,N,7,11))
61-
print (segt.query(1,1,N,7,12))
62-
segt.update(1,1,N,1,3,111)
63-
print (segt.query(1,1,N,1,15))
64-
segt.update(1,1,N,7,8,235)
64+
segt = SegmentTree(A)
65+
print (segt.query(4, 6))
66+
print (segt.query(7, 11))
67+
print (segt.query(7, 12))
68+
segt.update(1,3,111)
69+
print (segt.query(1, 15))
70+
segt.update(7,8,235)
6571
segt.showData()

0 commit comments

Comments
 (0)