1. 第K小的和
前置知识点:二分,排序
1. 问题描述
给定两个序列 A , B A, B A,B,长度分别为 n , m n, m n,m。
设另有一个序列 C C C 中包含了 A , B A, B A,B 中的数两两相加的结果 ( C C C 中共有 n × m n \times m n×m 个数)。问 C C C 中第 K K K 小的数是多少。请注意重复的数需要计算多次。例如 1 , 1 , 2 , 3 1,1,2,3 1,1,2,3 中,最小和次小都是 1 1 1,而 3 3 3 是第 4 4 4 小。
2. 输入格式
输入的第一行包含三个整数 n , m , K n, m, K n,m,K,相邻两个整数之间使用一个空格分隔。
第二行包含 n n n 个整数,分别表示 A 1 , A 2 , … , A n A_1, A_2, \ldots, A_n A1,A2,…,An,相邻两个整数之间使用一个空格分隔。
第三行包含 m m m 个整数,分别表示 B 1 , B 2 , … , B m B_1, B_2, \ldots, B_m B1,B2,…,Bm,相邻两个整数之间使用一个空格分隔。
3. 输出格式
输出一行包含一个整数表示答案。
4. 样例输入
3 4 5
1 3 4
2 3 5 6
5. 样例输出
6
6. 评测用例规模与约定
- 对于 40 % 40\% 40% 的评测用例, n , m ≤ 5000 n, m \leq 5000 n,m≤5000, A i , B i ≤ 1000 A_i, B_i \leq 1000 Ai,Bi≤1000;
- 对于所有评测用例, 1 ≤ n , m ≤ 1 0 5 1 \leq n, m \leq 10^5 1≤n,m≤105, 1 ≤ A i , B i ≤ 1 0 9 1 \leq A_i, B_i \leq 10^9 1≤Ai,Bi≤109, 1 ≤ K ≤ n × m 1 \leq K \leq n \times m 1≤K≤n×m。
7. 原题链接
2. 解题思路
让我们先考虑一个暴力解法:我们可以枚举出所有 n × m n \times m n×m 个数,将其存入数组并进行排序,然后输出第 K K K 大的数即可求解。然而,这种解法的时间复杂度为 O ( n m log ( n m ) ) O(nm \log(nm)) O(nmlog(nm)),空间复杂度为 O ( n m ) O(nm) O(nm),这将导致超时和内存溢出。
于是,我们需要寻找一个更高效的解决方案。让我们来思考一个问题:在一个序列中,第 K K K 大的数 x x x 有什么特性?
这需要满足数组中小于等于 x x x 的数至少有 K K K 个。同时,我们可以观察到,对于一个大于 x x x 的数 y y y,序列中小于等于 y y y 的数也一定至少有 K K K 个。对于一个小于 x x x 的数 z z z,序列中小于等于 z z z 的数肯定少于 K K K 个。
这让我们发现,寻找序列中的第 K K K 大的数,其实等同于寻找序列中小于等于当前数的数量至少有 K K K 个的最小值。因此,我们可以采用二分查找的策略。
当考虑答案的下界时, A A A 和 B B B 都取 1 1 1,那么 C C C 中的数全部为 2 2 2,所以答案的下界是 2 2 2。而对于上界,当 A A A 和 B B B 都取 1 0 9 10^9 109, C C C 中的数全部为 2 × 1 0 9 2 \times 10^9 2×109,所以答案的上界是 2 × 1 0 9 2 \times 10^9 2×109。
我们的关键问题在于如何编写二分查找中的 check
函数,即如何确定在 n × m n \times m n×m 个数中,有多少个数比一个给定的整数 x x x 小。直接遍历这 n × m n \times m n×m 个数显然是不可行的,我们需要找到一个优化的方法。一个可行的策略是先将 B B B 数组排序,然后遍历数组 A A A。对于每个 A i A_i Ai,我们需要在 B B B 数组中找到满足 B j + A i ≤ x B_j + A_i \leq x Bj+Ai≤x 的最大下标 j ( j ∈ [ 1 , m ] ) j(j \in[1,m]) j(j∈[1,m])。
这个问题等价于在 B B B 数组中找到最大的数,使其小于等于 x − A i x - A_i x−Ai,这显然是一个基础的二分查找问题。这样我们就能以 O ( log m ) O(\log m) O(logm) 的复杂度统计每个 A i A_i Ai 的贡献,加上遍历数组 A A A 的复杂度为 O ( n ) O(n) O(n),因此每次 check
函数的复杂度为 O ( n log m ) O(n \log m) O(nlogm)。
对于每次二分查找的数 x x x,在 check
函数中,如果我们统计到有至少 K K K 个数小于等于 x x x,我们就返回 true
,否则返回 false
。
因此,这种解法的时间复杂度为: O ( n log ( m ) log ( 2 × 1 0 9 ) ) O(n \log(m) \log(2 \times 10^9)) O(nlog(m)log(2×109))。
3. AC_Code
- C++
#include <bits/stdc++.h>
using namespace std;
typedef long long LL;
LL n, m, k;
void solve() {
cin >> n >> m >> k;
vector<int> a(n), b(m);
for (int i = 0; i < n; ++i) {
cin >> a[i];
}
for (int i = 0; i < m; ++i) {
cin >> b[i];
}
sort(b.begin(), b.end());
LL l = 2, r = 2e9;
auto check = [&](LL x) {
LL res = 0;
for (int i = 0; i < n; ++i) {
res += upper_bound(b.begin(), b.end(), x - a[i]) - b.begin();
}
return res >= k;
};
while (l < r) {
LL mid = l + r >> 1;
if (check(mid))
r = mid;
else
l = mid + 1;
}
cout << r << '\n';
}
int main() {
ios_base ::sync_with_stdio(false);
cin.tie(0);
cout << setiosflags(ios::fixed) << setprecision(2);
int t = 1;
while (t--) {
solve();
}
return 0;
}
- Java
import java.util.*;
import java.io.*;
public class Main {
static long n, m, k;
static long ans = 0;
static boolean check(long x, long[] a, long[] b) {
long res = 0;
for (int i = 0; i < n; ++i) {
res += upperBound(b, x - a[i]);
}
return res >= k;
}
static int upperBound(long[] a, long x) {
int l = 0, r = a.length;
while (l < r) {
int mid = (l + r) / 2;
if (a[mid] <= x) l = mid + 1;
else r = mid;
}
return l;
}
public static void main(String[] args) {
Scanner sc = new Scanner(System.in);
n = sc.nextLong();
m = sc.nextLong();
k = sc.nextLong();
long[] a = new long[(int) n];
long[] b = new long[(int) m];
for (int i = 0; i < n; ++i) {
a[i] = sc.nextLong();
}
for (int i = 0; i < m; ++i) {
b[i] = sc.nextLong();
}
Arrays.sort(b);
long l = 2, r = (long) 2e9;
while (l < r) {
long mid = l + (r - l) / 2;
if (check(mid, a, b))
r = mid;
else
l = mid + 1;
}
System.out.println(r);
}
}
- Python
import bisect
n, m, k = map(int, input().split())
a = list(map(int, input().split()))
b = sorted(list(map(int, input().split())))
l, r = 2, int(2e9)
def check(x):
res = 0
for i in range(n):
res += bisect.bisect_right(b, x - a[i])
return res >= k
while l < r:
mid = l + (r - l) // 2
if check(mid):
r = mid
else:
l = mid + 1
print(r)
文章评论