springBoot基于myBites分页Interceptor




import java.lang.reflect.Field;
import java.sql.Connection;
import java.sql.DatabaseMetaData;
import java.sql.PreparedStatement;
import java.sql.ResultSet;
import java.sql.SQLException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.UUID;
import java.util.Map.Entry;
import org.apache.ibatis.executor.parameter.ParameterHandler;
import org.apache.ibatis.executor.statement.StatementHandler;
import org.apache.ibatis.mapping.BoundSql;
import org.apache.ibatis.mapping.MappedStatement;
import org.apache.ibatis.mapping.ParameterMapping;
import org.apache.ibatis.mapping.SqlCommandType;
import org.apache.ibatis.plugin.Interceptor;
import org.apache.ibatis.plugin.Intercepts;
import org.apache.ibatis.plugin.Invocation;
import org.apache.ibatis.plugin.Plugin;
import org.apache.ibatis.plugin.Signature;
import org.apache.ibatis.reflection.MetaObject;
import org.apache.ibatis.reflection.SystemMetaObject;
import org.apache.ibatis.scripting.defaults.DefaultParameterHandler;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

@Intercepts({@Signature(
    type = StatementHandler.class,
    method = "prepare",
    args = {Connection.class, Integer.class}
)})
public class PageInterceptor implements Interceptor {
    private static final String pageFlag = "paged";
    private Logger log = LoggerFactory.getLogger(this.getClass());

    public PageInterceptor() {
    }

    public Object intercept(Invocation arg0) throws Throwable {
        if (arg0.getTarget() instanceof StatementHandler) {
            StatementHandler statementHandler = (StatementHandler)arg0.getTarget();
            MetaObject metaObject = SystemMetaObject.forObject(statementHandler);
            MappedStatement mappedStatement = (MappedStatement)metaObject.getValue("delegate.mappedStatement");
            String selectId = mappedStatement.getId();
            SqlCommandType sqlCommandType = mappedStatement.getSqlCommandType();
            if (selectId.substring(selectId.lastIndexOf(".") + 1).toLowerCase().contains("paged") && sqlCommandType == SqlCommandType.SELECT) {
                BoundSql boundSql = (BoundSql)metaObject.getValue("delegate.boundSql");
                String sql = boundSql.getSql();
                Map param = (Map)boundSql.getParameterObject();
                Connection connection = (Connection)arg0.getArgs()[0];
                String countSql = this.concatCountSql(sql, param);
                String pageSql = this.getPageSql(connection, sql, param);
                PreparedStatement statement = null;
                ResultSet rs = null;
                int totalCount = 0;

                try {
                    statement = connection.prepareStatement(countSql);
                    List<ParameterMapping> parameterMappings = boundSql.getParameterMappings();
                    Object parameterObject = boundSql.getParameterObject();
                    BoundSql countBoundSql = new BoundSql(mappedStatement.getConfiguration(), countSql, parameterMappings, parameterObject);
                    Field additionalParametersField = BoundSql.class.getDeclaredField("additionalParameters");
                    additionalParametersField.setAccessible(true);
                    Map<String, Object> additionalParameters = (Map)additionalParametersField.get(boundSql);
                    Iterator var21 = additionalParameters.keySet().iterator();

                    while(var21.hasNext()) {
                        String key = (String)var21.next();
                        countBoundSql.setAdditionalParameter(key, additionalParameters.get(key));
                    }

                    ParameterHandler parameterHandler = new DefaultParameterHandler(mappedStatement, parameterObject, countBoundSql);
                    parameterHandler.setParameters(statement);
                    rs = statement.executeQuery();
                    if (rs.next()) {
                        totalCount = rs.getInt(1);
                    }
                } catch (SQLException var31) {
                    this.log.error("执行sql时发生错误:" + var31.getMessage());
                    BaseUtil.saveLog(0, "执行sql时发生错误", var31.getMessage());
                } finally {
                    try {
                        if (rs != null) {
                            rs.close();
                        }

                        if (statement != null) {
                            statement.close();
                        }
                    } catch (SQLException var30) {
                        this.log.error("执行sql时发生错误:" + var30.getMessage());
                        BaseUtil.saveLog(0, "执行sql时发生错误", var30.getMessage());
                    }

                }

                metaObject.setValue("delegate.boundSql.sql", pageSql);
                param.put("total", totalCount);
                int size = Integer.valueOf(param.get("rows").toString());
                double pageCount = Math.ceil((double)totalCount / (double)size);
                param.put("pageCount", (int)pageCount);
            }
        }

        return arg0.proceed();
    }

    public Object plugin(Object arg0) {
        return arg0 instanceof StatementHandler ? Plugin.wrap(arg0, this) : arg0;
    }

    public void setProperties(Properties arg0) {
    }

    private String getPageSql(Connection conn, String sql, Map param) throws SQLException {
        StringBuffer sqlBuffer = new StringBuffer(sql);
        DatabaseMetaData dbmd = conn.getMetaData();
        String dataBaseType = dbmd.getDatabaseProductName().toLowerCase();
        if (dataBaseType.contains("mysql")) {
            return param.get("filter") == null && param.get("sort") == null ? this.concatMysqlPageSql(sql, param) : this.kendoMysqlPageSql(sql, param);
        } else if (dataBaseType.contains("oracle")) {
            return param.get("filter") == null && param.get("sort") == null ? this.concatOraclePageSql(sql, param) : this.kendoCountSql(sql, param);
        } else {
            return sqlBuffer.toString();
        }
    }

    private String concatMysqlPageSql(String sql, Map param) {
        String sqlPage = String.format(" %s limit %d,%d", sql, param.get("offset"), param.get("rows"));
        return sqlPage;
    }

    private String concatOraclePageSql(String sql, Map param) {
        StringBuffer buffer = new StringBuffer();
        return buffer.toString();
    }

    private String concatCountSql(String sql, Map param) {
        StringBuffer buffer = new StringBuffer();
        buffer.append("select count(1) from (");
        buffer.append(sql);
        buffer.append(") v");
        return buffer.toString();
    }

    private String kendoMysqlPageSql(String sql, Map param) {
        StringBuffer buffer = new StringBuffer();
        Map filter = (Map)param.get("filter");
        Map condition = buildCondition(filter);
        String clauseSql = condition.get("clause").toString();
        ArrayList sortField = (ArrayList)param.get("sort");
        if (clauseSql.length() > 0) {
            buffer.append("select * from (");
            buffer.append(sql);
            buffer.append(") v where 1 = 1 and ");
            buffer.append(clauseSql);
        } else {
            buffer.append("select * from (");
            buffer.append(sql);
            buffer.append(") v");
        }

        String sortStr;
        if (sortField != null && sortField.size() > 0) {
            sortStr = "";

            Map field;
            for(Iterator var9 = sortField.iterator(); var9.hasNext(); sortStr = sortStr + String.format("%s %s,", field.get("field"), field.get("dir"))) {
                Object item = var9.next();
                field = (Map)item;
            }

            String sort = String.format(" order by %s", sortStr.substring(0, sortStr.length() - 1));
            buffer.append(sort);
        }

        sortStr = String.format(" limit %d,%d", param.get("offset"), param.get("rows"));
        buffer.append(sortStr);
        return buffer.toString();
    }

    private String kendoCountSql(String sql, Map param) {
        StringBuffer buffer = new StringBuffer();
        Map filter = (Map)param.get("filter");
        Map condition = buildCondition(filter);
        String clauseSql = condition.get("clause").toString();
        if (clauseSql.length() > 0) {
            buffer.append("select count(1) from (");
            buffer.append(sql);
            buffer.append(") v where 1 = 1 and ");
            buffer.append(clauseSql);
        } else {
            buffer.append("select count(1) from (");
            buffer.append(sql);
            buffer.append(") v");
        }

        return buffer.toString();
    }

    private static Map buildCondition(Map filter) {
        Map data = new HashMap();
        if (filter == null) {
            data.put("clause", "");
            return data;
        } else {
            StringBuilder clause = new StringBuilder();
            ArrayList filters = (ArrayList)filter.get("filters");
            if (filters != null) {
                clause.append("(");

                for(int i = 0; i < filters.size(); ++i) {
                    Map f = (Map)filters.get(i);
                    Map temp;
                    if (f.get("logic") == null && f.get("filters") == null) {
                        temp = (Map)filters.get(i);
                        Map temp = createSimpleCondition(temp);
                        clause.append(temp.get("clause"));
                        if (i != filters.size() - 1) {
                            clause.append(String.format(" %s ", filter.get("logic")));
                        }
                    } else {
                        temp = createCondition(f);
                        clause.append(temp.get("clause"));
                        if (i < filters.size() - 1) {
                            clause.append(String.format(" %s ", filter.get("logic")));
                        }
                    }
                }

                clause.append(")");
            }

            data.put("clause", clause.toString());
            return data;
        }
    }

    private static Map createCondition(Map filter) {
        Map data = new HashMap();
        if (filter == null) {
            data.put("clause", "");
            return data;
        } else {
            ArrayList filters = (ArrayList)filter.get("filters");
            StringBuilder clause = new StringBuilder();
            if (filters != null) {
                clause.append("(");

                for(int i = 0; i < filters.size(); ++i) {
                    Map f = (Map)filters.get(i);
                    Map temp = createSimpleCondition(f);
                    clause.append(temp.get("clause"));
                    if (i < filters.size() - 1) {
                        clause.append(String.format(" %s ", filter.get("logic")));
                    }
                }

                clause.append(")");
            }

            data.put("clause", clause.toString());
            return data;
        }
    }

    private static Map createSimpleCondition(Map filterDesc) {
        Map data = new HashMap();
        String key = String.format("%s_%s", filterDesc.get("field"), UUID.randomUUID());
        String val = filterDesc.get("value") == null ? "" : filterDesc.get("value").toString();
        String value = "'" + val + "'";
        String op = filterDesc.get("operator") == null ? "" : filterDesc.get("operator").toString();
        if (op.equals("eq")) {
            data.put("clause", filterDesc.get("field") + " = " + value);
        } else if (op.equals("contains")) {
            data.put("clause", filterDesc.get("field") + " LIKE '%" + val + "%'");
        } else if (op.equals("doesnotcontain")) {
            data.put("clause", filterDesc.get("field") + " NOT LIKE '%" + val + "%'");
        } else if (op.equals("startswith")) {
            data.put("clause", filterDesc.get("field") + " LIKE '" + val + "%'");
        } else if (op.equals("endswith")) {
            data.put("clause", filterDesc.get("field") + " LIKE '%" + val + "'");
        } else if (op.equals("iscontainedin")) {
            data.put("clause", filterDesc.get("field") + " IN " + val);
        } else if (op.equals("gt")) {
            data.put("clause", filterDesc.get("field") + " > " + value);
        } else if (op.equals("gte")) {
            data.put("clause", filterDesc.get("field") + " >= " + value);
        } else if (op.equals("lt")) {
            data.put("clause", filterDesc.get("field") + " < " + value);
        } else if (op.equals("lte")) {
            data.put("clause", filterDesc.get("field") + " <= " + value);
        } else if (op.equals("neq")) {
            data.put("clause", filterDesc.get("field") + " != " + value);
        } else if (op.equals("in")) {
            if (filterDesc.get("value").getClass().getName() == "java.util.ArrayList") {
                ArrayList array = (ArrayList)filterDesc.get("value");
                String str = "";

                for(int i = 0; i < array.size(); ++i) {
                    if (i == array.size() - 1) {
                        str = str + "'" + array.get(i).toString() + "'";
                    } else {
                        str = str + "'" + array.get(i).toString() + "',";
                    }
                }

                data.put("clause", filterDesc.get("field") + " in (" + str + ")");
            } else if (filterDesc.get("value").getClass().getName() == "java.lang.String") {
                data.put("clause", filterDesc.get("field") + " in (" + value + ")");
            }
        } else {
            data.put("clause", filterDesc.get("field") + " = " + value);
        }

        return data;
    }

    private static Map buildParams(Map params) {
        Iterator var1 = params.entrySet().iterator();

        while(true) {
            Entry entry;
            do {
                if (!var1.hasNext()) {
                    return params;
                }

                Object item = var1.next();
                entry = (Entry)item;
            } while(!entry.getValue().getClass().isArray());

            Object[] arr = (Object[])((Object[])entry.getValue());
            String res = "";

            for(int i = 0; i < arr.length; ++i) {
                if (i < i - 1) {
                    res = res + arr[i] + ",";
                } else {
                    res = res + arr[i];
                }
            }

            params.put(entry.getKey(), res);
        }
    }
}
发布了116 篇原创文章 · 获赞 37 · 访问量 11万+

猜你喜欢

转载自blog.csdn.net/samHuangLiang/article/details/105042379