Skip to content

CostBasedJoinReorder Logical Optimization -- Join Reordering in Cost-Based Optimization

CostBasedJoinReorder is a base logical optimization that reorders joins in Cost-Based Optimization.

ReorderJoin is part of the Join Reorder once-executed batch in the standard batches of the Logical Optimizer.

ReorderJoin is simply a Catalyst rule for transforming LogicalPlans, i.e. Rule[LogicalPlan].

CostBasedJoinReorder applies the join optimizations on a logical plan with 2 or more consecutive inner or cross joins (possibly separated by Project operators) when spark.sql.cbo.enabled and spark.sql.cbo.joinReorder.enabled configuration properties are both enabled.

// Use shortcuts to read the values of the properties
scala> spark.sessionState.conf.cboEnabled
res0: Boolean = true

scala> spark.sessionState.conf.joinReorderEnabled
res1: Boolean = true

CostBasedJoinReorder uses row count statistic that is computed using ANALYZE TABLE COMPUTE STATISTICS SQL command with no NOSCAN option.

// Create tables and compute their row count statistics
// There have to be at least 2 joins
// Make the example reproducible
val tableNames = Seq("t1", "t2", "tiny")
import org.apache.spark.sql.catalyst.TableIdentifier
val tableIds = tableNames.map(TableIdentifier.apply)
val sessionCatalog = spark.sessionState.catalog
tableIds.foreach { tableId =>
  sessionCatalog.dropTable(tableId, ignoreIfNotExists = true, purge = true)
}

val belowBroadcastJoinThreshold = spark.sessionState.conf.autoBroadcastJoinThreshold - 1
spark.range(belowBroadcastJoinThreshold).write.saveAsTable("t1")
// t2 is twice as big as t1
spark.range(2 * belowBroadcastJoinThreshold).write.saveAsTable("t2")
spark.range(5).write.saveAsTable("tiny")

// Compute row count statistics
tableNames.foreach { t =>
  sql(s"ANALYZE TABLE $t COMPUTE STATISTICS")
}

// Load the tables
val t1 = spark.table("t1")
val t2 = spark.table("t2")
val tiny = spark.table("tiny")

// Example: Inner join with join condition
val q = t1.join(t2, Seq("id")).join(tiny, Seq("id"))
val plan = q.queryExecution.analyzed
scala> println(plan.numberedTreeString)
00 Project [id#51L]
01 +- Join Inner, (id#51L = id#57L)
02    :- Project [id#51L]
03    :  +- Join Inner, (id#51L = id#54L)
04    :     :- SubqueryAlias t1
05    :     :  +- Relation[id#51L] parquet
06    :     +- SubqueryAlias t2
07    :        +- Relation[id#54L] parquet
08    +- SubqueryAlias tiny
09       +- Relation[id#57L] parquet

// Eliminate SubqueryAlias logical operators as they no longer needed
// And "confuse" CostBasedJoinReorder
// CostBasedJoinReorder cares about how deep Joins are and reorders consecutive joins only
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
val noAliasesPlan = EliminateSubqueryAliases(plan)
scala> println(noAliasesPlan.numberedTreeString)
00 Project [id#51L]
01 +- Join Inner, (id#51L = id#57L)
02    :- Project [id#51L]
03    :  +- Join Inner, (id#51L = id#54L)
04    :     :- Relation[id#51L] parquet
05    :     +- Relation[id#54L] parquet
06    +- Relation[id#57L] parquet

// Let's go pro and create a custom RuleExecutor (i.e. an Optimizer)
import org.apache.spark.sql.catalyst.rules.RuleExecutor
import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
import org.apache.spark.sql.catalyst.analysis.EliminateSubqueryAliases
import org.apache.spark.sql.catalyst.optimizer.CostBasedJoinReorder
object Optimize extends RuleExecutor[LogicalPlan] {
  val batches =
    Batch("EliminateSubqueryAliases", Once, EliminateSubqueryAliases) ::
    Batch("Join Reorder", Once, CostBasedJoinReorder) :: Nil
}

val joinsReordered = Optimize.execute(plan)
scala> println(joinsReordered.numberedTreeString)
00 Project [id#51L]
01 +- Join Inner, (id#51L = id#54L)
02    :- Project [id#51L]
03    :  +- Join Inner, (id#51L = id#57L)
04    :     :- Relation[id#51L] parquet
05    :     +- Relation[id#57L] parquet
06    +- Relation[id#54L] parquet

// Execute the plans
// Compare the plans as diagrams in web UI @ http://localhost:4040/SQL
// We'd have to use too many internals so let's turn CBO on and off
// Moreover, please remember that the query "phases" are cached
// That's why we copy and paste the entire query for execution
import org.apache.spark.sql.internal.SQLConf
val cc = SQLConf.get
cc.setConf(SQLConf.CBO_ENABLED, false)
val q = t1.join(t2, Seq("id")).join(tiny, Seq("id"))
q.collect.foreach(_ => ())

cc.setConf(SQLConf.CBO_ENABLED, true)
val q = t1.join(t2, Seq("id")).join(tiny, Seq("id"))
q.collect.foreach(_ => ())