본문 바로가기

알고리즘

BOJ 1761 정점들의 거리

www.acmicpc.net/problem/1761

 

1761번: 정점들의 거리

첫째 줄에 노드의 개수 N이 입력되고 다음 N-1개의 줄에 트리 상에 연결된 두 점과 거리를 입력받는다. 그 다음 줄에 M이 주어지고, 다음 M개의 줄에 거리를 알고 싶은 노드 쌍이 한 줄에 한 쌍씩

www.acmicpc.net

 

예전에 틀린 문제들을 다시 풀어봤다.

 

이 문제는 가중치가 있는 tree에서 정점들의 거리를 구하는 문제이다. LCA를 구해서 cost[a] + cost[b] - 2*cost[lca] 를 구해주면 된다. (cost[i]는 root node 로부터 node i까지의 거리, dfs로 먼저 구해놓는다.)

여기까지는 쉽게 생각할 수 있을 것이다. 하지만, query의 존재 때문에, 하나하나 올라가서 LCA를 찾는다면, O(40000 * 10000 ) 로 시간초과가 나기 때문에, 시간을 단축시킬 무언가가 필요했다.

이때,  DP를 이용하여 자신의 조상에 대해 미리 구해놓는다면, LCA 시 시간을 단축시킬 수 있다. 

DP 점화식은 다음과 같다.

dp[ i ][ j ] = dp[ i-1 ][ dp[ i-a ][ j ] ]  ( i : node j 의 2^i 위의 조상 node, j : node 번호 )

i,j 순서를 바꾸면 좀더 이해하기 쉬운 코드가 되지만, 나는 cache locality 때문에 위와 같이 정의했다. (만약 이 내용이 틀리다면 댓글 부탁드립니다...)

이 dp table을 이용하여 공통 조상을 구해주면 된다. 

(LCA로 검색하면 더 자세한 설명이 있는 블로그를 찾을 수 있을 것이다.)

 

아래는 내 코드이다.

#include <cstdio>
#include <vector>

using namespace std;

vector<vector<pair<int, int>>> v;

int n,m;
//1-based
int cost[40001];
int depth[40001];
//16이면 충분하지만, 여유롭게 할당함
int table[21][40001];
bool visited[40001];

void dfs(int now, int tcost, int d, int p){

    if(visited[now]) return;

    visited[now] = true;
    cost[now] = tcost;
    depth[now] = d;
    table[0][now] = p;

    for(auto edge : v[now]){
        dfs(edge.first, tcost+edge.second, d+1, now);
    }
}

void make_table(){
    for(int i=1;i<=20;i++){
        for(int j=1; j<= n;j++){
            table[i][j] = table[i-1][table[i-1][j]];
        }
    }
}

int get_lca(int a, int b){

    //must be depth a < depth b
    if(depth[a] > depth[b]){
        int t = a;
        a = b;
        b = t;
    }

    for(int i=20;i>=0;i--){
        if(depth[b] - depth[a] >= (1<<i)){
            b = table[i][b];
        }
    }
    if(a==b) return b;
    for(int i=20;i>=0;i--){
        if(table[i][b] != table[i][a]){
            a = table[i][a];
            b = table[i][b];
        }
    }

    return table[0][a];
}

int main(){

    int a,b,c;
    
    scanf("%d",&n);
    v.resize(n+1);

    for(int i=0;i<n-1;i++){
        scanf("%d%d%d",&a,&b,&c);
        v[a].push_back({b,c});
        v[b].push_back({a,c});
    }

    dfs(1,0,0, 0);
    make_table();

    scanf("%d",&m);
    while(m--){
        scanf("%d%d",&a,&b);
        int lca = get_lca(a,b);
        printf("lca : %d\n", lca);
        printf("%d\n", cost[a]+cost[b]-2*cost[lca]);
    }

    return 0;
}

(사실 저 visited 배열을 없애고 싶었는데, 어떻게 해도 root node에 대해 예외처리를 해야 해서 가만히 두었다.)

 

예전 제출 코드는 다음과 같다.

#include <cstdio>

int Node[40001][3] = { 0, };
//0 : parent node
//1 : depth
//2 : distance

int main() {
	int ns; scanf("%d", &ns);
	int p, c, d;
	for (int i = 0; i < ns-1; i++) {
		scanf("%d%d%d", &p, &c, &d);
		Node[c][0] = p;
		if (Node[p][0] != 0)
			Node[c][1] = Node[p][1] + 1;
		else
			Node[p][1] = Node[c][1] - 1;
		Node[c][2] = d;
	}
	int t; scanf("%d", &t);
	
	while (t--) {
		int answer = 0;
		int depth;
		int n1, n2; scanf("%d%d", &n1, &n2);
		if (Node[n1][1] > Node[n2][1]) {
			int a = n1; n1 = n2; n2 = t;
		}
		while (Node[n1][1] != Node[n2][1]) {
			answer += Node[n2][2];
			n2 = Node[n2][0];
		}
		while (n1 != n2) {
			answer += (Node[n2][2] + Node[n1][2]);
			n1 = Node[n1][0];
			n2 = Node[n2][0];
		}
		printf("%d\n", answer);

	}

}

LCA로 접근했긴 했는데, 저 시간을 줄일 방법을 생각하지 못해서 TLE를 맞은 것 같다. 그래도 예전보단 조금이라도 나아져서 다행이다.

'알고리즘' 카테고리의 다른 글

BOJ 14226 이모티콘  (0) 2021.05.01
BOJ 1445 일요일 아침의 데이트  (0) 2021.04.20
BOJ 1074 Z  (0) 2021.04.15
BOJ 2014 소수의 곱  (0) 2021.04.08
BOJ 2629 양팔 저울  (0) 2021.04.07