1
- // http ://judge .u-aizu.ac.jp/onlinejudge/description.jsp?id= DSL_2_B
1
+ // https ://onlinejudge .u-aizu.ac.jp/problems/ DSL_2_B
2
2
3
- #include < vector>
4
- #include < functional>
3
+ #include < bits/stdc++.h>
5
4
6
- template <class T = long long >
5
+ template <class T = int64_t >
7
6
class BinaryIndexedTree {
8
7
public:
9
- using F = std::function<T(const T &, const T &)>;
8
+ using F = std::function<T(T const &, T const &)>;
10
9
11
10
protected:
12
- int N;
11
+ size_t N;
13
12
std::vector<T> bit;
14
- T id ;
15
- F plus ;
16
- F minus ;
13
+ T empty ;
14
+ F combine_func ;
15
+ F remove_func ;
17
16
18
17
public:
19
- // O(N)
18
+ // Time: O(N)
20
19
BinaryIndexedTree (
21
- int n ,
22
- T id = T(),
23
- F plus = std::plus<T>(),
24
- F minus = std::minus<T>()
20
+ size_t n = 0 ,
21
+ T empty = T(),
22
+ F combine = std::plus<T>(),
23
+ F remove = std::minus<T>()
25
24
)
26
25
: N(n+1 )
27
- , bit(n+1 , id )
28
- , id(id )
29
- , plus(plus )
30
- , minus(minus )
26
+ , bit(n+1 , empty )
27
+ , empty(empty )
28
+ , combine_func(combine )
29
+ , remove_func(remove )
31
30
{
32
31
// Do nothing
33
32
}
34
- // O(1)
35
- int size () {
33
+ // Time: O(1)
34
+ size_t size () const {
36
35
return N - 1 ;
37
36
}
38
- // Sum of array[0..index)
39
- // O(logN)
37
+ // Fold elements of array[0..index)
40
38
// index = [0,N]
41
- T sum (int index) {
42
- if (index < 0 || index > N) throw ;
43
- T ans = id;
39
+ // Time: O(logN)
40
+ T fold (size_t index) const {
41
+ if (index > N) throw std::out_of_range (" index" );
42
+ T ans = empty;
44
43
for (; index > 0 ; index -= index & -index ) {
45
- ans = plus (ans, bit[index ]);
44
+ ans = combine_func (ans, bit[index ]);
46
45
}
47
46
return ans;
48
47
}
49
- // Sum of array[l, r)
50
- // O(logN)
51
- T sum (int l, int r) {
52
- if (l > r) throw ;
53
- return minus (sum (r), sum (l));
48
+ // Fold elements of array[l, r)
49
+ // l = [0,N]
50
+ // r = [l,N]
51
+ // Time: O(logN)
52
+ T fold (size_t l, size_t r) const {
53
+ if (l > N) throw std::out_of_range (" l" );
54
+ if (r < l || r > N) throw std::out_of_range (" r" );
55
+ return remove_func (fold (r), fold (l));
54
56
}
55
- // Add value to array[index]
56
- // O(logN)
57
+ // Combine given value to array[index]
57
58
// index = [0,N)
58
- void add (int index, const T& value) {
59
- if (index < 0 || index >= N) throw ;
59
+ // Time: O(logN)
60
+ void combine (size_t index, T const & value) {
61
+ if (index >= N) throw std::out_of_range (" index" );
60
62
for (++index ; index < N; index += index & -index ) {
61
- bit[index ] = plus (bit[index ], value);
63
+ bit[index ] = combine_func (bit[index ], value);
62
64
}
63
65
}
64
- // Set value to array[index]
65
- // O(logN)
66
- // index = [0,N)
67
- void set (int index, const T& value) {
68
- if (index < 0 || index >= N) throw ;
69
- T new_value = minus (value, sum (index , index +1 ));
70
- add (index , new_value);
71
- }
72
66
};
73
67
74
68
75
- #include < iostream>
76
-
77
69
using namespace std ;
78
70
79
71
int main () {
@@ -89,10 +81,10 @@ int main() {
89
81
cin >> c >> x >> y;
90
82
if (c == 0 ) {
91
83
--x;
92
- bit.add (x, y);
84
+ bit.combine (x, y);
93
85
} else if (c == 1 ) {
94
86
--x, --y;
95
- cout << bit.sum (x, y+1 ) << endl;
87
+ cout << bit.fold (x, y+1 ) << endl;
96
88
} else throw ;
97
89
}
98
90
0 commit comments