sample: add numerical stability to temperature/softmax transform (#9631)

This commit is contained in:
Parth Sareen 2025-03-10 14:43:53 -07:00 committed by GitHub
parent fe776293f7
commit 7e34f4fbfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
3 changed files with 28 additions and 42 deletions

View File

@ -90,8 +90,9 @@ func (s *Sampler) sample(tokens []token) (token, error) {
sortLogits(tokens)
}
// token logit values are updated to probabilities
tokens = temperature(tokens, s.temperature)
tokens = softmax(tokens)
tokens = topP(tokens, s.topP)
tokens = minP(tokens, s.minP)

View File

@ -5,13 +5,25 @@ import (
"slices"
)
func softmax(ts []token) []token {
// temperature applies scaling and softmax to the logits
func temperature(ts []token, temp float32) []token {
// Find max logit for numerical stability
maxLogit := float32(math.Inf(-1))
for _, t := range ts {
if t.value > maxLogit {
maxLogit = t.value
}
}
// Apply temperature and compute exp(x - max)
temp = max(temp, 1e-7)
var sum float32
for i, v := range ts {
ts[i].value = float32(math.Exp(float64(v.value)))
ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp)))
sum += ts[i].value
}
// Normalize
for i := range ts {
ts[i].value /= sum
}
@ -19,27 +31,6 @@ func softmax(ts []token) []token {
return ts
}
func temperature(ti []token, t float32) []token {
if t == 1 {
return ti
}
temp := max(t, 1e-7)
maxLogit := float32(math.Inf(-1))
for _, token := range ti {
if token.value > maxLogit {
maxLogit = token.value
}
}
// subtracting max logit to avoid under/overflow
for i := range ti {
ti[i].value = (ti[i].value - maxLogit) / temp
}
return ti
}
// siftDown maintains a min-heap property by recursively moving larger elements down the heap.
//
// The heap is represented as an array where for any node at index i:
@ -145,7 +136,8 @@ func minP(ts []token, p float32) []token {
}
// TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584
// Conting sort implementation to sort tokens by logits
// sortLogits sorts implementation to sort tokens by logits using counting sort
// counting sort is faster than built-in sort for this use case
func sortLogits(tokens []token) {
if len(tokens) <= 1 {
return

View File

@ -32,17 +32,9 @@ func compareLogits(t *testing.T, name string, want []float64, got []token) {
}
}
func TestTemperature(t *testing.T) {
input := []float64{2, -1, 4, -3, 1, -2, 0}
want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp
func TestTemperatureAndSoftmax(t *testing.T) {
input := []float64{1, 4, -2, 0}
got := temperature(toTokens(input), 0.5)
compareLogits(t, "Temperature", want, got)
}
func TestSoftmax(t *testing.T) {
input := []float64{-3, -2, -1, 0, 1, 2, 4}
got := softmax(toTokens(input))
// Check probabilities sum to 1
var sum float32
@ -53,11 +45,14 @@ func TestSoftmax(t *testing.T) {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
// Check relative ordering is preserved
for i := 1; i < len(got); i++ {
if got[i].value < got[i-1].value {
t.Errorf("probability ordering not preserved at index %d", i)
}
got = temperature(toTokens(input), 1)
// Check probabilities sum to 1
sum = 0.0
for _, token := range got {
sum += token.value
}
if math.Abs(float64(sum)-1.0) > 1e-6 {
t.Errorf("probabilities don't sum to 1: got %f", sum)
}
}
@ -84,7 +79,6 @@ func TestTopP(t *testing.T) {
// First apply temperature and softmax to get probabilities
tokens = temperature(tokens, 1)
tokens = softmax(tokens)
sortLogits(tokens)
// Then apply topP
@ -103,7 +97,6 @@ func TestMinP(t *testing.T) {
// First apply temperature and softmax
tokens = temperature(tokens, 1)
tokens = softmax(tokens)
// Then apply minP
got := minP(tokens, 0.2)