Fix features expand: expand() -> expand().clone() #110
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
I found a bug regarding the saving/loading Pixyz object.
1. Problem
First, I make the Normal distribution instance, and save its parameters by
torch.save()
.Next, when I load the saved file with the same class object, it raises a RuntimeError. The error message tells that the parameter's dimensions in the model and those in the checkpoint are different, although both seem to be the same size
torch.Size([1, 2])
.I test the other implementation of Normal distribution. The following is also valid Normal distribution with the same dimension, and it correctly loads the saved parameters.
2. Change
This is because of the
features.expand()
method in_check_features_shape()
method, which is called when an object is created. When the tensor size of the given parameter is empty, DistributionBase class automatically expands its dimension without memory allocation.ref) https://pytorch.org/docs/stable/tensors.html#torch.Tensor.expand
However, once the parameters are saved into the checkpoint file, it seems to need full memory allocation when loading the saved tensors (no reference found).
Therefore, I added the
clone()
method when expanding the feature's shape. Although it wastes a little memory, it correctly works.It is my pleasure if this pull request would help you.
Thank you.