MySQL语句解析处理
应用场景: 项目中, 如数据权限拦截处理, 会需要在SQL语句中根据设置的权限字段, 添加WHERE条件语句
Sql解析类: (目前只使用于SELECT查询语句, 且常用简单查询SQL语句)
package com.richfun.boot.common.dao;
import cn.hutool.core.util.StrUtil;
import com.richfun.boot.common.util.BlankUtil;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
/**
* MySQL解析器, 用于分析, 修改 MySQL 语句
* */
public class MySqlParser {
/**
* 格式化SQL字符串
* */
public static String formatSqlStr(String SQLStr) {
if (BlankUtil.isBlank(SQLStr)) {
return SQLStr;
}
// 替换SQL中的换行符
SQLStr = removeBlank(SQLStr); // 格式化SQL字符串-去除SQL中的换行符
// 去除SQL中的 /**/ 注释(只去除/*注释, Mapper中请不要使用--注释)
// (/\*([^*]|[\r\n]|(\*+([^*/]|[\r\n])))*\*+/)|'(?:[^']|'')*'|(--.*)
String regex = "(/\\*([^*]|(\\*+[^*/]))*\\*+/)";
SQLStr = SQLStr.replaceAll(regex, " ");
SQLStr = removeBlank(SQLStr); // 格式化SQL字符串-去除SQL中多余的空格
// " , " 去除 , 前的空格
SQLStr = SQLStr.replace(" , ", ", ");
return SQLStr.trim();
}
/**
* 去除SQLStr多余的空格, 如 换行符," " 改为 " "
* */
private static String removeBlank(String SQLStr) {
return SQLStr.replaceAll("\\s+", " ");
}
/**
* SQL 关键字
* */
private static final String SELECT = "select ";
private static final String FROM = " from ";
private static final String WHERE = " where ";
private static final String GROUP_BY = " group by ";
private static final String ORDER_BY = " order by ";
private static final String LIMIT = " limit ";
/**
* 判断是否为 SELECT 语句
* */
public static boolean isSelectSQL(String SQLStr) {
if (BlankUtil.isBlank(SQLStr)) {
return false;
}
// 所有数据转小写
String sqlStr = SQLStr.toLowerCase();
return sqlStr.startsWith(SELECT) && sqlStr.contains(FROM);
}
public static boolean isNotSelectSQL(String SQLStr) {
return !isSelectSQL(SQLStr);
}
/**
* 获取SQL查询语句的字段列表
* @param SQLStr SQL语句
* @return List 所有字段表
* */
public static List<String> listField(String SQLStr) {
List<String> listField = new ArrayList<>();
String fieldSQL = getFieldSQL(SQLStr);
String[] fieldArr = fieldSQL.split(",");
for (String field : fieldArr) {
field = field.trim();
int lastIndex = field.lastIndexOf(" ");
if (lastIndex > -1) {
field = field.substring(lastIndex+1);
}
// 处理 t.filed 'filed1' 这种情况
field = field.replace("'", "");
listField.add(field);
}
return listField;
}
/**
* 获取总查询字段getFieldSQL
* */
private static String getFieldSQL(String SQLStr) {
SQLStr = removeParenthesis(SQLStr);
String sqlStr = SQLStr.toLowerCase();
int fromIndex = sqlStr.indexOf(FROM);
SQLStr = SQLStr.substring(0, fromIndex);
return SQLStr;
}
/**
* 获取真正的字段名称
* @param SQLStr SQL语句
* @param field 字段名 如: staff_name
* @return 真实可用的字段名 如: t1.staff_name
* */
public static String getFieldName(String SQLStr, String field) {
if (isNotSelectSQL(SQLStr)) {
return null;
}
if (!SQLStr.contains(field)) {
return field;
}
List<String> listField = listField(SQLStr);
return getFieldName(field, listField);
}
/**
* 在所有字段中获取指定字段
* @param field 指定字段
* @param listField 所有字段集合
* @return 正确的字段名称
* */
public static String getFieldName(String field, List<String> listField) {
if (BlankUtil.isBlank(listField)) {
return field;
}
for (String f : listField) {
boolean isEq = f.equals(field) || f.endsWith("."+field) || f.endsWith(".`"+field+"`");
if (isEq) {
return f;
}
}
return field;
}
/**
* 删除 SQLStr;
* */
private static String removeParenthesis(final String SQLStr) {
List<int[]> listArr = listArrParenthesisIndex(SQLStr);
if (listArr == null || listArr.isEmpty()) {
return SQLStr;
}
String SQLStr2 = SQLStr;
for (int[] ints : listArr) {
String str = SQLStr.substring(ints[0], ints[1]+1);
SQLStr2 = StrUtil.removeAll(SQLStr2, str);
}
BlankUtil.clean(listArr);
return SQLStr2;
}
/**
* 所有SQL中的括号下标集合
* @param SQLStr 需要解析的SQL
* @return [ [(左下标, )右下标], [].... ]
* */
private static List<int[]> listArrParenthesisIndex(final String SQLStr) {
if (BlankUtil.isBlank(SQLStr)) {
return null;
}
if (!SQLStr.contains("(")) {
return null;
}
List<Integer> listLeftIndex = new ArrayList<>(); // "(" 左括号下标
List<Integer> listRightIndex = new ArrayList<>(); // ")" 右括号下标
char[] chars = SQLStr.toCharArray();
for (int i = 0; i < chars.length; i++) {
if ('(' == chars[i]) {
listLeftIndex.add(i);
} else if (')' == chars[i]) {
listRightIndex.add(i);
}
}
// 组合 "(",")" 左右括号处理
List<int[]> listArr = new ArrayList<>();
for (int endIndex : listRightIndex) {
int startIndex = getStartIndex(endIndex, listLeftIndex);
int[] arr = {startIndex, endIndex};
listArr.add(arr);
}
BlankUtil.clean(listLeftIndex);
BlankUtil.clean(listRightIndex);
// 括号去重, 若括号中包含括号, 则只取最外层的括号;
List<int[]> listIndex = new ArrayList<>(listArr);
Iterator<int[]> iterator = listArr.iterator();
if (iterator.hasNext()) {
int[] index = iterator.next();
// 判断当前括号是否属于小括号
boolean isContains = isContains(index, listIndex);
if (isContains) {
iterator.remove();
}
}
BlankUtil.clean(listIndex);
return listArr;
}
/**
* 判断当前括号是否属于小括号
* */
private static boolean isContains(int[] next, List<int[]> listIndex) {
for (int[] index : listIndex) {
if (index[0] < next[0] && index[1] > next[1]) {
return true; // 删除包含的括号
}
}
return false;
}
private static int getStartIndex(int endIndex, List<Integer> listStartIndex) {
for (int i = 0; i < listStartIndex.size(); i++) {
int startIndex = listStartIndex.get(i);
if (startIndex > endIndex) {
int start_index = listStartIndex.get(i - 1);
listStartIndex.remove(i - 1);
return start_index;
}
}
return listStartIndex.get(0);
}
/**
* 增加SQL WHERE 条件
* @param SQLStr sql语句
* @param whereCondition SQL条件语句
* @return 链接好的SQL语句
* */
public static String addWhereCondition(String SQLStr, String whereCondition) {
List<int[]> listIndex = listArrParenthesisIndex(SQLStr); // 有括号的地址, 添加WHERE条件
String sqlStr = SQLStr.toLowerCase();
// 获取当前Sql的WHERE起始位置
int whereIndex = getIndex(sqlStr, listIndex, WHERE);
// 获取条件插入开始位置
int addWhereIndex = getAddWhereIndex(sqlStr, listIndex);
// SQL-WHERE 条件 链接
String sqlWhere = addWhereCondition(SQLStr, whereCondition, whereIndex, addWhereIndex);
// 去除多余的空格
sqlWhere = removeBlank(sqlWhere); // 添加sql条件-去除多余空格
return sqlWhere;
}
/**
* 通过 下标 链接 Sql WHERE 条件
* @param SQLStr 原SQL
* @param whereCondition 需要添加的WHERE条件
* @param whereIndex SQLStr 的 "WHERE" 下标, 没有为 -1
* @param addWhereIndex 需要添加 whereCondition 的地址, GROUP BY, ORDER BY, LIMIT 的下标位置
* @return 链接好的SQL语句
* */
private static String addWhereCondition(String SQLStr, String whereCondition, int whereIndex, int addWhereIndex) {
if (whereIndex == -1 && addWhereIndex == -1) {
return SQLStr + " WHERE " + whereCondition;
}
if (whereIndex > -1 && addWhereIndex > -1) {
return SQLStr.substring(0, addWhereIndex) + " AND " + whereCondition + " " + SQLStr.substring(addWhereIndex);
}
if (whereIndex == -1) {
return SQLStr.substring(0, addWhereIndex) + " WHERE " + whereCondition + " " + SQLStr.substring(addWhereIndex);
}
return SQLStr + " AND " + whereCondition;
}
/**
* 获取条件插入开始位置
* */
private static int getAddWhereIndex(String sqlStr, List<int[]> listIndex) {
int groupByIndex = getIndex(sqlStr, listIndex, GROUP_BY);
if (groupByIndex > -1) {
return groupByIndex;
}
int orderByIndex = getIndex(sqlStr, listIndex, ORDER_BY);
if (orderByIndex > -1) {
return orderByIndex;
}
int limitIndex = getIndex(sqlStr, listIndex, LIMIT);
if (limitIndex > -1) {
return limitIndex;
}
return -1;
}
/**
* 获取关键字下标位置
* */
private static int getIndex(final String sqlStr, List<int[]> listIndex, String keyword) {
return getIndex(sqlStr, listIndex, 0, keyword);
}
private static int getIndex(final String sqlStr, List<int[]> listIndex, int index, String keyword) {
int nowIndex = sqlStr.indexOf(keyword); // 当前sqlStr查到的index
if (nowIndex == -1) {
return -1;
}
if (listIndex == null || listIndex.isEmpty()) {
return nowIndex;
}
// 判断关键字下标是否在某个括号中, 递归的index
index = index + nowIndex;
boolean isContainsIndex = isContainsIndex(index, listIndex);
if (isContainsIndex) {
// 包含在括号中, 截取关键字后面的字符串, 重新查询
index = index + keyword.length();
String sql2 = sqlStr.substring(nowIndex + keyword.length());
return getIndex(sql2, listIndex, index, keyword);
}
return index;
}
/**
* 判断 index 是否 包含于 listIndex的某个括号中;
* */
private static boolean isContainsIndex(int index, List<int[]> listIndex) {
for (int[] ints : listIndex) {
if (ints[0] < index && index < ints[1]) {
return true;
}
}
return false;
}
/**
* 处理 COUNT_SQL;
* */
public static String getCountSqlStr(String SQLStr) {
List<int[]> listIndex = listArrParenthesisIndex(SQLStr); // 有括号的地址, 处理字段
String sqlStr = SQLStr.toLowerCase();
// 去掉SQL的 ORDER_BY, LIMIT 等数据
int orderByIndex = getIndex(sqlStr, listIndex, ORDER_BY);
int limitIndex = getIndex(sqlStr, listIndex, LIMIT);
int endIndex = orderByIndex > -1 ? orderByIndex : limitIndex;
if (endIndex > -1) {
SQLStr = SQLStr.substring(0, endIndex);
}
// 判断是否有 GROUP_BY
int groupByIndex = getIndex(sqlStr, listIndex, GROUP_BY);
if (groupByIndex > -1) {
// 有 GROUP BY 的 SQL, 使用外包裹数据
return "SELECT COUNT(*) FROM (" + SQLStr + ") t_page_count";
}
// 获取当前 Sql 的 FROM 起始位置
int fromIndex = getIndex(sqlStr, listIndex, FROM);
return "SELECT COUNT(*)" + SQLStr.substring(fromIndex);
}
}
测试类
package com.richfun.boot;
import com.richfun.boot.common.dao.MySqlParser;
import org.junit.Test;
import java.util.List;
public class MySqlParserTest {
private static final String SQLStr =
"SELECT\n" +
" t1.`gf`,t.fg, t.f2 '测试测试'\n" +
" , t1.field1\n" +
" /*, (SELECT field, (SELECT * FROM t8) t1 FROM t2) field2*/\n" +
" , (SELECT field FROM t3) field3\n" +
" , t2.field4\n" +
" , t3.field5\n" +
" , t4.field6\n" +
" , IFNULL(t1.field7, t2.field7) field7\n" +
" , t1.field8\n" +
" FROM t\n" +
" LEFT JOIN t4 ON (t4.field = t.field1 AND t4.x = 1)\n" +
" RIGHT JOIN t5 ON t5.field = t.field2\n" +
" INNER JOIN t6 ON t6.field = t.field3\n" +
" LEFT JOIN (SELECT * FROM t8 WHERE t8.f1 = 1) t8 ON t8.f2 = t.field1\n" +
" WHERE t.field = 'ff dd --ss'\n" +
" /*AND t1.fileid2 IN (SELECT f1 FROM t9 WHERE t9.f2 = 1 GROUP BY t9.f2)*/\n" +
// " GROUP BY t1.f1\n" +
" ORDER BY t2.f1, t3.from_1" +
" LIMIT 1"
;
@Test
public void testIsSelectSQL() {
long s1 = System.currentTimeMillis();
String SQLStr2 = MySqlParser.formatSqlStr(SQLStr);
System.out.println(SQLStr2);
List<String> listField = MySqlParser.listField(SQLStr2);
System.out.println(listField);
String fieldName = MySqlParser.getFieldName("gf", listField);
// System.out.println(fieldName);
String whereCondition = fieldName + " = '1'";
System.out.println(whereCondition);
String SQLStr3 = MySqlParser.addWhereCondition(SQLStr2, whereCondition);
System.out.println(SQLStr3);
System.out.println("耗时: " + (System.currentTimeMillis() - s1) + "ms");
}
@Test
public void testCountSqlStr() {
long s1 = System.currentTimeMillis();
String SQLStr2 = MySqlParser.formatSqlStr(SQLStr);
System.out.println(SQLStr2);
String countSqlStr = MySqlParser.getCountSqlStr(SQLStr2);
System.out.println(countSqlStr);
System.out.println("耗时: " + (System.currentTimeMillis() - s1) + "ms");
}
}
源码地址: https://gitee.com/ge.yang/spring-demo/tree/master/boot-demo/
来源:oschina
链接:https://my.oschina.net/u/3681868/blog/4834833