예전에 틀린 문제들을 다시 풀어봤다.
이 문제는 가중치가 있는 tree에서 정점들의 거리를 구하는 문제이다. LCA를 구해서 cost[a] + cost[b] - 2*cost[lca] 를 구해주면 된다. (cost[i]는 root node 로부터 node i까지의 거리, dfs로 먼저 구해놓는다.)
여기까지는 쉽게 생각할 수 있을 것이다. 하지만, query의 존재 때문에, 하나하나 올라가서 LCA를 찾는다면, O(40000 * 10000 ) 로 시간초과가 나기 때문에, 시간을 단축시킬 무언가가 필요했다.
이때, DP를 이용하여 자신의 조상에 대해 미리 구해놓는다면, LCA 시 시간을 단축시킬 수 있다.
DP 점화식은 다음과 같다.
dp[ i ][ j ] = dp[ i-1 ][ dp[ i-a ][ j ] ] ( i : node j 의 2^i 위의 조상 node, j : node 번호 )
i,j 순서를 바꾸면 좀더 이해하기 쉬운 코드가 되지만, 나는 cache locality 때문에 위와 같이 정의했다. (만약 이 내용이 틀리다면 댓글 부탁드립니다...)
이 dp table을 이용하여 공통 조상을 구해주면 된다.
(LCA로 검색하면 더 자세한 설명이 있는 블로그를 찾을 수 있을 것이다.)
아래는 내 코드이다.
#include <cstdio>
#include <vector>
using namespace std;
vector<vector<pair<int, int>>> v;
int n,m;
//1-based
int cost[40001];
int depth[40001];
//16이면 충분하지만, 여유롭게 할당함
int table[21][40001];
bool visited[40001];
void dfs(int now, int tcost, int d, int p){
if(visited[now]) return;
visited[now] = true;
cost[now] = tcost;
depth[now] = d;
table[0][now] = p;
for(auto edge : v[now]){
dfs(edge.first, tcost+edge.second, d+1, now);
}
}
void make_table(){
for(int i=1;i<=20;i++){
for(int j=1; j<= n;j++){
table[i][j] = table[i-1][table[i-1][j]];
}
}
}
int get_lca(int a, int b){
//must be depth a < depth b
if(depth[a] > depth[b]){
int t = a;
a = b;
b = t;
}
for(int i=20;i>=0;i--){
if(depth[b] - depth[a] >= (1<<i)){
b = table[i][b];
}
}
if(a==b) return b;
for(int i=20;i>=0;i--){
if(table[i][b] != table[i][a]){
a = table[i][a];
b = table[i][b];
}
}
return table[0][a];
}
int main(){
int a,b,c;
scanf("%d",&n);
v.resize(n+1);
for(int i=0;i<n-1;i++){
scanf("%d%d%d",&a,&b,&c);
v[a].push_back({b,c});
v[b].push_back({a,c});
}
dfs(1,0,0, 0);
make_table();
scanf("%d",&m);
while(m--){
scanf("%d%d",&a,&b);
int lca = get_lca(a,b);
printf("lca : %d\n", lca);
printf("%d\n", cost[a]+cost[b]-2*cost[lca]);
}
return 0;
}
(사실 저 visited 배열을 없애고 싶었는데, 어떻게 해도 root node에 대해 예외처리를 해야 해서 가만히 두었다.)
예전 제출 코드는 다음과 같다.
#include <cstdio>
int Node[40001][3] = { 0, };
//0 : parent node
//1 : depth
//2 : distance
int main() {
int ns; scanf("%d", &ns);
int p, c, d;
for (int i = 0; i < ns-1; i++) {
scanf("%d%d%d", &p, &c, &d);
Node[c][0] = p;
if (Node[p][0] != 0)
Node[c][1] = Node[p][1] + 1;
else
Node[p][1] = Node[c][1] - 1;
Node[c][2] = d;
}
int t; scanf("%d", &t);
while (t--) {
int answer = 0;
int depth;
int n1, n2; scanf("%d%d", &n1, &n2);
if (Node[n1][1] > Node[n2][1]) {
int a = n1; n1 = n2; n2 = t;
}
while (Node[n1][1] != Node[n2][1]) {
answer += Node[n2][2];
n2 = Node[n2][0];
}
while (n1 != n2) {
answer += (Node[n2][2] + Node[n1][2]);
n1 = Node[n1][0];
n2 = Node[n2][0];
}
printf("%d\n", answer);
}
}
LCA로 접근했긴 했는데, 저 시간을 줄일 방법을 생각하지 못해서 TLE를 맞은 것 같다. 그래도 예전보단 조금이라도 나아져서 다행이다.
'알고리즘' 카테고리의 다른 글
BOJ 14226 이모티콘 (0) | 2021.05.01 |
---|---|
BOJ 1445 일요일 아침의 데이트 (0) | 2021.04.20 |
BOJ 1074 Z (0) | 2021.04.15 |
BOJ 2014 소수의 곱 (0) | 2021.04.08 |
BOJ 2629 양팔 저울 (0) | 2021.04.07 |