程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> C++入門知識 >> 矩陣快速冪 poj3070 3233 3735 3150

矩陣快速冪 poj3070 3233 3735 3150

編輯:C++入門知識

一、矩陣的基礎知識

1.結合性 (AB)C=A(BC).

2.對加法的分配性 (A+B)C=AC+BC,C(A+B)=CA+CB .

3.對數乘的結合性 k(AB)=(kA)B =A(kB).

4.關於轉置 (AB)'=B'A'.

一個矩陣就是一個二維數組,為了方便聲明多個矩陣,我們一般會將矩陣封裝一個類或定義一個矩陣的結構體,我采用的是後者。

最特殊的矩陣應該就是單位矩陣e了,它的對角線的元素為1,非對角線元素為0。一個n*n的矩陣的0次冪就是單位矩陣。

若A為n×k矩陣,B為k×m矩陣,則它們的乘積AB(有時記做A·B)將是一個n×m矩陣。其乘積矩陣AB的第i行第j列的元素為:

 

一般矩陣乘法采用樸素的O(n^3)的算法,但是對於一些比較稀疏的矩陣(就是矩陣中0比較多),對於這樣的矩陣我們可以采用矩陣的優化,這個算法也適用於一般的矩陣,0特別多時,復雜度可能會降低到O(n^2),實現如下:

 


還要注意的是,我們要盡可能的減少取模運算,因為取模的復雜度很高,這樣我們就可以節約時間了。

矩陣加法就是簡單地將對應的位置的兩個矩陣的元素相加。

我們一般考慮的是n階方陣之間的乘法以及n階方陣與n維向量(把向量看成n×1的矩陣)的乘法。矩陣乘法最重要的性質就是滿足結合律,同時它另一個很重要的性質就是不滿足交換率,這保證了矩陣的冪運算滿足快速冪取模(A^k % MOD)算法,矩陣快速冪其實就是二分指數,避免重復的計算。我們可以采用遞歸的方式很容易的寫出來,但是當指數比較大,或者矩陣比較大得時候,我們就會出現棧溢出的狀況,不斷RE(我就被坑過)。所以還是寫成迭代的方式比較好。

 

制作矩陣圖一般要遵循以下幾個步驟:

1、列出質量因素:

2、把成對對因素排列成行和列,表示其對應關系

3、選擇合適的矩陣圖類型

4、在成對因素交點處表示其關系程度,一般憑經驗進行定性判斷,可分為三種:關系密切、關系較密切、關系一般(或可能有關系),並用不同符號表示

5、根據關系程度確定必須控制的重點因素

6、針對重點因素作對策表。


二、矩陣快速冪的應用

7、poj3070 是求解菲波那切數列,f(n)=f(n-1)+f(n-2),如果我們一個個遞推求解,當n特別大的時候復雜度就會變的很高,對於f(n)= a*f(n-1)+b*f(n-2),在矩陣運算中我們會發現這樣一組公式:


到知道這個公式後我們就采用矩陣快速冪的方法可以求解f(n)

[cpp]
#include <iostream>  
#include <cstdio>  
#include <cstring>  
using namespace std; 
struct mat{ 
    int at[2][2]; 
}; 
mat d; 
int n,mod; 
mat mul(mat a,mat b) 

    mat t; 
    memset(t.at,0,sizeof(t.at)); 
    for(int i=0;i<n;++i) 
    { 
        for(int k=0;k<n;++k) 
        { 
            if(a.at[i][k]) 
            for(int j=0;j<n;++j) 
            { 
                t.at[i][j]+=a.at[i][k]*b.at[k][j]; 
                if(t.at[i][j]>=mod){t.at[i][j]%=mod;} 
            } 
        } 
    } 
    return t; 

mat expo(mat p,int k) 

    if(k==1)return p; 
    mat e; 
    memset(e.at,0,sizeof(e.at)); 
    for(int i=0;i<n;++i){e.at[i][i]=1;} 
    if(k==0)return e; 
    while(k) 
    { 
        if(k&1)e=mul(p,e); 
        p=mul(p,p); 
        k>>=1; 
    } 
    return e; 

int main() 

    n=2;mod=10000; 
    d.at[1][1]=0; 
    d.at[0][0]=d.at[1][0]=d.at[0][1]=1; 
    int k; 
    while(~scanf("%d",&k)) 
    { 
        if(k==-1)break; 
        mat ret=expo(d,k); 
        int ans=ret.at[0][1]%mod; 
        printf("%d\n",ans); 
    } 
    return 0; 

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
struct mat{
 int at[2][2];
};
mat d;
int n,mod;
mat mul(mat a,mat b)
{
 mat t;
 memset(t.at,0,sizeof(t.at));
 for(int i=0;i<n;++i)
 {
  for(int k=0;k<n;++k)
  {
   if(a.at[i][k])
   for(int j=0;j<n;++j)
   {
    t.at[i][j]+=a.at[i][k]*b.at[k][j];
    if(t.at[i][j]>=mod){t.at[i][j]%=mod;}
   }
  }
 }
 return t;
}
mat expo(mat p,int k)
{
 if(k==1)return p;
 mat e;
 memset(e.at,0,sizeof(e.at));
 for(int i=0;i<n;++i){e.at[i][i]=1;}
 if(k==0)return e;
 while(k)
 {
  if(k&1)e=mul(p,e);
  p=mul(p,p);
  k>>=1;
 }
 return e;
}
int main()
{
 n=2;mod=10000;
 d.at[1][1]=0;
 d.at[0][0]=d.at[1][0]=d.at[0][1]=1;
 int k;
 while(~scanf("%d",&k))
 {
  if(k==-1)break;
  mat ret=expo(d,k);
  int ans=ret.at[0][1]%mod;
  printf("%d\n",ans);
 }
 return 0;
}

 

2、poj3233題意:給出矩陣A,求S = A + A^2 + A^3 + … + A^k 二分和

 

[cpp]
#include <iostream>  
#include <cstdio>  
#include <cstring>  
using namespace std; 
#define LL long long  
int n,m,k;  
int MOD; 
struct mat { 
    int at[40][40]; 
}; 
mat d; 
mat mul(mat a, mat b)  

    mat ret; 
    memset(ret.at,0,sizeof(ret.at)); 
    for (int i=0;i<n;++i) 
    { 
        for (int k=0;k<n;++k)  
        { 
            if(a.at[i][k]) 
            for (int j=0;j<n;++j) 
            { 
                ret.at[i][j]+=a.at[i][k]*b.at[k][j]; 
                if(ret.at[i][j]>=MOD){ret.at[i][j]%=MOD;} 
            } 
        } 
    } 
    return ret; 

 
mat expo(mat a, int k)  

    if(k==1)return a; 
    mat e; 
    memset(e.at,0,sizeof(e.at)); 
    for(int i=0;i<n;++i){e.at[i][i]=1;} 
    if(k==0)return e; 
    while(k) 
    { 
        if(k&1)e=mul(a,e); 
        a=mul(a,a); 
        k>>=1; 
    } 
    return e; 

 
mat add(mat a,mat b) 

    mat t; 
    for(int i=0;i<n;++i) 
    { 
        for(int j=0;j<n;++j) 
        {  
            t.at[i][j]=(a.at[i][j]+b.at[i][j]); 
            if(t.at[i][j]>=MOD){t.at[i][j]%=MOD;} 
        } 
    } 
    return t; 

 
void print(mat ans) 

    for(int i=0;i<n;++i) 
    { 
        for(int j=0;j<n;++j) 
        { 
            if(j==0){printf("%d",ans.at[i][j]);continue;} 
            printf(" %d",ans.at[i][j]); 
        } 
        printf("\n"); 
    } 

 
mat sum(int k) 

    if(k==1){return d;} 
    if(k&1) 
    { 
        return add(sum(k-1),expo(d,k)); 
    } 
    else 
    { 
        mat s=sum(k>>1); 
        return add(s,mul(s,expo(d,k>>1))); 
    } 

int main() 

    while(~scanf("%d%d%d",&n,&k,&m)) 
    { 
        MOD=m; 
        mat ans,t; 
        for(int i=0;i<n;++i) 
        { 
            for(int j=0;j<n;++j) 
            { 
                scanf("%d",&d.at[i][j]); 
                if(d.at[i][j]>=m) 
                { 
                    d.at[i][j]%=m; 
                } 
            } 
        } 
        ans=sum(k); 
        print(ans); 
    } 
    return 0;  

#include <iostream>
#include <cstdio>
#include <cstring>
using namespace std;
#define LL long long
int n,m,k;
int MOD;
struct mat {
 int at[40][40];
};
mat d;
mat mul(mat a, mat b)
{
 mat ret;
 memset(ret.at,0,sizeof(ret.at));
 for (int i=0;i<n;++i)
 {
  for (int k=0;k<n;++k)
  {
   if(a.at[i][k])
   for (int j=0;j<n;++j)
   {
    ret.at[i][j]+=a.at[i][k]*b.at[k][j];
    if(ret.at[i][j]>=MOD){ret.at[i][j]%=MOD;}
   }
  }
 }
 return ret;
}

mat expo(mat a, int k)
{
 if(k==1)return a;
 mat e;
 memset(e.at,0,sizeof(e.at));
 for(int i=0;i<n;++i){e.at[i][i]=1;}
 if(k==0)return e;
 while(k)
 {
  if(k&1)e=mul(a,e);
  a=mul(a,a);
  k>>=1;
 }
 return e;
}

mat add(mat a,mat b)
{
 mat t;
 for(int i=0;i<n;++i)
 {
  for(int j=0;j<n;++j)
  {
   t.at[i][j]=(a.at[i][j]+b.at[i][j]);
   if(t.at[i][j]>=MOD){t.at[i][j]%=MOD;}
  }
 }
 return t;
}

void print(mat ans)
{
 for(int i=0;i<n;++i)
 {
  for(int j=0;j<n;++j)
  {
   if(j==0){printf("%d",ans.at[i][j]);continue;}
   printf(" %d",ans.at[i][j]);
  }
  printf("\n");
 }
}

mat sum(int k)
{
 if(k==1){return d;}
 if(k&1)
 {
  return add(sum(k-1),expo(d,k));
 }
 else
 {
  mat s=sum(k>>1);
  return add(s,mul(s,expo(d,k>>1)));
 }
}
int main()
{
 while(~scanf("%d%d%d",&n,&k,&m))
 {
  MOD=m;
  mat ans,t;
  for(int i=0;i<n;++i)
  {
   for(int j=0;j<n;++j)
   {
    scanf("%d",&d.at[i][j]);
    if(d.at[i][j]>=m)
    {
     d.at[i][j]%=m;
    }
   }
  }
  ans=sum(k);
  print(ans);
 }
 return 0;
}

 


3、poj3735

   題意:有n只貓咪,開始時每只貓咪有花生0顆,現有一組操作,由下面三個中的k個操作組成:

   1. g i 給i只貓咪一顆花生米

   2. e i 讓第i只貓咪吃掉它擁有的所有花生米

   3. s i j 將貓咪i與貓咪j的擁有的花生米交換

   現將上述一組操作做m次後,問每只貓咪有多少顆花生?

分析:剛開始每只貓都沒有花生,所以我們要在單位矩陣上構建矩陣。給第i只貓一個花生米,那麼++met[0][i],讓第i只貓吃掉所有的花生米,就令第i列清空,喵咪i與貓咪j交換花生米,就令第i列和第j列互換。矩陣就這樣構造完畢,操作m次,我們就可以矩陣快速冪計算了。

[cpp]
#include <iostream>  
#include <cstring>  
#include <cstdio>  
#define LL long long  
using namespace std; 
struct met{ 
    LL at[105][105]; 
}; 
met ret,d; 
LL n,m,k; 
met mul(met a,met b) 

    memset(ret.at,0,sizeof(ret.at)); 
    for(int i=0;i<=n;++i) 
    { 
        for(int k=0;k<=n;++k) 
        { 
            if(a.at[i][k]) 
            { 
                for(int j=0;j<=n;++j) 
                { 
                    ret.at[i][j]+=a.at[i][k]*b.at[k][j]; 
                } 
            } 
        } 
    } 
    return ret; 

 
met expo(met a,LL k) 

    if(k==1) return a; 
    met e; 
    memset(e.at,0,sizeof(e.at)); 
    for(int i=0;i<=n;++i){e.at[i][i]=1;} 
    if(k==0)return e; 
    while(k) 
    { 
        if(k&1)e=mul(e,a); 
        k>>=1; 
        a=mul(a,a); 
    } 
    return e; 

 
 
int main() 

    while(~scanf("%lld%lld%lld",&n,&m,&k)) 
    { 
        LL a,b; 
        char ch[5]; 
        if(!n&&!k&&!m)break; 
        memset(d.at,0,sizeof(d.at)); 
        for(int i=0;i<=n;++i) 
        {d.at[i][i]=1;} 
        while(k--) 
        { 
            scanf("%s",ch); 
            if(ch[0]=='g') 
            { 
                scanf("%lld",&a); 
                d.at[0][a]++;        
            } 
            else if(ch[0]=='e') 
            { 
                scanf("%lld",&a); 
                for(int i=0;i<=n;++i) 
                { 
                    d.at[i][a]=0;    
                } 
            } 
            else { 
                scanf("%lld%lld",&a,&b); 
                for(int i=0;i<=n;++i) 
                { 
                    LL t=d.at[i][a]; 
                    d.at[i][a]=d.at[i][b]; 
                    d.at[i][b]=t; 
                } 
 
            } 
        } 
        met ans=expo(d,m); 
        printf("%lld",ans.at[0][1]); 
        for(int i=2;i<=n;++i) 
        { 
            printf(" %lld",ans.at[0][i]); 
        } 
        printf("\n"); 
 
    } 
    return 0;  

#include <iostream>
#include <cstring>
#include <cstdio>
#define LL long long
using namespace std;
struct met{
 LL at[105][105];
};
met ret,d;
LL n,m,k;
met mul(met a,met b)
{
 memset(ret.at,0,sizeof(ret.at));
 for(int i=0;i<=n;++i)
 {
  for(int k=0;k<=n;++k)
  {
   if(a.at[i][k])
   {
    for(int j=0;j<=n;++j)
    {
     ret.at[i][j]+=a.at[i][k]*b.at[k][j];
    }
   }
  }
 }
 return ret;
}

met expo(met a,LL k)
{
 if(k==1) return a;
 met e;
 memset(e.at,0,sizeof(e.at));
 for(int i=0;i<=n;++i){e.at[i][i]=1;}
 if(k==0)return e;
 while(k)
 {
  if(k&1)e=mul(e,a);
  k>>=1;
  a=mul(a,a);
 }
 return e;
}


int main()
{
 while(~scanf("%lld%lld%lld",&n,&m,&k))
 {
  LL a,b;
  char ch[5];
  if(!n&&!k&&!m)break;
  memset(d.at,0,sizeof(d.at));
  for(int i=0;i<=n;++i)
  {d.at[i][i]=1;}
  while(k--)
  {
   scanf("%s",ch);
   if(ch[0]=='g')
   {
    scanf("%lld",&a);
    d.at[0][a]++;  
   }
   else if(ch[0]=='e')
   {
    scanf("%lld",&a);
    for(int i=0;i<=n;++i)
    {
     d.at[i][a]=0; 
    }
   }
   else {
    scanf("%lld%lld",&a,&b);
    for(int i=0;i<=n;++i)
    {
     LL t=d.at[i][a];
     d.at[i][a]=d.at[i][b];
     d.at[i][b]=t;
    }

   }
  }
  met ans=expo(d,m);
  printf("%lld",ans.at[0][1]);
  for(int i=2;i<=n;++i)
  {
   printf(" %lld",ans.at[0][i]);
  }
  printf("\n");

 }
 return 0;
}

 

4、poj3150題目大意:給定n(1<=n<=500)個數字和一個數字m,這n個數字組成一個環(a0,a1.....an-1)。如果對ai進行一次d-step操作,那麼ai的值變為與ai的距離小於d的所有數字之和模m。求對此環進行K次d-step(K<=10000000)後這個環的數字會變為多少。

分析:首先我們要構造矩陣,我們會得到一個500*500的矩陣,那麼代碼的復雜度就會變成O(log(k)*n^3),很明顯這麼高的復雜度會超時的。但是我們發現這個矩陣是一個循環矩陣, 第i行都是第i-1行,右移一位得到的,即a[i][j]=a[i-1][j-1]。很容易我們就可以發現循環矩陣a和循環矩陣b的乘積矩陣c,c[i][j]=sum(a[i][k]*b[k][j])=sum(a[i-1][k-1]*b[j-1][k-1])=c[i-1][j-1]。那麼矩陣c也是一個循環矩陣,在做矩陣乘法的時候我們只需要算出第一行的值,其余行直接右移就可以得到,那麼算法的復雜度就會變為O(log(k)*n^2)。還需注意的是對於數據范圍會超int,要用long long,還有由於矩陣太大了,在函數中申請不了那麼大得空間,所以采用指針的方法去寫函數。

[cpp]
#include <iostream>  
#include <cstdio>  
#include <cstring>  
#define LL long long  
using namespace std; 
const int maxn=502; 
int n,m,d,k; 
LL tmp[maxn][maxn],e[maxn][maxn],c[maxn][maxn]; 
void mul(LL a[][maxn],LL b[][maxn]) 

    memset(c,0,sizeof(c)); 
    for(int k=0;k<n;++k) 
    { 
        if(a[0][k]) 
        for(int j=0;j<n;++j) 
        { 
            c[0][j]+=a[0][k]*b[k][j]; 
            if(c[0][j]>=m){c[0][j]%=m;} 
        } 
    } 
    for(int i=1;i<n;++i) 
    { 
        for(int j=0;j<n;++j)  
        { 
            c[i][j]=c[i-1][(j-1+n)%n]; 
        } 
    } 
    for(int i=0;i<n;++i) 
    { 
        for(int j=0;j<n;++j) 
        { 
            b[i][j]=c[i][j]; 
        } 
    } 

 
void expo(LL a[][maxn],int k) 

    if(k==1){ 
        for(int i=0;i<n;++i) 
        { 
            for(int j=0;j<n;++j) 
            { 
                e[i][j]=a[i][j]; 
            } 
        } 
        return; 
    } 
    memset(e,0,sizeof(e)); 
    for(int i=0;i<n;++i){e[i][i]=1;} 
    while(k) 
    { 
        if(k&1){mul(a,e);} 
        mul(a,a); 
        k>>=1; 
    } 

int main() 

    LL dat[maxn]; 
    scanf("%d%d%d%d",&n,&m,&d,&k); 
    for(int i=0;i<n;++i) 
    { 
        scanf("%lld",&dat[i]); 
        tmp[0][i]=0; 
    } 
    tmp[0][0]=1; 
    for(int i=1;i<=d;++i) 
    { 
        tmp[0][i]=tmp[0][n-i]=1; 
    } 
    for(int i=1;i<n;++i) 
    { 
        for(int j=0;j<n;++j) 
        { 
            tmp[i][j]=tmp[i-1][(j-1+n)%n]; 
        } 
    } 
    expo(tmp,k); 
    LL ans[maxn]; 
    memset(ans,0,sizeof(ans)); 
    for(int i=0;i<n;++i) 
    { 
        for(int j=0;j<n;++j) 
        { 
            ans[i]+=e[i][j]*dat[j]; 
            if(ans[i]>=m){ans[i]%=m;} 
        } 
    } 
    printf("%lld",ans[0]); 
    for(int i=1;i<n;++i) 
    { 
        printf(" %lld",ans[i]); 
    } 
    printf("\n"); 
    return 0; 

#include <iostream>
#include <cstdio>
#include <cstring>
#define LL long long
using namespace std;
const int maxn=502;
int n,m,d,k;
LL tmp[maxn][maxn],e[maxn][maxn],c[maxn][maxn];
void mul(LL a[][maxn],LL b[][maxn])
{
 memset(c,0,sizeof(c));
 for(int k=0;k<n;++k)
 {
  if(a[0][k])
  for(int j=0;j<n;++j)
  {
   c[0][j]+=a[0][k]*b[k][j];
   if(c[0][j]>=m){c[0][j]%=m;}
  }
 }
 for(int i=1;i<n;++i)
 {
  for(int j=0;j<n;++j)
  {
   c[i][j]=c[i-1][(j-1+n)%n];
  }
 }
 for(int i=0;i<n;++i)
 {
  for(int j=0;j<n;++j)
  {
   b[i][j]=c[i][j];
  }
 }
}

void expo(LL a[][maxn],int k)
{
 if(k==1){
  for(int i=0;i<n;++i)
  {
   for(int j=0;j<n;++j)
   {
    e[i][j]=a[i][j];
   }
  }
  return;
 }
 memset(e,0,sizeof(e));
 for(int i=0;i<n;++i){e[i][i]=1;}
 while(k)
 {
  if(k&1){mul(a,e);}
  mul(a,a);
  k>>=1;
 }
}
int main()
{
 LL dat[maxn];
 scanf("%d%d%d%d",&n,&m,&d,&k);
 for(int i=0;i<n;++i)
 {
  scanf("%lld",&dat[i]);
  tmp[0][i]=0;
 }
 tmp[0][0]=1;
 for(int i=1;i<=d;++i)
 {
  tmp[0][i]=tmp[0][n-i]=1;
 }
 for(int i=1;i<n;++i)
 {
  for(int j=0;j<n;++j)
  {
   tmp[i][j]=tmp[i-1][(j-1+n)%n];
  }
 }
 expo(tmp,k);
 LL ans[maxn];
 memset(ans,0,sizeof(ans));
 for(int i=0;i<n;++i)
 {
  for(int j=0;j<n;++j)
  {
   ans[i]+=e[i][j]*dat[j];
   if(ans[i]>=m){ans[i]%=m;}
  }
 }
 printf("%lld",ans[0]);
 for(int i=1;i<n;++i)
 {
  printf(" %lld",ans[i]);
 }
 printf("\n");
 return 0;
}

 

對於這道題,網上還有一段神代碼,在這裡同樣學習一下

[cpp]
#include <iostream>  
#include <cstdio>  
#include <cstring>  
#define LL long long  
using namespace std; 
int n,m,d,k; 
void mul(LL a[],LL b[]) 

      int i,j; 
      LL c[501]; 
      for(i=0;i<n;++i)for(c[i]=j=0;j<n;++j)c[i]+=a[j]*b[i>=j?(i-j):(n+i-j)]; 
      for(i=0;i<n;b[i]=c[i++]%m);                      

LL init[501],tmp[501]; 
int main() 

    int i,j; 
    scanf("%d%d%d%d",&n,&m,&d,&k); 
    for(i=0;i<n;++i)scanf("%lld",&init[i]); 
    for(tmp[0]=i=1;i<=d;++i)tmp[i]=tmp[n-i]=1; 
    while(k) 
    { 
            if(k&1)mul(tmp,init); 
            mul(tmp,tmp); 
            k>>=1;      
    } 
    for(i=0;i<n;++i)if(i)printf(" %lld",init[i]);else printf("%lld",init[i]); 
    printf("\n"); 
    return 0; 

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved