diff --git a/sample/samplers.go b/sample/samplers.go index aea99b3f2..8b0de3f54 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -84,11 +84,8 @@ func (s *Sampler) sample(tokens []token) (token, error) { return greedy(tokens), nil } - if s.topK > 0 { - tokens = topK(tokens, s.topK) - } else { - sortLogits(tokens) - } + // topK also sorts the tokens in descending order of logits + tokens = topK(tokens, s.topK) // token logit values are updated to probabilities tokens = temperature(tokens, s.temperature) diff --git a/sample/transforms.go b/sample/transforms.go index 05fd4533c..b65917afd 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -53,8 +53,17 @@ func temperature(ts []token, temp float32) []token { // topK limits the number of tokens considered to the k highest logits func topK(ts []token, k int) []token { - if k >= len(ts) { - sortLogits(ts) + if k >= len(ts) || k <= 0 { + slices.SortFunc(ts, func(a, b token) int { + switch { + case a.value < b.value: + return 1 + case a.value > b.value: + return -1 + default: + return 0 + } + }) return ts } @@ -125,17 +134,3 @@ func minP(ts []token, p float32) []token { ts = validTokens return ts } - -// sortLogits sorts the tokens in descending order of logits -func sortLogits(ts []token) { - slices.SortFunc(ts, func(a, b token) int { - switch { - case a.value < b.value: - return 1 - case a.value > b.value: - return -1 - default: - return 0 - } - }) -} diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 8ed6be3e0..8f0a58b60 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -59,7 +59,7 @@ func TestTemperatureAndSoftmax(t *testing.T) { func TestTopK(t *testing.T) { input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} - // Test k=3 + // Test k=5 got := topK(toTokens(input), 5) if len(got) != 5 { t.Errorf("topK(5): wrong length: want 5, got %d", len(got)) @@ -72,6 +72,24 @@ func TestTopK(t *testing.T) { if len(got) != len(input) { t.Errorf("topK(20): wrong length: want %d, got %d", len(input), len(got)) } + + // Test k=-1 + input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} + want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} + got = topK(toTokens(input), -1) + if len(got) != len(input) { + t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) + } + compareLogits(t, "topK(-1)", want, got) + + // Test k=0 + input = []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} + want = []float32{0.27755088, 0.20409796, 0.15720603, 0.08582123, 0.045046154, 0.043722924, 0.036774673, 0.026986899, 0.01681367, 0.0046718004, 0.00412893, 0.0030491839} + got = topK(toTokens(input), 0) + if len(got) != len(input) { + t.Errorf("topK(-1): wrong length: want %d, got %d", len(input), len(got)) + } + compareLogits(t, "topK(-1)", want, got) } func TestTopP(t *testing.T) { @@ -80,7 +98,7 @@ func TestTopP(t *testing.T) { // First apply temperature and softmax to get probabilities tokens = temperature(tokens, 1) - sortLogits(tokens) + tokens = topK(tokens, 20) // Then apply topP got := topP(tokens, 0.95) @@ -112,7 +130,7 @@ func TestSortLogits(t *testing.T) { input := []float32{0.026986899, 0.043722924, 0.036774673, 0.27755088, 0.0046718004, 0.08582123, 0.20409796, 0.00412893, 0.15720603, 0.045046154, 0.0030491839, 0.01681367} tokens := toTokens(input) - sortLogits(tokens) + tokens = topK(tokens, 20) for i := 1; i < len(tokens); i++ { if tokens[i].value > tokens[i-1].value { @@ -173,7 +191,7 @@ func BenchmarkTransforms(b *testing.B) { b.ResetTimer() for b.Loop() { copy(tokensCopy, tokens) - sortLogits(tokensCopy) + topK(tokensCopy, 200000) } }) }