前言

又是一条新鲜出炉的字节跳动算法题,这道题如果你没有见过的话,有很大概率当场是没有办法解决这个问题的,因为这个算法题本质上是在考你是否知道蓄水池采样算法,还记得自己当时的一个学长就因为在字节跳动的终面中被问到这题目,但是由于他没有接触过这个算法,所以很遗憾的挂在了终面,下面让我们来学习下这个算法吧。

题目

给定一个单链表,随机选择链表的一个节点,并返回相应的节点值。保证每个节点被选的概率一样。
进阶:
如果链表十分大且长度未知,如何解决这个问题?你能否使用常数级空间复杂度实现?

示例:

1
2
3
4
5
6
7
8
// 初始化一个单链表 [1,2,3].
ListNode head = new ListNode(1);
head.next = new ListNode(2);
head.next.next = new ListNode(3);
Solution solution = new Solution(head);

// getRandom()方法应随机返回1,2,3中的一个,保证每个元素被返回的概率相等。
solution.getRandom();

思路

第一眼看到这个题目,大部分人的想法是遍历获得单链表的长度,接下来的工作就很简单了,但是从进阶的要求来看,显然这个题目是没有办法这样解决的,因为看到进阶的要求有一句话:如果链表十分大且长度未知,这个时候我们可能觉得束手无策,不过不要紧,下面我们就要开始介绍蓄水池采样算法来解决这个问题。

蓄水池采样算法

首先给出蓄水池采样算法的过程:

  1. 假设数据序列的规模为 𝑛,需要采样的数量的为 𝑘

  2. 首先构建一个可容纳 𝑘 个元素的数组,将序列的前 𝑘 个元素放入数组中

  3. 然后从第k+1个元素开始,以 $ k \over n $的概率来决定该元素是否被替换到数组中(数组中的元素被替换的概率是相同的)。 当遍历完所有元素之后,数组中剩下的元素即为所需采取的样本

知道这个算法之后,其实解决本题就很简单了,本题实际上是k=1时的一个特例,所以解决这个问题就很简单了,不过在给出题解的代码之前,先简单给出这个算法的数学证明,加深一下印象同时复习一下概率学的知识。

证明

这里我们需要用到经典的数学归纳法来证明这个问题:

  1. 当n=k时,k个元素被全部选中,没有问题

  2. 当链表长度为n的时候,每个元素被选中的概率为$ k \over n $

  3. 试证明当链表长度为n+1时,以$ k \over (n+1) $的概率选中第n+1个元素,能使得n+1个元素被选中的概率都是$ k \over (n+1) $

证明其实本身不难,但是需要注意的是要将各种情况考虑清楚,下面我从n+1个元素中任一元素被选中和未被选中两个角度来证明:

  • 先从这个元素被选中的情况下来计算,首先考虑一个元素在什么情况下会被选中呢,其实很简单,必然是在之前就被选中的元素,因为第n+1步中不可能选中前面的元素,所以$ k \over n $的概率是这一步计算的大前提,进一步考虑,当最后一个元素被选中需要随机替换前面k个元素中一个的时候,这个元素肯定不能是被选中的,因此这种情况的概率是$ k \over(n+1)$ * $(k-1)\over k $,再考虑一下,如果最后一个元素没有被选中,那么这个元素最后是不是也是被选中了,这种情况的概率是1 -$k \over(n+1)$,综上,一个元素被选中的概率是$ k \over n $ * {( $ k \over (n+1) $ * $ (k-1) \over k $ + 1 - $ k \over (n+1) $ )} = $ k \over (n+1) $

  • 再从这个元素没被选中的情况下来计算,我们考虑一个元素在哪些情况下最终会没被选中呢,有两种情况,第一种情况是在前面n个节点中这个元素就没被选中,一开始就比较背(笑),第二种情况是在前面n个元素中这个元素被选中了,但是第n+1个元素被选中了,且此元素恰好被第n+1个元素替换掉了,这次更背了(笑),这两种情况的概率加在一起的概率就是该元素最终未被选中的概率,综上,一个元素未被选中的概率为1 - $ k \over n $ + $ k \over n $ * $ k \over (n+1) $ * $ 1 \over k $ = 1 - $ k \over (n+1) $

题解代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
class Solution {
public:
/** @param head The linked list's head.
Note that the head is guaranteed to be not null, so it contains at least
one node. */
Solution(ListNode *head) { head_ = head; }

/** Returns a random node's value. */
int getRandom() {
int count = 0;
auto temp = head_;
int res = 0;
while (temp) {
count++;
auto rand_num = rand() % count;
if (rand_num == 0) {
res = temp->val;
}
temp = temp->next;
}
return res;
}

private:
ListNode *head_;
};