TorchDistributor¶
TorchDistributor is a Distributor to run PyTorch's torch.distributed.run module on Apache Spark clusters.
TorchDistributor is a PySpark translation of torchrun (from Torch Distributed Elastic).
Demo¶
from pyspark.ml.torch.distributor import TorchDistributor
distributor = TorchDistributor(
    num_processes=1,
    local_mode=False,
    use_gpu=False)
# Use a path to a training script
# and variable-length kwargs
distributor.run(
    "train.py",
    "--learning-rate=1e-3",
    "--batch-size=64",
    "--my-key=my-value")
# Started local training with 1 processes
# NOTE: Redirects are currently not supported in Windows or MacOs.
# Finished local training with 1 processes
# Use a Callable (function)
# The number of positional arguments is the number of kwargs
def train(a, b, c):
    print(f"Got a={a}, b={b}, c={c}")
    return 'success'
distributor.run(
    train,
    "--learning-rate=1e-3",
    "--batch-size=64",
    "--my-key=my-value")
# Started distributed training with 1 executor proceses
# NOTE: Redirects are currently not supported in Windows or MacOs.    (0 + 1) / 1]
# NOTE: Redirects are currently not supported in Windows or MacOs.
# Got a=--learning-rate=1e-3, b=--batch-size=64, c=--my-key=my-value
# Got a=--learning-rate=1e-3, b=--batch-size=64, c=--my-key=my-value
# Finished distributed training with 1 executor proceses
# 'success'
Running Distributed Training¶
run(
    self,
    train_object: Union[Callable, str],
    *args: Any) -> Optional[Any]
run determines what to run (e.g., a function or a script based on the given train_object).
- With a function, runuses _run_training_on_pytorch_function
- With a script, runuses _run_training_on_pytorch_file
In the end, run runs a local or distributed training.
- In local mode, runruns local training
- In non-local mode, runruns distributed training
Local Training¶
_run_local_training(
    self,
    framework_wrapper_fn: Callable,
    train_object: Union[Callable, str],
    *args: Any,
) -> Optional[Any]
_run_local_training looks up CUDA_VISIBLE_DEVICES among the environment variables.
With use_gpu, _run_local_training...FIXME
_run_local_training prints out the following INFO message to the logs:
Started local training with [num_processes] processes
_run_local_training executes the given framework_wrapper_fn function (with the input_params, the given train_object and the args).
In the end, _run_local_training prints out the following INFO message to the logs:
Finished local training with [num_processes] processes
Distributed Training¶
_run_distributed_training(
    self,
    framework_wrapper_fn: Callable,
    train_object: Union[Callable, str],
    *args: Any,
) -> Optional[Any]
_run_distributed_training...FIXME
_run_training_on_pytorch_function¶
_run_training_on_pytorch_function(
    input_params: Dict[str, Any],
    train_fn: Callable,
    *args: Any
) -> Any
_run_training_on_pytorch_function prepares train and output files.
_run_training_on_pytorch_function...FIXME
Setting Up Files¶
# @contextmanager
_setup_files(
    train_fn: Callable,
    *args: Any
) -> Generator[Tuple[str, str], None, None]
_setup_files gives the paths of a TorchRun train file and output.pickle output file.
_setup_files creates a save directory.
_setup_files saves train_fn function to the save directory (that gives a pickle_file_path).
_setup_files uses the save directory and output.pickle name for the output file path.
_setup_files creates a torchrun_train_file with the following:
- Save directory
- pickle_file_path
- output.pickleoutput file path
In the end, _setup_files yields (gives) the torchrun_train_file and the output.pickle output file path.
Creating TorchRun Train File¶
_create_torchrun_train_file(
    save_dir_path: str,
    pickle_file_path: str,
    output_file_path: str
) -> str
_create_torchrun_train_file creates train.py in the given save_dir_path with the following content (based on the given pickle_file_path and the output_file_path):
import cloudpickle
import os
if __name__ == "__main__":
    with open("[pickle_file_path]", "rb") as f:
        train_fn, args = cloudpickle.load(f)
    output = train_fn(*args)
    with open("[output_file_path]", "wb") as f:
        cloudpickle.dump(output, f)
_run_training_on_pytorch_file¶
_run_training_on_pytorch_file(
    input_params: Dict[str, Any],
    train_path: str,
    *args: Any
) -> None
_run_training_on_pytorch_file looks up the log_streaming_client in the given input_params (or assumes None).
FIXME What's log_streaming_client?
_run_training_on_pytorch_file creates torchrun command.
_run_training_on_pytorch_file executes the command.
_create_torchrun_command¶
_create_torchrun_command(
    input_params: Dict[str, Any],
    path_to_train_file: str,
    *args: Any
) -> List[str]
_create_torchrun_command takes the value of the following parameters (from the given input_params):
- local_mode
- num_processes
_create_torchrun_command determines the torchrun_args and processes_per_node based on local_mode.
| local_mode | torchrun_args | processes_per_node | 
|---|---|---|
| True | 
 | num_processes(from the given input_params) | 
| False | 
 | 1 | 
In the end, _create_torchrun_command returns a Python command to execute torch_run_process_wrapper module (python -m) with the following positional arguments:
- torchrun_args
- --nproc_per_node=[processes_per_node]
- The given path_to_train_file
- The given args