Skip to content

Commit 41117ea

Browse files
committed
Redesign APIs of binary-indexed-tree
1 parent 3b04cfe commit 41117ea

File tree

5 files changed

+147
-198
lines changed

5 files changed

+147
-198
lines changed

lib/cpalgo/ds/binary-indexed-tree-2d-sparse.hpp

Lines changed: 61 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -1,108 +1,94 @@
11
#pragma once
22

3-
// Finding Sum in Two-Dimensional Array
4-
5-
#include <map>
63
#include <functional>
4+
#include <map>
5+
#include <stdexcept>
76

87
// BinaryIndexedTree
9-
// Memory: O(NM)
10-
// Query: O(logN logM logNM)
11-
// Update: O(logN logM logNM)
12-
template<class T = long long>
13-
class BinaryIndexedTree {
8+
//
9+
// Space: O(NM)
10+
// Time:
11+
// Query: O(logN logM logNM)
12+
// Update: O(logN logM logNM)
13+
//
14+
template<class T = int64_t>
15+
class BinaryIndexedTree2D {
1416
public:
15-
using F = std::function<T(const T&, const T&)>;
17+
using F = std::function<T(T const&, T const&)>;
1618

1719
protected:
18-
int N, M;
19-
std::map<std::pair<int,int>,T> bit;
20-
T id;
21-
F plus;
22-
F minus;
20+
size_t N, M;
21+
std::map<std::pair<size_t,size_t>,T> bit;
22+
T empty;
23+
F combine_func;
24+
F remove_func;
2325

2426
public:
25-
// O(NM)
26-
BinaryIndexedTree(
27-
int n = 0,
28-
int m = 0,
29-
T id = T(),
30-
F plus = std::plus<T>(),
31-
F minus = std::minus<T>()
27+
// Time: O(1)
28+
BinaryIndexedTree2D(
29+
size_t n = 0,
30+
size_t m = 0,
31+
T empty = T(),
32+
F combine = std::plus<T>(),
33+
F remove = std::minus<T>()
3234
)
3335
: N(n + 1)
3436
, M(m + 1)
3537
, bit()
36-
, id(id)
37-
, plus(plus)
38-
, minus(minus)
38+
, empty(empty)
39+
, combine_func(combine)
40+
, remove_func(remove)
3941
{
4042
// Do nothing
4143
}
42-
// O(1)
43-
std::pair<int,int> size() {
44+
// Time: O(1)
45+
std::pair<size_t,size_t> size() const {
4446
return { N - 1, M - 1 };
4547
}
46-
// Sum of array[[0,x),[0,y)]
47-
// O(logN logM logNM)
48+
// Fold elements of array[[0,x),[0,y)]
4849
// x = [0,N], y = [0,M]
49-
T sum(int x, int y) {
50-
if (x < 0 || x > N) throw;
51-
if (y < 0 || y > M) throw;
52-
T ans = id;
53-
for (int i = x; i > 0; i -= i & -i) {
54-
for (int j = y; j > 0; j -= j & -j) {
55-
if (!bit.count({i,j})) {
56-
bit[{i,j}] = id;
50+
// Time: O(logN logM logNM)
51+
T fold(size_t x, size_t y) const {
52+
if (x > N) throw std::out_of_range("x");
53+
if (y > M) throw std::out_of_range("y");
54+
T ans = empty;
55+
for (size_t i = x; i > 0; i -= i & -i) {
56+
for (size_t j = y; j > 0; j -= j & -j) {
57+
if (bit.count({i,j})) {
58+
ans = combine_func(ans, bit.at({i,j}));
5759
}
58-
ans = plus(ans, bit[{i,j}]);
5960
}
6061
}
6162
return ans;
6263
}
63-
// Sum of array[[xl,xr),[yl,yr)]
64-
// O(logN logM logNM)
65-
// xl = [0,N], xr = [0,N]
66-
// yl = [0,M], yr = [0,M]
67-
T sum(int xl, int xr, int yl, int yr) {
68-
if (xl > xr || yl > yr) throw;
69-
if (xl < 0 || xl > N || xr < 0 || xr > N) throw;
70-
if (yl < 0 || yl > M || yr < 0 || yr > M) throw;
71-
T ans = id;
72-
ans = plus(ans, sum(xr, yr));
73-
ans = minus(ans, sum(xl, yr));
74-
ans = minus(ans, sum(xr, yl));
75-
ans = plus(ans, sum(xl, yl));
64+
// Fold elements of array[[xl,xr),[yl,yr)]
65+
// xl = [0,N), xr = [xl,N]
66+
// yl = [0,M), yr = [yl,M]
67+
// Time: O(logN logM logNM)
68+
T fold(size_t xl, size_t xr, size_t yl, size_t yr) const {
69+
if (xl >= N) throw std::out_of_range("xl");
70+
if (xr < xl || xr > N) throw std::out_of_range("xr");
71+
if (yl >= M) throw std::out_of_range("yl");
72+
if (yr < yl || yr > M) std::out_of_range("yr");
73+
T ans = empty;
74+
ans = combine_func(ans, fold(xr, yr));
75+
ans = remove_func(ans, fold(xl, yr));
76+
ans = remove_func(ans, fold(xr, yl));
77+
ans = combine_func(ans, fold(xl, yl));
7678
return ans;
7779
}
78-
// Add value at array[x,y]
79-
// O(logN logN logNM)
80-
// x = [0,N), y = [0,M)
81-
void add(int x, int y, const T& value) {
82-
if (x < 0 || x >= N) throw;
83-
if (y < 0 || y >= M) throw;
84-
for (int i = x + 1; i < N; i += i & -i) {
85-
for (int j = y + 1; j < M; j += j & -j) {
86-
if (!bit.count({i,j})) {
87-
bit[{i,j}] = id;
88-
}
89-
bit[{i,j}] = plus(bit[{i,j}], value);
90-
}
91-
}
92-
}
93-
// Set value at array[x,y]
94-
// O(logN logN logNM)
80+
// Combine given value at array[x,y]
9581
// x = [0,N), y = [0,M)
96-
void set(int x, int y, const T& value) {
97-
if (x < 0 || x >= N) throw;
98-
if (y < 0 || y >= M) throw;
99-
int new_value = value - sum(x,x+1,y,y+1);
100-
for (int i = x + 1; i < N; i += i & -i) {
101-
for (int j = y + 1; j < M; j += j & -j) {
82+
// Time: O(logN logN logNM)
83+
void combine(size_t x, size_t y, T const& value) {
84+
if (x >= N) throw std::out_of_range("x");
85+
if ( y >= M) throw std::out_of_range("y");
86+
for (size_t i = x + 1; i < N; i += i & -i) {
87+
for (size_t j = y + 1; j < M; j += j & -j) {
10288
if (!bit.count({i,j})) {
103-
bit[{i,j}] = id;
89+
bit[{i,j}] = empty;
10490
}
105-
bit[{i,j}] = plus(bit[{i,j}], new_value);
91+
bit[{i,j}] = combine_func(bit[{i,j}], value);
10692
}
10793
}
10894
}

lib/cpalgo/ds/binary-indexed-tree.hpp

Lines changed: 47 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,80 +1,78 @@
11
#pragma once
22

3-
// Finding Sum in One-Dimensional Array
4-
// Verified
5-
// http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_2_B
6-
7-
#include <vector>
83
#include <functional>
4+
#include <stdexcept>
5+
#include <vector>
96

107
// Binary Indexed Tree
11-
// Memory: O(N)
12-
// Query: O(logN)
13-
// Update: O(logN)
14-
template <class T = long long>
8+
//
9+
// Space: O(N)
10+
// Time:
11+
// Query: O(logN)
12+
// Update: O(logN)
13+
//
14+
// Verified:
15+
// - http://judge.u-aizu.ac.jp/onlinejudge/description.jsp?id=DSL_2_B
16+
//
17+
template <class T = int64_t>
1518
class BinaryIndexedTree {
1619
public:
17-
using F = std::function<T(const T&, const T&)>;
20+
using F = std::function<T(T const&, T const&)>;
1821

1922
protected:
20-
int N;
23+
size_t N;
2124
std::vector<T> bit;
22-
T id;
23-
F plus;
24-
F minus;
25+
T empty;
26+
F combine_func;
27+
F remove_func;
2528

2629
public:
27-
// O(N)
30+
// Time: O(N)
2831
BinaryIndexedTree(
29-
int n = 0,
30-
T id = T(),
31-
F plus = std::plus<T>(),
32-
F minus = std::minus<T>()
32+
size_t n = 0,
33+
T empty = T(),
34+
F combine = std::plus<T>(),
35+
F remove = std::minus<T>()
3336
)
3437
: N(n+1)
35-
, bit(n+1, id)
36-
, id(id)
37-
, plus(plus)
38-
, minus(minus)
38+
, bit(n+1, empty)
39+
, empty(empty)
40+
, combine_func(combine)
41+
, remove_func(remove)
3942
{
4043
// Do nothing
4144
}
42-
// O(1)
43-
int size() {
45+
// Time: O(1)
46+
size_t size() const {
4447
return N - 1;
4548
}
46-
// Sum of array[0..index)
47-
// O(logN)
49+
// Fold elements of array[0..index)
4850
// index = [0,N]
49-
T sum(int index) {
50-
if (index < 0 || index > N) throw;
51-
T ans = id;
51+
// Time: O(logN)
52+
T fold(size_t index) const {
53+
if (index > N) throw std::out_of_range("index");
54+
T ans = empty;
5255
for (; index > 0; index -= index & -index) {
53-
ans = plus(ans, bit[index]);
56+
ans = combine_func(ans, bit[index]);
5457
}
5558
return ans;
5659
}
57-
// Sum of array[l, r)
58-
// O(logN)
59-
T sum(int l, int r) {
60-
if (l > r) throw;
61-
return minus(sum(r), sum(l));
60+
// Fold elements of array[l, r)
61+
// l = [0,N)
62+
// r = [l,N]
63+
// Time: O(logN)
64+
T fold(size_t l, size_t r) const {
65+
if (l >= N) throw std::out_of_range("l");
66+
if (r < l || r > N) throw std::out_of_range("r");
67+
return remove_func(fold(r), fold(l));
6268
}
63-
// Add value to array[index]
64-
// O(logN)
69+
// Combine given value to array[index]
6570
// index = [0,N)
66-
void add(int index, const T& value) {
67-
if (index < 0 || index >= N) throw;
71+
// Time: O(logN)
72+
void combine(size_t index, T const& value) {
73+
if (index >= N) throw std::out_of_range("index");
6874
for (++index; index < N; index += index & -index) {
69-
bit[index] = plus(bit[index], value);
75+
bit[index] = combine_func(bit[index], value);
7076
}
7177
}
72-
// Set value to array[index]
73-
// O(logN)
74-
// index = [0,N)
75-
void set(int index, const T& value) {
76-
if (index < 0 || index >= N) throw;
77-
T new_value = minus(value, sum(index, index+1));
78-
add(index, new_value);
79-
}
8078
};

lib/main/ds/binary-indexed-tree/main-binary-indexed-tree-2d-sparse.cpp

Lines changed: 15 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -7,67 +7,51 @@
77

88
using namespace std;
99

10-
using BIT = BinaryIndexedTree<long long>;
10+
using BIT = BinaryIndexedTree2D<int>;
1111
BIT bit;
1212

1313
void action_init() {
14-
int size_n, size_m;
14+
size_t size_n, size_m;
1515
cin >> size_n >> size_m;
16-
if (size_n <= 0 || size_m <= 0) {
17-
cout << "false" << endl;
18-
return;
19-
}
2016
bit = BIT(size_n, size_m);
2117
cout << "true" << endl;
2218
}
2319

2420
void action_sum() {
25-
int N, M;
21+
size_t N, M;
2622
tie(N, M) = bit.size();
27-
int x, y;
23+
size_t x, y;
2824
cin >> x >> y;
29-
if (x < 0 || x > N || y < 0 || y > M) {
25+
if (x > N || y > M) {
3026
cout << "false" << endl;
3127
return;
3228
}
33-
auto ans = bit.sum(x, y);
29+
auto ans = bit.fold(x, y);
3430
cout << ans << endl;
3531
}
3632

3733
void action_add() {
38-
int N, M;
39-
tie(N, M) = bit.size();
40-
int x, y, value;
41-
cin >> x >> y >> value;
42-
if (x < 0 || x >= N || y < 0 || y >= M) {
43-
cout << "false" << endl;
44-
return;
45-
}
46-
bit.add(x, y, value);
47-
cout << "true" << endl;
48-
}
49-
50-
void action_set() {
51-
int N, M;
34+
size_t N, M;
5235
tie(N, M) = bit.size();
53-
int x, y, value;
36+
size_t x, y;
37+
int value;
5438
cin >> x >> y >> value;
55-
if (x < 0 || x >= N || y < 0 || y >= M) {
39+
if (x >= N || y >= M) {
5640
cout << "false" << endl;
5741
return;
5842
}
59-
bit.set(x, y, value);
43+
bit.combine(x, y, value);
6044
cout << "true" << endl;
6145
}
6246

6347
void action_dump() {
64-
int N, M;
48+
size_t N, M;
6549
tie(N,M) = bit.size();
66-
for (int y = 0; y < M; ++y) {
50+
for (size_t y = 0; y < M; ++y) {
6751
if (y > 0) cout << endl;
68-
for (int x = 0; x < N; ++x) {
52+
for (size_t x = 0; x < N; ++x) {
6953
if (x > 0) cout << " ";
70-
cout << bit.sum(x, x+1, y, y+1);
54+
cout << bit.fold(x, x+1, y, y+1);
7155
}
7256
}
7357
cout << endl;
@@ -78,6 +62,5 @@ void setup(string& header, map<string,Command>& commands) {
7862
commands["init"] = { "init {size_n} {size_m}", action_init };
7963
commands["sum"] = { "sum {x} {y}", action_sum };
8064
commands["add"] = { "add {x} {y} {value}", action_add };
81-
commands["set"] = { "set {x} {y} {value}", action_add };
8265
commands["dump"] = { "dump", action_dump };
8366
}

0 commit comments

Comments
 (0)