In Spark, custom functions in Hive are also supported. Custom functions can be roughly divided into three types:
- UDF (User-Defined-Function), the most basic custom function, similar to to_char, to_date, etc.
- UDAF (User-Defined Aggregation Funcation), user-defined aggregation function, similar to sum, avg, etc. used after group by
- UDTF (User-Defined Table-Generating Functions), user-defined generating function, a bit like flatMap in stream
To customize a UDF function, you need to inherit the UserDefinedAggregateFunction class and implement 8 methods.
Example
import org.apache.spark.sql.Row import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction} import org.apache.spark.sql.types.{DataType, StringType, StructField, StructType} object GetDistinctCityUDF extends UserDefinedAggregateFunction{ /** * input data type * */ override def inputSchema: StructType = StructType( StructField("status",StringType,true) :: Nil ) /** * Cache field type * */ override def bufferSchema: StructType = { StructType( Array( StructField("buffer_city_info",StringType,true) ) ) } /** * output result type * */ override def dataType: DataType = StringType /** * Whether the input type and output type are the same * */ override def deterministic: Boolean = true /** * Initialize auxiliary fields * */ override def initialize(buffer: MutableAggregationBuffer): Unit = { buffer.update(0,"") } /** * Modify the value of the auxiliary field * */ override def update(buffer: MutableAggregationBuffer, input: Row): Unit = { // Get the last value var last_str = buffer.getString( 0 ) // Get the current value val current_str = input.getString( 0 ) // Determine whether the last value contains the current value if (! last_str.contains(current_str)){ // Determine whether it is the first value, if yes, go to if assignment, if not, add else if (last_str.equals( "" )){ last_str = current_str }else{ last_str += "," + current_str } } buffer.update(0,last_str) } /** * Merge partition results *buffer1 is the result on machine hadoop1 *buffer2 is the result on machine Hadoop2 * */ override def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit = { var buf1 = buffer1.getString(0) val buf2 = buffer2.getString( 0 ) // Append the data in buf2 but not in buf1 to buf1 // The data in buf2 is divided according to for (s <- buf2.split( " , " )) { if (! buf1.contains(s)){ if (buf1.equals( "" )){ buf1 = s }else{ buf1 += s } } } buffer1.update(0,buf1) } /** * Final calculation result * */ override def evaluate(buffer: Row): Any = { buffer.getString(0) } }
Register a custom UDF function as a temporary function
def main(args: Array[String]): Unit = { /** * The first step is to create a program entry */ val conf = new SparkConf().setAppName("AralHotProductSpark") val sc = new SparkContext(conf) val hiveContext = new HiveContext(sc)
//register as a temporary function hiveContext.udf.register( " get_distinct_city " ,GetDistinctCityUDF) //register as a temporary function hiveContext.udf.register( " get_product_status " ,(str:String) => { var status = 0 for (s <- str.split( " , " )){ if (s.contains( " product_status " )){ status = s.split(":")(1).toInt } } }) }