remove errors from sample, add tests

This commit is contained in:
ParthSareen 2025-03-17 23:16:42 -04:00
parent 586557eb5a
commit 253b3c7a25
3 changed files with 45 additions and 23 deletions

View File

@ -2,7 +2,6 @@ package sample
import (
"errors"
"log/slog"
"math"
"math/rand/v2"
"slices"
@ -37,10 +36,7 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
tokens[i].value = logits[i]
}
t, err := s.sample(tokens)
if err != nil {
return -1, err
}
t := s.sample(tokens)
if s.grammar != nil {
// optimization: first check if the max logit is accepted by the grammar
@ -60,10 +56,7 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
tokens[i].value = logits[i]
}
s.grammar.Apply(tokens)
t, err = s.sample(tokens)
if err != nil {
return -1, err
}
t = s.sample(tokens)
s.grammar.Accept(t.id)
}
@ -84,9 +77,9 @@ func greedy(tokens []token) token {
// sample returns the highest probability token from the tokens
// given sampler parameters. It also has side effects of modifying the tokens
func (s *Sampler) sample(tokens []token) (token, error) {
func (s *Sampler) sample(tokens []token) token {
if s.temperature == 0 {
return greedy(tokens), nil
return greedy(tokens)
}
// topK also sorts the tokens in descending order of logits
@ -99,12 +92,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)
// fallback to greedy sampling if no tokens are left
if len(tokens) == 0 {
slog.Warn("sample: no tokens left after applying transforms, falling back to greedy sampling")
return greedy(tokens), nil
}
var r float32
if s.rng != nil {
r = s.rng.Float32()
@ -127,7 +114,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
return 1
})
return tokens[idx], nil
return tokens[idx]
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278

View File

@ -29,6 +29,20 @@ func TestWeighted(t *testing.T) {
if want != got {
t.Errorf("index mismatch: want %d, got %d", want, got)
}
// Test greedy fallback when sample() returns error
logits = []float32{1.0, 0.999, 0.5, 0.1}
// Use extremely small topP to filter out all tokens, forcing error and greedy fallback
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
got, err = sampler.Sample(logits)
if err != nil {
t.Error(err)
return
}
want = int32(0) // Should fall back to greedy and pick highest logit
if want != got {
t.Errorf("greedy fallback: want %d, got %d", want, got)
}
}
func BenchmarkSample(b *testing.B) {

View File

@ -171,7 +171,7 @@ func TestTopP(t *testing.T) {
// Test with very high p value
got := topP(tokens, 1.0)
// Should keep almost all tokens since p is very high
// Should keep all tokens since p is 1
if len(got) != len(input) {
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
}
@ -179,17 +179,17 @@ func TestTopP(t *testing.T) {
// Test with normal p value
got = topP(tokens, 0.95)
// Should keep tokens until cumulative probability > 0.95
if len(got) > 3 {
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
t.Logf("got: %v", got)
}
// Test edge case - ensure at least one token remains
input = []float32{-1e6, -1e6, -1e6} // One dominant token
input = []float32{-1e6, -1e6, -1e7}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 0.0) // Very small p
got = topP(tokens, 0.0)
if len(got) < 1 {
t.Error("topP should keep at least one token")
}
@ -202,10 +202,19 @@ func TestTopP(t *testing.T) {
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
t.Logf("got: %v", got)
}
tokens = toTokens(input)
tokens = topK(tokens, 20)
softmax(tokens)
got = topP(tokens, 1e-10)
if len(got) == 0 {
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
t.Logf("got: %v", got)
}
}
func TestMinP(t *testing.T) {
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
tokens := toTokens(input)
// First apply temperature and softmax
@ -242,6 +251,18 @@ func TestMinP(t *testing.T) {
t.Logf("got: %v", tokens)
}
// Test with single token
tokens = toTokens(input[:1])
tokens = topK(tokens, 20)
softmax(tokens)
tokens = minP(tokens, 0.1)
// Should keep only the highest probability token
if len(tokens) != 1 {
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
t.Logf("got: %v", tokens)
}
input = []float32{1e-10, 1e-10, 1e-10}
tokens = toTokens(input)
softmax(tokens)