Skip to content

vllm.v1.worker.gpu.mm.encoder_cudagraph

CUDA graph manager for vision encoder budget-batch execution.

BudgetGraphMetadata dataclass

Metadata for a single budget graph.

CUDA graph replay pattern: 1. Copy new batch data into input_buffer (e.g. pixel_values) 2. Copy precomputed values into metadata_buffers 3. Replay graph 4. Read encoder outputs from output_buffer

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
@dataclass
class BudgetGraphMetadata:
    """Metadata for a single budget graph.

    CUDA graph replay pattern:
    1. Copy new batch data into input_buffer (e.g. pixel_values)
    2. Copy precomputed values into metadata_buffers
    3. Replay graph
    4. Read encoder outputs from output_buffer
    """

    token_budget: int
    max_batch_size: int  # Max number of images/videos per batch
    graph: torch.cuda.CUDAGraph
    # The input tensor updated before replay (e.g. pixel_values)
    input_buffer: torch.Tensor
    # Buffers recorded into the CUDA graph (e.g. embeddings, sequence metadata).
    # Before replay the manager zeros then slice-copies new data into these.
    metadata_buffers: dict[str, torch.Tensor]
    # Output written by graph, read after replay
    output_buffer: torch.Tensor

EncoderCudaGraphManager

Budget-based CUDA graph capture/replay for vision encoders.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
class EncoderCudaGraphManager:
    """Budget-based CUDA graph capture/replay for vision encoders."""

    def __init__(
        self,
        vllm_config: VllmConfig,
        device: torch.device,
        dtype: torch.dtype,
        model: SupportsEncoderCudaGraph,
    ):
        """Initialize CUDA graph manager with provided token budgets
        and max batch size."""
        self.vllm_config = vllm_config
        self.device = device
        self.dtype = dtype
        self.model = model
        self.config: EncoderCudaGraphConfig = model.get_encoder_cudagraph_config()

        comp_config = vllm_config.compilation_config
        user_budgets = comp_config.encoder_cudagraph_token_budgets
        user_max_images = comp_config.encoder_cudagraph_max_images_per_batch

        if user_budgets and user_max_images > 0:
            # Fully user-specified
            self.token_budgets = sorted(user_budgets)
            self.max_batch_size = user_max_images
        else:
            # Auto-infer missing values from model
            min_budget, max_budget = model.get_encoder_cudagraph_budget_range(
                vllm_config
            )
            self.token_budgets = (
                sorted(user_budgets)
                if user_budgets
                else self._generate_budgets(min_budget, max_budget)
            )
            self.max_batch_size = (
                user_max_images if user_max_images > 0 else max_budget // min_budget
            )

        mm_config = vllm_config.model_config.multimodal_config
        self.use_dp = (
            mm_config is not None
            and mm_config.mm_encoder_tp_mode == "data"
            and vllm_config.parallel_config.tensor_parallel_size > 1
        )

        self.budget_graphs: dict[int, BudgetGraphMetadata] = {}
        self.graph_hits = 0
        self.graph_misses = 0
        self.log_stats_interval = 100

        logger.info(
            "EncoderCudaGraphManager initialized with "
            "budgets=%s, max_batch_size=%d, use_dp=%s",
            self.token_budgets,
            self.max_batch_size,
            self.use_dp,
        )

    @staticmethod
    def _generate_budgets(min_budget: int, max_budget: int) -> list[int]:
        """Generate power-of-2 token budgets from min_budget to max_budget."""
        budgets: list[int] = []
        b = min_budget
        while b <= max_budget:
            budgets.append(b)
            b *= 2
        # Always include max_budget if it's not already a power-of-2 boundary
        if not budgets or budgets[-1] < max_budget:
            budgets.append(max_budget)
        return budgets

    def supports_modality(self, modality: str) -> bool:
        """Check if a modality is supported by this manager."""
        return modality in self.config.modalities

    def capture(self):
        """Capture CUDA graphs for all token budgets."""
        for token_budget in self.token_budgets:
            self._capture_budget_graph(token_budget)

        logger.info(
            "Encoder CUDA graph capture complete. Captured %d budget graphs.",
            len(self.budget_graphs),
        )

    def _capture_budget_graph(self, token_budget: int):
        """Capture CUDA graph for a single token budget."""
        logger.debug(
            "Capturing encoder cudagraph for budget=%d, max_batch_size=%d",
            token_budget,
            self.max_batch_size,
        )

        capture_inputs = self.model.prepare_encoder_cudagraph_capture_inputs(
            token_budget, self.max_batch_size, self.device, self.dtype
        )

        mm_kwargs = capture_inputs.mm_kwargs
        buffers = capture_inputs.buffers

        with torch.inference_mode():
            output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
            output_buffer = torch.empty_like(output)

        graph = torch.cuda.CUDAGraph()
        with torch.inference_mode(), torch.cuda.graph(graph):
            output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
            output_buffer.copy_(output)

        input_key = self.config.input_key
        self.budget_graphs[token_budget] = BudgetGraphMetadata(
            token_budget=token_budget,
            max_batch_size=self.max_batch_size,
            graph=graph,
            input_buffer=mm_kwargs[input_key],
            metadata_buffers=buffers,
            output_buffer=output_buffer,
        )

    def _find_smallest_fitting_budget_given_tokens(
        self, total_tokens: int
    ) -> int | None:
        """Find smallest budget >= total_tokens.

        Returns:
            Token budget if found, None if no fitting budget.
        """
        for budget in self.token_budgets:
            if budget >= total_tokens:
                return budget
        return None

    def _get_per_item_out_tokens(self, mm_kwargs: dict[str, Any]) -> list[int]:
        """Get per-item output token counts as plain ints."""
        return [
            int(t)
            for t in self.model.get_encoder_cudagraph_per_item_output_tokens(mm_kwargs)
        ]

    @staticmethod
    def _scatter_output_slices(
        output: torch.Tensor,
        indices: list[int],
        per_item_out_tokens: list[int],
        dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
        clone: bool = False,
    ) -> None:
        """Slice a concatenated output tensor and scatter into dest by index."""
        offset = 0
        for idx in indices:
            n_tok = per_item_out_tokens[idx]
            sliced = output[offset : offset + n_tok]
            dest[idx] = sliced.clone() if clone else sliced
            offset += n_tok

    def _run_budget_graph(
        self,
        mm_kwargs: dict[str, Any],
        token_budget: int,
        replay_buffers: dict[str, torch.Tensor | None],
    ) -> torch.Tensor | None:
        """Execute budget graph.

        Args:
            mm_kwargs: Multimodal inputs for the batch.
            token_budget: Token budget to use.
            replay_buffers: Buffer values to copy into captured buffers.
                None values leave the corresponding buffer unchanged.

        Returns:
            Encoder outputs, or None if graph not captured.
        """
        num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
        if token_budget not in self.budget_graphs:
            self.graph_misses += num_items
            return None

        graph_meta = self.budget_graphs[token_budget]

        # Copy the input tensor. Buffers are sized for the full budget;
        # actual inputs may be smaller. Zero then slice-copy so padded
        # positions are invisible to attention (cu_seqlens masks them out).
        input_key = self.config.input_key
        src = mm_kwargs[input_key]
        n = src.shape[0]
        graph_meta.input_buffer.zero_()
        graph_meta.input_buffer[:n].copy_(src)

        # Copy metadata buffers using keys from config.buffer_keys.
        for key in self.config.buffer_keys:
            src = replay_buffers.get(key)
            if src is None:
                continue
            buf = graph_meta.metadata_buffers[key]
            if src.ndim == 0:
                buf.copy_(src)
            else:
                n = src.shape[0]
                buf.zero_()
                buf[:n].copy_(src)

        graph_meta.graph.replay()

        self.graph_hits += num_items
        return graph_meta.output_buffer

    def _execute_local(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[torch.Tensor]:
        """Execute encoder on local inputs using greedy-packed CUDA graphs.

        Sort images by output token count (smallest first), then greedily pack
        as many images as possible into each batch while staying within
        max_budget tokens and max_batch_size. Once a batch is finalised (next
        image would overflow either constraint), find the smallest fitting
        budget once for that batch.

        By exchange argument, greedy smallest-first packing minimises eager
        fallbacks -- any other ordering yields a higher token sum in some batch,
        making that batch more likely to exceed the budget.

        Stats note:
          graph_hits  -- counted inside _run_budget_graph after successful replay.
          graph_misses -- counted here for single-image batches where the image
                         exceeds max_budget. Batches split due to max_batch_size
                         always satisfy total_tokens <= max_budget and therefore
                         always find a valid budget (no miss).
        """
        num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
        max_budget = self.token_budgets[-1]

        per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)

        # Sort ascending by output token count (smallest first)
        sorted_indices = sorted(range(num_items), key=lambda i: per_item_out_tokens[i])

        # Greedy pack against max_budget and max_batch_size.
        # _find_smallest_fitting_budget_given_tokens is called once per
        # finalised batch, not per image.
        batches: list[tuple[list[int], int | None]] = []
        current_batch: list[int] = []
        current_batch_tokens = 0

        for orig_idx in sorted_indices:
            item_tokens = per_item_out_tokens[orig_idx]
            if (
                current_batch_tokens + item_tokens <= max_budget
                and len(current_batch) < self.max_batch_size
            ):
                current_batch.append(orig_idx)
                current_batch_tokens += item_tokens
            else:
                if current_batch:
                    batches.append(
                        (
                            current_batch,
                            self._find_smallest_fitting_budget_given_tokens(
                                current_batch_tokens
                            ),
                        )
                    )
                current_batch = [orig_idx]
                current_batch_tokens = item_tokens

        if current_batch:
            batches.append(
                (
                    current_batch,
                    self._find_smallest_fitting_budget_given_tokens(
                        current_batch_tokens
                    ),
                )
            )

        # outputs_by_orig_idx maps each original image index to its output
        # tensor. Needed because greedy packing reorders images; we restore
        # the original order before returning.
        outputs_by_orig_idx: dict[int, torch.Tensor] = {}

        for batch_orig_indices, token_budget in batches:
            batch_mm_kwargs = self.model.select_encoder_cudagraph_items(
                mm_kwargs, batch_orig_indices
            )
            batch_out_tokens = sum(per_item_out_tokens[i] for i in batch_orig_indices)

            if token_budget is None:
                # Single oversized image: item_tokens > max_budget.
                # graph_misses counted here for this eager fallback.
                logger.debug(
                    "Encoder CUDA graph fallback to eager: no budget for "
                    "%d tokens from %d images",
                    batch_out_tokens,
                    len(batch_orig_indices),
                )
                self.graph_misses += len(batch_orig_indices)
                with torch.inference_mode():
                    raw = self.model.encoder_eager_forward(batch_mm_kwargs)
                self._scatter_output_slices(
                    raw,
                    batch_orig_indices,
                    per_item_out_tokens,
                    outputs_by_orig_idx,
                )
            else:
                logger.debug(
                    "Encoder CUDA graph: batch_size=%d, tokens=%d, "
                    "budget=%d, waste=%.1f%%",
                    len(batch_orig_indices),
                    batch_out_tokens,
                    token_budget,
                    (token_budget - batch_out_tokens) / token_budget * 100,
                )
                replay = self.model.prepare_encoder_cudagraph_replay_buffers(
                    batch_mm_kwargs, self.max_batch_size
                )

                # graph_hits counted inside _run_budget_graph after replay.
                output = self._run_budget_graph(
                    batch_mm_kwargs, token_budget, replay.buffers
                )
                assert output is not None
                self._scatter_output_slices(
                    output,
                    batch_orig_indices,
                    per_item_out_tokens,
                    outputs_by_orig_idx,
                    clone=True,
                )

        # Return in original batch order (caller maps outputs to token positions)
        return [outputs_by_orig_idx[i] for i in range(num_items)]

    def _dp_shard(
        self,
        mm_kwargs: dict[str, Any],
        per_item_out_tokens: list[int],
    ) -> tuple[dict[str, Any], list[int], list[int], int]:
        """Distribute items across TP ranks for data-parallel execution.

        Uses get_load_balance_assignment() to balance load by input size,
        then select_encoder_cudagraph_items() to extract each rank's inputs.

        Returns:
            local_mm_kwargs: Inputs for this rank.
            image_rank_assignment: Flattened assignment order across all ranks.
            images_per_rank: Number of items per rank.
            max_output_tokens_per_rank: Max output tokens across all ranks
                (for padding during all_gather).
        """
        tp_size = get_tensor_model_parallel_world_size()
        current_rank = get_tensor_model_parallel_rank()

        per_item_input_sizes = self.model.get_encoder_cudagraph_per_item_input_sizes(
            mm_kwargs
        )

        (image_rank_assignment, images_per_rank, input_patches_per_rank) = (
            get_load_balance_assignment(per_item_input_sizes, tp_size)
        )

        # Extract local indices for this rank
        cum_images_per_rank = [0]
        for count in images_per_rank:
            cum_images_per_rank.append(cum_images_per_rank[-1] + count)

        local_indices = image_rank_assignment[
            cum_images_per_rank[current_rank] : cum_images_per_rank[current_rank + 1]
        ]

        if len(local_indices) > 0:
            local_mm_kwargs = self.model.select_encoder_cudagraph_items(
                mm_kwargs, local_indices
            )
        else:
            local_mm_kwargs = self.model.select_encoder_cudagraph_items(mm_kwargs, [])

        max_output_tokens_per_rank = (
            max(
                sum(
                    per_item_out_tokens[i]
                    for i in image_rank_assignment[
                        cum_images_per_rank[r] : cum_images_per_rank[r + 1]
                    ]
                )
                for r in range(tp_size)
            )
            if len(per_item_out_tokens) > 0
            else 0
        )

        return (
            local_mm_kwargs,
            image_rank_assignment,
            images_per_rank,
            max_output_tokens_per_rank,
        )

    def _dp_gather(
        self,
        local_outputs: list[torch.Tensor],
        per_item_out_tokens: list[int],
        image_rank_assignment: list[int],
        images_per_rank: list[int],
        max_output_tokens_per_rank: int,
    ) -> list[torch.Tensor]:
        """Gather outputs from all TP ranks and reorder to original sequence.

        Assumes 2D output tensors [tokens, hidden]. Follows the same
        pad -> all_gather -> unpad -> reorder algorithm as
        run_dp_sharded_mrope_vision_model() in the eager path.
        """
        hidden_size = self.config.out_hidden_size
        tp_size = len(images_per_rank)

        if len(local_outputs) > 0:
            local_concat = torch.cat(local_outputs, dim=0)
        else:
            local_concat = torch.empty(
                (0, hidden_size), device=self.device, dtype=self.dtype
            )

        # Pad to max_output_tokens_per_rank for all_gather
        current_len = local_concat.shape[0]
        if current_len < max_output_tokens_per_rank:
            padding = torch.empty(
                (max_output_tokens_per_rank - current_len, hidden_size),
                dtype=self.dtype,
                device=self.device,
            )
            local_padded = torch.cat([local_concat, padding], dim=0)
        else:
            local_padded = local_concat

        gathered = tensor_model_parallel_all_gather(local_padded, dim=0)

        # Unpad each rank's contribution
        rank_outputs: list[torch.Tensor] = []
        current_idx = 0
        for rank in range(tp_size):
            start = rank * max_output_tokens_per_rank
            rank_count = images_per_rank[rank]
            rank_indices = image_rank_assignment[current_idx : current_idx + rank_count]
            rank_tokens = sum(per_item_out_tokens[i] for i in rank_indices)
            current_idx += rank_count
            rank_outputs.append(gathered[start : start + rank_tokens])

        # Reorder to original sequence
        total_items = len(per_item_out_tokens)
        result: list[torch.Tensor | None] = [None] * total_items
        current_idx = 0
        for rank in range(tp_size):
            count = images_per_rank[rank]
            if count > 0:
                rank_items = image_rank_assignment[current_idx : current_idx + count]
                self._scatter_output_slices(
                    rank_outputs[rank],
                    rank_items,
                    per_item_out_tokens,
                    result,
                )
                current_idx += count

        return [t for t in result if t is not None]

    def execute(
        self,
        mm_kwargs: dict[str, Any],
    ) -> list[torch.Tensor]:
        """Execute encoder using CUDA graph with optional DP.

        Args:
            mm_kwargs: Multimodal keyword arguments containing the
                input tensor and grid dimensions.

        Returns:
            List of encoder outputs (one per item).
        """
        if self.use_dp:
            per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)

            (
                local_mm_kwargs,
                image_rank_assignment,
                images_per_rank,
                max_output_tokens_per_rank,
            ) = self._dp_shard(mm_kwargs, per_item_out_tokens)

            local_outputs = self._execute_local(local_mm_kwargs)

            result = self._dp_gather(
                local_outputs,
                per_item_out_tokens,
                image_rank_assignment,
                images_per_rank,
                max_output_tokens_per_rank,
            )
        else:
            result = self._execute_local(mm_kwargs)

        # Log cumulative stats periodically
        stats = self.get_cumulative_stats()
        total_requests = self.graph_hits + self.graph_misses
        if total_requests > 0 and total_requests % self.log_stats_interval == 0:
            logger.debug(
                "Encoder CUDA graph cumulative stats: "
                "hits=%d, misses=%d, hit_rate=%.1f%%",
                stats["graph_hits"],
                stats["graph_misses"],
                stats["hit_rate"] * 100,
            )

        return result

    def get_cumulative_stats(self) -> dict[str, Any]:
        """Get cumulative CUDA graph statistics."""
        total_requests = self.graph_hits + self.graph_misses
        hit_rate = self.graph_hits / total_requests if total_requests > 0 else 0.0

        return {
            "graph_hits": self.graph_hits,
            "graph_misses": self.graph_misses,
            "hit_rate": hit_rate,
            "num_budgets": len(self.budget_graphs),
            "token_budgets": self.token_budgets,
        }

__init__

__init__(
    vllm_config: VllmConfig,
    device: device,
    dtype: dtype,
    model: SupportsEncoderCudaGraph,
)

Initialize CUDA graph manager with provided token budgets and max batch size.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def __init__(
    self,
    vllm_config: VllmConfig,
    device: torch.device,
    dtype: torch.dtype,
    model: SupportsEncoderCudaGraph,
):
    """Initialize CUDA graph manager with provided token budgets
    and max batch size."""
    self.vllm_config = vllm_config
    self.device = device
    self.dtype = dtype
    self.model = model
    self.config: EncoderCudaGraphConfig = model.get_encoder_cudagraph_config()

    comp_config = vllm_config.compilation_config
    user_budgets = comp_config.encoder_cudagraph_token_budgets
    user_max_images = comp_config.encoder_cudagraph_max_images_per_batch

    if user_budgets and user_max_images > 0:
        # Fully user-specified
        self.token_budgets = sorted(user_budgets)
        self.max_batch_size = user_max_images
    else:
        # Auto-infer missing values from model
        min_budget, max_budget = model.get_encoder_cudagraph_budget_range(
            vllm_config
        )
        self.token_budgets = (
            sorted(user_budgets)
            if user_budgets
            else self._generate_budgets(min_budget, max_budget)
        )
        self.max_batch_size = (
            user_max_images if user_max_images > 0 else max_budget // min_budget
        )

    mm_config = vllm_config.model_config.multimodal_config
    self.use_dp = (
        mm_config is not None
        and mm_config.mm_encoder_tp_mode == "data"
        and vllm_config.parallel_config.tensor_parallel_size > 1
    )

    self.budget_graphs: dict[int, BudgetGraphMetadata] = {}
    self.graph_hits = 0
    self.graph_misses = 0
    self.log_stats_interval = 100

    logger.info(
        "EncoderCudaGraphManager initialized with "
        "budgets=%s, max_batch_size=%d, use_dp=%s",
        self.token_budgets,
        self.max_batch_size,
        self.use_dp,
    )

_capture_budget_graph

_capture_budget_graph(token_budget: int)

Capture CUDA graph for a single token budget.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def _capture_budget_graph(self, token_budget: int):
    """Capture CUDA graph for a single token budget."""
    logger.debug(
        "Capturing encoder cudagraph for budget=%d, max_batch_size=%d",
        token_budget,
        self.max_batch_size,
    )

    capture_inputs = self.model.prepare_encoder_cudagraph_capture_inputs(
        token_budget, self.max_batch_size, self.device, self.dtype
    )

    mm_kwargs = capture_inputs.mm_kwargs
    buffers = capture_inputs.buffers

    with torch.inference_mode():
        output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
        output_buffer = torch.empty_like(output)

    graph = torch.cuda.CUDAGraph()
    with torch.inference_mode(), torch.cuda.graph(graph):
        output = self.model.encoder_cudagraph_forward(mm_kwargs, buffers)
        output_buffer.copy_(output)

    input_key = self.config.input_key
    self.budget_graphs[token_budget] = BudgetGraphMetadata(
        token_budget=token_budget,
        max_batch_size=self.max_batch_size,
        graph=graph,
        input_buffer=mm_kwargs[input_key],
        metadata_buffers=buffers,
        output_buffer=output_buffer,
    )

_dp_gather

_dp_gather(
    local_outputs: list[Tensor],
    per_item_out_tokens: list[int],
    image_rank_assignment: list[int],
    images_per_rank: list[int],
    max_output_tokens_per_rank: int,
) -> list[Tensor]

Gather outputs from all TP ranks and reorder to original sequence.

Assumes 2D output tensors [tokens, hidden]. Follows the same pad -> all_gather -> unpad -> reorder algorithm as run_dp_sharded_mrope_vision_model() in the eager path.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def _dp_gather(
    self,
    local_outputs: list[torch.Tensor],
    per_item_out_tokens: list[int],
    image_rank_assignment: list[int],
    images_per_rank: list[int],
    max_output_tokens_per_rank: int,
) -> list[torch.Tensor]:
    """Gather outputs from all TP ranks and reorder to original sequence.

    Assumes 2D output tensors [tokens, hidden]. Follows the same
    pad -> all_gather -> unpad -> reorder algorithm as
    run_dp_sharded_mrope_vision_model() in the eager path.
    """
    hidden_size = self.config.out_hidden_size
    tp_size = len(images_per_rank)

    if len(local_outputs) > 0:
        local_concat = torch.cat(local_outputs, dim=0)
    else:
        local_concat = torch.empty(
            (0, hidden_size), device=self.device, dtype=self.dtype
        )

    # Pad to max_output_tokens_per_rank for all_gather
    current_len = local_concat.shape[0]
    if current_len < max_output_tokens_per_rank:
        padding = torch.empty(
            (max_output_tokens_per_rank - current_len, hidden_size),
            dtype=self.dtype,
            device=self.device,
        )
        local_padded = torch.cat([local_concat, padding], dim=0)
    else:
        local_padded = local_concat

    gathered = tensor_model_parallel_all_gather(local_padded, dim=0)

    # Unpad each rank's contribution
    rank_outputs: list[torch.Tensor] = []
    current_idx = 0
    for rank in range(tp_size):
        start = rank * max_output_tokens_per_rank
        rank_count = images_per_rank[rank]
        rank_indices = image_rank_assignment[current_idx : current_idx + rank_count]
        rank_tokens = sum(per_item_out_tokens[i] for i in rank_indices)
        current_idx += rank_count
        rank_outputs.append(gathered[start : start + rank_tokens])

    # Reorder to original sequence
    total_items = len(per_item_out_tokens)
    result: list[torch.Tensor | None] = [None] * total_items
    current_idx = 0
    for rank in range(tp_size):
        count = images_per_rank[rank]
        if count > 0:
            rank_items = image_rank_assignment[current_idx : current_idx + count]
            self._scatter_output_slices(
                rank_outputs[rank],
                rank_items,
                per_item_out_tokens,
                result,
            )
            current_idx += count

    return [t for t in result if t is not None]

_dp_shard

_dp_shard(
    mm_kwargs: dict[str, Any],
    per_item_out_tokens: list[int],
) -> tuple[dict[str, Any], list[int], list[int], int]

Distribute items across TP ranks for data-parallel execution.

Uses get_load_balance_assignment() to balance load by input size, then select_encoder_cudagraph_items() to extract each rank's inputs.

Returns:

Name Type Description
local_mm_kwargs dict[str, Any]

Inputs for this rank.

image_rank_assignment list[int]

Flattened assignment order across all ranks.

images_per_rank list[int]

Number of items per rank.

max_output_tokens_per_rank int

Max output tokens across all ranks (for padding during all_gather).

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def _dp_shard(
    self,
    mm_kwargs: dict[str, Any],
    per_item_out_tokens: list[int],
) -> tuple[dict[str, Any], list[int], list[int], int]:
    """Distribute items across TP ranks for data-parallel execution.

    Uses get_load_balance_assignment() to balance load by input size,
    then select_encoder_cudagraph_items() to extract each rank's inputs.

    Returns:
        local_mm_kwargs: Inputs for this rank.
        image_rank_assignment: Flattened assignment order across all ranks.
        images_per_rank: Number of items per rank.
        max_output_tokens_per_rank: Max output tokens across all ranks
            (for padding during all_gather).
    """
    tp_size = get_tensor_model_parallel_world_size()
    current_rank = get_tensor_model_parallel_rank()

    per_item_input_sizes = self.model.get_encoder_cudagraph_per_item_input_sizes(
        mm_kwargs
    )

    (image_rank_assignment, images_per_rank, input_patches_per_rank) = (
        get_load_balance_assignment(per_item_input_sizes, tp_size)
    )

    # Extract local indices for this rank
    cum_images_per_rank = [0]
    for count in images_per_rank:
        cum_images_per_rank.append(cum_images_per_rank[-1] + count)

    local_indices = image_rank_assignment[
        cum_images_per_rank[current_rank] : cum_images_per_rank[current_rank + 1]
    ]

    if len(local_indices) > 0:
        local_mm_kwargs = self.model.select_encoder_cudagraph_items(
            mm_kwargs, local_indices
        )
    else:
        local_mm_kwargs = self.model.select_encoder_cudagraph_items(mm_kwargs, [])

    max_output_tokens_per_rank = (
        max(
            sum(
                per_item_out_tokens[i]
                for i in image_rank_assignment[
                    cum_images_per_rank[r] : cum_images_per_rank[r + 1]
                ]
            )
            for r in range(tp_size)
        )
        if len(per_item_out_tokens) > 0
        else 0
    )

    return (
        local_mm_kwargs,
        image_rank_assignment,
        images_per_rank,
        max_output_tokens_per_rank,
    )

_execute_local

_execute_local(mm_kwargs: dict[str, Any]) -> list[Tensor]

Execute encoder on local inputs using greedy-packed CUDA graphs.

Sort images by output token count (smallest first), then greedily pack as many images as possible into each batch while staying within max_budget tokens and max_batch_size. Once a batch is finalised (next image would overflow either constraint), find the smallest fitting budget once for that batch.

By exchange argument, greedy smallest-first packing minimises eager fallbacks -- any other ordering yields a higher token sum in some batch, making that batch more likely to exceed the budget.

Stats note

graph_hits -- counted inside _run_budget_graph after successful replay. graph_misses -- counted here for single-image batches where the image exceeds max_budget. Batches split due to max_batch_size always satisfy total_tokens <= max_budget and therefore always find a valid budget (no miss).

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def _execute_local(
    self,
    mm_kwargs: dict[str, Any],
) -> list[torch.Tensor]:
    """Execute encoder on local inputs using greedy-packed CUDA graphs.

    Sort images by output token count (smallest first), then greedily pack
    as many images as possible into each batch while staying within
    max_budget tokens and max_batch_size. Once a batch is finalised (next
    image would overflow either constraint), find the smallest fitting
    budget once for that batch.

    By exchange argument, greedy smallest-first packing minimises eager
    fallbacks -- any other ordering yields a higher token sum in some batch,
    making that batch more likely to exceed the budget.

    Stats note:
      graph_hits  -- counted inside _run_budget_graph after successful replay.
      graph_misses -- counted here for single-image batches where the image
                     exceeds max_budget. Batches split due to max_batch_size
                     always satisfy total_tokens <= max_budget and therefore
                     always find a valid budget (no miss).
    """
    num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
    max_budget = self.token_budgets[-1]

    per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)

    # Sort ascending by output token count (smallest first)
    sorted_indices = sorted(range(num_items), key=lambda i: per_item_out_tokens[i])

    # Greedy pack against max_budget and max_batch_size.
    # _find_smallest_fitting_budget_given_tokens is called once per
    # finalised batch, not per image.
    batches: list[tuple[list[int], int | None]] = []
    current_batch: list[int] = []
    current_batch_tokens = 0

    for orig_idx in sorted_indices:
        item_tokens = per_item_out_tokens[orig_idx]
        if (
            current_batch_tokens + item_tokens <= max_budget
            and len(current_batch) < self.max_batch_size
        ):
            current_batch.append(orig_idx)
            current_batch_tokens += item_tokens
        else:
            if current_batch:
                batches.append(
                    (
                        current_batch,
                        self._find_smallest_fitting_budget_given_tokens(
                            current_batch_tokens
                        ),
                    )
                )
            current_batch = [orig_idx]
            current_batch_tokens = item_tokens

    if current_batch:
        batches.append(
            (
                current_batch,
                self._find_smallest_fitting_budget_given_tokens(
                    current_batch_tokens
                ),
            )
        )

    # outputs_by_orig_idx maps each original image index to its output
    # tensor. Needed because greedy packing reorders images; we restore
    # the original order before returning.
    outputs_by_orig_idx: dict[int, torch.Tensor] = {}

    for batch_orig_indices, token_budget in batches:
        batch_mm_kwargs = self.model.select_encoder_cudagraph_items(
            mm_kwargs, batch_orig_indices
        )
        batch_out_tokens = sum(per_item_out_tokens[i] for i in batch_orig_indices)

        if token_budget is None:
            # Single oversized image: item_tokens > max_budget.
            # graph_misses counted here for this eager fallback.
            logger.debug(
                "Encoder CUDA graph fallback to eager: no budget for "
                "%d tokens from %d images",
                batch_out_tokens,
                len(batch_orig_indices),
            )
            self.graph_misses += len(batch_orig_indices)
            with torch.inference_mode():
                raw = self.model.encoder_eager_forward(batch_mm_kwargs)
            self._scatter_output_slices(
                raw,
                batch_orig_indices,
                per_item_out_tokens,
                outputs_by_orig_idx,
            )
        else:
            logger.debug(
                "Encoder CUDA graph: batch_size=%d, tokens=%d, "
                "budget=%d, waste=%.1f%%",
                len(batch_orig_indices),
                batch_out_tokens,
                token_budget,
                (token_budget - batch_out_tokens) / token_budget * 100,
            )
            replay = self.model.prepare_encoder_cudagraph_replay_buffers(
                batch_mm_kwargs, self.max_batch_size
            )

            # graph_hits counted inside _run_budget_graph after replay.
            output = self._run_budget_graph(
                batch_mm_kwargs, token_budget, replay.buffers
            )
            assert output is not None
            self._scatter_output_slices(
                output,
                batch_orig_indices,
                per_item_out_tokens,
                outputs_by_orig_idx,
                clone=True,
            )

    # Return in original batch order (caller maps outputs to token positions)
    return [outputs_by_orig_idx[i] for i in range(num_items)]

_find_smallest_fitting_budget_given_tokens

_find_smallest_fitting_budget_given_tokens(
    total_tokens: int,
) -> int | None

Find smallest budget >= total_tokens.

Returns:

Type Description
int | None

Token budget if found, None if no fitting budget.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def _find_smallest_fitting_budget_given_tokens(
    self, total_tokens: int
) -> int | None:
    """Find smallest budget >= total_tokens.

    Returns:
        Token budget if found, None if no fitting budget.
    """
    for budget in self.token_budgets:
        if budget >= total_tokens:
            return budget
    return None

_generate_budgets staticmethod

_generate_budgets(
    min_budget: int, max_budget: int
) -> list[int]

Generate power-of-2 token budgets from min_budget to max_budget.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
@staticmethod
def _generate_budgets(min_budget: int, max_budget: int) -> list[int]:
    """Generate power-of-2 token budgets from min_budget to max_budget."""
    budgets: list[int] = []
    b = min_budget
    while b <= max_budget:
        budgets.append(b)
        b *= 2
    # Always include max_budget if it's not already a power-of-2 boundary
    if not budgets or budgets[-1] < max_budget:
        budgets.append(max_budget)
    return budgets

_get_per_item_out_tokens

_get_per_item_out_tokens(
    mm_kwargs: dict[str, Any],
) -> list[int]

Get per-item output token counts as plain ints.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def _get_per_item_out_tokens(self, mm_kwargs: dict[str, Any]) -> list[int]:
    """Get per-item output token counts as plain ints."""
    return [
        int(t)
        for t in self.model.get_encoder_cudagraph_per_item_output_tokens(mm_kwargs)
    ]

_run_budget_graph

_run_budget_graph(
    mm_kwargs: dict[str, Any],
    token_budget: int,
    replay_buffers: dict[str, Tensor | None],
) -> Tensor | None

Execute budget graph.

Parameters:

Name Type Description Default
mm_kwargs dict[str, Any]

Multimodal inputs for the batch.

required
token_budget int

Token budget to use.

required
replay_buffers dict[str, Tensor | None]

Buffer values to copy into captured buffers. None values leave the corresponding buffer unchanged.

required

Returns:

Type Description
Tensor | None

Encoder outputs, or None if graph not captured.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def _run_budget_graph(
    self,
    mm_kwargs: dict[str, Any],
    token_budget: int,
    replay_buffers: dict[str, torch.Tensor | None],
) -> torch.Tensor | None:
    """Execute budget graph.

    Args:
        mm_kwargs: Multimodal inputs for the batch.
        token_budget: Token budget to use.
        replay_buffers: Buffer values to copy into captured buffers.
            None values leave the corresponding buffer unchanged.

    Returns:
        Encoder outputs, or None if graph not captured.
    """
    num_items = self.model.get_encoder_cudagraph_num_items(mm_kwargs)
    if token_budget not in self.budget_graphs:
        self.graph_misses += num_items
        return None

    graph_meta = self.budget_graphs[token_budget]

    # Copy the input tensor. Buffers are sized for the full budget;
    # actual inputs may be smaller. Zero then slice-copy so padded
    # positions are invisible to attention (cu_seqlens masks them out).
    input_key = self.config.input_key
    src = mm_kwargs[input_key]
    n = src.shape[0]
    graph_meta.input_buffer.zero_()
    graph_meta.input_buffer[:n].copy_(src)

    # Copy metadata buffers using keys from config.buffer_keys.
    for key in self.config.buffer_keys:
        src = replay_buffers.get(key)
        if src is None:
            continue
        buf = graph_meta.metadata_buffers[key]
        if src.ndim == 0:
            buf.copy_(src)
        else:
            n = src.shape[0]
            buf.zero_()
            buf[:n].copy_(src)

    graph_meta.graph.replay()

    self.graph_hits += num_items
    return graph_meta.output_buffer

_scatter_output_slices staticmethod

_scatter_output_slices(
    output: Tensor,
    indices: list[int],
    per_item_out_tokens: list[int],
    dest: dict[int, Tensor] | list[Tensor | None],
    clone: bool = False,
) -> None

Slice a concatenated output tensor and scatter into dest by index.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
@staticmethod
def _scatter_output_slices(
    output: torch.Tensor,
    indices: list[int],
    per_item_out_tokens: list[int],
    dest: dict[int, torch.Tensor] | list[torch.Tensor | None],
    clone: bool = False,
) -> None:
    """Slice a concatenated output tensor and scatter into dest by index."""
    offset = 0
    for idx in indices:
        n_tok = per_item_out_tokens[idx]
        sliced = output[offset : offset + n_tok]
        dest[idx] = sliced.clone() if clone else sliced
        offset += n_tok

capture

capture()

Capture CUDA graphs for all token budgets.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def capture(self):
    """Capture CUDA graphs for all token budgets."""
    for token_budget in self.token_budgets:
        self._capture_budget_graph(token_budget)

    logger.info(
        "Encoder CUDA graph capture complete. Captured %d budget graphs.",
        len(self.budget_graphs),
    )

execute

execute(mm_kwargs: dict[str, Any]) -> list[Tensor]

Execute encoder using CUDA graph with optional DP.

Parameters:

Name Type Description Default
mm_kwargs dict[str, Any]

Multimodal keyword arguments containing the input tensor and grid dimensions.

required

Returns:

Type Description
list[Tensor]

List of encoder outputs (one per item).

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def execute(
    self,
    mm_kwargs: dict[str, Any],
) -> list[torch.Tensor]:
    """Execute encoder using CUDA graph with optional DP.

    Args:
        mm_kwargs: Multimodal keyword arguments containing the
            input tensor and grid dimensions.

    Returns:
        List of encoder outputs (one per item).
    """
    if self.use_dp:
        per_item_out_tokens = self._get_per_item_out_tokens(mm_kwargs)

        (
            local_mm_kwargs,
            image_rank_assignment,
            images_per_rank,
            max_output_tokens_per_rank,
        ) = self._dp_shard(mm_kwargs, per_item_out_tokens)

        local_outputs = self._execute_local(local_mm_kwargs)

        result = self._dp_gather(
            local_outputs,
            per_item_out_tokens,
            image_rank_assignment,
            images_per_rank,
            max_output_tokens_per_rank,
        )
    else:
        result = self._execute_local(mm_kwargs)

    # Log cumulative stats periodically
    stats = self.get_cumulative_stats()
    total_requests = self.graph_hits + self.graph_misses
    if total_requests > 0 and total_requests % self.log_stats_interval == 0:
        logger.debug(
            "Encoder CUDA graph cumulative stats: "
            "hits=%d, misses=%d, hit_rate=%.1f%%",
            stats["graph_hits"],
            stats["graph_misses"],
            stats["hit_rate"] * 100,
        )

    return result

get_cumulative_stats

get_cumulative_stats() -> dict[str, Any]

Get cumulative CUDA graph statistics.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def get_cumulative_stats(self) -> dict[str, Any]:
    """Get cumulative CUDA graph statistics."""
    total_requests = self.graph_hits + self.graph_misses
    hit_rate = self.graph_hits / total_requests if total_requests > 0 else 0.0

    return {
        "graph_hits": self.graph_hits,
        "graph_misses": self.graph_misses,
        "hit_rate": hit_rate,
        "num_budgets": len(self.budget_graphs),
        "token_budgets": self.token_budgets,
    }

supports_modality

supports_modality(modality: str) -> bool

Check if a modality is supported by this manager.

Source code in vllm/v1/worker/gpu/mm/encoder_cudagraph.py
def supports_modality(self, modality: str) -> bool:
    """Check if a modality is supported by this manager."""
    return modality in self.config.modalities