Distributed Training with TorchElastic on Kubernetes


Distributed training is useful for speeding up training of a model with large dataset by utilizing multiple nodes (computers). In the past few years, the technical difficulty of doing distributed training has lowered drastically that it is no longer reserved just for engineers working at a large AI institution.

At a quick glance, distributed training involves multiple nodes running a training. The nodes are assigned ranks from 0 to N, with 0 being the master. Typically:

  • all nodes have a copy of the model

  • each node will sample a partition of the full dataset (e.g. by sampling data with index % num_nodes == node_rank) to train its model copy

  • the gradients are accumulated then synced to master for model update

Let's look at one way of doing this simply using 3 components:

  • PyTorch Lightning: handles all the engineering code such as device placement, so that there's little to no changes to the model code.

  • TorchElastic: handles the distributed communication/interface in a fault-tolerant manner.

  • Kubernetes: compute cluster with multiple nodes (see this guide to set up your own Kubernetes cluster).

PyTorch Lightning

If you're using PyTorch and manually doing a lot of "engineering chores" that are not directly related to your model/loss function, you should use PyTorch Lightning. Primarily, it takes away the need for boilerplate code by handling the training/evaluation loop, device placement, logging, checkpointing, etc. You only need to focus on the model architecture, loss function, and the dataset.

The device placement handling is very useful for running your model on different hardwares (with or without GPUs). It also supports distributed training with TorchElastic.

Suppose you already have a model written as a LightningModule that is not distributed yet. You only need to make a few modifications:

  • set the Trainer to use Distributed Data Parallel (ddp) by specifying accelerator='ddp'. This takes care of the distributed training logic, including the appropriate data sampling for each node by automatically switching to DistributedSampler.

  • in your LightningModule, assign a self.is_dist attributed based on whether accelerator='ddp' is specified.

  • synchronize your logging across multiple nodes by specifying sync_dist=self.is_dist

    to prevent conflicts, e.g.self.log('train/loss', loss, sync_dist=self.is_dist)

  • likewise, ensure that any processes that should be run once (such as uploading artifacts/model file) is called only on the master node by checking os.environ('RANK', '0') == '0'. The RANK environment variable will be set by TorchElastic when it runs.

  • for more comprehensive tips on updating your LightningModule to support distributed training, check out this page.

Next, test your changes by simulating distributed training on a single node:

# use num_processes to simulate 2 nodes on one machine
trainer = Trainer(accelerator="ddp", num_processes=2)


If you're using PyTorch >= 1.9.0, TorchElastic is already included as torch.distributed. It will be used to launch the distributed training process on every node (container), which will be composed by our Kubernetes manifest file next.

TorchElastic works by using an etcd server for communication, so:

  • in your Conda environment, install python-etcd.

  • in your Dockerfile, install etcd to system. Use this script.

PyTorchLightning works nicely with TorchElastic, so that's all.


No matter if you use a vendor or set up your own, a Kubernetes cluster with multiple nodes is ready to run distributed training. Here's the overview:

  • spawn a etcd server on Kubernetes

  • spawn multiple pods, each with a container that maximizes the amount of compute resources available on a node

  • run the torchelastic command on each container with the etcd parameters

The logic for spawning pods - managing them elastically, inserting the right arguments such as RANK/WORLD_SIZE, can be tricky. Luckily, TorchElastic has a ElasticJob Controller that can be installed on your cluster as Custom Resource Definition (CRD) to manage these pods elastically.

Install the TorchElastic CRDs - you will need cluster admin role to do so. By default this will create a elastic-job namespace to run the training in, but you can customize it by modifying the config.


In your Dockerfile, as mentioned earlier, install etcd.

Additionally, configure it to run torch distributed as entrypoint. The CRD will automatically append the relevant rank commands when creating pods.


ENTRYPOINT ["python", "-m", "torch.distributed.run"]
CMD ["--help"]

Manifest Files

We only need 2 Kubernetes manifest files - for etcd, and elasticjob.

Note that each etcd server is only reserved for running one distributed training session; suppose multiple engineers want to run different models with distributed training on the same cluster, they each need to spawn their own instance with a new pair of etcd server and elasticjob without conflict.

Here are the example manifest files modified from the original TorchElastic examples to run simultaneously without conflict. Replace the example "MY-MODEL" with a different name for each instance.

apiVersion: v1
kind: Service
  namespace: elastic-job
  name: MY-MODEL-etcd-service
    - name: etcd-client-port
      port: 2379
      protocol: TCP
      targetPort: 2379
    app: MY-MODEL-etcd

apiVersion: v1
kind: Pod
  namespace: elastic-job
  name: MY-MODEL-etcd
    app: MY-MODEL-etcd
    - name: etcd
      image: quay.io/coreos/etcd:latest
        - /usr/local/bin/etcd
        - --data-dir
        - /var/lib/etcd
        - --enable-v2
        - --listen-client-urls
        - --advertise-client-urls
        - --initial-cluster-state
        - new
        - containerPort: 2379
          name: client
          protocol: TCP
        - containerPort: 2380
          name: server
          protocol: TCP
  restartPolicy: Always
apiVersion: elastic.pytorch.org/v1alpha1
kind: ElasticJob
  namespace: elastic-job
  name: MY-MODEL
  rdzvEndpoint: MY-MODEL-etcd-service:2379
  minReplicas: 1
  maxReplicas: 2
      replicas: 2
      restartPolicy: ExitCode
        apiVersion: v1
        kind: Pod
          - name: elasticjob-worker
            image: YOUR_DOCKER_IMAGE:0.0.1
            imagePullPolicy: Always
              - "--nproc_per_node=1"
              - "my_model/train.py"
              - "trainer.accelerator=ddp"
              # if you can pass it to argparse/Hydra config
                cpu: 4
                memory: 12Gi
                nvidia.com/gpu: 1
            - name: dshm # increase shared memory for dataloader
              mountPath: /dev/shm
          - name: dshm
              medium: Memory

Then apply these manifest files and you'll have distributed training running, with the ElasticJob Controller managing pods elastically to ensure uptime.

You can also parametrize these manifest files with Helm to create multiple distributed training instances.

Tying it all together

That's all you need to run distributed training. To summarize:

  • use PyTorchLightning, specify accelerator='ddp' and some logging fixes

  • install etcd on Conda environment as well as Docker image

  • build your Docker image, specify the ENDPOINT with torch.distributed.run

  • install the TorchElastic CRDs on your Kubernetes cluster

  • create and apply the etcd and elasticjob manifest files on Kubernetes

  • you have distributed training running!

