题目内容
给定一个长度为 n n n 的数组 a a a ,每次操作可以选择一个数 x x x ,将所有大于 x x x 的数都下降为 x x x ,一次操作的下降总代价为 s s s ,要求 s ≤ k s\leq k s≤k ,问需要多少次操作使得数组 a a a 的所有数都相同。
数据范围
- 1 ≤ n ≤ 2 ⋅ 1 0 5 1\leq n\leq 2\cdot 10^5 1≤n≤2⋅105
- n ≤ k ≤ 1 0 9 n\leq k\leq 10^9 n≤k≤109
- 1 ≤ a i ≤ 2 ⋅ 1 0 5 1\leq a_i\leq 2\cdot 10^5 1≤ai≤2⋅105
题解
朴素做法
首先使得所有数都相同,则使得所有值都成为 m i n ( a ) min(a) min(a) 。
所以考虑怎么将所有数都变成 m i n ( a ) min(a) min(a) 。
先对 a a a 排序,然后从最大的数开始依次将前面大的数往小了降。
从大的值开始往下依次枚举,看每次的 k k k 可以将前面多少数统一下降为一个数。
因为值域很小,所以可以直接枚举值域,这样的时间复杂度是 O ( m a x ( a ) + n ) O(max(a)+n) O(max(a)+n)
优化做法
如果值域很大,是否有与值域无关的做法。
其实这种题本身就应该值域无关,考虑当前最大的数为 c u r cur cur ,有 c n t cnt cnt 个值为 c u r cur cur 的数,下一步变成一个大于 a [ j ] a[j] a[j] 的最小的数 n x t nxt nxt。
那么就有 ( c u r − n x t ) × c n t (cur-nxt)\times cnt (cur−nxt)×cnt 的代价将所有的 c u r cur cur 变成 n x t nxt nxt 。
所有相邻元素如果差距过大,考虑将统一下降的部分额外操作出来,保证我们每次循环的开始都是先用一个操作将整体的数做下降。
这样可以保证一次遍历就可以将所有元素都变成 m i n ( a ) min(a) min(a) 。
时间复杂度: O ( n log n ) O(n\log n) O(nlogn)
代码
#include <bits/stdc++.h>
using namespace std;
typedef long long ll;
void solve() {
int n, k;
cin >> n >> k;
vector<int> a(n);
for (int i = 0; i < n; ++i) cin >> a[i];
sort(a.begin(), a.end(), greater<>());
if (a.back() == a.front()) {
cout << "0\n";
return;
}
// 一次 k 能够支撑多少元素下降
ll ans = 0;
ll pre = 0;
ll cnt = 0;
ll cur = -1;
for (int i = 0; i < n; ++i) {
ll s = k;
// 第一个 s 到某个 a[j - 1]
int j = i;
while (j < n && pre - cnt * a[j] <= s) {
cur = a[j];
pre += a[j];
cnt += 1;
j += 1;
}
// pre 肯定是一个统一的值
if (j > i) {
s -= pre - cnt * a[j - 1];
// 然后把剩下部分给整了
ll dec = s / cnt;
cur = a[j - 1] - dec;
pre = cur * cnt;
ans += 1;
}
// pre -> a[j - 1]
if (j < n) {
ll single_op = k / cnt;
ll single_decrease = single_op * cnt;
ll cnt_op = (cur - a[j] - 1) * cnt / single_decrease;
ans += cnt_op;
cur = cur - cnt_op * single_op;
pre = cur * cnt;
}
i = j - 1;
}
cout << ans << "\n";
}
int main()
{
ios::sync_with_stdio(false);
cin.tie(nullptr);
int T = 1;
//cin >> T;
while (T--) {
solve();
}
return 0;
}
文章评论