程序師世界是廣大編程愛好者互助、分享、學習的平台,程序師世界有你更精彩!
首頁
編程語言
C語言|JAVA編程
Python編程
網頁編程
ASP編程|PHP編程
JSP編程
數據庫知識
MYSQL數據庫|SqlServer數據庫
Oracle數據庫|DB2數據庫
 程式師世界 >> 編程語言 >> C語言 >> C++ >> C++入門知識 >> bzoj 1036 樹鏈剖分

bzoj 1036 樹鏈剖分

編輯:C++入門知識

最近有點小頹廢,每天休息也不是很好,實在沒狀態CF,好久沒寫題解了,寫個水題來除除草

每天被寢室大神各種高端數據結構嚇傻不已,前段時間學了下樹鏈剖分,感覺還算比較好寫。

size[u]表示以u為根的節點個數,dep[u]表示u的深度,pre[u]表示u的父節點,top[u]表示u所在重鏈的頂端節點,son[u]表示與在同一重鏈上的u的兒子節點,id[u]表示u在線段樹中的位置,w[id[u]]表示線段樹中的點權。

名詞解釋:  

    重兒子:siz[u]為v的子節點中siz值最大的,那麼u就是v的重兒子。
    輕兒子:v的其它子節點。
    重邊:點v與其重兒子的連邊。
    輕邊:點v與其輕兒子的連邊。
    重鏈:由重邊連成的路徑。
    輕鏈:輕邊


重點性質:

    性質1:如果(v,u)為輕邊,則siz[u] * 2 < siz[v];

    性質2:從根到某一點的路徑上輕鏈、重鏈的個數都不大於logn


算法實現:

顯然可以dfs求出size,dep,pre,top,son,id,但是大數據可能出現爆棧的情況,於是我們可以用3次循環來求。

第一次正向循環求出pre,dep,再逆向循環求出size,再正向循環求出top,son,id。然後就可以將權值建立到線段樹中

建完樹就要開始我們核心算法了,比如更新u到v路徑,設f1=top[u],f2=top[v],f1!=f2時,若dep[f1]>=dep[f2],更新u到pre[f1]的權值,然後u=pre[f1].當f1=f2時,再更新u到v的權值即可,操作均為logn的。

上代碼:


[cpp] view plaincopyprint?
#include<cstdio>  
#include<algorithm>  
#include<iostream>  
#include<climits>  
#include<cmath>  
using namespace std; 
const int N=30005; 
int n,m,tot,head[N],next[N*2],to[N*2],val[N],pre[N],dep[N],q[N],size[N],son[N],top[N],w[N],id[N]; 
int MAX[N*4],sum[N*4]; 
bool v[N]; 
char s[20]; 
#define lc x << 1  
#define rc (lc) + 1  
inline void add(int u,int v) 

    to[++tot]=v; 
    next[tot]=head[u]; 
    head[u]=tot; 

void update(int x) 

    MAX[x]=max(MAX[lc],MAX[rc]); 
    sum[x]=sum[lc]+sum[rc];  

void build(int x,int l,int r) 

    if(l==r)  
    { 
        MAX[x]=sum[x]=w[l]; 
        return ;     
    } 
    int mid=(l+r)>>1; 
    build(lc,l,mid); 
    build(rc,mid+1,r); 
    update(x); 

void change(int x,int l,int r,int u,int val) 

    if(l==r) 
    { 
        MAX[x]=sum[x]=val; 
        return ; 
    } 
    int mid=(l+r)>>1; 
    if(u<=mid)  change(lc,l,mid,u,val); 
    else change(rc,mid+1,r,u,val); 
    update(x);    

int qmax(int x,int l,int r,int ql,int qr) 

    if(ql<=l && qr>=r)  return MAX[x]; 
    int mid=(l+r)>>1; 
    int tmp1=INT_MIN;int tmp2=INT_MIN; 
    if(ql<=mid) tmp1=qmax(lc,l,mid,ql,qr); 
    if(qr>mid)  tmp2=qmax(rc,mid+1,r,ql,qr); 
    return max(tmp1,tmp2); 

int qsum(int x,int l,int r,int ql,int qr) 

    if(ql<=l && qr>=r)  return sum[x]; 
    int mid=(l+r)>>1; 
    int tmp1=0;int tmp2=0; 
    if(ql<=mid) tmp1=qsum(lc,l,mid,ql,qr); 
    if(qr>mid)  tmp2=qsum(rc,mid+1,r,ql,qr); 
    return tmp1+tmp2; 

inline int gmax(int a,int b) 

    int tmp=INT_MIN; 
    while(top[a]!=top[b]) 
    { 
        if(dep[top[a]]<dep[top[b]]) swap(a,b); 
        tmp=max(tmp,qmax(1,1,n,id[top[a]],id[a]));  
        a=pre[top[a]]; 
    } 
    if(id[a]>id[b])   swap(a,b); 
    tmp=max(tmp,qmax(1,1,n,id[a],id[b])); 
    return tmp; 

inline int gsum(int a,int b) 

    int tmp=0; 
    while(top[a]!=top[b]) 
    { 
        if(dep[top[a]]<dep[top[b]]) swap(a,b); 
        tmp+=qsum(1,1,n,id[top[a]],id[a]);  
        a=pre[top[a]]; 
    } 
    if(id[a]>id[b])   swap(a,b); 
    tmp+=qsum(1,1,n,id[a],id[b]); 
    return tmp; 

int main() 

    scanf("%d",&n); 
    int a,b; 
    for(int i=1;i<n;++i) 
    { 
        scanf("%d%d",&a,&b); 
        add(a,b); 
        add(b,a); 
    } 
    for(int i=1;i<=n;++i)   scanf("%d",&val[i]); 
    v[dep[1]=q[0]=1]=true; 
    int r=0;tot=0; 
    for(int l=0;l<=r;++l) 
     for(int k=head[q[l]];k;k=next[k]) 
     { 
        if(!v[to[k]]) 
        { 
            v[to[k]]=true; 
            dep[q[++r]=to[k]]=dep[q[l]]+1; 
            pre[q[r]]=q[l]; 
        } 
     } 
    for(int i=r;i>=0;--i) 
    { 
        size[pre[q[i]]]+=++size[q[i]]; 
        if(size[q[i]]>size[son[pre[q[i]]]]) 
        son[pre[q[i]]]=q[i]; 
    }  
    for(int i=0;i<=r;++i) 
    { 
        if(!top[q[i]])   
         for(int k=q[i];k;k=son[k]) 
         { 
            top[k]=q[i]; 
            w[id[k]=++tot]=val[k]; 
         } 
    } 
    build(1,1,n); 
    scanf("%d",&m); 
    for(int i=1;i<=m;++i) 
    { 
        scanf("%s%d%d",s,&a,&b); 
        if(s[0]=='C')   change(1,1,n,id[a],b); 
        else if(s[1]=='M')  printf("%d\n",gmax(a,b)); 
        else printf("%d\n",gsum(a,b)); 
    } 
    return 0; 

#include<cstdio>
#include<algorithm>
#include<iostream>
#include<climits>
#include<cmath>
using namespace std;
const int N=30005;
int n,m,tot,head[N],next[N*2],to[N*2],val[N],pre[N],dep[N],q[N],size[N],son[N],top[N],w[N],id[N];
int MAX[N*4],sum[N*4];
bool v[N];
char s[20];
#define lc x << 1
#define rc (lc) + 1
inline void add(int u,int v)
{
    to[++tot]=v;
    next[tot]=head[u];
    head[u]=tot;
}
void update(int x)
{
    MAX[x]=max(MAX[lc],MAX[rc]);
    sum[x]=sum[lc]+sum[rc];
}
void build(int x,int l,int r)
{
    if(l==r)
    {
        MAX[x]=sum[x]=w[l];
        return ;   
    }
    int mid=(l+r)>>1;
    build(lc,l,mid);
    build(rc,mid+1,r);
    update(x);
}
void change(int x,int l,int r,int u,int val)
{
    if(l==r)
    {
        MAX[x]=sum[x]=val;
        return ;
    }
    int mid=(l+r)>>1;
    if(u<=mid)  change(lc,l,mid,u,val);
    else change(rc,mid+1,r,u,val);
    update(x);  
}
int qmax(int x,int l,int r,int ql,int qr)
{
    if(ql<=l && qr>=r)  return MAX[x];
    int mid=(l+r)>>1;
    int tmp1=INT_MIN;int tmp2=INT_MIN;
    if(ql<=mid) tmp1=qmax(lc,l,mid,ql,qr);
    if(qr>mid)  tmp2=qmax(rc,mid+1,r,ql,qr);
    return max(tmp1,tmp2);
}
int qsum(int x,int l,int r,int ql,int qr)
{
    if(ql<=l && qr>=r)  return sum[x];
    int mid=(l+r)>>1;
    int tmp1=0;int tmp2=0;
    if(ql<=mid) tmp1=qsum(lc,l,mid,ql,qr);
    if(qr>mid)  tmp2=qsum(rc,mid+1,r,ql,qr);
    return tmp1+tmp2;
}
inline int gmax(int a,int b)
{
    int tmp=INT_MIN;
    while(top[a]!=top[b])
    {
        if(dep[top[a]]<dep[top[b]]) swap(a,b);
        tmp=max(tmp,qmax(1,1,n,id[top[a]],id[a]));
        a=pre[top[a]];
    }
    if(id[a]>id[b])   swap(a,b);
    tmp=max(tmp,qmax(1,1,n,id[a],id[b]));
    return tmp;
}
inline int gsum(int a,int b)
{
    int tmp=0;
    while(top[a]!=top[b])
    {
        if(dep[top[a]]<dep[top[b]]) swap(a,b);
        tmp+=qsum(1,1,n,id[top[a]],id[a]);
        a=pre[top[a]];
    }
    if(id[a]>id[b])   swap(a,b);
    tmp+=qsum(1,1,n,id[a],id[b]);
    return tmp;
}
int main()
{
    scanf("%d",&n);
    int a,b;
    for(int i=1;i<n;++i)
    {
        scanf("%d%d",&a,&b);
        add(a,b);
        add(b,a);
    }
    for(int i=1;i<=n;++i)   scanf("%d",&val[i]);
    v[dep[1]=q[0]=1]=true;
    int r=0;tot=0;
    for(int l=0;l<=r;++l)
     for(int k=head[q[l]];k;k=next[k])
     {
        if(!v[to[k]])
        {
            v[to[k]]=true;
            dep[q[++r]=to[k]]=dep[q[l]]+1;
            pre[q[r]]=q[l];
        }
     }
    for(int i=r;i>=0;--i)
    {
        size[pre[q[i]]]+=++size[q[i]];
        if(size[q[i]]>size[son[pre[q[i]]]])
        son[pre[q[i]]]=q[i];
    }
    for(int i=0;i<=r;++i)
    {
        if(!top[q[i]]) 
         for(int k=q[i];k;k=son[k])
         {
            top[k]=q[i];
            w[id[k]=++tot]=val[k];
         }
    }
    build(1,1,n);
    scanf("%d",&m);
    for(int i=1;i<=m;++i)
    {
        scanf("%s%d%d",s,&a,&b);
        if(s[0]=='C')   change(1,1,n,id[a],b);
        else if(s[1]=='M')  printf("%d\n",gmax(a,b));
        else printf("%d\n",gsum(a,b));
    }
    return 0;
}


 

  1. 上一頁:
  2. 下一頁:
Copyright © 程式師世界 All Rights Reserved