Mybatis: multi-tenant sql-interceptor 系统多租户实现

Mybatis: multi-tenant sql-interceptor

系统租户隔离实现有多种实现方式:

  • 完全隔离(不同数据库): 没啥好讲的, 看作是多个系统就成, 此方式毫无疑问, 成本最高 玩不起 玩不起…
  • 共享隔离(共享同一个数据库), 又分为以下两种:
    • 多个Schema, 表完全隔离:一般通过中间件, 根据会话标识路由到指定schema即可
    • 同一个Schema, 表上添加租户标识:比较底层, 必须通过拦截方式实现SQL重构方可实现

下面介绍的是同Schema,表上添加租户标识的具体实现代码

1. 添加依赖
<dependency> <!-- 需要借助MyBatis拦截器插件, implements org.apache.ibatis.plugin.Interceptor -->
    <groupId>org.mybatis</groupId>
    <artifactId>mybatis</artifactId>
</dependency>
<dependency> <!-- 需要一个SQL识别的插件, Github上刚好有: https://github.com/JSQLParser/JSqlParser, Star 2.5k -->
    <groupId>com.github.jsqlparser</groupId>
    <artifactId>jsqlparser</artifactId>
    <version>1.4</version>
</dependency>
2. 代码片段及解析
class imports
 
  import net.sf.jsqlparser.expression.*;
  import net.sf.jsqlparser.expression.operators.conditional.AndExpression;
  import net.sf.jsqlparser.expression.operators.relational.EqualsTo;
  import net.sf.jsqlparser.expression.operators.relational.ExpressionList;
  import net.sf.jsqlparser.expression.operators.relational.MultiExpressionList;
  import net.sf.jsqlparser.parser.CCJSqlParserUtil;
  import net.sf.jsqlparser.schema.Column;
  import net.sf.jsqlparser.schema.Table;
  import net.sf.jsqlparser.statement.Statement;
  import net.sf.jsqlparser.statement.delete.Delete;
  import net.sf.jsqlparser.statement.insert.Insert;
  import net.sf.jsqlparser.statement.select.PlainSelect;
  import net.sf.jsqlparser.statement.select.Select;
  import net.sf.jsqlparser.statement.update.Update;
  import net.sf.jsqlparser.util.TablesNamesFinder;
  import org.apache.ibatis.executor.statement.StatementHandler;
  import org.apache.ibatis.mapping.BoundSql;
  import org.apache.ibatis.mapping.MappedStatement;
  import org.apache.ibatis.plugin.*;
  import org.apache.ibatis.reflection.MetaObject;
  import org.apache.ibatis.reflection.SystemMetaObject;
  
/**
 * 前言阐述知识点: 
 * Mybatis仅拦截的四大金刚: ParameterHandler、ResultSetHandler、StatementHandler、Executor
 * 注册插件分两种方式:
 * a. 配置文件
 * <Configuration>
 * 		<plugins>
 			<plugin />
 *      </plugins>
 * </configuration>
 * b. 代码方式
 * @Resource
 * SqlSessionFactory sqlSessionFactory;
 * sqlSessionFactory.getConfiguration().addInterceptor(interceptor);
 * 
 * 拦截器真正执行时机在plugin()方法通过代理方式注册拦截器责任链后.
 **/

@Log4j2
// Anno: @Intercepts, class头声明该拦截器需要拦截MyBatis中哪个类(type)且类中的哪些方法(method)<Signature>, 下面我们抓的是声明类的预处理方法
@Intercepts({@Signature(method = "prepare", type = StatementHandler.class, args = {Connection.class, Integer.class})})
// 亦可@Intercepts({
// @Signature(type = Executor.class, method = "update", args = {
//        MappedStatement.class, Object.class }),
// @Signature(type = Executor.class, method = "query", args = {
//        MappedStatement.class, Object.class, RowBounds.class,
//        ResultHandler.class }) }), 此处不作展示
public class TenantInterceptor implements Interceptor {
    private static final String SQL_TENANT_ID = "tenant_id"; // 名称自定
    private boolean onFilter(String statementId) {
        // todo 根据个人需求实现, 主要用途是过滤掉一些不需要租户过滤的脚本, statementId = Mapper中定义的属性ID, 譬如<select id="selectXXX" ../>
    	...
    }
    
    // method: intercept, 此处实现拦截逻辑
    @Override
    public Object intercept(Invocation invocation) throws Throwable {
        // 拦截的StatementHandler, 所以获取的对象应该也是它
        StatementHandler handler = (StatementHandler) invocation.getTarget();
        
        // 取绑定的SQL脚本并打印
        BoundSql boundSql = handler.getBoundSql();
        String sql = boundSql.getSql();
        log.debug("Intercept SQL: {}", sql);

        String delegateSql = sql;
        MetaObject statementHandler = SystemMetaObject.forObject(handler);
        // 取Mapper文件定义
        MappedStatement mappedStatement = (MappedStatement) statementHandler.getValue("delegate.mappedStatement");
        if (this.onFilter(mappedStatement.getId())) {
            // 核心注入自定义脚本的方法, 具体往下看
            delegateSql = this.delegate(sql);
            statementHandler.setValue("delegate.boundSql.sql", delegateSql);
        }

        log.debug("Delegate SQL: {}", delegateSql);
        return invocation.proceed();
    }
    
    private String delegate(String originSql) throws Exception {
        return this.rewrite(originSql);
    }
    
    // SQL重写路由
    private String rewrite(String originSql) throws Exception {
        Statement statement = CCJSqlParserUtil.parse(originSql);
        if (statement instanceof Insert) {
            return this.rewriteInsertSql(statement);
        } else if (statement instanceof Delete) {
            return this.rewriteDeleteSql(statement);
        } else if (statement instanceof Update) {
            return this.rewriteUpdateSql(statement);
        } else if (statement instanceof Select) {
            return this.rewriteSelectSql(statement);
        } else {
            // 自行实现异常
            throw new SQLNotSupportedException();
        }
    }
    
    private String rewriteInsertSql(Statement statement) {
        Insert insert = (Insert) statement;
        insert.getColumns().add(new Column(SQL_TENANT_ID));
        // insert into A (a, b, tenantId) values ('a', 'b', ''), ('a', 'b', ''), ('a', 'b', '') ...
        if (insert.getItemsList() instanceof MultiExpressionList){
            for (ExpressionList expression : ((MultiExpressionList) insert.getItemsList()).getExprList()) {
                expression.getExpressions().add(new StringValue(this.tenantSupport.getTenantId()));
            }
        } else {
            ((ExpressionList) insert.getItemsList()).getExpressions().add(new StringValue(this.tenantSupport.getTenantId()));
        }

        return insert.toString();
    }

    private String rewriteDeleteSql(Statement statement) {
        Delete deleteStatement = (Delete) statement;
        Expression whereExpression = deleteStatement.getWhere();
        if(whereExpression == null) {
            throw new SQLInterceptException("Delete-Statement must be set conditions");
        }
        // 含左右表达式
        if (whereExpression instanceof BinaryExpression) {
            AndExpression andExpression = new AndExpression(this.newEqualTo(), new Parenthesis(whereExpression));

            deleteStatement.setWhere(andExpression);
        }
        return deleteStatement.toString();
    }

    private String rewriteUpdateSql(Statement statement) {
        Update updateStatement = (Update) statement;
        if (updateStatement.getWhere() == null) {
            throw new SQLInterceptException("Update-Statement must be set conditions");
        }

        TablesNamesFinder tableNameFinder = new TablesNamesFinder();
        List<String> tableNames = tableNameFinder.getTableList(statement);
        // select 1
        if (tableNames.size() == 0) {
            return updateStatement.toString();
        }

        // update A set name='' where tenantId = ''
        for (String tableName : tableNames) {
            updateStatement.setWhere(this.newAndExpression(statement, tableName, updateStatement.getWhere()));
        }

        return updateStatement.toString();
    }

    private String rewriteSelectSql(Statement statement) {
        Select selectStatement = (Select) statement;

        TablesNamesFinder tablesNameFinder = new TablesNamesFinder();
        List<String> tableNames = tablesNameFinder.getTableList(selectStatement);
        // select 1 OR select now()
        if (tableNames.size() == 0) {
            return selectStatement.toString();
        }

        // 复杂查询, 譬如JOIN, 普通连表等, 当前仅仅处理主表条件
        PlainSelect plainSelect = (PlainSelect) selectStatement.getSelectBody();
        String mainTableName = ((Table) plainSelect.getFromItem()).getName();
        if(plainSelect.getWhere() == null) {
            plainSelect.setWhere(this.newEqualTo(statement, mainTableName));
        } else {
            plainSelect.setWhere(this.newAndExpression(statement, mainTableName, plainSelect.getWhere()));
        }

        return selectStatement.toString();
    }
    
    private AndExpression newAndExpression(Statement statement, String tableName, Expression whereExpression) {
        EqualsTo equalsTo = this.newEqualTo(statement, tableName);
        // rewrite parent where expression
        return new AndExpression(equalsTo, new Parenthesis(whereExpression));
    }

    private EqualsTo newEqualTo() {
        EqualsTo equalsTo = new EqualsTo();
        equalsTo.setLeftExpression(new Column(SQL_TENANT_ID));
        equalsTo.setRightExpression(new StringValue(tenantSupport.getTenantId()));
        return equalsTo;
    }

    private EqualsTo newEqualTo(Statement statement, String tableName) {
        EqualsTo equalsTo = new EqualsTo();
        String aliasName = this.getTableAlias(statement, tableName);
        equalsTo.setLeftExpression(new Column((aliasName == null ? "" : aliasName + '.') + SQL_TENANT_ID));
        equalsTo.setRightExpression(new StringValue(tenantSupport.getTenantId()));
        return equalsTo;
    }

    private String getTableAlias(Statement stmt, String tableName) {
        String aliasName = null;
        if (stmt instanceof Insert) {
            return tableName;
        } else if (stmt instanceof Delete) {
            Delete deleteStatement = (Delete) stmt;
            if ((deleteStatement.getTable()).getName().equalsIgnoreCase(tableName)) {
                Alias alias = deleteStatement.getTable().getAlias();
                aliasName = alias != null ? alias.getName() : tableName;
            }
        } else if (stmt instanceof Update) {
            Update updateStatement = (Update) stmt;
            if ((updateStatement.getTables().get(0)).getName().equalsIgnoreCase(tableName)) {
                Alias alias = updateStatement.getTables().get(0).getAlias();
                aliasName = alias != null ? alias.getName() : tableName;
            }
        } else if (stmt instanceof Select) {
            Select select = (Select) stmt;
            PlainSelect plainSelect = (PlainSelect) select.getSelectBody();
            if (((Table) plainSelect.getFromItem()).getName().equalsIgnoreCase(tableName)) {
                Alias alias = plainSelect.getFromItem().getAlias();
                aliasName = alias != null ? alias.getName() : tableName;
            }
        }
        return aliasName;
    }

    // 生成包装代理类
    @Override
    public Object plugin(Object target) {
        if (target instanceof StatementHandler) {
            return Plugin.wrap(target, this);
        } else {
            return target;
        }
    }
    
    @Override
    public void setProperties(Properties properties) {
    }
    
    
    ...
} 

小提示
<!-- 最好隐性操作租户标识,可在Mybatis生成基础文件时屏蔽 -->
<table ...>
	<ignoreColumn column="tenant_id"/>
</table>

猜你喜欢

转载自blog.csdn.net/weixin_49689128/article/details/107636972