题意:
给出一个有n个数的数列,并定义mex(l, r)表示数列中第l个元素到第r个元素中第一个没有出现的最小非负整数。
求出这个数列中所有mex的值。
思路:
可以看出对于一个数列,mex(r, r~l)是一个递增序列
mex(0, 0~n-1)是很好求的,只需要遍历找出第一个没有出现的最小非负整数就好了。这里有一个小技巧:
1 tmp = 0; 2 for (int i = 1; i <= n; ++i) { 3 mp[arr[i]] = 1; 4 while (mp.find(tmp) != mp.end()) tmp++; 5 mex[i] = tmp; 6 }
这样可以利用map中的红黑树很快找到第一个没有出现的最小非负整数。
然后在求mex(1~n-1, 0~n-1)的过程中,我们可以看出,每消除当前值arr[i],会影响到的是在下一个arr[i]出现前 往后的mex值中比arr[i]大的值,即如果当前这个值不存在了,那么在这个值下一次出现前,mex值比当前值大的mex值都应被替换成arr[i]。
所以我们可以再一次利用map的红黑树找到当前值下一次出现的位置,然后利用线段树成段更新往后的mex值和求出会影响到的mex值的个数。
1 for (int i = n; i >= 1; --i) { 2 if (mp.find(arr[i]) == mp.end()) next[i] = n+1; 3 else next[i] = mp[arr[i]]; 4 mp[arr[i]] = i; 5 }
这里我们还需要利用线段树求出第一个比当前arr[i]大的mex值的位置,以便成段更新区间的mex值。
Tips:
※ 这里有一个小小优化的地方,就是当更新的时候,可以先查看mx[1]是否比当前值大,如果是,则表示往后的区间里有比当前值大的mex值,则需要线段树是需要更新的,否则不用更新。
※ 还有一个要注意的地方是:只有求出的左边界比右边界小的时候才能更新。
Code:
1 #include <stdio.h> 2 #include <cstring> 3 #include <map> 4 #include <algorithm> 5 using namespace std; 6 7 const int MAXN = 200010; 8 long long sum[MAXN<<2]; 9 int mx[MAXN<<2], arr[MAXN], next[MAXN], mex[MAXN]; 10 int lazy[MAXN<<2]; 11 12 void Pushup(int rt) 13 { 14 sum[rt] = sum[rt<<1]+sum[rt<<1|1]; 15 mx[rt] = max(mx[rt<<1], mx[rt<<1|1]); 16 } 17 18 void Pushdown(int rt, int x) 19 { 20 if (lazy[rt] != -1) { 21 lazy[rt<<1] = lazy[rt<<1|1] = lazy[rt]; 22 sum[rt<<1] = (x-x/2)*lazy[rt]; 23 sum[rt<<1|1] = x/2*lazy[rt]; 24 mx[rt<<1] = mx[rt<<1|1] = lazy[rt]; 25 lazy[rt] = -1; 26 } 27 } 28 29 void Creat(int l, int r, int rt) 30 { 31 lazy[rt] = -1; 32 if (l == r) { 33 sum[rt] = mx[rt] = mex[l]; 34 return; 35 } 36 int mid = (l+r)/2; 37 Creat(l, mid, rt<<1); 38 Creat(mid+1, r, rt<<1|1); 39 Pushup(rt); 40 } 41 42 void Modify(int l, int r, int x, int L, int R, int rt) 43 { 44 if (l <= L && r >= R) { 45 lazy[rt] = x; 46 sum[rt] = x*(R-L+1); 47 mx[rt] = x; 48 return; 49 } 50 Pushdown(rt, R-L+1); 51 int mid = (L+R)/2; 52 if (l <= mid) Modify(l, r, x, L, mid, rt<<1); 53 if (r > mid) Modify(l, r, x, mid+1, R, rt<<1|1); 54 Pushup(rt); 55 } 56 57 int Get(int rt, int l, int r, int x) 58 { 59 if(l == r) return l; 60 Pushdown(rt, r-l+1); 61 int mid = (l+r)/2; 62 if (mx[rt<<1] > x) return Get(rt<<1, l, mid, x); 63 else return Get(rt<<1|1, mid+1, r, x); 64 } 65 66 int main() 67 { 68 //freopen("in.txt", "r", stdin); 69 int n, tmp; 70 long long ans_sum; 71 map<int, int> mp; 72 while (~scanf("%d", &n)) { 73 if (n == 0) break; 74 ans_sum = tmp = 0; 75 mp.clear(); 76 memset(sum, 0, sizeof(sum)); 77 memset(next, 0, sizeof(next)); 78 79 for (int i = 1; i <= n; ++i) 80 scanf("%d", &arr[i]); 81 for (int i = 1; i <= n; ++i) { 82 mp[arr[i]] = 1; 83 while (mp.find(tmp) != mp.end()) tmp++; 84 mex[i] = tmp; 85 } 86 87 Creat(1, n, 1); 88 mp.clear(); 89 for (int i = n; i >= 1; --i) { 90 if (mp.find(arr[i]) == mp.end()) next[i] = n+1; 91 else next[i] = mp[arr[i]]; 92 mp[arr[i]] = i; 93 } 94 95 for (int i = 1; i <= n; ++i) { 96 ans_sum += sum[1]; 97 if (mx[1] > arr[i]) { 98 int l = Get(1, 1, n, arr[i]); 99 int r = next[i]; 100 // printf("%d %d %d\n", l, r, sum[1]); 101 if (l < r) Modify(l, r-1, arr[i], 1, n, 1); 102 } 103 104 Modify(i, i, 0, 1, n, 1); 105 } 106 printf("%I64d\n", ans_sum); 107 } 108 return 0; 109 }
链接:http://acm.hdu.edu.cn/showproblem.php?pid=4747
来源:https://www.cnblogs.com/Griselda/p/3433595.html