APTX Blog

A Moe Blog Set Up By APTX

POJ 3233 Matrix Power Series(矩阵快速幂+二分)题解

洛谷:https://www.luogu.org/problemnew/show/U50124

Description

给定一个 n*n 的矩阵 A 以及一个正整数 k,计算\(S = A^1 + A^2 + A^3+…+A^k\)

Input

输入只有一组测试数据。输入的第一行包括三个正整数 nkm。接下来的 n 行每行包括 n 个非负整数,按照行优先的顺序输入矩阵 A 的元素。

Output

输出 S 中每一个元素 mod(%)m 以后的值

Sample Input

2 2 4
0 1
1 1

Sample Output

1 2
2 3

解析及代码

暴力:

我们考虑暴力解法,根据题意模拟,每次令i从1枚举到k,矩阵快速幂求解\(power(A,i)\),进行矩阵加法将\(power(A,i)\)累加起来。我觉得复杂度是\(O(k(n^3logn + n^2))\)的,反正肯定过不了。

//matrix.cpp
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;

const int MAXN = 35;

struct Mat {
	int a[MAXN][MAXN];
};

int n,k,MOD;
Mat x;
Mat Ans;

inline Mat jz(Mat a,Mat b){
    Mat tmp;
    memset(tmp.a,0,sizeof tmp.a);
    for(int i = 1;i <= n;++i)
        for(int j = 1;j <= n;++j)
            for(int k = 1;k <= n;++k)
               tmp.a[i][j] = (tmp.a[i][j] + (a.a[i][k] * b.a[k][j]) % MOD) % MOD;
    return tmp;
}
inline Mat power(Mat a,long long b){
	Mat ans;
	memset(ans.a,0,sizeof ans.a);
    for(int i = 1;i <= n;++i)
            ans.a[i][i]=1;
    while(b){
        if (b & 1) ans = jz(ans,a);
        a = jz(a,a);
        b >>= 1;
    }
    return ans;
} 

inline void Add(Mat T) {
	for(int i = 1;i <= n;++i)
		for(int j = 1;j <= n;++j)
			Ans.a[i][j] += T.a[i][j],Ans.a[i][j] %= MOD;
}

int main() {
	freopen("matrix.in","r",stdin);
	freopen("matrix.out","w",stdout);
	scanf("%d%d%d",&n,&k,&MOD);
	for(int i = 1;i <= n;++i)
		for(int j = 1;j <= n;++j)
			scanf("%d",&x.a[i][j]);
	for(int i = 1;i <= k;++i)
		Add(power(x,i));
	for(int i = 1;i <= n;++i) {
		for(int j = 1;j <= n;++j)
			printf("%d ",Ans.a[i][j] % MOD);
		puts("");
	}
	return 0;
}

正解:

考虑分类讨论,将k分为奇数和偶数进行分类讨论,对于k为偶数的情况,我们可以推出如下规律(例如k为6)

\(S(6)=A^1+A^2+A^3+A^4+A^5+A^6\)

\(=A^1+A^2+A^3+A^3*(A+A^2+A^3)\)

\(=S(3)*(1 + A^3)\)

所以当k为偶数时,就有:\(S(k)=S(k/2)*(1+A^\frac{k}{2})\)

那么当k为奇数时,显然有:\(S(k)=S(k-1)+A^k\)

根据以上规律,二分进行求解 复杂度不知道多少,反正是:\(O(AC)\)

//matrix.cpp
#include <cstdio>
#include <cmath>
#include <cstring>
#include <algorithm>
#include <iostream>
#include <vector>
#include <queue>
using namespace std;

const int MAXN = 35;

struct Mat {
	int a[MAXN][MAXN];
	void clear() {
		memset(a,0,sizeof a);
	}
};

int n,k,MOD;
Mat x;
Mat Ans;

inline Mat Mul(Mat a,Mat b){
    Mat tmp;
    tmp.clear();
    for(int i = 1;i <= n;++i)
        for(int j = 1;j <= n;++j)
            for(int k = 1;k <= n;++k)
               tmp.a[i][j] = (tmp.a[i][j] + (a.a[i][k] * b.a[k][j]) % MOD) % MOD;
    return tmp;
}
inline Mat power(Mat a,long long b){
	Mat ans;
	ans.clear();
    for(int i = 1;i <= n;++i)
            ans.a[i][i]=1;
    while(b){
        if (b & 1) ans = Mul(ans,a);
        a = Mul(a,a);
        b >>= 1;
    }
    return ans;
} 

inline Mat Add(Mat T_1,Mat T_2) {
	Mat Tmp;
	Tmp.clear();
	for(int i = 1;i <= n;++i)
		for(int j = 1;j <= n;++j)
			Tmp.a[i][j] = T_1.a[i][j] + T_2.a[i][j],Tmp.a[i][j] %= MOD;
	return Tmp;
}

inline Mat Work(Mat a,int k) {
	if(k == 1) return a;
	if(k & 1) return Add(Work(a,k - 1),power(a,k));
	else return Mul(Add(power(a,0),power(a,k >> 1)),Work(a,k >> 1));
}

int main() {
	freopen("matrix.in","r",stdin);
	freopen("matrix.out","w",stdout);
	scanf("%d%d%d",&n,&k,&MOD);
	for(int i = 1;i <= n;++i)
		for(int j = 1;j <= n;++j)
			scanf("%d",&x.a[i][j]);
	Ans = Work(x,k); 
	for(int i = 1;i <= n;++i) {
		for(int j = 1;j <= n;++j)
			printf("%d ",Ans.a[i][j] % MOD);
		puts("");
	}
	return 0;
}

 

点赞

发表评论

电子邮件地址不会被公开。 必填项已用*标注