Skip to content

Commit

Permalink
Fix sampling not ending on EOS token.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 616827061
Change-Id: Ie14c6954669e7c71120c33f4f7eaf7d1d006906d
  • Loading branch information
ddsh authored and copybara-github committed Mar 18, 2024
1 parent 9937890 commit 4ecc4e7
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion gemma/sampler.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def _sample_step(
logits_buffer = sampler_state.logits_buffer

done = sampler_state.done | jnp.equal(
sampler_state.token_buffer[:, decoding_step + 1], self.vocab.eos_id()
token_buffer[:, decoding_step + 1], self.vocab.eos_id()
)

return _SamplingState(
Expand Down

0 comments on commit 4ecc4e7

Please sign in to comment.