Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeErrors with "torch_two_sample.statistics_diff.SmoothKNNStatistic" #5

Open
adrienchaton opened this issue Sep 7, 2018 · 2 comments

Comments

@adrienchaton
Copy link

Hello !
Thank you very much for sharing these functions to the PyTorch community,

I am successfully using the "torch_two_sample.statistics_diff.MMDStatistic" both at training/backprop and evaluation, running on GPU.

I am trying to have an alternative criterion, using instead the SmoothKNN that I set as the MMD (with the additional True boolean for cuda and the k parameter).

The code is absolutely identical to when using the MMD for the criterion but with the SmoothKNN function it gives a "RuntimeError: one of the variables needed for gradient computation has been modified by an inplace operation"

Does anyone have a fix to that please ?

Thanks in advance !

@rekcahpassyla
Copy link

rekcahpassyla commented Jul 27, 2021

May be a bit late for you but I fixed it by changing statistics_diff.py in the SmoothKNNStatistic.__call__ function as follows (showing the lines that were commented out and replaced)

        if margs_.is_cuda:
            #margs.masked_scatter_(indices, margs_.view(-1))
            margs.masked_scatter_(indices, margs_.detach().view(-1))
        else:
            #margs.masked_scatter_(indices_cpu, margs_.view(-1))
            margs.masked_scatter_(indices_cpu, margs_.detach().view(-1))

@nkamath5
Copy link

Late to the party here, but for posterity:

@rekcahpassyla do you think gradients from the returned tensor in

return - (t_stat - mean) / std

will get propagated further backwards (until sample_1 & sample_2) correctly if you use detach() on margs_ ?

I think the real culprit here is the line

margs_ /= len(alphas)

which is an in-place operation and directly affects the computation graph involving margs_ .

I only changed margs_ /= len(alphas) to margs_ = margs_ / len(alphas) and got things working.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants