Count on a tree(树上路径第K小)

旧时模样 提交于 2019-12-05 20:59:31

题目链接:https://www.spoj.com/problems/COT/en/

题意:求树上A,B两点路径上第K小的数

思路:主席树实际上是维护的一个前缀和,而前缀和不一定要出现在一个线性表上。

比如说我们从一棵树的根节点进行DFS,得到根节点到各节点的距离dist[x]——这是一个根-x路径上点与根节点距离的前缀和。

利用这个前缀和,我们可以解决一些树上任意路径的问题,比如在线询问[a,b]点对的距离——答案自然是dist[a]+dist[b]-2*dist[lca(a,b)]。

DFS遍历整棵树,然后在每个节点上建立一棵线段树,某一棵线段树的“前一版本”是位于该节点父亲节点fa的线段树。

利用与之前类似的方法插入点权(排序离散)。那么对于询问[a,b],答案就是root[a]+root[b]-root[lca(a,b)]-root[fa[lca(a,b)]]上的第k大。

#include<cstdio>
#include<cstring>
#include<queue>
#include<cmath>
#include<algorithm>
#include<map>
#include<vector>
#include<string>
#include<set>
#define ll long long
#define maxn 100007
using namespace std;
const int MAXN=1e5+100;
const int POW=18;
int num[MAXN],node[MAXN];
struct point
{
    int l;
    int r;
    int sum;
}T[MAXN*20];
int root[MAXN];
vector<int> G[MAXN];
int d[MAXN];
int p[MAXN][POW];
int tot;
int f[MAXN];
int n,m;
void build(int l,int r,int& rt)
{
    rt=++tot;
    T[rt].sum=0;
    if(l>=r)return;
    int m=(l+r)>>1;
    build(l,m,T[rt].l);
    build(m+1,r,T[rt].r);
}
void update(int last,int p,int l,int r,int &rt)
{
    rt=++tot;
    T[rt].l=T[last].l;
    T[rt].r=T[last].r;
    T[rt].sum=T[last].sum+1;
    if(l>=r)return ;
    int m=(l+r)>>1;
    if(p<=m)update(T[last].l,p,l,m,T[rt].l);
    else update(T[last].r,p,m+1,r,T[rt].r);
}
int query(int left_rt,int right_rt,int lca_rt,int lca_frt,int l,int r,int k)
{
    if(l>=r)return l;
    int m=(l+r)>>1;
    int cnt=T[T[right_rt].l].sum+T[T[left_rt].l].sum-T[T[lca_rt].l].sum-T[T[lca_frt].l].sum;
    if(k<=cnt)
        return query(T[left_rt].l,T[right_rt].l,T[lca_rt].l,T[lca_frt].l,l,m,k);
    else
        return query(T[left_rt].r,T[right_rt].r,T[lca_rt].r,T[lca_frt].r,m+1,r,k-cnt);
}
void dfs(int u,int fa,int cnt)
{
    f[u]=fa;
    d[u]=d[fa]+1;
    p[u][0]=fa;
    for(int i=1;i<POW;i++)
        p[u][i]=p[p[u][i-1]][i-1];
    update(root[fa],num[u],1,cnt,root[u]);
    for(int i=0;i<(int)G[u].size();i++)
    {
        int v=G[u][i];
        if(v==fa)continue;
        dfs(v,u,cnt);
    }
}
int lca(int a,int b)
{
    if(d[a]>d[b])
        a^=b,b^=a,a^=b;
    if(d[a]<d[b])
    {
        int del=d[b]-d[a];
        for(int i=0;i<POW;i++)
            if(del&(1<<i))b=p[b][i];
    }
    if(a!=b)
    {
        for(int i=POW-1;i>=0;i--)
        {
            if(p[a][i]!=p[b][i])
            {
                a=p[a][i],b=p[b][i];
            }
        }
        a=p[a][0],b=p[b][0];
    }
    return a;
}
void init()
{
    for(int i=0;i<=n;i++)
    {
        G[i].clear();
    }
    memset(d,0,sizeof(d));
    memset(p,0,sizeof(p));
    memset(f,0,sizeof(f));
}
int main()
{
    while(~scanf("%d%d",&n,&m))
    {
        init();
        for(int i=1;i<=n;i++)
        {
            scanf("%d",&num[i]);
            node[i]=num[i];
        }
        tot=0;
        sort(node+1,node+1+n);
        int cnt=unique(node+1,node+n+1)-node-1;
        for(int i=1;i<=n;i++)
        {
            num[i]=lower_bound(node+1,node+cnt+1,num[i])-node;
        }
        int a,b,c;
        for(int i=1;i<=n-1;i++)
        {
            scanf("%d%d",&a,&b);
            G[a].push_back(b);
            G[b].push_back(a);
        }
        build(1,cnt,root[0]);
        dfs(1,0,cnt );
        while(m--)
        {
            scanf("%d%d%d",&a,&b,&c);
            int t=lca(a,b);
            int id=query(root[a],root[b],root[t],root[f[t]],1,cnt,c);
            printf("%d\n",node[id]);
        }
    }
    return 0;
}

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!