-
Notifications
You must be signed in to change notification settings - Fork 7k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
OOM Error with roi_align
in PyTorch 2.1.1 but fine in PyTorch 2.0.1
#8168
Comments
We need to understand where regression is coming from, but sounds a bit like a torchvision problem, isn't it? Also, I wonder if this is CUDA-11.8 vs CUDA-12.1 regression (2.0.1 was shipped with 11.8 by default, but 2.1 with 12.1) |
Hello, I've created a minimal toy example to demonstrate the issue in detail. This example relies on the ultralytics library. For context, the code runs as expected when using PyTorch version 2.0.1 and Torchvision version 0.15.2+cu118, and OOM when PyTorch 2.1.1. Please take a look at the gist, and let me know if you need any more information or if there's anything else I can do to assist in resolving this issue. Thank you! |
I just want to chime in an mention that I have the same problem. A very large memory allocation is attempted both on the GPU and the CPU. I observe the problem in the following environment:
But everything works with:
Below is a little snippet which leads to the OOM error with
Can someone else replicate this? |
Oh you know what, it's probably because of use deterministic algorithms. We added a deterministic implementation but it is very memory hungry |
Hey, I just came across this issue. Is there any update? I understand the appeal of a deterministic implementation but the caveat very memory hungry is an understatement :D When I call the problematic Has this been actually tested or run in a benchmark? If so, how? I fail to see how this is the intended behavior unless I'm missing something fundamental 😅 thx for any help. |
The implementation doesn't OOM if we torch.compile it. So I think I will fix it by making torch.compile on it mandatory. |
Jesus, how is that possible? Sounds great. |
Thank you for the quick response and the fix @ezyang 🚀 I applied your patch manually in my system and can confirm that it does eliminate the OOM issue! A Mask R-CNN ResNet-50 FPN now consumes ~5500 MB for batch_size 2 on COCO and about ~38000 MB for batch_size 16. I ran some quick tests using the torchvision reference implementation and can further confirm that we now have deterministic training (see below). In addition, I append some timings in case this helps moving forward. Test Setup
For reproducible determinism I set: torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True
torch.use_deterministic_algorithms(True)
os.environ['CUBLAS_WORKSPACE_CONFIG'] = ":4096:8"
# seed for main process/thread (torch, random, numpy)
seed=617971023
# seed for sampler and dataloader generators (num_workers=4)
seed_data=0 Deterministic Training works :)# nn.Module run 1
Epoch: [0] [ 0/58633] eta: 80 days, 11:30:44 lr: 0.000000 loss: 6.0988 (6.0988) loss_classifier: 4.5028 (4.5028) loss_box_reg: 0.0175 (0.0175) loss_mask: 0.7927 (0.7927) loss_objectness: 0.6928 (0.6928) loss_rpn_box_reg: 0.0930 (0.0930) time: 118.5927 data: 0.2771 max mem: 4346
Epoch: [0] [20/58633] eta: 12 days, 17:24:42 lr: 0.000002 loss: 6.0132 (6.0607) loss_classifier: 4.4406 (4.4362) loss_box_reg: 0.0319 (0.0412) loss_mask: 0.7437 (0.7650) loss_objectness: 0.6933 (0.6934) loss_rpn_box_reg: 0.0742 (0.1249) time: 13.7666 data: 0.0025 max mem: 5533
Epoch: [0] [40/58633] eta: 11 days, 12:51:42 lr: 0.000004 loss: 5.5885 (5.7992) loss_classifier: 3.9741 (4.1683) loss_box_reg: 0.0536 (0.0510) loss_mask: 0.7486 (0.7632) loss_objectness: 0.6906 (0.6920) loss_rpn_box_reg: 0.0687 (0.1248) time: 15.1755 data: 0.0028 max mem: 5533
Epoch: [0] [60/58633] eta: 11 days, 21:56:55 lr: 0.000006 loss: 3.4103 (4.9821) loss_classifier: 1.5982 (3.3080) loss_box_reg: 0.0917 (0.0753) loss_mask: 0.8951 (0.8050) loss_objectness: 0.6612 (0.6780) loss_rpn_box_reg: 0.0843 (0.1159) time: 18.7317 data: 0.0029 max mem: 5533
# nn.Module run 2
Epoch: [0] [ 0/58633] eta: 79 days, 04:24:00 lr: 0.000000 loss: 6.0988 (6.0988) loss_classifier: 4.5028 (4.5028) loss_box_reg: 0.0175 (0.0175) loss_mask: 0.7927 (0.7927) loss_objectness: 0.6928 (0.6928) loss_rpn_box_reg: 0.0930 (0.0930) time: 116.6824 data: 0.2564 max mem: 4346
Epoch: [0] [20/58633] eta: 12 days, 14:39:22 lr: 0.000002 loss: 6.0132 (6.0607) loss_classifier: 4.4406 (4.4362) loss_box_reg: 0.0319 (0.0412) loss_mask: 0.7437 (0.7650) loss_objectness: 0.6933 (0.6934) loss_rpn_box_reg: 0.0742 (0.1249) time: 13.6844 data: 0.0023 max mem: 5400
Epoch: [0] [40/58633] eta: 11 days, 11:22:22 lr: 0.000004 loss: 5.5885 (5.7992) loss_classifier: 3.9741 (4.1683) loss_box_reg: 0.0536 (0.0510) loss_mask: 0.7486 (0.7632) loss_objectness: 0.6906 (0.6920) loss_rpn_box_reg: 0.0687 (0.1248) time: 15.1657 data: 0.0026 max mem: 5400
Epoch: [0] [60/58633] eta: 11 days, 21:05:58 lr: 0.000006 loss: 3.4103 (4.9821) loss_classifier: 1.5982 (3.3080) loss_box_reg: 0.0917 (0.0753) loss_mask: 0.8951 (0.8050) loss_objectness: 0.6612 (0.6780) loss_rpn_box_reg: 0.0843 (0.1159) time: 18.7601 data: 0.0027 max mem: 5400
# DDP (world_size 1) run 1
Epoch: [0] [ 0/58633] eta: 70 days, 18:22:51 lr: 0.000000 loss: 6.0988 (6.0988) loss_classifier: 4.5028 (4.5028) loss_box_reg: 0.0175 (0.0175) loss_mask: 0.7927 (0.7927) loss_objectness: 0.6928 (0.6928) loss_rpn_box_reg: 0.0930 (0.0930) time: 104.2787 data: 0.4824 max mem: 4515
Epoch: [0] [20/58633] eta: 12 days, 02:47:31 lr: 0.000002 loss: 6.0132 (6.0607) loss_classifier: 4.4406 (4.4362) loss_box_reg: 0.0319 (0.0412) loss_mask: 0.7437 (0.7650) loss_objectness: 0.6933 (0.6934) loss_rpn_box_reg: 0.0742 (0.1249) time: 13.5395 data: 0.0021 max mem: 5572
Epoch: [0] [40/58633] eta: 11 days, 04:49:32 lr: 0.000004 loss: 5.5885 (5.7992) loss_classifier: 3.9741 (4.1683) loss_box_reg: 0.0536 (0.0510) loss_mask: 0.7486 (0.7632) loss_objectness: 0.6906 (0.6920) loss_rpn_box_reg: 0.0687 (0.1248) time: 15.1061 data: 0.0027 max mem: 5572
Epoch: [0] [60/58633] eta: 11 days, 16:31:04 lr: 0.000006 loss: 3.4103 (4.9821) loss_classifier: 1.5982 (3.3080) loss_box_reg: 0.0917 (0.0753) loss_mask: 0.8951 (0.8050) loss_objectness: 0.6612 (0.6780) loss_rpn_box_reg: 0.0843 (0.1159) time: 18.7259 data: 0.0028 max mem: 5572
# DDP (world_size 1) run 2
Epoch: [0] [ 0/58633] eta: 71 days, 13:13:18 lr: 0.000000 loss: 6.0988 (6.0988) loss_classifier: 4.5028 (4.5028) loss_box_reg: 0.0175 (0.0175) loss_mask: 0.7927 (0.7927) loss_objectness: 0.6928 (0.6928) loss_rpn_box_reg: 0.0930 (0.0930) time: 105.4355 data: 0.4228 max mem: 4515
Epoch: [0] [20/58633] eta: 12 days, 03:27:48 lr: 0.000002 loss: 6.0132 (6.0607) loss_classifier: 4.4406 (4.4362) loss_box_reg: 0.0319 (0.0412) loss_mask: 0.7437 (0.7650) loss_objectness: 0.6933 (0.6934) loss_rpn_box_reg: 0.0742 (0.1249) time: 13.5249 data: 0.0024 max mem: 5566
Epoch: [0] [40/58633] eta: 11 days, 05:10:02 lr: 0.000004 loss: 5.5885 (5.7992) loss_classifier: 3.9741 (4.1683) loss_box_reg: 0.0536 (0.0510) loss_mask: 0.7486 (0.7632) loss_objectness: 0.6906 (0.6920) loss_rpn_box_reg: 0.0687 (0.1248) time: 15.1059 data: 0.0027 max mem: 5566
Epoch: [0] [60/58633] eta: 11 days, 16:41:03 lr: 0.000006 loss: 3.4103 (4.9821) loss_classifier: 1.5982 (3.3080) loss_box_reg: 0.0917 (0.0753) loss_mask: 0.8951 (0.8050) loss_objectness: 0.6612 (0.6780) loss_rpn_box_reg: 0.0843 (0.1159) time: 18.7140 data: 0.0026 max mem: 5566 Time and memory overhead for DDP models
The following values are for DDP models with batch_size 2 per gpu (world_size) and depict the avg. time per batch and max_mem after 20 batches (measured with MetricLogger from reference implementation).
MiscI get the following deprecation warnings with UserWarning: 'has_cuda' is deprecated, please use 'torch.backends.cuda.is_built()'
UserWarning: 'has_cudnn' is deprecated, please use 'torch.backends.cudnn.is_available()'
UserWarning: 'has_mps' is deprecated, please use 'torch.backends.mps.is_built()'
UserWarning: 'has_mkldnn' is deprecated, please use 'torch.backends.mkldnn.is_available()' |
By the way, I think Inductor can potentially do a lot better codegen on this to bring down the time/memory overhead, just need some concerted elbow grease on it. |
I'd love to help but have to admit that this part of the code base is a little over my head 😅 However, if I can support you with some isolated testing (that does not require in-depth knowledge) let me know! |
Quick update: I ran some more tests to see if newer torch/torchvision versions will improve things but, it appears that I've been lucky with With Both are the pre-build versions from pypi. DDM models on 1 GPU (but same errors with nn.Module). I replaced my env path with ... for brevity. Hope this helps to narrow things down. OOM with
|
Just for clarity on your setup, did you manually patch in the change to the prebuilt binaries of torchvision to test them? |
Yes, I have three separate conda envs with the mentioned torch and torchvision versions (installed from pypi with pip), and manually patched the torchvision/ops/roi_align.py files in their site-packages to match the file of #8436. This naive approach resulted in the success and errors mentioned above. Let me know if you need more info or if this was too naive and is better tested in a different way. As mentioned, I'm not very familiar with torch.compile/dynamo/inductor. |
Reopening for torch version incompatibility |
Bah, I don't have a ready to go maskrcnn setup that I can use to easily test this |
@JohannesTheo do you have a suggested way of reproducing your problems? Alternately, if you are able to do runs with TORCH_TRACE and upload them here, that would also be greatly helpful. |
Hey @ezyang, I will put something together on the WE. |
🐛 Describe the bug
Description
I am encountering an Out of Memory (OOM) error when using the roi_align function from PyTorch version 2.1.1 with torchvision 0.16.1. This issue does not occur with PyTorch version 2.0.1 and torchvision 0.15.2. The error happens regardless of the GPU used (tested on NVIDIA A2000 and RTX 4090).
Note that when I downgrade the PyTorch and torchvision back to 2.0.1 and 0.15.2, the function can work properly.
I am seeking assistance in understanding why this OOM error occurs in the newer versions of PyTorch and torchvision and whether this is a bug or a change in how roi_align manages memory.
Background
Function
object_roi_align
function crops feature maps based on YOLO's object detection labels and uses RoI align to extract features of the object. The function accepts feature maps, YOLO detection labels, and several optional parameters for noise and class constraints.Error messages (A2000)
Error message (RTX 4090)
Versions
Versions (RTX 4090)
Versions (A2000)
cc @ezyang @gchanan @zou3519 @kadeng @ptrblck
The text was updated successfully, but these errors were encountered: