SparkSQL自定义函数

一:自定义函数分类

在Spark中,也支持Hive中的自定义函数。自定义函数大致可以分为三种:

1.UDF(User-Defined-Function),即最基本的自定义函数,类似to_char,to_date等
2.UDAF(User- Defined Aggregation Funcation),用户自定义聚合函数,类似在group by之后使用的sum,avg等
3.UDTF(User-Defined Table-Generating Functions),用户自定义生成函数,有点像stream里面的flatMap

二:自定义函数的使用UDF

(一)定义case class

         case class Emp(empno:Int,ename:String,job:String,mgr:String,hiredate:String,sal:Int,comm:String,deptno:Int)

(二)导入emp.csv的文件

          val lineRDD = sc.textFile("/emp.csv").map(_.split(","))

(三)生成DataFrame

         val allEmp = lineRDD.map(x=>Emp(x(0).toInt,x(1),x(2),x(3),x(4),x(5).toInt,x(6),x(7).toInt))
         val empDF = allEmp.toDF

(四)注册成一个临时视图

         empDF.createOrReplaceTempView("emp")

(五)自定义一个函数,拼加字符串

         spark.sqlContext.udf.register("concatstr",(s1:String,s2:String)=>s1+"***"+s2)

(六)调用自定义函数,将ename和job这两个字段拼接在一起

         spark.sql("select concatstr(ename,job) from emp").show

三:用户自定义聚合函数UDAF,需要继承UserDefinedAggregateFunction类,并实现其中的8个方法

UDAF就是用户自定义聚合函数,比如平均值,最大最小值,累加,拼接等。这里以求平均数为例,并用Java实现

(一)实现自定义聚合函数

package SparkUDAF;
 
import org.apache.spark.sql.Row;
import org.apache.spark.sql.expressions.MutableAggregationBuffer;
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
 
import java.util.ArrayList;
import java.util.List;
 
public class MyAvg extends UserDefinedAggregateFunction {
 
    @Override
    public StructType inputSchema() {
        //输入数据的类型,输入的是字符串
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("InputData", DataTypes.StringType, true));
 
        return DataTypes.createStructType(structFields);
    }
 
    @Override
    public StructType bufferSchema() {
 
        //聚合操作时,所处理的数据的数据类型,在这个例子里求平均数,要先求和(Sum),然后除以个数(Amount),所以这里需要处理两个字段
        //注意因为用了ArrayList,所以是有序的
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("Amount", DataTypes.IntegerType, true));
        structFields.add(DataTypes.createStructField("Sum", DataTypes.IntegerType, true));
 
        return DataTypes.createStructType(structFields);
    }
 
    @Override
    public DataType dataType() {
        //UDAF计算后的返回值类型
        return DataTypes.IntegerType;
    }
 
    @Override
    public boolean deterministic() {
        //判断输入和输出的类型是否一致,如果返回的是true则表示一致,false表示不一致,自行设置
        return false;
    }
 
    @Override
    public void initialize(MutableAggregationBuffer buffer) {
        /*
        对辅助字段进行初始化,就是上面定义的field1和field2
        第一个辅助字段的下标为0,初始值为0
        第二个辅助字段的下标为1,初始值为0
        */
        buffer.update(0, 0);
        buffer.update(1, 0);
    }
 
    @Override
    public void update(MutableAggregationBuffer buffer, Row input) {
        /*
        update可以认为是在每一个节点上都会对数据执行的操作,UDAF函数执行的时候,数据会被分发到每一个节点上,就是每一个分区
        buffer.getInt(0)获取的是上一次聚合后的值,input就是当前获取的数据
        */
 
        //修改辅助字段的值,buffer.getInt(x)获取的是上一次聚合后的值,x表示
        buffer.update(0, buffer.getInt(0) + 1); //表示某个数字的个数
        buffer.update(1, buffer.getInt(1) + Integer.parseInt(input.getString(0))); //表示某个数字的总和
    }
 
    @Override
    public void merge(MutableAggregationBuffer buffer1, Row buffer2) {
        /*
        merge:对每个分区的结果进行合并,每个分布式的节点上做完update之后就要做一个全局合并的操作
        合并每一个update操作的结果,将各个节点上的数据合并起来
        buffer1.getInt(0) : 上一次聚合后的值
        buffer2.getInt(0) : 这次计算传入进来的update的结果
        */
 
        //对第一个字段Amount进行求和,求出总个数
        buffer1.update(0, buffer1.getInt(0) + buffer2.getInt(0));
        //对第二个字段Sum进行求和,求出总和
        buffer1.update(1, buffer1.getInt(1) + buffer2.getInt(1));
    }
 
    @Override
    public Object evaluate(Row buffer) {
        //表示最终计算的结果,第二个参数表示和值,第一个参数表示个数
        return buffer.getInt(1) / buffer.getInt(0);
    }
}

(二)注册并使用UDAF

package SparkUDAF;
 
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SQLContext;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
 
import java.util.ArrayList;
import java.util.List;
 
public class TestMain {
    public static void main(String[] args) {
        SparkConf conf =new SparkConf();
        conf.setMaster("local").setAppName("MyAvg");
        JavaSparkContext sc= new JavaSparkContext(conf);
        //得到SQLContext对象
        SQLContext sqlContext = new SQLContext(sc);
 
        //注册自定义函数
        sqlContext.udf().register("my_avg",new MyAvg());
 
        //读入数据
        JavaRDD<String> lines = sc.textFile("d:\\test.txt");
        //分词
        JavaRDD<Row> rows=lines.map(line-> RowFactory.create(line.split("\\^")));
 
        //定义schema的结构,a字段是字母,b字段是value
        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("a",DataTypes.StringType,true));
        structFields.add(DataTypes.createStructField("b",DataTypes.StringType,true));
        StructType structType = DataTypes.createStructType(structFields);
 
        //创建DataFrame
        Dataset ds=sqlContext.createDataFrame(rows,structType);
        ds.registerTempTable("test");
 
        //执行查询
        sqlContext.sql("select a,my_avg(b) from test group by a").show();
        sc.stop();
    }
}

猜你喜欢

转载自www.cnblogs.com/ssyfj/p/12620368.html