Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix features expand: expand() -> expand().clone() #110

Merged
merged 2 commits into from
Mar 31, 2020
Merged

Fix features expand: expand() -> expand().clone() #110

merged 2 commits into from
Mar 31, 2020

Conversation

rnagumo
Copy link
Contributor

@rnagumo rnagumo commented Mar 9, 2020

I found a bug regarding the saving/loading Pixyz object.

1. Problem

First, I make the Normal distribution instance, and save its parameters by torch.save().

>>> import torch
>>> from pixyz.distributions import Normal
>>> z_dim = 2
>>> p = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), features_shape=[z_dim])
>>> torch.save(p.state_dict(), "./tmp.pt")

Next, when I load the saved file with the same class object, it raises a RuntimeError. The error message tells that the parameter's dimensions in the model and those in the checkpoint are different, although both seem to be the same size torch.Size([1, 2]).

>>> q = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), features_shape=[z_dim])
>>> q.load_state_dict(torch.load("./tmp.pt"))
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-8-0b2e959a1927> in <module>
----> 1 q.load_state_dict(torch.load("./tmp.pt"))

~/pixyz/.venv/lib/python3.7/site-packages/torch/nn/modules/module.py in load_state_dict(self, state_dict, strict)
    828         if len(error_msgs) > 0:
    829             raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
--> 830                                self.__class__.__name__, "\n\t".join(error_msgs)))
    831         return _IncompatibleKeys(missing_keys, unexpected_keys)
    832 

RuntimeError: Error(s) in loading state_dict for Normal:
        While copying the parameter named "loc", whose dimensions in the model are torch.Size([1, 2]) and whose dimensions in the checkpoint are torch.Size([1, 2]).
        While copying the parameter named "scale", whose dimensions in the model are torch.Size([1, 2]) and whose dimensions in the checkpoint are torch.Size([1, 2]).

I test the other implementation of Normal distribution. The following is also valid Normal distribution with the same dimension, and it correctly loads the saved parameters.

>>> q = Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim))
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>

2. Change

This is because of the features.expand() method in _check_features_shape() method, which is called when an object is created. When the tensor size of the given parameter is empty, DistributionBase class automatically expands its dimension without memory allocation.

ref) https://pytorch.org/docs/stable/tensors.html#torch.Tensor.expand

However, once the parameters are saved into the checkpoint file, it seems to need full memory allocation when loading the saved tensors (no reference found).

Therefore, I added the clone() method when expanding the feature's shape. Although it wastes a little memory, it correctly works.

>>> q = Normal(loc=torch.tensor(0.), scale=torch.tensor(1.), features_shape=[z_dim]) 
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>
>>> q = Normal(loc=torch.zeros(z_dim), scale=torch.ones(z_dim))
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>

It is my pleasure if this pull request would help you.

Thank you.

@ktaaaki
Copy link
Collaborator

ktaaaki commented Mar 10, 2020

Thank you not only for bug reports, but also for easy-to-read pull requests! After reading your comment, I found the bug fix can be generalized.

If you specify an expanded tensor as a parameter like Normal(loc=torch.tensor(0.).expand(1,2), scale=torch.ones(1,2)), the error still returns. How about converting the tensor passed to torch.nn.Module.register_buffer to contiguous as follows ? :

    def _check_features_shape(self, features):
        # scalar
        if features.size() == torch.Size():
            features = features.expand(self.features_shape)

        if self.features_shape == torch.Size():
            self._features_shape = features.shape

        # for the issue of torch.load (#110)
        if not features.is_contiguous():
            features = features.contiguous()

        if features.size() == self.features_shape:
            batches = features.unsqueeze(0)
            return batches

        raise ValueError("the shape of a given parameter {} and features_shape {} "
                         "do not match.".format(features.size(), self.features_shape))

@rnagumo
Copy link
Contributor Author

rnagumo commented Mar 10, 2020

Thank you for your reply. Your suggestion seems more general, so I changed the code. I checked that all the following implementation of Normal distribution could load the saved parameters.

>>> q = Normal(loc=torch.tensor(0.).expand(2), scale=torch.ones(2))
>>> q.load_state_dict(torch.load("./tmp.pt"))                                             
<All keys matched successfully>
>>> q = Normal(loc=torch.tensor(0.), scale=torch.ones(2), features_shape=[2])
>>> q.load_state_dict(torch.load("./tmp.pt"))
<All keys matched successfully>

@masa-su masa-su merged commit e5a7d2f into masa-su:develop/v0.2.0 Mar 31, 2020
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants