mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 11:28:17 +02:00
sample: add greedy sample as fallback
This commit is contained in:
parent
108fe02165
commit
586557eb5a
@ -2,6 +2,7 @@ package sample
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"log/slog"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
@ -26,6 +27,10 @@ type Sampler struct {
|
||||
}
|
||||
|
||||
func (s *Sampler) Sample(logits []float32) (int32, error) {
|
||||
if len(logits) == 0 {
|
||||
return -1, errors.New("sample: no logits provided to sample")
|
||||
}
|
||||
|
||||
tokens := make([]token, len(logits))
|
||||
for i := range logits {
|
||||
tokens[i].id = int32(i)
|
||||
@ -94,11 +99,10 @@ func (s *Sampler) sample(tokens []token) (token, error) {
|
||||
tokens = topP(tokens, s.topP)
|
||||
tokens = minP(tokens, s.minP)
|
||||
|
||||
// TODO: this should fall back to greedy sampling
|
||||
// or topP, topK values etc should be such that
|
||||
// there are always tokens to sample from
|
||||
// fallback to greedy sampling if no tokens are left
|
||||
if len(tokens) == 0 {
|
||||
return token{}, errors.New("no tokens to sample from")
|
||||
slog.Warn("sample: no tokens left after applying transforms, falling back to greedy sampling")
|
||||
return greedy(tokens), nil
|
||||
}
|
||||
|
||||
var r float32
|
||||
|
@ -168,23 +168,40 @@ func TestTopP(t *testing.T) {
|
||||
softmax(tokens)
|
||||
tokens = topK(tokens, 20)
|
||||
|
||||
// Then apply topP
|
||||
tokens = topP(tokens, 0.95)
|
||||
// Test with very high p value
|
||||
got := topP(tokens, 1.0)
|
||||
|
||||
// Should keep tokens until cumsum > 0.95
|
||||
if len(tokens) > 3 {
|
||||
// Should keep almost all tokens since p is very high
|
||||
if len(got) != len(input) {
|
||||
t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input))
|
||||
}
|
||||
|
||||
// Test with normal p value
|
||||
got = topP(tokens, 0.95)
|
||||
|
||||
// Should keep tokens until cumulative probability > 0.95
|
||||
if len(got) > 3 {
|
||||
t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens))
|
||||
t.Logf("got: %v", tokens)
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
// Test edge case - ensure at least one token remains
|
||||
input = []float32{-1e6, -1e6, -1e6} // One dominant token
|
||||
tokens = toTokens(input)
|
||||
softmax(tokens)
|
||||
tokens = topP(tokens, 0.0) // Very small p
|
||||
if len(tokens) < 1 {
|
||||
got = topP(tokens, 0.0) // Very small p
|
||||
if len(got) < 1 {
|
||||
t.Error("topP should keep at least one token")
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
got = topP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(got) != 1 {
|
||||
t.Errorf("topP(0.0): should keep only one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP(t *testing.T) {
|
||||
@ -231,24 +248,30 @@ func TestMinP(t *testing.T) {
|
||||
tokens = minP(tokens, 1.0)
|
||||
if len(tokens) < 1 {
|
||||
t.Error("minP should keep at least one token even with extreme probabilities")
|
||||
}
|
||||
}
|
||||
got := minP(tokens, 1.0)
|
||||
|
||||
func TestSortLogits(t *testing.T) {
|
||||
input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367}
|
||||
tokens := toTokens(input)
|
||||
if len(got) != 1 {
|
||||
t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens))
|
||||
}
|
||||
|
||||
tokens = topK(tokens, 20)
|
||||
// Test with normal p value
|
||||
got = minP(tokens, 0.2)
|
||||
|
||||
for i := 1; i < len(tokens); i++ {
|
||||
if tokens[i].value > tokens[i-1].value {
|
||||
t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f",
|
||||
i, tokens[i].value, tokens[i-1].value)
|
||||
// Should keep tokens with prob >= 0.2 * max_prob
|
||||
if len(got) > 3 {
|
||||
t.Errorf("minP(0.2): kept too many tokens: got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
|
||||
// Test with zero p value
|
||||
got = minP(tokens, 0.0)
|
||||
|
||||
// Should keep only the highest probability token
|
||||
if len(got) != len(tokens) {
|
||||
t.Errorf("minP(0.0): should keep only one token, got %d", len(got))
|
||||
t.Logf("got: %v", got)
|
||||
}
|
||||
}
|
||||
|
||||
want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839}
|
||||
compareLogits(t, "sortLogits", want, tokens)
|
||||
}
|
||||
|
||||
func BenchmarkTransforms(b *testing.B) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user