MySQL语句字符串解析处理工具类

こ雲淡風輕ζ 提交于 2020-12-23 18:50:01

MySQL语句解析处理

应用场景: 项目中, 如数据权限拦截处理, 会需要在SQL语句中根据设置的权限字段, 添加WHERE条件语句

Sql解析类: (目前只使用于SELECT查询语句, 且常用简单查询SQL语句)

package com.richfun.boot.common.dao;

import cn.hutool.core.util.StrUtil;
import com.richfun.boot.common.util.BlankUtil;

import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;

/**
 * MySQL解析器, 用于分析, 修改 MySQL 语句
 * */
public class MySqlParser {

    /**
     * 格式化SQL字符串
     * */
    public static String formatSqlStr(String SQLStr) {
        if (BlankUtil.isBlank(SQLStr)) {
            return SQLStr;
        }
        // 替换SQL中的换行符
        SQLStr = removeBlank(SQLStr); // 格式化SQL字符串-去除SQL中的换行符
        // 去除SQL中的 /**/ 注释(只去除/*注释, Mapper中请不要使用--注释)
        // (/\*([^*]|[\r\n]|(\*+([^*/]|[\r\n])))*\*+/)|'(?:[^']|'')*'|(--.*)
        String regex = "(/\\*([^*]|(\\*+[^*/]))*\\*+/)";
        SQLStr = SQLStr.replaceAll(regex, " ");
        SQLStr = removeBlank(SQLStr); // 格式化SQL字符串-去除SQL中多余的空格
        // " , " 去除 , 前的空格
        SQLStr = SQLStr.replace(" , ", ", ");
        return SQLStr.trim();
    }
    /**
     * 去除SQLStr多余的空格, 如  换行符,"  " 改为 " "
     * */
    private static String removeBlank(String SQLStr) {
        return  SQLStr.replaceAll("\\s+", " ");
    }

    /**
     * SQL 关键字
     * */
    private static final String SELECT = "select ";
    private static final String FROM = " from ";
    private static final String WHERE = " where ";
    private static final String GROUP_BY = " group by ";
    private static final String ORDER_BY = " order by ";
    private static final String LIMIT = " limit ";

    /**
     * 判断是否为 SELECT 语句
     * */
    public static boolean isSelectSQL(String SQLStr) {
        if (BlankUtil.isBlank(SQLStr)) {
            return false;
        }
        // 所有数据转小写
        String sqlStr = SQLStr.toLowerCase();
        return sqlStr.startsWith(SELECT) && sqlStr.contains(FROM);
    }

    public static boolean isNotSelectSQL(String SQLStr) {
        return !isSelectSQL(SQLStr);
    }

    /**
     * 获取SQL查询语句的字段列表
     * @param SQLStr SQL语句
     * @return List 所有字段表
     * */
    public static List<String> listField(String SQLStr) {
        List<String> listField = new ArrayList<>();
        String fieldSQL = getFieldSQL(SQLStr);
        String[] fieldArr = fieldSQL.split(",");
        for (String field : fieldArr) {
            field = field.trim();
            int lastIndex = field.lastIndexOf(" ");
            if (lastIndex > -1) {
                field = field.substring(lastIndex+1);
            }
            // 处理 t.filed 'filed1' 这种情况
            field = field.replace("'", "");
            listField.add(field);
        }
        return listField;
    }

    /**
     * 获取总查询字段getFieldSQL
     * */
    private static String getFieldSQL(String SQLStr) {
        SQLStr = removeParenthesis(SQLStr);
        String sqlStr = SQLStr.toLowerCase();
        int fromIndex = sqlStr.indexOf(FROM);
        SQLStr = SQLStr.substring(0, fromIndex);
        return SQLStr;
    }

    /**
     * 获取真正的字段名称
     * @param SQLStr SQL语句
     * @param field 字段名 如: staff_name
     * @return 真实可用的字段名 如: t1.staff_name
     * */
    public static String getFieldName(String SQLStr, String field) {
        if (isNotSelectSQL(SQLStr)) {
            return null;
        }
        if (!SQLStr.contains(field)) {
            return field;
        }
        List<String> listField = listField(SQLStr);
        return getFieldName(field, listField);
    }

    /**
     * 在所有字段中获取指定字段
     * @param field 指定字段
     * @param listField 所有字段集合
     * @return 正确的字段名称
     * */
    public static String getFieldName(String field, List<String> listField) {
        if (BlankUtil.isBlank(listField)) {
            return field;
        }
        for (String f : listField) {
            boolean isEq = f.equals(field) || f.endsWith("."+field) || f.endsWith(".`"+field+"`");
            if (isEq) {
                return f;
            }
        }
        return field;
    }

    /**
     * 删除 SQLStr;
     * */
    private static String removeParenthesis(final String SQLStr) {
        List<int[]> listArr = listArrParenthesisIndex(SQLStr);
        if (listArr == null || listArr.isEmpty()) {
            return SQLStr;
        }
        String SQLStr2 = SQLStr;
        for (int[] ints : listArr) {
            String str = SQLStr.substring(ints[0], ints[1]+1);
            SQLStr2 = StrUtil.removeAll(SQLStr2, str);
        }
        BlankUtil.clean(listArr);
        return SQLStr2;
    }

    /**
     * 所有SQL中的括号下标集合
     * @param SQLStr 需要解析的SQL
     * @return [ [(左下标, )右下标], [].... ]
     * */
    private static List<int[]> listArrParenthesisIndex(final String SQLStr) {
        if (BlankUtil.isBlank(SQLStr)) {
            return null;
        }
        if (!SQLStr.contains("(")) {
            return null;
        }
        List<Integer> listLeftIndex = new ArrayList<>(); // "(" 左括号下标
        List<Integer> listRightIndex = new ArrayList<>(); // ")" 右括号下标
        char[] chars = SQLStr.toCharArray();
        for (int i = 0; i < chars.length; i++) {
            if ('(' == chars[i]) {
                listLeftIndex.add(i);
            } else if (')' == chars[i]) {
                listRightIndex.add(i);
            }
        }
        // 组合 "(",")" 左右括号处理
        List<int[]> listArr = new ArrayList<>();
        for (int endIndex : listRightIndex) {
            int startIndex = getStartIndex(endIndex, listLeftIndex);
            int[] arr = {startIndex, endIndex};
            listArr.add(arr);
        }
        BlankUtil.clean(listLeftIndex);
        BlankUtil.clean(listRightIndex);
        // 括号去重, 若括号中包含括号, 则只取最外层的括号;
        List<int[]> listIndex = new ArrayList<>(listArr);
        Iterator<int[]> iterator = listArr.iterator();
        if (iterator.hasNext()) {
            int[] index = iterator.next();
            // 判断当前括号是否属于小括号
            boolean isContains = isContains(index, listIndex);
            if (isContains) {
                iterator.remove();
            }
        }
        BlankUtil.clean(listIndex);
        return listArr;
    }

    /**
     * 判断当前括号是否属于小括号
     * */
    private static boolean isContains(int[] next, List<int[]> listIndex) {
        for (int[] index : listIndex) {
            if (index[0] < next[0] && index[1] > next[1]) {
                return true; // 删除包含的括号
            }
        }
        return false;
    }

    private static int getStartIndex(int endIndex, List<Integer> listStartIndex) {
        for (int i = 0; i < listStartIndex.size(); i++) {
            int startIndex = listStartIndex.get(i);
            if (startIndex > endIndex) {
                int start_index = listStartIndex.get(i - 1);
                listStartIndex.remove(i - 1);
                return start_index;
            }
        }
        return listStartIndex.get(0);
    }

    /**
     * 增加SQL WHERE 条件
     * @param SQLStr sql语句
     * @param whereCondition SQL条件语句
     * @return 链接好的SQL语句
     * */
    public static String addWhereCondition(String SQLStr, String whereCondition) {
        List<int[]> listIndex = listArrParenthesisIndex(SQLStr); // 有括号的地址, 添加WHERE条件
        String sqlStr = SQLStr.toLowerCase();
        // 获取当前Sql的WHERE起始位置
        int whereIndex = getIndex(sqlStr, listIndex, WHERE);
        // 获取条件插入开始位置
        int addWhereIndex = getAddWhereIndex(sqlStr, listIndex);
        // SQL-WHERE 条件 链接
        String sqlWhere = addWhereCondition(SQLStr, whereCondition, whereIndex, addWhereIndex);
        // 去除多余的空格
        sqlWhere = removeBlank(sqlWhere); // 添加sql条件-去除多余空格
        return sqlWhere;
    }

    /**
     * 通过 下标 链接 Sql WHERE 条件
     * @param SQLStr 原SQL
     * @param whereCondition 需要添加的WHERE条件
     * @param whereIndex SQLStr 的 "WHERE" 下标, 没有为 -1
     * @param addWhereIndex 需要添加 whereCondition 的地址, GROUP BY, ORDER BY, LIMIT 的下标位置
     * @return 链接好的SQL语句
     * */
    private static String addWhereCondition(String SQLStr, String whereCondition, int whereIndex, int addWhereIndex) {
        if (whereIndex == -1 && addWhereIndex == -1) {
            return SQLStr + " WHERE " + whereCondition;
        }
        if (whereIndex > -1 && addWhereIndex > -1) {
            return SQLStr.substring(0, addWhereIndex) + " AND " + whereCondition + " " + SQLStr.substring(addWhereIndex);
        }
        if (whereIndex == -1) {
            return SQLStr.substring(0, addWhereIndex) + " WHERE " + whereCondition + " " + SQLStr.substring(addWhereIndex);
        }
        return SQLStr + " AND " + whereCondition;
    }

    /**
     * 获取条件插入开始位置
     * */
    private static int getAddWhereIndex(String sqlStr, List<int[]> listIndex) {
        int groupByIndex = getIndex(sqlStr, listIndex, GROUP_BY);
        if (groupByIndex > -1) {
            return groupByIndex;
        }
        int orderByIndex = getIndex(sqlStr, listIndex, ORDER_BY);
        if (orderByIndex > -1) {
            return orderByIndex;
        }
        int limitIndex = getIndex(sqlStr, listIndex, LIMIT);
        if (limitIndex > -1) {
            return limitIndex;
        }
        return -1;
    }
    /**
     * 获取关键字下标位置
     * */
    private static int getIndex(final String sqlStr, List<int[]> listIndex, String keyword) {
        return getIndex(sqlStr, listIndex, 0, keyword);
    }
    private static int getIndex(final String sqlStr, List<int[]> listIndex, int index, String keyword) {
        int nowIndex = sqlStr.indexOf(keyword); // 当前sqlStr查到的index
        if (nowIndex == -1) {
            return -1;
        }
        if (listIndex == null || listIndex.isEmpty()) {
            return nowIndex;
        }
        // 判断关键字下标是否在某个括号中, 递归的index
        index = index + nowIndex;
        boolean isContainsIndex = isContainsIndex(index, listIndex);
        if (isContainsIndex) {
            // 包含在括号中, 截取关键字后面的字符串, 重新查询
            index = index + keyword.length();
            String sql2 = sqlStr.substring(nowIndex + keyword.length());
            return getIndex(sql2, listIndex, index, keyword);
        }
        return index;
    }
    /**
     * 判断 index 是否 包含于 listIndex的某个括号中;
     * */
    private static boolean isContainsIndex(int index, List<int[]> listIndex) {
        for (int[] ints : listIndex) {
            if (ints[0] < index && index < ints[1]) {
                return true;
            }
        }
        return false;
    }


    /**
     * 处理 COUNT_SQL;
     * */
    public static String getCountSqlStr(String SQLStr) {
        List<int[]> listIndex = listArrParenthesisIndex(SQLStr); // 有括号的地址, 处理字段
        String sqlStr = SQLStr.toLowerCase();
        // 去掉SQL的 ORDER_BY, LIMIT 等数据
        int orderByIndex = getIndex(sqlStr, listIndex, ORDER_BY);
        int limitIndex = getIndex(sqlStr, listIndex, LIMIT);
        int endIndex = orderByIndex > -1 ? orderByIndex : limitIndex;
        if (endIndex > -1) {
            SQLStr = SQLStr.substring(0, endIndex);
        }
        // 判断是否有 GROUP_BY
        int groupByIndex = getIndex(sqlStr, listIndex, GROUP_BY);
        if (groupByIndex > -1) {
            // 有 GROUP BY 的 SQL, 使用外包裹数据
            return "SELECT COUNT(*) FROM (" + SQLStr + ") t_page_count";
        }
        // 获取当前 Sql 的 FROM 起始位置
        int fromIndex = getIndex(sqlStr, listIndex, FROM);
        return "SELECT COUNT(*)" + SQLStr.substring(fromIndex);
    }

}

测试类

package com.richfun.boot;

import com.richfun.boot.common.dao.MySqlParser;
import org.junit.Test;

import java.util.List;

public class MySqlParserTest {

    private static final String SQLStr =
                "SELECT\n" +
            "        t1.`gf`,t.fg, t.f2 '测试测试'\n" +
            "        , t1.field1\n" +
            "        /*, (SELECT field, (SELECT * FROM t8) t1 FROM t2) field2*/\n" +
            "        , (SELECT field FROM t3) field3\n" +
            "        , t2.field4\n" +
            "        , t3.field5\n" +
            "        , t4.field6\n" +
            "        , IFNULL(t1.field7, t2.field7) field7\n" +
            "        , t1.field8\n" +
            "    FROM t\n" +
            "    LEFT JOIN t4 ON (t4.field = t.field1 AND t4.x = 1)\n" +
            "    RIGHT JOIN t5 ON t5.field = t.field2\n" +
            "    INNER JOIN t6 ON t6.field = t.field3\n" +
            "    LEFT JOIN (SELECT * FROM t8 WHERE t8.f1 = 1) t8 ON t8.f2 = t.field1\n" +
            "    WHERE t.field = 'ff dd --ss'\n" +
            "    /*AND t1.fileid2 IN (SELECT f1 FROM t9 WHERE t9.f2 = 1 GROUP BY t9.f2)*/\n" +
//            "    GROUP BY t1.f1\n" +
            "    ORDER BY t2.f1, t3.from_1" +
            "    LIMIT 1"
            ;


    @Test
    public void testIsSelectSQL() {
        long s1 = System.currentTimeMillis();
        String SQLStr2 = MySqlParser.formatSqlStr(SQLStr);
        System.out.println(SQLStr2);

        List<String> listField = MySqlParser.listField(SQLStr2);
        System.out.println(listField);

        String fieldName = MySqlParser.getFieldName("gf", listField);
//        System.out.println(fieldName);
        String whereCondition = fieldName + " = '1'";
        System.out.println(whereCondition);

        String SQLStr3 = MySqlParser.addWhereCondition(SQLStr2, whereCondition);
        System.out.println(SQLStr3);

        System.out.println("耗时: " + (System.currentTimeMillis() - s1) + "ms");
    }

    @Test
    public void testCountSqlStr() {
        long s1 = System.currentTimeMillis();
        String SQLStr2 = MySqlParser.formatSqlStr(SQLStr);
        System.out.println(SQLStr2);

        String countSqlStr = MySqlParser.getCountSqlStr(SQLStr2);
        System.out.println(countSqlStr);

        System.out.println("耗时: " + (System.currentTimeMillis() - s1) + "ms");
    }

}

源码地址: https://gitee.com/ge.yang/spring-demo/tree/master/boot-demo/

标签
易学教程内所有资源均来自网络或用户发布的内容,如有违反法律规定的内容欢迎反馈
该文章没有解决你所遇到的问题?点击提问,说说你的问题,让更多的人一起探讨吧!