对抗搜索(一)

梦想与她 提交于 2020-08-13 19:16:21

        最近看刘汝佳老师的《算法经典入门训练指南》,搜相关的算法博客时,发现一本神书《人工智能  一种现代的方法》(以下简称《人工智能》),里面囊括的算法,也让我对算法有了新的认知。在学校时,相信大家都学过数据结构和算法,这些算法是大家接触的最基础的算法,再往上走,大家做人工智能,又涉及到机器学习和深度学习。鉴于自己的认知视野优先,总感觉这两类之间的算法中间跳过了很多东西。因此最近看算法时,不免要搜到很多博客,才知道算法的种类很多,每个类下面又有很多分支,图大类下的路径搜索,约束满足问题,以及这篇要讲的对抗搜索等等。东西太多但是在了解的过程中,也解答了以前的一些疑惑,比如路径搜索在游戏中的应用,约束满足的应用,以及对抗搜索在棋牌类的应用。最后感叹在校期间没有参加ACM训练,没有抓住机会扩展自己的视野。强烈建议在校生有机会参加参加ACM训练,拿不到奖牌也可以扩展自己的视野。真的很重要。正文开始:

      对抗搜索

        (本文很多东西都是参考上面两本书,文末会贴上自己的代码。)

        相信大家在网上或多或少的都玩过很多对抗类游戏,比如五子棋、象棋、国际象棋、围棋等。初期时,可能很多时候是和电脑进行人机联系。电脑方的出牌策略就是应用了很多对抗搜索的算法(对抗类游戏可能有多个人参与,本文只讨论二人对抗的游戏,多人对抗的请参考《人工智能》这本书)。

        这些问题非常难于求解,例如国际象棋的平均分支因子大约是35,一盘棋一般每个游戏者走50步,所以搜索树大约有35100即10154个结点(尽管整个状态空间“只”约1040个不同结点)。和现实世界一样,游戏要求即使无法找到最优决策也必须能做某种决策,而不能花费太多的时间。换句话说,这些游戏有严格的时间限制(time limit)。所以对博弈的研究也产生了一些有趣的思想,如何尽可能充分的利用好时间。

        我们知道对抗类游戏是参与者轮流出招,我们可以将其写成类似于路径寻找的问题:

        初始状态:包括棋盘局面和确定该哪个游戏者出招
        后继函数:返回(move, state)列表,每一项表示一个合法招数和对应的结果状态。
        终止测试:判断游戏是否结束。游戏结束的状态称为终止状态。
        效用函数:也称目标函数或收益函数,是终止状态的得分。国际象棋中赢、输、平分别是1,-1和0分,而围棋、黑白棋等可以有更多的结果。

        考虑一个简单的游戏:井字棋。现在有两个参与者MAX和MIN,在3×3的棋盘上,MAX划叉,MIN划圆圈。任何一种图案占据了一行或者一列或者一整条斜对角线(主副对角线),那么判定相应的游戏者获胜。如下图(摘自《人工智能》)。初始状态棋盘为空,然后依次由MAX、MIN方轮流走,这样就形成了一颗类似于搜索树的博弈树。

           

       这个图列举了双方能接受的所有选择。我们可以看到只有叶子节点才有评价函数,可以看到从根节点到叶子节点是双方按照当前路径走下来的最终结果(赢、输、平局)。每条路径都对应一个结果,双反不论在什么时候,肯定都要选择“最利于”自己获胜的步骤。此时的核心问题就是在每一步的时候,MAX/MIN如何来设计评价函数来选择“最利于”自己的下棋步骤。比如在第一步时,MAX到底该如何得知自己要选择9个选择中的哪一个。

        这里采用极大极小值方法: 对MAX方来说,评价函数越大越好,而对MIN方来说,评价函数越小越好。也就是在每一步中,MAX方选择所有节点中评价函数最大的节点,作为自己当前的落棋选择,而MIN则相反。如果一个MIN结点有三个儿子,评价值分别为3,4,-1。最聪明的对手一定会选择那个-1的儿子(这样对MAX最不利),而如果对手并没有发现这个走步(或者并不觉得它的后继状态对MAX最不利),它可能选择的是3或者4。

        可惜由于博弈树太大,如果要直接追踪到最终状态,这对于计算机来说也是一个超大的负荷,因此合理的方案是在固定深度截断,在这个深度内的“叶子节点”双发按照极大极小值方法来选择自己每一步的落棋选择。对于井字棋游戏,一个可能的评价函数是:
      e(s) = (MAX可能占有的行/列/对角线数) - (MIN可能占有的行/列/对角线数)

其中“可能占有”的意思是“此行/列/对角线”不含对方的符号。更复杂的评价函数往往是对各种特征进行加权计算。 下图是深度为2时的评价函数计算。

可以验证对max的第一步来说,选择走中间那个节点是最优的选择。如果此时MIN选择走第一行正中,那么此时节点的部分搜索树如下。

刚才所述的算法成为MAXMIN算法,我们采取递归的计算方式来描述整个算法:

int max_value ( int dep , state s ){
    if ( terminal ( s )) return e ( s ); //终止状态
    if ( dep == maxdepth ) return e ( s ); //深度截断,返回评价函数
    v = - inf ; //初始化为负无穷
    succ = make_successors ( s ); // succ [ i为第]个后继状态i
    for ( i = 0; i < succ . count ; i ++)
        v = max (v , min_value ( succ [ i ])); //计算所有儿子的最大值
    return v ;
}
int min_value ( int dep , state s ){
    if ( terminal ( s )) return e ( s ); //终止状态
    if ( dep == maxdepth ) return e ( s ); //深度截断,返回评价函数
    v = inf ; //初始化为无穷大
    succ = make_successors ( s ); // succ [ i为第]个后继状态i
    for ( i = 0; i < succ . count ; i ++)
        v = min (v , max_value ( succ [ i ])); //计算所有儿子的最小值(刘汝佳老师的书中是错的)
    return v ;
}

 

文末附上“井字棋”的完整JAVA代码

package search;

import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;

import javax.xml.parsers.FactoryConfigurationError;
import java.util.ArrayList;
import java.util.HashSet;
import java.util.List;
import java.util.Set;

/**
 * 对抗搜索-井字棋游戏
 */
public class Search_Ant {
    private static int n = 3;
    private static final int maxPlayer = 1;
    private static final int minPlayer = 2;
    private static int maxDepth = 3; //对抗预测的最大深度
    private static int[][] initArray = new int[n][n];
    private static int depth = 0;
    private static int alpha = -10000;
    private static int beta = 10000;


    public static void main(String[] args) {
        State initState = new State();
        initState.setCurrentState(initArray);
        antSearch(initState, depth);
    }

    /**
     * 开始对抗搜索
     * 偶数max执行;
     * 奇数min执行;
     * 共对抗执行的最大次数: n * n ;
     */
    public static void antSearch(State state, int depth) {
        if(isSuccess(state)){
            System.out.println("某一方成功赢了");
            return;
        }
        if (depth >= n * n) {
            System.out.println("双方平局");
            return;
        }

        //偶数次由max方走,奇数次由min方走;
        if (depth % 2 == 0) {
            maxValue(state, 0);
            int rowIndex = state.getNextBestState().getRowIndex();
            int columnIndex = state.getNextBestState().getColumnIndex();
            state.setRowIndex(rowIndex);
            state.setColumnIndex(columnIndex);
            state.getCurrentState()[rowIndex][columnIndex] = 1;
            display(state, maxPlayer, depth);
//            System.out.println(String.format("max方执行: (%s,%s)", state.getRowIndex(), state.getColumnIndex()));
        } else if (depth % 2 == 1) {
            minValue(state, 0);
            int rowIndex = state.getNextBestState().getRowIndex();
            int columnIndex = state.getNextBestState().getColumnIndex();
            state.setRowIndex(rowIndex);
            state.setColumnIndex(columnIndex);
            state.getCurrentState()[rowIndex][columnIndex] = 2;
            display(state, minPlayer, depth);
//            System.out.println(String.format("min方执行: (%s,%s)", state.getRowIndex(), state.getColumnIndex()));
        }
//        ++depth;
//        开始下一次迭代
//        antSearch(state, depth);
    }

    /**
     * 获得当前状态的后继节点
     *
     * @param currentState
     * @return
     */
    public static List<State> getSuccessor(State currentState, int player) {
        List<State> successorList = new ArrayList<State>();
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (currentState.getCurrentState()[i][j] == 0) {
                    int[][] array = copyArray(currentState.getCurrentState());
                    if (player == 1) {
                        array[i][j] = 1;
                    } else {
                        array[i][j] = 2;
                    }
                    State nextState = new State(i, j, array);
                    successorList.add(nextState);
                }
            }
        }
        return successorList;
    }

    public static int[][] copyArray(int[][] array) {
        int[][] copyArray = new int[n][n];
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                copyArray[i][j] = array[i][j];
            }
        }
        return copyArray;
    }

    /**
     * max方执行
     *
     * @return
     */
    public static int maxValue(State state, int currentDepth) {
        int[][] currentState = state.getCurrentState();
        if (currentDepth >= maxDepth) {
            //计算差值
            return evalFunction(currentState);
        }
        if (isSuccess(state)) {
            //计算差值
            return evalFunction(currentState);
        }

        List<State> successor = getSuccessor(state, 1);
        int v = -10000;
        int target = -1;
        for (int i = 0; i < successor.size(); i++) {
            //这里需要优化,Alpha-Beta剪枝,缩小检索空间
            int value = minValue(successor.get(i), currentDepth + 1);
            if (v < value) {
                v = value;
                target = i;
            }
//            System.out.println(String.format("depth=%s, value=%s, maxPlayer,数组: %s", currentDepth, value,display((successor.get(i)))));
        }
        if (target == -1) { //已经此时是平局
            return 0;
        }
        System.out.println(String.format("depth=%s, value=%s, maxPlayer,数组: %s", currentDepth, v,display((successor.get(target)))));
        state.setNextBestState(successor.get(target));

        return v;
    }

    /**
     * min方执行
     *
     * @return
     */
    public static int minValue(State state, int currentDepth) {
        int[][] currentState = state.getCurrentState();
        if (currentDepth >= maxDepth) {
            //计算差值
            return evalFunction(currentState);
        }
        if (isSuccess(state)) {
            //计算差值
            return evalFunction(currentState);
        }

        List<State> successor = getSuccessor(state, 2);
        int v = 10000;
        int target = -1;
        for (int i = 0; i < successor.size(); i++) {
            int value = maxValue(successor.get(i), currentDepth + 1);
            //这里需要优化,Alpha-Beta剪枝,缩小检索空间
            if (v > value) {
                target = i;
                v = value;
            }
//            System.out.println(String.format("depth=%s, value=%s, minPlayer,数组: %s", currentDepth,value,display((successor.get(i)))));
        }
        if (target == -1) { //已经此时是平局
            return 0;
        }
        System.out.println(String.format("depth=%s, value=%s, minPlayer,数组: %s", currentDepth, v,display((successor.get(target)))));
        state.setNextBestState(successor.get(target));
        return v;
    }

    /**
     * 当前状态的评估函数
     *
     * @param currentState
     * @return
     */
    public static int evalFunction(int[][] currentState) {
        int minPlayerResult = getPlayerOccupy(currentState, maxPlayer);
        int maxPlayerResult = getPlayerOccupy(currentState, minPlayer);

        return maxPlayerResult - minPlayerResult;
    }

    /**
     * 获得某一方所占用的坐标
     *
     * @param currentState
     * @return
     */
    public static int getPlayerOccupy(int[][] currentState, int palyer) {
        Set<Integer> rowOccupy = new HashSet<Integer>();
        Set<Integer> columnOccupy = new HashSet<Integer>();
        boolean mainDiag = false;
        boolean viceDiag = false;
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                if (currentState[i][j] == palyer) {
                    rowOccupy.add(i);
                    columnOccupy.add(j);
                    if (i == j) { //在主对角线上
                        mainDiag = true;
                    }
                    if (i + j == n - 1) {//在副对角线上
                        viceDiag = true;
                    }
                }

            }
        }

        int result = 0;
        result += n - rowOccupy.size();
        result += n - columnOccupy.size();
        result += mainDiag ? 0 : 1;
        result += viceDiag ? 0 : 1;

        return result;
    }

    /**
     * 判断当前状态是否可以判断某一方已经胜利
     *
     * @return
     */
    public static boolean isSuccess(State state) {
        if (isRowSame(state)
                || isColumnSame(state)
                || isMainDiagSame(state)
                || isViceDiagSame(state)) {
            return true;
        }
        return false;
    }

    /**
     * 一行为相同
     *
     * @return
     */
    public static boolean isRowSame(State state) {
        int[][] currentState = state.getCurrentState();
        int rowIndex = state.getRowIndex();
        int preValue = currentState[rowIndex][0];
        if (preValue == 0) {
            return false;
        }
        for (int i = 1; i < n; i++) {
            if (currentState[rowIndex][i] != preValue) {
                return false;
            }
        }
        return true;
    }

    /**
     * 列相同
     *
     * @return
     */
    public static boolean isColumnSame(State state) {
        int[][] currentState = state.getCurrentState();
        int columnIndex = state.getColumnIndex();
        int preValue = currentState[0][columnIndex];
        if (preValue == 0) {
            return false;
        }
        for (int i = 1; i < n; i++) {
            if (currentState[i][columnIndex] != preValue) {
                return false;
            }
        }
        return true;
    }

    /**
     * 主对角线是否相同
     *
     * @return
     */
    public static boolean isMainDiagSame(State state) {
        int[][] currentState = state.getCurrentState();
        int rowIndex = state.getRowIndex();
        int columnIndex = state.getColumnIndex();
        if (rowIndex == columnIndex) {
            int preValue = currentState[0][0];
            if (preValue == 0) {
                return false;
            }
            for (int i = 1; i < n; i++) {
                if (currentState[i][i] != preValue) {
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    /**
     * 副对角线是否相同
     *
     * @return
     */
    public static boolean isViceDiagSame(State state) {
        int[][] currentState = state.getCurrentState();
        int rowIndex = state.getRowIndex();
        int columnIndex = state.getColumnIndex();
        if (rowIndex + columnIndex == n - 1) {
            int preValue = currentState[0][n - 1];
            if (preValue == 0) {
                return false;
            }
            int m = 0;
            int k = n - 1;
            for (int i = 1; i < n; i++) {
                if (currentState[m + i][k - i] != preValue) {
                    return false;
                }
            }
            return true;
        }
        return false;
    }

    public static void display(State state, int player, int depth) {
        int[][] array = state.getCurrentState();
        System.out.println("===============================================");
        System.out.println(String.format("第%s步: 当前方以及走的坐标: %s --> (%s,%s)", depth, player, state.getRowIndex(), state.getColumnIndex()));
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
                System.out.print(array[i][j] + " ");
            }
            System.out.println();
        }
    }
    public static String display(State state) {
        StringBuffer buffer = new StringBuffer();
        int[][] array = state.getCurrentState();
        for (int i = 0; i < n; i++) {
            for (int j = 0; j < n; j++) {
//                System.out.print(array[i][j]+" ");
                buffer.append(array[i][j] + " ");
            }
        }
        return buffer.toString();
    }

}

@Getter
@Setter
@AllArgsConstructor
@NoArgsConstructor
class State {
    private int rowIndex;
    private int columnIndex;
    private int[][] currentState;
    private State nextBestState; //当前最好的状态


    public State(int i, int j, int[][] array) {
        this.rowIndex = i;
        this.columnIndex = j;
        this.currentState = array;
    }
}

 

 

 

 

 

易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!