package sparksql.day01
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{SparkSession, functions}
import org.apache.spark.sql.expressions.Aggregator
object sparkUDAF1 {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
System.setProperty("hadoop.home.dir", "D:\\spark")
val conf = new SparkConf().setAppName("spakrsql").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
val df = spark.read.json("data/user.json")
//df.show()
df.createOrReplaceTempView("user")
spark.udf.register("ageavg",functions.udaf( new MyAgeAVGUDAF()))
spark.sql("select ageavg(age) from user").show()
// spark.sql("select avg(age) from user").show()
spark.stop()
}
// 自定义聚合函数类:计算年龄的平均值
//继承来自Aggregator
//
//重写方法ctrl+i
//定义一个样例类 用来存放缓存区的数据
case class MyBuff(var Count:Long,var Number:Long)
class MyAgeAVGUDAF extends Aggregator[Long,MyBuff,Long]{
//初始化缓冲区
override def zero: MyBuff ={
MyBuff(0,0L)
}
//计算更新每次的数据
override def reduce(buff1: MyBuff, buff2: Long): MyBuff = {
buff1.Count = buff1.Count+buff2
buff1.Number = buff1.Number+1
buff1
}
//将缓存区的数据进行合并
override def merge(buff1: MyBuff, buff2: MyBuff): MyBuff = {
buff1.Count = buff1.Count + buff2.Count
buff1.Number = buff1.Number + buff2.Number
buff1
}
//计算结果的逻辑
override def finish(reduction: MyBuff): Long = {
reduction.Number/reduction.Number
}
//对于自定义类型的,默认使用Encoders.product
override def bufferEncoder: Encoder[MyBuff] = Encoders.product
//对于是scala中定义的类型。直接使用Encoders.scalaLong
override def outputEncoder: Encoder[Long] = Encoders.scalaLong
}
}