diff --git a/llama/llama.go b/llama/llama.go index bb5028bd9..a026bee24 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -245,6 +245,20 @@ func LoadModelFromFile(modelPath string, params ModelParams) (*Model, error) { return &m, nil } +func LoadVocabFromFile(path string) (*Vocab, error) { + mp := C.CString(path) + defer C.free(unsafe.Pointer(mp)) + v := Vocab{c: C.llama_load_vocab_from_file(mp)} + if v.c == nil { + return nil, fmt.Errorf("unable to load vocab: %s", path) + } + return &v, nil +} + +func FreeVocab(vocab *Vocab) { + C.llama_free_vocab(vocab.c) +} + func FreeModel(model *Model) { C.llama_model_free(model.c) } @@ -293,6 +307,10 @@ func (m *Model) ApplyLoraFromFile(context *Context, loraPath string, scale float return nil } +type Vocab struct { + c *C.struct_llama_vocab +} + func (m *Model) Vocab() *C.struct_llama_vocab { return C.llama_model_get_vocab(m.c) } @@ -669,3 +687,53 @@ func SchemaToGrammar(schema []byte) []byte { } return buf[:n] } + +type Sampler struct { + c *C.struct_llama_sampler +} + +func NewGrammarSampler(vocab *Vocab, grammar string) *Sampler { + cGrammar := C.CString(grammar) + cRoot := C.CString("root") + defer C.free(unsafe.Pointer(cGrammar)) + defer C.free(unsafe.Pointer(cRoot)) + + sampler := &Sampler{c: C.llama_sampler_init_grammar(vocab.c, cGrammar, cRoot)} + + return sampler +} + +func (s *Sampler) Accept(token int32) { + C.llama_sampler_accept(s.c, C.llama_token(token)) +} + +type TokenData struct { + Id int32 + Logit float32 +} + +func (s *Sampler) Apply(tokens []TokenData) { + tds := make([]C.struct_llama_token_data, len(tokens)) + for i, token := range tokens { + tds[i] = C.struct_llama_token_data{ + id: C.int32_t(token.Id), + logit: C.float(token.Logit), + p: C.float(0.0), + } + } + tda := &C.llama_token_data_array{ + data: (*C.struct_llama_token_data)(unsafe.Pointer(&tds[0])), + size: C.size_t(len(tokens)), + selected: C.int64_t(-1), + sorted: C.bool(false), + } + + var pinner runtime.Pinner + pinner.Pin(&tds[0]) + defer pinner.Unpin() + + C.llama_sampler_apply(s.c, tda) + for i := range tokens { + tokens[i].Logit = float32(tds[i].logit) + } +} diff --git a/llama/sampling_ext.cpp b/llama/sampling_ext.cpp index 0f137dc8d..b816cedd4 100644 --- a/llama/sampling_ext.cpp +++ b/llama/sampling_ext.cpp @@ -2,6 +2,9 @@ #include "sampling.h" #include "sampling_ext.h" #include "json-schema-to-grammar.h" +#include "llama.h" +#include "llama-model.h" +#include "llama-model-loader.h" struct common_sampler *common_sampler_cinit(const struct llama_model *model, struct common_sampler_cparams *params) { try { @@ -64,3 +67,22 @@ int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len) return 0; } } + +struct llama_vocab * llama_load_vocab_from_file(const char * fname) { + llama_vocab * vocab = new llama_vocab(); + try { + const auto kv = LLM_KV(LLM_ARCH_UNKNOWN); + std::vector splits = {}; + llama_model_loader ml(std::string(fname), splits, false, false, nullptr); + vocab->load(ml, kv); + } catch (const std::exception & err) { + LLAMA_LOG_ERROR("%s: error loading model: %s\n", __func__, err.what()); + return nullptr; + } + + return vocab; +} + +void llama_free_vocab(struct llama_vocab * vocab) { + delete vocab; +} diff --git a/llama/sampling_ext.h b/llama/sampling_ext.h index 39f499f19..9be7c100e 100644 --- a/llama/sampling_ext.h +++ b/llama/sampling_ext.h @@ -35,6 +35,9 @@ extern "C" int schema_to_grammar(const char *json_schema, char *grammar, size_t max_len); + struct llama_vocab * llama_load_vocab_from_file(const char * fname); + void llama_free_vocab(struct llama_vocab * vocab); + #ifdef __cplusplus } #endif diff --git a/llm/server.go b/llm/server.go index 9553ba8f0..a53306fb0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -729,29 +729,24 @@ func (s *llmServer) Completion(ctx context.Context, req CompletionRequest, fn fu } if len(req.Format) > 0 { - format := string(req.Format) - if format != `null` && format != `""` { - if s.textProcessor != nil { - // New engine handles this on the backend - request["format"] = req.Format - } else { - // old engine - switch format { - case `"json"`: - request["grammar"] = grammarJSON - default: - if req.Format[0] != '{' { - return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) - } - - // User provided a JSON schema - g := llama.SchemaToGrammar(req.Format) - if g == nil { - return fmt.Errorf("invalid JSON schema in format") - } - request["grammar"] = string(g) - } + switch string(req.Format) { + case `null`, `""`: + // Field was set, but "missing" a value. We accept + // these as "not set". + break + case `"json"`: + request["grammar"] = grammarJSON + default: + if req.Format[0] != '{' { + return fmt.Errorf("invalid format: %q; expected \"json\" or a valid JSON Schema object", req.Format) } + + // User provided a JSON schema + g := llama.SchemaToGrammar(req.Format) + if g == nil { + return fmt.Errorf("invalid JSON schema in format") + } + request["grammar"] = string(g) } } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index c8383a5dd..c1475cbb2 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -254,6 +254,12 @@ type Server struct { // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash + + // vocab is a llama.cpp vocab required for gammar-based + // constrained generation (json mode, structured outputs) + // TODO: this is temporary until Ollama sampling supports + // constrained generation + vocab *sample.Vocab } func (s *Server) allNil() bool { @@ -574,18 +580,25 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { return } + var grammar *sample.Grammar + var err error + if req.Grammar != "" { + grammar, err = sample.NewGrammar(s.vocab, req.Grammar) + if err != nil { + http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError) + return + } + } + sampler := sample.NewSampler( req.Temperature, req.TopK, req.TopP, req.MinP, req.Seed, + grammar, ) - if req.Grammar != "" { - panic("grammars are not yet supported") - } - seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ numPredict: req.NumPredict, stop: req.Stop, @@ -797,6 +810,8 @@ func (s *Server) loadModel( panic(err) } + s.vocab = sample.NewVocab(mpath) + // TODO(jessegross): LoRA loading if lpath.String() != "" { panic("loras are not yet implemented") diff --git a/sample/samplers.go b/sample/samplers.go index a5a0507ca..a9d90692d 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -2,43 +2,88 @@ package sample import ( "errors" + "math" "math/rand/v2" "slices" + "sync" + + "github.com/ollama/ollama/llama" ) -// Sampler is not thread-safe. Each goroutine should have its own instance -type Sampler interface { - Sample([]float32) (int32, error) -} - -// logit represents information about a single token during sampling -type logit struct { +// token represents information about a single token during sampling +type token struct { id int32 // The token's unique identifier value float32 // The raw logit or probability from the model } -type weighted struct { +type Sampler struct { rng *rand.Rand - tokens []logit topK int topP float32 minP float32 temperature float32 + grammar *Grammar } -func (s *weighted) Sample(logits []float32) (int32, error) { - if len(s.tokens) < len(logits) { - s.tokens = make([]logit, len(logits)) - } - - tokens := s.tokens[:len(logits)] - - for i, v := range logits { +func (s *Sampler) Sample(logits []float32) (int32, error) { + tokens := make([]token, len(logits)) + for i := range logits { tokens[i].id = int32(i) - tokens[i].value = v + tokens[i].value = logits[i] + } + + t, err := s.sample(tokens) + if err != nil { + return -1, err + } + + if s.grammar != nil { + // optimization: first check if the max logit is accepted by the grammar + // if the max logit is rejected, apply the grammar to all logits (slower) + top := []token{t} + s.grammar.Apply(top) + if !math.IsInf(float64(top[0].value), -1) { + s.grammar.Accept(top[0].id) + return top[0].id, nil + } + + // since .sample has side effects of modifying the tokens + // we need to reset them before applying the grammar and + // sampling again + for i := range logits { + tokens[i].id = int32(i) + tokens[i].value = logits[i] + } + s.grammar.Apply(tokens) + t, err = s.sample(tokens) + if err != nil { + return -1, err + } + s.grammar.Accept(t.id) + } + + return t.id, nil +} + +// greedy returns the highest probability token from the tokens +func greedy(tokens []token) token { + max := tokens[0] + for i := 1; i < len(tokens); i++ { + if tokens[i].value > max.value { + max = tokens[i] + } + } + + return max +} + +// sample returns the highest probability token from the tokens +// given sampler parameters. It also has side effects of modifying the tokens +func (s *Sampler) sample(tokens []token) (token, error) { + if s.temperature == 0 { + return greedy(tokens), nil } - // Tokens are sorted by logits in TopK or SortTokens if s.topK > 0 { tokens = topK(tokens, s.topK) } else { @@ -47,12 +92,14 @@ func (s *weighted) Sample(logits []float32) (int32, error) { tokens = temperature(tokens, s.temperature) tokens = softmax(tokens) - tokens = topP(tokens, s.topP) tokens = minP(tokens, s.minP) + // TODO: this should fall back to greedy sampling + // or topP, topK values etc should be such that + // there are always tokens to sample from if len(tokens) == 0 { - return -1, errors.New("no valid logits found for weighted sampling") + return token{}, errors.New("no tokens to sample from") } var r float32 @@ -70,48 +117,18 @@ func (s *weighted) Sample(logits []float32) (int32, error) { } r *= tokens[len(tokens)-1].value - idx, _ := slices.BinarySearchFunc(tokens, r, func(token logit, target float32) int { - // Compare cumulative probabilities + idx, _ := slices.BinarySearchFunc(tokens, r, func(token token, target float32) int { if token.value < target { return -1 } - // First token that exceeds target return 1 }) - if idx >= len(tokens) { - idx = len(tokens) - 1 - } - - return tokens[idx].id, nil -} - -type greedy struct{} - -// Greedy sample returns the index of the maximum value in logits. -func (s greedy) Sample(logits []float32) (int32, error) { - if len(logits) == 0 { - return -1, errors.New("no logits provided for greedy sampling") - } - - maxIdx := 0 - maxVal := logits[0] - for i := 1; i < len(logits); i++ { - if logits[i] > maxVal { - maxVal = logits[i] - maxIdx = i - } - } - - return int32(maxIdx), nil + return tokens[idx], nil } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 -func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int) Sampler { - if temperature == 0 { - return &greedy{} - } - +func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { var rng *rand.Rand if seed != -1 { // PCG requires two parameters: sequence and stream @@ -120,7 +137,9 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed // Use golden ratio hash to generate statistically independent seeds rng = rand.New(rand.NewPCG(sequence, sequence^0x9E3779B9)) } - temperature = max(temperature, 1) + if temperature < 0.0 { + temperature = 0.0 + } if topP < 0.0 { topP = 0.0 @@ -136,11 +155,73 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed minP = 1.0 } - return &weighted{ + return Sampler{ rng: rng, topK: topK, topP: topP, minP: minP, temperature: temperature, + grammar: grammar, } } + +type Grammar struct { + vocab *Vocab + grammar string + sampler *llama.Sampler +} + +func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) { + v, err := vocab.Load() + if err != nil { + return nil, err + } + + return &Grammar{ + vocab: vocab, + grammar: grammar, + sampler: llama.NewGrammarSampler(v, grammar), + }, nil +} + +func (g *Grammar) Apply(tokens []token) { + tds := make([]llama.TokenData, len(tokens)) + for i, token := range tokens { + tds[i].Id = token.id + tds[i].Logit = token.value + } + + g.sampler.Apply(tds) + + for i := range tokens { + tokens[i].value = tds[i].Logit + } +} + +func (g *Grammar) Accept(token int32) { + g.sampler.Accept(token) +} + +type Vocab struct { + once sync.Once + vocab *llama.Vocab + err error + path string +} + +func NewVocab(path string) *Vocab { + return &Vocab{path: path} +} + +// Load returns the lazily-loaded vocabulary +func (v *Vocab) Load() (*llama.Vocab, error) { + v.once.Do(func() { + vocab, err := llama.LoadVocabFromFile(v.path) + if err != nil { + v.err = err + return + } + v.vocab = vocab + }) + return v.vocab, v.err +} diff --git a/sample/samplers_benchmark_test.go b/sample/samplers_benchmark_test.go index 41c0b487f..cd1380141 100644 --- a/sample/samplers_benchmark_test.go +++ b/sample/samplers_benchmark_test.go @@ -16,13 +16,10 @@ func BenchmarkWeightedSampler(b *testing.B) { logits[i] = float32(rand.Float64()*10 - 5) } - sampler := NewSampler(0.8, 0, 0, 0, 42) + sampler := NewSampler(0.8, 0, 0, 0, 42, nil) b.ResetTimer() for b.Loop() { - _, err := sampler.Sample(logits) - if err != nil { - b.Fatalf("Sampling failed: %v", err) - } + sampler.Sample(logits) } }) } @@ -52,30 +49,24 @@ func BenchmarkWeightedSampler(b *testing.B) { for _, tc := range configs { b.Run("Config"+tc.name, func(b *testing.B) { - sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed) + sampler := NewSampler(tc.temperature, tc.topK, tc.topP, tc.minP, tc.seed, nil) sampler.Sample(logits) b.ResetTimer() for b.Loop() { - _, err := sampler.Sample(logits) - if err != nil { - b.Fatalf("Sampling failed: %v", err) - } + sampler.Sample(logits) } }) } // Test with combined transforms separately - topK influences performance greatly b.Run("TransformCombined", func(b *testing.B) { - sampler := NewSampler(0.8, 50, 0.9, 0.05, 42) + sampler := NewSampler(0.8, 50, 0.9, 0.05, 42, nil) b.ResetTimer() for b.Loop() { - _, err := sampler.Sample(logits) - if err != nil { - b.Fatalf("Sampling failed: %v", err) - } + sampler.Sample(logits) } }) } @@ -90,14 +81,11 @@ func BenchmarkGreedySampler(b *testing.B) { logits[i] = float32(rand.Float64()*10 - 5) } - sampler := NewSampler(0, -1, 0, 0, -1) + sampler := NewSampler(0, -1, 0, 0, -1, nil) b.ResetTimer() for b.Loop() { - _, err := sampler.Sample(logits) - if err != nil { - b.Fatalf("Sampling failed: %v", err) - } + sampler.Sample(logits) } }) } diff --git a/sample/samplers_test.go b/sample/samplers_test.go index dbbee17bb..38b9b352a 100644 --- a/sample/samplers_test.go +++ b/sample/samplers_test.go @@ -7,7 +7,7 @@ import ( func TestWeighted(t *testing.T) { logits := []float32{-10, 3, -10, -10} - sampler := NewSampler(0, 0, 0, 0, 0) + sampler := NewSampler(0, 0, 0, 0, 0, nil) got, err := sampler.Sample(logits) if err != nil { t.Error(err) @@ -19,7 +19,7 @@ func TestWeighted(t *testing.T) { } logits = []float32{-100, -10, 0, 10} - sampler = NewSampler(0, 0, 0, 0, 0) + sampler = NewSampler(0, 0, 0, 0, 0, nil) got, err = sampler.Sample(logits) if err != nil { t.Error(err) @@ -31,94 +31,10 @@ func TestWeighted(t *testing.T) { } } -func TestNewSampler(t *testing.T) { - tests := []struct { - name string - temperature float32 - topK int - topP float32 - minP float32 - seed int - wantGreedy bool // Instead of wantErr, check if we get greedy sampler - }{ - { - name: "temperature", - temperature: 0.5, - wantGreedy: false, - }, - { - name: "zero temperature - greedy", - temperature: 0, - wantGreedy: true, - }, - { - name: "top k", - temperature: 0.1, - topK: 10, - wantGreedy: false, - }, - { - name: "top p", - temperature: 0.1, - topP: 0.9, - wantGreedy: false, - }, - { - name: "min p", - temperature: 0.1, - minP: 0.2, - wantGreedy: false, - }, - { - name: "seed - weighted", - temperature: 0.1, - seed: 42, - wantGreedy: false, - }, - { - name: "default values", - temperature: 0.8, - topK: 40, - topP: 0.9, - minP: 0.0, - seed: 0, - wantGreedy: false, - }, - { - name: "all zeroes - greedy", - temperature: 0.0, - topK: 0, - topP: 0.0, - minP: 0.0, - seed: 0, - wantGreedy: true, - }, - { - name: "all transforms", - temperature: 0.8, - topK: 50, - topP: 0.95, - minP: 0.1, - seed: 42, - wantGreedy: false, - }, - } - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - sampler := NewSampler(tt.temperature, tt.topK, tt.topP, tt.minP, tt.seed) - _, isGreedy := sampler.(*greedy) - if isGreedy != tt.wantGreedy { - t.Errorf("NewSampler() got greedy = %v, want %v", isGreedy, tt.wantGreedy) - } - }) - } -} - func BenchmarkSample(b *testing.B) { - weighted := NewSampler(0.5, 10, 0.9, 0.2, -1) samplers := map[string]Sampler{ - "Greedy": NewSampler(0, 0, 0, 0, 0), // Use NewSampler with temp=0 for greedy - "Weighted": weighted, + "Greedy": NewSampler(0, 0, 0, 0, 0, nil), // Use NewSampler with temp=0 for greedy + "Weighted": NewSampler(0.5, 10, 0.9, 0.2, -1, nil), } // Generate random logits for benchmarking @@ -132,7 +48,7 @@ func BenchmarkSample(b *testing.B) { b.ResetTimer() for b.Loop() { if _, err := s.Sample(logits); err != nil { - b.Error(err) + b.Fatalf("error sampling: %v", err) } } }) diff --git a/sample/transforms.go b/sample/transforms.go index f1f4f3b19..496252975 100644 --- a/sample/transforms.go +++ b/sample/transforms.go @@ -5,7 +5,7 @@ import ( "slices" ) -func softmax(ts []logit) []logit { +func softmax(ts []token) []token { var sum float32 for i, v := range ts { ts[i].value = float32(math.Exp(float64(v.value))) @@ -19,7 +19,7 @@ func softmax(ts []logit) []logit { return ts } -func temperature(ti []logit, t float32) []logit { +func temperature(ti []token, t float32) []token { if t == 1 { return ti } @@ -51,7 +51,7 @@ func temperature(ti []logit, t float32) []logit { // 1. Finds the smallest value between the node and its children // 2. If the node is not the smallest, swaps it with its smallest child // 3. Continues this process down the affected path until the min-heap property is restored -func siftDown(data []logit, start, end int) { +func siftDown(data []token, start, end int) { root := start for { child := 2*root + 1 @@ -73,7 +73,7 @@ func siftDown(data []logit, start, end int) { } // topK limits the number of tokens considered to the k highest logits -func topK(ts []logit, k int) []logit { +func topK(ts []token, k int) []token { if k >= len(ts) { return ts } @@ -99,7 +99,7 @@ func topK(ts []logit, k int) []logit { } // topP limits tokens to those with cumulative probability p -func topP(ts []logit, p float32) []logit { +func topP(ts []token, p float32) []token { if p == 1.0 { return ts } @@ -118,7 +118,7 @@ func topP(ts []logit, p float32) []logit { } // minP limits tokens to those with cumulative probability p -func minP(ts []logit, p float32) []logit { +func minP(ts []token, p float32) []token { if p == 1.0 { return ts } @@ -146,7 +146,7 @@ func minP(ts []logit, p float32) []logit { // TODO(parthsareen): possibly replace with simpler implementation https://github.com/ollama/ollama/issues/9584 // Conting sort implementation to sort tokens by logits -func sortLogits(tokens []logit) { +func sortLogits(tokens []token) { if len(tokens) <= 1 { return } @@ -187,7 +187,7 @@ func sortLogits(tokens []logit) { } // Second pass: place elements in correct position - output := make([]logit, len(tokens)) + output := make([]token, len(tokens)) // Track current positions countsCopy := counts diff --git a/sample/transforms_test.go b/sample/transforms_test.go index 950d79b35..1065231dc 100644 --- a/sample/transforms_test.go +++ b/sample/transforms_test.go @@ -7,10 +7,10 @@ import ( ) // Helper to convert float64 slice to logit slice -func toLogits(values []float64) []logit { - tokens := make([]logit, len(values)) +func toTokens(values []float64) []token { + tokens := make([]token, len(values)) for i, v := range values { - tokens[i] = logit{ + tokens[i] = token{ id: int32(i), value: float32(v), } @@ -19,7 +19,7 @@ func toLogits(values []float64) []logit { } // Helper to compare logit slices -func compareLogits(t *testing.T, name string, want []float64, got []logit) { +func compareLogits(t *testing.T, name string, want []float64, got []token) { t.Helper() if len(want) != len(got) { t.Errorf("%s: length mismatch: want %d, got %d", name, len(want), len(got)) @@ -36,13 +36,13 @@ func TestTemperature(t *testing.T) { input := []float64{2, -1, 4, -3, 1, -2, 0} want := []float64{-4, -10, 0, -14, -6, -12, -8} // (logit - max logit) / temp - got := temperature(toLogits(input), 0.5) + got := temperature(toTokens(input), 0.5) compareLogits(t, "Temperature", want, got) } func TestSoftmax(t *testing.T) { input := []float64{-3, -2, -1, 0, 1, 2, 4} - got := softmax(toLogits(input)) + got := softmax(toTokens(input)) // Check probabilities sum to 1 var sum float32 @@ -65,7 +65,7 @@ func TestTopK(t *testing.T) { input := []float64{-3, -2, -1, 0, 1, 2, 4} // Test k=3 - got := topK(toLogits(input), 3) + got := topK(toTokens(input), 3) if len(got) != 3 { t.Errorf("topK(3): wrong length: want 3, got %d", len(got)) } @@ -74,13 +74,13 @@ func TestTopK(t *testing.T) { compareLogits(t, "topK(3)", want, got) // Test k > len - got = topK(toLogits(input), 10) + got = topK(toTokens(input), 10) compareLogits(t, "topK(10)", input, got) } func TestTopP(t *testing.T) { input := []float64{-3, -2, -1, 0, 1, 2, 4} - tokens := toLogits(input) + tokens := toTokens(input) // First apply temperature and softmax to get probabilities tokens = temperature(tokens, 1) @@ -99,7 +99,7 @@ func TestTopP(t *testing.T) { func TestMinP(t *testing.T) { input := []float64{-3, -2, -1, 0, 1, 2, 4, 3} - tokens := toLogits(input) + tokens := toTokens(input) // First apply temperature and softmax tokens = temperature(tokens, 1) @@ -116,7 +116,7 @@ func TestMinP(t *testing.T) { func TestSortLogits(t *testing.T) { input := []float64{3, 1, 4, 2, -1, 0, -2} - tokens := toLogits(input) + tokens := toTokens(input) sortLogits(tokens) @@ -133,15 +133,15 @@ func TestSortLogits(t *testing.T) { func BenchmarkTransforms(b *testing.B) { // Generate random logits - tokens := make([]logit, 1<<16) + tokens := make([]token, 1<<16) for i := range tokens { - tokens[i] = logit{ + tokens[i] = token{ id: int32(i), value: rand.Float32(), } } - tokensCopy := make([]logit, len(tokens)) + tokensCopy := make([]token, len(tokens)) b.Run("Temperature", func(b *testing.B) { b.ResetTimer()