sparkSQL 自定义UDAF函数(弱类型的方式)spark1.x spark2.x

package sparksql.day01

import org.apache.log4j.{Level, Logger}
import org.apache.spark.SparkConf
import org.apache.spark.sql.{Row, SparkSession}
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructField, StructType}

object sparkUDAF {

  def main(args: Array[String]): Unit = {
    Logger.getLogger("org").setLevel(Level.ERROR)
    System.setProperty("hadoop.home.dir", "D:\\spark")
    val conf = new SparkConf().setAppName("spakrsql").setMaster("local[*]")


    val spark = SparkSession.builder().config(conf).getOrCreate()
    import spark.implicits._


    val df = spark.read.json("data/user.json")
    //df.show()

    df.createOrReplaceTempView("user")


    spark.udf.register("ageavg",new MyAgeAVGUDAF())


    spark.sql("select ageavg(age) from user").show()
   // spark.sql("select avg(age) from user").show()

spark.stop()
  }
//  自定义聚合函数类:计算年龄的平均值
  //继承来自UserDefinedAggregateFunction
  //重写方法ctrl+i
  class MyAgeAVGUDAF extends UserDefinedAggregateFunction{
  override def inputSchema: StructType = {
    //输入的类型
    StructType(
      Array(StructField("age",LongType))
    )
  }
//中间缓存区的数据类型
  override def bufferSchema: StructType = {
    StructType(
      Array(StructField("age",LongType),
            StructField("count",LongType)
      )
    )
  }
//输出的数据类型
  override def dataType: DataType = LongType
//函数的稳定性
  override def deterministic: Boolean = true
//缓存区初始化
  override def initialize(buffer: MutableAggregationBuffer): Unit = {
    //查看MutableAggregationBuffer 类发现是一个抽象类,直接使用
    buffer.update(0,0L)//将第1个参数age的值初始化为0
    buffer.update(1,0L)//将第2个参数count的值初始化为0
  }
//计算从输入的数据到缓冲区的计算方式
  override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
    //将现在0号位置的值+输入数据的值=当前缓冲区的更新后的值
    buffer.update(0,buffer.getLong(0)+input.getLong(0))
    //将现在1号位置的值+1=当前缓冲区的更新后数量总和
    buffer.update(1,buffer.getLong(1)+1)
  }
//缓冲区的数据进行合并
  override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
    //将缓冲区的数据进行合并,对于spark的分布式计算,一般都是进行两两合并,合并到buffer1中作为最后的数据
    buffer1.update(0,buffer1.getLong(0)+buffer2.getLong(0))
    buffer1.update(0,buffer1.getLong(1)+buffer2.getLong(1))
    //得到的数据格式就是(120,3)
  }
//将缓存区合并后的数据进行统一计算 具体的功能计算
  override def evaluate(buffer: Row): Any = {
    buffer.getLong(0)/buffer.getLong(1)
  }
}
}

user.json数据

{"username":"shijinhua001","age":30}
{"username":"shijinhua002","age":40}
{"username":"shijinhua003","age":50}

猜你喜欢

转载自blog.csdn.net/weixin_38638777/article/details/114462524