Java根据数据库生成所有实体

备忘,很多代码是从别人的博客里找的,有做一些修改。

基本情况是,比如现在我有一个现成的数据库,里面有百十个表,每个表有10几20个字段。

在java代码里使用需要每个表对应一个实体,如何获取这些实体类呢?最直接的方式,一个个表,一个个字段手工coding。这种方式有一个不好的地方:如果表和字段很多,怕是实体没写完,人进医院了。

所以想,写一个小小的程序,提供数据库连接的必要信息,用生成这些实体类,使用的时候就放到main方法里运行就行。

public class EntitiesGenerator {
    // 数据库连接
    private String URL;
    private String DBName;
    private String NAME;
    private String PASS;

    private String authorName = "eric";// 作者名字
    private String[] colnames; // 列名数组
    private String[] colTypes; // 列名类型数组
    private int[] colSizes; // 列名大小数组
    private boolean f_util = false; // 是否需要导入包java.util.*
    private boolean f_sql = false; // 是否需要导入包java.sql.*

    private SqlHelper sqlHelper = null;

    /*
     * 构造函数
     */
    public EntitiesGenerator(String url, String dbname, String username, String password) {
        this.URL = url + "/" + dbname;
        this.DBName = dbname;
        this.NAME = username;
        this.PASS = password;

        sqlHelper = new SqlHelper(this.URL, this.NAME, this.PASS);
    }

    public void Generate() {
        List<String> tableNames = sqlHelper.Get(
                "SELECT * FROM INFORMATION_SCHEMA.TABLES where TABLE_SCHEMA='" + DBName + "';", "TABLE_NAME");
        Connection con = null;
        try {
            con = sqlHelper.getConnection();
        } catch (ClassNotFoundException | SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
        for (String table : tableNames) {
            Generate(table, con);
            System.out.println("generated: "+table );
            resetTableInfo();
        }
        try {
            sqlHelper.closeConnection(con);
        } catch (ClassNotFoundException | SQLException e) {
            // TODO Auto-generated catch block
            e.printStackTrace();
        }
    }

    private void resetTableInfo(){
        colnames = null;
        colTypes=null;
        colSizes=null;
        f_util=false;
        f_sql=false;
    }

    private String getPackageOutPath() {
        return "com.eric.learning.entity." + DBName;
    }

    private void Generate(String tablename, Connection con) {
        if (con == null) {
            System.out.println("------------------Connection to database was not set up------------------");
            return;
        }
        // 查要生成实体类的表
        String sql = "SELECT * FROM " + tablename + " limit 0, 1;";
        PreparedStatement pStemt = null;
        try {
            pStemt = con.prepareStatement(sql);
            ResultSetMetaData rsmd = pStemt.getMetaData();
            int size = rsmd.getColumnCount(); // 统计列
            colnames = new String[size];
            colTypes = new String[size];
            colSizes = new int[size];
            for (int i = 0; i < size; i++) {
                colnames[i] = rsmd.getColumnName(i + 1).replace(" ", "");
                colTypes[i] = rsmd.getColumnTypeName(i + 1);

                if (colTypes[i].equalsIgnoreCase("datetime") || colTypes[i].equalsIgnoreCase("date")) {
                    f_util = true;
                }
                if (colTypes[i].equalsIgnoreCase("image") || colTypes[i].equalsIgnoreCase("text")
                        || colTypes[i].equalsIgnoreCase("TIMESTAMP")) {
                    f_sql = true;
                }
                colSizes[i] = rsmd.getColumnDisplaySize(i + 1);
            }

            String content = parse(colnames, colTypes, colSizes, tablename);

            try {
                File directory = new File("");
                File dir = new File(directory.getAbsolutePath() + "/src/" + this.getPackageOutPath().replace(".", "/"));
                dir.mkdirs();

                String outputPath = dir.getAbsolutePath() + "/" + initcap(tablename) + ".java";
                FileWriter fw = new FileWriter(outputPath);
                PrintWriter pw = new PrintWriter(fw);
                pw.println(content);
                pw.flush();
                pw.close();
            } catch (IOException e) {
                e.printStackTrace();
            }
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                if (pStemt != null) {
                    pStemt.close();
                }
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }
    }

    /**
     * 功能:生成实体类主体代码
     * 
     * @param colnames
     * @param colTypes
     * @param colSizes
     * @return
     */
    private String parse(String[] colnames, String[] colTypes, int[] colSizes, String tablename) {
        StringBuffer sb = new StringBuffer();
        sb.append("package " + this.getPackageOutPath() + ";\r\n");
        // 判断是否导入工具包
        if (f_util) {
            sb.append("import java.util.Date;\r\n");
        }
        if (f_sql) {
            sb.append("import java.sql.*;\r\n");
        }

        sb.append("\r\n");
        // 注释部分
        sb.append("/**\r\n");
        sb.append(" * " + tablename);

        SimpleDateFormat formater = new SimpleDateFormat("yyyy-MM-dd HH:ss:mm");
        sb.append("  generated at " + formater.format(new Date()) + " by: " + this.authorName + "\r\n");

        sb.append(" */");
        // 实体部分
        sb.append("\r\n\r\npublic class " + initcap(tablename) + "{\r\n");
        processAllAttrs(sb);// 属性
        processAllMethod(sb);// get set方法
        sb.append("}");
        
        return sb.toString();
    }

    /**
     * 功能:生成所有属性
     * 
     * @param sb
     */
    private void processAllAttrs(StringBuffer sb) {
        for (int i = 0; i < colnames.length; i++) {
            sb.append("\tprivate " + sqlType2JavaType(colTypes[i]) + " " + colnames[i] + ";\r\n");
        }
        sb.append(System.lineSeparator());
    }

    /**
     * 功能:生成所有方法
     * 
     * @param sb
     */
    private void processAllMethod(StringBuffer sb) {

        for (int i = 0; i < colnames.length; i++) {
            sb.append("\tpublic void set" + initcap(colnames[i]) + "(" + sqlType2JavaType(colTypes[i]) + " "
                    + colnames[i] + "){\r\n");
            sb.append("\t\tthis." + colnames[i] + "=" + colnames[i] + ";\r\n");
            sb.append("\t}\r\n\r\n");
            sb.append("\tpublic " + sqlType2JavaType(colTypes[i]) + " get" + initcap(colnames[i]) + "(){\r\n");
            sb.append("\t\treturn " + colnames[i] + ";\r\n");
            sb.append("\t}\r\n\r\n");
        }

    }

    /**
     * 功能:将输入字符串的首字母改成大写
     * 
     * @param str
     * @return
     */
    private String initcap(String str) {

        char[] ch = str.toCharArray();
        if (ch[0] >= 'a' && ch[0] <= 'z') {
            ch[0] = (char) (ch[0] - 32);
        }

        return new String(ch);
    }

    /**
     * 功能:获得列的数据类型
     * 
     * @param sqlType
     * @return
     */
    private String sqlType2JavaType(String sqlType) {

        if (sqlType.equalsIgnoreCase("bit")) {
            return "boolean";
        } else if (sqlType.equalsIgnoreCase("tinyint") || sqlType.equalsIgnoreCase("tinyINT UNSIGNED")) {
            return "byte";
        } else if (sqlType.equalsIgnoreCase("smallint")) {
            return "short";
        } else if (sqlType.equalsIgnoreCase("int") || sqlType.equalsIgnoreCase("INT UNSIGNED")) {
            return "int";
        } else if (sqlType.equalsIgnoreCase("bigint")) {
            return "long";
        } else if (sqlType.equalsIgnoreCase("float")) {
            return "float";
        } else if (sqlType.equalsIgnoreCase("decimal") || sqlType.equalsIgnoreCase("numeric")
                || sqlType.equalsIgnoreCase("real") || sqlType.equalsIgnoreCase("money")
                || sqlType.equalsIgnoreCase("smallmoney")||sqlType.equalsIgnoreCase("DOUBLE") ) {
            return "double";
        } else if (sqlType.equalsIgnoreCase("varchar") || sqlType.equalsIgnoreCase("char")
                || sqlType.equalsIgnoreCase("nvarchar") || sqlType.equalsIgnoreCase("nchar")
                || sqlType.equalsIgnoreCase("text")) {
            return "String";
        } else if (sqlType.equalsIgnoreCase("datetime") || sqlType.equalsIgnoreCase("date")) {
            return "Date";
        } else if (sqlType.equalsIgnoreCase("image")) {
            return "Blod";
        }else if (sqlType.equalsIgnoreCase("TIMESTAMP")){
            return "Timestamp";
        }

        return null;
    }

}
View Code

使用到的类SqlHelper,没有完整的实现各种查询方法(暂时用不到):

public class SqlHelper {
    private String url;
    private String username;
    private String password;
    
    private Connection connection;

    private static final String DRIVER = "com.mysql.jdbc.Driver";

    public SqlHelper(String url, String username, String password) {
        this.url = url;
        this.username = username;
        this.password = password;
    }

    public String getUrl() {
        return url;
    }

    public void setUrl(String url) {
        this.url = url;
    }

    public String getUsername() {
        return username;
    }

    public void setUsername(String username) {
        this.username = username;
    }

    public String getPassword() {
        return password;
    }

    public void setPassword(String password) {
        this.password = password;
    }

    public List<HashMap<String, Object>> Get(String sql) {
        List<HashMap<String, Object>> result = new ArrayList<HashMap<String, Object>>();
        Statement statement = null;
        try {
            statement = getStatement();
            ResultSet set = statement.executeQuery(sql);
            ResultSetMetaData meta = set.getMetaData();
            int columnCount = meta.getColumnCount();
            System.out.println(columnCount);

            while (set.next()) {
                HashMap<String, Object> row = new HashMap<>();
                for (int i = 1; i <= columnCount; i++) {
                    String column = meta.getColumnName(i);
                    row.put(column, set.getObject(column));
                }
                result.add(row);
            }
        } catch (ClassNotFoundException e) {
            e.printStackTrace();
        } catch (SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                closeStatement(statement);
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }

        return result;
    }

    @SuppressWarnings("unchecked")
    public <T> List<T> Get(String sql, String columnName) {
        List<T> result = new ArrayList<T>();
        Statement statement = null;
        try {
            statement = getStatement();
            ResultSet set = statement.executeQuery(sql);
            while (set.next()) {
                result.add((T) set.getObject(columnName));
            }
        } catch (ClassNotFoundException | SQLException e) {
            e.printStackTrace();
        } finally {
            try {
                closeStatement(statement);
            } catch (SQLException e) {
                e.printStackTrace();
            }
        }

        return result;
    }

    public Statement getStatement() throws ClassNotFoundException, SQLException{
        Class.forName(DRIVER);
        Connection con = DriverManager.getConnection(url, username, password);
        Statement statement = con.createStatement();
        return statement;
    }

    public Connection getConnection() throws ClassNotFoundException, SQLException {
        if (connection == null) {
            Class.forName(DRIVER);
            connection = DriverManager.getConnection(url, username, password);
        }

        return connection;
    }

    public void closeConnection(Connection conn) throws ClassNotFoundException, SQLException {
        if (connection != null) {
            connection.close();
        }

        if (conn != null) {
            conn.close();
        }
        
        System.out.println("-----------Connection closed now-----------");
    }

    public void closeStatement(Statement statement) throws SQLException {
        if (statement != null) {
            Connection con = statement.getConnection();
            statement.close();
            if (con != null) {
                con.close();
            }
        }
    }
}
View Code

在main方法里调用:

EntitiesGenerator gen = new EntitiesGenerator("jdbc:mysql://xx.xx.xx.xx:xxxx", "xxx", "xxx", "xxx");
gen.Generate();

这里是连接mysql生成实体的,所以项目里一定需要有mysql的驱动jar包。

猜你喜欢

转载自www.cnblogs.com/lihan829/p/9615210.html