본문 바로가기

알고리즘

2021 KAKAO 채용연계형 인턴쉽 표 편집

https://programmers.co.kr/learn/courses/30/lessons/81303

 

코딩테스트 연습 - 표 편집

8 2 ["D 2","C","U 3","C","D 4","C","U 2","Z","Z"] "OOOOXOOO" 8 2 ["D 2","C","U 3","C","D 4","C","U 2","Z","Z","U 1","C"] "OOXOXOOO"

programmers.co.kr

 

첫번째로 생각한 것은 Segment Tree + Binary Search이다.


U X, D X 명령어는 현재 위치로 부터 임의의 위치 사이에 존재하는 살아있는 Node 개수를 알아야 구할 수 있다.
따라서 구간합을 Log n 만에 구할 수 있는 Segment Tree를 사용할 수 있다고 생각했다.

하지만, 문제는, 임의의 위치를 구하는 것이다. Linear Search를 이용하면, TLE가 날 것이다.
여기에 Binary Search로 이용하면 Log n 만에 구할 수 있다고 결론 내었다.

Binary Search 시 Index 처리로 인한 무한 Loop 로 인해서 난감했는데, 예전에 들었던 무한 Loop 없는 Binary Search 구현법이 떠올랐고, 이를 적용했다.
(구간 [Start, End] 를 Binary Search 시 나눠지는 구간을 [Start, Mid], [Mid+1, End] ( Mid = (Start+End//2) ) 로 구성하고, 이를 재귀적으로 처리하는 것이다. )

복구/삭제의 경우, Segment Tree의  Update를 사용하면 log n 만에 처리가 가능하다.

즉, 각 쿼리마다 적절한 다음 위치를 Binary Search로 찾는데 Log n, Binary Serach에서 구간 [left,right] 사이의 살아있는 node 개수를 탐색하는데 log n, 쿼리의 수가 m개이므로, 시간복잡도는 다음과 같다.
 $$ O( m log_{2} ^2 n ) $$

코드는 다음과 같다.

#include <string>
#include <vector>
#include <stack>
#include <cassert>
using namespace std;

class SegmentTree{
public:
    int arr[4000000];
    int init(int now, int start, int end){
        if(start == end) return this->arr[now] = 1;
        int mid = (start + end) /2;
        return this->arr[now] = init(now*2, start, mid) + init(now*2+1, mid+1, end);
    }
    
    void update(int now, int start, int end, int target, int value){
        if(target < start || end < target) return;
        this->arr[now] += value;
        if(start != end){
            int mid = (start + end) / 2;
            this->update(now*2, start, mid, target, value);
            this->update(now*2+1, mid+1, end, target, value);
        }
    }
    
    int search(int now, int start, int end, int first, int last){
        if(last < start || end < first) return 0; 
        else if(first <= start && end <= last) return this->arr[now];
        else{
            int mid = (start+end) / 2;
            return this->search(now*2, start, mid, first, last) + 
                    this->search(now*2+1, mid+1, end, first, last);  
        }
    }
};

class Table{
public:
    int now;
    int n;
    stack<int> deleted;
    SegmentTree segtree;
    Table(int n, int k){
        this->n = n;
        this->now = k+1;
        this->segtree.init(1,1,n);
    }
    
    void Up(int x){
        this->now -= 1;
        int mid = 0, start = 1, end = (this->now);
        while(start != end){
            mid = (start + end) / 2;
            int t1 = this->segtree.search(1,1,this->n,mid+1, this->now);
            if(t1 >= x){start = mid + 1;}
            else {end = mid;}
        }
        this->now = start;
    }
    void Down(int x){
        this->now += 1;
        int mid = 0, start = this->now, end = this->n;
        while(start != end){
            mid = (start + end) / 2;
            int t1 = this->segtree.search(1,1,this->n,this->now,mid);
            if(t1 >= x){end = mid;}
            else {start = mid+1;}
        }
        this->now = start;
    }
    void Delete(){
        this->segtree.update(1,1,this->n,this->now, -1);
        this->deleted.push(this->now);
        if(this->segtree.search(1,1,this->n,this->now,this->n)) this->Down(1);
        else this->Up(1);
    }
    void Revert(){
        int t = this->deleted.top();
        this->deleted.pop();
        this->segtree.update(1,1,this->n,t,1);
    }
    
    string S(){
        string s;
        for(int i=1; i<=this->n;i++) s += this->segtree.search(1,1,this->n,i,i) ? 'O' : 'X';
        return s;
    }
};

int str_to_int(string s){
    int ret = 0;
    for(int i=2;s[i];i++){
        ret *= 10;
        ret += (int)(s[i]-'0');
    }
    return ret;
}


string solution(int n, int k, vector<string> cmd) {

    Table t(n,k);
    
    for(string s : cmd){
        switch(s[0]){
        case 'U': t.Up(str_to_int(s));   break;
        case 'D': t.Down(str_to_int(s)); break;
        case 'C': t.Delete();            break;
        default : t.Revert();            break;
        }
    }
    
    return t.S();
}

 

하지만, 정확성은 다 맞췃지만, 효율성에서 TLE가 나왔다.
아무리 줄여봐도 여기서 안줄여지던데, 재귀적으로 부르는 함수 call cost 가 높아서 상수 시간 복잡도에서 초과한 것이 아닌가 생각해본다. 

 

두번재로 생각한 방법은 Linked List 이다.
문제 조건에 이동하는 칸의 총 합은 1,000,000 이하로만 주어지기 때문에, 탐색에 시간을 조금 더 써도 된다.
그렇다면, 탐색에는 불리하지만, 삽입/삭제에 유리한 Linked List 를 사용할 수 있겠다고 생각했다.

코드는 다음과 같다.

class node:
    def __init__(self, i):
        self.prev = None
        self.next = None
        self.num  = i
        self.deleted = False
        
class linked_list:
    def __init__(self, n, k):
        self.nodes = [node(i) for i in range(n)]
        for idx in range(1, n-1):
            self.nodes[idx].prev = self.nodes[idx-1]
            self.nodes[idx].next = self.nodes[idx+1]
        self.nodes[0].next = self.nodes[1]
        self.nodes[-1].prev = self.nodes[-2]
        self.now = self.nodes[k]
        self.deleted = []
        
    def Up(self, x):
        for _ in range(x): 
            self.now = self.now.prev
            
    def Down(self, x):
        for _ in range(x): 
            self.now = self.now.next
    
    def Delete(self):        
        self.deleted.append(self.now.num)
        self.now.deleted = True
        if self.now.prev: self.now.prev.next = self.now.next
        if self.now.next : 
            self.now.next.prev = self.now.prev
            self.now = self.now.next
        else: self.now = self.now.prev
        
    
    def Revert(self):
        d = self.deleted.pop()
        self.nodes[d].deleted = False
        if self.nodes[d].next : self.nodes[d].next.prev = self.nodes[d]
        if self.nodes[d].prev : self.nodes[d].prev.next = self.nodes[d]
        
    def __str__(self):
        return ''.join(['X' if n.deleted else 'O' for n in self.nodes])

def solution(n, k, cmd_list):
    
    ll = linked_list(n,k)
    
    for cmd in cmd_list:
        c = cmd.split(' ')
        if   c[0] == 'U' : ll.Up(int(c[1]))
        elif c[0] == 'D' : ll.Down(int(c[1]))
        elif c[0] == 'C' : ll.Delete()
        else             : ll.Revert()
    
        
    return str(ll)

 

처음 __str__ 구현 시 다음과 같이 했었는데, O(n^2)나 걸리는 멍청한 코드였다.

def __str__(self):
	return ''.join(['X' if i in self.deleted else 'O' for i in range(len(self.nodes))])