Hadoop 自定义输入输出
这里以MySQL为输入、MySQL为输出作为测试例
一、输入端
自定义的输入需要继承InputFormat
,并实现数据分片(getSplits()
)和创建记录读取对象(createRecordReader()
)
1. 数据读取抽象类
public abstract class MySQLInputWritable implements Writable {
/**
* 从数据返回信息中读取字段信息
* @param rs
* @throws SQLException
*/
public abstract void readFieldsFromResultSet(ResultSet rs) throws SQLException;
}
2. 自定义MySQL输入类
public class MySQLInputFormat<V extends MySQLInputWritable> extends InputFormat<LongWritable, V> {
private static final Logger LOG = Logger.getLogger(MySQLInputFormat.class);
/** 配置 - 输入端数据库驱动类 */
public static final String MYSQL_INPUT_DRIVER = "mysql.input.driver";
/** 配置 - 输入端数据库URL */
public static final String MYSQL_INPUT_URL = "mysql.input.url";
/** 配置 - 输入端数据库用户名 */
public static final String MYSQL_INPUT_USERNAME = "mysql.input.username";
/** 配置 - 输入端数据库密码 */
public static final String MYSQL_INPUT_PASSWORD = "mysql.input.password";
/** 配置 - 查询总记录数语句 */
public static final String MYSQL_INPUT_SELECT_COUNT_SQL = "mysql.input.select.count";
/** 配置 - 查询语句 */
public static final String MYSQL_INPUT_SELECT_RECORD_SQL = "mysql.input.select.record";
/** 配置 - 每个数据分片包含的条数(默认 100 条) */
public static final String MYSQL_INPUT_SPLIT_PRE_SIZE = "mysql.input.split.pre.size";
/** 配置 - 读取数据类 */
public static final String MYSQL_OUTPUT_VALUE_CLASS = "mysql.output.value.class";
/**
* 计算切片,决定map任务数量
*/
@Override
public List<InputSplit> getSplits(JobContext context)
throws IOException, InterruptedException {
Configuration conf = context.getConfiguration();
Connection conn = null;
Statement stmt = null;
ResultSet rs = null;
long recordCount = 0;
try {
conn = this.getConnection(conf);
stmt = conn.createStatement();
rs = stmt.executeQuery(conf.get(MYSQL_INPUT_SELECT_COUNT_SQL));
if(rs.next())
recordCount = rs.getLong(1);
} catch (Exception e) {
throw new IOException("查询数据总量失败", e);
} finally {
this.closeAutoCloseable(conn);
this.closeAutoCloseable(stmt);
this.closeAutoCloseable(rs);
}
List<InputSplit> splits = new ArrayList<InputSplit>();
// 计算分片数量
long preSplitCount = conf.getLong(MYSQL_INPUT_SPLIT_PRE_SIZE, 100);
int splitNums = (int) (recordCount / preSplitCount +
recordCount % preSplitCount == 0 ? 0 : 1);
// 将数据分片信息存入列表中
for(int i = 0; i < splitNums; i++) {
if(i != splitNums - 1)
splits.add(new MySQLInputSplit(i * preSplitCount, (i + 1) * preSplitCount));
else
splits.add(new MySQLInputSplit(i * preSplitCount, recordCount));
}
return splits;
}
/**
* 创建记录读取对象
*/
@Override
public RecordReader<LongWritable, V> createRecordReader(InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException {
RecordReader<LongWritable, V> reader = new MySQLRecordReader();
reader.initialize(split, context);
return reader;
}
/**
* 获取数据库连接
* @param conf
* @return
* @throws Exception
*/
private Connection getConnection(Configuration conf) throws Exception {
String driver = conf.get(MYSQL_INPUT_DRIVER);
String url = conf.get(MYSQL_INPUT_URL);
String username = conf.get(MYSQL_INPUT_USERNAME);
String password = conf.get(MYSQL_INPUT_PASSWORD);
Class.forName(driver);
return DriverManager.getConnection(url, username, password);
}
/**
* 关闭连接
* @param autoCloseable
*/
private void closeAutoCloseable(AutoCloseable autoCloseable) {
try {
if(autoCloseable != null)
autoCloseable.close();
} catch (Exception e) {
LOG.error("关闭失败"+e.getMessage());
}
}
/**
* MySQL数据切片信息类
*/
public static class MySQLInputSplit extends InputSplit implements Writable {
// 分片数据位置信息,MySQL数据不存在HDFS中,所以数组设置为空
private String[] locations = new String[0];
// 开始位置
private long start;
// 结束位置
private long end;
public MySQLInputSplit() {
}
public MySQLInputSplit(long start, long end) {
this.start = start;
this.end = end;
}
@Override
public long getLength() throws IOException, InterruptedException {
return this.end - this.start;
}
@Override
public String[] getLocations() throws IOException, InterruptedException {
// 根据该值决定是否采用数据本地化策略
return this.locations;
}
public long getStart() {
return start;
}
public void setStart(long start) {
this.start = start;
}
public long getEnd() {
return end;
}
public void setEnd(long end) {
this.end = end;
}
@Override
public void write(DataOutput out) throws IOException {
out.writeLong(this.start);
out.writeLong(this.end);
}
@Override
public void readFields(DataInput in) throws IOException {
this.start = in.readLong();
this.end = in.readLong();
}
}
/**
* MySQL数据读取类
*
* @param <V>
*/
public class MySQLRecordReader extends RecordReader<LongWritable, V> {
private Connection conn;
private ResultSet rs = null;
private Configuration conf;
private MySQLInputSplit split;
private LongWritable key = null;
private V value = null;
private long postion = 0; // 计算当前进度
@Override
public void initialize(InputSplit split, TaskAttemptContext context)
throws IOException, InterruptedException {
this.split = (MySQLInputSplit) split;
this.conf = context.getConfiguration();
}
/**
* 通过反射实例化输出类
* 默认为空数据类型
* @return
*/
@SuppressWarnings("unchecked")
private V createValue() {
Class<? extends MySQLInputWritable> clazz = this.conf.getClass(MYSQL_OUTPUT_VALUE_CLASS,
MySQLNullWritable.class, MySQLInputWritable.class);
return (V) ReflectionUtils.newInstance(clazz, this.conf);
}
/**
* 组装查询语句
* @return
*/
private String getQuerySql() {
String sql = conf.get(MYSQL_INPUT_SELECT_RECORD_SQL);
try {
sql += " LIMIT " + this.split.getLength();
sql += " OFFSET " + this.split.getStart();
} catch (Exception e) {
LOG.error(e.getMessage());
}
return sql;
}
@Override
public boolean nextKeyValue()
throws IOException, InterruptedException {
if(this.key == null) {
this.key = new LongWritable();
}
if(this.value == null) {
this.value = createValue();
}
if(this.conn == null) {
try {
this.conn = MySQLInputFormat.this.getConnection(this.conf);
} catch (Exception e) {
throw new IOException("获取数据库连接失败", e);
}
}
try {
if(this.rs == null) {
String sql = this.getQuerySql();
Statement stmt = this.conn.createStatement();
this.rs = stmt.executeQuery(sql);
}
if(!this.rs.next()) {
return false; // 没有下一个结果了
}
// 还有结果
this.value.readFieldsFromResultSet(this.rs); // 读取字段信息
this.key.set(this.postion);
this.postion++; // 更新进度
return true;
} catch (SQLException e) {
throw new IOException("获取数据失败", e);
}
}
@Override
public LongWritable getCurrentKey()
throws IOException, InterruptedException {
return this.key;
}
@Override
public V getCurrentValue()
throws IOException, InterruptedException {
return this.value;
}
@Override
public float getProgress()
throws IOException, InterruptedException {
return this.postion / this.split.getLength();
}
@Override
public void close() throws IOException {
MySQLInputFormat.this.closeAutoCloseable(this.conn);
MySQLInputFormat.this.closeAutoCloseable(this.rs);
}
}
/**
* 空数据类型
*/
public class MySQLNullWritable extends MySQLInputWritable {
@Override
public void write(DataOutput out) throws IOException {
}
@Override
public void readFields(DataInput in) throws IOException {
}
@Override
public void readFieldsFromResultSet(ResultSet rs) throws SQLException {
}
}
}
二、输出端
1. 数据输出抽象类
public abstract class MySQLOutputWritable implements Writable {
/**
* 获取插入或更新语句
* @return
*/
public abstract String fetchInsertOrUpdateSql();
/**
* 设置数据输出参数
* @param pstmt
* @throws SQLException
*/
public abstract void setPreparedStatementParameters(PreparedStatement pstmt) throws SQLException;
}
2. 自定义MySQL输出类
public class MySQLOutputFormat<V extends MySQLOutputWritable> extends OutputFormat<NullWritable, V> {
private static final Logger LOG = Logger.getLogger(MySQLOutputFormat.class);
/** 配置 - 输出端数据库驱动类 */
public static final String MYSQL_OUTPUT_DRIVER = "mysql.output.dirver";
/** 配置 - 输出端数据库URL */
public static final String MYSQL_OUTPUT_URL = "mysql.output.url";
/** 配置 - 输出端数据库用户名 */
public static final String MYSQL_OUTPUT_USERNAME = "mysql.output.username";
/** 配置 - 输出端数据库密码 */
public static final String MYSQL_OUTPUT_PASSWORD = "mysql.output.password";
/** 配置 - 批量提交的数据记录数 */
public static final String MYSQL_OUTPUT_BATCH_SIZE = "mysql.output.batch.size";
/**
* 获取记录写入对象
*/
@Override
public RecordWriter<NullWritable, V> getRecordWriter(TaskAttemptContext context)
throws IOException, InterruptedException {
return new MySQLRecordWriter(context.getConfiguration());
}
/**
* 检查输出空间是否有效
*/
@Override
public void checkOutputSpecs(JobContext context)
throws IOException, InterruptedException {
Connection conn = null;
try {
conn = this.getConnection(context.getConfiguration());
} catch (Exception e) {
throw new IOException("连接数据库失败", e);
} finally {
this.closeAutoCloseable(conn);
}
}
@Override
public OutputCommitter getOutputCommitter(TaskAttemptContext context)
throws IOException, InterruptedException {
return new FileOutputCommitter(null, context);
}
/**
* MySQL数据写入类
*/
public class MySQLRecordWriter extends RecordWriter<NullWritable, V> {
private Configuration conf;
private Connection conn;
// PreparedStatement 缓冲器
private Map<String, PreparedStatement> pstmtCache = new HashMap<String, PreparedStatement>();
// Batch计数器
private Map<String, Integer> batchCache = new HashMap<String, Integer>();
private int batchSize = 100; // 批量提交记录数
public MySQLRecordWriter() {}
public MySQLRecordWriter(Configuration conf) {
this.conf = conf;
this.batchSize = conf.getInt(MYSQL_OUTPUT_BATCH_SIZE, this.batchSize);
}
@Override
public void write(NullWritable key, V value) throws IOException, InterruptedException {
if(this.conn == null) {
try {
this.conn = MySQLOutputFormat.this.getConnection(this.conf);
this.conn.setAutoCommit(false); // 关闭自动提交
} catch (Exception e) {
throw new IOException("连接数据库失败", e);
}
}
String sql = value.fetchInsertOrUpdateSql();
PreparedStatement pstmt = this.pstmtCache.get(sql);
if(pstmt == null) {
try {
pstmt = conn.prepareStatement(value.fetchInsertOrUpdateSql());
this.pstmtCache.put(sql, pstmt);
} catch (SQLException e) {
throw new IOException("创建PreparedStatement对象产生异常", e);
}
}
Integer count = this.batchCache.get(sql);
if(count == null)
count = 0;
try {
value.setPreparedStatementParameters(pstmt);
pstmt.addBatch();
count++;
if(count >= this.batchSize) {
pstmt.executeBatch(); // 批量执行
this.conn.commit(); // 提交执行结果
count = 0; // 清零
}
this.batchCache.put(sql, count);
} catch (SQLException e) {
throw new IOException("向数据库写入数据出现异常", e);
}
}
@Override
public void close(TaskAttemptContext context) throws IOException, InterruptedException {
// 将缓冲器中的pstmt再次提交一次,防止因批量提交数量不足而未提交的数据漏掉
for(Map.Entry<String, PreparedStatement> entry : pstmtCache.entrySet()) {
try {
entry.getValue().executeBatch();
this.conn.commit();
} catch (SQLException e) {
throw new IOException("向数据库写入数据出现异常", e);
}
}
MySQLOutputFormat.this.closeAutoCloseable(this.conn);
}
}
/**
* 获取数据库连接
* @param conf
* @return
* @throws Exception
*/
private Connection getConnection(Configuration conf) throws Exception {
String driver = conf.get(MYSQL_OUTPUT_DRIVER);
String url = conf.get(MYSQL_OUTPUT_URL);
String username = conf.get(MYSQL_OUTPUT_USERNAME);
String password = conf.get(MYSQL_OUTPUT_PASSWORD);
Class.forName(driver);
return DriverManager.getConnection(url, username, password);
}
/**
* 关闭连接
* @param autoCloseable
*/
private void closeAutoCloseable(AutoCloseable autoCloseable) {
try {
if(autoCloseable != null)
autoCloseable.close();
} catch (Exception e) {
LOG.error("关闭失败"+e.getMessage());
}
}
}
三、测试例
1. 目的
统计某一URL单日用户访问量
2. 数据库表结构
- 数据输入表(event_logs)
字段名 | 字段类型 | 字段说明 |
---|---|---|
uid | varchar | 用户id |
sid | varchar | 会话id |
url | varchar | URL |
time | decimal | 时间戳 |
- 数据输出表(stats_uv)
字段名 | 字段类型 | 字段说明 |
---|---|---|
url | varchar | URL |
date | date | 日期 |
uv | int | 用户访问量 |
3. 编写测试例
3.1 Map 输入Value类
public class MySQLMapperInputValue extends MySQLInputWritable {
private String uid;
private String sid;
private String url;
private long time;
@Override
public void write(DataOutput out) throws IOException {
out.writeUTF(uid);
out.writeUTF(sid);
out.writeUTF(url);
out.writeLong(time);
}
@Override
public void readFields(DataInput in) throws IOException {
this.uid = in.readUTF();
this.sid = in.readUTF();
this.url = in.readUTF();
this.time = in.readLong();
}
@Override
public void readFieldsFromResultSet(ResultSet rs) throws SQLException {
this.uid = rs.getString("uid");
this.sid = rs.getString("sid");
this.url = rs.getString("url");
this.time = rs.getLong("time");
}
public String getUid() {
return uid;
}
public void setUid(String uid) {
this.uid = uid;
}
public String getSid() {
return sid;
}
public void setSid(String sid) {
this.sid = sid;
}
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public long getTime() {
return time;
}
public void setTime(long time) {
this.time = time;
}
}
3.2 Map 输出Key
public class MySQLMapperOutputKey implements WritableComparable<MySQLMapperOutputKey> {
private String url;
private String date;
@Override
public void write(DataOutput out) throws IOException {
out.writeUTF(url);
out.writeUTF(date);
}
@Override
public void readFields(DataInput in) throws IOException {
this.url = in.readUTF();
this.date = in.readUTF();
}
@Override
public int compareTo(MySQLMapperOutputKey o) {
if(this.url.equals(o.url)) {
return this.date.compareTo(o.date);
}
return this.url.compareTo(o.url);
}
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public String getDate() {
return date;
}
public void setDate(String date) {
this.date = date;
}
@Override
public int hashCode() {
final int prime = 31;
int result = 1;
result = prime * result + ((date == null) ? 0 : date.hashCode());
result = prime * result + ((url == null) ? 0 : url.hashCode());
return result;
}
@Override
public boolean equals(Object obj) {
if (this == obj)
return true;
if (obj == null)
return false;
if (getClass() != obj.getClass())
return false;
MySQLMapperOutputKey other = (MySQLMapperOutputKey) obj;
if (date == null) {
if (other.date != null)
return false;
} else if (!date.equals(other.date))
return false;
if (url == null) {
if (other.url != null)
return false;
} else if (!url.equals(other.url))
return false;
return true;
}
}
3.3 Map 输出Value
public class MySQLMapperOutputValue implements Writable {
private String uid;
private String sid;
@Override
public void write(DataOutput out) throws IOException {
if (this.uid == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeUTF(this.uid);
}
if (this.sid == null) {
out.writeBoolean(false);
} else {
out.writeBoolean(true);
out.writeUTF(this.sid);
}
}
@Override
public void readFields(DataInput in) throws IOException {
this.uid = in.readBoolean() ? in.readUTF() : null;
this.sid = in.readBoolean() ? in.readUTF() : null;
}
public String getUid() {
return uid;
}
public void setUid(String uid) {
this.uid = uid;
}
public String getSid() {
return sid;
}
public void setSid(String sid) {
this.sid = sid;
}
@Override
public String toString() {
return "[uid=" + uid + ", sid=" + sid + "]";
}
}
3.4 Map 任务
public class MySQLMapper extends Mapper<LongWritable, MySQLMapperInputValue, MySQLMapperOutputKey, MySQLMapperOutputValue> {
private Calendar calendar = Calendar.getInstance();
private SimpleDateFormat sdf = new SimpleDateFormat("yyyy-MM-dd");
private MySQLMapperOutputKey outputKey;
private MySQLMapperOutputValue outputValue;
@Override
protected void map(LongWritable key, MySQLMapperInputValue value, Context context)
throws IOException, InterruptedException {
outputKey = new MySQLMapperOutputKey();
outputKey.setUrl(value.getUrl());
calendar.setTimeInMillis(value.getTime());
outputKey.setDate(sdf.format(calendar.getTime()));
outputValue = new MySQLMapperOutputValue();
outputValue.setUid(value.getUid());
outputValue.setSid(value.getSid());
System.out.println(key);
context.write(outputKey, outputValue);
}
}
3.5 Reduce 输出Value
public class MySQLReducerOutputValue extends MySQLOutputWritable {
private String url;
private String date;
private int uv;
@Override
public String fetchInsertOrUpdateSql() {
// 示例:插入语句
return "insert into stats_uv(url, date, uv) values(?, ?, ?)";
}
@Override
public void setPreparedStatementParameters(PreparedStatement pstmt) throws SQLException {
pstmt.setString(1, this.url);
pstmt.setString(2, this.date);
pstmt.setInt(3, this.uv);
}
@Override
public void write(DataOutput out) throws IOException {
if(this.url == null)
out.writeBoolean(false);
else {
out.writeBoolean(true);
out.writeUTF(this.url);
}
if(this.date == null)
out.writeBoolean(false);
else {
out.writeBoolean(true);
out.writeUTF(this.date);
}
out.writeInt(uv);
}
@Override
public void readFields(DataInput in) throws IOException {
this.url = in.readBoolean() ? in.readUTF() : null;
this.date = in.readBoolean() ? in.readUTF() : null;
this.uv = in.readInt();
}
public String getUrl() {
return url;
}
public void setUrl(String url) {
this.url = url;
}
public String getDate() {
return date;
}
public void setDate(String date) {
this.date = date;
}
public int getUv() {
return uv;
}
public void setUv(int uv) {
this.uv = uv;
}
}
3.6 Reduce 任务
public class MySQLReducer extends Reducer<MySQLMapperOutputKey, MySQLMapperOutputValue, NullWritable, MySQLReducerOutputValue> {
private NullWritable outputKey = NullWritable.get();
private MySQLReducerOutputValue outputValue;
@Override
protected void reduce(MySQLMapperOutputKey key, Iterable<MySQLMapperOutputValue> values, Context context)
throws IOException, InterruptedException {
outputValue = new MySQLReducerOutputValue();
outputValue.setUrl(key.getUrl());
outputValue.setDate(key.getDate());
Set<String> set = new HashSet<String>();
for(MySQLMapperOutputValue value : values) {
set.add(value.getUid());
}
outputValue.setUv(set.size());
context.write(outputKey, outputValue);
}
}
3.7 Runner
public class MySQLRunner implements Tool {
private Configuration conf = new Configuration();
@Override
public void setConf(Configuration conf) {
this.conf = conf;
}
@Override
public Configuration getConf() {
return this.conf;
}
@Override
public int run(String[] args) throws Exception {
Job job = Job.getInstance(this.conf, "test-custom-format");
conf = job.getConfiguration();
job.setJarByClass(MySQLRunner.class);
// inputFormat
conf.set(MySQLInputFormat.MYSQL_INPUT_DRIVER, "com.mysql.jdbc.Driver");
conf.set(MySQLInputFormat.MYSQL_INPUT_URL, "jdbc:mysql://192.168.100.1:3306/data");
conf.set(MySQLInputFormat.MYSQL_INPUT_USERNAME, "username");
conf.set(MySQLInputFormat.MYSQL_INPUT_PASSWORD, "password");
conf.set(MySQLInputFormat.MYSQL_INPUT_SELECT_COUNT_SQL, "select count(*) from event_logs");
conf.set(MySQLInputFormat.MYSQL_INPUT_SELECT_RECORD_SQL, "select uid,sid,url,time from event_logs");
conf.setLong(MySQLInputFormat.MYSQL_INPUT_SPLIT_PRE_SIZE, 10L);
conf.setClass(MySQLInputFormat.MYSQL_OUTPUT_VALUE_CLASS, MySQLMapperInputValue.class, MySQLInputWritable.class);
job.setInputFormatClass(MySQLInputFormat.class);
// map
job.setMapperClass(MySQLMapper.class);
job.setMapOutputKeyClass(MySQLMapperOutputKey.class);
job.setMapOutputValueClass(MySQLMapperOutputValue.class);
// reduce
job.setReducerClass(MySQLReducer.class);
job.setOutputKeyClass(NullWritable.class);
job.setOutputValueClass(MySQLReducerOutputValue.class);
// outputFormat
conf.set(MySQLOutputFormat.MYSQL_OUTPUT_DRIVER, "com.mysql.jdbc.Driver");
conf.set(MySQLOutputFormat.MYSQL_OUTPUT_URL, "jdbc:mysql://192.168.100.1:3306/data");
conf.set(MySQLOutputFormat.MYSQL_OUTPUT_USERNAME, "username");
conf.set(MySQLOutputFormat.MYSQL_OUTPUT_PASSWORD, "password");
conf.setInt(MySQLOutputFormat.MYSQL_OUTPUT_BATCH_SIZE, 2);
job.setOutputFormatClass(MySQLOutputFormat.class);
return job.waitForCompletion(true) ? 0 : 1;
}
public static void main(String[] args) throws Exception {
ToolRunner.run(new MySQLRunner(), args);
}
}
4. 运行结果
输入结果
输出结果