Distributed Training with PyTorch

Shagun Sodhani (~shagunsodhani)




PyTorch is one of the most popular ML frameworks with the recent releases focusing on enhanced support for distributed training. This tutorial provides a code-lab kind guide to the distributed training mechanisms provided by PyTorch. It should be useful for both practitioners & researchers who want to train larger models and faster.


  • The tutorial is designed for practitioners/engineers/researchers/students who are familiar with PyTorch and Machine Learning (ML) and want to scale their machine learning training systems.

  • Prior experience with PyTorch is recommended but not required.

  • Prior experience in ML is required.

  • Experience with distributed computing is not required.

By the end of the tutorial/session, the attendees would be able to take a simple PyTorch model, and scale it to work with dozens of machines. For the straightforward use cases, this will require writing just a few extra lines of code.

Content URLs:

The tutorial would be organized as a code-lab, and most of the time would be spent on hands-on programming.

The code will be accessible as Jupyter Notebooks (hosted over Colab). Since we are talking about distributed training, we need access to multiple GPUs. I am thinking of using either AWS, Google Cloud Platform, or Microsoft Azure. All these platforms provide compute resources on a trial basis, which will be sufficient for the tutorial.

Tentative timeline

Part 1: Setup and Data Parallel Mode

  • 00:00 - 10:00 -> Setting up the stage: speaker introduction, 5 min intro to PyTorch, presenting the motivation for using distributed training.
  • 10:00 - 15:00 -> Walking through a simple standard ML training/evaluation pipeline. Concretely, I will use the example of training ResNet50 for the ImageNet dataset on a single GPU. The code will be a modification of https://github.com/pytorch/examples/tree/master/imagenet.
  • 15:00-20:00 -> Start training this model, describe the bottleneck in the training process, motivate the use of distributed training
  • 20:00-25:00 -> 5 min buffer time to make sure everyone can start training their models.
  • 25:00-35:00 -> Introduce Data Parallelism based on https://pytorch.org/docs/stable/nn.html?highlight=dataparallel#torch.nn.DataParallel This part of the presentation will be based entirely on slides (no coding). It will focus on building the intuition for Data-Parallel training (irrespective of the actual framework used).
  • 35:00-45:00 -> Show how to use Data-Parallel. It requires only a few lines of change, and the buffer time is to answer the questions (I expect many) and make sure everyone can start training the model.
  • 45:00-55:00 -> Gotachs when using Data-Parallel, e.g., how to correctly save and load the model and perform checkpointing (code examples).
  • 55:00-60:00 -> Advanced use cases. For example, part of the model on CPU and part on the GPU
  • 60:00-70:00 -> Talk about limitations of DataParallel (DP). Specifically, how much can be scale using DataParallel (e.g., in PyTorch, DataParallel can only be used with GPUs on the same node(machine)). Alternative ways of implementing DataParallel (e.g., using multiprocessing). This portion of the talk will largely be based on slides.
  • 70:00-80:00 -> 10 min break. Attendees can use this time to ask questions

Part 2: Distributed Data-Parallel

  • 80:00-90:00 -> Quick Recap, Introduce Distributed Data Parallelism (DDP) based on https://pytorch.org/docs/stable/notes/ddp.html?highlight=data%20parallel. Relate DDP to DP and explain how DDP fixes various shortcomings of DP.
  • 90:00-105:00 -> Show how to use Distributed Data-Parallel. This is more involved than Data-Parallel.
  • 105:00-110:00 -> 5 min buffer time to ensure the audience can start training their models.
  • 110:00-120:00 -> Some common recipes/pitfalls like saving and loading checkpoints, how special layers like batch-norm work with DDP etc.
  • 120:00-125:00 -> How to use a model (trained with DDP) for inference on a single GPU. While this is generally straightforward, this is easy to get wrong.
  • 125:00-135:00 -> Buffer of 10 min if any of the previous steps takes longer than expected and for questions so far
  • 135:00-145:00 -> What other solutions does PyTorch provide for Distributed Training of Models? This includes the torch.distributed package which offers excellent grained control for distributed model training, model-parallelism (with or without DDP), RPC, etc. This part will be entirely based on slides and can be thought of as "How can I push the boundaries for scaling models"? This section will also include "Where to go from here? Link to tutorials, existing implementations."
  • 145:00-150:00 -> Questions.

One thing I am worried about is that setup for downloading and processing ImageNet data. If I can not figure out a reasonable strategy (in the coming 2-3 weeks), I will replace the ImageNet dataset with CIFAR100 dataset (which is very easy to download).

Speaker Info:

Hi! I am Shagun, a Research Engineer with Facebook AI Research. Before that, I was an MSc student at Mila (Quebec Artificial Intelligence Institute) with Prof Yoshua Bengio and Prof Jian Tang. My research focuses on lifelong reinforcement learning - training AI systems that can interact with and learn from the physical world (reinforcement learning) and consistently improve as they do so without forgetting the previous knowledge (lifelong learning).

My stack primarily comprises of Python (and related ML/DS/visualization toolkits). I love to play with new technology and look forward to meeting new people at PyCon :).

Speaker Links:

Section: Data Science, Machine Learning and AI
Type: Workshop
Target Audience: Advanced
Last Updated: