UserDefinedAggregateFunction — User-Defined Untyped Aggregate Functions (UDAFs)¶
UserDefinedAggregateFunction
is the <
// Custom UDAF to count rows
import org.apache.spark.sql.Row
import org.apache.spark.sql.expressions.{MutableAggregationBuffer, UserDefinedAggregateFunction}
import org.apache.spark.sql.types.{DataType, LongType, StructType}
class MyCountUDAF extends UserDefinedAggregateFunction {
override def inputSchema: StructType = {
new StructType().add("id", LongType, nullable = true)
}
override def bufferSchema: StructType = {
new StructType().add("count", LongType, nullable = true)
}
override def dataType: DataType = LongType
override def deterministic: Boolean = true
override def initialize(buffer: MutableAggregationBuffer): Unit = {
println(s">>> initialize (buffer: $buffer)")
// NOTE: Scala's update used under the covers
buffer(0) = 0L
}
override def update(buffer: MutableAggregationBuffer, input: Row): Unit = {
println(s">>> update (buffer: $buffer -> input: $input)")
buffer(0) = buffer.getLong(0) + 1
}
override def merge(buffer: MutableAggregationBuffer, row: Row): Unit = {
println(s">>> merge (buffer: $buffer -> row: $row)")
buffer(0) = buffer.getLong(0) + row.getLong(0)
}
override def evaluate(buffer: Row): Any = {
println(s">>> evaluate (buffer: $buffer)")
buffer.getLong(0)
}
}
UserDefinedAggregateFunction
is created using <
val dataset = spark.range(start = 0, end = 4, step = 1, numPartitions = 2)
// Use the UDAF
val mycount = new MyCountUDAF
val q = dataset.
withColumn("group", 'id % 2).
groupBy('group).
agg(mycount.distinct('id) as "count")
scala> q.show
+-----+-----+
|group|count|
+-----+-----+
| 0| 2|
| 1| 2|
+-----+-----+
The <UserDefinedAggregateFunction
is entirely managed using ScalaUDAF
expression container.
[NOTE]¶
Use user-defined-functions/UDFRegistration.md[UDFRegistration] to register a (temporary) UserDefinedAggregateFunction
and use it in SparkSession.md#sql[SQL mode].
[source, scala]¶
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction val mycount: UserDefinedAggregateFunction = ... spark.udf.register("mycount", mycount)
spark.sql("SELECT mycount(*) FROM range(5)")¶
====
=== [[contract]] UserDefinedAggregateFunction Contract
[source, scala]¶
package org.apache.spark.sql.expressions
abstract class UserDefinedAggregateFunction { // only required methods that have no implementation def bufferSchema: StructType def dataType: DataType def deterministic: Boolean def evaluate(buffer: Row): Any def initialize(buffer: MutableAggregationBuffer): Unit def inputSchema: StructType def merge(buffer1: MutableAggregationBuffer, buffer2: Row): Unit def update(buffer: MutableAggregationBuffer, input: Row): Unit }
.(Subset of) UserDefinedAggregateFunction Contract [cols="1,2",options="header",width="100%"] |=== | Method | Description
[[bufferSchema]] bufferSchema |
---|
[[dataType]] dataType |
---|
[[deterministic]] deterministic |
---|
[[evaluate]] evaluate |
---|
[[initialize]] initialize |
---|
[[inputSchema]] inputSchema |
---|
[[merge]] merge |
---|
[[update]] update |
---|
=== |
=== [[apply]] Creating Column for UDAF -- apply
Method
apply(
exprs: Column*): Column
apply
creates a Column with ScalaUDAF
(inside AggregateExpression).
Note
AggregateExpression
uses Complete
mode and isDistinct
flag is disabled.
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
val myUDAF: UserDefinedAggregateFunction = ...
val myUdafCol = myUDAF.apply($"id", $"name")
scala> myUdafCol.explain(extended = true)
mycountudaf('id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0)
scala> println(myUdafCol.expr.numberedTreeString)
00 mycountudaf('id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0)
01 +- MyCountUDAF('id,'name)
02 :- 'id
03 +- 'name
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
myUdafCol.expr.asInstanceOf[AggregateExpression]
import org.apache.spark.sql.execution.aggregate.ScalaUDAF
val scalaUdaf = myUdafCol.expr.children.head.asInstanceOf[ScalaUDAF]
scala> println(scalaUdaf.toString)
MyCountUDAF('id,'name)
=== [[distinct]] Creating Column for UDAF with Distinct Values -- distinct
Method
distinct(
exprs: Column*): Column
distinct
creates a Column with ScalaUDAF
(inside AggregateExpression).
Note
AggregateExpression
uses Complete
mode and isDistinct
flag is enabled.
Note
distinct
is like apply but has isDistinct
flag enabled.
import org.apache.spark.sql.expressions.UserDefinedAggregateFunction
val myUDAF: UserDefinedAggregateFunction = ...
scala> val myUdafCol = myUDAF.distinct($"id", $"name")
myUdafCol: org.apache.spark.sql.Column = mycountudaf(DISTINCT id, name)
scala> myUdafCol.explain(extended = true)
mycountudaf(distinct 'id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0)
import org.apache.spark.sql.catalyst.expressions.aggregate.AggregateExpression
val aggExpr = myUdafCol.expr
scala> println(aggExpr.numberedTreeString)
00 mycountudaf(distinct 'id, 'name, $line17.$read$$iw$$iw$MyCountUDAF@4704b66a, 0, 0)
01 +- MyCountUDAF('id,'name)
02 :- 'id
03 +- 'name
scala> aggExpr.asInstanceOf[AggregateExpression].isDistinct
res0: Boolean = true