Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix PyTorch stateful RNN/LSTM gradient computation error resolves #20875 #20916

Merged

Conversation

praveenhosdrug123
Copy link
Contributor

The error occurred because the PyTorch autograd engine detected that tensors required for gradient computation were modified in-place, invalidating the computational graph.

Fix: Added explicit state cloning in RNN.step() for PyTorch backend when stateful=True to create new tensor objects with separate memory allocation.

Copy link

google-cla bot commented Feb 17, 2025

Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA).

View this failed invocation of the CLA check for more information.

For the most up to date status, view the checks section at the bottom of the pull request.

@codecov-commenter
Copy link

codecov-commenter commented Feb 17, 2025

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 82.44%. Comparing base (86873b5) to head (318c07d).
Report is 23 commits behind head on master.

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #20916      +/-   ##
==========================================
+ Coverage   82.22%   82.44%   +0.21%     
==========================================
  Files         561      561              
  Lines       52955    53219     +264     
  Branches     8205     8245      +40     
==========================================
+ Hits        43544    43876     +332     
+ Misses       7373     7336      -37     
+ Partials     2038     2007      -31     
Flag Coverage Δ
keras 82.26% <100.00%> (+0.22%) ⬆️
keras-jax 64.01% <0.00%> (-0.04%) ⬇️
keras-numpy 58.83% <0.00%> (-0.01%) ⬇️
keras-openvino 32.64% <0.00%> (+0.23%) ⬆️
keras-tensorflow 64.46% <0.00%> (-0.17%) ⬇️
keras-torch 64.08% <100.00%> (-0.03%) ⬇️

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@praveenhosdrug123 praveenhosdrug123 changed the title Fix PyTorch stateful RNN/LSTM gradient computation error Fix PyTorch stateful RNN/LSTM gradient computation error resolves #20875 Feb 17, 2025
Copy link
Collaborator

@fchollet fchollet left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR!

@google-ml-butler google-ml-butler bot added kokoro:force-run ready to pull Ready to be merged into the codebase labels Mar 3, 2025
@fchollet fchollet merged commit f7115c2 into keras-team:master Mar 3, 2025
7 checks passed
SaifMohammed22 pushed a commit to SaifMohammed22/keras that referenced this pull request Mar 6, 2025
…as-team#20875 (keras-team#20916)

* Fix PyTorch stateful RNN gradient computation error

* Updates post feedback
11happy pushed a commit to 11happy/keras that referenced this pull request Mar 9, 2025
…as-team#20875 (keras-team#20916)

* Fix PyTorch stateful RNN gradient computation error

* Updates post feedback
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
kokoro:force-run ready to pull Ready to be merged into the codebase size:S
Projects
Status: Merged
Development

Successfully merging this pull request may close these issues.

4 participants