From c7c751647dc5c0e0012a127bf2bab1923ed41bad Mon Sep 17 00:00:00 2001 From: Bruce MacDonald Date: Fri, 14 Mar 2025 16:56:32 -0700 Subject: [PATCH] model: support for mistral-small in the ollama runner Mistral is a popular research lab making open source models. This updates the forward pass of llama architecture models to support both llama models and mistral models by accounting for additional metadata present in mistral models, and finding the correct dimensions for the output projection. --- model/models/llama/model.go | 25 ++- model/process_text_test.go | 320 ++++++++++++++++++++++++++++++++++++ 2 files changed, 340 insertions(+), 5 deletions(-) diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 19a2ab8c4..47a88043e 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -13,9 +13,9 @@ import ( ) type Options struct { - hiddenSize, numHeads, numKVHeads int - eps, ropeBase, ropeScale float32 - ropeDim uint32 + hiddenSize, numHeads, numKVHeads, headDim int + eps, ropeBase, ropeScale float32 + ropeDim uint32 } type Model struct { @@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) { m := Model{ BytePairEncoding: model.NewBytePairEncoding( + // TODO: need to set this in the conversion for mistral: + // tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+ c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`), &model.Vocabulary{ Values: c.Strings("tokenizer.ggml.tokens"), @@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) { hiddenSize: int(c.Uint("embedding_length")), numHeads: int(c.Uint("attention.head_count")), numKVHeads: int(c.Uint("attention.head_count_kv")), + headDim: int(c.Uint("attention.key_length")), eps: c.Float("attention.layer_norm_rms_epsilon"), ropeBase: c.Float("rope.freq_base"), ropeScale: c.Float("rope.freq_scale", 1), @@ -75,24 +78,36 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(1) - headDim := opts.hiddenSize / opts.numHeads ropeType := uint32(0) + // Get head dimension - use explicit value if available, otherwise calculate + headDim := opts.headDim + if headDim == 0 { + headDim = opts.hiddenSize / opts.numHeads + } + // Query projection and reshape q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, headDim, opts.numHeads, batchSize) q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + // Key projection and reshape k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize) k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + // Value projection and reshape v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize) + // Attention computation scaleFactor := 1.0 / math.Sqrt(float64(headDim)) kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache) - kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize) + // Reshape attention output for final projection + outputDim := headDim * opts.numHeads + kqv = kqv.Reshape(ctx, outputDim, batchSize) + + // Apply output projection return sa.Output.Forward(ctx, kqv) } diff --git a/model/process_text_test.go b/model/process_text_test.go index f48303212..8654f6d27 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -209,6 +209,326 @@ func TestLlama(t *testing.T) { }) } +// tekken loads the Tekken tokenizer for testing +func tekken(t testing.TB) TextProcessor { + t.Helper() + + // Load tokenizer config from mistral-small + tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json") + configFile, err := os.Open(tokenizerConfigPath) + if err != nil { + t.Fatal(err) + } + defer configFile.Close() + + var config struct { + AddBosToken bool `json:"add_bos_token"` + AddEosToken bool `json:"add_eos_token"` + BosToken struct { + Content string `json:"content"` + } `json:"bos_token"` + EosToken struct { + Content string `json:"content"` + } `json:"eos_token"` + } + if err := json.NewDecoder(configFile).Decode(&config); err != nil { + t.Fatal(err) + } + + // Load tokenizer.json which contains the vocabulary and other settings + tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json") + tokenizerFile, err := os.Open(tokenizerJsonPath) + if err != nil { + t.Fatal(err) + } + defer tokenizerFile.Close() + + var tokenizerData struct { + Model struct { + Type string `json:"type"` + Vocab map[string]int32 `json:"vocab"` + Merges []string `json:"merges"` + } `json:"model"` + AddedTokens []struct { + Id int32 `json:"id"` + Content string `json:"content"` + Special bool `json:"special"` + } `json:"added_tokens"` + PreTokenizer struct { + Type string `json:"type"` + Pretokenizers []struct { + Type string `json:"type"` + Pattern struct { + String string `json:"String"` + } `json:"pattern"` + Behavior string `json:"behavior"` + } `json:"pretokenizers"` + } `json:"pre_tokenizer"` + } + if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil { + t.Fatal(err) + } + + // Extract the pattern from pre_tokenizer if available + var pattern string + if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 { + pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String + } + + // Combine regular vocab and added tokens + vocab := tokenizerData.Model.Vocab + + // Add special tokens from added_tokens + for _, token := range tokenizerData.AddedTokens { + vocab[token.Content] = token.Id + } + + // Create vocabulary arrays + maxId := int32(-1) + for _, id := range vocab { + if id > maxId { + maxId = id + } + } + + vocabSize := int(maxId + 1) + types := make([]uint32, vocabSize) + tokens := make([]string, vocabSize) + scores := make([]float32, vocabSize) + + for token, id := range vocab { + tokens[id] = token + types[id] = TOKEN_TYPE_NORMAL + + // Assign appropriate token types for special tokens + if token == "" { + types[id] = TOKEN_TYPE_CONTROL + } else if token == "" { + types[id] = TOKEN_TYPE_CONTROL + } else if token == "[INST]" || token == "[/INST]" { + types[id] = TOKEN_TYPE_CONTROL + } + } + + // In Tekken, we don't need to load merges separately as they're part of the model + var merges []string + + // Create vocabulary object + vocabObj := &Vocabulary{ + Values: tokens, + Types: types, + Scores: scores, + Merges: merges, + BOS: vocab[config.BosToken.Content], + EOS: vocab[config.EosToken.Content], + AddBOS: config.AddBosToken, + AddEOS: config.AddEosToken, + } + + // Use pattern from tokenizer.json if available + if pattern != "" { + // Ensure pattern has proper escaping for Go regexp + pattern = strings.ReplaceAll(pattern, "p{", "\\p{") + return NewBytePairEncoding(pattern, vocabObj) + } + + // Fallback pattern if not found + return NewBytePairEncoding( + `\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`, + vocabObj, + ) +} + +func TestTekken(t *testing.T) { + // Skip if the test data isn't available + if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) { + t.Skip("Mistral-small test data not available") + } + + tokenizer := tekken(t) + + t.Run("whitespace_handling", func(t *testing.T) { + t.Parallel() + + // The key difference from SentencePiece is that Tekken doesn't prepend whitespace + cases := []struct { + input string + expected string + }{ + {" hello", " hello"}, + {"hello ", "hello "}, + {"hello world", "hello world"}, + {" hello world ", " hello world "}, + } + + for _, tc := range cases { + ids, err := tokenizer.Encode(tc.input, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", tc.input, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) + continue + } + + if decoded != tc.expected { + t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected) + } + } + }) + + t.Run("chat_templates", func(t *testing.T) { + t.Parallel() + + // Test the Tekken chat template format which doesn't have spaces after special tokens + templates := []struct { + input string + expectSpace bool // whether we expect a space after special tokens + }{ + {"[INST]user message[/INST]", false}, + {"[INST] user message[/INST]", true}, + {"[INST]user message [/INST]", true}, + } + + for _, tc := range templates { + ids, err := tokenizer.Encode(tc.input, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", tc.input, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) + continue + } + + // Check if there's a space after special tokens + hasSpaceAfterINST := strings.Contains(decoded, "[INST] ") + + if hasSpaceAfterINST != tc.expectSpace { + t.Errorf("Chat template space handling: got space=%v, want space=%v for %q", + hasSpaceAfterINST, tc.expectSpace, tc.input) + } + } + }) + + t.Run("special_tokens", func(t *testing.T) { + t.Parallel() + + // Test how Tekken handles special tokens + cases := []struct { + input string + expected []string // We'll check if these tokens are in the decoded output + }{ + {"[INST]hello[/INST]", []string{"", "[INST]", "hello", "[/INST]"}}, + {"[INST]hello[/INST]", []string{"[INST]", "hello", "[/INST]", ""}}, + {"[INST]hello[/INST][INST]again[/INST]", []string{"", "[INST]", "hello", "[/INST]", "", "[INST]", "again", "[/INST]"}}, + } + + for _, tc := range cases { + ids, err := tokenizer.Encode(tc.input, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", tc.input, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", tc.input, err) + continue + } + + for _, expected := range tc.expected { + if !strings.Contains(decoded, expected) { + t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded) + } + } + } + }) + + t.Run("vocabulary_coverage", func(t *testing.T) { + t.Parallel() + + // Tekken has a larger vocabulary, so test coverage of various token types + samples := []string{ + "Hello world!", + "This is a test of the Tekken tokenizer.", + "It has a considerably larger vocabulary size.", + "Special characters: !@#$%^&*()", + "Numbers: 1234567890", + "Multiple languages: こんにちは 你好 안녕하세요", + "Code snippets: def function(): return True", + } + + for _, sample := range samples { + ids, err := tokenizer.Encode(sample, false) + if err != nil { + t.Errorf("Failed to encode %q: %v", sample, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("Failed to decode tokens for %q: %v", sample, err) + continue + } + + if decoded != sample { + t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample) + } + } + }) + + t.Run("splitting_behavior", func(t *testing.T) { + t.Parallel() + + // Test the splitting behavior which might differ from SentencePiece + cases := map[string][]string{ + "Hello World!": {"Hello", " World", "!"}, + "user message": {"user", " message"}, + "[INST]hello": {"[INST]", "hello"}, + "hello[/INST]": {"hello", "[/INST]"}, + } + + for s, want := range cases { + got := slices.Collect(tokenizer.(*BytePairEncoding).split(s)) + if diff := cmp.Diff(want, got); diff != "" { + t.Errorf("Splitting behavior no match (-want +got):\n%s", diff) + } + } + }) + + t.Run("full_chat_sequence", func(t *testing.T) { + t.Parallel() + + // Test a complete chat sequence with Tekken's format + chatSequence := "[INST]user message[/INST]assistant message[INST]new user message[/INST]" + + ids, err := tokenizer.Encode(chatSequence, false) + if err != nil { + t.Fatalf("Failed to encode chat sequence: %v", err) + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Fatalf("Failed to decode chat sequence tokens: %v", err) + } + + // In Tekken, the whitespace shouldn't be added after special tokens + if strings.Contains(decoded, "[INST] ") { + t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded) + } + + if strings.Contains(decoded, "[/INST] ") { + t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded) + } + }) +} + func BenchmarkBytePairEncoding(b *testing.B) { tokenizer := llama(b) bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))