Skip to content

Commit

Permalink
MAINT: raise error earlier if there is no sufficient data points to s…
Browse files Browse the repository at this point in the history
…uggest LR

Now LR finder will raise a RuntimeError if there is no sufficient data
points to calculate gradient for suggested LR when
`lr_finder.plot(..., suggest_lr=True)` is called.

The error message will clarify the details of failure, so users can fix
the issue earlier as well.
  • Loading branch information
NaleRaphael committed Aug 25, 2024
1 parent ad5971b commit f642cd1
Show file tree
Hide file tree
Showing 2 changed files with 36 additions and 24 deletions.
18 changes: 15 additions & 3 deletions tests/test_lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -401,9 +401,21 @@ def test_plot_with_skip_and_suggest_lr(suggest_lr, skip_start, skip_end):
)

fig, ax = plt.subplots()
results = lr_finder.plot(
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
)

results = None
if suggest_lr and num_iter < (skip_start + skip_end + 2):
# No sufficient data points to calculate gradient, so this call should fail
with pytest.raises(RuntimeError, match="Need at least"):
results = lr_finder.plot(
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
)

# No need to proceed then
return
else:
results = lr_finder.plot(
skip_start=skip_start, skip_end=skip_end, suggest_lr=suggest_lr, ax=ax
)

# NOTE:
# - ax.lines[0]: the lr-loss curve. It should be always available once
Expand Down
42 changes: 21 additions & 21 deletions torch_lr_finder/lr_finder.py
Original file line number Diff line number Diff line change
Expand Up @@ -510,6 +510,13 @@ def plot(
if show_lr is not None and not isinstance(show_lr, float):
raise ValueError("show_lr must be float")

# Make sure there are enough data points to suggest a learning rate
if suggest_lr and len(self.history["lr"]) < (skip_start + skip_end + 2):
raise RuntimeError(
f"Need at least {skip_start + skip_end + 2} iterations to suggest a "
f"learning rate. Got {len(self.history['lr'])}"
)

# Get the data to plot from the history dictionary. Also, handle skip_end=0
# properly so the behaviour is the expected
lrs = self.history["lr"]
Expand All @@ -533,26 +540,19 @@ def plot(
if suggest_lr:
# 'steepest': the point with steepest gradient (minimal gradient)
print("LR suggestion: steepest gradient")
min_grad_idx = None
try:
min_grad_idx = (np.gradient(np.array(losses))).argmin()
except ValueError:
print(
"Failed to compute the gradients, there might not be enough points. "
"Please check whether num_iter >= (skip_start + skip_end + 2)."
)
if min_grad_idx is not None:
print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
ax.scatter(
lrs[min_grad_idx],
losses[min_grad_idx],
s=75,
marker="o",
color="red",
zorder=3,
label="steepest gradient",
)
ax.legend()
min_grad_idx = (np.gradient(np.array(losses))).argmin()

print("Suggested LR: {:.2E}".format(lrs[min_grad_idx]))
ax.scatter(
lrs[min_grad_idx],
losses[min_grad_idx],
s=75,
marker="o",
color="red",
zorder=3,
label="steepest gradient",
)
ax.legend()

if log_lr:
ax.set_xscale("log")
Expand All @@ -568,7 +568,7 @@ def plot(

if suggest_lr:
# If suggest_lr is set, then we should always return 2 values.
suggest_lr = None if min_grad_idx is None else lrs[min_grad_idx]
suggest_lr = lrs[min_grad_idx]
return ax, suggest_lr
else:
return ax
Expand Down

0 comments on commit f642cd1

Please sign in to comment.