Skip to content

Commit

Permalink
fix: segfault when logits_all=False. Closes #1319
Browse files Browse the repository at this point in the history
  • Loading branch information
abetlen committed Apr 3, 2024
1 parent f96de6d commit 8649d76
Showing 1 changed file with 10 additions and 8 deletions.
18 changes: 10 additions & 8 deletions llama_cpp/llama.py
Original file line number Diff line number Diff line change
Expand Up @@ -535,14 +535,16 @@ def eval(self, tokens: Sequence[int]):
# Save tokens
self.input_ids[n_past : n_past + n_tokens] = batch
# Save logits
rows = n_tokens
cols = self._n_vocab
offset = (
0 if self.context_params.logits_all else n_tokens - 1
) # NOTE: Only save the last token logits if logits_all is False
self.scores[n_past + offset : n_past + n_tokens, :].reshape(-1)[
:
] = self._ctx.get_logits()[offset * cols : rows * cols]
if self.context_params.logits_all:
rows = n_tokens
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
self.scores[n_past : n_past + n_tokens, :].reshape(-1)[: :] = logits
else:
rows = 1
cols = self._n_vocab
logits = self._ctx.get_logits()[: rows * cols]
self.scores[n_past + n_tokens - 1, :].reshape(-1)[: :] = logits
# Update n_tokens
self.n_tokens += n_tokens

Expand Down

0 comments on commit 8649d76

Please sign in to comment.