Skip to content

TorchDistributor

TorchDistributor is a Distributor to run PyTorch's torch.distributed.run module on Apache Spark clusters.

TorchDistributor is a PySpark translation of torchrun (from Torch Distributed Elastic).

Demo

from pyspark.ml.torch.distributor import TorchDistributor

distributor = TorchDistributor(
    num_processes=1,
    local_mode=False,
    use_gpu=False)
# Use a path to a training script
# and variable-length kwargs
distributor.run(
    "train.py",
    "--learning-rate=1e-3",
    "--batch-size=64",
    "--my-key=my-value")

# Started local training with 1 processes
# NOTE: Redirects are currently not supported in Windows or MacOs.
# Finished local training with 1 processes
# Use a Callable (function)
# The number of positional arguments is the number of kwargs
def train(a, b, c):
    print(f"Got a={a}, b={b}, c={c}")
    return 'success'

distributor.run(
    train,
    "--learning-rate=1e-3",
    "--batch-size=64",
    "--my-key=my-value")

# Started distributed training with 1 executor proceses
# NOTE: Redirects are currently not supported in Windows or MacOs.    (0 + 1) / 1]
# NOTE: Redirects are currently not supported in Windows or MacOs.
# Got a=--learning-rate=1e-3, b=--batch-size=64, c=--my-key=my-value
# Got a=--learning-rate=1e-3, b=--batch-size=64, c=--my-key=my-value
# Finished distributed training with 1 executor proceses
# 'success'

Running Distributed Training

run(
    self,
    train_object: Union[Callable, str],
    *args: Any) -> Optional[Any]

run determines what to run (e.g., a function or a script based on the given train_object).

In the end, run runs a local or distributed training.

Local Training

_run_local_training(
    self,
    framework_wrapper_fn: Callable,
    train_object: Union[Callable, str],
    *args: Any,
) -> Optional[Any]

_run_local_training looks up CUDA_VISIBLE_DEVICES among the environment variables.

With use_gpu, _run_local_training...FIXME

_run_local_training prints out the following INFO message to the logs:

Started local training with [num_processes] processes

_run_local_training executes the given framework_wrapper_fn function (with the input_params, the given train_object and the args).

In the end, _run_local_training prints out the following INFO message to the logs:

Finished local training with [num_processes] processes

Distributed Training

_run_distributed_training(
    self,
    framework_wrapper_fn: Callable,
    train_object: Union[Callable, str],
    *args: Any,
) -> Optional[Any]

_run_distributed_training...FIXME

_run_training_on_pytorch_function

_run_training_on_pytorch_function(
    input_params: Dict[str, Any],
    train_fn: Callable,
    *args: Any
) -> Any

_run_training_on_pytorch_function prepares train and output files.

_run_training_on_pytorch_function...FIXME

Setting Up Files

# @contextmanager
_setup_files(
    train_fn: Callable,
    *args: Any
) -> Generator[Tuple[str, str], None, None]

_setup_files gives the paths of a TorchRun train file and output.pickle output file.


_setup_files creates a save directory.

_setup_files saves train_fn function to the save directory (that gives a pickle_file_path).

_setup_files uses the save directory and output.pickle name for the output file path.

_setup_files creates a torchrun_train_file with the following:

In the end, _setup_files yields (gives) the torchrun_train_file and the output.pickle output file path.

Creating TorchRun Train File

_create_torchrun_train_file(
    save_dir_path: str,
    pickle_file_path: str,
    output_file_path: str
) -> str

_create_torchrun_train_file creates train.py in the given save_dir_path with the following content (based on the given pickle_file_path and the output_file_path):

import cloudpickle
import os

if __name__ == "__main__":
    with open("[pickle_file_path]", "rb") as f:
        train_fn, args = cloudpickle.load(f)
    output = train_fn(*args)
    with open("[output_file_path]", "wb") as f:
        cloudpickle.dump(output, f)

_run_training_on_pytorch_file

_run_training_on_pytorch_file(
    input_params: Dict[str, Any],
    train_path: str,
    *args: Any
) -> None

_run_training_on_pytorch_file looks up the log_streaming_client in the given input_params (or assumes None).

FIXME What's log_streaming_client?

_run_training_on_pytorch_file creates torchrun command.

_run_training_on_pytorch_file executes the command.

_create_torchrun_command

_create_torchrun_command(
    input_params: Dict[str, Any],
    path_to_train_file: str,
    *args: Any
) -> List[str]

_create_torchrun_command takes the value of the following parameters (from the given input_params):

  • local_mode
  • num_processes

_create_torchrun_command determines the torchrun_args and processes_per_node based on local_mode.

local_mode torchrun_args processes_per_node
True
  • --standalone
  • --nnodes=1
num_processes
(from the given input_params)
False
  • --nnodes=[num_processes]
  • --node_rank=[node_rank]
  • --rdzv_endpoint=[MASTER_ADDR]:[MASTER_PORT]
  • --rdzv_id=0
1

In the end, _create_torchrun_command returns a Python command to execute torch_run_process_wrapper module (python -m) with the following positional arguments:

  • torchrun_args
  • --nproc_per_node=[processes_per_node]
  • The given path_to_train_file
  • The given args