Aggregation Execution Planning Strategy¶
Aggregation
is an execution planning strategy that SparkPlanner uses for planning Aggregate logical operators (in the order of preference):
Executing Rule¶
apply(
plan: LogicalPlan): Seq[SparkPlan]
apply
is part of the GenericStrategy abstraction.
apply
works with Aggregate logical operators with all the aggregate expressions being either AggregateExpressions or PythonUDFs only. Otherwise, apply
throws an AnalysisException.
apply
destructures the Aggregate logical operator (into a four-element tuple) with the following:
- Grouping Expressions
- Aggregration Expressions
- Result Expressions
- Child Logical Operator
AggregateExpressions¶
For Aggregate logical operators with AggregateExpressions, apply
splits them based on the isDistinct flag.
Without distinct aggregate functions (expressions), apply
planAggregateWithoutDistinct. Otherwise, apply
planAggregateWithOneDistinct.
In the end, apply
creates one of the following physical operators based on whether there is distinct aggregate function or not.
Note
It is assumed that all the distinct aggregate functions have the same column expressions.
COUNT(DISTINCT foo), MAX(DISTINCT foo)
The following is not valid due to different column expressions
COUNT(DISTINCT bar), COUNT(DISTINCT foo)
PythonUDFs¶
For Aggregate logical operators with PythonUDF
s (PySpark)...FIXME
AnalysisException¶
apply
can throw an AnalysisException
:
Cannot use a mixture of aggregate function and group aggregate pandas UDF
Demo¶
scala> :type spark
org.apache.spark.sql.SparkSession
// structured query with count aggregate function
val q = spark
.range(5)
.groupBy($"id" % 2 as "group")
.agg(count("id") as "count")
val plan = q.queryExecution.optimizedPlan
scala> println(plan.numberedTreeString)
00 Aggregate [(id#0L % 2)], [(id#0L % 2) AS group#3L, count(1) AS count#8L]
01 +- Range (0, 5, step=1, splits=Some(8))
import spark.sessionState.planner.Aggregation
val physicalPlan = Aggregation.apply(plan)
// HashAggregateExec selected
scala> println(physicalPlan.head.numberedTreeString)
00 HashAggregate(keys=[(id#0L % 2)#12L], functions=[count(1)], output=[group#3L, count#8L])
01 +- HashAggregate(keys=[(id#0L % 2) AS (id#0L % 2)#12L], functions=[partial_count(1)], output=[(id#0L % 2)#12L, count#14L])
02 +- PlanLater Range (0, 5, step=1, splits=Some(8))