Persistent Segment Tree

티스토리 메뉴 펼치기 댓글수1

Problem Solving/Data Structure

Persistent Segment Tree

hongjun7
댓글수1

들어가기 이전에 다음 글이 해당 자료구조에 대해 영어로 잘 설명해 놓았다.


Persistent Segment Tree를 이용하는 문제로 BOJ 11932 트리와 K번째 수 문제가 있다. 노드의 개수가 N개이고, 각 정점마다 가중치가 있는 트리가 있을 때에 M개의 쿼리에 대해 두 노드 사이를 잇는 경로 상의 K번째 정점 가중치 값을 출력하는 문제이다. 편의를 위해서 가중치들은 다 좌표압축을 하여 1에서부터 N까지의 값만 가진다고 가정하자.


한 쿼리에 대해서 답을 이분탐색하면서 정한 답을 Ans라고 한다면, 두 정점 u와 v 사이를 잇는 경로에 Ans보다 작거나 같은 원소의 개수를 counting해서 그 개수가 K보다 크다면 Ans를 줄이고, K보다 작으면 Ans를 늘이면 된다. 답을 정하는 데에 O(log N)이 필요하고, counting 하는 데에 T만큼의 시간이 걸린다면 쿼리마다 O(T log N)의 시간이 소요된다.


함수 F(p, val)를 트리의 루트 노드부터 p번 노드까지의 경로 상에서 val보다 작거나 같은 원소의 개수를 리턴하는 함수라고 정의하자. 그리고 x를 u와 v의 LCA이고, p(x)를 x의 부모 노드라고 정의하자. 두 정점 u와 v 사이의 경로에서 Ans보다 작거나 같은 원소의 개수 Q(u, v, Ans) = F(u, ans) + F(v, ans) - F(x, ans) - F(p(x), ans)이다.


각 정점마다, root 노드에서 각 노드에 이르는 경로까지의 정점들만을 고려했을 때에 구간 [L, R]에 해당하는 원소 가중치의 값이 몇 개 있는지를 저장하는 Segment tree를 생각해보자. 총 N개의 Segment Tree가 있는 것이다. 그러면 위의 F함수를 계산하는 것을 [1, Ans]까지의 구간합을 구하는 것과 같으므로 O(log N)의 시간이 소요된다. 따라서 T = log N이 되는 것이다. 하지만 메모리 사용량이 각 정점마다 O(N)개의 공간이 필요해서 공간복잡도가 O(N*N)이 된다.


여기서 Persistent Segment Tree가 등장하는데, 어떤 노드 u의 자식 노드 c를 생각해보자. 노드 c에서의 Segment Tree에서는 노드 u에서의 Segment tree에 하나의 원소만이 업데이트 된다. 따라서 총 O(log N)의 구간이 업데이트가 되고 나머지는 다 같다. 따라서 노드 c의 Segment Tree는 노드 u의 Segment Tree를 가리키면서, 새로운 업데이트가 된 구간만 추가해주는 식이다. 그림으로 표현하면 다음과 같다.


각 노드마다 새로운 O(log N)의 구간이 추가되어서 공간복잡도마저 O(N log N)이 되고, Persistent Segment Tree를 구현하는 시간복잡도 또한 공간복잡도에 비례하기 때문에 O(N log N)이 된다. 따라서 전체적으로 시간복잡도는 트리를 구성하는 것과 쿼리에 답하는 것을 합해 O(N log N + M log^2 N)이 된다.


여기서 굳이 답을 결정할 필요가 있을지에 대해 의문을 가져보자. 구간 트리 [L, R]이 왼쪽 부트리 [L, (L+R)/2]와 오른쪽 부트리 [(L+R)/2+1, R]로 나뉜다고 할 때, 트리의 루트에서부터 내려오면서 왼쪽 부트리에 속하는 원소의 개수의 합이 K보다 작다면 우리가 관심있는 K번째 원소는 오른쪽 부트리에 있는 것이고 아니면 왼쪽 부트리에 있는 것이다. 이 Top-Down 방식으로 답을 찾아나가면 트리의 depth는 최대 log N이기 때문에 O(log^2 N)에서 O(log N)으로 시간복잡도를 줄일 수 있다.


따라서 전체적인 시간복잡도는 O(N log N + M log N)이 된다.


#include <stdio.h>
#include <vector>
#include <map>
#include <algorithm>
using namespace std;
#define MAXN 100005
struct node {
	int c;
	node *left, *right;
	node(int _c, node *a, node *b) { c = _c, left = a, right = b; }
	node *upd(int l, int r, int w);
} *root[MAXN], *null;
//[L, R] -> [L, (L+R)/2], [(L+R)/2+1, R]
node *node::upd(int l, int r, int w) {
	//현재 추가하는 원소가 구간 안에 있다면, 구간에 속하는 원소의 개수는 +1
	if (l <= w && w <= r) {
		if (l == r) return new node(this->c + 1, null, null);
		int m = (l + r) / 2;
		return new node(this->c + 1, this->left->upd(l, m, w), this->right->upd(m + 1, r, w));
	}
	return this;
}
int N, M, w[MAXN], X[MAXN], Xn, depth[MAXN];
map  dx;
vector  v[MAXN];
int p[MAXN][17];
void dfs(int cur, int prv) {
	if (prv == -1) root[cur] = null->upd(1, Xn, w[cur]);
	else root[cur] = root[prv]->upd(1, Xn, w[cur]);
	for (auto &nxt : v[cur]) {
		if (nxt == prv) continue;
		p[nxt][0] = cur; //직계 부모
		depth[nxt] = depth[cur] + 1;
		dfs(nxt, cur);
	}
}
int lca(int a, int b) {
	if (depth[a] < depth[b]) swap(a, b);
	int diff = depth[a] - depth[b];
	for (int i = 16; i >= 0; i--) if ((diff >> i) & 1) a = p[a][i];
	if (a == b) return a;
	for (int i = 16; i >= 0; i--) if (p[a][i] != p[b][i]) a = p[a][i], b = p[b][i];
	return p[a][0];
}
//a, b : u, v, c : lca, d : parent(c)
int query(node *a, node *b, node *c, node *d, int l, int r, int k) {
	if (l == r) return l;
	int cnt = a->left->c + b->left->c - c->left->c - d->left->c;
	int m = (l + r) / 2;
	if (cnt >= k) return query(a->left, b->left, c->left, d->left, l, m, k);
	return query(a->right, b->right, c->right, d->right, m + 1, r, k - cnt);
}
int main() {
	freopen("input.txt", "r", stdin);
	freopen("output.txt", "w", stdout);
	scanf("%d%d", &N, &M);
	for (int i = 1; i <= N; i++) scanf("%d", &w[i]), X[i] = w[i];
	sort(X + 1, X + N + 1);
	Xn = unique(X + 1, X + N + 1) - (X + 1);
	for (int i = 1; i <= Xn; i++) dx[X[i]] = i;
	for (int i = 1; i <= N; i++) w[i] = dx[w[i]];
	for (int i = 1; i < N; i++) {
		int a, b; scanf("%d%d", &a, &b);
		v[a].emplace_back(b);
		v[b].emplace_back(a);
	}
	for (int i = 0; i <= N; i++) for (int j = 0; j <= 16; j++) p[i][j] = -1;
	null = new node(0, NULL, NULL);
	null->left = null->right = null;
	dfs(1, -1);
	for (int j = 1; j <= 16; j++) {
		for (int i = 1; i <= N; i++) {
			if (p[i][j - 1] == -1) continue;
			p[i][j] = p[p[i][j - 1]][j - 1];
		}
	}
	for (int i = 1; i <= M; i++) {
		int u, v, k; scanf("%d%d%d", &u, &v, &k);
		int x = lca(u, v);
		int res = query(root[u], root[v], root[x], (p[x][0] == -1 ? null : root[p[x][0]]), 1, Xn, k);
		printf("%d\n", X[res]);
	}
}


맨위로

https://hongjun7.tistory.com/64

신고하기