diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index d11eba820..d39981204 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -65,8 +65,8 @@ type Sequence struct { // number of tokens to predict numPredict int - // set of samplers to run on generated logits - samplers []sample.Sampler + // sampler with transforms to run on generated logits + sampler sample.Sampler // channel to send back the embedding if embedding only embedding chan []float32 @@ -93,7 +93,7 @@ type NewSequenceParams struct { numPredict int stop []string numKeep int32 - samplers []sample.Sampler + sampler sample.Sampler embedding bool } @@ -136,7 +136,7 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen responses: make(chan string, 100), quit: make(chan bool, 1), embedding: make(chan []float32, 1), - samplers: params.samplers, + sampler: params.sampler, embeddingOnly: params.embedding, stop: params.stop, numKeep: params.numKeep, @@ -393,13 +393,7 @@ func (s *Server) processBatch() error { return fmt.Errorf("failed to decode batch: %w", err) } - f32s := modelOutput.Floats() - - // TODO(jessegross): This will no longer be necessary once the sampling interface takes f32s - logits := make([]float64, len(f32s)) - for i, f32 := range f32s { - logits[i] = float64(f32) - } + logits := modelOutput.Floats() for i, seq := range s.seqs { if seq == nil { @@ -433,14 +427,12 @@ func (s *Server) processBatch() error { } // sample a token - vocabSize := len(f32s) / len(options.Outputs) - tokens, err := sample.Sample(logits[seq.iBatch*vocabSize:(seq.iBatch+1)*vocabSize], seq.samplers...) - if err != nil { - return err - } + vocabSize := len(logits) / len(options.Outputs) - // TODO(jessegross): Sampler will output a single int32 in the future - token := int32(tokens[0]) + token, err := seq.sampler.Sample(logits[seq.iBatch*vocabSize : (seq.iBatch+1)*vocabSize]) + if err != nil { + return fmt.Errorf("failed to sample token: %w", err) + } // if it's an end of sequence token, break if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { @@ -565,27 +557,6 @@ type CompletionResponse struct { Timings Timings `json:"timings"` } -func getSamplers(_ CompletionRequest) []sample.Sampler { - // TODO(jessegross): Waiting for sampling code - - /*samplingParams.TopK = req.TopK - samplingParams.TopP = req.TopP - samplingParams.MinP = req.MinP - samplingParams.TypicalP = req.TypicalP - samplingParams.Temp = req.Temperature - samplingParams.RepeatLastN = req.RepeatLastN - samplingParams.PenaltyRepeat = req.RepeatPenalty - samplingParams.PenaltyFreq = req.FrequencyPenalty - samplingParams.PenaltyPresent = req.PresencePenalty - samplingParams.Mirostat = req.Mirostat - samplingParams.MirostatTau = req.MirostatTau - samplingParams.MirostatEta = req.MirostatEta - samplingParams.Seed = uint32(req.Seed) - samplingParams.Grammar = req.Grammar*/ - - return []sample.Sampler{sample.Greedy()} -} - func (s *Server) completion(w http.ResponseWriter, r *http.Request) { var req CompletionRequest req.Options = Options(api.DefaultOptions()) @@ -604,11 +575,23 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + sampler, err := sample.NewSampler( + req.Temperature, + req.TopK, + req.TopP, + req.MinP, + req.Seed, + ) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to create sampler: %v", err), http.StatusInternalServerError) + return + } + seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, numKeep: int32(req.NumKeep), - samplers: getSamplers(req), + sampler: sampler, embedding: false, }) if err != nil { diff --git a/sample/greedy.go b/sample/greedy.go deleted file mode 100644 index 206f5544d..000000000 --- a/sample/greedy.go +++ /dev/null @@ -1,13 +0,0 @@ -package sample - -import "gonum.org/v1/gonum/floats" - -type greedy struct{} - -func Greedy() Sampler { - return greedy{} -} - -func (s greedy) Sample(t []float64) ([]float64, error) { - return []float64{float64(floats.MaxIdx(t))}, nil -} diff --git a/sample/sample.go b/sample/sample.go deleted file mode 100644 index 44c08caed..000000000 --- a/sample/sample.go +++ /dev/null @@ -1,74 +0,0 @@ -package sample - -import ( - "slices" - - "gonum.org/v1/gonum/floats" - "gonum.org/v1/gonum/stat/sampleuv" -) - -type Sampler interface { - Sample([]float64) ([]float64, error) -} - -type Temperature float64 - -func (s Temperature) Sample(t []float64) ([]float64, error) { - floats.Div(t, slices.Repeat([]float64{float64(s)}, len(t))) - return t, nil -} - -type softmax struct{} - -func Softmax() Sampler { - return softmax{} -} - -func (softmax) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type TopK int - -func (s TopK) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type TopP float32 - -func (s TopP) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type MinP float32 - -func (s MinP) Sample(t []float64) ([]float64, error) { - return t, nil -} - -type weighed struct{} - -func Weighed() Sampler { - return weighed{} -} - -func (s weighed) Sample(t []float64) ([]float64, error) { - w := sampleuv.NewWeighted(t, nil) - if v, ok := w.Take(); ok { - return []float64{float64(v)}, nil - } - - return t, nil -} - -func Sample(floats []float64, samplers ...Sampler) ([]float64, error) { - var err error - for _, sampler := range samplers { - floats, err = sampler.Sample(floats) - if err != nil { - return nil, err - } - } - - return floats, nil -} diff --git a/sample/samplers.go b/sample/samplers.go new file mode 100644 index 000000000..836c6e4d9 --- /dev/null +++ b/sample/samplers.go @@ -0,0 +1,139 @@ +package sample + +import ( + "errors" + "math" + + "golang.org/x/exp/rand" + "gonum.org/v1/gonum/stat/sampleuv" +) + +type Sampler interface { + Sample([]float32) (int32, error) +} + +type weighted struct { + src rand.Source + transforms []Transform +} + +// TODO(parthsareen): remove uv sample dependency https://github.com/ollama/ollama/issues/9279 +func Weighted(seed *uint64, transforms ...Transform) Sampler { + var src rand.Source + if seed != nil { + src = rand.NewSource(*seed) + } + return weighted{src: src, transforms: transforms} +} + +func (s weighted) Sample(logits []float32) (int32, error) { + logits64 := make([]float64, len(logits)) + for i, v := range logits { + logits64[i] = float64(v) + } + + for _, t := range s.transforms { + logits64 = t.Apply(logits64) + } + + logitsCopy := make([]float64, 0, len(logits)) + indices := make([]int, 0, len(logits)) + for i, logit := range logits64 { + if !math.IsInf(logit, -1) { + logitsCopy = append(logitsCopy, logit) + indices = append(indices, i) + } + } + + if len(logitsCopy) == 0 { + return -1, errors.New("no valid logits found for weighed sampling") + } + + probs := softmax(logitsCopy) + w := sampleuv.NewWeighted(probs, s.src) + if idx, ok := w.Take(); ok { + return int32(indices[idx]), nil + } + return -1, errors.New("weighed sampler failed, no valid token found") +} + +type greedy struct { + transforms []Transform +} + +func Greedy(transforms ...Transform) Sampler { + return greedy{transforms: transforms} +} + +func (s greedy) Sample(logits []float32) (int32, error) { + logits64 := make([]float64, len(logits)) + for i, v := range logits { + logits64[i] = float64(v) + } + + for _, t := range s.transforms { + logits64 = t.Apply(logits64) + } + + var maxIdx int + var maxLogit float64 + for i, logit := range logits64 { + if logit > maxLogit { + maxLogit = logit + maxIdx = i + } + } + + if maxLogit == math.Inf(-1) { + return -1, errors.New("no valid logits found for greedy sampling") + } + + return int32(maxIdx), nil +} + +// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 +func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) (Sampler, error) { + transforms := []Transform{} + if temperature < 0 || temperature > 2 { + return nil, errors.New("temperature must be between 0 and 2") + } + + if temperature != 0 { + transforms = append(transforms, Temperature(temperature)) + } + + if topK != 0 { + if topK <= 0 { + return nil, errors.New("topK must be greater than 0") + } + transforms = append(transforms, TopK(topK)) + } + + if topP != 0 { + if topP < 0 || topP >= 1 { + return nil, errors.New("topP must be between 0 and 1") + } + transforms = append(transforms, TopP(topP)) + } + + if minP != 0 { + if minP < 0 || minP >= 1 { + return nil, errors.New("minP must be between 0 and 1") + } + transforms = append(transforms, MinP(minP)) + } + + if len(transforms) == 0 { + return nil, errors.New("at least one transform is required") + } + + if temperature == 0 { + return Greedy(transforms...), nil + } + + if seed != 0 { + seed64 := uint64(seed) + return Weighted(&seed64, transforms...), nil + } + return Weighted(nil, transforms...), nil +} diff --git a/sample/samplers_test.go b/sample/samplers_test.go new file mode 100644 index 000000000..aaa8d99c4 --- /dev/null +++ b/sample/samplers_test.go @@ -0,0 +1,238 @@ +package sample + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestWeighted(t *testing.T) { + got, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))}) + if err != nil { + t.Error(err) + return + } + want := int32(1) + if want != got { + t.Errorf("index mismatch: want %d, got %d", want, got) + } + + got, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))}) + if err == nil { + t.Error("expected error for no valid tokens, got index", got) + } + + seed := uint64(42) + got, err = Weighted(&seed).Sample([]float32{1, 2, 3, 4}) + if err != nil { + t.Error(err) + return + } + // With seed 42, we expect a consistent sample + want = int32(3) // This will be deterministic due to the seed + if want != got { + t.Errorf("index mismatch: want %d, got %d", want, got) + } +} + +type testTransform struct { + id int + callOrder *[]int +} + +func (ts *testTransform) Apply(logits []float64) []float64 { + if ts.callOrder != nil { + *ts.callOrder = append(*ts.callOrder, ts.id) + } + return logits +} + +func TestSample(t *testing.T) { + input := []float32{1, 2, 3, 4} + + var callOrder []int + mock1 := &testTransform{ + id: 1, + callOrder: &callOrder, + } + mock2 := &testTransform{ + id: 2, + callOrder: &callOrder, + } + mock3 := &testTransform{ + id: 3, + callOrder: &callOrder, + } + + got, err := Greedy(mock1, mock2, mock3).Sample(input) + if err != nil { + t.Error(err) + return + } + + want := int32(3) // Greedy sampler should pick highest logit + if want != got { + t.Errorf("index mismatch: want %d, got %d", want, got) + } + wantOrder := []int{1, 2, 3} + if diff := cmp.Diff(wantOrder, callOrder); diff != "" { + t.Errorf("call order mismatch (-want +got):\n%s", diff) + } + + callOrder = nil + + _, err = Weighted(nil, mock1, mock2, mock3).Sample(input) + if err != nil { + t.Error(err) + return + } + wantOrder = []int{1, 2, 3} + if diff := cmp.Diff(wantOrder, callOrder); diff != "" { + t.Errorf("call order mismatch (-want +got):\n%s", diff) + } +} + +func TestNewSampler(t *testing.T) { + tests := []struct { + name string + temperature float32 + topK int + topP float32 + minP float32 + seed int + wantErr bool + }{ + { + name: "no transforms", + wantErr: true, + }, + { + name: "temperature", + temperature: 0.5, + wantErr: false, + }, + { + name: "invalid temperature negative", + temperature: -1, + wantErr: true, + }, + { + name: "invalid temperature too high", + temperature: 2.1, + wantErr: true, + }, + { + name: "top k", + topK: 10, + wantErr: false, + }, + { + name: "invalid top k negative", + topK: -1, + wantErr: true, + }, + { + name: "top p", + topP: 0.9, + wantErr: false, + }, + { + name: "invalid top p negative", + topP: -0.1, + wantErr: true, + }, + { + name: "invalid top p one", + topP: 1.0, + wantErr: true, + }, + { + name: "min p", + minP: 0.2, + wantErr: false, + }, + { + name: "invalid min p negative", + minP: -0.1, + wantErr: true, + }, + { + name: "invalid min p one", + minP: 1.0, + wantErr: true, + }, + { + name: "seed", + seed: 42, + wantErr: true, // seed alone is not valid without other transforms + }, + { + name: "default values", + temperature: 0.8, + topK: 40, + topP: 0.9, + minP: 0.0, + seed: 0, + wantErr: false, + }, + { + name: "all zeroes", + temperature: 0.0, + topK: 0, + topP: 0.0, + minP: 0.0, + seed: 0, + wantErr: true, // all zeroes means no transforms + }, + { + name: "all transforms", + temperature: 0.8, + topK: 50, + topP: 0.95, + minP: 0.1, + seed: 42, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + _, err := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed) + if (err != nil) != tt.wantErr { + t.Errorf("NewSampler() error = %v, wantErr %v", err, tt.wantErr) + } + }) + } +} + +func BenchmarkSample(b *testing.B) { + transforms := []Transform{ + Temperature(0.5), + TopK(10), + TopP(0.9), + MinP(0.2), + } + + samplers := map[string]Sampler{ + "Greedy": Greedy(transforms...), + "Weighted": Weighted(nil, transforms...), + } + + logits := make([]float32, 1<<16) + for i := range logits { + logits[i] = rand.Float32() + } + + for name, s := range samplers { + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for range b.N { + if _, err := s.Sample(logits); err != nil { + b.Error(err) + } + } + }) + } +} diff --git a/sample/transforms.go b/sample/transforms.go new file mode 100644 index 000000000..2dc6ebae1 --- /dev/null +++ b/sample/transforms.go @@ -0,0 +1,120 @@ +package sample + +import ( + "cmp" + "math" + "slices" + + pq "github.com/emirpasic/gods/v2/queues/priorityqueue" +) + +type Transform interface { + Apply([]float64) []float64 +} + +// TODO(parthsareen): potentially cache softmax values +func softmax(logits []float64) []float64 { + var sum float64 + probs := make([]float64, len(logits)) + for i, v := range logits { + probs[i] = math.Exp(v) + sum += probs[i] + } + + for i := range probs { + probs[i] /= sum + } + + return probs +} + +type Temperature float64 + +func (t Temperature) Apply(logits []float64) []float64 { + temp := math.Max(float64(t), 1e-7) + + // subtracting max logit to avoid under/overflow + maxLogit := slices.Max(logits) + for i := range logits { + logits[i] = (logits[i] - maxLogit) / temp + } + + return logits +} + +type logitMap struct { + index int + logit float64 +} + +type TopK int + +// TODO(parthsareen): avoid having to check all logits after this transform +func (k TopK) Apply(logits []float64) []float64 { + if int(k) >= len(logits) { + return logits + } + q := pq.NewWith(func(a, b logitMap) int { + return -cmp.Compare(a.logit, b.logit) + }) + + for i, logit := range logits { + q.Enqueue(logitMap{index: i, logit: logit}) + } + + validLogits := make(map[int]float64) + for range k { + logitMap, _ := q.Dequeue() + validLogits[logitMap.index] = logitMap.logit + } + + for i := range logits { + if _, ok := validLogits[i]; !ok { + logits[i] = math.Inf(-1) + } + } + + return logits +} + +type TopP float64 + +func (p TopP) Apply(logits []float64) []float64 { + probs := softmax(logits) + indices := make([]int, len(probs)) + for i := range indices { + indices[i] = i + } + + // sort in descending order + slices.SortFunc(indices, func(i, j int) int { + return cmp.Compare(probs[j], probs[i]) + }) + + var sum float64 + for i, idx := range indices { + sum += probs[idx] + if sum > float64(p) { + for _, idx := range indices[i+1:] { + logits[idx] = math.Inf(-1) + } + break + } + } + return logits +} + +type MinP float64 + +func (p MinP) Apply(logits []float64) []float64 { + probs := softmax(logits) + threshold := slices.Max(probs) * float64(p) + + for i, prob := range probs { + if prob < threshold { + logits[i] = math.Inf(-1) + } + } + + return logits +} diff --git a/sample/transforms_test.go b/sample/transforms_test.go new file mode 100644 index 000000000..05f76a274 --- /dev/null +++ b/sample/transforms_test.go @@ -0,0 +1,80 @@ +package sample + +import ( + "math" + "math/rand/v2" + "testing" + + "github.com/google/go-cmp/cmp" +) + +func TestTemperature(t *testing.T) { + got := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0}) + want := []float64{-4, -10, 0, -14, -6, -12, -8} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func TestSoftmax(t *testing.T) { + got := softmax([]float64{-3, -2, -1, 0, 1, 2, 4}) + + want := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("probs mismatch (-want +got):\n%s", diff) + } +} + +func TestTopK(t *testing.T) { + got := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } + + got = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + + want = []float64{-3, -2, -1, 0, 1, 2, 4} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func TestTopP(t *testing.T) { + got := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}) + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func TestMinP(t *testing.T) { + got := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3}) + want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3} + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("logits mismatch (-want +got):\n%s", diff) + } +} + +func BenchmarkTransform(b *testing.B) { + transforms := map[string]Transform{ + "Temperature": Temperature(0.5), + "TopK": TopK(10), + "TopP": TopP(0.9), + "MinP": MinP(0.2), + } + + logits := make([]float64, 1<<16) + for i := range logits { + logits[i] = rand.Float64() + } + + for name, transform := range transforms { + b.Run(name, func(b *testing.B) { + b.ResetTimer() + for range b.N { + transform.Apply(logits) + } + }) + } +}