diff --git a/keras/src/layers/rnn/rnn.py b/keras/src/layers/rnn/rnn.py index b0cbc795aeb5..e595eb52637c 100644 --- a/keras/src/layers/rnn/rnn.py +++ b/keras/src/layers/rnn/rnn.py @@ -331,6 +331,12 @@ def inner_loop(self, sequences, initial_state, mask, training=False): cell_kwargs["training"] = training def step(inputs, states): + # Create new tensor copies when using PyTorch backend + # with stateful=True. This prevents in-place modifications + # that would otherwise break PyTorch's autograd functionality + # by modifying tensors needed for gradient computation. + if backend.backend() == "torch" and self.stateful: + states = tree.map_structure(ops.copy, states) output, new_states = self.cell(inputs, states, **cell_kwargs) if not tree.is_nested(new_states): new_states = [new_states]