Matepoint is a fork of PyTorch's torch.utils.checkpoint
that allows you to utilize CPU RAM when you're low on GPU VRAM. While standard checkpointing trades computation for memory by recomputing activations during the backward pass, Matepoint takes this further by:
- Automatically offloading activation tensors to CPU after the forward pass
- Efficiently moving tensors back to GPU only when needed during the backward pass
- Supporting pipelined tensor transfers for better performance
- Providing optional CPU memory pooling for large, similarly-shaped tensors
Replace your existing torch.utils.checkpoint
calls with matepoint
:
from matepoint import checkpoint
# Instead of:
# from torch.utils.checkpoint import checkpoint
def forward(self, x):
# Use exactly like torch.utils.checkpoint
out = checkpoint(self.layer, x, use_reentrant=False)
return out
- PyTorch >= 2.4.0
- CUDA-capable GPU
- Sufficient CPU memory for activation storage
pip install --index-url https://test.pypi.org/simple/ matepoint
rm -rf dist/ build/ .egg-info
python setup.py sdist bdist_wheel
twine upload --repository testpypi dist/*
# if needed, can specify exact version
# twine upload --repository testpypi dist/matepoint-0.1.7*
Refer to the Matepoint section in this blog post for more details on the implementation and performance benefits.
This project is licensed under the MIT License - see the LICENSE file for details.
We actually built Matepoint when we were running out of VRAM trying to solve weather(™) with transformers. While WeatherMesh, our model itself isn't huge (~180M parameters), forecasting weather for the entire planet over 6 days means running through 200+ transformer layers.
Without some clever tricks, we'd need hundreds of GiB of VRAM. Even regular checkpointing wasn't enough - storing those 200MiB latent tensors for each transformer block would eat up around 40GiB of VRAM, which is more than even an RTX 4090 can handle.
Matepoint ships those tensors off to CPU RAM when we don't need them, then brings them back just in time during the backward pass. Adding more forecast days costs almost nothing in VRAM terms. This meant we could train our whole weather model on consumer RTX 4090s instead of shelling out for pricier hardware.
Check out these visualizations to see Matepoint in action:
Matepoint overlaps data movement with computation by default, improving performance by efficiently transferring tensors between CPU and GPU. You can disable this optimization if needed:
import matepoint
matepoint.NOPIPELINE = True # Disable pipelined tensor transfers