闲碎记事本 闲碎记事本
首页
  • JAVA
  • Cloudflare
  • 学完再改一遍UI
友链
  • 分类
  • 标签
  • 归档
GitHub (opens new window)

YAN

我要偷偷记录...
首页
  • JAVA
  • Cloudflare
  • 学完再改一遍UI
友链
  • 分类
  • 标签
  • 归档
GitHub (opens new window)
  • java

    • SpringBoot

    • SpringSecurity

    • MybatisPlus

      • 租户拦截器
      • 动态表名
        • 核心类
        • 简单使用
        • 源码阅读
          • DynamicTableNameInnerInterceptor
          • TableNameParser
        • 单元测试
          • 搬运MP测试类源码
      • SQL阻断器
    • Netty

    • sip

    • 其他

  • linux

  • docker

  • redis

  • nginx

  • mysql

  • 其他

  • 环境搭建

  • 知识库
  • java
  • MybatisPlus
Yan
2023-03-06
目录

动态表名

# 核心类

  • DynamicTableNameInnerInterceptor
  • TableNameParser
  • TableNameHandler

# 简单使用

DynamicTableNameInnerInterceptor interceptor = new DynamicTableNameInnerInterceptor();
// 此处 TableNameHandler 策略可自行变更改
interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");

# 源码阅读

# DynamicTableNameInnerInterceptor

/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.extension.plugins.inner;

import com.baomidou.mybatisplus.core.plugins.InterceptorIgnoreHelper;
import com.baomidou.mybatisplus.core.toolkit.ExceptionUtils;
import com.baomidou.mybatisplus.core.toolkit.PluginUtils;
import com.baomidou.mybatisplus.core.toolkit.TableNameParser;
import com.baomidou.mybatisplus.extension.plugins.handler.TableNameHandler;
import lombok.AllArgsConstructor;
import lombok.Getter;
import lombok.NoArgsConstructor;
import lombok.Setter;
import org.apache.ibatis.executor.Executor;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.session.ResultHandler;
import org.apache.ibatis.session.RowBounds;

import java.sql.Connection;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.List;

/**
 * 动态表名
 *
 * @author jobob
 * @since 3.4.0
 */
@Getter
@Setter
@NoArgsConstructor
@AllArgsConstructor
@SuppressWarnings({"rawtypes"})
public class DynamicTableNameInnerInterceptor implements InnerInterceptor {
    private Runnable hook;

    public void setHook(Runnable hook) {
        this.hook = hook;
    }

    /**
     * 表名处理器,是否处理表名的情况都在该处理器中自行判断
     */
    private TableNameHandler tableNameHandler;

    @Override
    public void beforeQuery(Executor executor, MappedStatement ms, Object parameter, RowBounds rowBounds, ResultHandler resultHandler, BoundSql boundSql) throws SQLException {
        //检查是否忽略插件
        if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) return;
        //
        PluginUtils.MPBoundSql mpBs = PluginUtils.mpBoundSql(boundSql);
        mpBs.sql(this.changeTable(mpBs.sql()));
    }

    @Override
    public void beforePrepare(StatementHandler sh, Connection connection, Integer transactionTimeout) {
        PluginUtils.MPStatementHandler mpSh = PluginUtils.mpStatementHandler(sh);
        MappedStatement ms = mpSh.mappedStatement();
        SqlCommandType sct = ms.getSqlCommandType();
        //查询是不关注
        if (sct == SqlCommandType.INSERT || sct == SqlCommandType.UPDATE || sct == SqlCommandType.DELETE) {
            //开启动态表名
            if (InterceptorIgnoreHelper.willIgnoreDynamicTableName(ms.getId())) {
                return;
            }
            PluginUtils.MPBoundSql mpBs = mpSh.mPBoundSql();
            mpBs.sql(this.changeTable(mpBs.sql()));
        }
    }

    protected String changeTable(String sql) {
        //没有进行构造注入
        ExceptionUtils.throwMpe(null == tableNameHandler, "Please implement TableNameHandler processing logic");
        //解析sql 将sql 各个段落转换为List 例如 ['SELECT','*','FROM','T_USER','WHERE','USER_NAME','=',"?"]
        TableNameParser parser = new TableNameParser(sql);
        List<TableNameParser.SqlToken> names = new ArrayList<>();
        //解析出 表名 SqlToken
        //SqlToken 属性:
        // start 表名在SQL中的开始下标
        // end 表名在SQL中的结束下标
        // val 表名
        parser.accept(names::add);
        StringBuilder builder = new StringBuilder();
        //首次构建结束位置为0 ,后续根据 table SqlToken 的end下标变化
        int last = 0;
        for (TableNameParser.SqlToken name : names) {
            int start = name.getStart();
            if (start != last) {
                // 此次可以理解为添加原SQL里指定下标范围的 字符
                // 例如:
                // names [t_user,t_user_role]
                // 第一次
                // sql :"select * from t_user,t_user_role"
                // last :0
                // start : 14
                // end : 20
                // val : t_user
                // builder = "select * from "
                // 第二次
                // sql :"select * from t_user,t_user_role"
                // last :20
                // start : 22
                // end : 33
                // val : t_user_role
                // builder = "select * from "
                builder.append(sql, last, start);
                // 根据自己的表名策略拼接动态表名
                // 例如:表名生成策略为 table_?
                // 第一次: "select * from "+"t_user_?
                // 第二次: "select * from "+"t_user_?"+","+"t_user_role_?"
                builder.append(tableNameHandler.dynamicTableName(sql, name.getValue()));
            }
            //第一次 last : 0 > 20
            //第二次 last : 20 > 33
            last = name.getEnd();
        }
        if (last != sql.length()) {
            //如果最后一个表名下标不等于 sql 长度 ,
            // 例如 sql后带 where 条件,拼接条件
            builder.append(sql.substring(last));
        }
        if (hook != null) {
            //如果定义有钩子,运行钩子
            hook.run();
        }
        return builder.toString();
    }
}

# TableNameParser

/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.core.toolkit;

import java.util.*;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

/**
 * SQL 表名解析
 * <p>
 * https://github.com/mnadeem/sql-table-name-parser
 * <p>
 * Ultra light, Ultra fast parser to extract table name out SQLs, supports oracle dialect SQLs as well.
 * USE: new TableNameParser(sql).tables()
 *
 * @author Nadeem Mohammad, hcl
 * @since 2019-04-22
 */
public final class TableNameParser {

    private static final String TOKEN_SET = "set";
    private static final String TOKEN_OF = "of";
    private static final String TOKEN_DUAL = "dual";
    private static final String TOKEN_DELETE = "delete";
    private static final String TOKEN_CREATE = "create";
    private static final String TOKEN_INDEX = "index";

    private static final String KEYWORD_JOIN = "join";
    private static final String KEYWORD_INTO = "into";
    private static final String KEYWORD_TABLE = "table";
    private static final String KEYWORD_FROM = "from";
    private static final String KEYWORD_USING = "using";
    private static final String KEYWORD_UPDATE = "update";
    private static final String KEYWORD_DUPLICATE = "duplicate";

    private static final List<String> concerned = Arrays.asList(KEYWORD_TABLE, KEYWORD_INTO, KEYWORD_JOIN, KEYWORD_USING, KEYWORD_UPDATE);
    private static final List<String> ignored = Arrays.asList(StringPool.LEFT_BRACKET, TOKEN_SET, TOKEN_OF, TOKEN_DUAL);

    /**
     * 该表达式会匹配 SQL 中不是 SQL TOKEN 的部分,比如换行符,注释信息,结尾的 {@code ;} 等。
     * <p>
     * 排除的项目包括:
     * 1、以 -- 开头的注释信息
     * 2、;
     * 3、空白字符
     * 4、使用 /* * / 注释的信息
     * 5、把 ,() 也要分出来
     */
    private static final Pattern NON_SQL_TOKEN_PATTERN = Pattern.compile("(--[^\\v]+)|;|(\\s+)|((?s)/[*].*?[*]/)"
            + "|(((\\b|\\B)(?=[,()]))|((?<=[,()])(\\b|\\B)))"
    );

    private final List<SqlToken> tokens;

    /**
     * 从 SQL 中提取表名称
     *
     * @param sql 需要解析的 SQL 语句
     */
    public TableNameParser(String sql) {
        tokens = fetchAllTokens(sql);
    }

    /**
     * 接受一个新的访问者,并访问当前 SQL 的表名称
     * <p>
     * 现在我们改成了访问者模式,不在对以前的 SQL 做改动
     * 同时,你可以方便的获得表名位置的索引
     *
     * @param visitor 访问者
     */
    public void accept(TableNameVisitor visitor) {
        int index = 0;
        String first = tokens.get(index).getValue();
        //判断特殊语法
        if (isOracleSpecialDelete(first, tokens, index)) {
            visitNameToken(tokens.get(index + 1), visitor);
        } else if (isCreateIndex(first, tokens, index)) {
            visitNameToken(tokens.get(index + 4), visitor);
        } else {
            //如果有token 就一直执行
            while (hasMoreTokens(tokens, index)) {
                // 注意 此次是 index++  得到当前值后再位移下标
                String current = tokens.get(index++).getValue();
                if (isFromToken(current)) {
                    //找到一个 from 关键字 开始解析
                    // processFromToken 里的话就是获取表名了
                    processFromToken(tokens, index, visitor);
                } else if (isOnDuplicateKeyUpdate(current, index)) {
                    //这里也是个特殊语法 个人见解 ,为了处理奇奇怪怪的sql
                    // 此处是 on duplicate key update 直接跳过了
                    index = skipDuplicateKeyUpdateIndex(index);
                } else if (concerned.contains(current.toLowerCase())) {
                    // 如果是 table,using,join ,update ,into 啥的
                    // 只要没到末尾 就继续解析
                    if (hasMoreTokens(tokens, index)) {
                        //位移一位下标,准备获取表名
                        SqlToken next = tokens.get(index++);
                        // 这里也是获取表名了 ,
                        // 为啥呢 因为上面的语法后一位都是表名
                        visitNameToken(next, visitor);
                    }
                }
            }
        }
    }

    /**
     * 表名访问器
     */
    public interface TableNameVisitor {
        /**
         * @param name 表示表名称的 token
         */
        void visit(SqlToken name);
    }

    /**
     * 从 SQL 语句中提取出 所有的 SQL Token
     *
     * @param sql SQL
     * @return 语句
     */
    protected List<SqlToken> fetchAllTokens(String sql) {
        List<SqlToken> tokens = new ArrayList<>();
        Matcher matcher = NON_SQL_TOKEN_PATTERN.matcher(sql);
        int last = 0;
        while (matcher.find()) {
            int start = matcher.start();
            if (start != last) {
                tokens.add(new SqlToken(last, start, sql.substring(last, start)));
            }
            last = matcher.end();
        }
        if (last != sql.length()) {
            tokens.add(new SqlToken(last, sql.length(), sql.substring(last)));
        }
        return tokens;
    }

    /**
     * 如果是 DELETE 后面紧跟的不是 FROM 或者 * ,则 返回 true
     *
     * @param current 当前的 token
     * @param tokens  token 列表
     * @param index   索引
     * @return 判断是不是 Oracle 特殊的删除手法
     */
    private static boolean isOracleSpecialDelete(String current, List<SqlToken> tokens, int index) {
        if (TOKEN_DELETE.equalsIgnoreCase(current)) {
            if (hasMoreTokens(tokens, index++)) {
                String next = tokens.get(index).getValue();
                return !KEYWORD_FROM.equalsIgnoreCase(next) && !StringPool.ASTERISK.equals(next);
            }
        }
        return false;
    }

    private boolean isCreateIndex(String current, List<SqlToken> tokens, int index) {
        index++; // Point to next token
        if (TOKEN_CREATE.equalsIgnoreCase(current) && hasIthToken(tokens, index)) {
            String next = tokens.get(index).getValue();
            return TOKEN_INDEX.equalsIgnoreCase(next);
        }
        return false;
    }

    /**
     * @param current 当前token
     * @param index   索引
     * @return 判断是否是mysql的特殊语法 on duplicate key update
     */
    private boolean isOnDuplicateKeyUpdate(String current, int index) {
        if (KEYWORD_DUPLICATE.equalsIgnoreCase(current)) {
            if (hasMoreTokens(tokens, index++)) {
                String next = tokens.get(index).getValue();
                return KEYWORD_UPDATE.equalsIgnoreCase(next);
            }
        }
        return false;
    }

    private static boolean hasIthToken(List<SqlToken> tokens, int currentIndex) {
        return hasMoreTokens(tokens, currentIndex) && tokens.size() > currentIndex + 3;
    }

    private static boolean isFromToken(String currentToken) {
        return KEYWORD_FROM.equalsIgnoreCase(currentToken);
    }

    private int skipDuplicateKeyUpdateIndex(int index) {
        // on duplicate key update为mysql的固定写法,直接跳过即可。
        return index + 2;
    }

    /**
     * from处理过程
     * @param tokens
     * @param index
     * @param visitor
     */
    private static void processFromToken(List<SqlToken> tokens, int index, TableNameVisitor visitor) {
        SqlToken sqlToken = tokens.get(index++);
        visitNameToken(sqlToken, visitor);

        String next = null;
        //如果不是最后一个token
        if (hasMoreTokens(tokens, index)) {
            next = tokens.get(index++).getValue();
        }
        //判读是一个表还是多个表
        if (shouldProcessMultipleTables(next)) {
            processNonAliasedMultiTables(tokens, index, next, visitor);
        } else {
            processAliasedMultiTables(tokens, index, sqlToken, visitor);
        }
    }

    /**
     * 处理没有别名的多表
     */
    private static void processNonAliasedMultiTables(List<SqlToken> tokens, int index, String nextToken, TableNameVisitor visitor) {
        while (nextToken.equals(StringPool.COMMA)) {
            //如果是 , 将 ,后面 为表名 ,添加到访问者(也就是 tables.add)
            visitNameToken(tokens.get(index++), visitor);
            if (hasMoreTokens(tokens, index)) {
                nextToken = tokens.get(index++).getValue();
            } else {
                break;
            }
        }
    }

    /**
     * 处理有别名的多表
     */
    private static void processAliasedMultiTables(List<SqlToken> tokens, int index, SqlToken current, TableNameVisitor visitor) {
        String nextNextToken = null;
        if (hasMoreTokens(tokens, index)) {
            nextNextToken = tokens.get(index++).getValue();
        }

        if (shouldProcessMultipleTables(nextNextToken)) {
            while (hasMoreTokens(tokens, index) && nextNextToken.equals(StringPool.COMMA)) {
                if (hasMoreTokens(tokens, index)) {
                    current = tokens.get(index++);
                }
                if (hasMoreTokens(tokens, index)) {
                    index++;
                }
                if (hasMoreTokens(tokens, index)) {
                    nextNextToken = tokens.get(index++).getValue();
                }
                visitNameToken(current, visitor);
            }
        }
    }

    private static boolean shouldProcessMultipleTables(final String nextToken) {
        return nextToken != null && nextToken.equals(StringPool.COMMA);
    }

    private static boolean hasMoreTokens(List<SqlToken> tokens, int index) {
        return index < tokens.size();
    }

    private static void visitNameToken(SqlToken token, TableNameVisitor visitor) {
        String value = token.getValue().toLowerCase();
        if (!ignored.contains(value)) {
            visitor.visit(token);
        }
    }

    /**
     * parser tables
     *
     * @return table names extracted out of sql
     * @see #accept(TableNameVisitor)
     */
    public Collection<String> tables() {
        Map<String, String> tableMap = new HashMap<>();
        accept(token -> {
            String name = token.getValue();
            tableMap.putIfAbsent(name.toLowerCase(), name);
        });
        return new HashSet<>(tableMap.values());
    }

    /**
     * SQL 词
     */
    public static class SqlToken implements Comparable<SqlToken> {
        private final int start;
        private final int end;
        private final String value;

        private SqlToken(int start, int end, String value) {
            this.start = start;
            this.end = end;
            this.value = value;
        }

        public int getStart() {
            return start;
        }

        public int getEnd() {
            return end;
        }

        public String getValue() {
            return value;
        }

        @Override
        public int compareTo(SqlToken o) {
            return Integer.compare(start, o.start);
        }

        @Override
        public String toString() {
            return value;
        }

    }

}
/*
 * Copyright (c) 2011-2022, baomidou (jobob@qq.com).
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package com.baomidou.mybatisplus.extension.plugins.handler;

/**
 * 动态表名处理器
 *
 * @author miemie
 * @since 3.4.0
 */
public interface TableNameHandler {

    /**
     * 生成动态表名
     *
     * @param sql       当前执行 SQL
     * @param tableName 表名
     * @return String
     */
    String dynamicTableName(String sql, String tableName);
}

# 单元测试

# 搬运MP测试类源码

package com.baomidou.mybatisplus.extension.plugins.inner;

import org.intellij.lang.annotations.Language;
import org.junit.jupiter.api.Test;

import static org.junit.jupiter.api.Assertions.assertEquals;

/**
 * 动态表名内部拦截器测试
 *
 * @author miemie, hcl
 * @since 2020-07-16
 */
class DynamicTableNameInnerInterceptorTest {

    /**
     * 测试 SQL 中的动态表名替换
     */
    @Test
    @SuppressWarnings({"SqlDialectInspection", "SqlNoDataSourceInspection"})
    void doIt() {
        DynamicTableNameInnerInterceptor interceptor = new DynamicTableNameInnerInterceptor();
        interceptor.setTableNameHandler((sql, tableName) -> tableName + "_r");

        // 表名相互包含
        @Language("SQL")
        String origin = "SELECT * FROM t_user, t_user_role";
        assertEquals("SELECT * FROM t_user_r, t_user_role_r", interceptor.changeTable(origin));

        // 表名在末尾
        origin = "SELECT * FROM t_user";
        assertEquals("SELECT * FROM t_user_r", interceptor.changeTable(origin));

        // 表名前后有注释
        origin = "SELECT * FROM /**/t_user/* t_user */";
        assertEquals("SELECT * FROM /**/t_user_r/* t_user */", interceptor.changeTable(origin));

        // 值中带有表名
        origin = "SELECT * FROM t_user WHERE name = 't_user'";
        assertEquals("SELECT * FROM t_user_r WHERE name = 't_user'", interceptor.changeTable(origin));

        // 别名被声明要替换
        origin = "SELECT t_user.* FROM t_user_real t_user";
        assertEquals("SELECT t_user.* FROM t_user_real_r t_user", interceptor.changeTable(origin));

        // 别名被声明要替换
        origin = "SELECT t.* FROM t_user_real t left join entity e on e.id = t.id";
        assertEquals("SELECT t.* FROM t_user_real_r t left join entity e on e.id = t.id", interceptor.changeTable(origin));
    }
}

上次更新: 2025/05/14, 01:34:05
租户拦截器
SQL阻断器

← 租户拦截器 SQL阻断器→

最近更新
01
Caddy操作指南
04-25
02
Swap空间
04-22
03
Alist使用
04-21
更多文章>
Theme by Vdoing | Copyright © 2022-2025 YAN | MIT License
  • 跟随系统
  • 浅色模式
  • 深色模式
  • 阅读模式