본문 바로가기

백준 문제 풀이

백준 2887번 - 행성 터널 (C++)

728x90

https://www.acmicpc.net/problem/2887

 

 

문제 자체는 값에 따라 그래프를 구성하고 MST를 구하면 되는 문제이다. 그러나 노드의 개수가 10만이며 간선의 수도 만만치 않다. 그래서 모든 간선을 각자 비교하여 구하면 10만 * 10만 * 4byte 이므로

 

약 38146 MB

 

말도안되는 메모리를 사용해버린다..!

 

그래서 이 문제의 핵심은 어떻게 3개의 좌표를 가진 노드들을 효율적으로 비교할 것인가?

로 귀결되고 이 것이 문제이다.

 

A: 2, 3, 5

B: 6, 9 10

C: 11, 1, 3

 

이렇게 있다고 해보자. 만약 모두 구하려면 9번의 연산이 필요하고 각각 X, Y, Z를 비교한 값을 가지고 있어야한다.

하지만 결국 간선이 될 수 있는 후보는 X, Y, Z의 최소차이가 간선의 후보이다. 

 

좌표끼리 한 번 모아보자.

    A B C

X: 2, 6, 11

Y: 3, 9, 1

Z: 5, 10, 3

 

여기서 각 좌표마다 오름차순 정렬을 해보면

X: 2 6 11

Y: 1 3 9

Z: 3 5 10

 

이 된다. 물론 A,B,C 순서는 달라졌지만 이제 우리는 각 좌표마다 최소값을 구할 수 있다.

문제 자체는 어렵지 않지만 10만개의 값을 메모리를 최적화해야하는 방법을 찾아야하는데 이것이 정렬이라는 아이디어를 떠올리기 힘들다. O(N)으로 해결하도록 만들어야하는 것이 어려운 문제이다.

 

메모리 초과 코드

struct Node{
	int x, y, z;
};


for(int i = 1; i <= N; i++){
		for(int j = 1; j <= N; j++){
			if(i == j) continue;

			int X = abs(node[i].x - node[j].x);
			int Y = abs(node[i].y - node[j].y);
			int Z = abs(node[i].z - node[j].z);

			graph.push_back({min(X, min(Y, Z)), {  i, j  }});
		}
	}

 

이러면 10만 * 3개의 메모리를 모두 갖고 있어야하고 N * N 번의 연산이 필요하기 때문에 메모리초과가 발생한다. 필요없는 값들까지 모두 계산하게 된다.

 

위 예제에서 A의 X와 C의 X는 절대 행성 터널 후보가 될 수없는대도 말이다!

 

	sort(pos_x.begin(), pos_x.end());
	sort(pos_y.begin(), pos_y.end());
	sort(pos_z.begin(), pos_z.end());

	for(int i = 0; i < N - 1; i++){
		graph.push_back({ abs(pos_x[i].first - pos_x[i + 1].first) , {pos_x[i].second, pos_x[i + 1].second}});
		graph.push_back({ abs(pos_y[i].first - pos_y[i + 1].first) , {pos_y[i].second, pos_y[i + 1].second}});
		graph.push_back({ abs(pos_z[i].first - pos_z[i + 1].first) , {pos_z[i].second, pos_z[i + 1].second}});
	}

 

 

x, y, z 포지션을 입력받고 진짜 후보가 될 수 있는 값들만 O(N)으로 만들어낼 수 있다.

 

전체코드

#include <iostream>
#include <vector>
#include <algorithm>
using namespace std; 

int N;

vector<pair<int, int>> pos_x;
vector<pair<int, int>> pos_y;
vector<pair<int, int>> pos_z;

int answer = 0;

vector<pair<int, pair<int, int>>> graph;
int par[100001];

int find(int x){
	if(x == par[x]) return x;
	else return par[x] = find(par[x]);
}

void Union(int x, int y){
	int px = find(x);
	int py = find(y);

	if(px < py) par[py] = px;
	else par[px] = py; 
}

int main(){
	ios_base::sync_with_stdio(false);
	cin.tie(NULL);

	cin >> N;

	for(int i = 0; i < N; i++){
		par[i + 1] = i + 1;
		int x, y, z;
		cin >> x >> y >> z;

		pos_x.push_back({x, i + 1});
		pos_y.push_back({y, i + 1});
		pos_z.push_back({z, i + 1});
	}

	sort(pos_x.begin(), pos_x.end());
	sort(pos_y.begin(), pos_y.end());
	sort(pos_z.begin(), pos_z.end());

	for(int i = 0; i < N - 1; i++){
		graph.push_back({ abs(pos_x[i].first - pos_x[i + 1].first) , {pos_x[i].second, pos_x[i + 1].second}});
		graph.push_back({ abs(pos_y[i].first - pos_y[i + 1].first) , {pos_y[i].second, pos_y[i + 1].second}});
		graph.push_back({ abs(pos_z[i].first - pos_z[i + 1].first) , {pos_z[i].second, pos_z[i + 1].second}});
	}	

	sort(graph.begin(), graph.end());

	for(int i = 0; i < graph.size(); i++){
		int x = graph[i].second.first;
		int y = graph[i].second.second;
		int cost = graph[i].first;

		//cout << x << " " << y <<" "<< cost << endl;

		if(find(x) == find(y)) continue;
		Union(x, y);
		answer += cost;
	}

	cout << answer << '\n';
}