Segment Tree
配列の各区間での計算結果をノードに持つことで任意の区間での計算をO(log n)で行えるようにした完全二分木。 以下の例では和を持っているが最小値にすればRange Minimum Query (RMQ)を解くことができ、ソートした列を持てばマージソートになる。値を更新する際は下から順に再計算していく。
#include <bits/stdc++.h>
using namespace std;
template <typename S, typename T>
class SegmentTree {
S default_value;
int leaf_num = 1;
vector<S> nodes;
/*
[0]
/ \
[1] [2]
/ \ / \
[3][4][5][6]
*/
function<int(int, int)> calc;
public:
SegmentTree(vector<S> data, S default_value, function<T(S, S)> calc) {
while (leaf_num < data.size()) leaf_num *= 2;
this->calc = calc;
this->default_value = default_value;
nodes = vector<S>(2 * data.size() - 1, default_value);
for (int i = 0; i < data.size(); i++) update(i, data[i]);
}
void update(int idx, S value) {
idx += leaf_num - 1;
nodes[idx] = value;
while (idx > 0) {
idx = (idx - 1) / 2; // parent
nodes[idx] = this->calc(nodes[idx * 2 + 1], nodes[idx * 2 + 2]);
}
}
T query(int from, int to, int node_idx=0, int left=0, int right=-1) {
if (right == -1) right = leaf_num-1;
if (right < from || to < left) return this->calc(default_value, default_value);
if (from <= left && right <= to) return this->calc(default_value, nodes[node_idx]);
return this->calc(
query(from, to, node_idx * 2 + 1, left, (left + right) / 2), // left child
query(from, to, node_idx * 2 + 2, (left + right) / 2 + 1, right) // right child
);
}
};
int main() {
SegmentTree<int, int> segtree({6, 4, 2, 3, 5, 1, 3, 2}, 0, [](int a, int b) { return a + b; });
cout << segtree.query(0, 6) << endl; // 24
segtree.update(0, 8);
cout << segtree.query(0, 6) << endl; // 26
}
区間を更新する場合はその長さを m とすると O(m log n) となってしまうが、もう一つのSegment Treeに区間の更新分を持つなどして遅延評価すれば改善できる。
Binary Indexed Tree (BIT)
Segment Treeで部分和を求める場合右の子ノードの値は親ノードに含まれているので省くことができる。 残ったノードに次のようにインデックスを振っていったのがBinary Indexed Treeで、その2進数表現での末尾の1のビットを足したり引いたりすることで値の更新や0からidxまでの和を求めることができる。Segment Treeと比べてできることが限られるが実装がシンプルで速い。
末尾の1のビットは x & -x
で得られる。
#include <bits/stdc++.h>
using namespace std;
class BinaryIndexedTree {
vector<int> nodes;
/*
[4]
/ \
[2] -
/ \ / \
[1] - [3] -
*/
public:
BinaryIndexedTree(vector<int> data) {
this->nodes = vector<int>(data.size()+1, 0);
for(int i = 0; i < data.size(); i++) this->add(i, data[i]);
}
void add(int idx, int value) {
idx++;
while (idx < nodes.size()) {
nodes[idx] += value;
idx += idx & -idx; // -----> e.g. 0001 & 1001 = 0001
}
}
int sum(int to) {
to++;
int ret = 0;
while (to > 0) {
ret += nodes[to];
to -= to & -to; // - - ->
}
return ret;
}
};
int main() {
BinaryIndexedTree bit({6, 4, 2, 3, 5, 1, 3, 2});
cout << bit.sum(6) << endl; // 24
bit.add(0, 2);
cout << bit.sum(6) << endl; // 26
}