动态表名
# 核心类
- DynamicTableNameInnerInterceptor
- TableNameParser
- TableNameHandler
# 简单使用
DynamicTableNameInnerInterceptor interceptor = new DynamicTableNameInnerInterceptor();
// 此处 TableNameHandler 策略可自行变更改
interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");
# 源码阅读
# DynamicTableNameInnerInterceptor
/*
* Copyright (c) 2011-2022, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.inner;
import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.TableNameParser;
import com.baomidou.mybatisplus.extension.plugins.handler.TableNameHandler;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;
import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;
/**
* 动态表名
*
* @author jobob
* @since 3.4.0
*/
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@SuppressWarnings({"rawtypes"})
public class DynamicTableNameInnerInterceptor implements InnerInterceptor {
private Runnable hook;
public void setHook(Runnable hook) {
this.hook = hook;
}
/**
* 表名处理器,是否处理表名的情况都在该处理器中自行判断
*/
private TableNameHandler tableNameHandler;
@Override
public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
//检查是否忽略插件
if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) return;
//
PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
mpBs.sql(this.changeTable(mpBs.sql()));
}
@Override
public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
MappedStatement ms = mpSh.mappedStatement();
SqlCommandType sct = ms.getSqlCommandType();
//查询是不关注
if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
//开启动态表名
if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) {
return;
}
PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
mpBs.sql(this.changeTable(mpBs.sql()));
}
}
protected String changeTable(String sql) {
//没有进行构造注入
ExceptionUtils.throwMpe(null == tableNameHandler, "Please implement TableNameHandler processing logic");
//解析sql 将sql 各个段落转换为List 例如 ['SELECT','*','FROM','T_USER','WHERE','USER_NAME','=',"?"]
TableNameParser parser = new TableNameParser(sql);
List<TableNameParser.SqlToken> names = new ArrayList<>();
//解析出 表名 SqlToken
//SqlToken 属性:
// start 表名在SQL中的开始下标
// end 表名在SQL中的结束下标
// val 表名
parser.accept(names::add);
StringBuilder builder = new StringBuilder();
//首次构建结束位置为0 ,后续根据 table SqlToken 的end下标变化
int last = 0;
for (TableNameParser.SqlToken name : names) {
int start = name.getStart();
if (start != last) {
// 此次可以理解为添加原SQL里指定下标范围的 字符
// 例如:
// names [t_user,t_user_role]
// 第一次
// sql :"select * from t_user,t_user_role"
// last :0
// start : 14
// end : 20
// val : t_user
// builder = "select * from "
// 第二次
// sql :"select * from t_user,t_user_role"
// last :20
// start : 22
// end : 33
// val : t_user_role
// builder = "select * from "
builder.append(sql, last, start);
// 根据自己的表名策略拼接动态表名
// 例如:表名生成策略为 table_?
// 第一次: "select * from "+"t_user_?
// 第二次: "select * from "+"t_user_?"+","+"t_user_role_?"
builder.append(tableNameHandler.dynamicTableName(sql, name.getValue()));
}
//第一次 last : 0 > 20
//第二次 last : 20 > 33
last = name.getEnd();
}
if (last != sql.length()) {
//如果最后一个表名下标不等于 sql 长度 ,
// 例如 sql后带 where 条件,拼接条件
builder.append(sql.substring(last));
}
if (hook != null) {
//如果定义有钩子,运行钩子
hook.run();
}
return builder.toString();
}
}
# TableNameParser
/*
* Copyright (c) 2011-2022, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.core.toolkit;
import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;
/**
* SQL 表名解析
* <p>
* https://github.com/mnadeem/sql-table-name-parser
* <p>
* Ultra light, Ultra fast parser to extract table name out SQLs, supports oracle dialect SQLs as well.
* USE: new TableNameParser(sql).tables()
*
* @author Nadeem Mohammad, hcl
* @since 2019-04-22
*/
public final class TableNameParser {
private static final String TOKEN_SET = "set";
private static final String TOKEN_OF = "of";
private static final String TOKEN_DUAL = "dual";
private static final String TOKEN_DELETE = "delete";
private static final String TOKEN_CREATE = "create";
private static final String TOKEN_INDEX = "index";
private static final String KEYWORD_JOIN = "join";
private static final String KEYWORD_INTO = "into";
private static final String KEYWORD_TABLE = "table";
private static final String KEYWORD_FROM = "from";
private static final String KEYWORD_USING = "using";
private static final String KEYWORD_UPDATE = "update";
private static final String KEYWORD_DUPLICATE = "duplicate";
private static final List<String> concerned = Arrays.asList(KEYWORD_TABLE, KEYWORD_INTO, KEYWORD_JOIN, KEYWORD_USING, KEYWORD_UPDATE);
private static final List<String> ignored = Arrays.asList(StringPool.LEFT_BRACKET, TOKEN_SET, TOKEN_OF, TOKEN_DUAL);
/**
* 该表达式会匹配 SQL 中不是 SQL TOKEN 的部分,比如换行符,注释信息,结尾的 {@code ;} 等。
* <p>
* 排除的项目包括:
* 1、以 -- 开头的注释信息
* 2、;
* 3、空白字符
* 4、使用 /* * / 注释的信息
* 5、把 ,() 也要分出来
*/
private static final Pattern NON_SQL_TOKEN_PATTERN = Pattern.compile("(--[^\\v]+)|;|(\\s+)|((?s)/[*].*?[*]/)"
+ "|(((\\b|\\B)(?=[,()]))|((?<=[,()])(\\b|\\B)))"
);
private final List<SqlToken> tokens;
/**
* 从 SQL 中提取表名称
*
* @param sql 需要解析的 SQL 语句
*/
public TableNameParser(String sql) {
tokens = fetchAllTokens(sql);
}
/**
* 接受一个新的访问者,并访问当前 SQL 的表名称
* <p>
* 现在我们改成了访问者模式,不在对以前的 SQL 做改动
* 同时,你可以方便的获得表名位置的索引
*
* @param visitor 访问者
*/
public void accept(TableNameVisitor visitor) {
int index = 0;
String first = tokens.get(index).getValue();
//判断特殊语法
if (isOracleSpecialDelete(first, tokens, index)) {
visitNameToken(tokens.get(index + 1), visitor);
} else if (isCreateIndex(first, tokens, index)) {
visitNameToken(tokens.get(index + 4), visitor);
} else {
//如果有token 就一直执行
while (hasMoreTokens(tokens, index)) {
// 注意 此次是 index++ 得到当前值后再位移下标
String current = tokens.get(index++).getValue();
if (isFromToken(current)) {
//找到一个 from 关键字 开始解析
// processFromToken 里的话就是获取表名了
processFromToken(tokens, index, visitor);
} else if (isOnDuplicateKeyUpdate(current, index)) {
//这里也是个特殊语法 个人见解 ,为了处理奇奇怪怪的sql
// 此处是 on duplicate key update 直接跳过了
index = skipDuplicateKeyUpdateIndex(index);
} else if (concerned.contains(current.toLowerCase())) {
// 如果是 table,using,join ,update ,into 啥的
// 只要没到末尾 就继续解析
if (hasMoreTokens(tokens, index)) {
//位移一位下标,准备获取表名
SqlToken next = tokens.get(index++);
// 这里也是获取表名了 ,
// 为啥呢 因为上面的语法后一位都是表名
visitNameToken(next, visitor);
}
}
}
}
}
/**
* 表名访问器
*/
public interface TableNameVisitor {
/**
* @param name 表示表名称的 token
*/
void visit(SqlToken name);
}
/**
* 从 SQL 语句中提取出 所有的 SQL Token
*
* @param sql SQL
* @return 语句
*/
protected List<SqlToken> fetchAllTokens(String sql) {
List<SqlToken> tokens = new ArrayList<>();
Matcher matcher = NON_SQL_TOKEN_PATTERN.matcher(sql);
int last = 0;
while (matcher.find()) {
int start = matcher.start();
if (start != last) {
tokens.add(new SqlToken(last, start, sql.substring(last, start)));
}
last = matcher.end();
}
if (last != sql.length()) {
tokens.add(new SqlToken(last, sql.length(), sql.substring(last)));
}
return tokens;
}
/**
* 如果是 DELETE 后面紧跟的不是 FROM 或者 * ,则 返回 true
*
* @param current 当前的 token
* @param tokens token 列表
* @param index 索引
* @return 判断是不是 Oracle 特殊的删除手法
*/
private static boolean isOracleSpecialDelete(String current, List<SqlToken> tokens, int index) {
if (TOKEN_DELETE.equalsIgnoreCase(current)) {
if (hasMoreTokens(tokens, index++)) {
String next = tokens.get(index).getValue();
return !KEYWORD_FROM.equalsIgnoreCase(next) && !StringPool.ASTERISK.equals(next);
}
}
return false;
}
private boolean isCreateIndex(String current, List<SqlToken> tokens, int index) {
index++; // Point to next token
if (TOKEN_CREATE.equalsIgnoreCase(current) && hasIthToken(tokens, index)) {
String next = tokens.get(index).getValue();
return TOKEN_INDEX.equalsIgnoreCase(next);
}
return false;
}
/**
* @param current 当前token
* @param index 索引
* @return 判断是否是mysql的特殊语法 on duplicate key update
*/
private boolean isOnDuplicateKeyUpdate(String current, int index) {
if (KEYWORD_DUPLICATE.equalsIgnoreCase(current)) {
if (hasMoreTokens(tokens, index++)) {
String next = tokens.get(index).getValue();
return KEYWORD_UPDATE.equalsIgnoreCase(next);
}
}
return false;
}
private static boolean hasIthToken(List<SqlToken> tokens, int currentIndex) {
return hasMoreTokens(tokens, currentIndex) && tokens.size() > currentIndex + 3;
}
private static boolean isFromToken(String currentToken) {
return KEYWORD_FROM.equalsIgnoreCase(currentToken);
}
private int skipDuplicateKeyUpdateIndex(int index) {
// on duplicate key update为mysql的固定写法,直接跳过即可。
return index + 2;
}
/**
* from处理过程
* @param tokens
* @param index
* @param visitor
*/
private static void processFromToken(List<SqlToken> tokens, int index, TableNameVisitor visitor) {
SqlToken sqlToken = tokens.get(index++);
visitNameToken(sqlToken, visitor);
String next = null;
//如果不是最后一个token
if (hasMoreTokens(tokens, index)) {
next = tokens.get(index++).getValue();
}
//判读是一个表还是多个表
if (shouldProcessMultipleTables(next)) {
processNonAliasedMultiTables(tokens, index, next, visitor);
} else {
processAliasedMultiTables(tokens, index, sqlToken, visitor);
}
}
/**
* 处理没有别名的多表
*/
private static void processNonAliasedMultiTables(List<SqlToken> tokens, int index, String nextToken, TableNameVisitor visitor) {
while (nextToken.equals(StringPool.COMMA)) {
//如果是 , 将 ,后面 为表名 ,添加到访问者(也就是 tables.add)
visitNameToken(tokens.get(index++), visitor);
if (hasMoreTokens(tokens, index)) {
nextToken = tokens.get(index++).getValue();
} else {
break;
}
}
}
/**
* 处理有别名的多表
*/
private static void processAliasedMultiTables(List<SqlToken> tokens, int index, SqlToken current, TableNameVisitor visitor) {
String nextNextToken = null;
if (hasMoreTokens(tokens, index)) {
nextNextToken = tokens.get(index++).getValue();
}
if (shouldProcessMultipleTables(nextNextToken)) {
while (hasMoreTokens(tokens, index) && nextNextToken.equals(StringPool.COMMA)) {
if (hasMoreTokens(tokens, index)) {
current = tokens.get(index++);
}
if (hasMoreTokens(tokens, index)) {
index++;
}
if (hasMoreTokens(tokens, index)) {
nextNextToken = tokens.get(index++).getValue();
}
visitNameToken(current, visitor);
}
}
}
private static boolean shouldProcessMultipleTables(final String nextToken) {
return nextToken != null && nextToken.equals(StringPool.COMMA);
}
private static boolean hasMoreTokens(List<SqlToken> tokens, int index) {
return index < tokens.size();
}
private static void visitNameToken(SqlToken token, TableNameVisitor visitor) {
String value = token.getValue().toLowerCase();
if (!ignored.contains(value)) {
visitor.visit(token);
}
}
/**
* parser tables
*
* @return table names extracted out of sql
* @see #accept(TableNameVisitor)
*/
public Collection<String> tables() {
Map<String, String> tableMap = new HashMap<>();
accept(token -> {
String name = token.getValue();
tableMap.putIfAbsent(name.toLowerCase(), name);
});
return new HashSet<>(tableMap.values());
}
/**
* SQL 词
*/
public static class SqlToken implements Comparable<SqlToken> {
private final int start;
private final int end;
private final String value;
private SqlToken(int start, int end, String value) {
this.start = start;
this.end = end;
this.value = value;
}
public int getStart() {
return start;
}
public int getEnd() {
return end;
}
public String getValue() {
return value;
}
@Override
public int compareTo(SqlToken o) {
return Integer.compare(start, o.start);
}
@Override
public String toString() {
return value;
}
}
}
/*
* Copyright (c) 2011-2022, baomidou (jobob@qq.com).
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package com.baomidou.mybatisplus.extension.plugins.handler;
/**
* 动态表名处理器
*
* @author miemie
* @since 3.4.0
*/
public interface TableNameHandler {
/**
* 生成动态表名
*
* @param sql 当前执行 SQL
* @param tableName 表名
* @return String
*/
String dynamicTableName(String sql, String tableName);
}
# 单元测试
# 搬运MP测试类源码
package com.baomidou.mybatisplus.extension.plugins.inner;
import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;
import static org.junit.jupiter.api.Assertions.assertEquals;
/**
* 动态表名内部拦截器测试
*
* @author miemie, hcl
* @since 2020-07-16
*/
class DynamicTableNameInnerInterceptorTest {
/**
* 测试 SQL 中的动态表名替换
*/
@Test
@SuppressWarnings({"SqlDialectInspection", "SqlNoDataSourceInspection"})
void doIt() {
DynamicTableNameInnerInterceptor interceptor = new DynamicTableNameInnerInterceptor();
interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");
// 表名相互包含
@Language("SQL")
String origin = "SELECT * FROM t_user, t_user_role";
assertEquals("SELECT * FROM t_user_r, t_user_role_r", interceptor.changeTable(origin));
// 表名在末尾
origin = "SELECT * FROM t_user";
assertEquals("SELECT * FROM t_user_r", interceptor.changeTable(origin));
// 表名前后有注释
origin = "SELECT * FROM /**/t_user/* t_user */";
assertEquals("SELECT * FROM /**/t_user_r/* t_user */", interceptor.changeTable(origin));
// 值中带有表名
origin = "SELECT * FROM t_user WHERE name = 't_user'";
assertEquals("SELECT * FROM t_user_r WHERE name = 't_user'", interceptor.changeTable(origin));
// 别名被声明要替换
origin = "SELECT t_user.* FROM t_user_real t_user";
assertEquals("SELECT t_user.* FROM t_user_real_r t_user", interceptor.changeTable(origin));
// 别名被声明要替换
origin = "SELECT t.* FROM t_user_real t left join entity e on e.id = t.id";
assertEquals("SELECT t.* FROM t_user_real_r t left join entity e on e.id = t.id", interceptor.changeTable(origin));
}
}
上次更新: 2024/11/05, 08:29:31