这是一个我自己参考网站写的UDAF,期间各种bug,终于让我跑通了,作用是输入表字段名称,输出字段的统计总行数,为空行数,以及top十条去重后的样例数据,方法说明都有标注,以下是代码贴图:
package com.zh.hive;
import net.sf.json.JSONObject;
import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.eclipse.jetty.util.ajax.JSON;
import java.util.*;
public class QcUdf extends AbstractGenericUDAFResolver {
public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameter) throws SemanticException {
ObjectInspector oi = TypeInfoUtils.getStandardJavaObjectInspectorFromTypeInfo(parameter[0]);
PrimitiveObjectInspector inputOI = (PrimitiveObjectInspector) oi;
return new GenericUDAFHistogramNumericEvaluator();
}
public static class GenericUDAFHistogramNumericEvaluator extends GenericUDAFEvaluator {
// UDAF logic goes here!
PrimitiveObjectInspector inputOI;
ObjectInspector outputOI;
PrimitiveObjectInspector integerOI;
@Override
public ObjectInspector init(Mode m, ObjectInspector[] parameters)
throws HiveException {
assert (parameters.length == 1);
super.init(m, parameters);
//map阶段读取sql列,输入为String基础数据格式
if (m == Mode.PARTIAL1 || m == Mode.COMPLETE) {
inputOI = (PrimitiveObjectInspector) parameters[0];
} else {
//其余阶段,输入为String基础数据格式
integerOI = (PrimitiveObjectInspector) parameters[0];
}
// 指定各个阶段输出数据格式都为String类型
outputOI = ObjectInspectorFactory.getReflectionObjectInspector(String.class,
ObjectInspectorFactory.ObjectInspectorOptions.JAVA);
return outputOI;
}
/**
* 存储当前字符总数的类
*/
static class LetterSumAgg implements AggregationBuffer {
int sum = 0;
int count1 = 0;
Map<String,Integer> map =new HashMap<String,Integer>();
void put(String str){//放进去一个字段值
str = str.trim();
if(str!=null||str!=""){
if (map.get(str)!=null){
int org=map.get(str)+1;
map.put(str,org);
}else{
map.put(str,1);
}
}else{
map.put("null_key",1);
}
}
void put(Map<String,Integer> target_map) {//合并两个map
Iterator<Map.Entry<String,Integer>> target = target_map.entrySet().iterator();
while (target.hasNext()) {
Map.Entry<String,Integer> next = target.next();
String key = next.getKey();
if(map.get(key)!=null){
map.put(key,map.get(key)+target_map.get(key));
} else{
map.put(key,target_map.get(key));
}
}
}
void add(int num,int count){
sum += num;
count1 += count;
}
String getTop10(){
List <String> list = new ArrayList<String>();
String str ="";
for(Map.Entry entry:map.entrySet()){
list.add(entry.getValue().toString());
}
Collections.sort(list);
if(list.size()>10){
int count = 0;
for (int i=list.size()-1;i>list.size()-11;i--){
if (count<=10) {
for (Map.Entry entry : map.entrySet()) {
if (list.get(i).equals(entry.getValue().toString())) {
count++;
if (count <= 10) {
str += entry.getKey().toString().replace("\n","").replace("\t","").replace("|","") + "@" + entry.getValue().toString() + ",";
map.put(entry.getKey().toString(), 0);
}else{ break;}
}
}
}
}
}else{
for(Map.Entry entry:map.entrySet()){
str += entry.getKey().toString().replace("null_key","null")+"@"+entry.getValue().toString()+",";
}
}
return str;
}
}
public AggregationBuffer getNewAggregationBuffer() throws HiveException {
LetterSumAgg result = new LetterSumAgg();
return result;
}
public void reset(AggregationBuffer aggregationBuffer) throws HiveException {
LetterSumAgg myagg = new LetterSumAgg();
}
private boolean warned = false;
public void iterate(AggregationBuffer aggregationBuffer, Object[] objects) throws HiveException {//逻辑存放地
assert (objects.length == 1);
LetterSumAgg myagg = (LetterSumAgg) aggregationBuffer;
if(myagg==null){
myagg = new LetterSumAgg();
}
if (objects[0] != null&&objects[0].toString().toLowerCase().trim() !="null"&&objects[0].toString().trim() !="") {
myagg.put(objects[0].toString());
myagg.add(1,0);//统计总行数
}else{
myagg.put("null_key");
myagg.add(1,1);//统计总行数
}
}
public String terminatePartial(AggregationBuffer aggregationBuffer) throws HiveException {//单机合并
LetterSumAgg agg = new LetterSumAgg();
LetterSumAgg myagg = (LetterSumAgg)aggregationBuffer;
if(myagg==null){
myagg = new LetterSumAgg();
}
agg.sum += myagg.sum;
agg.count1 += myagg.count1;
agg.put(myagg.map);
JSONObject jsonObject=null;
if (agg.map!=null){
jsonObject = JSONObject.fromObject(agg.map);
}
// JSONObject jsonObject = JSONObject.fromObject(agg.map);
return agg.sum+"#@"+agg.count1+"#@"+jsonObject;
}
public void merge(AggregationBuffer aggregationBuffer, Object o) throws HiveException {//集群合并
if ( o!= null) {
LetterSumAgg myagg1 = (LetterSumAgg) aggregationBuffer;
String agg = (String) integerOI.getPrimitiveJavaObject(o);
String result[] = agg.split("#@");
if (result[2]!=null) {
Map maps = (Map) JSON.parse(result[2]);
myagg1.put(maps);
}
myagg1.add(Integer.parseInt(result[0]),Integer.parseInt(result[1]));
}
}
public Object terminate(AggregationBuffer aggregationBuffer) throws HiveException {//复制最终结果
LetterSumAgg myagg = (LetterSumAgg) aggregationBuffer;
return myagg.sum+"|"+myagg.count1+"|"+myagg.getTop10();
}
}
}
各位朋友使用请直接copy即可。附上maven依赖
<dependencies>
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-jdbc</artifactId>
<version>2.1.1</version>
</dependency>
<dependency>
<groupId>org.apache.hive</groupId>
<artifactId>hive-exec</artifactId>
<version>2.1.1</version>
</dependency>
</dependencies>
大功告成,测试结果样例如下:
38386|0|[3522963, 3383561, 3517824, 3505051, 3037673, 3523778, 3300084, 3483628, 3525325, 3514324]
执行代码如下:
use databases_name;
add jar /home/zhangheng/hive.jar;
create temporary function tj as 'com.zh.hive.QcUdf';
select tj(c1) ,tj(c2),tj(c3) from table;