[백준 2228번] 구간 나누기 (java)

2228번: 구간 나누기
N(1≤N≤100)개의 수로 이루어진 1차원 배열이 있다. 이 배열에서 M(1≤M≤⌈(N/2)⌉)개의 구간을 선택해서, 구간에 속한 수들의 총 합이 최대가 되도록 하려 한다. 단, 다음의 조건들이 만족되어야
www.acmicpc.net
dp[n][m] : n개의 수를 m개의 구간으로 나누었을 때 최대 합
arr[i] : 입력받은 수들을 저장
n번째 수가 m번째 구간에 포함되어 있는지, 포함되어 있지 않는지 두 가지 경우로 나눌 수 있다.
- n번째 수가 구간 m에 포함되지 않은 경우 : dp[n][m] = dp[n-1][m]
n-1번째 수가 구간 m에 포함된 경우의 최대 합(dp[n-1][m])과 동일하다.
- n번째 수가 구간 m에 포함된 경우 :
dp[n][m] = max(dp[k][m-1]) + sum(arr[k+2] ~ arr[n]) (0 <= k <= n-2)
n번째 수가 구간 m에 포함된 경우, 구간 m에 포함된 수들의 범위를 알 수가 없다.
따라서 k를 이동해가며(0<=k<=n-2) k번째 수가 구간 m-1에 포함된다고 했을 때(dp[k][m-1])
구간 m에 포함된 수는 k+2번째 수부터 n번째 수가 된다. 왜냐하면 구간과 구간은 붙어있을 수 없기
때문에 k+2번째 수부터 포함시켜야 하기 때문이다.
따라서, dp[k][m-1] + (k+2번째 수부터 n번째 수까지의 합)가 dp[n][m]의 후보가 된다.
그리고 k에 따라서(0<=k<=n-2) 구한 값들 중, 최대가 되는 값이 최종 dp[n][m]가 된다.

int min = 0;
if (m == 1)
min = -1;
for (int k = n - 2; k >= min; k--) {
if (k < 0)
dp[n][m] = Math.max(dp[n][m], sum[n]);
else
dp[n][m] = Math.max(dp[n][m], dp[k][m - 1] + sum[n] - sum[k + 1]);
}
위 코드에서 k+2번째 수부터 n번째 수까지의 합은 누적합을 이용하여 구한 것이다.
k+2번째 수부터 n번째 수까지의 합은 (n번째 수까지의 합) - (k+1번째 수까지의 합)과 동일하다.
(n번째 수까지의 합) - (k+1번째 수까지의 합) = sum[n] - sum[k+1]
k의 범위를 자세히 보면,
m = 1일 때는 min = -1,
m != 1일 때는 min = 0
으로 설정하여서 k의 범위를 다르게 해주었다.
이렇게 해준 이유는, m = 1인 경우에는 입력받은 수가 모두 양수일 경우 모든 수를 구간에 포함시킨 것이 최대 합이 될 것이다. 그런데 k의 최소 값을 0으로 설정해버리면 위 코드의 점화식
dp[n][m] = Math.max(dp[n][m], dp[k][m-1] + sum[n] - sum[k+1])에서
k = 0일 때 구간의 합이 sum[n] - sum[1]이 되는데 그러면 첫번째 수인 arr[1]를 포함시킬 수 없기 때문이다.
그래서 구간이 하나뿐인 경우(m = 1)에는 k의 최소 값을 -1로 설정하여서 sum[n] - sum[0]
즉, 구간 합에 첫번째 수를 포함시키는 경우도 고려했다.
그리고 dp 2차원 배열에서 m = 0인 경우에는 값을 0으로 초기화하고,
그 외의 경우에는 Integer.MIN_VALUE/2 로 초기화했다.
최종적으로 dp배열의 메모이제이션을 완료했을 때,
임의의 n과 m에 대하여 dp[n][m] = Integer.MIN_VALUE/2 라는 것은, n개의 수를 m개의 구간으로 나눌 수 없는 경우를 의미한다. 예를 들어 n<m일 때는 n개의 수를 m개의 구간으로 나누는 것이 불가능하다.
for (int n = 0; n <= N; n++) {
for (int m = 1; m <= M; m++) { // m이 1부터 시작
dp[n][m] = Integer.MIN_VALUE / 2;
}
}
전체 코드 :
import java.io.*; | |
public class n02228 { | |
static int N, M; | |
static int[] arr, sum; | |
static int[][] dp; | |
public static void main(String[] args) throws Exception { | |
BufferedReader br = new BufferedReader(new InputStreamReader(System.in)); | |
BufferedWriter bw = new BufferedWriter(new OutputStreamWriter(System.out)); | |
String[] input = br.readLine().split(" "); | |
N = Integer.parseInt(input[0]); | |
M = Integer.parseInt(input[1]); | |
arr = new int[N + 1]; | |
sum = new int[N + 1]; | |
dp = new int[N + 1][M + 1]; | |
for (int i = 1; i <= N; i++) { | |
arr[i] = Integer.parseInt(br.readLine()); | |
sum[i] = sum[i - 1] + arr[i]; // 누적합 구하기 | |
} | |
for (int n = 0; n <= N; n++) { | |
for (int m = 1; m <= M; m++) { | |
dp[n][m] = Integer.MIN_VALUE / 2; | |
} | |
} | |
dp[1][1] = arr[1]; | |
for (int n = 2; n <= N; n++) { | |
for (int m = 1; m <= M; m++) { | |
dp[n][m] = dp[n - 1][m]; // n번째 수가 구간에 포함안되는 경우 | |
int min = 0; | |
if (m == 1) | |
min = -1; | |
for (int k = n - 2; k >= min; k--) { | |
if (k < 0) | |
dp[n][m] = Math.max(dp[n][m], sum[n]); | |
else | |
dp[n][m] = Math.max(dp[n][m], dp[k][m - 1] + sum[n] - sum[k + 1]); | |
} | |
} | |
} | |
bw.write(dp[N][M] + "\n"); | |
bw.flush(); | |
} | |
} |