假设我们有一个数组,如何生成它的所有排列(即集合{A,B,C}上的所有置换)? 比如,ABC的所有排列是:
- ABC
- ACB
- BAC
- BCA
- CAB
- CBA
字典序
假设我们有一个数组,如何按照字典序生成下一个排列? 比如,假如输入的数组是{1,3,2},那么下一个排列就是{2,1,3}。
对于这个问题,我是这么想的:
- 首先考虑一个极端的特殊情形:数组里所有元素都是按降序排列,如7654321。那么它的下一个排列是不存在的。因为没有办法把它排列成一个更大的数字。
- 在此基础上考虑下一种情形:除了第一个元素外,剩下的都是按降序排列。 即,在上面的数组的左边再加一个额外的数组。 如47654321。那么这种情况下我们只需要把第一个元素变得稍微更大一些。即,把它和一个比它大的元素swap一下。 在这个例子中比4 大的元素中最小的是5。 所以我把第一个元素和5交换。 然后我们需要把从2个元素开始到末尾的所有元素逆序排列。即把第二个数字和最后一个数字交换。 把第三个数字和倒数第二个交换。 如此类推。因为这个操作能产生一个更小的序列。
- 然后就是更一般的情形: 在第二种情形的数组的左边再加上任意其它数字,算法依然保持不变。 多加的那些数字并不对逻辑造成任何影响,我们不需要去查看那些元素。 至此,我们已经列出了所有的可能性。
下面是按照上述思想实现的代码:
1 | // F's signature is bool(std::span<const int>) |
2 | #include <span> |
3 | template <typename F> |
4 | inline void VisitAllPermutation(std::span<int> a, F &&f) { |
5 | if (a.size() == 1) { |
6 | if (!f(a)) return; |
7 | return; |
8 | } |
9 | if (a.size() == 2) { |
10 | if (!f(a)) return; |
11 | std::swap(a[0], a[1]); |
12 | if (!f(a)) return; |
13 | return; |
14 | } |
15 | const size_t n = a.size(); |
16 | while (true) { |
17 | if (!f(a)) return; |
18 | size_t j = n - 1; |
19 | while (a[j - 1] >= a[j]) { |
20 | --j; |
21 | if (j == 0) return; |
22 | } |
23 | size_t l = n - 1; |
24 | --j; |
25 | while (a[j] >= a[l]) --l; |
26 | std::swap(a[j], a[l]); |
27 | // a[j+1] to the end |
28 | std::reverse(a.begin() + j + 1, a.end()); |
29 | } |
30 | return; |
31 | } |
写完之后我尝试如何改进它。于是我就想,在上述的第二种情况里,我们需要在一个有序的数组中做查找。那为何不用二分查找呢?上面的代码是从右往左扫描该有序数组,如果换成二分查找,效率会不会更高? 于是我写了下面的代码片段:
1 | // 把下面的代码稍作修改放在一个for循环里就能实现VisitAllPermutation的功能。 |
2 | #include <cassert> |
3 | #include <span> |
4 | inline void NextPermutation(std::span<int> nums) { |
5 | if (nums.size() <= 1) return; |
6 | if (nums.size() == 2) { |
7 | std::swap(nums[0], nums[1]); |
8 | return; |
9 | } |
10 | size_t i = nums.size() - 1; |
11 | while (i > 0) { |
12 | --i; |
13 | if (nums[i] >= nums[i + 1]) continue; |
14 | // numbers in range [i+1, end) are sorted |
15 | // 先reverse, 再做二分查找 |
16 | std::reverse(nums.begin() + i + 1, nums.end()); |
17 | auto iter = std::upper_bound(nums.begin() + i + 1, nums.end(), nums[i]); |
18 | // at least nums[i+1] is bigger than nums[i] |
19 | assert(iter != nums.end())(static_cast <bool> (iter != nums.end()) ? void (0) : __assert_fail ("iter != nums.end()", __builtin_FILE (), __builtin_LINE (), __extension__ __PRETTY_FUNCTION__)); |
20 | std::swap(*iter, nums[i]); |
21 | return; |
22 | } |
23 | |
24 | std::reverse(nums.begin(), nums.end()); |
25 | } |
然而,经测试发现比第一个程序慢了一倍多。因为,在这个算法里当我们做查找的时候,大部分时候要找到元素就在该数组的末尾。我觉得这个很有趣。它提醒我们:当选择算法的时候,不仅要考虑算法复杂度,还要考虑数据的分布是否均匀。
于是,这就提示我们: 可以通过loop-unrolling的方式提高热点代码的执行效率。比如下面的代码就要比使用STL的next_permutation
来访问所有排列要快一倍多。
1 | // F's signature is bool(std::span<const int>) |
2 | #include <span> |
3 | template <typename F> |
4 | inline void VisitAllPermutation2(std::span<int> a, F&& f) { |
5 | if (a.size() == 1) { |
6 | if (!f(a)) return; |
7 | return; |
8 | } |
9 | if (a.size() == 2) { |
10 | if (!f(a)) return; |
11 | std::swap(a[0], a[1]); |
12 | if (!f(a)) return; |
13 | return; |
14 | } |
15 | const size_t n = a.size(); |
16 | while (true) { |
17 | if (!f(a)) return; |
18 | int* p = a.data() + n - 1; |
19 | int z = *p--; |
20 | int y = *p; |
21 | // int y = a[n - 2]; |
22 | // int z = a[n - 1]; |
23 | if (y < z) { |
24 | a[n - 2] = z; |
25 | a[n - 1] = y; |
26 | continue; |
27 | } |
28 | // int* p = a.data() + n - 3; |
29 | int x = *--p; |
30 | if (x < y) { |
31 | if (x < z) { |
32 | *p++ = z; |
33 | *p++ = x; |
34 | *p = y; |
35 | } else { |
36 | *p++ = y; |
37 | *p++ = z; |
38 | *p = x; |
39 | } |
40 | continue; |
41 | } |
42 | |
43 | size_t j = n - 3; |
44 | if (j == 0) { |
45 | return; |
46 | } |
47 | y = a[j - 1]; |
48 | |
49 | while (y >= x) { |
50 | --j; |
51 | if (j == 0) { |
52 | return; |
53 | } |
54 | x = y; |
55 | y = a[j - 1]; |
56 | } |
57 | if (y < z) { |
58 | a[j - 1] = z; |
59 | a[j] = y; |
60 | a[n - 1] = x; |
61 | std::reverse(a.begin() + j + 1, a.begin() + n - 1); |
62 | continue; |
63 | } |
64 | |
65 | size_t l = n - 2; |
66 | while (y >= a[l]) { |
67 | --l; |
68 | } |
69 | a[j - 1] = a[l]; |
70 | a[l] = y; |
71 | a[n - 1] = a[j]; |
72 | a[j] = z; |
73 | std::reverse(a.begin() + j + 1, a.begin() + n - 1); |
74 | } |
75 | return; |
76 | } |
非字典序算法
未完待续
不含三个连续下降数字的的置换
定义: 是集合上的一个置换。假设是置换的一个子序列(即)且里的数字是严格递减排列。那么被称为是一个长度为k的递减子序列。
下面我们特别关心k=3这种情况。我们考察集合上的所有置换里不含长度为3的递减子序列的置换有多少个。这个答案恰好是Catalan Numbers。
比如,321是一个长度为3的递减子序列。这个集合上所有其它的置换都不存在长度为3的递减子序列。这样的置换一共有3!-1=6-1=5个。
比如,4321是一个长度为4的递减子序列。这个集合上有如下置换存在长度3的递减子序列:
- 1,4,3,2
- 2,4,3,1
- 3,2,1,4
- 3,2,4,1
- 3,4,2,1
- 4,1,3,2
- 4,2,1,3
- 4,2,3,1
- 4,3,1,2
- 4,3,2,1
以上一共是7个。而有4!=24种置换。所以其中有24-7=14个不存在长度为3的递减子序列。