N개의 마을로 이루어진 트리 형태의 나라에서
우수 마을인 두 마을이 인접하지 않고
우수 마을이 아닌 마을은 적어도 한 우수 마을과 인접해야 할 때
우수 마을로 선정된 마을 주민 수의 최댓값을 구한다.
Tree + DP
2535번과 비슷하다.
다만 2535번에서는 v가 선택될 때/아닐 때 두 가지만 고려하면 됐는데
여기서는 v의 자식노드까지 생각해야 한다. 좀 더 상위버전인 듯.
우수마을을 선정하는 조건을 파악하자
1. 우수 마을인 두 마을이 인접하지 않는다.
v가 우수마을이면 v의 모든 자식 노드는 우수 마을이 될 수 없다.
2. 우수마을이 아닌 마을은 적어도 한 우수 마을과 인접하다.
v가 우수마을이 아니라면 v의 자식 노드 중 하나가 우수 마을이거나, v의 부모 노드가 우수 마을이어야 한다.
* v가 우수마을일 때/아닐 때 두 가지만 고려해도 될 것 같지만 다음과 같은 경우가 있다.
A-B-C-D와 같은 나라가 있을 때 A와 D가 우수 마을인 경우
즉 v와 vchild(v의 자식 노드) 모두 우수 마을이 아니지만
v의 부모 노드와 vchild의 자식노드가 우수 마을일 때
따라서 세 가지 경우를 고려한다.
- v가 우수 마을인 경우
- v는 우수 마을이 아니지만 v의 자식 노드 중 적어도 하나가 우수 마을인 경우
- v와 v의 모든 자식 노드가 우수 마을이 아닌 경우
1을 root로 두고 DFS로 탐색한다.
* 처음엔 위상 정렬로 생각해봤는데 입력에서 edge가 undirectional로 주어져 매우 불편하다...
void dfs(int cur) {
// Get dp[cur]
dp[cur][0] += num[cur];
int diff = 987654;
bool no_child_selected = true;
for (int nxt : e[cur]) {
if (dp[nxt][0]) continue; //parent
dfs(nxt); //child
// Select cur
dp[cur][0] += max(dp[nxt][1],dp[nxt][2]);
// Select some child, not cur
if (dp[nxt][0] < dp[nxt][1]) {
dp[cur][1] += dp[nxt][1];
diff = min(diff, dp[nxt][1] - dp[nxt][0]);
}
else {
dp[cur][1] += dp[nxt][0];
no_child_selected = false;
}
// Select v, not child, not cur
if (dp[nxt][1] && dp[cur][2] >= 0) dp[cur][2] += dp[nxt][1];
else dp[cur][2] = -1;
}
if (no_child_selected && diff < 987654) dp[cur][1] -= diff;
}
현재 노드를 cur, 자식 노드를 nxt라 하자.
먼저 마을 cur의 주민 수를 dp[cur][0]에 더해준다. (dfs에서의 visit 역할이기도 함) line 3
1. dp[cur][0] cur이 우수 마을인 경우
어떤 nxt도 우수 마을이 아니어야 하므로
모든 nxt에 대해 dp[nxt][1]과 dp[nxt][2] 중 최댓값을 더한다. line 12
2. dp[cur][1] cur이 우수 마을이 아니고, nxt 중 적어도 하나는 우수 마을인 경우
dp[nxt][0]과 dp[nxt][1]을 비교해 더한다.
dp[nxt][0]이 더 크거나 같은 경우, nxt를 우수 마을로 선택하는 상황 line 18~21
dp[nxt][1]이 더 큰 경우에는 일단 그 값을 더하고 line 15
후에 nxt중 하나도 우수 마을이 없는 경우를 대비해 diff를 구한다. line 16
diff는 반복문이 끝난 후 상황을 판단해 더한다. line 26
3. dp[cur][2] cur와 모든 nxt가 우수 마을이 아닌 경우
nxt가 우수 마을이 아니라면 nxt의 자식 노드 중 하나가 반드시 우수 마을이어야 한다.
이런 경우가 불가능하면 dp[cur][2] 값을 음수로 만들어 준다.
즉 dp[nxt][1]이 양수인 경우에는 그 값을 더해주고 line 23
dp[nxt][1]이 0인 경우 (nxt의 자식 노드 중 우수 마을이 없음) 값을 음수처리한다. line 24
#include <cstdio>
#include <vector>
#include <algorithm>
using namespace std;
int dp[10001][3];
int num[10001];
vector<int> e[10001];
void dfs(int cur) {
// Get dp[cur]
dp[cur][0] += num[cur];
int diff = 987654;
bool no_child_selected = true;
for (int nxt : e[cur]) {
if (dp[nxt][0]) continue; //parent
dfs(nxt); //child
// Select cur
dp[cur][0] += max(dp[nxt][1],dp[nxt][2]);
// Select some child, not cur
if (dp[nxt][0] < dp[nxt][1]) {
dp[cur][1] += dp[nxt][1];
diff = min(diff, dp[nxt][1] - dp[nxt][0]);
}
else {
dp[cur][1] += dp[nxt][0];
no_child_selected = false;
}
// Select v, not child, not cur
if (dp[nxt][1] && dp[cur][2] >= 0) dp[cur][2] += dp[nxt][1];
else dp[cur][2] = -1;
}
if (no_child_selected && diff < 987654) dp[cur][1] -= diff;
}
int main() {
int N, a, b;
scanf("%d", &N);
for (int i = 1; i <= N; ++i) scanf("%d", &num[i]);
for (int i = 1; i < N; ++i) {
scanf("%d%d", &a, &b);
e[a].push_back(b), e[b].push_back(a);
}
dfs(1);
printf("%d\n", max(dp[1][0], dp[1][1])); return 0;
}
첫 도전에 뇌절(?) 와서 한 일주일 미루고 계속 안 풀다가 해결
사실 엄청 어려운 문제가 아닌데도 뭔가 예외처리하는 것들이 복잡하게 느껴졌다.
중간에 메모리 낭비도 많이 했었는데 결국 DFS + diff 이용해서 깔끔하게 풀 수 있었다고 생각한다.
'Probelm Solving > BOJ' 카테고리의 다른 글
BOJ2549 루빅의 사각형 (0) | 2021.07.25 |
---|---|
BOJ3830 교수님은 기다리지 않는다 (0) | 2021.06.19 |
BOJ1014 컨닝 (0) | 2021.05.19 |
BOJ3716 도로 네트워크 (0) | 2021.04.23 |
BOJ1086 박성원 (0) | 2021.04.21 |
댓글