diff --git a/sample/samplers.go b/sample/samplers.go index a9d90692d..aea99b3f2 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -90,8 +90,9 @@ func (s *Sampler) sample(tokens []token) (token, error) { sortLogits(tokens) } + // token logit values are updated to probabilities tokens = temperature(tokens, s.temperature) - tokens = softmax(tokens) + tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) diff --git a/sample/transforms.go b/sample/transforms.go index 496252975..ab62455f3 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -5,13 +5,25 @@ import ( "slices" ) -func softmax(ts []token) []token { +// temperature applies scaling and softmax to the logits +func temperature(ts []token, temp float32) []token { + // Find max logit for numerical stability + maxLogit := float32(math.Inf(-1)) + for _, t := range ts { + if t.value > maxLogit { + maxLogit = t.value + } + } + + // Apply temperature and compute exp(x - max) + temp = max(temp, 1e-7) var sum float32 for i, v := range ts { - ts[i].value = float32(math.Exp(float64(v.value))) + ts[i].value = float32(math.Exp(float64((v.value - maxLogit) / temp))) sum += ts[i].value } + // Normalize for i := range ts { ts[i].value /= sum } @@ -19,27 +31,6 @@ func softmax(ts []token) []token { return ts } -func temperature(ti []token, t float32) []token { - if t == 1 { - return ti - } - - temp := max(t, 1e-7) - maxLogit := float32(math.Inf(-1)) - for _, token := range ti { - if token.value > maxLogit { - maxLogit = token.value - } - } - - // subtracting max logit to avoid under/overflow - for i := range ti { - ti[i].value = (ti[i].value - maxLogit) / temp - } - - return ti -} - // siftDown maintains a min-heap property by recursively moving larger elements down the heap. // // The heap is represented as an array where for any node at index i: @@ -145,7 +136,8 @@ func minP(ts []token, p float32) []token { } // TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584 -// Conting sort implementation to sort tokens by logits +// sortLogits sorts implementation to sort tokens by logits using counting sort +// counting sort is faster than built-in sort for this use case func sortLogits(tokens []token) { if len(tokens) <= 1 { return diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 1065231dc..81e8849b7 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -32,17 +32,9 @@ func compareLogits(t *testing.T, name string, want []float64, got []token) { } } -func TestTemperature(t *testing.T) { - input := []float64{2, -1, 4, -3, 1, -2, 0} - want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp - +func TestTemperatureAndSoftmax(t *testing.T) { + input := []float64{1, 4, -2, 0} got := temperature(toTokens(input), 0.5) - compareLogits(t, "Temperature", want, got) -} - -func TestSoftmax(t *testing.T) { - input := []float64{-3, -2, -1, 0, 1, 2, 4} - got := softmax(toTokens(input)) // Check probabilities sum to 1 var sum float32 @@ -53,11 +45,14 @@ func TestSoftmax(t *testing.T) { t.Errorf("probabilities don't sum to 1: got %f", sum) } - // Check relative ordering is preserved - for i := 1; i < len(got); i++ { - if got[i].value < got[i-1].value { - t.Errorf("probability ordering not preserved at index %d", i) - } + got = temperature(toTokens(input), 1) + // Check probabilities sum to 1 + sum = 0.0 + for _, token := range got { + sum += token.value + } + if math.Abs(float64(sum)-1.0) > 1e-6 { + t.Errorf("probabilities don't sum to 1: got %f", sum) } } @@ -84,7 +79,6 @@ func TestTopP(t *testing.T) { // First apply temperature and softmax to get probabilities tokens = temperature(tokens, 1) - tokens = softmax(tokens) sortLogits(tokens) // Then apply topP @@ -103,7 +97,6 @@ func TestMinP(t *testing.T) { // First apply temperature and softmax tokens = temperature(tokens, 1) - tokens = softmax(tokens) // Then apply minP got := minP(tokens, 0.2)