package sparksql.day01
import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}
object sparkUDAF {
def main(args: Array[String]): Unit = {
Logger.getLogger("org").setLevel(Level.ERROR)
System.setProperty("hadoop.home.dir", "D:\\spark")
val conf = new SparkConf().setAppName("spakrsql").setMaster("local[*]")
val spark = SparkSession.builder().config(conf).getOrCreate()
import spark.implicits._
val df = spark.read.json("data/user.json")
//df.show()
df.createOrReplaceTempView("user")
spark.udf.register("ageavg",new MyAgeAVGUDAF())
spark.sql("select ageavg(age) from user").show()
// spark.sql("select avg(age) from user").show()
spark.stop()
}
// 自定义聚合函数类:计算年龄的平均值
//继承来自UserDefinedAggregateFunction
//重写方法ctrl+i
class MyAgeAVGUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
//输入的类型
StructType(
Array(StructField("age",LongType))
)
}
//中间缓存区的数据类型
override def bufferSchema: StructType = {
StructType(
Array(StructField("age",LongType),
StructField("count",LongType)
)
)
}
//输出的数据类型
override def dataType: DataType = LongType
//函数的稳定性
override def deterministic: Boolean = true
//缓存区初始化
override def initialize(buffer: MutableAggregationBuffer): Unit = {
//查看MutableAggregationBuffer 类发现是一个抽象类,直接使用
buffer.update(0,0L)//将第1个参数age的值初始化为0
buffer.update(1,0L)//将第2个参数count的值初始化为0
}
//计算从输入的数据到缓冲区的计算方式
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
//将现在0号位置的值+输入数据的值=当前缓冲区的更新后的值
buffer.update(0,buffer.getLong(0)+input.getLong(0))
//将现在1号位置的值+1=当前缓冲区的更新后数量总和
buffer.update(1,buffer.getLong(1)+1)
}
//缓冲区的数据进行合并
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
//将缓冲区的数据进行合并,对于spark的分布式计算,一般都是进行两两合并,合并到buffer1中作为最后的数据
buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
buffer1.update(0,buffer1.getLong(1)+buffer2.getLong(1))
//得到的数据格式就是(120,3)
}
//将缓存区合并后的数据进行统一计算 具体的功能计算
override def evaluate(buffer: Row): Any = {
buffer.getLong(0)/buffer.getLong(1)
}
}
}
user.json数据
{"username":"shijinhua001","age":30}
{"username":"shijinhua002","age":40}
{"username":"shijinhua003","age":50}