Hive之——用户自定义聚合函数

版权声明:本文为博主原创文章,未经博主允许不得转载。 https://blog.csdn.net/l1028386804/article/details/88536189

转载请注明出处:https://blog.csdn.net/l1028386804/article/details/88536189

基于UDAF执行的转换的不同,在不同阶段的返回值类型也可能是不同的。
在写UDAF的时候一定要注意内存使用的问题。通过配置参数mapred.child.java.opts可以调整执行过程的内存需求量,但是这种方式并非总是奏效:

<property>
	<name>mapred.child.java.opts</name>
	<value>-Xmx200m</value>
</property>

创建一个COLLECT UDAF来模拟GROUP_CONCAT
MySQL中有一个非常有用的函数名为GROUP_CONCAT,其可以将一组中的所有元素按照用户指定的分隔符组装成一个字符串。
GROUP_CONCAT在MySQL中的用法如下:

mysql> create table people(name string, friendname string);

mysql> select * from people;
bob		sara
bob		john
bob		ted
john	bob
ted		sara

mysql> select name, group_concat(friendname separator ',') from people group by name;

我们无需增加新的语法就可以在Hive中实现同样的转换。
接下来我们创建GenericUDAFCollect类和GenericUDAFMkListEvaluator类

package com.lyz.hadoop.hive.udaf;

import org.apache.hadoop.hive.ql.exec.Description;
import org.apache.hadoop.hive.ql.exec.UDFArgumentTypeException;
import org.apache.hadoop.hive.ql.parse.SemanticException;
import org.apache.hadoop.hive.ql.udf.generic.AbstractGenericUDAFResolver;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfo;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;

/**
 *  自定义聚合函数
 * @author liuyazhuang
 *
 */
@Description(name="collect", 
value = "_FUNC_(x) - Returns a list of objects. CAUTION will easily OOM on large data sets ")
public class GenericUDAFCollect extends AbstractGenericUDAFResolver {
	public GenericUDAFCollect() {
		
	}
	
	@Override
	public GenericUDAFEvaluator getEvaluator(TypeInfo[] parameters) throws SemanticException {
		if(parameters.length != 1) {
			throw new UDFArgumentTypeException(parameters.length - 1, "Exactly one argument is expected.");
		}
		if(parameters[0].getCategory() != ObjectInspector.Category.PRIMITIVE) {
			throw new UDFArgumentTypeException(0, "Pnly primitive type arguments are accepted but " + parameters[0].getTypeName() + " was passed as parameter 1.");
		}
		return new GenericUDAFMkListEvaluator(); 
	}
}
package com.lyz.hadoop.hive.udaf;

import java.util.ArrayList;
import java.util.List;

import org.apache.hadoop.hive.ql.metadata.HiveException;
import org.apache.hadoop.hive.ql.udf.generic.GenericUDAFEvaluator;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorFactory;
import org.apache.hadoop.hive.serde2.objectinspector.ObjectInspectorUtils;
import org.apache.hadoop.hive.serde2.objectinspector.PrimitiveObjectInspector;
import org.apache.hadoop.hive.serde2.objectinspector.StandardListObjectInspector;

/**
   * 自定义聚合函数
 * @author liuyazhuang
 *
 */
public class GenericUDAFMkListEvaluator extends GenericUDAFEvaluator {
	private PrimitiveObjectInspector inputOI;
	private StandardListObjectInspector loi;
	private StandardListObjectInspector internalMergeOI;
	
	@Override
	public ObjectInspector init(Mode m, ObjectInspector[] parameters) throws HiveException {
		super.init(m, parameters);
		if(m == Mode.PARTIAL1) {
			inputOI = (PrimitiveObjectInspector) parameters[0];
			return ObjectInspectorFactory.getStandardListObjectInspector((PrimitiveObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(inputOI));
		}else {
			if(!(parameters[0] instanceof StandardListObjectInspector)) {
				inputOI = (PrimitiveObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(parameters[0]);
				return (StandardListObjectInspector) ObjectInspectorFactory.getStandardListObjectInspector(inputOI);
			}else {
				internalMergeOI = (StandardListObjectInspector) parameters[0];
				inputOI = (PrimitiveObjectInspector) internalMergeOI.getListElementObjectInspector();
				loi = (StandardListObjectInspector) ObjectInspectorUtils.getStandardObjectInspector(internalMergeOI);
				return loi;
			}
		}
	}
	static class MkArrayAggregationBuffer implements AggregationBuffer{
		List<Object> container;
	}
	
	@Override
	public AggregationBuffer getNewAggregationBuffer() throws HiveException {
		MkArrayAggregationBuffer ret = new MkArrayAggregationBuffer();
		reset(ret);
		return ret;
	}

	@Override
	public void reset(AggregationBuffer agg) throws HiveException {
		((MkArrayAggregationBuffer) agg).container = new ArrayList<Object>();
	}

	@Override
	public void iterate(AggregationBuffer agg, Object[] parameters) throws HiveException {
		assert(parameters.length == 1);
		Object p = parameters[0];
		if(p != null) {
			MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
			putInfoList(p, myagg);
		}
	}
	private void putInfoList(Object p, MkArrayAggregationBuffer myagg) {
		Object pCopy = ObjectInspectorUtils.copyToStandardObject(p, this.inputOI);
		myagg.container.add(pCopy);
	}

	@Override
	public Object terminatePartial(AggregationBuffer agg) throws HiveException {
		MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
		ArrayList<Object> ret = new ArrayList<Object>(myagg.container.size());
		ret.addAll(myagg.container);
		return ret;
	}

	@Override
	public void merge(AggregationBuffer agg, Object partial) throws HiveException {
		MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
		ArrayList<Object> partialResult = (ArrayList<Object>) internalMergeOI.getList(partial);
		for(Object i : partialResult) {
			putInfoList(i, myagg);
		}

	}

	@Override
	public Object terminate(AggregationBuffer agg) throws HiveException {
		MkArrayAggregationBuffer myagg = (MkArrayAggregationBuffer) agg;
		ArrayList<Object> ret = new ArrayList<Object>(myagg.container.size());
		ret.addAll(myagg.container);
		return ret;
	}

}

接下来我们创建数据集文件afile.txt

twelve	12
twelve	12
eleven	11
eleven	11

接下来,创建表collecttest

hive> create table collecttest(str string, countVal int) row format delimited fields terminated by '\t' lines terminated by '\n';
hive> load data local inpath '/root/afile.txt' into table collecttest;

将这两个类导出为udaf.jar并上传到服务器的/usr/local/src/

hive> add jar /usr/local/src/udaf.jar;
hive> create temporary function collect as 'com.lyz.hadoop.hive.udaf.GenericUDAFCollect';
hive> select collect(str) from collecttest;
["twelve","twelve","eleven","eleven"]

函数concat_ws()的第1个参数是个分隔符,其他的参数可以是字符串或者字符串数组,返回值按照指定分隔符将所有字符串拼接在一起后的字符
例如:我们使用逗号将一组字符串拼接成一个字符串:

hive> select concat_ws(',',collect(str)) from collecttest;
twelve,twelve,eleven,eleven

group_concat函数可以按照如下语句通过组合使用group by、collect和concat_ws()达到同样的效果:

hive> select str, concat_ws(',', collect(cast(countVal as string))) from collecttest group by str;
eleven  11,11
twelve  12,12

注意:create function语句中的temporary关键字。当前会话中声明的函数只会在当前会话有效。因此用户需要在每个会话中都增加Jar然后创建函数。不过,如果用户频繁的使用同一个Jar文件和函数的话,可以将相关的语句增加到$HOME/.hiverc文件中。

猜你喜欢

转载自blog.csdn.net/l1028386804/article/details/88536189