图论 test solution

旧巷老猫 提交于 2019-11-27 07:01:09

图论 test solution

T1:潜伏

题目背景

小悠回家之后,跟着母亲看了很多抗日神剧,其中不乏一些谍战片。

题目描述

解放前夕,北平城内潜伏着若干名地下党员,他们居住在城市的不同位置。现在身为地下党第一指挥官的你,想知道地下党员之间的最小通信距离,即从某一地下党员住处前往另一地下党员住处的距离的最小值。

我们可以将北平城视为一张N个点M条边的无向图,每条边连接两个点 ,且长度为\(w_i\)

输入格式

每个测试点包含多组数据。
第一行,给出数据组数 ,之后依次输入每组数据。
每组数据的第一行,N,M,K,分别表示点数,边数,地下党员数。
之后M行,每\(u_i,v_i,w_i\)表示第i条边。
之后一行,K个整数代表地下党员所在结点。
结点编号为1到N,保证N>=K。

输出格式

对于每组数据,输出一行一个整数,表示地下党员之间的最小通信距离。
如果最小通信距离为∞,请输出-1代替。

样例输入

3
5 10 3
1 3 437
1 2 282
1 5 328
1 2 519
1 2 990
2 3 837
2 4 267
2 3 502
3 5 613
4 5 132
1 3 4
10 13 4
1 6 484
1 3 342
2 3 695
2 3 791
2 8 974
3 9 526
4 9 584
4 7 550
5 9 914
6 7 444
6 8 779
6 10 350
8 8 394
9 10 3 7
10 9 4
1 2 330
1 3 374
1 6 194
2 4 395
2 5 970
2 10 117
3 8 209
4 9 253
5 7 864
8 5 10 6

样例输出

437
526
641

数据范围

对于所有测试点,T<=10,时间限制2000ms,内存限制1GB 。
对于50%的测试点,N<=1000,M<=2000。
对于20%的测试点,N<=100000,输无的无向图无环。
对于30%的测试点,N<=100000,M<=200000。
所有 0<=\(w_i\)<=1000。
50pts:

进行k次单源最短路计算(用dijkstra比较好吧,堆优化的话复杂度可以降到O((n+m)logn)*k;

然后Floyd是不行的,会死的很惨。

以下是50pts程序

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<cstdlib>
#include<string>
#include<queue>
#define pr pair<int,int>

using namespace std;

const int mxn=100005;
const int mxm=200005;

inline int read(){
    int ans=0;
    char last=' ',ch=getchar();
    while(ch>'9'||ch<'0') last=ch,ch=getchar();
    while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar();
    if(last=='-') ans=-ans;
    return ans;
}

int n,m,k;
int head[mxn],ecnt;
struct edge{
    int to,dis,nxt;
}e[mxm<<1];

void add(int from,int to,int dis){
    ++ecnt;
    e[ecnt].to=to;
    e[ecnt].dis=dis;
    e[ecnt].nxt=head[from];
    head[from]=ecnt;
}

int dis[mxn];
int p[mxn];
bool bo[mxn],vis[mxn];

int dijkstra(int s){
    priority_queue<pr,vector<pr>,greater<pr> > q;
    memset(dis,0x3f,sizeof(dis));
    memset(vis,0,sizeof(vis));
    dis[s]=0;
    q.push(make_pair(dis[s],s));
    int ans=2147483647;
    while(!q.empty()){
        int u=q.top().second;
        q.pop();
        if(vis[u]) continue;
        vis[u]=1;
        for(int i=head[u];i;i=e[i].nxt){
            int v=e[i].to;
            if(dis[v]>dis[u]+e[i].dis){
                dis[v]=dis[u]+e[i].dis;
                if(bo[v]) ans=min(ans,dis[v]);
                q.push(make_pair(dis[v],v));
            }
        }
    }
    return ans;
}

int main(){
    int T,u,v,w;
    
    T=read();
    while(T--){
        int ans=2147483647;
        for(int i=1;i<=n;i++)head[i] =p[i]=bo[i]=0;
        n=read();m=read();k=read();
        for(int i=1;i<=m;i++){
            u=read();v=read();w=read();
            add(u,v,w);add(v,u,w);
        }
        for(int i=1;i<=k;i++){
            p[i]=read();bo[p[i]]=1;
        }
        for(int i=1;i<=k;i++)
            ans=min(ans,dijkstra(p[i]));
        printf("%d\n",ans);
    }
    return 0;
}

对于20%的测试点,N<=100000,输无的无向图无环。

是森林,然后可以考虑入门级树形DP。

但是我不会

100pts:

优化枚举量

显然我们所求答案的两个点是不同的点,所以编号的二进制表示中至少一位不同

考虑枚举二进制每一位:

假设枚举到第i位,把这一位是1的设为源点,0的设为汇点,跑多源多汇最短路(怎么跑?)

(我们可以建一个超级源点,一个超级汇点,将所有第i位为1的都和超级源点连一条长度为0的边,所有第i位为0的点,向超级汇点连一条长度为0的点,那么我们跑多源多汇最短路就变成了求超级源点到超级汇点的最短路。)

多源多汇最短路的源点与终点的选取既可以从1~n中选取,也可以从1~k选取。

显然1~k更优一些

因为int只有32位,所以跑32次即可得到最优解;

#include <queue>
#include <cstdio>
#include <cstring>

template <class cls>
inline cls min(const cls & a, const cls & b) {
    return a < b ? a : b;
}

const int mxn = 100005;
const int mxm = 500005;
const int inf = 0x3f3f3f3f;

int n, m, k;

int points[mxn];

int tot;
int hd[mxn];
int nt[mxm];
int to[mxm];
int vl[mxm];

inline void add_edge(int u, int v, int w) {
    nt[++tot] = hd[u];
    to[tot] = v;
    vl[tot] = w;
    hd[u] = tot;
}

int dis[mxn];

struct data {
    int u, d;
    //u:特工原编号 
    //d: 由源点到终点的最短路径 

    data(int _u, int _d) :
        u(_u), d(_d) {}
    
    bool operator < (const data & that) const {
        return d > that.d;
    }
};

std::priority_queue<data> heap;

int main() {
    int cas;
    scanf("%d", &cas);
    for (int c = 0; c < cas; ++c) {
        scanf("%d%d%d", &n, &m, &k);
        memset(hd, 0, sizeof(int) * (n + 5)); tot = 0;//大概是memset了n+5个位置来减少时间; 
        for (int i = 0, u, v, w; i < m; ++i) {
            scanf("%d%d%d", &u, &v, &w);
            add_edge(u, v, w);
            add_edge(v, u, w);
        }
        for (int i = 0; i < k; ++i)
            scanf("%d", points + i);//指针型写法 
        int ans = inf;
        for (int i = 1; i < k; i <<= 1) { //枚举第i位是1 
        //以下本质是个dijkstra 
            memset(dis, inf, sizeof(int) * (n + 5));
            for (int j = 0, p; j < k; ++j)
                if (p = points[j], (j & i) == 0)//相当于对每个地下党从0~p-1重新编了个号 
                    heap.push(data(p, dis[p] = 0));//判断这k个特工的第i位是0还是1 
            //这里dis[p]=0可以看做是从超级源点到p的距离为0 
            while (!heap.empty()) {
                int u = heap.top().u;
                int d = heap.top().d;
                heap.pop();
                if (dis[u] != d)//和判断vis是相同的 
                    continue;
                for (int e = hd[u], v, w; e; e = nt[e])
                    if (v = to[e], w = vl[e], dis[v] > d + w)
                        heap.push(data(v, dis[v] = d + w));
            }
            for (int j = 0, p; j < k; ++j)//然后求这其中最小的(可以看做是求到超级汇点的最短路) 
                if (p = points[j], (j & i) != 0)
                    ans = min(ans, dis[p]);
        }
        printf("%d\n", ans == inf ? -1 : ans);
    }
    return 0;
}

T2:神经

题目背景

神经网络在人工智能领域大火,所以小悠学习了些许关于神经的知识。

题目描述

神经网络可以表示为N个点M条边的有向图,每个结点带有一个兴奋权值\(w_i\)

如果以x单位大小的电流刺激结点u,设\(v_1,v_2,……,v_k\)是从u出发可以到达的结点,则神经网络会产生\(x×max\{w_{v_1},w_{v_2}……w_{v_k}\}\)的兴奋度,请注意,我们认为从u出发可以到达u。

现在请你回答若干询问,每个询问表示为:在以\(x_i\)大小的电流刺激\(u_i\)点后,网络的兴奋度是多少。

输入格式

每个测试点包含多组数据。

第一行,一个整数T表示数据组数,之后依次输入每组数据。

每组数据第一行,N,M,K分别表示点数,边数,询问次数。

之后一行,N个整数,\(w_1,w_2,…,w_N\)表示每个点的兴奋权值。

之后M行,每行\(u_i,v_i\)表示一条从\(u_i\)\(v_i\)的单向边。

之后K行,每行\(u_i,x_i\)表示一次刺激。

输出格式

对于每个测试点中的每个询问,输出一行一个整数表示答案。

样例输入

3
5 10 5
4 10 8 1 10
1 3
1 4
1 5
1 3
2 1
2 5
4 3
4 3
4 5
5 1
1 4
4 6
1 9
4 7
2 9
5 10 5
2 8 8 10 10
2 1
2 3
3 2
3 4
3 1
3 2
3 4
4 1
5 4
5 1
1 4
2 3
4 7
3 10
1 5
5 10 5
9 9 8 2 1
1 5
1 5
2 1
2 4
2 4
2 4
3 2
3 1
4 3
4 3
5 9
3 9
2 7
5 1
5 4

样例输出

40
60
90
70
90
8
30
70
100
10
9
81
63
14

数据范围

对于所有数据,T<=5,时间限制 1000ms,内存限制1GB 。
对于50%的测试点,N<=1000,M<=2000,K<=1000。
对于20%的测试点,N<=100000,M<=200000,K<=100000,输入的有向图构成DAG。
对于30%的测试点,N<=200000,M<=400000,K<=100000。
所有 \(0<=x,w_i<=10^9\)

50pts:

暴力bfs,注意要开long long

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<cstdlib>
#include<string>
#include<queue>
#define ll long long

using namespace std;

inline ll read(){
    ll ans=0;
    char last=' ',ch=getchar();
    while(ch>'9'||ch<'0') last=ch,ch=getchar();
    while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar();
    if(last=='-') ans=-ans;
    return ans;
}
const int mxn=200010;
const int mxm=400010;
ll n,m,k;
ll w[mxn];
ll cnt;
struct node{
    ll to,nxt;
}e[mxm];
ll ecnt,head[mxn];
void add(ll from,ll to){
    ++ecnt;
    e[ecnt].to=to;
    e[ecnt].nxt=head[from];
    head[from]=ecnt;
}

ll maxn[mxn];
bool vis[mxn];

ll solve(ll s){
    memset(vis,0,sizeof(vis));
    queue<ll> q;
    q.push(s);
    vis[s]=1;
    ll ans=w[s];
    while(!q.empty()){
        ll u=q.front();
        q.pop();
        for(int i=head[u];i;i=e[i].nxt){
            ll v=e[i].to;
            if(!vis[v]){
                q.push(v);
                vis[v]=1;
                ans=max(ans,w[v]);
            }
        }
    }
    return ans;
}

int main(){
    ll T,u,v,x;
    T=read();
    while(T--){
        for(int i=1;i<=n;i++)
            w[i]=head[i]=0;
        n=read();m=read();k=read();
        for(int i=1;i<=n;i++) w[i]=read();

        for(int i=1;i<=m;i++) {
            u=read();v=read();
            add(u,v);
        }
            for(int i=1;i<=k;i++){
                u=read();x=read();
                printf("%lld\n",solve(u)*x);
            }
    }
    return 0;
}

2.对于20%的数据,是DAG,拓扑排序,dp。

尝试手写代码rwr:
(啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊啊写不对了)

神仙wz帮忙改,终于改对了(细节锅,awsl);

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<cstdlib>
#include<string>
#include<queue>
#define ll long long

using namespace std;

inline ll read(){
    ll ans=0;
    char last=' ',ch=getchar();
    while(ch>'9'||ch<'0') last=ch,ch=getchar();
    while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar();
    if(last=='-') ans=-ans;
    return ans;
}
const int mxn=200010;
const int mxm=400010;
ll n,m,k;
ll w[mxn];
ll cnt;
struct node{
    ll to,nxt;
}e[mxm];
ll ecnt,head[mxn];
void add(ll from,ll to){
    ++ecnt;
    e[ecnt].to=to;
    e[ecnt].nxt=head[from];
    head[from]=ecnt;
}

ll in[mxn];
bool vis[mxn];

ll solve(ll s){
    memset(vis,0,sizeof(vis));
    queue<ll> q;
    q.push(s);
    vis[s]=1;
    ll ans=w[s];
    while(!q.empty()){
        ll u=q.front();
        q.pop();
        for(int i=head[u];i;i=e[i].nxt){
            ll v=e[i].to;
            if(!vis[v]){
                q.push(v);
                vis[v]=1;
                ans=max(ans,w[v]);
            }
        }
    }
    return ans;
}
ll mx[mxn];

void topu(){
    memset(mx,0,sizeof(mx));
    queue<int> q;
    for(int i=1;i<=n;i++){
        if(in[i]==0)
            q.push(i);
        mx[i]=w[i];
    }
    while(!q.empty()){
        int u=q.front();
        q.pop();
        for(int i=head[u];i;i=e[i].nxt){
            int v=e[i].to;
            in[v]--;
            mx[v]=max(mx[v],mx[u]);
            if(in[v]==0) q.push(v);
        }
    }
}

int main(){
    ll T,u,v,x;
    T=read();
    while(T--){
        ecnt=0;
        for(int i=1;i<=n;i++)
            w[i]=head[i]=in[i]=0;
        n=read();m=read();k=read();
        for(int i=1;i<=n;i++) w[i]=read();
        for(int i=1;i<=m;i++) {
            u=read();v=read();
            if(n<=1000) add(u,v);
            else {
                add(v,u);
                in[u]++;
            }
        }
        if(n>1000) topu();
        for(int i=1;i<=k;i++){
            u=read();x=read();
            if(n<=1000) printf("%lld\n",solve(u)*x);
            else 
                printf("%lld\n",mx[u]*x);
        }
    }
    return 0;
}

3.tarjan缩点,变成DAG,然后再拓扑排序dp

然后不知道为什么,lemon把std和自己的代码都卡死了

std:

#include <cstdio>
#include <cstring>

template <class cls>
inline cls min(const cls & a, const cls & b) {
    return a < b ? a : b;
}

template <class cls>
inline cls max(const cls & a, const cls & b) {
    return a > b ? a : b;
}

const int mxn = 200005;
const int mxm = 400005;

int n, m, k, w[mxn];

struct edge {
    int u, v;
} edges[mxm];

int tot;
int hd[mxn];
int to[mxm << 1];
int nt[mxm << 1];

inline void add_edge(int u, int v) {
    nt[++tot] = hd[u];
    to[tot] = v;
    hd[u] = tot;
}

int tim;
int cnt;
int top;
int dfn[mxn];
int low[mxn];
int stk[mxn];
int scc[mxn];

void tarjan(int u) {
    dfn[u] = low[u] = ++tim; stk[++top] = u;
    for (int e = hd[u], v; e; e = nt[e])
        if (v = to[e], scc[v] == 0) {
            if (dfn[v] == 0)tarjan(v),
                low[u] = min(low[u], low[v]);
            else
                low[u] = min(low[u], dfn[v]);
        }
    if (dfn[u] == low[u]) {
        cnt += 1;
        do {
            scc[stk[top]] = cnt;
        } while (stk[top--] != u);
    }
}

int oe[mxn];
int mx[mxn];

int que[mxn];

void bfs() {
    int l = 0, r = 0;
    for (int i = 1; i <= cnt; ++i)
        if (oe[i] == 0)
            que[r++] = i;
    while (l < r) {
        int u = que[l++];
        for (int e = hd[u], v; e; e = nt[e])
            if (v = to[e], mx[v] = max(mx[v], mx[u]), --oe[v] == 0)
                que[r++] = v;
    }
}

int main() {
    int cas;
    scanf("%d", &cas);
    for (int c = 0; c < cas; ++c) {
        scanf("%d%d%d", &n, &m, &k);
        for (int i = 1; i <= n; ++i)
            scanf("%d", w + i);
        memset(hd, 0, sizeof(int) * (n + 5)); tot = 0;
        for (int i = 0; i < m; ++i) {
            scanf("%d%d", &edges[i].u, &edges[i].v);
            add_edge(edges[i].u, edges[i].v);
        }
        tim = cnt = top = 0;
        memset(scc, 0, sizeof(int) * (n + 5));
        memset(dfn, 0, sizeof(int) * (n + 5));
        for (int i = 1; i <= n; ++i)
            if (scc[i] == 0)
                tarjan(i);
        memset(hd, 0, sizeof(int) * (cnt + 5)); tot = 0;
        memset(oe, 0, sizeof(int) * (cnt + 5));
        memset(mx, 0, sizeof(int) * (cnt + 5));
        for (int i = 0; i < m; ++i) {
            int u = scc[edges[i].u];
            int v = scc[edges[i].v];
            if (u != v) 
                add_edge(v, u), oe[u] += 1;
        }
        for (int i = 1; i <= n; ++i)
            mx[scc[i]] = max(mx[scc[i]], w[i]);
        bfs();
        for (int i = 0, u, x; i < k; ++i) {
            scanf("%d%d", &u, &x);
            printf("%lld\n", 1LL * x * mx[scc[u]]);
        }
    }
    return 0;
}

lz's:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<cstdlib>
#include<string>
#include<queue>
#define ll long long

using namespace std;

inline ll read(){
    ll ans=0;
    char last=' ',ch=getchar();
    while(ch>'9'||ch<'0') last=ch,ch=getchar();
    while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar();
    if(last=='-') ans=-ans;
    return ans;
}
const int mxn=200010;
const int mxm=400010;
ll n,m,k;
ll w[mxn];
ll cnt;
struct node{
    ll to,nxt;
}e[mxm];
ll ecnt,head[mxn];
void add(ll from,ll to){
    ++ecnt;
    e[ecnt].to=to;
    e[ecnt].nxt=head[from];
    head[from]=ecnt;
}

ll in[mxn],u[mxm],v[mxm];
bool vis[mxn];

int dfn[mxn],low[mxn],stk[mxn<<1];
int tim,top;
int scc[mxn];

void tarjan(int u) {
    dfn[u]=low[u]=++tim;
    stk[++top]=u;
    vis[u]=1;
    for (int i=head[u],v;i;i=e[i].nxt){
        v=e[i].to;
            if (!dfn[v]) {
                tarjan(v);
                low[u]=min(low[u],low[v]);
            }
            else if(vis[v]) low[u]=min(low[u],dfn[v]);
    }
    if (dfn[u]==low[u]) {
        ++cnt;
        do {
            scc[stk[top]]=cnt;
            vis[stk[top]]=0;
        } while (stk[top--]!=u);
    }
}

ll mx[mxn];

void topu(){
    queue<int> q;
    for(int i=1;i<=cnt;i++){
        if(in[i]==0)
            q.push(i);
    }
    while(!q.empty()){
        int u=q.front();
        q.pop();
        for(int i=head[u];i;i=e[i].nxt){
            int v=e[i].to;
            in[v]--;
            mx[v]=max(mx[v],mx[u]);
            if(in[v]==0) q.push(v);
        }
    }
}

void clean(){
    for(int i=1;i<=n+5;i++)
        w[i]=head[i]=in[i]=mx[i]=0;
    for(int i=1;i<=m+5;i++)
        v[i]=u[i]=0;
    ecnt=0;
}

void vclean(){
    top=tim=ecnt=0;
    memset(dfn,0,sizeof(int)*(n+5));
    memset(vis,0,sizeof(int)*(n+5));
    memset(stk,0,sizeof(int)*(n+5));
}

int main(){
    ll T,u_,x_;
    T=read();
    while(T--){
        n=read();m=read();k=read();
        clean();
        for(int i=1;i<=n;i++) w[i]=read();
        for(int i=1;i<=m;i++) {
            u[i]=read();v[i]=read();
            add(u[i],v[i]);
        }
        vclean();
        for(int i=1;i<=n;i++) if(!dfn[i]) tarjan(i);
        for(int i=1;i<=n;i++)
            mx[scc[i]]=max(mx[scc[i]],w[i]);
        memset(head,0,sizeof(head));
        for(int i=1;i<=m;i++){
            if(scc[u[i]]!=scc[v[i]]){
                add(scc[v[i]],scc[u[i]]);
                in[scc[u[i]]]++;
            }
        }
//      cout<<cnt<<endl;
        topu();
        for(int i=1;i<=k;i++){
            u_=read();x_=read();
            printf("%lld\n",mx[scc[u_]]*x_);
        }
    }
    return 0;
}

rqy的神仙做法:读入所有点点权,按点权从大到小排序,反向建图,从权值最大的点出发,覆盖点,然后再从第2大的点出发

T3:计数

题目背景

小悠的导师对于树叶颜色有很深的研究,但是碍于眼神不好,总是要请小悠来帮忙数一数树上某种颜色的叶子有多少片。

题目描述

给出一棵N个结点的树,每个结点初始具有一个颜色\(c_i\),现在有如下两种操作——

  1. 更改结点u的颜色为x。
  2. 询问从u到v的路径上有多少个颜色为x的结点。

现在请你按顺序完成若干次操作。

输入格式

每个测试点包含多组数据。
第一行,一个整数T表示数据组数,之后依次输入每组数据。
每组数据第一行,N,M分别表示点数、操作次数。
之后一行,N个整数\(c_1,c_2,……,c_N\)表示每个点的初始颜色。
之后N-1行,每行两个整数,\(u_i,v_i\),表示树上存在一条连接\(u_i,v_i\)的边。
之后M行,表示每个操作。
如果是操作1,则形如1 u x。
如果是操作2,则形如2 u v x。

输出格式

对于每组数据的每个操作2,输出一行一个整数表示答案。

样例输入

3
5 5
3 2 1 1 1
1 2
2 3
2 5
3 4
2 3 4 1
1 2 1
2 3 5 1
2 1 5 3
2 4 4 3
5 5
1 2 1 2 2
1 2
2 3
2 4
3 5
1 1 2
1 3 2
1 2 2
2 4 2 2
2 1 4 1
5 5
2 1 1 1 1
1 2
1 4
2 3
4 5
2 4 2 1
1 1 1
2 1 4 1
2 3 3 1
2 2 2 2

样例输出

2
3
1
0
2
0
2
2
1
0

数据范围

对于所有测试点,T<=5,时间限制1000ms,内存限制1GB。
对于50%的测试点,N<=1000,M<=1000。
对于20%的测试点,N<=100000,M<=100000,颜色种类数不超过5。
对于30%的测试点,N<=100000,M<=100000。
所有测试点颜色种类数均不超过100。

对于前50%的数据:每次询问bfs一遍:

#include<cstdio>
#include<algorithm>
#include<cstring>
#include<iostream>
#include<cmath>
#include<cstdlib>
#include<string>
#include<queue>

using namespace std;

inline int read(){
    int ans=0;
    char last=' ',ch=getchar();
    while(ch>'9'||ch<'0') last=ch,ch=getchar();
    while(ch>='0'&&ch<='9') ans=(ans<<1)+(ans<<3)+ch-'0',ch=getchar();
    if(last=='-') ans=-ans;
    return ans;
}

const int mxn=100010;
const int mxm=200010;

struct node{
    int to,nxt;
}e[mxm];
int ecnt,head[mxn];
void add(int from,int to){
    ++ecnt;
    e[ecnt].to=to;
    e[ecnt].nxt=head[from];
    head[from]=ecnt;
}
struct Node{
    int cnt,u;  
};
Node _n(int cnt,int u){
    Node rtn;
    rtn.cnt=cnt;
    rtn.u=u;
    return rtn;
}

bool vis[mxn];
int c[mxn],n,m,u,v,x;

int solve(int u,int k,int x){
    memset(vis,0,sizeof(vis));
    queue<Node> q;
    if(c[u]==x) q.push({1,u});
    else q.push({0,u});
    vis[u]=1;
    while(!q.empty()){
        Node z=q.front();
        q.pop();
        for(int i=head[z.u];i;i=e[i].nxt){
            int v=e[i].to;
            if(v==k){
                if(c[v]==x) return z.cnt+1;
                else return z.cnt;
            }
            if(!vis[v]){
                vis[v]=1;
                if(c[v]==x) q.push(_n(z.cnt+1,v));
                else q.push({z.cnt,v});
            }
        }
    }
}

int main(){
//  freopen("count.in","r",stdin);
//  freopen("count.out","w",stdout);
    int T,opt;
    T=read();
    while(T--){
        for(int i=1;i<=n;i++)
            head[i]=c[i]=0;
        n=read();
        m=read();
        for(int i=1;i<=n;i++)
            c[i]=read();
        for(int i=1;i<n;i++){
            u=read();v=read();
            add(u,v);
            add(v,u);
        }
        for(int i=1;i<=m;i++){
            opt=read();
            if(opt==1){
                u=read();x=read();
                c[u]=x;
            }
            else {
                u=read();v=read();x=read();
                if(u==v){
                    if(c[u]==x) 
                        printf("1\n");
                    else printf("0\n");
                }
                else 
                printf("%d\n",solve(u,v,x));
            }
        }
    }
    return 0;
}

对于剩下20%的数据,树链剖分,开五棵线段树,颜色为1的就在第一棵线段树的位置+1,颜色由1到2,2+1,1-1;

对于100%的数据,暴力:开100个树状数组,和刚才没什么区别
如果线段树在每一个节点上维护一个100的数组
合并的时候可以直接暴力统计节点次数,这样代价是区间长度
如果每一位枚举则是n*100
每一层访问的点是n的,一共log层
o(nlogn)
离线操作
-1和+1分别隶属于x和y棵线段树
把操作分类,每一次处理每一棵的线段树
有多少个颜色就有多少棵
所有操作次数相加就是2m
所以操作还是o(m)

#include <cstdio>
#include <cstring>
#include <iostream>
#include <algorithm>

using namespace std;

inline int getint()
{
    int r = 0, c = getchar();
    
    for (; c < 48; c = getchar());
    for (; c > 47; c = getchar())
        r = r * 10 + c - 48;
    
    return r;
}

const int mxc = 100005;
const int mxn = 100005;
const int mxm = 200005;

int n, m, c;

int tt;
int hd[mxn];
int to[mxm];
int nt[mxm];

inline void addedge(int x, int y)
{
    nt[++tt] = hd[x], to[tt] = y, hd[x] = tt;
    nt[++tt] = hd[y], to[tt] = x, hd[y] = tt;
}

struct data
{
    int k, x, y;
    
    data() {} ;
    data(int a, int b, int c)
        : k(a), x(b), y(c) {} ;
};

int color[mxn];

#include <vector>

vector<data> vec[mxc];

int tim;
int dfn[mxn];
int top[mxn];
int fat[mxn];
int dep[mxn];
int son[mxn];
int siz[mxn];

void dfs1(int u, int f)
{
    siz[u] = 1;
    son[u] = 0;
    fat[u] = f;
    dep[u] = dep[f] + 1;
    
    for (int i = hd[u], v; i; i = nt[i])
        if (v = to[i], v != f)
        {
            dfs1(v, u);
            siz[u] += siz[v];
            if (siz[v] > siz[son[u]])
                son[u] = v;
        }
}

void dfs2(int u, int f)
{
    dfn[u] = ++tim;
    
    if (son[f] == u)
        top[u] = top[f];
    else
        top[u] = u;
    
    if (son[u])
        dfs2(son[u], u);
    
    for (int i = hd[u], v; i; i = nt[i])
        if (v = to[i], v != f && v != son[u])
            dfs2(v, u);
}

int bit[mxn];

inline void add(int p, int v)
{
    for (; p <= n; p += p & -p)
        bit[p] += v;
}

inline int ask(int l, int r)
{
    int sum = 0; --l;
    
    for (; r; r -= r & -r)
        sum += bit[r];
    
    for (; l; l -= l & -l)
        sum -= bit[l];
    
    return sum;
}

int ans[mxn];

signed main()
{
    int cas = getint();
    
    while (cas--) 
    {
        n = getint();
        m = getint();
        
        for (int i = 1; i <= n; ++i)
            vec[color[i] = getint()].push_back(data(0, i, +1));
        
        c = 0;
        
        for (int i = 1; i <= n; ++i)
            c = max(c, color[i]);

        memset(hd, 0, sizeof(int) * (n + 5)); tt = 0;
        
        for (int i = 1; i < n; ++i)
        {
            int x = getint();
            int y = getint();
            
            addedge(x, y);
        }
        
        for (int i = 1; i <= m; ++i)
        {
            if (getint() == 1)
            {
                int p = getint();
                int a = color[p];
                int b = color[p] = getint();
                
                vec[a].push_back(data(0, p, -1));
                vec[b].push_back(data(0, p, +1));
            }
            else
            {
                int x = getint();
                int y = getint();
                int k = getint();
                
                vec[k].push_back(data(i, x, y));
            }
        }
        
        dfs1(1, 0);
        dfs2(1, 0);
        
        memset(ans, -1, sizeof ans);
        
        for (int k = 1; k <= c; ++k)
        {
            int sz = vec[k].size();
            
            memset(bit, 0, sizeof bit);
            
            for (int i = 0; i < sz; ++i)
            {
                const data &d = vec[k][i];
                
                ans[d.k] = 0;
                
                if (d.k == 0)
                    add(dfn[d.x], d.y);
                else
                {
                    int a = d.x, ta = top[a];
                    int b = d.y, tb = top[b];
                    
                    while (ta != tb)
                    {
                        if (dep[ta] >= dep[tb])
                            ans[d.k] += ask(dfn[ta], dfn[a]), ta = top[a = fat[ta]];
                        else
                            ans[d.k] += ask(dfn[tb], dfn[b]), tb = top[b = fat[tb]];
                    }
                    
                    if (dep[a] <= dep[b])
                        ans[d.k] += ask(dfn[a], dfn[b]);
                    else
                        ans[d.k] += ask(dfn[b], dfn[a]);
                }
            }
        }
        
        for (int i = 1; i <= m; ++i)
            if (ans[i] >= 0)
                printf("%d\n", ans[i]);

        for (int i = 1; i <= c; ++i)
            vec[i].clear();
        
        tim = 0;
    }
    
    return 0;
}
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!