[백준] No.1761 - 정점들의 거리 (C++, 최소 공통 조상)
문제
https://www.acmicpc.net/problem/1761
1761번: 정점들의 거리
첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩
www.acmicpc.net
풀이
solved.ac 난이도: Platium 5
트리에는 두 노드의 경로는 유일하다. 두 노드의 쌍을 받은 후 한 노드에서 다른 노드로 직접 이동하며 거리합을 구하는 것이 가장 간단한 해결 방법이겠지만, 어느 경로로 이동해야 다른 노드로 이동할 수 있는지 알 수 없다.
따라서 최소 공통 조상(LCA)을 구하여 각 노드에서의 LCA로의 거리의 합을 구하면 두 노드 사이의 거리를 알 수 있다.
또한 두 노드를 조상 노드로 한 단계씩 이동하며 구하는 거리합은 O(Tree_height)의 시간 복잡도를 가지기 때문에 입력의 크기가 크고 치우쳐진 트리에서는 오랜 시간이 소모된다.
따라서 다이나믹 프로그래밍을 응용한 최적화로 시간 복잡도를 O(log Tree_height)로 줄이자.
해당 알고리즘에 대해서는 백준(11483) LCA 2에서 다루었으니 참고하자.
최적화한 LCA에서는 x의 2^k(2의 k승) 번째 조상을 parent[x][k]에 저장하였다.
parent[x][k]는 parent[parent[idx][k-1]][k-1]로 구해낼 수 있었다.
이번 문제에서는 두 노드의 LCA 까지의 거리 합도 구해야 한다.
따라서 x의 2^k(2의 k승) 번째 조상까지의 거리를 dist[x][k]에 저장하자.
x의 2^k(2의 k승)번째 조상까지의 거리, dist[x][k]는 아래와 같은 점화식으로 구할 수 있다.
dist[x][k] = dist[x][k-1] + dist[parent[x][k-1]][k-1]
=> x의 2^(k-1)조상까지의 거리 + x의 2^(k-1)조상의 2^(k-1)조상까지의 거리
예를 들어 아래와 같은 트리에서 5번 노드의 4번째 조상까지의 거리(dist[5][2])를 구해보자.
(화살표는 부모와 자식관계를 나타낼 뿐 방향성을 없습니다.)

parent[5][0] = 4, parent[4][0]=3, parent[3][0] = 2, parent[2][0]=1,
parent[5][1] = parent[ parent[5][0] ][0] = parent[4][0]=3,
parent[5][2] = parent[ parent[5][1] ][1] = parent[3][1] = parent[ parent[3][0] ][0] = parent[2][0] = 1
dist[5][0] = 40, dist[4][0] = 30, dist[3][0] = 20, dist[2][0]=10
dist[5][1] = dist[ parent[5][0] ][0] + dist[5][0] = dist[4][0] + 40 = 30 + 40 = 70,
5번 노드에서 2^1=2번째 조상까지의 거리는 4번 노드에서 첫 조상까지의 거리 + 5번 노드에서 첫 조상까지의 거리 이다.
dist[5][2] = dist[ parent[5][1] ][1] + dist[5][1] = dist[3][1] + 70 = dist[ parent[3][0] ][0] + dist[3][0] + 70
= dist[2][0] + 20 + 70 = 10 + 20 + 70 = 100
5번 노드에서 2^2=4번째 조상까지의 거리는 3번 노드(5번 노드의 2^1 조상)에서 2^1 조상까지의 거리 + 5번 노드에서 2^1 조상까지의 거리이다.
즉, x에서 2^k 번째 조상까지의 거리는
x에서 2^(k-1) 번째 조상까지의 거리와 x의 2^(k-1) 번째 조상에서 2^(k-1) 조상까지의 거리
두 가지로 계속 나눌 수 있다.
따라서 LCA까지의 거리는 노드에서 LCA까지의 깊이 차이를 h라고 할 때,
O(log h)의 복잡도로 계산 가능하다.
이러한 계산을 위한 bottom-up방식의 전처리는 O(n log Tree_height)의 시간 복잡도가 소모된다.
1
2
3
4
5
6
7
8
|
for(int k=1; k<TREE_HIGTH ; ++k){
for(int idx = 2 ; idx<=node_num ; ++idx){
if(parent[idx][k-1] != 0){
parent[idx][k] = parent[parent[idx][k-1]][k-1];
dist[idx][k] = dist[idx][k-1] + dist[parent[idx][k-1]][k-1];
}
}
}
|
cs |
두 노드의 LCA까지의 거리합은 LCA를 구하는 함수에 거리합 sum을 계산해주는 작업만 추가해주면 된다.
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
|
int DistNodePair(int a, int b){
int sum=0;
if(depth[a] != depth[b]){
if(depth[a] < depth[b]) //깊이가 다르다면 a가 항상 더 깊게
swap(a,b);
int dif = depth[a] - depth[b];
for(int i=0; dif>0 ; ++i){
if(dif %2 ==1){
sum += dist[a][i];
a = parent[a][i];
}
dif = dif>>1;
}
}
if(a != b){
for(int k = TREE_HIGTH-1; k>=0 ; --k){
if(parent[a][k] != 0 && parent[a][k] != parent[b][k]){
sum+= (dist[a][k] + dist[b][k]);
a = parent[a][k];
b = parent[b][k];
}
}
sum += dist[a][0] + dist[b][0];
}
return sum;
}
|
cs |
코드
#include <iostream> | |
#include<vector> | |
#include<cstdio> | |
#include<string.h> | |
using namespace std; | |
int node_num; | |
const int TREE_HIGTH = 20; | |
int depth[40001]; | |
int parent[40001][TREE_HIGTH]; | |
int dist[40001][TREE_HIGTH]; | |
vector<pair<int,int>> adj[40001]; | |
void FindParent(int par, int now, int dep, int cost){ | |
depth[now] = dep; | |
parent[now][0] = par; | |
dist[now][0] = cost; | |
for(int i=0; i<adj[now].size() ; ++i){ | |
if(adj[now][i].first != par) | |
FindParent(now, adj[now][i].first, dep+1,adj[now][i].second); | |
} | |
return; | |
} | |
int DistNodePair(int a, int b){ | |
int sum=0; | |
if(depth[a] != depth[b]){ | |
if(depth[a] < depth[b]) //깊이가 다르다면 a가 항상 더 깊게 | |
swap(a,b); | |
int dif = depth[a] - depth[b]; | |
for(int i=0; dif>0 ; ++i){ | |
if(dif %2 ==1){ | |
sum += dist[a][i]; | |
a = parent[a][i]; | |
} | |
dif = dif>>1; | |
} | |
} | |
if(a != b){ | |
for(int k = TREE_HIGTH-1; k>=0 ; --k){ | |
if(parent[a][k] != 0 && parent[a][k] != parent[b][k]){ | |
sum+= (dist[a][k] + dist[b][k]); | |
a = parent[a][k]; | |
b = parent[b][k]; | |
} | |
} | |
sum += dist[a][0] + dist[b][0]; | |
} | |
return sum; | |
} | |
int main(){ | |
scanf("%d",&node_num); | |
int a,b,cost; | |
for(int i=0; i<node_num-1 ; ++i){ | |
scanf("%d %d %d",&a, &b, &cost); | |
adj[a].push_back(make_pair(b, cost)); | |
adj[b].push_back(make_pair(a, cost)); | |
} | |
memset(parent, 0 , sizeof(parent)); | |
memset(dist, 0, sizeof(dist)); | |
FindParent(0, 1,0,0); | |
for(int k=1; k<TREE_HIGTH ; ++k){ | |
for(int idx = 2 ; idx<=node_num ; ++idx){ | |
if(parent[idx][k-1] != 0){ | |
parent[idx][k] = parent[parent[idx][k-1]][k-1]; | |
dist[idx][k] = dist[idx][k-1] + dist[parent[idx][k-1]][k-1]; | |
} | |
} | |
} | |
int pair_num; | |
scanf("%d", &pair_num); | |
while(pair_num--){ | |
scanf("%d %d", &a,&b); | |
printf("%d\n",DistNodePair(a,b)); | |
} | |
return 0; | |
} |