CostBasedJoinReorder Logical Optimization -- Join Reordering in Cost-Based Optimization¶
CostBasedJoinReorder
is a base logical optimization that reorders joins in Cost-Based Optimization.
ReorderJoin
is part of the Join Reorder once-executed batch in the standard batches of the Logical Optimizer.
ReorderJoin
is simply a Catalyst rule for transforming LogicalPlans, i.e. Rule[LogicalPlan]
.
CostBasedJoinReorder
applies the join optimizations on a logical plan with 2 or more consecutive inner or cross joins (possibly separated by Project
operators) when spark.sql.cbo.enabled and spark.sql.cbo.joinReorder.enabled configuration properties are both enabled.
// Use shortcuts to read the values of the properties
scala> spark.sessionState.conf.cboEnabled
res0: Boolean = true
scala> spark.sessionState.conf.joinReorderEnabled
res1: Boolean = true
CostBasedJoinReorder
uses row count statistic that is computed using ANALYZE TABLE COMPUTE STATISTICS SQL command with no NOSCAN
option.
// Create tables and compute their row count statistics
// There have to be at least 2 joins
// Make the example reproducible
val tableNames = Seq("t1", "t2", "tiny")
import org.apache.spark.sql.catalyst.TableIdentifier
val tableIds = tableNames.map(TableIdentifier.apply)
val sessionCatalog = spark.sessionState.catalog
tableIds.foreach { tableId =>
sessionCatalog.dropTable(tableId, ignoreIfNotExists = true, purge = true)
}
val belowBroadcastJoinThreshold = spark.sessionState.conf.autoBroadcastJoinThreshold - 1
spark.range(belowBroadcastJoinThreshold).write.saveAsTable("t1")
// t2 is twice as big as t1
spark.range(2 * belowBroadcastJoinThreshold).write.saveAsTable("t2")
spark.range(5).write.saveAsTable("tiny")
// Compute row count statistics
tableNames.foreach { t =>
sql(s"ANALYZE TABLE $t COMPUTE STATISTICS")
}
// Load the tables
val t1 = spark.table("t1")
val t2 = spark.table("t2")
val tiny = spark.table("tiny")
// Example: Inner join with join condition
val q = t1.join(t2, Seq("id")).join(tiny, Seq("id"))
val plan = q.queryExecution.analyzed
scala> println(plan.numberedTreeString)
00 Project [id#51L]
01 +- Join Inner, (id#51L = id#57L)
02 :- Project [id#51L]
03 : +- Join Inner, (id#51L = id#54L)
04 : :- SubqueryAlias t1
05 : : +- Relation[id#51L] parquet
06 : +- SubqueryAlias t2
07 : +- Relation[id#54L] parquet
08 +- SubqueryAlias tiny
09 +- Relation[id#57L] parquet
// Eliminate SubqueryAlias logical operators as they no longer needed
// And "confuse" CostBasedJoinReorder
// CostBasedJoinReorder cares about how deep Joins are and reorders consecutive joins only
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
val noAliasesPlan = EliminateSubqueryAliases(plan)
scala> println(noAliasesPlan.numberedTreeString)
00 Project [id#51L]
01 +- Join Inner, (id#51L = id#57L)
02 :- Project [id#51L]
03 : +- Join Inner, (id#51L = id#54L)
04 : :- Relation[id#51L] parquet
05 : +- Relation[id#54L] parquet
06 +- Relation[id#57L] parquet
// Let's go pro and create a custom RuleExecutor (i.e. an Optimizer)
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder
object Optimize extends RuleExecutor[LogicalPlan] {
val batches =
Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) ::
Batch("Join Reorder", Once, CostBasedJoinReorder) :: Nil
}
val joinsReordered = Optimize.execute(plan)
scala> println(joinsReordered.numberedTreeString)
00 Project [id#51L]
01 +- Join Inner, (id#51L = id#54L)
02 :- Project [id#51L]
03 : +- Join Inner, (id#51L = id#57L)
04 : :- Relation[id#51L] parquet
05 : +- Relation[id#57L] parquet
06 +- Relation[id#54L] parquet
// Execute the plans
// Compare the plans as diagrams in web UI @ http://localhost:4040/SQL
// We'd have to use too many internals so let's turn CBO on and off
// Moreover, please remember that the query "phases" are cached
// That's why we copy and paste the entire query for execution
import org.apache.spark.sql.internal.SQLConf
val cc = SQLConf.get
cc.setConf(SQLConf.CBO_ENABLED, false)
val q = t1.join(t2, Seq("id")).join(tiny, Seq("id"))
q.collect.foreach(_ => ())
cc.setConf(SQLConf.CBO_ENABLED, true)
val q = t1.join(t2, Seq("id")).join(tiny, Seq("id"))
q.collect.foreach(_ => ())