Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch RAG tests GPU OOM #16881

Merged
merged 5 commits into from
Apr 25, 2022

Conversation

ydshieh
Copy link
Collaborator

@ydshieh ydshieh commented Apr 21, 2022

What does this PR do?

Fix PyTorch RAG tests GPU OOM.

The GPU OOM

E     tensorflow.python.framework.errors_impl.ResourceExhaustedError: OOM when allocating tensor with shape[32,5,16,300,64] and type float on /job:localhost/replica:0/task:0/device:GPU:0 by allocator GPU_0_bfc [Op:GatherV2]

could be found in https://github.com/huggingface/transformers/runs/6100697349?check_suite_focus=true

Results

  • Without this PR, after the PyTorch RAG test, torch occupies about 9.5 GB GPU memory. There are 10 TF RAG tests failed.
  • With this PR, all PT/TF RAG tests pass without GPU OOM

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Apr 21, 2022

The documentation is not available anymore as the PR was closed or merged.

@ydshieh ydshieh marked this pull request as draft April 21, 2022 19:32
@ydshieh ydshieh changed the title Fix PyTorch RAG tests GPU OOM Fix (partially) PyTorch RAG tests GPU OOM Apr 22, 2022
@ydshieh ydshieh marked this pull request as ready for review April 22, 2022 09:43
@ydshieh ydshieh requested review from lhoestq, patrickvonplaten, LysandreJik and sgugger and removed request for lhoestq April 22, 2022 09:43
@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

Also cc @patil-suraj and @stas00 to see if they have suggestions

@patrickvonplaten
Copy link
Contributor

Good for me!

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What cache is this emptying since the model is not deleted? I think this is a symptom there is a memory leak in the model, which would need to be fixed.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

What cache is this emptying since the model is not deleted? I think this is a symptom there is a memory leak in the model, which would need to be fixed.

From the documentations torch.cuda.empty_cache and Memory management, in particular

PyTorch uses a caching memory allocator to speed up memory allocations.
This allows fast memory deallocation without device synchronizations.
However, the unused memory managed by the allocator will still show as if used in nvidia-smi. 

and

Calling empty_cache() releases all unused cached memory from PyTorch so that those can be used by other GPU applications.

my understanding is: PyTorch will keep the allocated GPU memory for later use in order to avoid to reduce the number of times of memory allocation - the goal is to speed up some (memory) operations.

This doesn't mean those GPU memory are leaked - PyTorch still controls them. But for other applications (say TensorFlow or nvidia-smi), it means those GPU memory are not available. Use empty_cache() will release them.

Of course, the memory occupied by current available tensors won't be release.

@stas00
Copy link
Contributor

stas00 commented Apr 22, 2022

if you're trying to use multiple programs concurrently accessing the same GPU (regardless if they all are pytorch, or mixed framework), torch.cuda.empty_cache is a definite help and is OK to use as long as it's inside the test only and not in transformers

But why when the pytorch test finishes the torch still has allocated tensors? Why not free them?

Often import gc; gc.collect() is needed to force garbage collection immediately. I'm not sure if this is the case.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

@stas00

My words in the previous comment might be a bit confusing: we don't have the issue of the pytorch test finishes the torch still has allocated tensors. I am just mentioning a general fact (which is quite trivial) that is mentioned in the PyTorch docs I linked.

empty_cache() helps, but not completely. There are still some GPU memory occupied even we leave the testing methods, but still in the same Python process (for example, entering the TF testing module which is launched together with the PT testing).

There are some discussions, like this one

https://discuss.pytorch.org/t/pytorch-do-not-clear-gpu-memory-when-return-to-another-function/125944/4

@stas00
Copy link
Contributor

stas00 commented Apr 22, 2022

when you do torch.ones(1) it allocates 1-2GB of cuda kernels on the gpu and they remain allocated unless the program is shutdown.

In such a case the solution is not to run the program inside pytest but to use an external process. Once an external process finishes 100% of gpu memory is returned. (Except the tests are then much slower because it has to launch an external program)

I created a special framework for running external programs

def execute_subprocess_async(cmd, env=None, stdin=None, timeout=180, quiet=False, echo=True) -> _RunOutput:

You can see it extensively used in the deepspeed and extended tests.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

Yeah, I know this approach, but wasn't very sure how not use it in a good way with testing.
Maybe we can discuss this!

By the way: torch.ones(1) it allocates 1-2GB of cuda kernels --> I tried it and it seems a correct statement.

I am really surprised (and not very happy) that there is no way to free these kinds of memory allocation.

@stas00
Copy link
Contributor

stas00 commented Apr 22, 2022

Chances are is that there was no need for that until now and nobody asked for it. If I may propose you could create a feature request at pytorch asking for a feature that releases the cuda env completely. It's very possible that there is a C++ API to do that and it just wasn't made available in python.

The use case can be for example this exact situation, where the same program needs to alternate between different frameworks in the same run and needs to be able to access all of gpu's memory.

Does tf give the memory fully back when it's done and the process is still running?

@sgugger
Copy link
Collaborator

sgugger commented Apr 22, 2022

If the goal is to recover as much memory as possible, shouldn't we delete the model before calling the empty_cache function?

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

There have been some requests in torch GH page without response

pytorch/pytorch#28829
(this one is on 2019/10)

Same situation for TF: not fully giving back GPU memory, and the requests are always without response

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

If the goal is to recover as much memory as possible, shouldn't we delete the model before calling the empty_cache function?

@sgugger You are right! I tried it and just like you said.

Maybe I can just implement tearDownModule() which calls empty_cache(), so we don't need to del models + empty_cache() in all testing methods ..?

I am going to try this and see if how it goes.
(tried with toy examples, and works as expected)

@sgugger
Copy link
Collaborator

sgugger commented Apr 22, 2022

Note that the tearDown is only called at the end of all tests of the module, so won't do the same thing you implemented (clean up at the end of each test).

@stas00
Copy link
Contributor

stas00 commented Apr 22, 2022

Implement tearDown for unittest.TestCase subclass (and make sure to call its super) - this one will be called at the end of each test.

and before empty_cache it often helps to call gc.collect() to make it deterministic.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

OK, I can do that.

But I feel that while we are in PyTorch test itself, we don't need to call empty_cache() --> because the occupied cache will be managed by torch and will be assigned to subsequential torch operations if they require GPU.

This empty_cache() is mainly for other applications to use GPU, like TF for example, in the same process.

And since the TF tests are in other modules, tearDownModule() in PT test module should be enough.

But again, I can go for tearDown()

@stas00
Copy link
Contributor

stas00 commented Apr 22, 2022

Of course, we are discussing this particular test. I wasn't suggesting to do it to all tests.

The reason I suggested gc.collect before empty_cache is because when you free the model it's not guaranteed it'll be immediately freed due to how python's GC works. So if you want a reliable deterministic memory release inside a long running pytest process, gc.collect followed by empty_cache is how you make things deterministic.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

@sgugger @stas00

With the suggestions, all TF RAG tests pass now on GPU! 🔥 Thank you!

@stas00
Copy link
Contributor

stas00 commented Apr 22, 2022

Unrelated to this PR, but since you work a lot with tests (thank you!), in case you're not aware of it, awhile ago I have developed:

class TestCasePlus(unittest.TestCase):

which extends unittest.TestCase with various handy features - like automatic removal of temp dirs, accessors to file paths and many others. It's extensively documented in the module itself and also in https://huggingface.co/docs/transformers/testing

You don't need to do anything about it, other than perhaps I hope it'll save you time in the future.

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 22, 2022

Thank you, @stas00 . Maybe I can play with it, and at some point have a discussion with other team members to see if to use it by default!

@stas00
Copy link
Contributor

stas00 commented Apr 22, 2022

And of course please feel free to extend it if there are other features that can be re-used.

@ydshieh ydshieh changed the title Fix (partially) PyTorch RAG tests GPU OOM Fix PyTorch RAG tests GPU OOM Apr 22, 2022
@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 25, 2022

Would like to have @sgugger and/or @LysandreJik opinion before merge :-)

Copy link
Collaborator

@sgugger sgugger left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks for iterating!

@ydshieh
Copy link
Collaborator Author

ydshieh commented Apr 25, 2022

Merge now - we should have 13 test failure fewer (if this PR also works well on multiple GPUs too)

@ydshieh ydshieh merged commit 32adbb2 into huggingface:main Apr 25, 2022
@ydshieh ydshieh deleted the quick_fix_rag_tests_gpu_oom branch April 25, 2022 15:34
@ydshieh ydshieh mentioned this pull request Jun 1, 2022
elusenji pushed a commit to elusenji/transformers that referenced this pull request Jun 12, 2022
* add torch.cuda.empty_cache in some PT RAG tests

* torch.cuda.empty_cache in tearDownModule()

* tearDown()

* add gc.collect()

Co-authored-by: ydshieh <ydshieh@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants