做法:
把圖中所有的圈縮成一個點,那麼就是求是否存在一個點,使得所有的點都能到達。
遍歷所有入度為0的點,對所有遍歷到的出度為0的點的標記+1;
若出度為0的點的數目大於兩個,則輸出0。否則若標記的值等於入度點的和,那麼輸出這個點縮點之前含的點。
#include<stdio.h>
#include<iostream>
#include<string.h>
#include<algorithm>
#include<queue>
#include<stack>
#include<map>
#include<vector>
#include<stdlib.h>
#define INF_MAX 0x7fffffff
#define INF 999999
#define max3(a,b,c) (max(a,b)>c?max(a,b):c)
#define min3(a,b,c) (min(a,b)<c?min(a,b):c)
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;
#define ll __int64
#define maxn 10001
#define maxm 50001
struct node
{
int u;
int v;
int w;
bool friend operator < (node a, node b){
return a.w < b.w;
}
}edge[maxn];
ll gcd(ll n,ll m){if(n<m) swap(n,m);return n%m==0?m:gcd(m,n%m);}
ll lcm(ll n,ll m){if(n<m) swap(n,m);return n/gcd(n,m)*m;}
vector<int>vec[maxn];
vector<int>vect[maxn];
stack<int>st;
int dnf[maxn],low[maxn],vis[maxn],sum[maxn],instack[maxn];
int du[maxn],du2[maxn];
int n,m;
int times,num;
void init()
{
int i;
for(i=0;i<=n;i++)
dnf[i]=low[i]=instack[i]=du[i]=du2[i]=vis[i]=sum[i]=0;
times=1;
num=1;
for(i=0;i<=n;i++)vec[i].clear();
for(i=0;i<=n;i++)vect[i].clear();
while(!st.empty())st.pop();
}
void tarjan(int x)
{
int i;
dnf[x]=low[x]=times++;
instack[x]=1;
st.push(x);
int n=vec[x].size();
for(i=0;i<n;i++)
{
int y=vec[x][i];
if(!dnf[y])
{
tarjan(y);
low[x]=min(low[x],low[y]);
}
else if (instack[y])
{
low[x]=min(low[x],dnf[y]);
}
}
if(low[x]==dnf[x])
{
int y=-1;
while(y!=x)
{
y=st.top();
st.pop();
instack[y]=0;
sum[num]++;
vis[y]=num;
}
num++;
}
}
void jiantu()
{
int i,j;
for(i=1;i<=n;i++)
{
int len=vec[i].size();
for(j=0;j<len;j++)
{
int y=vec[i][j];
if(vis[i]==vis[y])continue;
vect[vis[i]].push_back(vis[y]);
du[vis[y]]++;
du2[vis[i]]++;
}
}
}
void bfs(int x)
{
int visit[maxn],i;
for(i=0;i<=num;i++)
{
visit[i]=0;
}
queue<int>q;
q.push(x);
visit[x]=1;
while(!q.empty())
{
int y=q.front();
q.pop();
int len=vect[y].size();
int leap=0;
for(i=0;i<len;i++)
{
leap=1;
if(!visit[vect[y][i]])
{
q.push(vect[y][i]);
visit[vect[y][i]]=1;
}
}
if(leap==0)
{
vis[y]++;
}
}
}
int main()
{
int i;
while(~scanf("%d%d",&n,&m))
{
init();
int a,b;
for(i=0;i<m;i++)
{
scanf("%d%d",&a,&b);
vec[a].push_back(b);
}
for(i=1;i<=n;i++)if(!dnf[i])tarjan(i);
int ns=0;
jiantu();
mem(vis,0);
int t1,t2,ip;
t1=t2=0;
for(i=1;i<num;i++)
{
if(du[i]==0)t1++;
if(du2[i]==0)t2=i,ip=i;
}
if(t2>1)
{
cout<<"0"<<endl;
continue;
}
for(i=1;i<num;i++)
{
if(!du[i])
{
bfs(i),ns++;
}
}
int sums=0;
if(vis[ip]==t1)sums=sum[ip];
cout<<sums<<endl;
}
return 0;
}