sparkSQL 自定义UDAF函数(强类型的方式)spark3.x

package sparksql.day01

import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator


object sparkUDAF1 {

  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.ERROR)
           System.setProperty("hadoop.home.dir", "D:\\spark")
    val conf = new SparkConf().setAppName("spakrsql").setMaster("local[*]")


    val spark = SparkSession.builder().config(conf).getOrCreate()


    val df = spark.read.json("data/user.json")
    //df.show()

    df.createOrReplaceTempView("user")


   spark.udf.register("ageavg",functions.udaf( new MyAgeAVGUDAF()))


    spark.sql("select ageavg(age) from user").show()
   // spark.sql("select avg(age) from user").show()

spark.stop()
  }
//  自定义聚合函数类:计算年龄的平均值
  //继承来自Aggregator
  //
  //重写方法ctrl+i
  //定义一个样例类 用来存放缓存区的数据
  case class MyBuff(var Count:Long,var Number:Long)
  class MyAgeAVGUDAF extends Aggregator[Long,MyBuff,Long]{
    //初始化缓冲区
    override def zero: MyBuff ={
      MyBuff(0,0L)
    }
//计算更新每次的数据
    override def reduce(buff1: MyBuff, buff2: Long): MyBuff = {
      buff1.Count  = buff1.Count+buff2
      buff1.Number = buff1.Number+1
      buff1
    }
//将缓存区的数据进行合并
    override def merge(buff1: MyBuff, buff2: MyBuff): MyBuff = {
        buff1.Count = buff1.Count + buff2.Count
        buff1.Number = buff1.Number + buff2.Number
      buff1
    }
//计算结果的逻辑
    override def finish(reduction: MyBuff): Long = {
      reduction.Number/reduction.Number
    }
//对于自定义类型的,默认使用Encoders.product
    override def bufferEncoder: Encoder[MyBuff] = Encoders.product
//对于是scala中定义的类型。直接使用Encoders.scalaLong
    override def outputEncoder: Encoder[Long] = Encoders.scalaLong
  }
}

猜你喜欢

转载自blog.csdn.net/weixin_38638777/article/details/114463970