본문으로 바로가기

https://www.acmicpc.net/problem/11049

 

11049번: 행렬 곱셈 순서

첫째 줄에 입력으로 주어진 행렬을 곱하는데 필요한 곱셈 연산의 최솟값을 출력한다. 정답은 231-1 보다 작거나 같은 자연수이다. 또한, 최악의 순서로 연산해도 연산 횟수가 231-1보다 작거나 같

www.acmicpc.net

 

 

 

 

 

DP 문제

주어진 행렬은 서로 인접한 행렬끼리만 곱할 수 있고 다이나믹 프로그래밍으로 풀 수 있다. 

 

dp[i][j] 를 i 번째 행렬부터 j 번째 행렬까지의 연산 횟수 중 최소의 값을 저장하고 있다고 할 때, 

dp[i][j] = d[i][k] + d[k][j] + i, k, j 간의 곱셈 연산 수 로 점화식을 표현할 수 있다.

 

 

초기화 작업 >>

 

우선 초기화 작업으로 행과 열이 같을 때, 그리고 간격이 1인 행렬에 대해서만 dp 값을 초기화 한다. 

MAX 값은 987654321 로 임의 설정했다. 

 

 

DP 코드 >>

 

i 라는 간격 (= 곱할 행렬 갯수) 을 설정하고 

j 부터 j + i 까지의 행렬 곱셈에서 중간 지점인 k 를 변경하면서 인접한  j ~ j+i 번째 행렬 간의 최소 곱셈 수를 구할 수 있다. 

 

 

전체 행렬의 최소 연산 수는 dp[0][N-1] 에 저장되어 있다.

 

import java.io.BufferedReader;
import java.io.IOException;
import java.io.InputStreamReader;

public class Main {
    static int[][] matrix;
    static int[][] dp;
    static int N;
    
    public static void main(String[] args) throws IOException {
        BufferedReader br = new BufferedReader(new InputStreamReader(System.in));
       
        // dp[i][j] : i 부터 j 까지 곱했을 때 최적의 값
        // dp[i][j] = d[i][k] + d[k+1][j], i~k 행렬과 k+1 ~ j 행렬 곱 수
        
        N = Integer.parseInt(br.readLine());
        matrix = new int[N][2];
        dp = new int[N][N];
        
        for(int i = 0; i<N; i++) {
            String[] input = br.readLine().split(" ");
            matrix[i][0] = Integer.parseInt(input[0]);
            matrix[i][1] = Integer.parseInt(input[1]);
        }
        
        for(int i = 0; i<N; i++) {
            for(int j = 0; j<N; j++) {
                if (i == j) {
                    dp[i][j] = 0;
                } else if (j == i + 1) {
                    dp[i][j] = matrix[i][0] * matrix[i][1] * matrix[j][1];
                } else {
                    dp[i][j] = 987654321;
                }
            }
        }
        
        /*
            i : 간격
            j : 시작 지점
            k : 중간 지점
         */
        for(int i = 2; i<N; i++) {
            for(int j = 0; j<N-i; j++) {
                for(int k = j; k<j+i; k++) {
                    int count = matrix[j][0] * matrix[k][1] * matrix[j+i][1];
                    dp[j][j+i] = Math.min(dp[j][j+i], dp[j][k] + dp[k+1][j+i] + count);
                }
            }
        }
        
        System.out.println(dp[0][N-1]);
    }
}