TorchDistributor¶
TorchDistributor is a Distributor to run PyTorch's torch.distributed.run module on Apache Spark clusters.
TorchDistributor is a PySpark translation of torchrun (from Torch Distributed Elastic).
Demo¶
from pyspark.ml.torch.distributor import TorchDistributor
distributor = TorchDistributor(
num_processes=1,
local_mode=False,
use_gpu=False)
# Use a path to a training script
# and variable-length kwargs
distributor.run(
"train.py",
"--learning-rate=1e-3",
"--batch-size=64",
"--my-key=my-value")
# Started local training with 1 processes
# NOTE: Redirects are currently not supported in Windows or MacOs.
# Finished local training with 1 processes
# Use a Callable (function)
# The number of positional arguments is the number of kwargs
def train(a, b, c):
print(f"Got a={a}, b={b}, c={c}")
return 'success'
distributor.run(
train,
"--learning-rate=1e-3",
"--batch-size=64",
"--my-key=my-value")
# Started distributed training with 1 executor proceses
# NOTE: Redirects are currently not supported in Windows or MacOs. (0 + 1) / 1]
# NOTE: Redirects are currently not supported in Windows or MacOs.
# Got a=--learning-rate=1e-3, b=--batch-size=64, c=--my-key=my-value
# Got a=--learning-rate=1e-3, b=--batch-size=64, c=--my-key=my-value
# Finished distributed training with 1 executor proceses
# 'success'
Running Distributed Training¶
run(
self,
train_object: Union[Callable, str],
*args: Any) -> Optional[Any]
run determines what to run (e.g., a function or a script based on the given train_object).
- With a function,
runuses _run_training_on_pytorch_function - With a script,
runuses _run_training_on_pytorch_file
In the end, run runs a local or distributed training.
- In local mode,
runruns local training - In non-local mode,
runruns distributed training
Local Training¶
_run_local_training(
self,
framework_wrapper_fn: Callable,
train_object: Union[Callable, str],
*args: Any,
) -> Optional[Any]
_run_local_training looks up CUDA_VISIBLE_DEVICES among the environment variables.
With use_gpu, _run_local_training...FIXME
_run_local_training prints out the following INFO message to the logs:
Started local training with [num_processes] processes
_run_local_training executes the given framework_wrapper_fn function (with the input_params, the given train_object and the args).
In the end, _run_local_training prints out the following INFO message to the logs:
Finished local training with [num_processes] processes
Distributed Training¶
_run_distributed_training(
self,
framework_wrapper_fn: Callable,
train_object: Union[Callable, str],
*args: Any,
) -> Optional[Any]
_run_distributed_training...FIXME
_run_training_on_pytorch_function¶
_run_training_on_pytorch_function(
input_params: Dict[str, Any],
train_fn: Callable,
*args: Any
) -> Any
_run_training_on_pytorch_function prepares train and output files.
_run_training_on_pytorch_function...FIXME
Setting Up Files¶
# @contextmanager
_setup_files(
train_fn: Callable,
*args: Any
) -> Generator[Tuple[str, str], None, None]
_setup_files gives the paths of a TorchRun train file and output.pickle output file.
_setup_files creates a save directory.
_setup_files saves train_fn function to the save directory (that gives a pickle_file_path).
_setup_files uses the save directory and output.pickle name for the output file path.
_setup_files creates a torchrun_train_file with the following:
- Save directory
pickle_file_pathoutput.pickleoutput file path
In the end, _setup_files yields (gives) the torchrun_train_file and the output.pickle output file path.
Creating TorchRun Train File¶
_create_torchrun_train_file(
save_dir_path: str,
pickle_file_path: str,
output_file_path: str
) -> str
_create_torchrun_train_file creates train.py in the given save_dir_path with the following content (based on the given pickle_file_path and the output_file_path):
import cloudpickle
import os
if __name__ == "__main__":
with open("[pickle_file_path]", "rb") as f:
train_fn, args = cloudpickle.load(f)
output = train_fn(*args)
with open("[output_file_path]", "wb") as f:
cloudpickle.dump(output, f)
_run_training_on_pytorch_file¶
_run_training_on_pytorch_file(
input_params: Dict[str, Any],
train_path: str,
*args: Any
) -> None
_run_training_on_pytorch_file looks up the log_streaming_client in the given input_params (or assumes None).
FIXME What's log_streaming_client?
_run_training_on_pytorch_file creates torchrun command.
_run_training_on_pytorch_file executes the command.
_create_torchrun_command¶
_create_torchrun_command(
input_params: Dict[str, Any],
path_to_train_file: str,
*args: Any
) -> List[str]
_create_torchrun_command takes the value of the following parameters (from the given input_params):
local_modenum_processes
_create_torchrun_command determines the torchrun_args and processes_per_node based on local_mode.
| local_mode | torchrun_args | processes_per_node |
|---|---|---|
True |
| num_processes(from the given input_params) |
False |
| 1 |
In the end, _create_torchrun_command returns a Python command to execute torch_run_process_wrapper module (python -m) with the following positional arguments:
torchrun_args--nproc_per_node=[processes_per_node]- The given
path_to_train_file - The given
args