Skip to content

Commit 3d7361c

Browse files
committed
use variance in gradient guide instead of std
1 parent da8b87c commit 3d7361c

File tree

2 files changed

+3
-2
lines changed

2 files changed

+3
-2
lines changed

diffuser/sampling/functions.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -12,13 +12,14 @@ def n_step_guided_p_sample(
1212
):
1313
model_log_variance = extract(model.posterior_log_variance_clipped, t, x.shape)
1414
model_std = torch.exp(0.5 * model_log_variance)
15+
model_var = torch.exp(model_log_variance)
1516

1617
for _ in range(n_guide_steps):
1718
with torch.enable_grad():
1819
y, grad = guide.gradients(x, cond, t)
1920

2021
if scale_grad_by_std:
21-
grad = model_std * grad
22+
grad = model_var * grad
2223

2324
grad[t < t_stopgrad] = 0
2425

slurm/plan.sh

+1-1
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ do
1414
python -u scripts/plan_guided.py \
1515
--logbase logs/pretrained \
1616
--dataset $env-$buffer-v2 \
17-
--prefix plans/reference \
17+
--prefix plans/reference_var \
1818
--vis_freq 500 \
1919
--verbose False \
2020
--suffix {1} \

0 commit comments

Comments
 (0)