最近班里开始玩24点了。起因是一个在计算器上两人比赛24点的程序,但计算器判断一组数据是否有解需要15秒,于是这个程序就没有判定有解这一功能。
这么慢的速度我当然看不下去,但去优化那个BASIC程序是不可能的,我就开始写自己的24点程序。正好之前的算法课中递归一章提到过24点,我就理所当然地用开始写递归求解算法。
第一个版本是一个运行在PC上的非常复杂的C++程序,用上了十多个头文件。由于它太烂了,我把它注释掉以后又删掉了。这个程序最后算出1820组数据中有1362组有解数据,与网上查到的数字是一致的,不过算得很慢,要二十几秒。
这个时候我的想法还是单片机上的程序通过标准库随机数函数产生数据然后跑一遍求解算法。于是我就把这个程序修改了一下,标准库容器换成数组替换掉了,排序就随便写了个冒泡。PC上所有数据遍历一遍需要十几秒。
1 #include <iostream> 2 #include <cstdint> 3 4 using Integer = std::uint16_t; 5 6 template <typename T> 7 inline void swap(T& lhs, T& rhs) 8 { 9 auto temp = lhs; 10 lhs = rhs; 11 rhs = temp; 12 } 13 14 template <typename I> 15 inline void sort(I begin, I end) 16 { 17 for (auto pass_end = end - 1; pass_end != begin; --pass_end) 18 { 19 bool changed = false; 20 for (auto iter = begin; iter != pass_end; ++iter) 21 if (*(iter + 1) < *iter) 22 { 23 swap(*iter, *(iter + 1)); 24 changed = true; 25 } 26 if (!changed) 27 break; 28 } 29 } 30 31 int divide_count = 0, modulo_count = 0; 32 33 class Rational 34 { 35 public: 36 Integer num, den; 37 Rational(Integer num = 0, Integer den = 1) 38 : num(num), den(den) 39 { 40 // make every object reduced 41 reduce(); 42 } 43 Rational& operator=(Integer i) 44 { 45 num = i; 46 den = 1; 47 return *this; 48 } 49 Rational operator+(const Rational& rhs) const 50 { 51 // assume it won't overflow 52 return Rational(num * rhs.den + rhs.num * den, den * rhs.den); 53 } 54 Rational operator-(const Rational& rhs) const 55 { 56 // assume *this >= rhs 57 return Rational(num * rhs.den - rhs.num * den, den * rhs.den); 58 } 59 Rational operator*(const Rational& rhs) const 60 { 61 return Rational(num * rhs.num, den * rhs.den); 62 } 63 Rational operator/(const Rational& rhs) const 64 { 65 // assume rhs != 0 66 return Rational(num * rhs.den, den * rhs.num); 67 } 68 bool operator==(const Rational& rhs) const 69 { 70 return num == rhs.num && den == rhs.den; 71 } 72 bool operator==(Integer rhs) const 73 { 74 return num == rhs && den == 1; 75 } 76 bool operator<(const Rational& rhs) const 77 { 78 return num * rhs.den < rhs.num * den; 79 } 80 explicit operator bool() 81 { 82 return num; 83 } 84 private: 85 void reduce() 86 { 87 if (num == 1 || den == 1) 88 return; 89 if (num == 0) 90 { 91 den = 1; 92 return; 93 } 94 Integer gcd = 1; 95 auto a = num, b = den; 96 while (1) 97 { 98 if (a == 0 || a == b) 99 { 100 gcd = b; 101 break; 102 } 103 if (b == 0) 104 { 105 gcd = a; 106 break; 107 } 108 if (a > b) 109 { 110 ++modulo_count; 111 a %= b; 112 } 113 else 114 { 115 ++modulo_count; 116 b %= a; 117 } 118 } 119 if (gcd > 1) 120 { 121 divide_count += 2; 122 num /= gcd; 123 den /= gcd; 124 } 125 } 126 }; 127 128 template <typename S> 129 S& operator<<(S& lhs, const Rational& rhs) 130 { 131 lhs << rhs.num; 132 if (rhs.den > 1) 133 lhs << '/' << rhs.den; 134 return lhs; 135 } 136 137 struct Expression 138 { 139 Expression() = default; 140 Expression(const Rational& lhs, char op, const Rational& rhs, 141 const Rational& res) 142 : lhs(lhs), rhs(rhs), res(res), op(op) { } 143 char op = ' '; 144 Rational lhs, rhs, res; 145 }; 146 147 template <typename S> 148 S& operator<<(S& lhs, const Expression& rhs) 149 { 150 lhs << rhs.lhs << ' ' << rhs.op << ' ' << rhs.rhs << " = " << rhs.res; 151 return lhs; 152 } 153 154 constexpr Integer target = 24; 155 constexpr Integer max_count = 4; 156 157 bool solve(Integer count, const Rational* data, Expression* expr) 158 { 159 // assume data is ordered 160 if (count == 1) 161 return *data == target; 162 auto end = data + count; 163 auto before_end = end - 1; 164 --count; 165 Rational new_data[max_count - 1]; 166 auto new_end = new_data + count; 167 for (auto lhs = data; lhs != before_end; ++lhs) 168 for (auto rhs = lhs + 1; rhs != end; ++rhs) 169 { 170 auto dst = new_data; 171 for (auto src = data; src != end; ++src) 172 if (src != lhs && src != rhs) 173 *dst++ = *src; 174 *dst = *lhs + *rhs; 175 Expression temp(*lhs, '+', *rhs, *dst); 176 sort(new_data, new_end); 177 if (solve(count, new_data, expr + 1)) 178 { 179 *expr = temp; 180 return true; 181 } 182 } 183 for (auto lhs = data + 1; lhs != end; ++lhs) 184 for (auto rhs = data; rhs != lhs; ++rhs) 185 { 186 auto dst = new_data; 187 for (auto src = data; src != end; ++src) 188 if (src != lhs && src != rhs) 189 *dst++ = *src; 190 *dst = *lhs - *rhs; 191 Expression temp(*lhs, '-', *rhs, *dst); 192 sort(new_data, new_end); 193 if (solve(count, new_data, expr + 1)) 194 { 195 *expr = temp; 196 return true; 197 } 198 } 199 for (auto lhs = data; lhs != before_end; ++lhs) 200 for (auto rhs = lhs + 1; rhs != end; ++rhs) 201 { 202 auto dst = new_data; 203 for (auto src = data; src != end; ++src) 204 if (src != lhs && src != rhs) 205 *dst++ = *src; 206 *dst = *lhs * *rhs; 207 Expression temp(*lhs, '*', *rhs, *dst); 208 sort(new_data, new_end); 209 if (solve(count, new_data, expr + 1)) 210 { 211 *expr = temp; 212 return true; 213 } 214 } 215 for (auto lhs = data; lhs != end; ++lhs) 216 for (auto rhs = data; rhs != end; ++rhs) 217 { 218 if (lhs == rhs || *rhs == Rational(0)) 219 continue; 220 auto dst = new_data; 221 for (auto src = data; src != end; ++src) 222 if (src != lhs && src != rhs) 223 *dst++ = *src; 224 *dst = *lhs / *rhs; 225 Expression temp(*lhs, '/', *rhs, *dst); 226 sort(new_data, new_end); 227 if (solve(count, new_data, expr + 1)) 228 { 229 *expr = temp; 230 return true; 231 } 232 } 233 return false; 234 } 235 236 bool test(Integer a, Integer b, Integer c, Integer d) 237 { 238 Rational data[6]; 239 Expression expr[3]; 240 data[0] = a; 241 data[1] = b; 242 data[2] = c; 243 data[3] = d; 244 std::cout << a << ", " << b << ", " << c << ", " << d 245 << ':' << std::endl; 246 bool solved = solve(4, data, expr); 247 if (solved) 248 for (const auto& e : expr) 249 std::cout << '\t' << e << std::endl; 250 else 251 std::cout << "\tno solution" << std::endl; 252 return solved; 253 } 254 255 int main() 256 { 257 int count = 0; 258 constexpr Integer max_num = 13; 259 for (int a = 1; a <= max_num; ++a) 260 for (int b = a; b <= max_num; ++b) 261 for (int c = b; c <= max_num; ++c) 262 for (int d = c; d <= max_num; ++d) 263 if (test(a, b, c, d)) 264 ++count; 265 std::cout << count << ' ' << divide_count << ' ' 266 << modulo_count << std::endl; 267 268 test(1, 3, 4, 6); 269 test(1, 4, 5, 6); 270 test(1, 5, 5, 5); 271 test(1, 6, 11, 13); 272 test(2, 2, 11, 11); 273 test(2, 2, 13, 13); 274 test(2, 7, 7, 10); 275 test(3, 3, 7, 7); 276 test(3, 3, 8, 8); 277 test(3, 7, 9, 13); 278 test(4, 4, 7, 7); 279 test(5, 5, 7, 11); 280 }
稍微解释一下这个程序。先定义了Integer类型别名,本来应该是uint8_t,意思是单片机是8位字长的,但流输入输出中会被当成char处理,而我想要的是整数,就换成了uint16_t。swap和sort用于替换标准库中同名函数,后者是冒泡排序,反正待排序的元素至多4个。然后是Rational类,表示非负有理数,自动约分,用的是辗转相除。由于单片机算除法和取模极慢,我把特殊情况排除掉了,类似于剪枝的思想。另有divide_count和modulo_count两变量用于统计除法和取模次数。Expression类表示表达式,24点的解法由3个表达式组成。以及Rational类和Expression类的流插入运算符重载。
递归求解函数参数有3个:要求解的数字个数count、输入数据,data以及表达式存放位置expr。函数假设输入数据已经有序。递归出口为count == 1,检查唯一的数据是否为24。其他情况下,程序会选两个数相加、相减、相乘、相除(排除除数为0的情况),把这两个数从序列中删除再加入新的运算结果,然后排序传给递归子程序。如果子程序返回true,就把这一步运算存入*expr。
这里还有一个小插曲。写完这个程序后,第一次运行的结果是有1320多组有解,然而答案应该是1362。我试着debug,但递归函数相当难debug,我只能在递归中写输出语句,最后发现是排序算法出了问题。那么简单的冒泡排序我当然不会写错,问题在于*(iter + 1) < *iter这句,一开始写的是*iter > *(iter + 1),然而Rational类并没有重载operator>,实际调用的是两个operator bool。我把比较方向换了过来,又在operator bool前加了explicit关键字以防万一,结果就正确了。这个故事告诉我们重载关系运算符要么乖乖全部写出来,要么写using namespace std::rel_ops。
然后移植到了AVR单片机上(没有C++标准库,这就是为什么前一个程序刻意避免了那么好用的标准库),开发板用的是写系列教程那块。输入是硬编码,输出是几个LED,3组数据各跑100遍,通过LED和秒表测运行速度。运行结果是300遍有解的数据求解一共用了15秒,其中最后一组大约用了一半时间。我有充足理由估计对无解的数据跑一遍求解需要至少100毫秒。
如果我要保证提供给用户的数据有解,一定会出现200ms以上的延迟,就算不用保证,也至少要100毫秒计算时间,我认为这是不能接受的。所以这个求解算法太慢了。为了用户体验,需要一种更快的算法,然而我并没有办法把求解算法优化掉一个量级。
那唯一的方法就是不求解了。等等,不求解?那数据怎么来?要手写数据?还是要模板元编程把数据在编译器就存起来?
想多了。让PC端程序数据一定格式的数据作为单片机程序的代码,然后直接读取即可。由于数据量比较大,单片机2KB内存放不下,必须放在flash中(就算数据量不大放在RAM中也是浪费)。数据长成这样:
1 #ifndef DATA_H 2 #define DATA_H 3 4 #include <avr/pgmspace.h> 5 6 const uint8_t valid_data[][5] PROGMEM = 7 { 8 17, 129, 32, 130, 117, 9 17, 177, 32, 98, 180, 10 17, 193, 32, 146, 117, 11 17, 209, 32, 75, 180, 12 17, 98, 32, 130, 117, 13 // ... 14 // 1362 lines in all 15 // ... 16 219, 221, 32, 130, 109, 17 204, 204, 32, 130, 109, 18 204, 220, 32, 75, 149, 19 204, 221, 32, 130, 109, 20 220, 221, 32, 122, 172, 21 }; 22 23 const uint8_t invalid_data[][2] PROGMEM = 24 { 25 17, 17, 26 17, 33, 27 17, 49, 28 17, 65, 29 17, 81, 30 // ... 31 // 458 lines in all 32 // ... 33 186, 187, 34 186, 221, 35 187, 187, 36 187, 221, 37 221, 221, 38 }; 39 40 #endif
数据格式是我一拍脑袋定的:前两字节的低4位、高4位分别是从小到大的4个数字;后三字节是3个表达式,最低3位为LHS下标,中间2位为运算符,最高3位为RHS下标。运算符域从0到3分别是加减乘除,LHS和RHS的下标定义为依次存放输入数字和中间结果的数组中相等元素的下标;对于无解的数据,只有前两个字节。
这样单片机就无需对输入数据求解,只需根据压缩成字节码的表达式复原出原来的表达式,涉及到少量分数运算而已。实验证明这样的算法是足够快的,至少没有超过16毫秒的显示屏刷新时间。
单片机端的程序框架无非是定时器中断中更新显示,其他如硬件驱动等函数都是现成的。程序用到两个按键,分别用于切换刷新状态与显示答案。左边的按键按一下开始刷新数据,再按一下停止刷新,显示一组数据;对于有解的数据,右边的按键按第一下会显示最后一行,第二下会显示完整答案;对于无解的数据,按右边的按键会显示“no solution”。
我以前玩的24点都是1~10的数字,这次是真实模拟扑克牌环境的A~K。10这个数字如果正常显示需要2位,不美观,因此我用画点和画线的操作组合出了在一个字符空间内画10的操作。
为了增加可玩性,我还加入了无解的数据,概率大约为1/6。一群人围着一道题想了半分钟后发现是“no solution”是最爽的事。
实际上这个24点程序还远不完美。单片机经常在屏幕上输出诡异的解法,比如10 * 12 = 120,120 / 5 = 24,这些是不符合人类计算逻辑的,正常人想到的都是10 / 5 = 2,2 * 12 = 24。一个可行的方法是把递归搜索的顺序换一下,先减再加,先除后乘,在除法中优先用最大的数除以最小的数。但还是会出现12 / 5 = 12/5,12/5 * 10 = 24这样的式子,最根本的算法还是根据表达式建立树,在树上调整顺序。也许4个数算24点的情况不需要这么复杂,但这是万能的、具有可扩展性的做法(也有可能是我想多了)。
点一下,玩一年。24点这么好玩,我肯定不能止步于4个数加减乘除算出24这种简单的游戏。这句话暗示得很清楚了吧,我们中篇再见。