思路:首先一个最重要是思路:ai*aj%mod=k%mod 即a[i]=k*inv[a[j]]%mod。
所以对于一条路径上的乘积是否为k,我们考虑将其分成2部分,第一部分乘起来放map里,第二部分乘了以后去map里找,如果找到,说明有这样的组合可以乘积为k,既满足题意。
具体实现的思路自然是点分治,一次找重心做为根节点,需要注意的是,每次重新选根节点,都要清空map,因为重新找一颗子树时,相互之间没有关系。
具体实现看代码和注释:
#include<iostream>
#include<map>
#include<string>
#include<cstring>
#include<vector>
#include<algorithm>
#include<set>
#include<sstream>
#include<cstdio>
#include<cmath>
#include<climits>
using namespace std;
const int maxn=1e5+7;
const int inf=0x3f3f3f3f;
typedef long long ll;
const int mod=1e6+3;
int n,k,allnode;
int head[maxn*2];
int num;
int dp[maxn];
int size[maxn];
int Focus,M;
int a[maxn],tmp[maxn];
bool vis[maxn];
int ansx,ansy;
int inv[mod];
int id[maxn],cnt,mp[mod];
struct Edge
{
int u,v,w,next;
}edge[maxn<<2];
void addEdge(int u,int v,int w)
{
edge[num].u=u;
edge[num].v=v;
edge[num].w=w;
edge[num].next=head[u];
head[u]=num++;
}
void init()
{
memset(head,-1,sizeof(head));
memset(vis,0,sizeof(vis));
memset(mp,0,sizeof(mp));
num=0;
}
void getFocus(int u,int pre)
{
size[u]=1;
dp[u]=0;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==pre||vis[v]) continue;
getFocus(v,u);
size[u]+=size[v];
dp[u]=max(dp[u],size[v]);
}
dp[u]=max(dp[u],allnode-size[u]);
if(M>dp[u])
{
M=dp[u];
Focus=u;
}
}
void dfs(int u,int pre,int val)
{
tmp[cnt]=val,id[cnt++]=u;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(v==pre||vis[v]) continue;
dfs(v,u,(ll)a[v]*val%mod);
}
}
void query(int val,int ed)
{
//printf("%d %d\n",val,ed);
int tmp=(ll)inv[val]*k%mod;
int st=mp[tmp];
//printf("%d\n\n",st);
if(st==0||st==ed) return;
if(st>ed) swap(st,ed);
if(st<ansx||(st==ansx&&ansy>ed))
{
ansx=st;
ansy=ed;
//printf("%d %d...\n",st,ed);
//printf("%d... %d\n",ansx,ansy);
}
}
void solve(int u)
{
vis[u]=1;
mp[a[u]]=u;
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(vis[v]) continue;
cnt=0;
dfs(v,u,a[v]);
for(int j=0;j<cnt;j++)
{
query(tmp[j],id[j]);//为了保证根节点只被乘一次,所以在map中放乘根节点的值
}
for(int j=0;j<cnt;j++)
{
//printf("id %d\n",id[j]);
tmp[j]=(ll)a[u]*tmp[j]%mod;
if(mp[tmp[j]]==0||mp[tmp[j]]>id[j])
{
mp[tmp[j]]=id[j];
}
}
}
mp[a[u]]=0;//清空
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(vis[v]) continue;
cnt=0;
dfs(v,u,(ll)a[u]*a[v]%mod);
for(int j=0;j<cnt;j++)
{
mp[tmp[j]]=0;
}
}
for(int i=head[u];i!=-1;i=edge[i].next)
{
int v=edge[i].v;
if(vis[v]) continue;
M=1e9,Focus=0,allnode=size[v];
getFocus(v,u);
solve(Focus);
}
}
void pre() //线性预处理
{
inv[1]=1;
for(int i=2;i<mod;i++)
{
int x=mod/i,y=mod%i;
inv[i]=((ll)inv[y]*(-x)%mod+mod)%mod;
}
}
int main()
{
#ifndef ONLINE_JUDGE
freopen("in.txt","r",stdin);
freopen("out.txt","w",stdout);
#endif
pre();
while(scanf("%d%d",&n,&k)!=EOF)
{
init();
for(int i=1;i<=n;i++)
{
scanf("%d",&a[i]);
}
for(int i=1;i<n;i++)
{
int u,v;
scanf("%d%d",&u,&v);
addEdge(u,v,0);
addEdge(v,u,0);
}
allnode=n,M=1e9,Focus=0;
ansx=ansy=mod;
getFocus(1,0);
solve(Focus);
if(ansx==mod)
{
puts("No solution");
}
else
{
printf("%d %d\n",ansx,ansy);
}
}
return 0;
}