1.UDF(User-Defined-Function)–SparkSQL支持自定义–最常用
- DSL风格 val small2big1: UserDefinedFunction = functions.udf((word:String)=>{word.toUpperCase})//定义方法
- SQL风格 spark.udf.register(“small2big2”,(word:String)=>{word.toUpperCase})//注册UDF
- 即最基本的自定义函数,类似to_char,to_date
输入一行,输出一行
2 .UDAF(User-Defined Aggregation Funcation)–SparkSQL支持自定义
- 用户自定义聚合函数,类似在group by之后使用的sum,avg
输入多行,输出一行
3.UDTF(User-Defined Table-Generating Functions)–SparkSQL不支持自定义UDTF
- 用户自定义生成函数,有点像flatMap
- 输入一行,输出多行
package cn.hanjiaxiaozhi.sql
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.UserDefinedFunction
import org.apache.spark.sql.{
DataFrame, Dataset, SparkSession, functions}
object UDFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("sql").master("local[*]").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
import spark.implicits._
val df: DataFrame = spark.read.text("D:\\data\\spark\\udf.txt")
df.show(false)
import org.apache.spark.sql.functions._
val small2big1: UserDefinedFunction = udf((word:String)=>{
word.toUpperCase})
df.select($"value",small2big1($"value")).show(false)
df.createOrReplaceTempView("t_word")
spark.udf.register("small2big2",(word:String)=>{
word.toUpperCase})
val sql:String =
"""
|select value,small2big2(value)
|from t_word
|""".stripMargin
spark.sql(sql).show(false)
}
}
4.自定义UDAF
有udaf.json格式数据内容如下
{
"name":"Michael","salary":3000}
{
"name":"Andy","salary":4500}
{
"name":"Justin","salary":3500}
{
"name":"Berta","salary":4000}
求取平均工资
inputSchema:输入数据的类型
bufferSchema:产生中间结果的数据类型
dataType:最终返回的结果类型
deterministic:确保一致性(输入什么类型的数据就返回什么类型的数据),一般用true
initialize:指定初始值
update:每有一条数据参与运算就更新一下中间结果(update相当于在每一个分区中的运算)
merge:全局聚合(将每个分区的结果进行聚合)
evaluate:计算最终的结果
package cn.hanjiaxiaozhi.sql
import org.apache.spark.SparkContext
import org.apache.spark.sql.expressions.{
MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types._
import org.apache.spark.sql.{
DataFrame, Row, SparkSession}
object UDAFDemo {
def main(args: Array[String]): Unit = {
val spark: SparkSession = SparkSession.builder().appName("SparkSQL").master("local[*]").getOrCreate()
val sc: SparkContext = spark.sparkContext
sc.setLogLevel("WARN")
val employeeDF: DataFrame = spark.read.json("D:\\data\\sql\\udaf.json")
employeeDF.createOrReplaceTempView("t_employee")
spark.udf.register("avgsalary",new SparkFunctionUDAF)
spark.sql("select avgsalary(salary) from t_employee").show()
spark.sql("select avg(salary) from t_employee").show()
}
}
class SparkFunctionUDAF extends UserDefinedAggregateFunction{
override def inputSchema: StructType = {
StructType(StructField("input",LongType)::Nil)
}
override def bufferSchema: StructType = {
StructType(StructField("sum",LongType)::StructField("total",LongType)::Nil)
}
override def dataType: DataType = {
DoubleType
}
override def deterministic: Boolean = {
true
}
override def initialize(buffer: MutableAggregationBuffer): Unit = {
buffer(0) = 0L
buffer(1) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
buffer(0) = buffer.getLong(0) + input.getLong(0)
buffer(1) = buffer.getLong(1) + 1
}
override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = {
buffer1(0) =buffer1.getLong(0) + buffer2.getLong(0)
buffer1(1) = buffer1.getLong(1) + buffer2.getLong(1)
}
override def evaluate(buffer: Row): Any = {
buffer.getLong(0).toDouble / buffer.getLong(1)
}
}