Skip to content

TaskContext

TaskContext is an abstraction of task contexts.

Contract (Subset)

addTaskCompletionListener

addTaskCompletionListener[U](
  f: (TaskContext) => U): TaskContext
addTaskCompletionListener(
  listener: TaskCompletionListener): TaskContext

Registers a TaskCompletionListener

val rdd = sc.range(0, 5, numSlices = 1)

import org.apache.spark.TaskContext
val printTaskInfo = (tc: TaskContext) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}

rdd.foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskCompletionListener(printTaskInfo)
}

addTaskFailureListener

addTaskFailureListener(
  f: (TaskContext, Throwable) => Unit): TaskContext
addTaskFailureListener(
  listener: TaskFailureListener): TaskContext

Registers a TaskFailureListener

val rdd = sc.range(0, 2, numSlices = 2)

import org.apache.spark.TaskContext
val printTaskErrorInfo = (tc: TaskContext, error: Throwable) => {
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |error:         ${error.toString}
                |-------------------""".stripMargin
  println(msg)
}

val throwExceptionForOddNumber = (n: Long) => {
  if (n % 2 == 1) {
    throw new Exception(s"No way it will pass for odd number: $n")
  }
}

// FIXME It won't work.
rdd.map(throwExceptionForOddNumber).foreachPartition { _ =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
}

// Listener registration matters.
rdd.mapPartitions { (it: Iterator[Long]) =>
  val tc = TaskContext.get
  tc.addTaskFailureListener(printTaskErrorInfo)
  it
}.map(throwExceptionForOddNumber).count

fetchFailed

fetchFailed: Option[FetchFailedException]

Used when:

  • TaskRunner is requested to run

getKillReason

getKillReason(): Option[String]

getLocalProperty

getLocalProperty(
  key: String): String

Looks up a local property by key

getMetricsSources

getMetricsSources(
  sourceName: String): Seq[Source]

Looks up Sources by name

Registering Accumulator

registerAccumulator(
  a: AccumulatorV2[_, _]): Unit

Registers a AccumulatorV2

Used when:

Resources

resources(): Map[String, ResourceInformation]

Resources (names) allocated to this task

See:

taskMetrics

taskMetrics(): TaskMetrics

TaskMetrics

others

Important

There are other methods, but don't seem very interesting.

Implementations

Serializable

TaskContext is a Serializable (Java).

Accessing TaskContext

get(): TaskContext

get returns the thread-local TaskContext instance.

import org.apache.spark.TaskContext
val tc = TaskContext.get
val rdd = sc.range(0, 3, numSlices = 3)

assert(rdd.partitions.size == 3)

rdd.foreach { n =>
  import org.apache.spark.TaskContext
  val tc = TaskContext.get
  val msg = s"""|-------------------
                |partitionId:   ${tc.partitionId}
                |stageId:       ${tc.stageId}
                |attemptNum:    ${tc.attemptNumber}
                |taskAttemptId: ${tc.taskAttemptId}
                |-------------------""".stripMargin
  println(msg)
}