Papers
Topics
Authors
Recent
Gemini 2.5 Flash
Gemini 2.5 Flash
126 tokens/sec
GPT-4o
47 tokens/sec
Gemini 2.5 Pro Pro
43 tokens/sec
o3 Pro
4 tokens/sec
GPT-4.1 Pro
47 tokens/sec
DeepSeek R1 via Azure Pro
28 tokens/sec
2000 character limit reached

PyTorch: An Imperative Style, High-Performance Deep Learning Library (1912.01703v1)

Published 3 Dec 2019 in cs.LG, cs.MS, and stat.ML

Abstract: Deep learning frameworks have often focused on either usability or speed, but not both. PyTorch is a machine learning library that shows that these two goals are in fact compatible: it provides an imperative and Pythonic programming style that supports code as a model, makes debugging easy and is consistent with other popular scientific computing libraries, while remaining efficient and supporting hardware accelerators such as GPUs. In this paper, we detail the principles that drove the implementation of PyTorch and how they are reflected in its architecture. We emphasize that every aspect of PyTorch is a regular Python program under the full control of its user. We also explain how the careful and pragmatic implementation of the key components of its runtime enables them to work together to achieve compelling performance. We demonstrate the efficiency of individual subsystems, as well as the overall speed of PyTorch on several common benchmarks.

Citations (38,235)

Summary

  • The paper introduces PyTorch as an imperative, Pythonic deep learning library that simplifies debugging and model development within the Python ecosystem.
  • It details a design leveraging operator overloading, efficient GPU execution, and a custom CUDA memory allocator to achieve competitive performance.
  • The study also highlights PyTorch’s extensibility through seamless interoperability with libraries like NumPy and robust support for distributed and parallel computation.

The paper introduces PyTorch, a deep learning library that combines usability and speed by employing an imperative, Pythonic programming style with efficient execution and hardware acceleration. The library supports code-as-a-model, simplifies debugging, and is consistent with scientific computing libraries, while maintaining performance through careful implementation and pragmatic design choices.

The authors detail the principles that guided PyTorch's implementation and their reflection in its architecture. They emphasize that every aspect of PyTorch operates as a standard Python program, fully controllable by the user. The paper also explains how the runtime components are implemented to achieve compelling performance. The efficiency of individual subsystems and the overall speed of PyTorch are demonstrated using common benchmarks.

The design principles of PyTorch include:

  • Being Pythonic, which means being a first-class member of the Python ecosystem and integrating with standard plotting, debugging, and data processing tools.
  • Prioritizing researchers by making model writing, data loaders, and optimizers as easy and productive as possible.
  • Providing pragmatic performance, accepting added complexity to deliver performance, but also providing tools for manual control to allow researchers to optimize their code independently.
  • Embracing the "Worse is Better" philosophy, which prioritizes simplicity and maintainability to enable faster adaptation and feature implementation.

The usability-centric design of PyTorch is reflected in treating deep learning models as Python programs. This approach supports the growing complexity of neural networks, from simple digit recognition to playing StarCraft, by foregoing a graph-metaprogramming approach and preserving Python's imperative programming model. Layers are expressed as Python classes with constructors for parameter initialization and forward methods for input activation processing. This design extends to optimizers and data loaders, facilitating experimentation with new training techniques such as generative adversarial networks.

Interoperability and extensibility are key priorities, allowing bidirectional data exchange with libraries like NumPy and DLPack without data copying. The automatic differentiation system allows users to add support for custom differentiable functions by defining subclasses of \lstinline{torch.autograd.Function} with \lstinline{forward()} and \lstinline{backward()} methods. New datasets can be added by subclassing \lstinline|torch.utils.data.Dataset| and implementing \lstinline{getitem} and \lstinline{len} methods. Users can replace any component of PyTorch to meet specific needs or performance requirements.

PyTorch uses operator overloading to build a representation of the computed function during execution. It performs reverse-mode automatic differentiation for computing gradients of a scalar output with respect to a multivariate input. The system can differentiate through code that mutates tensors by using a versioning system to track modifications and ensure data integrity.

To achieve efficient performance from a Python interpreter, PyTorch optimizes every aspect of its execution and empowers users to leverage additional optimization strategies. Most of PyTorch is written in C++ to achieve high performance. The core \lstinline{libtorch} library implements the tensor data structure, GPU and CPU operators, and basic parallel primitives. It also includes the automatic differentiation system and gradient formulas for built-in functions. Python bindings are generated using YAML meta-data files, which has enabled the creation of bindings for other languages, such as NimTorch, and HaskTorch. First-class C++ bindings and modeling libraries can be used in environments where Python is less convenient. The TorchScript engine allows PyTorch models described in Python code to be run without Python.

PyTorch maintains a separation between control flow (program branches, loops) and data flow (tensors and operations). Control flow is handled by Python and optimized C++ code on the host CPU, resulting in a sequence of operator invocations on the device. Operators can be executed on either CPU or GPU. PyTorch executes operators asynchronously on GPU using CUDA streams, overlapping CPU code execution with GPU tensor operators.

PyTorch implements a custom allocator that incrementally builds a cache of CUDA memory and reassigns it to later allocations without using CUDA APIs to optimize the speed of dynamic memory allocators. The allocator is tuned for deep learning memory usage patterns, rounding up allocations to multiples of 512 bytes to avoid fragmentation and maintaining a distinct memory pool for every CUDA stream.

To address the limitations of Python's global interpreter lock (GIL), PyTorch extends the Python \verb|multiprocessing| module into \verb|torch.multiprocessing|, which moves tensor data sent to other processes to shared memory instead of sending it over the communication channel. This improves performance and allows for the implementation of parallel programs that operate on independent GPUs and synchronize gradients using all-reduce primitives. The system also transparently handles the sharing of CUDA tensors, facilitating techniques like Hogwild.

PyTorch uses a reference counting scheme to track tensor usage and immediately frees memory when the count reaches zero. This ensures that memory is released exactly when tensors become unneeded, optimizing memory usage.

In performance evaluations, PyTorch's ability to asynchronously execute dataflow on GPU was quantified using the built-in profiler, demonstrating near-perfect device utilization. NVIDIA profiler traces showed that the caching memory allocator reuses previously allocated regions, reducing the overhead of CUDA memory management functions. Benchmarks comparing PyTorch with other deep learning frameworks, including CNTK, MXNet, TensorFlow, Chainer, and PaddlePaddle, showed that PyTorch achieves competitive performance across various tasks.

The impact on ease-of-use was assessed by measuring how often various machine learning tools are mentioned on arXiv e-Prints since the initial release of PyTorch in January 2017. The results show the monthly number of mentions of the word "PyTorch" as a percentage of all mentions among these deep learning frameworks.

Future plans for PyTorch include improving speed and scalability with the PyTorch JIT, which allows PyTorch programs to be executed outside of the Python interpreter for further optimization. The authors also plan to enhance support for distributed computation by providing efficient primitives for data parallelism, as well as a Pythonic library for model parallelism based around remote procedure calls.

Youtube Logo Streamline Icon: https://streamlinehq.com