COME ON CODE ON

A blog about programming and more programming.

Posts Tagged ‘divide-and-conquer

Recurrence Relation and Matrix Exponentiation

with 14 comments

Recurrence relations appear many times in computer science. Using recurrence relation and dynamic programming we can calculate the nth term in O(n) time. But many times we need to calculate the nth in O(log n) time. This is where Matrix Exponentiation comes to rescue.

We will specifically look at linear recurrences. A linear recurrence is a sequence of vectors defined by the equation Xi+1 = M Xi for some constant matrix M. So our aim is to find this constant matrix M, for a given recurrence relation.

Let’s first start by looking at the common structure of our three matrices Xi+1, Xi and M. For a recurrence relation where the next term is dependent on last k terms, Xi+1 and Xi+1 are matrices of size 1 x k and M is a matrix of size k x k.

| f(n+1)   |       | f(n)   |
|  f(n)    |       | f(n-1) |
| f(n-1)   | = M x | f(n-2) |
| ......   |       | ...... |
| f(n-k+1) |       | f(n-k) |

Let’s look at different type of recurrence relations and how to find M.

1. Let’s start with the most common recurrence relation in computer science, The Fibonacci Sequence. Fi+1 = Fi + Fi-1.

| f(n+1) | = M x | f(n)   |
| f(n)   |       | f(n-1) |

Now we know that M is a 2 x 2 matrix. Let it be

| f(n+1) | = | a b | x | f(n)   |
| f(n)   |   | c d |   | f(n-1) |

Now a*f(n) + b*f(n-1) = f(n+1) and c*f(n) + d*f(n-1) = f(n). Solving these two equations we get a=1,b=1,c=1 and d=0. So,

| f(n+1) | = | 1 1 | x | f(n)   |
| f(n)   |   | 1 0 |   | f(n-1) |

2. For recurrence relation f(n) = a*f(n-1) + b*f(n-2) + c*f(n-3), we get

| f(n+1) | = | a b c | x | f(n)   |
| f(n)   |   | 1 0 0 |   | f(n-1) |
| f(n-1) |   | 0 1 0 |   | f(n-2) |

3. What if the recurrence relation is f(n) = a*f(n-1) + b*f(n-2) + c, where c is a constant. We can also add it in the matrices as a state.

| f(n+1) | = | a b 1 | x | f(n)   |
| f(n)   |   | 1 0 0 |   | f(n-1) |
| c      |   | 0 0 1 |   | c      |

4. If a recurrence relation is given like this f(n) = f(n-1) if n is odd, f(n-2) otherwise, we can convert it to f(n) = (n&1) * f(n-1) + (!(n&1)) * f(n-2) and substitute the value accordingly in the matrix.


5.If there are more than one recurrence relation, g(n) = a*g(n-1) + b*g(n-2) + c*f(n), and, f(n) = d*f(n-1) + e*f(n-2). We can still define the matrix X in following way

| g(n+1) |   | a b c 0 |   | g(n)   |
| g(n)   | = | 1 0 0 0 | x | g(n-1) |
| f(n+2) |   | 0 0 d e |   | f(n+1) |
| f(n+1) |   | 0 0 1 0 |   | f(n)   |

Now that we have got our matrix M, how are we going to find the nth term.
Xi+1 = M Xi
(Multiplying M both sides)
M * Xi+1 = M * M Xi
Xi+2 = M^2 Xi
..
Xi+k = M^k Xi

So all we need now is to find the matrix M^k to find the k-th term. For example in the case of Fibonacci Sequence,

M^2 = | 2 1 |
      | 1 1 |

Hence F(2) = 2.

Now we need to learn to find M^k in O(n3 log b) time. The brute force approach to calculate a^b takes O(b) time, but using a recursive divide-and-conquer algorithm takes only O(log b) time:

  • If b = 0, then the answer is 1.
  • If b = 2k is even, then ab = (ak)2.
  • If b is odd, then ab = a * ab-1.

We take a similar approach for Matrix Exponentiation. The multiplication part takes the O(n3) time and hence the overall complexity is O(n3 log b). Here’s a sample code in C++ using template class:

#include<iostream>
using namespace std;

template< class T >
class Matrix
{
    public:
		int m,n;
		T *data;

        Matrix( int m, int n );
		Matrix( const Matrix< T > &matrix );

		const Matrix< T > &operator=( const Matrix< T > &A );
        const Matrix< T > operator*( const Matrix< T > &A );
        const Matrix< T > operator^( int P );

		~Matrix();
};

template< class T >
Matrix< T >::Matrix( int m, int n )
{
    this->m = m;
    this->n = n;
    data = new T[m*n];
}

template< class T >
Matrix< T >::Matrix( const Matrix< T > &A )
{
    this->m = A.m;
    this->n = A.n;
    data = new T[m*n];
    for( int i = 0; i < m * n; i++ )
		data[i] = A.data[i];
}

template< class T >
Matrix< T >::~Matrix()
{
    delete [] data;
}

template< class T >
const Matrix< T > &Matrix< T >::operator=( const Matrix< T > &A )
{
    if( &A != this )
    {
        delete [] data;
        m = A.m;
        n = A.n;
        data = new T[m*n];
        for( int i = 0; i < m * n; i++ )
			data[i] = A.data[i];
    }
    return *this;
}

template< class T >
const Matrix< T > Matrix< T >::operator*( const Matrix< T > &A )
{
	Matrix C( m, A.n );
	for( int i = 0; i < m; ++i )
		for( int j = 0; j < A.n; ++j )
        {
			C.data[i*C.n+j]=0;
			for( int k = 0; k < n; ++k )
				C.data[i*C.n+j] = C.data[i*C.n+j] + data[i*n+k]*A.data[k*A.n+j];
		}
	return C;
}

template< class T >
const Matrix< T > Matrix< T >::operator^( int P )
{
	if( P == 1 ) return (*this);
	if( P & 1 ) return (*this) * ((*this) ^ (P-1));
	Matrix B = (*this) ^ (P/2);
	return B*B;
}

int main()
{
    Matrix<int> M(2,2);
    M.data[0] = 1;M.data[1] = 1;
    M.data[2] = 1;M.data[3] = 0;

	int F[2]={0,1};
    int N;
	while (~scanf("%d",&N))
		if (N>1)
			printf("%lld\n",(M^N).data[0]);
		else
			printf("%d\n",F[N]);
}

Written by fR0DDY

May 8, 2011 at 5:26 PM

%d bloggers like this: