A very naive implementation of multiscale deformable attention in triton.
Here is a performance comparison with the PyTorch-native multiscale deformable attention:
The results are also in line with the original CUDA implementation from deformable DETR. Running the same benchmark for CUDA, I get:
- FWD with 10k queries: 5.37 ms in CUDA vs. 4.81 ms in Triton.
- FWD+BWD with 10k queries: 28.04 ms in CUDA vs. 24.95 ms in Triton.
- Memory with 10k queries: 166.14 MB in CUDA vs. 166.14 MB in Triton (same).
Results obtained on my RTX 2060 (gpu poor).
Replacing the original implementation with this kernel in a Deformable DETR-like model yields the same results.
Model: Grounding DINO, Image source: COCO
Note
The original version usespadding_mode="zeros"
andalign_corners=False
. Make sure you match these arguments when using this implementation.
This package requires PyTorch and Triton, but does not install them automatically. Make sure you have them installed before proceeding.
Check if Pytorch is installed with:
python -c "import torch; print(torch.__version__)"
It should print something like 2.6.0+cu124
.
If it does not, follow the install instructions.
Triton can already come bundled with PyTorch. Check if it is installed with:
python -c "import triton; print(triton.__version__)"
It should print something like 3.2.0
.
If it does not, follow the install instructions.
Clone the repository:
git clone https://github.com/rziga/msda-triton
cd msda-triton
and install via pip
:
pip install .
You need to install pytest
to run the tests in the tests
directory.
Then run:
pytest ./tests
To run the benchmark, you also need matplotlib and pandas.
Then run:
python scripts/benchmark.py
The results will be printed in terminal and saved in outputs/benchmark_results
folder.
Since triton can be pretty finnicky, I also provide the dependencies that I used for development.
Install via pip
:
pip install -e .[dev]
or with uv
:
uv sync --dev
The package exposes two things:
multiscale_deformable_attention
- a differentiable PyTorch function defining the multiscale deformable attention operator proposed in Deformable DETR. Usage:
import torch
from msda_triton import multiscale_deformable_attention
# input params
batch = 2
head_dim = 32
num_queries = 900
num_heads = 8
num_points = 4
img_shapes = [(64, 64), (32, 32), (16, 16), (8, 8)]
num_pixels = sum(h * w for h, w in img_shapes)
num_levels = len(img_shapes)
device = "cuda" # "cpu" uses fallback native torch version
# generate random inputs
img = torch.randn(batch, num_pixels, num_heads, head_dim, device=device)
img_shapes = torch.tensor(img_shapes, device=device)
sampling_points = torch.rand(batch, num_queries, num_heads, num_levels, num_points, 2, device=device)
attention_weights = torch.rand(batch, num_queries, num_heads, num_levels, num_points, device=device)
padding_mode = "zeros" # or "border"
align_corners = False # or True
# perform MSDA
output = multiscale_deformable_attention(
img, img_shapes, sampling_points, attention_weights,
padding_mode, align_corners,
)
assert output.shape == (batch, num_queries, num_heads, head_dim)
MultiscaleDeformableAttention
- a PyTorchnn.Module
, which also handles the input and output projections. Usage:
import torch
from msda_triton import MultiscaleDeformableAttention
# input params
batch = 2
emb_dim = 256
hidden_dim = 512
num_queries = 900
num_heads = 8
num_points = 4
img_shapes = [(64, 64), (32, 32), (16, 16), (8, 8)]
num_pixels = sum(h * w for h, w in img_shapes)
num_levels = len(img_shapes)
device = "cuda" # "cpu" uses fallback native torch version
# generate random inputs
img = torch.randn(batch, num_pixels, emb_dim, device=device)
img_shapes = torch.tensor(img_shapes, device=device)
queries = torch.rand(batch, num_queries, emb_dim, device=device)
reference_points = torch.rand(batch, num_queries, 2, device=device)
# init module
msda = MultiscaleDeformableAttention(
emb_dim,
hidden_dim,
num_levels,
num_heads,
num_points,
padding_mode="border", # or "zeros"
align_corners=True, # or False
).to(device)
# perform MSDA
output = msda(img, img_shapes, queries, reference_points)
assert output.shape == (batch, num_queries, emb_dim)
The kernels are quite basic, as this is my first experience with Triton. I have tested the functions as much as I could, but there could still be some issues. Performance can definitely be improved. Feel free to open an issue and/or submit improvements.