알고리즘 문제풀이

최소 공통 조상(Lowest Common Ancestor, LCA) 알고리즘 본문

자료구조 + 알고리즘/기초

최소 공통 조상(Lowest Common Ancestor, LCA) 알고리즘

JoonDev 2021. 8. 16. 00:58

최소 공통 조상?

트리에서 임의의 두 노드(u, v)에서의 최소 공통 조상은 "u와 v의 공통 조상노드들 중 가장 깊이가 깊은(=루트로 부터 거리가 먼)노드"로 정의할 수가 있다. 

최소 공통 조상은 여러가지 분야에서 응용될 수 있지만, 트리 상에서 두 노드의 거리를 구할 때 많이 이용되는 알고리즘이다.

 


아래와 같은 트리가 있다고 가정해보자.

트리의 최상위 루트 노드는 루트노드를 제외한 모든 노드들의 조상 노드임을 알 수 있다.

깊이가 다른 노드 3과 6의 조상 노드들은 무엇이 있을까? (0, 1)번 노드가 3,6의 공통 조상일 것이다.

그 중, 최소 공통 조상은 공통 조상 중 depth가 가장 깊은 1이 된다.

다음으로는 깊이가 같은 4,5번 노드의 최소 공통 조상은 무엇이 있을까? 

최상위 루트 노드인 0번 노드일 것이다.

 

이것을 일반화 해보도록 하자.

u와 v의 LCA를 구하기 위해 u와 v가 제 각기 부모 노드로 이동하면서, 방문한 부모 노드들을 기록하여 비교&대조 해 보는 방법을 떠올릴 수가 있다. 이 방식은 O( depth(u) + depth(v) ) 의 시간을 요구하고 이 후, 비교&대조 하는 부분에 있어 추가적인 시간을 요구한다.

 

위 방법보다는 효율적인 방법이 있다.

바로, 임의의 노드 u, v의 depth를 통일 시킨다음 u와 v를 동시에 한칸 씩 올리면서 확인하는 방법이다.

이 때, (u의 현재위치) == (v의 현재위치) 가 같으면 이것이 최소 공통 조상임이 보장이 된다. 

이를 구현하기 위해, 우리는 각각의 노드가 가지는 깊이(depth) 정보와 각각의 부모 노드(parent)를 알아야한다.

 

위 문제에서 루트 노드는 0번 이므로 DFS를 통해 노드마다 순회하며, 부모와 깊이를 초기화한다.

class LCA{
public:
    int parent[8]={0}, depth[8];
    void getParentAndDepth(int current, int dep){
        this->depth[current] = dep;
        for(auto child : tree[current] ){
            parent[child] = current;
            getParentAndDepth(child, dep+1);
        }
    }
    int find(int u, int v);
};
// 호출 시 -> (LCA의 인스턴스).getParentAndDepth(최상위 루트 노드, 0)

우리는 모든 노드의 (부모, 깊이) 정보를 알고 있다.

이를 기반으로, LCA를 찾아보자. LCA를 찾는 방법은 위에서 설명한 바와 같다.

class LCA{
public:
    int parent[8] = {0,}, depth[8];
    void getParentAndDepth(int current, int dep);
    int find(int u, int v){
        // depth[u] >= depth[v] 임을 보장
        if( depth[u] < depth[v] )
            swap(u, v);
        // u와 v의 깊이를 맞추기 위함
        int diff = depth[u] - depth[v];
        while(diff--){
            u = parent[u];
        }
        // u, v를 부모 노드 쪽으로 한칸씩 움직인다.
        // 이 과정에서 u == v 라면, 이것이 최소 공통 조상이다.
        while( u != v ){
            u = parent[u];
            v = parent[v];
        }
        return u;
    }
};

두 노드의 LCA를 찾는 find함수의 시간 복잡도는 최악의 경우, 루트 노드까지 두 노드가 움직여야 하므로 O( max(depth(u), depth(v)) ) 가 되겠다. 

 

전체 코드는 아래와 같다.

#include <bits/stdc++.h>
using namespace std;
vector<int> tree[8];
void makeTree(){
    tree[0].push_back(1), tree[0].push_back(2);
    tree[1].push_back(3), tree[1].push_back(4);
    tree[2].push_back(5);
    tree[4].push_back(6), tree[4].push_back(7);
}
class LCA{
public:
    int parent[8] = {0,}, depth[8];
    void getParentAndDepth(int current, int dep){
        this->depth[current] = dep;
        for(auto child : tree[current] ){
            parent[child] = current;
            getParentAndDepth(child, dep+1);
        }
    }
    int find(int u, int v){
        // depth[u] >= depth[v] 임을 보장
        if( depth[u] < depth[v] )
            swap(u, v);
        // u와 v의 깊이를 맞추기 위함
        int diff = depth[u] - depth[v];
        while(diff--){
            u = parent[u];
        }
        // u, v를 부모 노드 쪽으로 한칸씩 움직인다.
        // 이 과정에서 u == v 라면, 이것이 최소 공통 조상이다.
        while( u != v ){
            u = parent[u];
            v = parent[v];
        }
        return u;
    }
};
int main(void){
    makeTree();

    LCA lca;
    lca.getParentAndDepth(0, 0);
    for(int i=0; i<8; i++){
        for(int j=i+1; j<8; j++){
            cout << i << "와 " << j << "의 LCA : " << lca.find(i, j) << '\n';
        }
    }

    return 0;
}

 

시간복잡도

전처리를 하기 위한, getParentAndDepth() 는 O(N)의 시간이 소요된다.

LCA를 구하기 위해, 매 번 선형탐색을 기본으로 하기 때문에 O(N)의 시간이 소요된다. 

 


 

 

 

개선된 방법

위와 같은 방식은 u와 v의 깊이를 맞추기 위해 선형적으로 부모노드를 방문한다. 이에 따른 시간을 개선할 순 없을까?

또한, depth(u) == depth(v) 일 때 선형적으로 부모노드를 방문하면서 비교한다. 이에 따른 시간을 개선할 순 없을까?

 

기본 접근에서 부모를 방문하는 것을 exponential 하게 탐색한다는 idea에서 출발할 수 있다.

깊이 차이(diff) = depth[u] - depth[v] ( depth[u] >= depth[v] ) 일 때, 이것을 다음과 같은 형태로 표현할 수 있다.

${a_0, a_1, ..., a_n}$은 diff를 이진수로 변환한다음 최하위비트의 계수로 볼 수 있다.

 

예를 들어 diff = 13 이라고 가정해보자.

이진수 표현은 1101가 되고 이를 위의 Notation으로 바라 보았을 때는 다음과 같다.

보다 더 직관적으로 살펴보자.

v와 동일하게 depth를 맞추기 위해 우리는 총 13칸 움직이면 된다. 

위 그림에서 우리는 13칸을 linear하게 올라가는 것이 아닌, (1+4+8) 총 3번에 거쳐서 올라갈 수 있다는 것을 의미한다. 

( 각각의 크기는 2^k 꼴이라는 것을 인지하자. )

(u -> u(1), 1칸) | (u(1) -> u(2), 4칸) | (u(2) -> u(3), 8칸) 

그렇다면 u의 2^k 번째 부모를 저장해놓는다면, 이와 같이 O(logN) 시간만에 두 노드의 depth를 맞출 수 있다는 것을 의미한다.

 

이를 위한 sparse table 을 정의해주자.

vector<vector<int>> parent; // parent[u][k] : u노드의 2^k번째 부모노드

이 때, 부모노드의 최대 개수는 트리의 높이와 같고 연결리스트 형태의 트리일 경우 트리의 높이와 노드의 총 개수가 동일하기 때문에 

parent의 최대 크기는 넉넉잡아

다음과 같을 것이다. 내부의 ${log_2{N} + 1}$값은 올림 처리 한 것이다.

자세한 구현은 나중에 살펴보고 지금은 흐름만 파악하도록 하자.


depth를 통일 시켰다고 해도, 이전의 Naive한 방법대로 선형적으로 부모노드로 이동할 경우 O(N)의 시간이 소요되므로 최악의 경우에는 개선되기 이전과 동일한 시간복잡도를 가질 것이다. 이를 개선하기 위해서 u와 v의 depth를 통일시켰을 경우의 상황을 살펴보자.

 

 

u와 v는 LCA를 기준으로 윗 부분(k)은 동일한 조상을 가진다.

parent[u][i] != parent[v][i] 인 최소 깊이의 u'를 찾는다. 즉, i의 최댓값을 찾는다.

이 때, 수렴성을 보장하기 위해서 i는 내림차순으로 탐색하는 것이 중요하다.

우리는 u에서 ${2^i}$ 번째 부모 노드로 이동할 수 있다. 그림과 같이 ${2^k}$ 만큼 분기를 하고, 최적해(LCA)까지 ${2^l}$이 남은 상황을 가정해보자. (${2^k >= 2^l}$) 

오름차순으로 i를 탐색할 경우, 이전(${2^k}$) 보다는 더 크게 분기해야하므로, 이 경우는 최적해를 찾지 못하게 된다.

 


위의 과정들을 이해했다면, 소스코드로 구현해보자.

먼저, 각 노드의 깊이를 통일시키기 위해 각각의 깊이 정보를 저장하는 depth배열을 선언한다.

또한, 각각의 노드의 2^k 번째 부모 정보를 저장해주는 parent 배열을 선언한다.

class LCA{
public:
	int N;
    vector<int> depth;
    vector<vector<int>> parent;
    LCA(int n=8){ N = n, depth.resize(n), parent.resize(n, vector<int>(ceil(log2(n))));}
};

depth와 i노드의 첫번째 부모(즉, parent[i][0])은 DFS를 통해서 쉽게 구할 수 있다.

class LCA{
public:
	int N;
    vector<int> depth;
    vector<vector<int>> parent;
    LCA(int n=8){ N = n, depth.resize(n), parent.resize(n, vector<int>(ceil(log2(n))));}
    
    
    void DFS(int current, int previous, int d){
        depth[current] = d;
        for(auto next : adj[current]){
            if( next == previous ) continue; // (양방향 그래프에서) 사이클 방지
            parent[next][0] = current;
            DFS(next, current, d+1);
        }
    }
};

나머지 parent배열의 원소를 채우기 위해 다음과 같은 특성을 활용하여 bottom-up으로 채워준다.

위 사실을 일반화 하여 parent[i][j] = parent[ parent[i][j-1] ][ j - 1 ] 라는 점화식을 세울 수 있다.

parent 테이블은 ${2^0}$ 부모를 알고 있는 것을 base 로 하여 노드 별로 ${2^1}$ 부모, ${2^1}$ 부모를 알고있는 것을 base로 하여 노드별로 ${2^2}$ 부모, ... 이런 식으로 채워준다. 트리에서 벗어나는 범위에 대해서 예외처리를 해준다.

class LCA{
public:
    int N;
    vector<int> depth;
    vector<vector<int>> parent;
    LCA(int n = 8){ N = n, depth.resize(n), parent.resize(n, vector<int>(ceil(log2(n))));}
    void DFS(int current, int previous, int d){
        depth[current] = d;
        for(auto next : tree[current]){
            if( next == previous ) continue; // (양방향 그래프에서) 사이클 방지
            parent[next][0] = current;
            DFS(next, current, d+1);
        }
    }
    void setParent(){
        int sz = ceil(log2(N));
        for(int j=1; j<sz; j++) {
            for (int i = 0; i < N; i++) {
            	if( parent[i][j-1] == -1 ) continue;
                parent[i][j] = parent[parent[i][j - 1]][j - 1];
            }
        }
    }
};

 

마지막으로, LCA를 구하는 find함수를 작성하자.

동작 방식은, 앞 서 말한 것과 동일하다.

class LCA{
public:
    int N;
    vector<int> depth;
    vector<vector<int>> parent;
    LCA(int n = 8){ N = n, depth.resize(n), parent.resize(n, vector<int>(ceil(log2(n))));}
    void DFS(int current, int previous, int d){
        depth[current] = d;
        for(auto next : tree[current]){
            if( next == previous ) continue; // (양방향 그래프에서) 사이클 방지
            parent[next][0] = current;
            DFS(next, current, d+1);
        }
    }
    void setParent(){
        int sz = ceil(log2(N));
        for(int j=1; j<sz; j++) {
            for (int i = 0; i < N; i++) {
                parent[i][j] = parent[parent[i][j - 1]][j - 1];
            }
        }
    }
    int find(int u, int v){
        if( depth[u] < depth[v] )
            swap(u, v);
        // depth 통일
        int diff = depth[u] - depth[v];
        int pos = 0;
        while(diff){
            if( diff & 1 ) {
                u = parent[u][pos];
            }
            pos += 1;
            diff >>= 1;
        }
        // depth[u] == depth[v] 보장
        if( u == v )
            return u;
        int sz = ceil(log2(N));
        for(int i=sz-1; i>=0; i--){
            if( parent[u][i] != parent[v][i] ){
                u = parent[u][i];
                v = parent[v][i];
            }
        }
        return parent[u][0];
    }
};

테스트를 포함한 코드는 아래와 같다.

#include <bits/stdc++.h>
using namespace std;
vector<int> tree[8];
void makeTree(){
    tree[0].push_back(1), tree[0].push_back(2);
    tree[1].push_back(3), tree[1].push_back(4);
    tree[2].push_back(5);
    tree[4].push_back(6), tree[4].push_back(7);
}
class LCA{
public:
    int N;
    vector<int> depth;
    vector<vector<int>> parent;
    LCA(int n = 8){ N = n, depth.resize(n), parent.resize(n, vector<int>(ceil(log2(n))));}
    void DFS(int current, int previous, int d){
        depth[current] = d;
        for(auto next : tree[current]){
            if( next == previous ) continue; // (양방향 그래프에서) 사이클 방지
            parent[next][0] = current;
            DFS(next, current, d+1);
        }
    }
    void setParent(){
        int sz = ceil(log2(N));
        for(int j=1; j<sz; j++) {
            for (int i = 0; i < N; i++) {
                if( parent[i][j-1] == -1 ) continue;
                parent[i][j] = parent[parent[i][j - 1]][j - 1];
            }
        }
    }
    int find(int u, int v){
        if( depth[u] < depth[v] )
            swap(u, v);
        // depth 통일
        int diff = depth[u] - depth[v];
        int pos = 0;
        while(diff){
            if( diff & 1 ) {
                u = parent[u][pos];
            }
            pos += 1;
            diff >>= 1;
        }
        // depth[u] == depth[v] 보장
        if( u == v )
            return u;
        int sz = ceil(log2(N));
        for(int i=sz-1; i>=0; i--){
            if( parent[u][i] != parent[v][i] ){
                u = parent[u][i];
                v = parent[v][i];
            }
        }
        return parent[u][0];
    }
};
int main(void){
    makeTree();
    LCA lca;
    lca.DFS(0, -1, 0);
    lca.setParent();

    for(int i=0; i<8; i++){
        for(int j=i+1; j<8; j++){
            cout << i << "와 " << j <<"의 LCA : " << lca.find(i, j) << '\n';
        }
    }

    return 0;
}

시간복잡도

parent 테이블은 N * logN 크기를 가지는 sparse table이고, 이것을 채우는데에도 O(NlogN) 시간이 소요된다. 나머지는 1차원 배열이므로 시간복잡도에 연산되지 않는다. 

고로, 전처리 과정에 소요되는 시간은 O(NlogN) 이다.

 

전처리가 끝난 후, LCA 쿼리 1개당 소요되는 시간은 O(logN)이다.

Comments