「题解」树(未完待续:全部)

#include<cstdio>
#include<algorithm>
using std::swap;
using std::sort;
#include<vector>
using std::vector;
typedef long long ll;
#define getchar() (p1==p2&&(p2=(p1=buf)+fread(buf,1,100000,stdin),p1==p2)?EOF:*p1++)
static char buf[100000],*p1=buf,*p2=buf;
inline int read(void){
    register char ch=getchar();
    register int res=0;
    while(ch<'0'||'9'<ch)ch=getchar();
    while('0'<=ch&&ch<='9')res=10*res+ch-'0',ch=getchar();
    return res;
}
const int MAXN=100000+5;
const int MAXLOG2N=17+1;
struct Matrix{
    int x,y1,y2,val;
    Matrix(int a,int b,int c,int d){
        x=a,y1=b,y2=c,val=d;
        return;
    }
    bool operator<(const Matrix &a)const{
        return x<a.x;
    }
};
struct SegmentTree{
    struct Segment{
        int val,tag;
    };
    Segment unit[MAXN*8+5];
    void Update(int ID,int l,int r,int posl,int posr,int val){
        if(posl<=l&&r<=posr){
            unit[ID].tag+=val;
            if(unit[ID].tag)
                unit[ID].val=r-l+1;
            else if(l==r)
                unit[ID].val=0;
            else
                unit[ID].val=unit[ID<<1].val+unit[ID<<1|1].val;
            return;
        }
        register int mid=(l+r)>>1;
        if(posl<=mid)
            Update(ID<<1,l,mid,posl,posr,val);
        if(posr>=mid+1)
            Update(ID<<1|1,mid+1,r,posl,posr,val);
        if(unit[ID].tag)
            unit[ID].val=r-l+1;
        else if(l==r)
            unit[ID].val=0;
        else
            unit[ID].val=unit[ID<<1].val+unit[ID<<1|1].val;
        return;
    }
};
int n,m;
int cnt,head[MAXN],to[MAXN<<1],Next[MAXN<<1];
int dep[MAXN],fa[MAXN][MAXLOG2N];
int Time,first[MAXN],last[MAXN];
vector<Matrix> V;
SegmentTree T;
void Add_Matrix(int,int,int,int);
void Add_Edge(int,int);
void DFS(int,int,int);
int LCA(int,int);
int main(void){
    //freopen("tree.in","r",stdin);
    //freopen("tree.out","w",stdout);
    register int i,j;
    register ll ans;
    n=read(),m=read();
    //scanf("%d%d",&n,&m);
    ans=(ll)n*(n-1)/2;
    for(i=1;i<n;++i){
        static int u,v;
        u=read(),v=read();
        //scanf("%d%d",&u,&v);
        Add_Edge(u,v);
        Add_Edge(v,u);
    }
    DFS(1,0,1);
    for(i=1;i<=m;++i){
        static int u,v;
        u=read(),v=read();
        //scanf("%d%d",&u,&v);
        if(first[u]>first[v])
            swap(u,v);
        if(first[v]<=last[u]&&first[v]>first[u]){
            int lca=LCA(u,v);
            if(first[lca]>1)
                Add_Matrix(1,first[lca]-1,first[v],last[v]);
            if(last[lca]<n)
                Add_Matrix(first[v],last[v],last[lca]+1,n);
        }
        else
            Add_Matrix(first[u],last[u],first[v],last[v]);
    }
    sort(V.begin(),V.end());
    for(i=1,j=0;i<=n;++i){
        while(j<(int)V.size()&&V[j].x<=i){
            T.Update(1,1,n,V[j].y1,V[j].y2,V[j].val);
            ++j;
        }
        ans-=T.unit[1].val;
    }
    printf("%lld\n",ans);
    return 0;
}
void Add_Matrix(int x1,int x2,int y1,int y2){
    V.push_back(Matrix(x1,y1,y2,1));
    V.push_back(Matrix(x2+1,y1,y2,-1));
}
void Add_Edge(int u,int v){
    Next[++cnt]=head[u];
    to[cnt]=v;
    head[u]=cnt;
    return;
}
void DFS(int ID,int father,int depth){
    register int i;
    first[ID]=++Time;
    dep[ID]=depth;
    fa[ID][0]=father;
    for(i=1;i<MAXLOG2N;++i)
        fa[ID][i]=fa[fa[ID][i-1]][i-1];
    for(i=head[ID];i;i=Next[i])
        if(to[i]!=father)
            DFS(to[i],ID,depth+1);
    last[ID]=Time;
    return;
}
int LCA(int u,int v){
    register int i;
    for(i=MAXLOG2N-1;i>=0;--i)
        if(dep[fa[v][i]]>dep[u])
            v=fa[v][i];
    return v;
}