PyTorch and CPU-GPU Synchronizations

Writing Fast PyTorch Code

GPUs
PyTorch
Triton
Author
Affiliation

Tomas Ruiz

Ludwig-Maximilians-Universität München

Published

January 7, 2026

Modified

January 8, 2026

TL;DR: This post is a guide to understand and prevent CPU-GPU synchronizations, which will help you write fast and efficient PyTorch programs 🚀. I explain the concept with concrete PyTorch examples from this Github Gist, and profiles from NVIDIA Nsight Systems.

Introduction

CPU-GPU synchronizations (also known as host-device synchronizations) are a subtle but important mechanism to write fast and efficient PyTorch programs. The CPU-GPU synchronization is a blocking operation that prevents the CPU from scheduling new work (PyTorch ops) on the GPU. The PyTorch Tuning Guide makes some basic suggestions about how to avoid CPU-GPU synchronizations, e.g. by avoiding calls to tensor.item(), printing a device tensor, or calling tensor.cpu() / tensor.cuda(), but doesn’t explain in depth what is happening under the hood. Furthermore, there are other subtle ways to run into CPU-GPU synchronizations. In this post I dive deeper into synchronizations and explain how they slow down a PyTorch program. I show how to use the NVIDIA Nsight Systems profiler to identify and diagnose CPU-GPU synchronization issues. Finally, I also show how to use the experimental PyTorch API torch.cuda.set_sync_debug_mode in unit tests to verify the absence of CPU-GPU synchronizations in production code.

Figure 1: Profile showing CPU-GPU synchronization issues (here triggered by printing a CUDA tensor). A deep dive into the profile details is in Section 4.

What is a CPU-GPU Synchronization?

To fully take advantage of the speed of a GPU, it must be kept busy (high GPU utilization). This depends on the CPU scheduling enough work to keep the GPU busy. Typically, in a well-performing PyTorch program, the CPU schedules many instructions quickly, which are executed by the GPU. The CPU is said to run ahead of the GPU, because it issues instructions without waiting for the previous ones to complete. This is why PyTorch is said to be asynchronous. If the CPU fails to schedule enough work, the GPU will sit idle waiting for work.

The intuition: As an analogy to understand CPU-GPU synchronizations, think of a restaurant where a chef (CPU) schedules all the steps (PyTorch ops) for a big dinner party in advance. A CPU-GPU synchronization would be like the chef waiting to observe how a specific step in a specific dish turned out during the dinner, before scheduling the rest of the dishes and telling the cooks what to do. Obviously, this latency in communication will lead to the cooks being idle, waiting for the chef to schedule the dishes. Instead, the chef should schedule the entire dinner plan in advance (run ahead).

Example

Let’s look at a concrete example. Below we have a PyTorch program that executes a slow_operation() on the GPU, followed by a quick_operation() on the GPU. The quick operation just counts the number of 1s in the result tensor. The result of the quick operation is gathered in results tensor. Both functions are executed in a loop, to simulate a long-running program like LLM inference, where the slow operation could be the LLM forward pass, and the quick operation could be bookkeeping with the sampled tokens (a real vLLM example is explained in Section 7). To warm up the torch code, I run both the slow_operation() and quick_operation() once before the hot loop. I annotated regions of code with nvtx.annotate(), to keep track of GPU and CPU runtime during profiling.

Profiling Analysis

We can run the NVIDIA Nsight Systems profiler with the command below to produce the timelines shown in Figure 2 and Figure 3. The script takes a flag --do-print to include a CPU-GPU synchronization in the quick operation: a print of a GPU tensor, which forces data back from the GPU to the CPU.

Without CPU-GPU synchronization:

nsys profile \
  -o profile-no-print.nsys-rep \
  python cpu_gpu_sync_example.py

With CPU-GPU synchronization:

nsys profile \
  -o profile-do-print.nsys-rep \
  python cpu_gpu_sync_example.py --do-print

The resulting profiles are shown below and annotated. I ran the code on an NVIDIA RTX3090 GPU and Nsight Systems 2025.3.1.90.

To inspect the report:

  • Open the .nsys-rep in the Nsight Systems UI
  • Zoom into the NVTX-annotated region.
  • Look for long CPU-side CUDA API calls like cudaStreamSynchronize and correlate them with gaps in GPU utilization.

Run With CPU-GPU Synchronization

Figure 2: Profile with CPU-GPU synchronization.
  • We see gaps in the GPU utilization (first blue bar from top to bottom). These gaps indicate that the GPU has idle times and waits for work.
  • We observe that the runtimes of the GPU and CPU are similarly long, as seen both timelines being equal in length in the horizontal axis. In a well-performing program, the CPU runtime should be shorter than the GPU runtime (CPU runs ahead of the GPU).
  • In the CPU ops timeline, we see long green bars (cudaStreamSynchronize), which mean that the CPU is blocked waiting for the GPU to return some data.
  • On the GPU ops timeline we observe that the slow operation takes longer than the quick operation, while in the CPU ops timeline we observe the reverse. This means that the quick operation is the one blocking the CPU.

Healthy Run

Figure 3: Profile without CPU-GPU synchronization.
  • We observe that the CPU runs ahead of the GPU, so it dispatches all the work quickly and finishes way before the GPU.
  • The GPU utilization does not have any gaps. The blue line showing utilization is continuous.
  • In the GPU ops timeline, the slow operation is slower than the quick operation.
  • The full region interleaved-code runs faster than with the CPU-GPU synchronization (4.434 ms vs 4.827 ms).

Dynamic Shapes Can Trigger Synchronizations

Besides the code mentioned in the PyTorch Tuning Guide there is another common pattern that triggers CPU-GPU synchronizations, namely dynamic shapes. One example of dynamic shapes is boolean indexing x = t[bool_mask], where bool_mask is a boolean tensor on GPU. The size of the tensor x cannot be determined by the CPU alone, because it depends on the number of true values in bool_mask. Therefore, PyTorch typically needs to fetch data from the GPU to determine the amount of memory to allocate for x. This creates a CPU-GPU synchronization.

Another example of this same problem is slicing x = t[:gpu_index], where gpu_index is a scalar integer tensor on GPU. Once again, the size of the tensor x cannot be determined by the CPU alone, because it depends on the value of gpu_index.

There are other ways to slice a torch tensor, e.g. x = t.narrow(dim=0, start, length) (link). However, this operation requires length to be a Python int. Passing a GPU tensor as length will typically trigger a cast to Python int on CPU, which introduces a CPU-GPU synchronization.

A similar situation arises with x = torch.repeat_interleave(y, repeats). But interestingly, this API exposes an optional argument output_size that can be used to prevent the synchronization (link).

What all these synchronization triggers have in common is that the length of the resulting tensor depends on data residing on the GPU. This is why dynamic shapes are so problematic. Instead, if the length of the resulting tensor can be known statically on the CPU side, then we can find a way to avoid the CPU-GPU synchronization. PyTorch might not provide the precise API you need to avoid the synchronization. In that case, you might want to write a custom kernel in Triton to avoid the synchronization. The advantage of doing this is that you can fuse many sequential PyTorch operations into a single Triton kernel, which reduces the overhead of dispatching many small kernels on the CPU-side (around 2µs to 3µs per kernel). A concrete example is explained in Section 7. The accompanying Github Gist also uses Triton kernels to fuse together PyTorch ops, and prevent CPU-GPU synchronizations.

Unit Testing for CPU-GPU Synchronizations

So we need to prevent synchronization in our code. How can we do this without running a profiler on each code section we are interested in? PyTorch provides an experimental feature that can raise warnings or errors when CPU-GPU synchronizations occur: The torch.cuda.set_sync_debug_mode() function (link). Since it raises errors immediately, it can be used in unit tests to verify that the code is free of CPU-GPU synchronizations. For example it can be used in a context manager:

import functools
from contextlib import contextmanager

import torch

@contextmanager
def fail_if_gpu_cpu_synchronization(fail: bool):
    """Within this context, GPU-CPU synchronization raises an error if `fail` is True."""

    new_mode = 2 if fail else 0
    old_mode = torch.cuda.get_sync_debug_mode()
    torch.cuda.set_sync_debug_mode(new_mode)
    try:
        yield
    finally:
        torch.cuda.set_sync_debug_mode(old_mode)


x = torch.arange(10, device="cuda")
with fail_if_gpu_cpu_synchronization(fail=True):
    print(x)  # Raises an error

The context manager can be used in a decorator as well, to decorate functions that should fail if they contain CPU-GPU synchronizations.

def on_gpu_cpu_synchronization(fail: bool):
    """
    Wrap a function to raise an error on GPU-CPU synchronization if `fail` is True.
    """

    def decorator(fn):
        @functools.wraps(fn)
        def wrapper(*args, **kwargs):
            with fail_if_gpu_cpu_synchronization(fail):
                return fn(*args, **kwargs)

        return wrapper

    return decorator


@on_gpu_cpu_synchronization(fail=True)
def test_should_not_sync():
    x = torch.arange(10, device="cuda")
    y = x ** 2

Note that the decorator is used only on test functions, meaning that the PyTorch API to raise synchronization errors is never applied to production code, but only to the test code. This keeps this experimental feature isolated from production code while still providing quick and valuable feedback on CPU-GPU synchronizations. I used this pattern to verify the absence of synchronizations in my Github Gist about vLLM TokenGroup. The mechanism itself of raising errors is also unit tested in the test_helpers.py file in the Github Gist.

Warning

This PyTorch API is experimental as of today, and does not cover all CPU-GPU synchronizations. For example, it does not cover the torch.distributed and torch.sparse namespaces (see docs).

In Practice: vLLM

I encountered CPU-GPU synchronizations in the context of contributing speculative decoding to vLLM (PR #24322). In particular, the LLM forward passes are heavy operations, and preparing the input tokens for the forward pass requires lots of small PyTorch operations, which can introduce CPU-GPU synchronizations, if not carefully written. Senior NVIDIA engineer Benjamin Chislett helped me understand what causes the synchronizations, in particular dynamic shapes, and the value of writing custom kernels in Triton to fuse together multiple sequential PyTorch operations (example: PR #28597). I thank him for his support and feedback! 💪

Summary

In this post, I explained what CPU-GPU synchronizations are, and how to identify and diagnose them with the NVIDIA Nsight Systems profiler. I discussed dynamic shapes as a common trigger for CPU-GPU synchronizations, and the value of writing custom kernels in Triton to fuse together multiple sequential PyTorch operations. I also showed how to use the experimental PyTorch API torch.cuda.set_sync_debug_mode in unit tests to verify the absence of CPU-GPU synchronizations in production code. This guide aims to help engineers prevent CPU-GPU synchronizations, which is key to write fast and efficient PyTorch programs.

Further References

For a deep dive into kernel benchmarking in practice, I recommend this YouTube lecture by NVIDIA engineering manager Georgii Evtushenko.

Citation

BibTeX citation:
@online{ruiz2026,
  author = {Ruiz, Tomas},
  title = {PyTorch and {CPU-GPU} {Synchronizations}},
  date = {2026-01-07},
  url = {https://tomasruizt.github.io/posts/08_cpu_gpu_synchronization/},
  langid = {en}
}
For attribution, please cite this work as:
Ruiz, Tomas. 2026. “PyTorch and CPU-GPU Synchronizations.” January 7, 2026. https://tomasruizt.github.io/posts/08_cpu_gpu_synchronization/.