-
Notifications
You must be signed in to change notification settings - Fork 19.5k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Fix PyTorch stateful RNN/LSTM gradient computation error resolves #20875 #20916
Fix PyTorch stateful RNN/LSTM gradient computation error resolves #20875 #20916
Conversation
Thanks for your pull request! It looks like this may be your first contribution to a Google open source project. Before we can look at your pull request, you'll need to sign a Contributor License Agreement (CLA). View this failed invocation of the CLA check for more information. For the most up to date status, view the checks section at the bottom of the pull request. |
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #20916 +/- ##
==========================================
+ Coverage 82.22% 82.44% +0.21%
==========================================
Files 561 561
Lines 52955 53219 +264
Branches 8205 8245 +40
==========================================
+ Hits 43544 43876 +332
+ Misses 7373 7336 -37
+ Partials 2038 2007 -31
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. |
4bfd4a8
to
e0c4415
Compare
e0c4415
to
48e20f6
Compare
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the PR!
…as-team#20875 (keras-team#20916) * Fix PyTorch stateful RNN gradient computation error * Updates post feedback
…as-team#20875 (keras-team#20916) * Fix PyTorch stateful RNN gradient computation error * Updates post feedback
The error occurred because the PyTorch autograd engine detected that tensors required for gradient computation were modified in-place, invalidating the computational graph.
Fix: Added explicit state cloning in RNN.step() for PyTorch backend when stateful=True to create new tensor objects with separate memory allocation.