题意是,有$n$个石头,每个石头有初始能量$E_i$,每秒能量增长$L_i$,以及能量上限$C_i$,有$m$个收能量的时间点,每次把区间$\left[S_i, T_i\right]$石头的能量都给收掉,石头的能量都置零重新开始增长。问最后收了多少能量。
看完题解觉得好有道理...我好菜...
考虑每个石头在多少个时间点收能量,然后每次收的能量就和这些时间点的时间间隔有关。
若时间间隔大于等于$\dfrac {C_i}{L_i}$,那么这一段对答案的贡献就是$C_i$了,统计有多少这样的段即可。
若时间间隔小于$\dfrac {C_i}{L_i}$那么对答案的贡献就是时间长度$t \times L_i$。
用两个权值树状数组可以维护对应时间长度的和及个数。
时间点可以用set维护。从前到后遍历,遇到一个$S_i$就把对应是时间加入,遇到一个$T_i + 1$就把时间删去,同时维护树状数组即可。感觉看代码就很好懂?
#include <bits/stdc++.h> #define ll long long using namespace std; const int N = 2e5 + 7; int n, m; ll E[N], C[N], L[N]; set<int> st; vector<int> G[N]; struct BIT { ll tree1[N], tree2[N]; inline void clear() { memset(tree1, 0, sizeof tree1); memset(tree2, 0, sizeof tree2); } inline int lowbit(int x) { return x & -x; } inline void add(int x, int val) { if (!x) return; for (int i = x; i < N; i += lowbit(i)) { if (val > 0) tree1[i]++; else tree1[i]--; tree2[i] += val; } } inline int query1(int x) { int ans = 0; for (int i = x; i; i -= lowbit(i)) ans += tree1[i]; return ans; } inline int query2(int x) { int ans = 0; for (int i = x; i; i -= lowbit(i)) ans += tree2[i]; return ans; } } bit; inline void init() { st.clear(); bit.clear(); for (int i = 0; i <= n; i++) G[i].clear(); } void add(int x) { if (st.empty()) { st.insert(x); return; } auto p = st.lower_bound(x); if (p == st.begin()) { bit.add((*p - x), (*p - x)); st.insert(x); return; } if (p == st.end()) { bit.add(x - (*prev(p)), x - (*prev(p))); st.insert(x); return; } int x1 = (*p) - x, x2 = x - (*prev(p)); bit.add(x1, x1); bit.add(x2, x2); bit.add(x1 + x2, -x1 - x2); st.insert(x); } void del(int x) { auto p = st.find(x); if (st.size() == 1) { st.erase(p); return; } if (p == st.begin()) { bit.add((*next(p)) - x, x -(*next(p))); st.erase(p); return; } if (p == prev(st.end())) { bit.add(x - (*prev(p)), (*prev(p)) - x); st.erase(p); return; } int x1 = (*next(p)) - x, x2 = x - (*prev(p)); bit.add(x1, -x1); bit.add(x2, -x2); bit.add(x1 + x2, x1 + x2); st.erase(p); } int main() { int T, kase = 0; scanf("%d", &T); while (T--) { scanf("%d", &n); init(); for (int i = 1; i <= n; i++) scanf("%lld%lld%lld", &E[i], &L[i], &C[i]); scanf("%d", &m); for (int i = 1; i <= m; i++) { int l, r, t; scanf("%d%d%d", &t, &l, &r); G[l].push_back(t); G[r + 1].push_back(-t); } ll ans = 0; for (int i = 1; i <= n; i++) { for (auto x: G[i]) { if (x > 0) add(x); else del(-x); } if (st.empty()) continue; ans += min(C[i], 1LL * (*st.begin()) * L[i] + E[i]); if (!L[i]) continue; ans += (st.size() - 1 - bit.query1(C[i] / L[i])) * C[i] + bit.query2(C[i] / L[i]) * L[i]; } printf("Case #%d: %lld\n", ++kase, ans); } return 0; }