虚树

Deadly 提交于 2020-12-17 02:10:17

虚树做法

虚树是把原树中少量的有效节点和他们两两的LCA拿出来,去除一些无效节点,从而降低复杂度。

如果有效节点是K个的话 加上LCA 虚树中的的点最多为min(n,2*K)个

建虚树的复杂度为O(k*logn) 虚树上树形dp的复杂度为O(k)

证明:

按dfn值从小到大加入有效节点

假设现在加入的节点为X 前面的节点为Y 再前面的节点为Z

如果X与Y的LCA为LCA1 X与Z的LCA为LCA2

LCA1与LCA2里面深度浅的为LCA3 则Y与Z的LCA就为LCA3 是之前存在过的

所以每加入一个点 多出来的LCA只会是一个

建树步骤:

1.dfs得到节点的必要信息 比如dfn deep fa[i][j] minedge[i][j]/maxedge[i][j].

2.建立一个栈 表示从根要栈顶这条链 每次新加入的节点为X 栈顶为P 则LCA为LCA(X,P)

然后就有两种情况 第一种是X与P不在一条链上  另一种是X与P在一条链上

则dfn(LCA(X,P))<=dfn(P)<dfn(X)

对于第一种情况要仔细讨论 第二种情况直接把X压入栈就行了

第一种情况表示P以下包括P的子树已经没有有效的点了 因为我们是按照dfn的大小来遍历的

void insert(int x){
    if (!top){
        st[++top]=x;
        return;
    }
    int ll=lca(st[top],x);
    while (dep[st[top-1]]>dep[ll]&&top>1){
        add(st[top-1],st[top],dist(st[top-1],st[top]));
        top--;
    }
    if (dep[ll]<dep[st[top]]){
        add(ll,st[top],dist(ll,st[top]));
        top--; 
    }
    if (!top||dep[st[top]]<dep[ll]) st[++top]=ll;
    st[++top]=x;
}

P2495 [SDOI2011]消耗战

O(n)dp O(n*m)的做法

dp中对于要删除的点 如果能到达此点的话当前最佳情况为删除当前点与父亲节点之间的边

对于不用删除的点 有两种删法 一种是所有子树中需要删除的点返回值的和 一种是直接删除该点与父亲节点之间的边

/*Huyyt*/
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const double eps = 1e-8;
const int dir[8][2] = {{0, 1}, {1, 0}, {0, -1}, { -1, 0}, {1, 1}, {1, -1}, { -1, -1}, { -1, 1}};
const int mod = 1e9 + 7, gakki = 5 + 2 + 1 + 19880611 + 1e9;
const int MAXN = 3e5 + 5, MAXM = 3e5 + 5;
const int MAXQ = 100010, INF = 1e9;
int to[MAXM << 1], nxt[MAXM << 1], Head[MAXN], tot = 1;
ll cost[MAXM << 1];
inline void addedge(int u, int v, ll c)
{
        to[++tot] = v;
        nxt[tot] = Head[u];
        cost[tot] = c;
        Head[u] = tot;
}
inline void read(int &v)
{
        v = 0;
        char c = 0;
        int p = 1;
        while (c < '0' || c > '9')
        {
                if (c == '-')
                {
                        p = -1;
                }
                c = getchar();
        }
        while (c >= '0' && c <= '9')
        {
                v = (v << 3) + (v << 1) + c - '0';
                c = getchar();
        }
        v *= p;
}
int k[500005];
bool del[500005];
ll ans = 0;
ll faedge[500005];
void dfs1(int x, int fa)
{
        for (int v, i = Head[x]; i; i = nxt[i])
        {
                if (v = to[i], v != fa)
                {
                        faedge[v] = cost[i];
                        dfs1(v, x);
                }
        }
}
ll get_ans(int x, int fa)
{
        ll anser = 0;
        for (int v, i = Head[x]; i; i = nxt[i])
        {
                v = to[i];
                if (v != fa)
                {
                        anser += get_ans(v, x);
                }
                //cout<<x<<" "<<v<<" "<<anser<<endl;
        }
        if (del[x])
        {
                return faedge[x];
        }
        return min(faedge[x], anser);
}
int main()
{
        ios_base::sync_with_stdio(0);
        cin.tie(0);

        //freopen("jqkout.txt", "w", stdout);
        int n;
        int u, v, c;
        read(n);
        for (int i = 1; i <= n - 1; i++)
        {
                read(u), read(v), read(c);
                addedge(u, v, c);
                addedge(v, u, c);
        }
        dfs1(1, -1);
        faedge[1] = LLONG_MAX;
        int m, number;
        read(m);
        while (m--)
        {
                ans = 0;
                read(number);
                for (int i = 1; i <= number; i++)
                {
                        read(k[i]);
                        del[k[i]] = true;
                }
                cout << get_ans(1, -1) << endl;
                for (int i = 1; i <= number; i++)
                {
                        del[k[i]] = false;
                }
        }
        return 0;
}
View Code

虚树做法

/*Huyyt*/
#include<bits/stdc++.h>
#define mem(a,b) memset(a,b,sizeof(a))
using namespace std;
typedef long long ll;
typedef unsigned long long ull;
const double eps = 1e-8;
const int dir[8][2] = {{0, 1}, {1, 0}, {0, -1}, { -1, 0}, {1, 1}, {1, -1}, { -1, -1}, { -1, 1}};
const int mod = 1e9 + 7, gakki = 5 + 2 + 1 + 19880611 + 1e9;
const int MAXN = 3e5 + 5, MAXM = 3e5 + 5, MAXQ = 100010, INF = 1e9;
const ll LLINF = (1LL << 50);
int to[MAXM << 1], nxt[MAXM << 1], Head[MAXN], tot = 1;
ll cost[MAXM << 1];
struct node
{
        int u, v;
        ll c;
} edge[MAXM << 1];
inline void addedge(int u, int v, int c)
{
        if (u == v)
        {
                return ;
        }
        to[++tot] = v;
        nxt[tot] = Head[u];
        cost[tot] = c;
        Head[u] = tot;
}
template <typename T> inline void read(T&x)
{
        char cu = getchar();
        x = 0;
        bool fla = 0;
        while (!isdigit(cu))
        {
                if (cu == '-')
                {
                        fla = 1;
                }
                cu = getchar();
        }
        while (isdigit(cu))
        {
                x = x * 10 + cu - '0', cu = getchar();
        }
        if (fla)
        {
                x = -x;
        }
}
int n;
int jumpfa[MAXN][20];
int jumpdis[MAXN][20];
int dfn[MAXN], deep[MAXN];
int cnt = 0;
int xutot = 0;
int xusta[MAXN];
int k[MAXN];
bool del[MAXN];
ll minedge[MAXN];
ll finalans[MAXN];
void dfs1(int x, int fa)
{
        jumpfa[x][0] = fa;
        dfn[x] = ++cnt;
        for (int i = 1; i <= 18; i++)
        {
                jumpfa[x][i] = jumpfa[jumpfa[x][i - 1]][i - 1];
        }
        for (int v, i = Head[x]; i; i = nxt[i])
        {
                if (v = to[i], v != fa)
                {
                        deep[v] = deep[x] + 1;
                        minedge[v] = min(minedge[x], cost[i]);
                        dfs1(v, x);
                }
        }
}
inline int lca(int x, int y)
{
        if (deep[x] < deep[y])
        {
                swap(x, y);
        }
        int t = 0;
        while ((1 << t) <= deep[x])
        {
                t++;
        }
        t--;
        for (int i = t; i >= 0; i--)
        {
                if (deep[x] - (1 << i) >= deep[y])
                {
                        x = jumpfa[x][i];
                }
        }
        if (x == y)
        {
                return x;
        }
        for (int i = t; i >= 0; i--)
        {
                if (jumpfa[x][i] != jumpfa[y][i])
                {
                        x = jumpfa[x][i], y = jumpfa[y][i];
                }
        }
        return jumpfa[x][0];
}
inline void get_ans(int x, int fa)
{
        ll anser = 0;
        //cout << x << " del " << del[x] <<" faedge "<<faedge[x]<< endl;
        finalans[x] = minedge[x];
        for (int i = Head[x]; i; i = nxt[i])
        {
                if (to[i] != fa)
                {
                        get_ans(to[i], x);
                        anser += finalans[to[i]];
                        //cout << x << " to " << to[i] << " value " << anser << endl;
                }
        }
        Head[x] = 0;
        if (!anser)
        {
                finalans[x] = minedge[x];
        }
        else
        {
                finalans[x] = min(finalans[x], anser);
        }
}
inline bool cmp(int a, int b)
{
        return dfn[a] < dfn[b];
}
void init()
{
        minedge[1] = LLONG_MAX;
        //mem(jumpdis, 0x3f3f3f3f);
        //minedge[1] = INT_MAX;
        //faedge[1] = INT_MAX;
        deep[1] = 1;
}
int main()
{
        ios_base::sync_with_stdio(0);
        cin.tie(0);

        //freopen("jqkout.txt", "w", stdout);
        int u, v, c;
        read(n);
        for (int i = 1; i <= n - 1; i++)
        {
                read(u), read(v), read(c);
                edge[i].u = u, edge[i].v = v, edge[i].c = c;
                addedge(u, v, c);
                addedge(v, u, c);
        }
        init();
        dfs1(1, 0);
        //cout << lcadist(5, 6) << endl;
        //        for (int i = 1; i <= n; i++)
        //        {
        //                cout << i << " " << dfn[i] << endl;
        //        }
        mem(Head, 0);
        int m, number;
        read(m);
        while (m--)
        {
                tot = 1;
                int nowtot = 1;
                xutot = 0;
                read(number);
                for (int i = 1; i <= number; i++)
                {
                        read(k[i]);
                }
                sort(k + 1, k + number + 1, cmp);
                for (int i = 2; i <= number; i++)
                {
                        if (lca(k[i], k[nowtot]) != k[nowtot])
                        {
                                k[++nowtot] = k[i];
                        }
                }
                //                for(int i=1;i<=nowtot;i++)
                //                {
                //                        cout<<k[i]<<" ";
                //                }
                //                cout<<endl;
                number = nowtot;
                xusta[++xutot] = 1;
                for (int i = 1; i <= number; i++)
                {
                        int grand = lca(xusta[xutot], k[i]);
                        //cout<<"grand "<<grand<<endl;
                        while (1)
                        {
                                if (deep[xusta[xutot - 1]] <= deep[grand]) //分别处在两个子树,grand深度更大!!!
                                {
                                        //cout << "add" << grand << " " << xusta[xutot] << endl;
                                        addedge(grand, xusta[xutot], 0);
                                        xutot--;
                                        if (xusta[xutot] != grand)
                                        {
                                                xusta[++xutot] = grand;
                                        }
                                        break;
                                }
                                //cout << "add" <<  xusta[xutot - 1] << " " << xusta[xutot] << endl;
                                addedge(xusta[xutot - 1], xusta[xutot], 0);
                                xutot--;
                        }
                        if (xusta[xutot] != k[i])
                        {
                                xusta[++xutot] = k[i];        //在同一子树
                        }
                }
                xutot--;
                while (xutot)
                {
                        //cout << "add" <<  xusta[xutot] << " " << xusta[xutot - 1] << endl;
                        addedge(xusta[xutot], xusta[xutot + 1], 0);
                        xutot--;
                }
                get_ans(1, 0);
                cout << finalans[1] << endl;
        }
        return 0;
}
View Code

 

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!