Hadoop 自定义输入输出

Hadoop 自定义输入输出

这里以MySQL为输入、MySQL为输出作为测试例

一、输入端

自定义的输入需要继承InputFormat,并实现数据分片(getSplits())和创建记录读取对象(createRecordReader())

1. 数据读取抽象类

public abstract class MySQLInputWritable implements Writable {
    /**
     * 从数据返回信息中读取字段信息
     * @param rs
     * @throws SQLException
     */
    public abstract void readFieldsFromResultSet(ResultSet rs) throws SQLException;

}

2. 自定义MySQL输入类

public class MySQLInputFormat<V extends MySQLInputWritable> extends InputFormat<LongWritable, V> {

    private static final Logger LOG = Logger.getLogger(MySQLInputFormat.class);

    /** 配置 - 输入端数据库驱动类 */
    public static final String MYSQL_INPUT_DRIVER = "mysql.input.driver";
    /** 配置 - 输入端数据库URL */
    public static final String MYSQL_INPUT_URL = "mysql.input.url";
    /** 配置 - 输入端数据库用户名 */
    public static final String MYSQL_INPUT_USERNAME = "mysql.input.username";
    /** 配置 - 输入端数据库密码 */
    public static final String MYSQL_INPUT_PASSWORD = "mysql.input.password";

    /** 配置 - 查询总记录数语句 */
    public static final String MYSQL_INPUT_SELECT_COUNT_SQL = "mysql.input.select.count";
    /** 配置 - 查询语句 */
    public static final String MYSQL_INPUT_SELECT_RECORD_SQL = "mysql.input.select.record";
    /** 配置 - 每个数据分片包含的条数(默认 100 条) */
    public static final String MYSQL_INPUT_SPLIT_PRE_SIZE = "mysql.input.split.pre.size";
    /** 配置 - 读取数据类 */
    public static final String MYSQL_OUTPUT_VALUE_CLASS = "mysql.output.value.class";


    /**
     * 计算切片,决定map任务数量
     */
    @Override
    public List<InputSplit> getSplits(JobContext context) 
            throws IOException, InterruptedException {

        Configuration conf = context.getConfiguration();
        Connection conn = null;
        Statement stmt = null;
        ResultSet rs = null;

        long recordCount = 0;

        try {
            conn = this.getConnection(conf);
            stmt = conn.createStatement();
            rs = stmt.executeQuery(conf.get(MYSQL_INPUT_SELECT_COUNT_SQL));
            if(rs.next())
                recordCount = rs.getLong(1);
        } catch (Exception e) {
            throw new IOException("查询数据总量失败", e);
        } finally {
            this.closeAutoCloseable(conn);
            this.closeAutoCloseable(stmt);
            this.closeAutoCloseable(rs);
        }

        List<InputSplit> splits = new ArrayList<InputSplit>();
        // 计算分片数量
        long preSplitCount = conf.getLong(MYSQL_INPUT_SPLIT_PRE_SIZE, 100);
        int splitNums = (int) (recordCount / preSplitCount + 
                recordCount % preSplitCount == 0 ? 0 : 1);

        // 将数据分片信息存入列表中
        for(int i = 0; i < splitNums; i++) {
            if(i != splitNums - 1)
                splits.add(new MySQLInputSplit(i * preSplitCount, (i + 1) * preSplitCount));
            else
                splits.add(new MySQLInputSplit(i * preSplitCount, recordCount));
        }

        return splits;
    }

    /**
     * 创建记录读取对象
     */
    @Override
    public RecordReader<LongWritable, V> createRecordReader(InputSplit split, TaskAttemptContext context)
            throws IOException, InterruptedException {

        RecordReader<LongWritable, V> reader = new MySQLRecordReader();
        reader.initialize(split, context);
        return reader;
    }


    /**
     * 获取数据库连接
     * @param conf
     * @return
     * @throws Exception
     */
    private Connection getConnection(Configuration conf) throws Exception {

        String driver = conf.get(MYSQL_INPUT_DRIVER);
        String url = conf.get(MYSQL_INPUT_URL);
        String username = conf.get(MYSQL_INPUT_USERNAME);
        String password = conf.get(MYSQL_INPUT_PASSWORD);

        Class.forName(driver);
        return DriverManager.getConnection(url, username, password);
    }

    /**
     * 关闭连接
     * @param autoCloseable
     */
    private void closeAutoCloseable(AutoCloseable autoCloseable) {
        try {
            if(autoCloseable != null)
                autoCloseable.close();
        } catch (Exception e) {
            LOG.error("关闭失败"+e.getMessage());
        }
    }

    /**
     * MySQL数据切片信息类
     */
    public static class MySQLInputSplit extends InputSplit implements Writable {
        // 分片数据位置信息,MySQL数据不存在HDFS中,所以数组设置为空
        private String[] locations = new String[0];
        // 开始位置
        private long start;
        // 结束位置
        private long end;

        public MySQLInputSplit() {
        }

        public MySQLInputSplit(long start, long end) {
            this.start = start;
            this.end = end;
        }

        @Override
        public long getLength() throws IOException, InterruptedException {
            return this.end - this.start;
        }

        @Override
        public String[] getLocations() throws IOException, InterruptedException {
            // 根据该值决定是否采用数据本地化策略
            return this.locations;
        }

        public long getStart() {
            return start;
        }

        public void setStart(long start) {
            this.start = start;
        }

        public long getEnd() {
            return end;
        }

        public void setEnd(long end) {
            this.end = end;
        }

        @Override
        public void write(DataOutput out) throws IOException {
            out.writeLong(this.start);
            out.writeLong(this.end);
        }

        @Override
        public void readFields(DataInput in) throws IOException {
            this.start = in.readLong();
            this.end = in.readLong();
        }
    }

    /**
     * MySQL数据读取类
     *
     * @param <V>
     */
    public class MySQLRecordReader extends RecordReader<LongWritable, V> {

        private Connection conn;
        private ResultSet rs = null;
        private Configuration conf;
        private MySQLInputSplit split;
        private LongWritable key = null;
        private V value = null;
        private long postion = 0; // 计算当前进度

        @Override
        public void initialize(InputSplit split, TaskAttemptContext context) 
                throws IOException, InterruptedException {
            this.split = (MySQLInputSplit) split;
            this.conf = context.getConfiguration();
        }

        /**
         * 通过反射实例化输出类
         * 默认为空数据类型
         * @return
         */
        @SuppressWarnings("unchecked")
        private V createValue() {
            Class<? extends MySQLInputWritable> clazz = this.conf.getClass(MYSQL_OUTPUT_VALUE_CLASS, 
                    MySQLNullWritable.class, MySQLInputWritable.class);
            return (V) ReflectionUtils.newInstance(clazz, this.conf);
        }

        /**
         * 组装查询语句
         * @return
         */
        private String getQuerySql() {
            String sql = conf.get(MYSQL_INPUT_SELECT_RECORD_SQL);
            try {
                sql += " LIMIT " + this.split.getLength();
                sql += " OFFSET " + this.split.getStart();
            } catch (Exception e) {
                LOG.error(e.getMessage());
            }

            return sql;
        }

        @Override
        public boolean nextKeyValue() 
                throws IOException, InterruptedException {
            if(this.key == null) {
                this.key = new LongWritable();
            }
            if(this.value == null) {
                this.value = createValue();
            }
            if(this.conn == null) {
                try {
                    this.conn = MySQLInputFormat.this.getConnection(this.conf);
                } catch (Exception e) {
                    throw new IOException("获取数据库连接失败", e);
                }
            }

            try {
                if(this.rs == null) {
                    String sql = this.getQuerySql();
                    Statement stmt = this.conn.createStatement();
                    this.rs = stmt.executeQuery(sql);
                }

                if(!this.rs.next()) {
                    return false; // 没有下一个结果了
                }

                // 还有结果
                this.value.readFieldsFromResultSet(this.rs); // 读取字段信息
                this.key.set(this.postion);
                this.postion++; // 更新进度
                return true;
            } catch (SQLException e) {
                throw new IOException("获取数据失败", e);
            }
        }

        @Override
        public LongWritable getCurrentKey() 
                throws IOException, InterruptedException {

            return this.key;
        }

        @Override
        public V getCurrentValue() 
                throws IOException, InterruptedException {

            return this.value;
        }

        @Override
        public float getProgress() 
                throws IOException, InterruptedException {

            return this.postion / this.split.getLength();
        }

        @Override
        public void close() throws IOException {
            MySQLInputFormat.this.closeAutoCloseable(this.conn);
            MySQLInputFormat.this.closeAutoCloseable(this.rs);
        }
    }

    /**
     * 空数据类型
     */
    public class MySQLNullWritable extends MySQLInputWritable {

        @Override
        public void write(DataOutput out) throws IOException {

        }

        @Override
        public void readFields(DataInput in) throws IOException {

        }

        @Override
        public void readFieldsFromResultSet(ResultSet rs) throws SQLException {

        }
    }
}

二、输出端

1. 数据输出抽象类

public abstract class MySQLOutputWritable implements Writable {

    /**
     * 获取插入或更新语句
     * @return
     */
    public abstract String fetchInsertOrUpdateSql();

    /**
     * 设置数据输出参数
     * @param pstmt
     * @throws SQLException
     */
    public abstract void setPreparedStatementParameters(PreparedStatement pstmt) throws SQLException;
}

2. 自定义MySQL输出类

public class MySQLOutputFormat<V extends MySQLOutputWritable> extends OutputFormat<NullWritable, V> {

    private static final Logger LOG = Logger.getLogger(MySQLOutputFormat.class);

    /** 配置 - 输出端数据库驱动类 */
    public static final String MYSQL_OUTPUT_DRIVER = "mysql.output.dirver";
    /** 配置 - 输出端数据库URL */
    public static final String MYSQL_OUTPUT_URL = "mysql.output.url";
    /** 配置 - 输出端数据库用户名 */
    public static final String MYSQL_OUTPUT_USERNAME = "mysql.output.username";
    /** 配置 - 输出端数据库密码 */
    public static final String MYSQL_OUTPUT_PASSWORD = "mysql.output.password";
    /** 配置 - 批量提交的数据记录数 */
    public static final String MYSQL_OUTPUT_BATCH_SIZE = "mysql.output.batch.size";

    /**
     * 获取记录写入对象
     */
    @Override
    public RecordWriter<NullWritable, V> getRecordWriter(TaskAttemptContext context)
            throws IOException, InterruptedException {

        return new MySQLRecordWriter(context.getConfiguration());
    }

    /**
     * 检查输出空间是否有效
     */
    @Override
    public void checkOutputSpecs(JobContext context) 
            throws IOException, InterruptedException {
        Connection conn = null;
        try {
            conn = this.getConnection(context.getConfiguration());
        } catch (Exception e) {
            throw new IOException("连接数据库失败", e);
        } finally {
            this.closeAutoCloseable(conn);
        }
    }

    @Override
    public OutputCommitter getOutputCommitter(TaskAttemptContext context) 
            throws IOException, InterruptedException {

        return new FileOutputCommitter(null, context);
    }


    /**
     * MySQL数据写入类
     */
    public class MySQLRecordWriter extends RecordWriter<NullWritable, V> {
        private Configuration conf;
        private Connection conn;
        // PreparedStatement 缓冲器
        private Map<String, PreparedStatement> pstmtCache = new HashMap<String, PreparedStatement>();
        // Batch计数器
        private Map<String, Integer> batchCache = new HashMap<String, Integer>();
        private int batchSize = 100; // 批量提交记录数

        public MySQLRecordWriter() {}

        public MySQLRecordWriter(Configuration conf) {
            this.conf = conf;
            this.batchSize = conf.getInt(MYSQL_OUTPUT_BATCH_SIZE, this.batchSize);
        }

        @Override
        public void write(NullWritable key, V value) throws IOException, InterruptedException {

            if(this.conn == null) {
                try {
                    this.conn = MySQLOutputFormat.this.getConnection(this.conf);
                    this.conn.setAutoCommit(false); // 关闭自动提交
                } catch (Exception e) {
                    throw new IOException("连接数据库失败", e);
                }
            }

            String sql = value.fetchInsertOrUpdateSql();
            PreparedStatement pstmt = this.pstmtCache.get(sql);
            if(pstmt == null) {
                try {
                    pstmt = conn.prepareStatement(value.fetchInsertOrUpdateSql());
                    this.pstmtCache.put(sql, pstmt);
                } catch (SQLException e) {
                    throw new IOException("创建PreparedStatement对象产生异常", e);
                }
            }

            Integer count = this.batchCache.get(sql);
            if(count == null)
                count = 0;

            try {
                value.setPreparedStatementParameters(pstmt);
                pstmt.addBatch();
                count++;
                if(count >= this.batchSize) {
                    pstmt.executeBatch(); // 批量执行
                    this.conn.commit(); // 提交执行结果
                    count = 0; // 清零
                }
                this.batchCache.put(sql, count);
            } catch (SQLException e) {
                throw new IOException("向数据库写入数据出现异常", e);
            }
        }

        @Override
        public void close(TaskAttemptContext context) throws IOException, InterruptedException {

            // 将缓冲器中的pstmt再次提交一次,防止因批量提交数量不足而未提交的数据漏掉
            for(Map.Entry<String, PreparedStatement> entry : pstmtCache.entrySet()) {
                try {
                    entry.getValue().executeBatch();
                    this.conn.commit();
                } catch (SQLException e) {
                    throw new IOException("向数据库写入数据出现异常", e);
                }
            }

            MySQLOutputFormat.this.closeAutoCloseable(this.conn);
        }
    }

    /**
     * 获取数据库连接
     * @param conf
     * @return
     * @throws Exception
     */
    private Connection getConnection(Configuration conf) throws Exception {

        String driver = conf.get(MYSQL_OUTPUT_DRIVER);
        String url = conf.get(MYSQL_OUTPUT_URL);
        String username = conf.get(MYSQL_OUTPUT_USERNAME);
        String password = conf.get(MYSQL_OUTPUT_PASSWORD);

        Class.forName(driver);
        return DriverManager.getConnection(url, username, password);
    }

    /**
     * 关闭连接
     * @param autoCloseable
     */
    private void closeAutoCloseable(AutoCloseable autoCloseable) {
        try {
            if(autoCloseable != null)
                autoCloseable.close();
        } catch (Exception e) {
            LOG.error("关闭失败"+e.getMessage());
        }
    }
}

三、测试例

1. 目的

统计某一URL单日用户访问量

2. 数据库表结构

  • 数据输入表(event_logs)
字段名 字段类型 字段说明
uid varchar 用户id
sid varchar 会话id
url varchar URL
time decimal 时间戳
  • 数据输出表(stats_uv)
字段名 字段类型 字段说明
url varchar URL
date date 日期
uv int 用户访问量

3. 编写测试例

3.1 Map 输入Value类

public class MySQLMapperInputValue extends MySQLInputWritable {

    private String uid;

    private String sid;

    private String url;

    private long time;

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeUTF(uid);
        out.writeUTF(sid);
        out.writeUTF(url);
        out.writeLong(time);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        this.uid = in.readUTF();
        this.sid = in.readUTF();
        this.url = in.readUTF();
        this.time = in.readLong();
    }

    @Override
    public void readFieldsFromResultSet(ResultSet rs) throws SQLException {
        this.uid = rs.getString("uid");
        this.sid = rs.getString("sid");
        this.url = rs.getString("url");
        this.time = rs.getLong("time");
    }

    public String getUid() {
        return uid;
    }

    public void setUid(String uid) {
        this.uid = uid;
    }

    public String getSid() {
        return sid;
    }

    public void setSid(String sid) {
        this.sid = sid;
    }

    public String getUrl() {
        return url;
    }

    public void setUrl(String url) {
        this.url = url;
    }

    public long getTime() {
        return time;
    }

    public void setTime(long time) {
        this.time = time;
    }
}

3.2 Map 输出Key

public class MySQLMapperOutputKey implements WritableComparable<MySQLMapperOutputKey>  {

    private String url;

    private String date;

    @Override
    public void write(DataOutput out) throws IOException {
        out.writeUTF(url);
        out.writeUTF(date);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        this.url = in.readUTF();
        this.date = in.readUTF();
    }

    @Override
    public int compareTo(MySQLMapperOutputKey o) {
        if(this.url.equals(o.url)) {
            return this.date.compareTo(o.date);
        }
        return this.url.compareTo(o.url);
    }

    public String getUrl() {
        return url;
    }

    public void setUrl(String url) {
        this.url = url;
    }

    public String getDate() {
        return date;
    }

    public void setDate(String date) {
        this.date = date;
    }

    @Override
    public int hashCode() {
        final int prime = 31;
        int result = 1;
        result = prime * result + ((date == null) ? 0 : date.hashCode());
        result = prime * result + ((url == null) ? 0 : url.hashCode());
        return result;
    }

    @Override
    public boolean equals(Object obj) {
        if (this == obj)
            return true;
        if (obj == null)
            return false;
        if (getClass() != obj.getClass())
            return false;
        MySQLMapperOutputKey other = (MySQLMapperOutputKey) obj;
        if (date == null) {
            if (other.date != null)
                return false;
        } else if (!date.equals(other.date))
            return false;
        if (url == null) {
            if (other.url != null)
                return false;
        } else if (!url.equals(other.url))
            return false;
        return true;
    }
}

3.3 Map 输出Value

public class MySQLMapperOutputValue implements Writable {

    private String uid;

    private String sid;

    @Override
    public void write(DataOutput out) throws IOException {

        if (this.uid == null) {
            out.writeBoolean(false);
        } else {
            out.writeBoolean(true);
            out.writeUTF(this.uid);
        }

        if (this.sid == null) {
            out.writeBoolean(false);
        } else {
            out.writeBoolean(true);
            out.writeUTF(this.sid);
        }
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        this.uid = in.readBoolean() ? in.readUTF() : null;
        this.sid = in.readBoolean() ? in.readUTF() : null;
    }

    public String getUid() {
        return uid;
    }

    public void setUid(String uid) {
        this.uid = uid;
    }

    public String getSid() {
        return sid;
    }

    public void setSid(String sid) {
        this.sid = sid;
    }

    @Override
    public String toString() {
        return "[uid=" + uid + ", sid=" + sid + "]";
    }
}

3.4 Map 任务

public class MySQLMapper extends Mapper<LongWritable, MySQLMapperInputValue, MySQLMapperOutputKey, MySQLMapperOutputValue> {

    private Calendar calendar = Calendar.getInstance();
    private SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");

    private MySQLMapperOutputKey outputKey;
    private MySQLMapperOutputValue outputValue;

    @Override
    protected void map(LongWritable key, MySQLMapperInputValue value, Context context)
            throws IOException, InterruptedException {

        outputKey = new MySQLMapperOutputKey();
        outputKey.setUrl(value.getUrl());
        calendar.setTimeInMillis(value.getTime());
        outputKey.setDate(sdf.format(calendar.getTime()));

        outputValue = new MySQLMapperOutputValue();
        outputValue.setUid(value.getUid());
        outputValue.setSid(value.getSid());

        System.out.println(key);

        context.write(outputKey, outputValue);
    }
}

3.5 Reduce 输出Value

public class MySQLReducerOutputValue extends MySQLOutputWritable {

    private String url;

    private String date;

    private int uv;

    @Override
    public String fetchInsertOrUpdateSql() {
        // 示例:插入语句
        return "insert into stats_uv(url, date, uv) values(?, ?, ?)";
    }

    @Override
    public void setPreparedStatementParameters(PreparedStatement pstmt) throws SQLException {
        pstmt.setString(1, this.url);
        pstmt.setString(2, this.date);
        pstmt.setInt(3, this.uv);
    }

    @Override
    public void write(DataOutput out) throws IOException {
        if(this.url == null)
            out.writeBoolean(false);
        else {
            out.writeBoolean(true);
            out.writeUTF(this.url);
        }
        if(this.date == null)
            out.writeBoolean(false);
        else {
            out.writeBoolean(true);
            out.writeUTF(this.date);
        }
        out.writeInt(uv);
    }

    @Override
    public void readFields(DataInput in) throws IOException {
        this.url = in.readBoolean() ? in.readUTF() : null;
        this.date = in.readBoolean() ? in.readUTF() : null;
        this.uv = in.readInt();
    }

    public String getUrl() {
        return url;
    }

    public void setUrl(String url) {
        this.url = url;
    }

    public String getDate() {
        return date;
    }

    public void setDate(String date) {
        this.date = date;
    }

    public int getUv() {
        return uv;
    }

    public void setUv(int uv) {
        this.uv = uv;
    }
}

3.6 Reduce 任务

public class MySQLReducer extends Reducer<MySQLMapperOutputKey, MySQLMapperOutputValue, NullWritable, MySQLReducerOutputValue> {

    private NullWritable outputKey = NullWritable.get();
    private MySQLReducerOutputValue outputValue;


    @Override
    protected void reduce(MySQLMapperOutputKey key, Iterable<MySQLMapperOutputValue> values, Context context)
            throws IOException, InterruptedException {
        outputValue = new MySQLReducerOutputValue();
        outputValue.setUrl(key.getUrl());
        outputValue.setDate(key.getDate());

        Set<String> set = new HashSet<String>();
        for(MySQLMapperOutputValue value : values) {
            set.add(value.getUid());
        }
        outputValue.setUv(set.size());
        context.write(outputKey, outputValue);
    }
}

3.7 Runner

public class MySQLRunner implements Tool {

    private Configuration conf = new Configuration();

    @Override
    public void setConf(Configuration conf) {
        this.conf = conf;
    }

    @Override
    public Configuration getConf() {
        return this.conf;
    }

    @Override
    public int run(String[] args) throws Exception {

        Job job = Job.getInstance(this.conf, "test-custom-format");

        conf = job.getConfiguration();

        job.setJarByClass(MySQLRunner.class);

        // inputFormat
        conf.set(MySQLInputFormat.MYSQL_INPUT_DRIVER, "com.mysql.jdbc.Driver");
        conf.set(MySQLInputFormat.MYSQL_INPUT_URL, "jdbc:mysql://192.168.100.1:3306/data");
        conf.set(MySQLInputFormat.MYSQL_INPUT_USERNAME, "username");
        conf.set(MySQLInputFormat.MYSQL_INPUT_PASSWORD, "password");
        conf.set(MySQLInputFormat.MYSQL_INPUT_SELECT_COUNT_SQL, "select count(*) from event_logs");
        conf.set(MySQLInputFormat.MYSQL_INPUT_SELECT_RECORD_SQL, "select uid,sid,url,time from event_logs");
        conf.setLong(MySQLInputFormat.MYSQL_INPUT_SPLIT_PRE_SIZE, 10L);
        conf.setClass(MySQLInputFormat.MYSQL_OUTPUT_VALUE_CLASS, MySQLMapperInputValue.class, MySQLInputWritable.class);
        job.setInputFormatClass(MySQLInputFormat.class);

        // map
        job.setMapperClass(MySQLMapper.class);
        job.setMapOutputKeyClass(MySQLMapperOutputKey.class);
        job.setMapOutputValueClass(MySQLMapperOutputValue.class);

        // reduce
        job.setReducerClass(MySQLReducer.class);
        job.setOutputKeyClass(NullWritable.class);
        job.setOutputValueClass(MySQLReducerOutputValue.class);

        // outputFormat
        conf.set(MySQLOutputFormat.MYSQL_OUTPUT_DRIVER, "com.mysql.jdbc.Driver");
        conf.set(MySQLOutputFormat.MYSQL_OUTPUT_URL, "jdbc:mysql://192.168.100.1:3306/data");
        conf.set(MySQLOutputFormat.MYSQL_OUTPUT_USERNAME, "username");
        conf.set(MySQLOutputFormat.MYSQL_OUTPUT_PASSWORD, "password");
        conf.setInt(MySQLOutputFormat.MYSQL_OUTPUT_BATCH_SIZE, 2);
        job.setOutputFormatClass(MySQLOutputFormat.class);

        return job.waitForCompletion(true) ? 0 : 1;
    }

    public static void main(String[] args) throws Exception {
        ToolRunner.run(new MySQLRunner(), args);
    }
}

4. 运行结果

  • 输入结果
    自定义输入输出-数据库输入

  • 输出结果
    自定义输入输出-数据库输出

猜你喜欢

转载自blog.csdn.net/goldlone/article/details/82431734
今日推荐