From 42a14f7f633110ab83343848865d4612cfefb398 Mon Sep 17 00:00:00 2001 From: Parth Sareen Date: Thu, 20 Mar 2025 11:11:18 -0700 Subject: [PATCH] sample: add error handling for empty logits (#9740) --- sample/samplers.go | 14 +++---- sample/samplers_test.go | 24 +++++++++++ sample/transforms_test.go | 88 +++++++++++++++++++++++++++++---------- 3 files changed, 97 insertions(+), 29 deletions(-) diff --git a/sample/samplers.go b/sample/samplers.go index 7c12da08b..ef8033691 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -26,6 +26,10 @@ type Sampler struct { } func (s *Sampler) Sample(logits []float32) (int32, error) { + if len(logits) == 0 { + return -1, errors.New("sample: no logits provided to sample") + } + tokens := make([]token, len(logits)) for i := range logits { tokens[i].id = int32(i) @@ -94,13 +98,6 @@ func (s *Sampler) sample(tokens []token) (token, error) { tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) - // TODO: this should fall back to greedy sampling - // or topP, topK values etc should be such that - // there are always tokens to sample from - if len(tokens) == 0 { - return token{}, errors.New("no tokens to sample from") - } - var r float32 if s.rng != nil { r = s.rng.Float32() @@ -123,6 +120,9 @@ func (s *Sampler) sample(tokens []token) (token, error) { return 1 }) + if math.IsNaN(float64(sum)) { + return token{}, errors.New("sample: logits sum to NaN, check model output") + } return tokens[idx], nil } diff --git a/sample/samplers_test.go b/sample/samplers_test.go index 38b9b352a..d79dce474 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -1,6 +1,7 @@ package sample import ( + "math" "math/rand/v2" "testing" ) @@ -29,6 +30,29 @@ func TestWeighted(t *testing.T) { if want != got { t.Errorf("index mismatch: want %d, got %d", want, got) } + + // Test very high p + logits = []float32{1.0, 0.9999999999999999, 0.5, 0.1} + // Use extremely small topP to filter out all tokens + sampler = NewSampler(1.0, 0, 1e-10, 0, 0, nil) + got, err = sampler.Sample(logits) + if err != nil { + t.Error(err) + return + } + // Should get the token with the highest logit + want = int32(0) + if want != got { + t.Errorf("index mismatch: want %d, got %d", want, got) + } + + logits = []float32{float32(math.NaN()), float32(math.NaN()), float32(math.NaN())} + sampler = NewSampler(1, 0, 0.95, 0.05, 0, nil) + got, err = sampler.Sample(logits) + if err == nil { + t.Errorf("expected error, got %d", got) + return + } } func BenchmarkSample(b *testing.B) { diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 7faf30a55..5307c5f8a 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -168,27 +168,53 @@ func TestTopP(t *testing.T) { softmax(tokens) tokens = topK(tokens, 20) - // Then apply topP - tokens = topP(tokens, 0.95) + // Test with very high p value + got := topP(tokens, 1.0) - // Should keep tokens until cumsum > 0.95 - if len(tokens) > 3 { + // Should keep all tokens since p is 1 + if len(got) != len(input) { + t.Errorf("topP(1.0): should keep all tokens, got %d, want %d", len(got), len(input)) + } + + // Test with normal p value + got = topP(tokens, 0.95) + + if len(got) > 3 { t.Errorf("topP(0.95): kept too many tokens: got %d", len(tokens)) - t.Logf("got: %v", tokens) + t.Logf("got: %v", got) } // Test edge case - ensure at least one token remains - input = []float32{-1e6, -1e6, -1e6} // One dominant token + input = []float32{-1e6, -1e6, -1e7} tokens = toTokens(input) + tokens = topK(tokens, 20) softmax(tokens) - tokens = topP(tokens, 0.0) // Very small p - if len(tokens) < 1 { + got = topP(tokens, 0.0) + if len(got) < 1 { t.Error("topP should keep at least one token") } + + // Test with zero p value + got = topP(tokens, 0.0) + + // Should keep only the highest probability token + if len(got) != 1 { + t.Errorf("topP(0.0): should keep only one token, got %d", len(got)) + t.Logf("got: %v", got) + } + + tokens = toTokens(input) + tokens = topK(tokens, 20) + softmax(tokens) + got = topP(tokens, 1e-10) + if len(got) == 0 { + t.Errorf("topP(1e-10): should keep at least one token, got %d", len(got)) + t.Logf("got: %v", got) + } } func TestMinP(t *testing.T) { - input := []float32{-3, -2, -1, 0, 1, 2, 4, 3} + input := []float32{-2, 0, -1, -3, 2, 1, 4, 3} tokens := toTokens(input) // First apply temperature and softmax @@ -225,30 +251,48 @@ func TestMinP(t *testing.T) { t.Logf("got: %v", tokens) } + // Test with single token + tokens = toTokens(input[:1]) + tokens = topK(tokens, 20) + softmax(tokens) + tokens = minP(tokens, 0.1) + + // Should keep only the highest probability token + if len(tokens) != 1 { + t.Errorf("minP(0.1): should return single token, got %d", len(tokens)) + t.Logf("got: %v", tokens) + } + input = []float32{1e-10, 1e-10, 1e-10} tokens = toTokens(input) softmax(tokens) tokens = minP(tokens, 1.0) if len(tokens) < 1 { t.Error("minP should keep at least one token even with extreme probabilities") - } -} + got := minP(tokens, 1.0) -func TestSortLogits(t *testing.T) { - input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} - tokens := toTokens(input) + if len(got) != 1 { + t.Errorf("minP(1.0): should keep all tokens, got %d, want %d", len(got), len(tokens)) + } - tokens = topK(tokens, 20) + // Test with normal p value + got = minP(tokens, 0.2) - for i := 1; i < len(tokens); i++ { - if tokens[i].value > tokens[i-1].value { - t.Errorf("sortLogits: tokens not sorted in descending order at index %d: %f > %f", - i, tokens[i].value, tokens[i-1].value) + // Should keep tokens with prob >= 0.2 * max_prob + if len(got) > 3 { + t.Errorf("minP(0.2): kept too many tokens: got %d", len(got)) + t.Logf("got: %v", got) + } + + // Test with zero p value + got = minP(tokens, 0.0) + + // Should keep only the highest probability token + if len(got) != len(tokens) { + t.Errorf("minP(0.0): should keep only one token, got %d", len(got)) + t.Logf("got: %v", got) } } - - want := []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} - compareLogits(t, "sortLogits", want, tokens) } func BenchmarkTransforms(b *testing.B) {