SparkSQL用户自定义函数

UDF函数

通过spark.udf.register(“funcName”, func) 来进行注册,通过select funcName(name) from *来直接使用

scala> val df = spark.read.json("1.json")
df: org.apache.spark.sql.DataFrame = [age: bigint, name: string]
 
scala> spark.udf.register("addName", (x:String)=> "No:"+x)
res5: org.apache.spark.sql.expressions.UserDefinedFunction = UserDefinedFunction(<function1>,StringType,Some(List(StringType)))
 
scala> df.createOrReplaceTempView("people")
 
scala> spark.sql("Select addName(name), age from people").show()
+-----------------+----+
|UDF:addName(name)| age|
+-----------------+----+
|     No:Michael|null|
|        No:Andy|  30|
|      No:Justin|  19|
+-----------------+----+

UDAF函数

强类型的Dataset和弱类型的DataFrame都提供了相关的聚合函数, 如count(),countDistinct(),avg(),max(),min()。除此之外,用户可以设定自己的自定义聚合函数。

弱类型用户自定义聚合函数

通过继承UserDefinedAggregateFunction来实现用户自定义聚合函数。下面展示一个求平均工资的自定义聚合函数。

import org.apache.spark.sql.expressions.MutableAggregationBuffer
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
import org.apache.spark.sql.types._
import org.apache.spark.sql.Row
import org.apache.spark.sql.SparkSession

object MyAverage extends UserDefinedAggregateFunction {
// 聚合函数输入参数的数据类型 
def inputSchema: StructType = StructType(StructField("inputColumn", LongType) :: Nil)
// 聚合缓冲区中值得数据类型 
def bufferSchema: StructType = {
StructType(StructField("sum", LongType) :: StructField("count", LongType) :: Nil)
}
// 返回值的数据类型 
def dataType: DataType = DoubleType
// 对于相同的输入是否一直返回相同的输出。 
def deterministic: Boolean = true
// 初始化
def initialize(buffer: MutableAggregationBuffer): Unit = {
// 存工资的总额
buffer(0) = 0L
// 存工资的个数
buffer(1) = 0L
}
// 相同Execute间的数据合并。 
def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
if (!input.isNullAt(0)) {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
}
// 不同Execute间的数据合并 
def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) = buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
// 计算最终结果
def evaluate(buffer: Row): Double = buffer.getLong(0).toDouble / buffer.getLong(1)
}
 
// 注册函数
spark.udf.register("myAverage", MyAverage)
 
val df = spark.read.json("examples/src/main/resources/employees.json")
df.createOrReplaceTempView("employees")
df.show()
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+
 
val result = spark.sql("SELECT myAverage(salary) as average_salary FROM employees")
result.show()
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+

强类型用户自定义聚合函数

通过继承Aggregator来实现强类型自定义聚合函数,同样是求平均工资

import org.apache.spark.sql.expressions.Aggregator
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.Encoders
import org.apache.spark.sql.SparkSession
// 既然是强类型,可能有case类
case class Employee(name: String, salary: Long)
case class Average(var sum: Long, var count: Long)
 
object MyAverage extends Aggregator[Employee, Average, Double] {
// 定义一个数据结构,保存工资总数和工资总个数,初始都为0
def zero: Average = Average(0L, 0L)
// Combine two values to produce a new value. For performance, the function may modify `buffer`
// and return it instead of constructing a new object
def reduce(buffer: Average, employee: Employee): Average = {
buffer.sum += employee.salary
buffer.count += 1
buffer
}
// 聚合不同execute的结果
def merge(b1: Average, b2: Average): Average = {
b1.sum += b2.sum
b1.count += b2.count
b1
}
// 计算输出
def finish(reduction: Average): Double = reduction.sum.toDouble / reduction.count
// 设定之间值类型的编码器,要转换成case类
// Encoders.product是进行scala元组和case类转换的编码器 
def bufferEncoder: Encoder[Average] = Encoders.product
// 设定最终输出值的编码器
def outputEncoder: Encoder[Double] = Encoders.scalaDouble
}
 
val ds = spark.read.json("examples/src/main/resources/employees.json").as[Employee]
ds.show()
// +-------+------+
// |   name|salary|
// +-------+------+
// |Michael|  3000|
// |   Andy|  4500|
// | Justin|  3500|
// |  Berta|  4000|
// +-------+------+
 
// Convert the function to a `TypedColumn` and give it a name
val averageSalary = MyAverage.toColumn.name("average_salary")
val result = ds.select(averageSalary)
result.show()
// +--------------+
// |average_salary|
// +--------------+
// |        3750.0|
// +--------------+

开窗函数

over()开窗函数:

在使用聚合函数后,会将多行变成一行,而开窗函数是将一行变成多行;
并且在使用聚合函数后,如果要显示其他的列必须将列加入到group by中,
而使用开窗函数后,可以不使用group by,直接将所有信息显示出来。
开窗函数适用于在每一行的最后一列添加聚合函数的结果。

常用开窗函数:
1.为每条数据显示聚合信息(聚合函数() over())
2.为每条数据提供分组的聚合函数结果(聚合函数() over(partition by 字段) as 别名)
–按照字段分组,分组后进行计算
3.与排名函数一起使用(row number() over(order by 字段) as 别名)

常用函数:

   1、row_number() over(partition by ... order by ...)
   2、rank() over(partition by ... order by ...)
   3、dense_rank() over(partition by ... order by ...)
   4、count() over(partition by ... order by ...)
   5、max() over(partition by ... order by ...)
   6、min() over(partition by ... order by ...)
   7、sum() over(partition by ... order by ...)
   8、avg() over(partition by ... order by ...)
   9、first_value() over(partition by ... order by ...)
   10、last_value() over(partition by ... order by ...)
   11、lag() over(partition by ... order by ...)
   12、lead() over(partition by ... order by ...)
   //lag和lead可以获取结果集中,按一定排序所排列的当前行的上下相邻若干offset的某个行的某个列(不用结果集的自关联)
   //lag ,lead分别是向前,向后
   //lag和lead有三个参数,第一个参数是列名,第二个参数是偏移的offset,第三个参数是超出记录窗口时的默认值

猜你喜欢

转载自blog.csdn.net/ThreeAspects/article/details/105947526