From 253b3c7a2562355ae77edd4c862d2c1901b36f63 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Mon, 17 Mar 2025 23:16:42 -0400 Subject: [PATCH] remove errors from sample, add tests --- sample/samplers.go | 23 +++++------------------ sample/samplers_test.go | 14 ++++++++++++++ sample/transforms_test.go | 31 ++++++++++++++++++++++++++----- 3 files changed, 45 insertions(+), 23 deletions(-) diff --git a/sample/samplers.go b/sample/samplers.go index a3fcd404c..2596148af 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -2,7 +2,6 @@ package sample import ( "errors" - "log/slog" "math" "math/rand/v2" "slices" @@ -37,10 +36,7 @@ func (s *Sampler) Sample(logits []float32) (int32, error) { tokens[i].value = logits[i] } - t, err := s.sample(tokens) - if err != nil { - return -1, err - } + t := s.sample(tokens) if s.grammar != nil { // optimization: first check if the max logit is accepted by the grammar @@ -60,10 +56,7 @@ func (s *Sampler) Sample(logits []float32) (int32, error) { tokens[i].value = logits[i] } s.grammar.Apply(tokens) - t, err = s.sample(tokens) - if err != nil { - return -1, err - } + t = s.sample(tokens) s.grammar.Accept(t.id) } @@ -84,9 +77,9 @@ func greedy(tokens []token) token { // sample returns the highest probability token from the tokens // given sampler parameters. It also has side effects of modifying the tokens -func (s *Sampler) sample(tokens []token) (token, error) { +func (s *Sampler) sample(tokens []token) token { if s.temperature == 0 { - return greedy(tokens), nil + return greedy(tokens) } // topK also sorts the tokens in descending order of logits @@ -99,12 +92,6 @@ func (s *Sampler) sample(tokens []token) (token, error) { tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) - // fallback to greedy sampling if no tokens are left - if len(tokens) == 0 { - slog.Warn("sample: no tokens left after applying transforms, falling back to greedy sampling") - return greedy(tokens), nil - } - var r float32 if s.rng != nil { r = s.rng.Float32() @@ -127,7 +114,7 @@ func (s *Sampler) sample(tokens []token) (token, error) { return 1 }) - return tokens[idx], nil + return tokens[idx] } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 diff --git a/sample/samplers_test.go b/sample/samplers_test.go index 38b9b352a..e399d8a75 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -29,6 +29,20 @@ func TestWeighted(t *testing.T) { if want != got { t.Errorf("index mismatch: want %d, got %d", want, got) } + + // Test greedy fallback when sample() returns error + logits = []float32{1.0, 0.999, 0.5, 0.1} + // Use extremely small topP to filter out all tokens, forcing error and greedy fallback + sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil) + got, err = sampler.Sample(logits) + if err != nil { + t.Error(err) + return + } + want = int32(0) // Should fall back to greedy and pick highest logit + if want != got { + t.Errorf("greedy fallback: want %d, got %d", want, got) + } } func BenchmarkSample(b *testing.B) { diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 69c0792ab..5307c5f8a 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -171,7 +171,7 @@ func TestTopP(t *testing.T) { // Test with very high p value got := topP(tokens, 1.0) - // Should keep almost all tokens since p is very high + // Should keep all tokens since p is 1 if len(got) != len(input) { t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input)) } @@ -179,17 +179,17 @@ func TestTopP(t *testing.T) { // Test with normal p value got = topP(tokens, 0.95) - // Should keep tokens until cumulative probability > 0.95 if len(got) > 3 { t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens)) t.Logf("got: %v", got) } // Test edge case - ensure at least one token remains - input = []float32{-1e6, -1e6, -1e6} // One dominant token + input = []float32{-1e6, -1e6, -1e7} tokens = toTokens(input) + tokens = topK(tokens, 20) softmax(tokens) - got = topP(tokens, 0.0) // Very small p + got = topP(tokens, 0.0) if len(got) < 1 { t.Error("topP should keep at least one token") } @@ -202,10 +202,19 @@ func TestTopP(t *testing.T) { t.Errorf("topP(0.0): should keep only one token, got %d", len(got)) t.Logf("got: %v", got) } + + tokens = toTokens(input) + tokens = topK(tokens, 20) + softmax(tokens) + got = topP(tokens, 1e-10) + if len(got) == 0 { + t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got)) + t.Logf("got: %v", got) + } } func TestMinP(t *testing.T) { - input := []float32{-3, -2, -1, 0, 1, 2, 4, 3} + input := []float32{-2, 0, -1, -3, 2, 1, 4, 3} tokens := toTokens(input) // First apply temperature and softmax @@ -242,6 +251,18 @@ func TestMinP(t *testing.T) { t.Logf("got: %v", tokens) } + // Test with single token + tokens = toTokens(input[:1]) + tokens = topK(tokens, 20) + softmax(tokens) + tokens = minP(tokens, 0.1) + + // Should keep only the highest probability token + if len(tokens) != 1 { + t.Errorf("minP(0.1): should return single token, got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + input = []float32{1e-10, 1e-10, 1e-10} tokens = toTokens(input) softmax(tokens)