Distributed Training with TorchElastic on Kubernetes
2021/07/18
Distributed training is useful for speeding up training of a model with large dataset by utilizing multiple nodes (computers). In the past few years, the technical difficulty of doing distributed training has lowered drastically that it is no longer reserved just for engineers working at a large AI institution.
At a quick glance, distributed training involves multiple nodes running a training. The nodes are assigned ranks from 0 to N, with 0 being the master. Typically:
all nodes have a copy of the model
each node will sample a partition of the full dataset (e.g. by sampling data with
index % num_nodes == node_rank
) to train its model copythe gradients are accumulated then synced to master for model update
Let's look at one way of doing this simply using 3 components:
PyTorch Lightning: handles all the engineering code such as device placement, so that there's little to no changes to the model code.
TorchElastic: handles the distributed communication/interface in a fault-tolerant manner.
Kubernetes: compute cluster with multiple nodes (see this guide to set up your own Kubernetes cluster).
PyTorch Lightning
If you're using PyTorch and manually doing a lot of "engineering chores" that are not directly related to your model/loss function, you should use PyTorch Lightning. Primarily, it takes away the need for boilerplate code by handling the training/evaluation loop, device placement, logging, checkpointing, etc. You only need to focus on the model architecture, loss function, and the dataset.
The device placement handling is very useful for running your model on different hardwares (with or without GPUs). It also supports distributed training with TorchElastic.
Suppose you already have a model written as a LightningModule that is not distributed yet. You only need to make a few modifications:
set the Trainer to use Distributed Data Parallel (ddp) by specifying
accelerator='ddp'
. This takes care of the distributed training logic, including the appropriate data sampling for each node by automatically switching to DistributedSampler.in your LightningModule, assign a
self.is_dist
attributed based on whetheraccelerator='ddp'
is specified.synchronize your logging across multiple nodes by specifying
sync_dist=self.is_dist
to prevent conflicts, e.g.
self.log('train/loss', loss, sync_dist=self.is_dist)
likewise, ensure that any processes that should be run once (such as uploading artifacts/model file) is called only on the master node by checking
os.environ('RANK', '0') == '0'
. The RANK environment variable will be set by TorchElastic when it runs.for more comprehensive tips on updating your LightningModule to support distributed training, check out this page.
Next, test your changes by simulating distributed training on a single node:
TorchElastic
If you're using PyTorch >= 1.9.0
, TorchElastic is already included as torch.distributed
. It will be used to launch the distributed training process on every node (container), which will be composed by our Kubernetes manifest file next.
TorchElastic works by using an etcd server for communication, so:
in your Conda environment, install
python-etcd
.in your Dockerfile, install etcd to system. Use this script.
PyTorchLightning works nicely with TorchElastic, so that's all.
Kubernetes
No matter if you use a vendor or set up your own, a Kubernetes cluster with multiple nodes is ready to run distributed training. Here's the overview:
spawn a etcd server on Kubernetes
spawn multiple pods, each with a container that maximizes the amount of compute resources available on a node
run the torchelastic command on each container with the etcd parameters
The logic for spawning pods - managing them elastically, inserting the right arguments such as RANK/WORLD_SIZE, can be tricky. Luckily, TorchElastic has a ElasticJob Controller that can be installed on your cluster as Custom Resource Definition (CRD) to manage these pods elastically.
Install the TorchElastic CRDs - you will need cluster admin role to do so. By default this will create a elastic-job namespace to run the training in, but you can customize it by modifying the config.
Dockerfile
In your Dockerfile, as mentioned earlier, install etcd.
Additionally, configure it to run torch distributed as entrypoint. The CRD will automatically append the relevant rank commands when creating pods.
Manifest Files
We only need 2 Kubernetes manifest files - for etcd, and elasticjob.
Note that each etcd server is only reserved for running one distributed training session; suppose multiple engineers want to run different models with distributed training on the same cluster, they each need to spawn their own instance with a new pair of etcd server and elasticjob without conflict.
Here are the example manifest files modified from the original TorchElastic examples to run simultaneously without conflict. Replace the example "MY-MODEL" with a different name for each instance.
Then apply these manifest files and you'll have distributed training running, with the ElasticJob Controller managing pods elastically to ensure uptime.
You can also parametrize these manifest files with Helm to create multiple distributed training instances.
Tying it all together
That's all you need to run distributed training. To summarize:
use PyTorchLightning, specify
accelerator='ddp'
and some logging fixesinstall etcd on Conda environment as well as Docker image
build your Docker image, specify the
ENDPOINT
with torch.distributed.runinstall the TorchElastic CRDs on your Kubernetes cluster
create and apply the etcd and elasticjob manifest files on Kubernetes
you have distributed training running!
Last updated