A Guide to Classify 35,000 Videos with a 72B Multimodal LLM on the vLLM Engine

GPUs
vLLM
Transformers
Author
Affiliation

Tomas Ruiz

Ludwig-Maximilians-Universität München

Published

September 22, 2025

Summary

In this post, I describe how we used the vLLM inference engine to classify 35k videos collected from TikTok for a research project. I share lessons learned about computing on SLURM, parallelism strategies (tensor- and data-parallelism), and practical tips for fault tolerance and video inference with vLLM. I also dive into concrete token throughput statistics of our workload, which is extremely prefill-heavy. Finally, I show how we corrected for the bias of the Qwen2.5-VL-72B model to make statistically valid statements about our dataset. As a bonus, I also share Python code for video classification using the vLLM library in Section 13, as well as the SLURM file for our workload in Section 12.

Motivation

For this research project, I collaborated with my colleague at LMU Munich, Renata Topinkova, a computational social scientist. Our goal was to detect subtle product sales by influencers on the platform. We presented a poster with our findings at the IC2S2 conference in 2025.

This post focuses mainly on our large inference workload, which required running a large multimodal LLM (72B parameters) on 35k videos as quickly as possible before the deadline. The only way to get this done in reasonable time was with an inference engine. I used vLLM (Kwon et al. 2023), because it’s the engine I am most familiar with. There are other alternative engines like SGLang (Zheng et al. 2024) or NVIDIA Dynamo that you can try.

Dataset and Model

Our dataset consists of 35,000 multimodal TikTok posts made up of videos, images, audio and text. The size of the dataset is around 180GB, the majority of which is from the videos. The dataset is a special subsample of a much bigger and more general dataset of ~1M posts we have.

To analyze this dataset, we needed a multimodal model capable of consuming inputs with multiple modalities. On top of that, the cues we looked for in the videos were often subtle and implicit: E.g.,

“Does the post suggest a follow-up transaction or sales interaction?”

So we also needed the model to grasp subtle communication cues in the videos. Larger models, typically in the range of 70B+ parameters, are much more competent at this than the widely used 8B models.

We decided to use the Qwen2.5-VL-72B-Instruct model by the Qwen team (Bai et al. 2025) (Hugging Face link). Other multimodal models require you to cut up videos into individual frames and pass a video as a “stack of images”, but the Qwen2.5-VL model instead natively consumes the videos directly. The key innovation to do this is called multidimensional RoPE (mRoPE) embeddings. mRoPE encodes the time dimension of the video into a separate dimension (Wang et al. 2024). Thus, Qwen2.5-VL understands the temporal relationship between individual video frames, unlike other multimodal models.

Audio: An important shortcoming of this model is that it does not process audio natively. To alleviate this, we transcribed the audio using Whisper (Radford et al. 2023) and included the transcript in the text input. Some models that are capable of processing audio natively have been released recently: Qwen2.5-Omni or Phi-4 (Xu et al. 2025; Abdin et al. 2024). However, the support for vLLM was limited until recently, the models remain relatively small in terms of size, and often output speech, which is not something we needed. Nevertheless, they are a very promising direction to explore.

Video Input: Feeding a video to the model through vLLM requires a specific setup: Many examples for image analysis with vLLM / OpenAI API serialize the image into base64 encoding and send it in the request payload (e.g., OpenAI docs, Google docs, even vLLM examples). However, I found that this was not scalable for videos, because it blocked the server while it processed the payloads, and left the GPU idle. Instead, we only send the local file path referencing the video in the payload. vLLM reads the file from disk, processes it, and feeds it to the model. This removes the serialization and deserialization overhead. For this to work, we must give the vLLM server explicit permission to read from a specific directory (and subdirectories) using this option:

vllm serve ... --allowed-local-media-path=/path/to/your/files/

The way to include a video file path in the payload is shown below. Note that the prefix file:// is important. Both this file path and the path passed to allowed-local-media-path must be absolute paths. A full example can be found in their documentation: link.

filepath = "/path/to/your/files/my-video.mp4"
messages = [
  {"role": "user", "content": [
    {"type": "text", "text": "What is in the video?"},
    {"type": "video_url", "video_url": {"url": f"file://{filepath}"}}
  ]}
]

GPU Cluster

For our workload, we needed to run multiple instances of this large 72B model in parallel to process the dataset quickly (i.e. within a few hours). To this end, we used the LRZ AI Systems cluster. We used the BayernKI partition, which has 30 nodes, each with 4 Hopper H100 GPUs (94 GB VRAM per GPU) connected over NVLink. These are beefy GPUs, but nevertheless, the 72B model would not fit on a single GPU, even in bfloat16 precision, since only the model weights require around 144 GB VRAM. The model would fit on a single Blackwell B200 GPU (180 GB VRAM), but those are not available on our cluster. Beyond the model weights, the model also requires space for the KV cache, model activations, etc. Therefore, we allocated 4 GPUs, totaling 376 GB VRAM, per model instance. To launch jobs, the cluster uses the SLURM job scheduler (Yoo, Jette, and Grondona 2003).

Containers: On the LRZ AI Systems, one cannot install and customize a Python environment on the worker nodes directly. Instead, one has to use NVIDIA enroot containers. These are conceptually similar to Docker containers, and the LRZ has introductory documentation about enroot here. There are a couple of useful commands to create and start containers including enroot import, enroot create and enroot start. The code in Section 12 shows how we used these commands in our SLURM file.

Data Transfer: To move your data into the LRZ AI System, they integrated a tool called Globus Data Transfer. This is special software to transfer large amounts of data between different servers. It is preinstalled on the LRZ AI Systems, and we installed it on our server as well. Then we used the web interface to transfer our dataset to the GPU cluster. The transfer is asynchronous and will email you when it’s done. It also works the other way around, to transfer data back from the LRZ to your server.

Parallelism Strategies

vLLM supports distributing the model on multiple GPUs, a strategy called Tensor Parallelism. In essence, this means that each model tensor is split up across multiple GPUs so that no single GPU must hold the entire model in memory. The drawback of this is that the GPUs must communicate with each other to generate the model outputs. Nevertheless, with the high-bandwidth NVLink, this overhead is acceptable for the ability to run larger models. The option to use multiple GPUs in vLLM is very simple: vllm serve ... --tensor-parallel-size=4.

In our workload, each model instance processed only a portion of the full dataset, a strategy called Data Parallelism. I implemented data parallelism on the client side, outside vLLM, which means that the client decides what data to send to the vLLM server based on its own rank within the full workload. SLURM supports a feature called Job Arrays, which allow you to submit a collection of similar jobs. Each job in the collection gets a unique serial ID (e.g., 0, 1, …, 100), which SLURM sets in the environment variable SLURM_ARRAY_TASK_ID. The model instance can read this ID (its rank) and understand which partition of the dataset it is assigned to process. I found the descriptions of these parallelism strategies on the Hugging Face Parallelism Methods page useful, for starters. The code in Section 12 shows how we used SLURM_ARRAY_TASK_ID in our SLURM file.

Different Parallelism Strategies

Different Parallelism Strategies: Data Parallelism partitions the dataset so that \(N\) different model instances process different parts of the dataset. Tensor Parallelism distributes a model instance across \(K\) separate GPUs. In our workload \(N=10\) and \(K=4\).

Fault Tolerance

The vLLM library supports different inference modes: It is possible to use batch processing, which takes a batch of requests at once, processes them and returns the results. Alternatively, it’s possible to spin up an inference server behind an OpenAI-compatible API, which takes requests asynchronously, and returns the results individually as they become available. For our large workload, we expected the batch processing to be more efficient, but at the time of writing, the batch processing code blocked until all requests were completed before returning. If any error occurred during processing, the entire batch would fail, and failures were more common than one would expect: e.g., an exception was raised mid-workload because a long video did not fit in the context window.

Therefore, I opted for the asynchronous inference mode, which offered better fault tolerance. As soon as a request completed, we dumped its result to a file with JSON-lines format. If the workload failed halfway, we still had all previous results on disk. If a single request failed (e.g. video too long), it was also logged to the results file. This approach made it possible to restart a failed workload later, by determining which requests were still missing from the results file. The JSON-lines file format allows for efficient append-only incremental writes, but is not a very efficient data format for downstream use, because it takes up a lot of disk space and is slow to read. Therefore, once result files were complete, we converted them to the Parquet format, which has great lossless compression and is super fast to read.

Sending thousands of multimodal requests to the server concurrently choked it, because the server accepted all requests and started processing all videos (e.g., loading from disk), but neglected to generate tokens. On top of this, processing too many concurrent requests filled up the KV cache very quickly. You could observe this in the vLLM logs: the KV cache utilization spiked to 100% and stayed there for most of the workload. To cope with this, the server evicted blocks from the KV cache, and put most requests back in the queue, where they effectively sat idle. To prevent this situation, I implemented a client-side throttling mechanism, which limited the number of requests sent to the server at the same time. We set 16 concurrent requests for Qwen2.5-VL-72B. Getting this to work properly with Python’s asyncio library and correct timeouts was tricky, but now you can just copy the code from my repo or use the Python code from Section 13 directly.

Preventing Clashes

Multiple model instances might be co-located on the same node, e.g., an 8 GPU node can host 2×4 GPU jobs. To prevent the jobs from clashing over the same vLLM port, we set different ports for each model instance based on its rank (SLURM_ARRAY_TASK_ID). We also gave each model instance a unique enroot container name based on the SLURM_ARRAY_TASK_ID, to prevent co-located jobs from clashing over the same container name. Also, we deleted the created enroot container after the job finished. To diagnose the failure of an individual job task within a job array, it’s useful to pipe stderr and stdout of each job task to different files, as shown in lines 5 and 6 of our SLURM file in Section 12. It’s possible to restart individual job tasks within a job array with SLURM: sbatch <slurm-file> --array=<failed-task-id>.

Prefill-Decode Ratio

We processed over 552 million input tokens, and generated over 9 million output tokens. Per request, we processed 15.8k input tokens and produced over 250 output tokens, on average. The ratio of input tokens to output tokens is very skewed (61 to 1). It means that most compute in our workload was not spent decoding long answers, but rather processing the many input tokens. The workload is therefore called “prefill-heavy”. This contrasts to “decode-heavy” workloads, like those involved in solving mathematical problems that generate a long answer including reasoning. The prefill step is generally speaking compute-bound rather than memory-bound, so our workload likely exploits the full compute capacity of the GPUs. Figure 1 and Figure 2 show the distribution of input and output tokens for the entire workload.

Throughput Stats

I estimated from the logs that the throughput per model instance was ~2.95 seconds per request (~3 for simplicity). Dividing the ~15k input tokens and ~250 output tokens per request gives a throughput of ~5k input tokens and ~83 output tokens per second per model instance (≈5,080 tokens/s total throughput).

This is a high throughput compared to other benchmarks: The vLLM documentation reports a synthetic benchmark with a ratio of ~9:1 input tokens per output token, which achieves a lower throughput of 1461 tokens/s. The higher throughput of our workload might be just because we use better GPUs (H100s vs A100s). NVIDIA has also reported performance benchmarks for the similarly-sized model Llama-3.3-70b-instruct running on 4 H100 GPUs and their NIM inference framework. However, the peak throughputs reported by NVIDIA vary a lot depending not only on the ratio of input to output tokens, but also the total number of tokens per request, with throughput peaks ranging from 349 to 6085 tokens/s (I report their best results in Table 1). The throughput we achieved is competitive without being “too good to be true” 🚀.

Table 1: Throughput Comparison. All benchmarks use instruct models. Input and output tokens are per request. The token ratio tells if the workload is prefill- or decode-heavy. The NVIDIA numbers are the most optimistic from their benchmarks.
Benchmark Model In. Toks Out. Toks Tok. Ratio GPUs Throughput \(\uparrow\)
NVIDIA NIM Llama-3.3-70b 1000 1000 1:1 4 H100 80GB 6085 tok/s
Ours (Videos) Qwen2.5-VL-72B 15000 250 61:1 4 H100 96GB 5020 tok/s
vLLM (Synthetic) Qwen2.5-VL-72B 8000 900 9:1 4 A100 80GB 1461 tok/s

All benchmarks used different GPUs. I searched for the difference between the H100 80GB and the H100 96GB GPUs, and found this comparison which says the former card has more bfloat16 compute power (1979 vs 1671 TFLOPS), while the latter has higher memory bandwidth (3.9 vs 3.35 TB/s), so it’s not clear if they either one is better.

Costs

In terms of total runtime, the workload required 3s/post * 35k posts = 105k s \(\approx\) 29hr running a single model instance (4 GPUs). The cost of renting an equivalent H100 GPU on Runpod.io is $3.07/hr per GPU (as of 2025-09-24). The total cost of the workload would be $3.07 * 4 GPUs * 29hr = $356.12 (not a very expensive workload). Nevertheless, this does not include the time spent setting up the workload, debugging failures, and tweaking parameters, which required running at least one model instance (~$12/hr).

Results

In terms of social science, we found interesting results, e.g., that a lot of posts promoted a product (42%), and that the most common product sold was dietary supplements (16%). For quality control, we compared the Qwen labels with human labels in a random subset of ~150 posts on a specific product detection question:

“Does the post mention a product or service that is being offered, promoted, or recommended? Answer with ‘yes’ or ‘no’.”

On this subset, we also used Gemini-2.5-Pro for classification, intended as an upper performance reference. We compared both models in terms of precision, recall, and F1 scores in Table 2. In plain words, the table shows that Qwen was too strict about what constitutes a product mention, while Gemini was too lax, but achieved a better balance.

Table 2: Model Comparison. Gemini outperformed Qwen in recall, and in F1, but Qwen had higher precision.
Model Precision Recall F1 Score
Qwen2.5-VL-72B 0.90 0.69 0.78
Gemini-2.5-Pro 0.81 0.98 0.89

Correcting for AI Bias: The numbers showed that both models had a bias in a specific direction that needed to be corrected to answer a question like the following:

“If humans were to label all 35k posts, what fraction would they label as mentioning a product?”

It’s possible to combine the few human labels with the thousands of structurally biased AI labels to answer this question. We followed a method described in Confidence-Driven Inference by Gligoric et al. (2025) (also presented in IC2S2). When computing an average using the combined human and AI labels, the method weights the AI labels with a weight \(\lambda \in [0,1]\). The weight is 0 when the AI labels are completely uncorrelated to the human labels, and 1 when they are completely correlated. This method allowed us to correct the Qwen model bias in our estimates, and also to make statements about confidence intervals like the following:

“The percentage of the 35k posts mentioning a product is ~53%, with the 95% confidence interval being 47% to 59%.”

The confidence intervals are useful for hypothesis testing, and for making statements about statistical significance. More details in our IC2S2 poster.

Conclusion

I showed a practical recipe to classify a dataset of 35,000 videos with a 72B parameter model. We went over the infrastructure needed in terms of GPUs, tensor- and data-parallelism strategies, SLURM commands, concrete precautions for fault tolerance, and tips on video inference with vLLM. Our workload was a prefill-heavy workload, and I showed how the throughput we achieved is competitive with existing benchmarks by NVIDIA and vLLM. Finally, I explained how to use thousands of potentially biased AI labels to make statistically valid arguments. I hope this guide is useful to other researchers aiming to analyze large video datasets with multimodal LLMs that require multiple GPUs.

Outlook: In the future, I hope to scale this pipeline to 1M videos and push the throughput limits with quantization methods and speculative decoding. Testing audio-capable models is also high on my agenda. If you are interested in collaborating on large-scale inference, reach out to me! 🚀

SLURM File

Below is a GitHub Gist of the SLURM file to submit our workload on the LRZ AI Systems. It shows how to request multiple GPUs, configure SLURM job arrays, set vLLM ports, and set up and tear down the enroot container, etc. Unlike the Python code, this will not run without the dataset, but it shows the general workflow.

Python Code

The code below shows how to use Python to (1) start up the vLLM server, and (2) send concurrent requests to the server. It uses a library I wrote for this purpose: GitHub/tomasruizt/llm_app. The snippet below is in the repo’s examples/ directory.

References

Abdin, Marah, Jyoti Aneja, Harkirat Behl, Sébastien Bubeck, Ronen Eldan, Suriya Gunasekar, Michael Harrison, et al. 2024. “Phi-4 Technical Report.” https://arxiv.org/abs/2412.08905.
Bai, Shuai, Keqin Chen, Xuejing Liu, Jialin Wang, Wenbin Ge, Sibo Song, Kai Dang, et al. 2025. “Qwen2.5-VL Technical Report.” https://arxiv.org/abs/2502.13923.
Gligoric, Kristina, Tijana Zrnic, Cinoo Lee, Emmanuel Candes, and Dan Jurafsky. 2025. “Can Unconfident LLM Annotations Be Used for Confident Conclusions?” In Proceedings of the 2025 Conference of the Nations of the Americas Chapter of the Association for Computational Linguistics: Human Language Technologies (Volume 1: Long Papers), edited by Luis Chiruzzo, Alan Ritter, and Lu Wang, 3514–33. Albuquerque, New Mexico: Association for Computational Linguistics. https://doi.org/10.18653/v1/2025.naacl-long.179.
Kwon, Woosuk, Zhuohan Li, Siyuan Zhuang, Ying Sheng, Lianmin Zheng, Cody Hao Yu, Joseph Gonzalez, Hao Zhang, and Ion Stoica. 2023. “Efficient Memory Management for Large Language Model Serving with PagedAttention.” In SOSP, 611–26. https://doi.org/10.1145/3600006.3613165.
Radford, Alec, Jong Wook Kim, Tao Xu, Greg Brockman, Christine Mcleavey, and Ilya Sutskever. 2023. “Robust Speech Recognition via Large-Scale Weak Supervision.” In Proceedings of the 40th International Conference on Machine Learning, edited by Andreas Krause, Emma Brunskill, Kyunghyun Cho, Barbara Engelhardt, Sivan Sabato, and Jonathan Scarlett, 202:28492–518. Proceedings of Machine Learning Research. PMLR. https://proceedings.mlr.press/v202/radford23a.html.
Wang, Peng, Shuai Bai, Sinan Tan, Shijie Wang, Zhihao Fan, Jinze Bai, Keqin Chen, et al. 2024. “Qwen2-VL: Enhancing Vision-Language Model’s Perception of the World at Any Resolution.” https://arxiv.org/abs/2409.12191.
Xu, Jin, Zhifang Guo, Jinzheng He, Hangrui Hu, Ting He, Shuai Bai, Keqin Chen, et al. 2025. “Qwen2.5-Omni Technical Report.” https://arxiv.org/abs/2503.20215.
Yoo, Andy B, Morris A Jette, and Mark Grondona. 2003. “Slurm: Simple Linux Utility for Resource Management.” In Workshop on Job Scheduling Strategies for Parallel Processing, 44–60. Springer.
Zheng, Lianmin, Liangsheng Yin, Zhiqiang Xie, Chuyue Sun, Jeff Huang, Cody Hao Yu, Shiyi Cao, et al. 2024. SGLang: Efficient Execution of Structured Language Model Programs.” In The Thirty-Eighth Annual Conference on Neural Information Processing Systems. https://openreview.net/forum?id=VqkAKQibpq.

Citation

BibTeX citation:
@online{ruiz2025,
  author = {Ruiz, Tomas},
  title = {A {Guide} to {Classify} 35,000 {Videos} with a {72B}
    {Multimodal} {LLM} on the {vLLM} {Engine}},
  date = {2025-09-22},
  url = {https://tomasruizt.github.io/posts/mm-inference-on-vllm/},
  langid = {en}
}
For attribution, please cite this work as:
Ruiz, Tomas. 2025. “A Guide to Classify 35,000 Videos with a 72B Multimodal LLM on the vLLM Engine.” September 22, 2025. https://tomasruizt.github.io/posts/mm-inference-on-vllm/.