假设我们有一个数组,如何生成它的所有排列(即集合{A,B,C}上的所有置换)? 比如,ABC的所有排列是:
- ABC
- ACB
- BAC
- BCA
- CAB
- CBA
字典序
假设我们有一个数组,如何按照字典序生成下一个排列? 比如,假如输入的数组是{1,3,2},那么下一个排列就是{2,1,3}。
对于这个问题,我是这么想的:
- 首先考虑一个极端的特殊情形:数组里所有元素都是按降序排列,如7654321。那么它的下一个排列是不存在的。因为没有办法把它排列成一个更大的数字。
- 在此基础上考虑下一种情形:除了第一个元素外,剩下的都是按降序排列。 即,在上面的数组的左边再加一个额外的数组。 如47654321。那么这种情况下我们只需要把第一个元素变得稍微更大一些。即,把它和一个比它大的元素swap一下。 在这个例子中比4 大的元素中最小的是5。 所以我把第一个元素和5交换。 然后我们需要把从2个元素开始到末尾的所有元素逆序排列。即把第二个数字和最后一个数字交换。 把第三个数字和倒数第二个交换。 如此类推。因为这个操作能产生一个更小的序列。
- 然后就是更一般的情形: 在第二种情形的数组的左边再加上任意其它数字,算法依然保持不变。 多加的那些数字并不对逻辑造成任何影响,我们不需要去查看那些元素。 至此,我们已经列出了所有的可能性。
下面是按照上述思想实现的代码:
| 1 | // F's signature is bool(std::span<const int>) |
| 2 | #include <span> |
| 3 | template <typename F> |
| 4 | inline void VisitAllPermutation(std::span<int> a, F &&f) { |
| 5 | if (a.size() == 1) { |
| 6 | if (!f(a)) return; |
| 7 | return; |
| 8 | } |
| 9 | if (a.size() == 2) { |
| 10 | if (!f(a)) return; |
| 11 | std::swap(a[0], a[1]); |
| 12 | if (!f(a)) return; |
| 13 | return; |
| 14 | } |
| 15 | const size_t n = a.size(); |
| 16 | while (true) { |
| 17 | if (!f(a)) return; |
| 18 | size_t j = n - 1; |
| 19 | while (a[j - 1] >= a[j]) { |
| 20 | --j; |
| 21 | if (j == 0) return; |
| 22 | } |
| 23 | size_t l = n - 1; |
| 24 | --j; |
| 25 | while (a[j] >= a[l]) --l; |
| 26 | std::swap(a[j], a[l]); |
| 27 | // a[j+1] to the end |
| 28 | std::reverse(a.begin() + j + 1, a.end()); |
| 29 | } |
| 30 | return; |
| 31 | } |
写完之后我尝试如何改进它。于是我就想,在上述的第二种情况里,我们需要在一个有序的数组中做查找。那为何不用二分查找呢?上面的代码是从右往左扫描该有序数组,如果换成二分查找,效率会不会更高? 于是我写了下面的代码片段:
| 1 | // 把下面的代码稍作修改放在一个for循环里就能实现VisitAllPermutation的功能。 |
| 2 | #include <cassert> |
| 3 | #include <span> |
| 4 | inline void NextPermutation(std::span<int> nums) { |
| 5 | if (nums.size() <= 1) return; |
| 6 | if (nums.size() == 2) { |
| 7 | std::swap(nums[0], nums[1]); |
| 8 | return; |
| 9 | } |
| 10 | size_t i = nums.size() - 1; |
| 11 | while (i > 0) { |
| 12 | --i; |
| 13 | if (nums[i] >= nums[i + 1]) continue; |
| 14 | // numbers in range [i+1, end) are sorted |
| 15 | // 先reverse, 再做二分查找 |
| 16 | std::reverse(nums.begin() + i + 1, nums.end()); |
| 17 | auto iter = std::upper_bound(nums.begin() + i + 1, nums.end(), nums[i]); |
| 18 | // at least nums[i+1] is bigger than nums[i] |
| 19 | assert(iter != nums.end())(static_cast <bool> (iter != nums.end()) ? void (0) : __assert_fail ("iter != nums.end()", __builtin_FILE (), __builtin_LINE (), __extension__ __PRETTY_FUNCTION__)); |
| 20 | std::swap(*iter, nums[i]); |
| 21 | return; |
| 22 | } |
| 23 | |
| 24 | std::reverse(nums.begin(), nums.end()); |
| 25 | } |
然而,经测试发现比第一个程序慢了一倍多。因为,在这个算法里当我们做查找的时候,大部分时候要找到元素就在该数组的末尾。我觉得这个很有趣。它提醒我们:当选择算法的时候,不仅要考虑算法复杂度,还要考虑数据的分布是否均匀。
于是,这就提示我们: 可以通过loop-unrolling的方式提高热点代码的执行效率。比如下面的代码就要比使用STL的next_permutation来访问所有排列要快一倍多。
| 1 | // F's signature is bool(std::span<const int>) |
| 2 | #include <span> |
| 3 | template <typename F> |
| 4 | inline void VisitAllPermutation2(std::span<int> a, F&& f) { |
| 5 | if (a.size() == 1) { |
| 6 | if (!f(a)) return; |
| 7 | return; |
| 8 | } |
| 9 | if (a.size() == 2) { |
| 10 | if (!f(a)) return; |
| 11 | std::swap(a[0], a[1]); |
| 12 | if (!f(a)) return; |
| 13 | return; |
| 14 | } |
| 15 | const size_t n = a.size(); |
| 16 | while (true) { |
| 17 | if (!f(a)) return; |
| 18 | int* p = a.data() + n - 1; |
| 19 | int z = *p--; |
| 20 | int y = *p; |
| 21 | // int y = a[n - 2]; |
| 22 | // int z = a[n - 1]; |
| 23 | if (y < z) { |
| 24 | a[n - 2] = z; |
| 25 | a[n - 1] = y; |
| 26 | continue; |
| 27 | } |
| 28 | // int* p = a.data() + n - 3; |
| 29 | int x = *--p; |
| 30 | if (x < y) { |
| 31 | if (x < z) { |
| 32 | *p++ = z; |
| 33 | *p++ = x; |
| 34 | *p = y; |
| 35 | } else { |
| 36 | *p++ = y; |
| 37 | *p++ = z; |
| 38 | *p = x; |
| 39 | } |
| 40 | continue; |
| 41 | } |
| 42 | |
| 43 | size_t j = n - 3; |
| 44 | if (j == 0) { |
| 45 | return; |
| 46 | } |
| 47 | y = a[j - 1]; |
| 48 | |
| 49 | while (y >= x) { |
| 50 | --j; |
| 51 | if (j == 0) { |
| 52 | return; |
| 53 | } |
| 54 | x = y; |
| 55 | y = a[j - 1]; |
| 56 | } |
| 57 | if (y < z) { |
| 58 | a[j - 1] = z; |
| 59 | a[j] = y; |
| 60 | a[n - 1] = x; |
| 61 | std::reverse(a.begin() + j + 1, a.begin() + n - 1); |
| 62 | continue; |
| 63 | } |
| 64 | |
| 65 | size_t l = n - 2; |
| 66 | while (y >= a[l]) { |
| 67 | --l; |
| 68 | } |
| 69 | a[j - 1] = a[l]; |
| 70 | a[l] = y; |
| 71 | a[n - 1] = a[j]; |
| 72 | a[j] = z; |
| 73 | std::reverse(a.begin() + j + 1, a.begin() + n - 1); |
| 74 | } |
| 75 | return; |
| 76 | } |
非字典序算法
未完待续
不含三个连续下降数字的的置换
定义: 是集合上的一个置换。假设是置换的一个子序列(即)且里的数字是严格递减排列。那么被称为是一个长度为k的递减子序列。
下面我们特别关心k=3这种情况。我们考察集合上的所有置换里不含长度为3的递减子序列的置换有多少个。这个答案恰好是Catalan Numbers。
比如,321是一个长度为3的递减子序列。这个集合上所有其它的置换都不存在长度为3的递减子序列。这样的置换一共有3!-1=6-1=5个。
比如,4321是一个长度为4的递减子序列。这个集合上有如下置换存在长度3的递减子序列:
- 1,4,3,2
- 2,4,3,1
- 3,2,1,4
- 3,2,4,1
- 3,4,2,1
- 4,1,3,2
- 4,2,1,3
- 4,2,3,1
- 4,3,1,2
- 4,3,2,1
以上一共是7个。而有4!=24种置换。所以其中有24-7=14个不存在长度为3的递减子序列。