• 연쇄 행렬 곱셈 (Matrix chain multiplication) :: 마이구미
    알고리즘 2017. 11. 27. 00:14
    반응형

    이 글은 연쇄 행렬 곱셈(Matrix chain multiplication) 알고리즘을 다룬다.

    동적계획법이 기반으로 된 알고리즘이다.

    위키 내용과 관련 알고리즘 문제를 참고했다. (위키는 번역본이 아직 없다)

    참고 링크 - https://en.wikipedia.org/wiki/Matrix_chain_multiplication


    연쇄 행렬 곱셈은 최적화 문제를 동적계획법(DP) 을 이용하여 해결할 수 있다.

    행렬의 순서가 주어질 때, 행렬의 곱셈에서 가장 효율적인 방법을 찾는 것이 목표이다.

    문제는 실제로는 곱셈을 수행하는 것이 아니라 행렬의 곱셈 순서를 결정하는 것이다.


    행렬 곱셈에 있어서, 괄호를 어디에 넣어도 같은 결과를 만든다.

    이것이 의미하는 것은 예를 통해 확인해보자.

    4개의 행렬 A, B, C, D 를 가정한다면 다음과 같다.


    ((AB)C)D = (A(BC))D = (AB)(CD) = A((BC)D) = A(B(CD))


    위와 같이 결과는 같다.

    하지만 각 경우에 따라 연산 횟수는 서로 다르다.

    예를 통해 의미를 파악해보자. (11049번 행렬 곱셈 순서 문제)


    예를 들어, A의 크기가 5×3이고, B의 크기가 3×2, C의 크기가 2×6인 경우에 행렬의 곱 ABC를 구하는 경우를 생각해보자.

    • AB를 먼저 곱하고 C를 곱하는 경우 (AB)C에 필요한 곱셈 연산의 수는 5×3×2 + 5×2×6 = 30 + 60 = 90번이다.
    • BC를 먼저 곱하고 A를 곱하는 경우 A(BC)에 필요한 곱셈 연산의 수는 3×2×6 + 5×3×6 = 36 + 90 = 126번이다.

    같은 곱셈이지만, 곱셈을 하는 순서에 따라서 곱셈 연산의 수가 달라진다.


    위의 경우에는 첫번째 경우가 더 적은 횟수를 통해 효율적인 방법이 된다.

    이렇게 최소의 연산 횟수를 구하기 위해서는 어떻게 해야할까?

    괄호에 대한 모든 경우(브루트 포스) 를 확인하여 해결할 수도 있지만, 행렬이 커질수록 느리고 비효율적이다.

    그래서 나온 것이 동적계획법을 이용한 연쇄 행렬 곱셈 알고리즘이 된다.


    연쇄 행렬 곱셈의 구현의 핵심은 부분수열(subsequence) 을 이용하는 것이다.

    간략히 설명한다면, 다음과 같다.


    1. 전체 행렬에 있어, 2개의 부분수열로 분리한다.
    2. 각 부분수열에 있어, 최소 비용을 구한 후 합쳐준다.
    3. 분리할 수 있을 때까지 부분수열의 길이를 늘려주면서 이 과정을 반복한다.


    연쇄 행렬 곱셈의 점화식은 다음과 같다.


    m[i][j] = 행렬 i번에서 j번까지의 최소 비용, d = 행렬 크기

    => m[i][j] = m[i][k] + m[k + 1][j] + d[i - 1] + d[k] + d[j]


    연쇄 행렬 곱셈의 의사코드는 다음과 같다.


    // Matrix A[i] has dimension dims[i-1] x dims[i] for i = 1..n MatrixChainOrder(int dims[]) { // length[dims] = n + 1 n = dims.length - 1; // m[i,j] = Minimum number of scalar multiplications (i.e., cost) // needed to compute the matrix A[i]A[i+1]...A[j] = A[i..j] // The cost is zero when multiplying one matrix for (i = 1; i <= n; i++) m[i, i] = 0; for (len = 2; len <= n; len++) { // Subsequence lengths for (i = 1; i <= n - len + 1; i++) { j = i + len - 1; m[i, j] = MAXINT; for (k = i; k <= j - 1; k++) { cost = m[i, k] + m[k+1, j] + dims[i-1]*dims[k]*dims[j]; if (cost < m[i, j]) { m[i, j] = cost; s[i, j] = k; // Index of the subsequence split that achieved minimal cost } } } } }


    시간복잡도는 O(n^3) 이 된다.

    실제로 동작하는 것을 보면 쉽게 이해할 수 있다.

    행렬 A, B, C, D 가 주어졌을 때 다음과 같다.


    m[1][2] = 행렬 A~B

    m[2][3] = 행렬 B~C

    m[3][4] = 행렬 C~D


    m[1][3] = 행렬 A~C, m[1][1] + m[2][2] + ....

    m[1][3] = 행렬 A~C, m[1][2] + m[3][3] + ....

    m[2][4] = 행렬 B~D, m[2][1] + m[2][4] + ....

    m[2][4] = 행렬 B~D, m[2][2] + m[3][4] + ....


    m[1][4] = 행렬 A~D, m[1][1] + m[2][4] + ....

    m[1][4] = 행렬 A~D, m[1][2] + m[3][4] + ....

    m[1][4] = 행렬 A~D, m[1][3] + m[4][4] + ....


    쪼개서 동작하는 것을 본다면, 쉽게 흐름을 이해할 수 있다.

    코드가 이해되지 않는다면, 흐름을 이해한 후 코드를 맞춰보면 도움이 될 것이다.


    11049번 "행렬 곱셈 순서" 문제는 이 알고리즘을 위해 만들어졌다고 해도 무방하다.

    그렇기에 이해한 후 풀어보면 좋다.

    응용을 위한 문제로는 11066번 "파일 합치기" 를 풀어보면 된다.


    11049번 문제 링크 - https://www.acmicpc.net/problem/11049

    11066번 문제 링크 - https://www.acmicpc.net/problem/11066

    Github - https://github.com/hotehrud/acmicpc/tree/master/dp


    11049번 "행렬 곱셈 순서" 코드


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    private void solve() {
        int n = sc.nextInt();
        int[][] m = new int[501][501];
        int[] d = new int[1001];
     
        for (int i = 0; i < n; i++) {
            d[i] = sc.nextInt();
            d[i + 1= sc.nextInt();
        }
     
        for (int len = 2; len <= n; len++) {
            for (int i = 1; i <= n - len + 1; i++) {
                int j = i + len - 1;
                m[i][j] = Integer.MAX_VALUE;
                for (int k = i; k < j; k++) {
                    int cost = m[i][k] + m[k + 1][j] + d[i - 1* d[k] * d[j];
                    m[i][j] = Math.min(m[i][j], cost);
                }
            }
        }
        System.out.println(m[1][n]);
    }
    cs


    11066번 "파일 합치기" 코드


    1
    2
    3
    4
    5
    6
    7
    8
    9
    10
    11
    12
    13
    14
    15
    16
    17
    18
    19
    20
    21
    22
    23
    24
    25
    26
    27
    28
    private void solve() {
        int t = sc.nextInt();
        StringBuilder sb = new StringBuilder();
     
        while (t-- > 0) {
            int n = sc.nextInt();
            int[] sum = new int[n + 1];
            int[][] dp = new int[n + 1][n + 1];
     
            for (int i = 1; i <= n; i++) {
                sum[i] = sum[i - 1+ sc.nextInt();
            }
     
            for (int len = 2; len <= n; len++) {
                for (int i = 1; i <= n - len + 1; i++) {
                    int j = i + len - 1;
                    dp[i][j] = Integer.MAX_VALUE;
     
                    for(int k = i; k < j; k++) {
                        int cost = dp[i][k] + dp[k + 1][j] + sum[j] - sum[i - 1];
                        dp[i][j] = Math.min(dp[i][j], cost);
                    }
                }
            }
            sb.append(dp[1][n] + "\n");
        }
        System.out.println(sb.toString());
    }
    cs


    반응형

    댓글 1

Designed by Tistory.