Implementation(mostly from original implementation) and explanation of VPSDE
- Foward Process
- Predictor-Corrector sampling (pc sampling)
- Train and sample script with MNIST
- ODE sampling
- Likelihood estimation
- BPD evaluation
- Parameter tuning with Ray
To tune hyperparameters:
python tune_script.py
You can modify tune_config in tune_script.py, if you want to use different parameter range, or add/remove hyperparameter search space. There won't be any sample generated during tuning due to time constraint.
For this script, I tuned over learning rate, batch size, number of resblock in Unet, channel multiplier of Unet, and sampling eps which is used to estimate bits per dim.
To train and sample:
python run_script.py
You can set hyperparameters in config.py based on tuning result. If not, it is set to be the value used in original paper. Note that you should set self.sampler in config.py to 'ode', 'pc', or 'both', if you want to generate sample every epoch during training and testing.
Training Log
BPD | Test DSM loss | Train DSM loss |
---|---|---|
![]() |
![]() |
![]() |
Generated samples
PC sampling | ODE sampling |
---|---|
![]() |
![]() |
- torch, pytorch-lightning, torchvision
- Original paper:Score-Based Generative Modeling through Stochastic Differential Equations
- Official implementation
- Unet: Score network
The two base equations are:
The first equation is from DDPM, where it assumes noise levels are discrete. The second equation is more general and noise level is assumed to be continuous. In Official implementation, it chooses
With vpsde, the author formulated transition probability
$$ \begin{align*} \int_{0}^{t} \beta(s),ds &= \int_{0}^{t} \beta_{min}+\frac{\beta_{max}-\beta_{min}}{T}s,ds \ &= \frac{1}{2}s^{2}(\beta_{max}-\beta_{min})+\beta_{min}s \Big|{0}^{t} \ &= \frac{1}{2}t^{2}(\beta{max}-\beta_{min})+t*\beta_{min} \end{align*} $$
Hence,
def perturb(self, x):
batch_size = x.shape[0]
t = torch.rand(batch_size).cuda() * (self.T - self.eps) + self.eps
z = torch.randn_like(x).cuda()
mean, std = self.marginal_prob_mean_std(x, t)
x_tilda = mean + std.view(-1, 1, 1, 1) * z
return x_tilda, t, z, mean, std
DSM objective is pretty clear as wrote in paper:
$$J(\theta)=\mathbb{E}{t\sim \mathcal{U}(0, T)} [\lambda(t) \mathbb{E}{\mathbf{x}(0) \sim p_0(\mathbf{x})}\mathbf{E}{\mathbf{x}(t) \sim p{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))}[ |s_\theta(\mathbf{x}(t), t) - \nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))|_2^2]]$$
So we sample
Now let's derive
By recalling how we performed data perturbation in perturb function,
x_tilda = mean + std.view(-1, 1, 1, 1) * z
Hence we have derived:
With the property of
$$
\begin{align*}
J(\theta)&=\mathbb{E}{t\sim \mathcal{U}(0, T)} [\lambda(t) \mathbb{E}{\mathbf{x}(0) \sim p_0(\mathbf{x})}\mathbf{E}{\mathbf{x}(t) \sim p{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))}[ |s_\theta(\mathbf{x}(t), t) - \nabla_{\mathbf{x}(t)}\log p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))|2^2]] \
&=\mathbb{E}{t\sim \mathcal{U}(0, T)} [ \mathbb{E}{\mathbf{x}(0) \sim p_0(\mathbf{x})}\mathbf{E}{\mathbf{x}(t) \sim p_{0t}(\mathbf{x}(t) \mid \mathbf{x}(0))}[ \lambda(t)|s_\theta(\mathbf{x}(t), t) + \frac{z}{\sigma}|2^2]] \
&= \frac{1}{N} \sum{i=0}^{N}[ \mathbb{E}{\mathbf{x}(0) \sim p_0(\mathbf{x})}\mathbf{E}{\mathbf{x}(t_i) \sim p_{0t_{i}}(\mathbf{x}(t_{i}) \mid \mathbf{x}(0))}[(s_\theta(\mathbf{x}(t_i), t_i), \sigma_i + z)^2]]
\end{align*}
$$
Here is my implementation for DSM loss:
def forward(self, x):
x_tilda, t, z, mean, std = self.perturb(x)
normed_score = self.score_func(x_tilda, t) / std.view(-1, 1, 1, 1)
return normed_score, std, t, z
def dsm_loss(self, x):
normed_score, std, t, z = self(x)
dsm_loss = torch.mean(torch.sum((normed_score * std.view(-1, 1, 1, 1) + z)**2, dim=(1, 2, 3)))
return dsm_loss
The equation for reverse SDE is:
I'm not quite certain about why
With this equation(predictor), we can do sampling already. Since
For every SDE:
Some implementation detail of ODE sampling:
scipy.integrate.solve_ivp need a callable function
Also solve_ivp only takes 1D ndarry. So remember to transform and reshape torch.tensor to 1D ndarray.
With the associated ODE:
How? Remember what we throw into the solver is