牛客练习赛42 C 出题的诀窍
链接:https://ac.nowcoder.com/acm/contest/393/C来源:牛客网
题目描述
给定m个长为n的序列a1,a2,…,ama_1 , a_2 , \dots , a_ma1,a2,…,am。
小Z想问你:
其中SUM(一个序列)\texttt{SUM}(\text{一个序列})SUM(一个序列)表示这个序列中所有不同的数的和,相当于先sort,unique\tt sort,uniquesort,unique再求和。
输入描述:
第一行两个整数n,m。接下来m行,每行n个整数,第i行第j个表示ai,ja_{i,j}ai,j
输出描述:
一行一个整数,表示答案。
示例1
输入
2 3 1 2 2 3 1 3
输出
36
题意
就是求题面中给定的公式。
思路:
计算贡献的题目。
把所有的数放入一个集合S(去重)
那么集合S中的每一个元素x,对答案的贡献就是x*num,num为含有x的一组数的个数
那么如何求num呢?
\(num=n^m-cnt\)
cnt为不含有x的一组数的个数
那么只需要m行,每一行中(n-x的个数)乘起来即可。
对于那些不含有x的行。我们用预处理n的幂次来解决。
并且这题比较卡常,
需要用快速读入+pbds的hash来离散化。
能用int的地方不要用longlong
代码:
#include <bits/stdc++.h> #include <cstdio> #include<ext/pb_ds/assoc_container.hpp> #include<ext/pb_ds/hash_policy.hpp> #include <cstring> #include <algorithm> #include <cmath> #include <queue> #include <stack> #include <map> #include <set> #include <vector> #include <iomanip> #define ALL(x) (x).begin(), (x).end() #define sz(a) int(a.size()) #define rep(i,x,n) for(int i=x;i<n;i++) #define repd(i,x,n) for(int i=x;i<=n;i++) #define pii pair<int,int> #define pll pair<long long ,long long> #define gbtb ios::sync_with_stdio(false),cin.tie(0),cout.tie(0) #define MS0(X) memset((X), 0, sizeof((X))) #define MSC0(X) memset((X), '\0', sizeof((X))) #define pb push_back #define mp make_pair #define fi first #define se second #define eps 1e-6 #define gg(x) getInt(&x) #define chu(x) cout<<"["<<#x<<" "<<(x)<<"]"<<endl #define du3(a,b,c) scanf("%d %d %d",&(a),&(b),&(c)) #define du2(a,b) scanf("%d %d",&(a),&(b)) #define du1(a) scanf("%d",&(a)); using namespace std; typedef long long ll; ll gcd(ll a, ll b) {return b ? gcd(b, a % b) : a;} ll lcm(ll a, ll b) {return a / gcd(a, b) * b;} ll powmod(ll a, ll b, ll MOD) {a %= MOD; if (a == 0ll) {return 0ll;} ll ans = 1; while (b) {if (b & 1) {ans = ans * a % MOD;} a = a * a % MOD; b >>= 1;} return ans;} void Pv(const vector<int> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%d", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}} void Pvl(const vector<ll> &V) {int Len = sz(V); for (int i = 0; i < Len; ++i) {printf("%lld", V[i] ); if (i != Len - 1) {printf(" ");} else {printf("\n");}}} inline void getInt(int* p); const int maxn = 4000010; const int inf = 0x3f3f3f3f; /*** TEMPLATE CODE * * STARTS HERE ***/ namespace IO { #define BUF_SIZE 100000 #define OUT_SIZE 100000 #define ll long long //fread->read bool IOerror = 0; inline char nc() { static char buf[BUF_SIZE], *p1 = buf + BUF_SIZE, *pend = buf + BUF_SIZE; if (p1 == pend) { p1 = buf; pend = buf + fread(buf, 1, BUF_SIZE, stdin); if (pend == p1) {IOerror = 1; return -1;} //{printf("IO error!\n");system("pause");for (;;);exit(0);} } return *p1++; } inline bool blank(char ch) {return ch == ' ' || ch == '\n' || ch == '\r' || ch == '\t';} inline void read(int &x) { bool sign = 0; char ch = nc(); x = 0; for (; blank(ch); ch = nc()); if (IOerror)return; if (ch == '-')sign = 1, ch = nc(); for (; ch >= '0' && ch <= '9'; ch = nc())x = x * 10 + ch - '0'; if (sign)x = -x; } //fwrite->write struct Ostream_fwrite { char *buf, *p1, *pend; Ostream_fwrite() {buf = new char[BUF_SIZE]; p1 = buf; pend = buf + BUF_SIZE;} void out(char ch) { if (p1 == pend) { fwrite(buf, 1, BUF_SIZE, stdout); p1 = buf; } *p1++ = ch; } void print(int x) { static char s[15], *s1; s1 = s; if (!x)*s1++ = '0'; if (x < 0)out('-'), x = -x; while (x)*s1++ = x % 10 + '0', x /= 10; while (s1-- != s)out(*s1); } void print(char *s) {while (*s)out(*s++);} } Ostream; inline void print(int x) {Ostream.print(x);} inline void print(char *s) {Ostream.print(s);} }; using namespace IO; int a[2005][2005]; int b[2005][2005]; int n, m; ll base; const ll mod = 1000000007ll; ll ans; int vis[maxn]; __gnu_pbds::gp_hash_table<int, int> w; bool wvis[maxn]; bool solved[maxn]; int cnt[maxn]; int p[5000]; int id = 0; __gnu_pbds::gp_hash_table<int, int> lsh; int main() { read(n); read(m); repd(i, 1, m) { repd(j, 1, n) { read(a[i][j]); } } repd(i, 1, m) { repd(j, 1, n) { int q = lsh[a[i][j]]; if (q == 0) { lsh[a[i][j]] = ++id; b[i][j] = id; } else { b[i][j] = q; } vis[b[i][j]] = 1ll; } } repd(i, 1, m) { repd(j, 1, n) { w[b[i][j]] += 1; } repd(j, 1, n) { if (wvis[b[i][j]] == 0) { wvis[b[i][j]] = 1; vis[b[i][j]] = 1ll * vis[b[i][j]] * (n - w[b[i][j]]) % mod; cnt[b[i][j]]++; } } repd(j, 1, n) { wvis[b[i][j]] = 0; w[b[i][j]] -= 1; } } base = powmod(n, m, mod); p[0] = 1ll; repd(i, 1, n) { p[i] = (1ll * p[i - 1] * n) % mod; } repd(i, 1, m) { repd(j, 1, n) { if (solved[b[i][j]] == 0) { solved[b[i][j]] = 1; vis[b[i][j]] = (1ll * vis[b[i][j]] * p[ m - cnt[b[i][j]]]) % mod; ans = (ans + ( base - vis[b[i][j]] + mod) % mod * a[i][j] % mod) % mod; } } } printf("%lld\n", ans); return 0; } inline void getInt(int* p) { char ch; do { ch = getchar(); } while (ch == ' ' || ch == '\n'); if (ch == '-') { *p = -(getchar() - '0'); while ((ch = getchar()) >= '0' && ch <= '9') { *p = *p * 10 - ch + '0'; } } else { *p = ch - '0'; while ((ch = getchar()) >= '0' && ch <= '9') { *p = *p * 10 + ch - '0'; } } }