Skip to content

Commit ea93767

Browse files
committed
1804 RNN fix
1 parent 1737ce1 commit ea93767

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

apex/amp/wrap.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -138,7 +138,7 @@ def rnn_wrapper(*args, **kwargs):
138138
# autograd graph correctly backprops from the wgrads computed
139139
# inside cuDNN (on fp16 weights) into the fp32 weights.
140140
assert utils.type_string(flat_weight) == 'FloatTensor'
141-
if compat.tensor_is_float_tensor():
141+
if compat.tensor_is_float_tensor() or compat.tensor_is_variable():
142142
# Pre-0.4. A little slower, since it zeros out memory.
143143
flat_weight_fp16 = flat_weight.new().half().resize_(flat_weight.shape)
144144
else:

0 commit comments

Comments
 (0)