mirror of
https://github.com/ollama/ollama.git
synced 2025-05-05 04:10:20 +02:00
Mistral is a popular research lab making open source models. This updates the forward pass of llama architecture models to support both llama models and mistral models by accounting for additional metadata present in mistral models, and finding the correct dimensions for the output projection.
575 lines
15 KiB
Go
575 lines
15 KiB
Go
package model
|
||
|
||
import (
|
||
"bufio"
|
||
"encoding/json"
|
||
"math"
|
||
"os"
|
||
"path/filepath"
|
||
"slices"
|
||
"strconv"
|
||
"strings"
|
||
"testing"
|
||
|
||
"github.com/google/go-cmp/cmp"
|
||
)
|
||
|
||
func llama(t testing.TB) BytePairEncoding {
|
||
t.Helper()
|
||
|
||
f, err := os.Open(filepath.Join("testdata", "llama3.2", "encoder.json"))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer f.Close()
|
||
|
||
vocab := make(map[string]int32)
|
||
if err := json.NewDecoder(f).Decode(&vocab); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
types := make([]uint32, len(vocab))
|
||
tokens := make([]string, len(vocab))
|
||
for token, id := range vocab {
|
||
tokens[id] = token
|
||
types[id] = 1
|
||
}
|
||
|
||
for _, token := range []string{"<|begin_of_text|>", "<|end_of_text|>"} {
|
||
if _, ok := vocab[token]; !ok {
|
||
tokens = append(tokens, token) //nolint:makezero
|
||
types = append(types, 3) //nolint:makezero
|
||
vocab[token] = int32(len(vocab))
|
||
}
|
||
}
|
||
|
||
f, err = os.Open(filepath.Join("testdata", "llama3.2", "vocab.bpe"))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer f.Close()
|
||
|
||
merges := make([]string, 0, 50000)
|
||
|
||
scanner := bufio.NewScanner(f)
|
||
for scanner.Scan() {
|
||
if !strings.HasPrefix(scanner.Text(), "#") {
|
||
merges = append(merges, scanner.Text())
|
||
}
|
||
}
|
||
|
||
return NewBytePairEncoding(
|
||
`(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`,
|
||
&Vocabulary{
|
||
Values: tokens,
|
||
Types: types,
|
||
Merges: merges,
|
||
},
|
||
)
|
||
}
|
||
|
||
func TestLlama(t *testing.T) {
|
||
tokenizer := llama(t)
|
||
|
||
t.Run("simple", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
ids, err := tokenizer.Encode("hello world", true)
|
||
if err != nil {
|
||
t.Error(err)
|
||
}
|
||
|
||
if diff := cmp.Diff([]int32{15339, 1917}, ids); diff != "" {
|
||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||
}
|
||
|
||
s, err := tokenizer.Decode([]int32{15339, 1917})
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if s != "hello world" {
|
||
t.Errorf("got %q, want hello world", s)
|
||
}
|
||
|
||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||
if err != nil {
|
||
t.Error(err)
|
||
}
|
||
|
||
if diff := cmp.Diff([]int32{15339, 220, 128001}, ids); diff != "" {
|
||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||
}
|
||
})
|
||
|
||
t.Run("simple repeated", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
cases := map[string][]int32{
|
||
strings.Repeat("0", 1): {15},
|
||
strings.Repeat("0", 2): {410},
|
||
strings.Repeat("0", 3): {931},
|
||
strings.Repeat("0", 4): {931, 15},
|
||
strings.Repeat("0", 5): {931, 410},
|
||
strings.Repeat("0", 6): {931, 931},
|
||
strings.Repeat("0", 7): {931, 931, 15},
|
||
strings.Repeat("0", 8): {931, 931, 410},
|
||
strings.Repeat("0", 9): {931, 931, 931},
|
||
strings.Repeat("0", 10): {931, 931, 931, 15},
|
||
strings.Repeat("0", 11): {931, 931, 931, 410},
|
||
strings.Repeat("0", 12): {931, 931, 931, 931},
|
||
strings.Repeat("0", 13): {931, 931, 931, 931, 15},
|
||
strings.Repeat("0", 14): {931, 931, 931, 931, 410},
|
||
strings.Repeat("0", 15): {931, 931, 931, 931, 931},
|
||
strings.Repeat("0", 16): {931, 931, 931, 931, 931, 15},
|
||
strings.Repeat("0", 17): {931, 931, 931, 931, 931, 410},
|
||
}
|
||
|
||
for s, want := range cases {
|
||
ids, err := tokenizer.Encode(s, true)
|
||
if err != nil {
|
||
t.Error(err)
|
||
}
|
||
|
||
if diff := cmp.Diff(want, ids); diff != "" {
|
||
t.Errorf("%q no match (-theirs +ours):\n%s", s, diff)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("basic roundtrip", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
cases := []string{
|
||
"hello",
|
||
"hello ",
|
||
"hello ",
|
||
" hello",
|
||
" hello ",
|
||
" hello ",
|
||
"hello world",
|
||
"请考试我的软件!12345",
|
||
}
|
||
|
||
for _, want := range cases {
|
||
ids, err := tokenizer.Encode(want, true)
|
||
if err != nil {
|
||
t.Error(err)
|
||
}
|
||
|
||
if got, err := tokenizer.Decode(ids); err != nil {
|
||
t.Fatal(err)
|
||
} else if got != want {
|
||
t.Errorf("got %q, want %q", got, want)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("special", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
cases := map[string][]int32{
|
||
"<|begin_of_text|>A B!": {128000, 32, 426, 0},
|
||
"<|begin_of_text|>A<|end_of_text|>B!": {128000, 32, 128001, 33, 0},
|
||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!": {128000, 32, 128001, 33, 128000, 0},
|
||
"<|begin_of_text|>A<|end_of_text|>B<|begin_of_text|>!<|end_of_text|>": {128000, 32, 128001, 33, 128000, 0, 128001},
|
||
}
|
||
|
||
for s, want := range cases {
|
||
ids, err := tokenizer.Encode(s, true)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if diff := cmp.Diff(want, ids); diff != "" {
|
||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("split", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
cases := map[string][]string{
|
||
"Hello World!": {"Hello", " World", "!"},
|
||
"I'm don't won't": {"I", "'m", " don", "'t", " won", "'t"},
|
||
"In 2024 there are 366 days": {"In", " ", "202", "4", " there", " are", " ", "366", " days"},
|
||
"Hello!! ...world": {"Hello", "!!", " ...", "world"},
|
||
"Hello World": {"Hello", " ", " World"},
|
||
"Hello\nWorld": {"Hello", "\n", "World"},
|
||
"Hello, WORLD!! How's it going?": {"Hello", ",", " WORLD", "!!", " How", "'s", " it", " going", "?"},
|
||
}
|
||
|
||
for s, want := range cases {
|
||
got := slices.Collect(tokenizer.split(s))
|
||
if diff := cmp.Diff(want, got); diff != "" {
|
||
t.Errorf("no match (-theirs +ours):\n%s", diff)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
// tekken loads the Tekken tokenizer for testing
|
||
func tekken(t testing.TB) TextProcessor {
|
||
t.Helper()
|
||
|
||
// Load tokenizer config from mistral-small
|
||
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
|
||
configFile, err := os.Open(tokenizerConfigPath)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer configFile.Close()
|
||
|
||
var config struct {
|
||
AddBosToken bool `json:"add_bos_token"`
|
||
AddEosToken bool `json:"add_eos_token"`
|
||
BosToken struct {
|
||
Content string `json:"content"`
|
||
} `json:"bos_token"`
|
||
EosToken struct {
|
||
Content string `json:"content"`
|
||
} `json:"eos_token"`
|
||
}
|
||
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// Load tokenizer.json which contains the vocabulary and other settings
|
||
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
|
||
tokenizerFile, err := os.Open(tokenizerJsonPath)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
defer tokenizerFile.Close()
|
||
|
||
var tokenizerData struct {
|
||
Model struct {
|
||
Type string `json:"type"`
|
||
Vocab map[string]int32 `json:"vocab"`
|
||
Merges []string `json:"merges"`
|
||
} `json:"model"`
|
||
AddedTokens []struct {
|
||
Id int32 `json:"id"`
|
||
Content string `json:"content"`
|
||
Special bool `json:"special"`
|
||
} `json:"added_tokens"`
|
||
PreTokenizer struct {
|
||
Type string `json:"type"`
|
||
Pretokenizers []struct {
|
||
Type string `json:"type"`
|
||
Pattern struct {
|
||
String string `json:"String"`
|
||
} `json:"pattern"`
|
||
Behavior string `json:"behavior"`
|
||
} `json:"pretokenizers"`
|
||
} `json:"pre_tokenizer"`
|
||
}
|
||
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
// Extract the pattern from pre_tokenizer if available
|
||
var pattern string
|
||
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
|
||
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
|
||
}
|
||
|
||
// Combine regular vocab and added tokens
|
||
vocab := tokenizerData.Model.Vocab
|
||
|
||
// Add special tokens from added_tokens
|
||
for _, token := range tokenizerData.AddedTokens {
|
||
vocab[token.Content] = token.Id
|
||
}
|
||
|
||
// Create vocabulary arrays
|
||
maxId := int32(-1)
|
||
for _, id := range vocab {
|
||
if id > maxId {
|
||
maxId = id
|
||
}
|
||
}
|
||
|
||
vocabSize := int(maxId + 1)
|
||
types := make([]uint32, vocabSize)
|
||
tokens := make([]string, vocabSize)
|
||
scores := make([]float32, vocabSize)
|
||
|
||
for token, id := range vocab {
|
||
tokens[id] = token
|
||
types[id] = TOKEN_TYPE_NORMAL
|
||
|
||
// Assign appropriate token types for special tokens
|
||
if token == "<s>" {
|
||
types[id] = TOKEN_TYPE_CONTROL
|
||
} else if token == "</s>" {
|
||
types[id] = TOKEN_TYPE_CONTROL
|
||
} else if token == "[INST]" || token == "[/INST]" {
|
||
types[id] = TOKEN_TYPE_CONTROL
|
||
}
|
||
}
|
||
|
||
// In Tekken, we don't need to load merges separately as they're part of the model
|
||
var merges []string
|
||
|
||
// Create vocabulary object
|
||
vocabObj := &Vocabulary{
|
||
Values: tokens,
|
||
Types: types,
|
||
Scores: scores,
|
||
Merges: merges,
|
||
BOS: vocab[config.BosToken.Content],
|
||
EOS: vocab[config.EosToken.Content],
|
||
AddBOS: config.AddBosToken,
|
||
AddEOS: config.AddEosToken,
|
||
}
|
||
|
||
// Use pattern from tokenizer.json if available
|
||
if pattern != "" {
|
||
// Ensure pattern has proper escaping for Go regexp
|
||
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
|
||
return NewBytePairEncoding(pattern, vocabObj)
|
||
}
|
||
|
||
// Fallback pattern if not found
|
||
return NewBytePairEncoding(
|
||
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
|
||
vocabObj,
|
||
)
|
||
}
|
||
|
||
func TestTekken(t *testing.T) {
|
||
// Skip if the test data isn't available
|
||
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
|
||
t.Skip("Mistral-small test data not available")
|
||
}
|
||
|
||
tokenizer := tekken(t)
|
||
|
||
t.Run("whitespace_handling", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
|
||
cases := []struct {
|
||
input string
|
||
expected string
|
||
}{
|
||
{" hello", " hello"},
|
||
{"hello ", "hello "},
|
||
{"hello world", "hello world"},
|
||
{" hello world ", " hello world "},
|
||
}
|
||
|
||
for _, tc := range cases {
|
||
ids, err := tokenizer.Encode(tc.input, false)
|
||
if err != nil {
|
||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||
continue
|
||
}
|
||
|
||
decoded, err := tokenizer.Decode(ids)
|
||
if err != nil {
|
||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||
continue
|
||
}
|
||
|
||
if decoded != tc.expected {
|
||
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("chat_templates", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Test the Tekken chat template format which doesn't have spaces after special tokens
|
||
templates := []struct {
|
||
input string
|
||
expectSpace bool // whether we expect a space after special tokens
|
||
}{
|
||
{"<s>[INST]user message[/INST]", false},
|
||
{"<s>[INST] user message[/INST]", true},
|
||
{"<s>[INST]user message [/INST]", true},
|
||
}
|
||
|
||
for _, tc := range templates {
|
||
ids, err := tokenizer.Encode(tc.input, false)
|
||
if err != nil {
|
||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||
continue
|
||
}
|
||
|
||
decoded, err := tokenizer.Decode(ids)
|
||
if err != nil {
|
||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||
continue
|
||
}
|
||
|
||
// Check if there's a space after special tokens
|
||
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
|
||
|
||
if hasSpaceAfterINST != tc.expectSpace {
|
||
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
|
||
hasSpaceAfterINST, tc.expectSpace, tc.input)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("special_tokens", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Test how Tekken handles special tokens
|
||
cases := []struct {
|
||
input string
|
||
expected []string // We'll check if these tokens are in the decoded output
|
||
}{
|
||
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
|
||
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
|
||
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
|
||
}
|
||
|
||
for _, tc := range cases {
|
||
ids, err := tokenizer.Encode(tc.input, false)
|
||
if err != nil {
|
||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||
continue
|
||
}
|
||
|
||
decoded, err := tokenizer.Decode(ids)
|
||
if err != nil {
|
||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||
continue
|
||
}
|
||
|
||
for _, expected := range tc.expected {
|
||
if !strings.Contains(decoded, expected) {
|
||
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
|
||
}
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("vocabulary_coverage", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Tekken has a larger vocabulary, so test coverage of various token types
|
||
samples := []string{
|
||
"Hello world!",
|
||
"This is a test of the Tekken tokenizer.",
|
||
"It has a considerably larger vocabulary size.",
|
||
"Special characters: !@#$%^&*()",
|
||
"Numbers: 1234567890",
|
||
"Multiple languages: こんにちは 你好 안녕하세요",
|
||
"Code snippets: def function(): return True",
|
||
}
|
||
|
||
for _, sample := range samples {
|
||
ids, err := tokenizer.Encode(sample, false)
|
||
if err != nil {
|
||
t.Errorf("Failed to encode %q: %v", sample, err)
|
||
continue
|
||
}
|
||
|
||
decoded, err := tokenizer.Decode(ids)
|
||
if err != nil {
|
||
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
|
||
continue
|
||
}
|
||
|
||
if decoded != sample {
|
||
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("splitting_behavior", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Test the splitting behavior which might differ from SentencePiece
|
||
cases := map[string][]string{
|
||
"Hello World!": {"Hello", " World", "!"},
|
||
"user message": {"user", " message"},
|
||
"[INST]hello": {"[INST]", "hello"},
|
||
"hello[/INST]": {"hello", "[/INST]"},
|
||
}
|
||
|
||
for s, want := range cases {
|
||
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
|
||
if diff := cmp.Diff(want, got); diff != "" {
|
||
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("full_chat_sequence", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
// Test a complete chat sequence with Tekken's format
|
||
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
|
||
|
||
ids, err := tokenizer.Encode(chatSequence, false)
|
||
if err != nil {
|
||
t.Fatalf("Failed to encode chat sequence: %v", err)
|
||
}
|
||
|
||
decoded, err := tokenizer.Decode(ids)
|
||
if err != nil {
|
||
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
|
||
}
|
||
|
||
// In Tekken, the whitespace shouldn't be added after special tokens
|
||
if strings.Contains(decoded, "[INST] ") {
|
||
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
|
||
}
|
||
|
||
if strings.Contains(decoded, "[/INST] ") {
|
||
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
|
||
}
|
||
})
|
||
}
|
||
|
||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||
tokenizer := llama(b)
|
||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||
if err != nil {
|
||
b.Fatal(err)
|
||
}
|
||
|
||
for i := range 8 {
|
||
n := min(int(math.Pow10(i)), len(bts))
|
||
bts := bts[:n]
|
||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||
b.ResetTimer()
|
||
for range b.N {
|
||
_, err := tokenizer.Encode(string(bts), true)
|
||
if err != nil {
|
||
b.Fatal(err)
|
||
}
|
||
}
|
||
})
|
||
|
||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||
ids, err := tokenizer.Encode(string(bts), true)
|
||
if err != nil {
|
||
b.Fatal(err)
|
||
}
|
||
|
||
b.ResetTimer()
|
||
for range b.N {
|
||
_, err := tokenizer.Decode(ids)
|
||
if err != nil {
|
||
b.Fatal(err)
|
||
}
|
||
}
|
||
})
|
||
|
||
b.Run("split"+strconv.Itoa(n), func(b *testing.B) {
|
||
b.ResetTimer()
|
||
for range b.N {
|
||
slices.Collect(tokenizer.split(string(bts)))
|
||
}
|
||
})
|
||
}
|
||
}
|