[Code+#1]大吉大利,晚上吃鸡!

输入输出样例

输入样例#1:

7 7 1 7
1 2 2
2 4 2
4 6 2
6 7 2
1 3 2
3 5 4
5 7 2

输出样例#1:

6

输入样例#2:

5 5 1 4
1 2 1
1 3 1
2 4 1
3 4 1
4 5 1

输出样例#2:

3

输入样例#3:

6 7 1 4
1 2 1
1 3 1
2 4 1
3 4 1
4 5 1
1 6 2
6 4 2

输出样例#3:

5


这题好码农啊 写挂了好多发

这题就是让我们找出符合条件的点对的数量

符合条件的点对能够覆盖所有从\(S~T\)的最短路径并且必须不在同一条最短路径上

那么我们可以先正反跑两边最短路记录通过每个点的最短路数目

那么显然只有点对\(<u,v>\)符合\(f[u]+f[v]=f[T]\)才是合法的

然后就该考虑如何处理不在同一条最短路径上了

我们可以在跑最短路的时候顺便记录下一条从S~T的最短路径

然后可以对这个最短路径进行编号\(1~Num\)

然后我们要统计每个不在找出的这条最短路径上的每个点能对我们找出的这条最短路上的哪些点产生影响

显然能影响的点是我们找到的最短路径上的一段连续的点

因为如果这个点能在走最短路径的时候被u走到,那么一定能被u的前驱/后继走到

如果没有最短路径经过这个点,那么这个点一定会对我们找出的最短路径上的所有点产生影响

所以我们只需要求有最短路径经过的点对找到的最短路径上的点的贡献

直接求这段连续的点比较困难,我们可以正反两遍拓扑排序求

无向边怎么拓扑排序?

其实并不是真正的拓扑排序

我们只需要枚举每条边\(<u,v>\)然后查看这条边是不是最短路径经过的边,如果是就\(++d[v]\)

这样我们就可以拓扑排序了

拓扑的时候用\(l[u]/r[u]\)来更新\(l[v]/r[v]\)

最后扫一遍就好了

#include<map>
#include<queue>
#include<vector>
#include<cstdio>
#include<cstring>
#include<iostream>
#include<algorithm>
# define LL long long
const int M = 50005 ;
const LL INF = 1e15 ;
using namespace std ;
inline int read() {
    char c = getchar() ; int x = 0 , w = 1 ;
    while(c>'9'||c<'0') { if(c=='-') w = -1 ; c = getchar() ; }
    while(c>='0'&&c<='9') { x = x*10+c-'0' ; c = getchar() ; }
    return x * w ; 
}
bool vis[M] ;
int n , m , S , T , upp ;
int hea[M] , num , Nxt[M] , Num ;
int p[M] , l[M] , r[M] , d[M] ;
LL dis[2][M] , f[2][M] , e[M] , Ans ;
map < LL , int > t ;
vector < int > pl[M] , pr[M] ;
struct Node { int id ; LL v ; } ;
struct E { int Nxt , to ; LL dis ; } edge[M << 1] ;
inline bool operator < (Node a , Node b) { return a.v > b.v ; }
inline void add_edge(int from , int to , int dis) {
    edge[++num].Nxt = hea[from] ; edge[num].to = to ;
    edge[num].dis = dis ; hea[from] = num ;
}
inline void dijkstra(int t , int S) {
    priority_queue < Node > q ;
    memset(vis , false , sizeof(vis)) ;
    memset(dis[t] , 63 , sizeof(dis[t])) ;
    f[t][S] = 1 ; dis[t][S] = 0 ;
    q.push((Node) { S , 0 }) ;
    while(!q.empty()) {
        int u = q.top().id ; q.pop() ; 
        if(vis[u]) continue ; vis[u] = true ;
        for(int i = hea[u] ; i ; i = edge[i].Nxt) {
            int v = edge[i].to ;
            if(dis[t][v] > dis[t][u] + edge[i].dis) {
                f[t][v] = f[t][u] ;
                dis[t][v] = dis[t][u] + edge[i].dis ;
                Nxt[v] = u ;
                if(vis[v]) continue ;
                q.push((Node) { v , dis[t][v] }) ;
            }
            else if(dis[t][v] == dis[t][u] + edge[i].dis)
                f[t][v] += f[t][u] ;
        }
    }
}
inline void Topsort(int t) {
    queue < int > q ;
    for(int u = 1 ; u <= n ; u ++)
        for(int i = hea[u] ; i ; i = edge[i].Nxt) {
            int v = edge[i].to ;
            if(dis[t][v] + dis[t ^ 1][u] + edge[i].dis == upp)
                ++d[v] ;
        }
    while(!q.empty()) {
        int u = q.front() ; q.pop() ;
        for(int i = hea[u] ; i ; i = edge[i].Nxt) {
            int v = edge[i].to ;
            if(dis[t][v] + dis[t ^ 1][u] + edge[i].dis == upp) {
                -- d[v] ;
                if(d[v] == 0) q.push(v) ;
                if(!t) l[v] = max(l[v] , l[u]) ;
                else r[v] = max(r[v] , r[u]) ;
            }
        }
    }
}
int main() {
    n = read() ; m = read() ; S = read() ; T = read() ;
    for(int i = 1 , u , v , w ; i <= m ; i ++) {
        u = read() , v = read() , w = read() ;
        add_edge(u , v , w) ; add_edge(v , u , w) ;
    }
    dijkstra(0 , S) ;
    if(dis[0][T] > INF) { printf("%lld\n" , 1LL * n * (n - 1) / 2) ; return 0 ; }
    dijkstra(1 , T) ;
    upp = dis[0][T] ;
    Nxt[T] = 0 ;
    for(int i = S ; i ; i = Nxt[i]) {
        p[++Num] = i ;
        l[i] = Num + 1 , r[i] = Num - 1 ;
    }
    for(int i = 1 ; i <= n ; i ++)
        if(l[i] == r[i] && l[i] == 0)
            l[i] = 1 , r[i] = Num ;
    Topsort(0) ; Topsort(1) ;
    for(int i = 1 ; i <= n ; i ++) {
        if(dis[0][i] + dis[1][i] == upp) 
            e[i] = f[0][i] * f[1][i] ;
        if(l[i] > r[i]) continue ;
        pl[l[i]].push_back(i) ;
        pr[r[i]].push_back(i) ;
    }
    for(int i = 1 ; i <= Num ; i ++) {
        for(int j = 0 ; j < pl[i].size() ; j ++)
            ++t[e[pl[i][j]]] ;
        Ans += t[e[T] - e[p[i]]] ;
        for(int j = 0 ; j < pr[i].size() ; j ++)
            --t[e[pr[i][j]]] ;
    }
    printf("%lld\n",Ans) ;
    return 0 ;
}

猜你喜欢

转载自www.cnblogs.com/beretty/p/9711242.html
今日推荐