HDU 2865 Birthday Toy

题目链接

题意:n个小圆组成的正n边形,中间有一个大圆。有木棍相连的两个圆不能有相同的颜色,旋转后相同视为相同的方案,求着色方案数。

\(\\\)

先选定一种颜色放在中间,剩下的\(k-1\)种颜色再摆在环上。下面直接令\(k=k-1\)

根据Burnside引理,\(ans=\sum_{a|n}f(a)\phi(\frac{n}{a})\)\(f(a)\)表示最多使用\(k\)种颜色且长度为\(a\)的,首尾以及相邻珠子颜色互不相同的方案数。计算\(f(n)\)时,假设\(n-1\)号珠子与\(1\)号珠子相同,则对答案的贡献为\((k-1)\cdot f(n-2)\);若不同,贡献为\((k-2)\cdot f(n-1)\)。所以\(f(n)=(k-1)\cdot f(n-2)+(k-2)\cdot f(n-1)\)。用矩阵快速幂就好了。

代码:

#include<bits/stdc++.h>
#define ll long long
#define N 1000005
#define mod 1000000007

using namespace std;
inline int Get() {int x=0,f=1;char ch=getchar();while(ch<'0'||ch>'9') {if(ch=='-') f=-1;ch=getchar();}while('0'<=ch&&ch<='9') {x=(x<<1)+(x<<3)+ch-'0';ch=getchar();}return x*f;}

ll n,k;
int p[N];
bool vis[N];

void pre(int n) {
    for(int i=2;i<=n;i++) {
        if(!vis[i]) p[++p[0]]=i;
        for(int j=1;j<=p[0]&&1ll*i*p[j]<=n;j++) {
            vis[i*p[j]]=1;
            if(i%p[j]==0) break;
        }
    }
}

int phi(int n) {
    ll ans=n;
    for(int i=1;1ll*p[i]*p[i]<=n;i++) {
        if(n%p[i]==0) ans=(ans-ans/p[i]);
        while(n%p[i]==0) n/=p[i];
    }
    if(n>1) ans=(ans-ans/n);
    return ans;
}

struct matrix {
    ll f[2][2];
    void Init() {memset(f,0,sizeof(f));}
}tem,g,t;

matrix operator *(const matrix &a,const matrix &b) {
    tem.Init();
    for(int i=0;i<2;i++)
        for(int j=0;j<2;j++)
            for(int k=0;k<2;k++)
                (tem.f[i][j]+=a.f[i][k]*b.f[k][j])%=mod;
    return tem;
}

matrix ksm(matrix g,int x) {
    matrix ans;
    ans.Init();
    for(int i=0;i<2;i++) ans.f[i][i]=1;
    for(;x;x>>=1,g=g*g)
        if(x&1) ans=ans*g;
    return ans;
}

ll ksm(ll t,ll x) {
    ll ans=1;
    for(;x;x>>=1,t=t*t%mod)
        if(x&1) ans=ans*t%mod;
    return ans;
}

ll cal(ll n) {
    if(n==1) return 0;
    matrix ans=t*ksm(g,n-2);
    return ans.f[0][1];
}

int main() {
    pre(1000000);
    while(scanf("%lld%lld",&n,&k)!=EOF) {
        g.f[0][0]=0,g.f[0][1]=k-2;
        g.f[1][0]=1,g.f[1][1]=k-3;
        t.f[0][0]=0,t.f[0][1]=(k-1)*(k-2)%mod;
        
        ll ans=0;
        for(int i=1,maxx=sqrt(n);i<=maxx;i++) {
            if(n%i==0) {
                (ans+=cal(i)*phi(n/i)%mod)%=mod;
                if(i*i!=n) (ans+=cal(n/i)*phi(i)%mod)%=mod;
            }
        }
        ans=ans*ksm(n,mod-2)%mod;
        ans=ans*k%mod;
        cout<<ans<<"\n";
    }
    return 0;
}

猜你喜欢

转载自www.cnblogs.com/hchhch233/p/10197519.html