mirror of
https://github.com/ollama/ollama.git
synced 2025-03-29 11:11:47 +01:00
remove errors from sample, add tests
This commit is contained in:
parent
586557eb5a
commit
253b3c7a25
@ -2,7 +2,6 @@ package sample
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
@ -37,10 +36,7 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
tokens[i].value = logits[i]
|
||||
}
|
||||
|
||||
t, err := s.sample(tokens)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
t := s.sample(tokens)
|
||||
|
||||
if s.grammar != nil {
|
||||
// optimization: first check if the max logit is accepted by the grammar
|
||||
@ -60,10 +56,7 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
tokens[i].value = logits[i]
|
||||
}
|
||||
s.grammar.Apply(tokens)
|
||||
t, err = s.sample(tokens)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
t = s.sample(tokens)
|
||||
s.grammar.Accept(t.id)
|
||||
}
|
||||
|
||||
@ -84,9 +77,9 @@ func greedy(tokens []token) token {
|
||||
|
||||
// sample returns the highest probability token from the tokens
|
||||
// given sampler parameters. It also has side effects of modifying the tokens
|
||||
func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
func (s *Sampler) sample(tokens []token) token {
|
||||
if s.temperature == 0 {
|
||||
return greedy(tokens), nil
|
||||
return greedy(tokens)
|
||||
}
|
||||
|
||||
// topK also sorts the tokens in descending order of logits
|
||||
@ -99,12 +92,6 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
// fallback to greedy sampling if no tokens are left
|
||||
if len(tokens) == 0 {
|
||||
slog.Warn("sample: no tokens left after applying transforms, falling back to greedy sampling")
|
||||
return greedy(tokens), nil
|
||||
}
|
||||
|
||||
var r float32
|
||||
if s.rng != nil {
|
||||
r = s.rng.Float32()
|
||||
@ -127,7 +114,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
return 1
|
||||
})
|
||||
|
||||
return tokens[idx], nil
|
||||
return tokens[idx]
|
||||
}
|
||||
|
||||
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
|
||||
|
@ -29,6 +29,20 @@ func TestWeighted(t *testing.T) {
|
||||
if want != got {
|
||||
t.Errorf("index mismatch: want %d, got %d", want, got)
|
||||
}
|
||||
|
||||
// Test greedy fallback when sample() returns error
|
||||
logits = []float32{1.0, 0.999, 0.5, 0.1}
|
||||
// Use extremely small topP to filter out all tokens, forcing error and greedy fallback
|
||||
sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil)
|
||||
got, err = sampler.Sample(logits)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want = int32(0) // Should fall back to greedy and pick highest logit
|
||||
if want != got {
|
||||
t.Errorf("greedy fallback: want %d, got %d", want, got)
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
|
@ -171,7 +171,7 @@ func TestTopP(t *testing.T) {
|
||||
// Test with very high p value
|
||||
got := topP(tokens, 1.0)
|
||||
|
||||
// Should keep almost all tokens since p is very high
|
||||
// Should keep all tokens since p is 1
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||
}
|
||||
@ -179,17 +179,17 @@ func TestTopP(t *testing.T) {
|
||||
// Test with normal p value
|
||||
got = topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumulative probability > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
||||
input = []float32{-1e6, -1e6, -1e7}
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
got = topP(tokens, 0.0) // Very small p
|
||||
got = topP(tokens, 0.0)
|
||||
if len(got) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
}
|
||||
@ -202,10 +202,19 @@ func TestTopP(t *testing.T) {
|
||||
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
tokens = toTokens(input)
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
got = topP(tokens, 1e-10)
|
||||
if len(got) == 0 {
|
||||
t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP(t *testing.T) {
|
||||
input := []float32{-3, -2, -1, 0, 1, 2, 4, 3}
|
||||
input := []float32{-2, 0, -1, -3, 2, 1, 4, 3}
|
||||
tokens := toTokens(input)
|
||||
|
||||
// First apply temperature and softmax
|
||||
@ -242,6 +251,18 @@ func TestMinP(t *testing.T) {
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
// Test with single token
|
||||
tokens = toTokens(input[:1])
|
||||
tokens = topK(tokens, 20)
|
||||
softmax(tokens)
|
||||
tokens = minP(tokens, 0.1)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(tokens) != 1 {
|
||||
t.Errorf("minP(0.1): should return single token, got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
}
|
||||
|
||||
input = []float32{1e-10, 1e-10, 1e-10}
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
|
Loading…
x
Reference in New Issue
Block a user