mirror of
https://github.com/ollama/ollama.git
synced 2025-04-02 09:00:28 +02:00
Merge c7c751647dc5c0e0012a127bf2bab1923ed41bad into 108fe021657ea1f7a012299f424d673e86512df2
This commit is contained in:
commit
5f133ddacd
@ -13,9 +13,9 @@ import (
|
||||
)
|
||||
|
||||
type Options struct {
|
||||
hiddenSize, numHeads, numKVHeads int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
hiddenSize, numHeads, numKVHeads, headDim int
|
||||
eps, ropeBase, ropeScale float32
|
||||
ropeDim uint32
|
||||
}
|
||||
|
||||
type Model struct {
|
||||
@ -37,6 +37,8 @@ func New(c ml.Config) (model.Model, error) {
|
||||
|
||||
m := Model{
|
||||
BytePairEncoding: model.NewBytePairEncoding(
|
||||
// TODO: need to set this in the conversion for mistral:
|
||||
// tokenizer.ggml.pretokenizer = [^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]*[\p{Ll}\p{Lm}\p{Lo}\p{M}]+|[^\r\n\p{L}\p{N}]?[\p{Lu}\p{Lt}\p{Lm}\p{Lo}\p{M}]+[\p{Ll}\p{Lm}\p{Lo}\p{M}]*|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n/]*|\s*[\r\n]+|\s+(?!\S)|\s+
|
||||
c.String("tokenizer.ggml.pretokenizer", `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}{1,3}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`),
|
||||
&model.Vocabulary{
|
||||
Values: c.Strings("tokenizer.ggml.tokens"),
|
||||
@ -53,6 +55,7 @@ func New(c ml.Config) (model.Model, error) {
|
||||
hiddenSize: int(c.Uint("embedding_length")),
|
||||
numHeads: int(c.Uint("attention.head_count")),
|
||||
numKVHeads: int(c.Uint("attention.head_count_kv")),
|
||||
headDim: int(c.Uint("attention.key_length")),
|
||||
eps: c.Float("attention.layer_norm_rms_epsilon"),
|
||||
ropeBase: c.Float("rope.freq_base"),
|
||||
ropeScale: c.Float("rope.freq_scale", 1),
|
||||
@ -75,24 +78,36 @@ type SelfAttention struct {
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
// Get head dimension - use explicit value if available, otherwise calculate
|
||||
headDim := opts.headDim
|
||||
if headDim == 0 {
|
||||
headDim = opts.hiddenSize / opts.numHeads
|
||||
}
|
||||
|
||||
// Query projection and reshape
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, headDim, opts.numHeads, batchSize)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Key projection and reshape
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
// Value projection and reshape
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, headDim, opts.numKVHeads, batchSize)
|
||||
|
||||
// Attention computation
|
||||
scaleFactor := 1.0 / math.Sqrt(float64(headDim))
|
||||
kqv := nn.Attention(ctx, q, k, v, scaleFactor, cache)
|
||||
kqv = kqv.Reshape(ctx, opts.hiddenSize, batchSize)
|
||||
|
||||
// Reshape attention output for final projection
|
||||
outputDim := headDim * opts.numHeads
|
||||
kqv = kqv.Reshape(ctx, outputDim, batchSize)
|
||||
|
||||
// Apply output projection
|
||||
return sa.Output.Forward(ctx, kqv)
|
||||
}
|
||||
|
||||
|
@ -209,6 +209,326 @@ func TestLlama(t *testing.T) {
|
||||
})
|
||||
}
|
||||
|
||||
// tekken loads the Tekken tokenizer for testing
|
||||
func tekken(t testing.TB) TextProcessor {
|
||||
t.Helper()
|
||||
|
||||
// Load tokenizer config from mistral-small
|
||||
tokenizerConfigPath := filepath.Join("testdata", "mistral-small", "tokenizer_config.json")
|
||||
configFile, err := os.Open(tokenizerConfigPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer configFile.Close()
|
||||
|
||||
var config struct {
|
||||
AddBosToken bool `json:"add_bos_token"`
|
||||
AddEosToken bool `json:"add_eos_token"`
|
||||
BosToken struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"bos_token"`
|
||||
EosToken struct {
|
||||
Content string `json:"content"`
|
||||
} `json:"eos_token"`
|
||||
}
|
||||
if err := json.NewDecoder(configFile).Decode(&config); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Load tokenizer.json which contains the vocabulary and other settings
|
||||
tokenizerJsonPath := filepath.Join("testdata", "mistral-small", "tokenizer.json")
|
||||
tokenizerFile, err := os.Open(tokenizerJsonPath)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
defer tokenizerFile.Close()
|
||||
|
||||
var tokenizerData struct {
|
||||
Model struct {
|
||||
Type string `json:"type"`
|
||||
Vocab map[string]int32 `json:"vocab"`
|
||||
Merges []string `json:"merges"`
|
||||
} `json:"model"`
|
||||
AddedTokens []struct {
|
||||
Id int32 `json:"id"`
|
||||
Content string `json:"content"`
|
||||
Special bool `json:"special"`
|
||||
} `json:"added_tokens"`
|
||||
PreTokenizer struct {
|
||||
Type string `json:"type"`
|
||||
Pretokenizers []struct {
|
||||
Type string `json:"type"`
|
||||
Pattern struct {
|
||||
String string `json:"String"`
|
||||
} `json:"pattern"`
|
||||
Behavior string `json:"behavior"`
|
||||
} `json:"pretokenizers"`
|
||||
} `json:"pre_tokenizer"`
|
||||
}
|
||||
if err := json.NewDecoder(tokenizerFile).Decode(&tokenizerData); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Extract the pattern from pre_tokenizer if available
|
||||
var pattern string
|
||||
if tokenizerData.PreTokenizer.Type == "Sequence" && len(tokenizerData.PreTokenizer.Pretokenizers) > 0 {
|
||||
pattern = tokenizerData.PreTokenizer.Pretokenizers[0].Pattern.String
|
||||
}
|
||||
|
||||
// Combine regular vocab and added tokens
|
||||
vocab := tokenizerData.Model.Vocab
|
||||
|
||||
// Add special tokens from added_tokens
|
||||
for _, token := range tokenizerData.AddedTokens {
|
||||
vocab[token.Content] = token.Id
|
||||
}
|
||||
|
||||
// Create vocabulary arrays
|
||||
maxId := int32(-1)
|
||||
for _, id := range vocab {
|
||||
if id > maxId {
|
||||
maxId = id
|
||||
}
|
||||
}
|
||||
|
||||
vocabSize := int(maxId + 1)
|
||||
types := make([]uint32, vocabSize)
|
||||
tokens := make([]string, vocabSize)
|
||||
scores := make([]float32, vocabSize)
|
||||
|
||||
for token, id := range vocab {
|
||||
tokens[id] = token
|
||||
types[id] = TOKEN_TYPE_NORMAL
|
||||
|
||||
// Assign appropriate token types for special tokens
|
||||
if token == "<s>" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
} else if token == "</s>" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
} else if token == "[INST]" || token == "[/INST]" {
|
||||
types[id] = TOKEN_TYPE_CONTROL
|
||||
}
|
||||
}
|
||||
|
||||
// In Tekken, we don't need to load merges separately as they're part of the model
|
||||
var merges []string
|
||||
|
||||
// Create vocabulary object
|
||||
vocabObj := &Vocabulary{
|
||||
Values: tokens,
|
||||
Types: types,
|
||||
Scores: scores,
|
||||
Merges: merges,
|
||||
BOS: vocab[config.BosToken.Content],
|
||||
EOS: vocab[config.EosToken.Content],
|
||||
AddBOS: config.AddBosToken,
|
||||
AddEOS: config.AddEosToken,
|
||||
}
|
||||
|
||||
// Use pattern from tokenizer.json if available
|
||||
if pattern != "" {
|
||||
// Ensure pattern has proper escaping for Go regexp
|
||||
pattern = strings.ReplaceAll(pattern, "p{", "\\p{")
|
||||
return NewBytePairEncoding(pattern, vocabObj)
|
||||
}
|
||||
|
||||
// Fallback pattern if not found
|
||||
return NewBytePairEncoding(
|
||||
`\p{L}+|\p{N}+|[^\s\p{L}\p{N}]+|\s+`,
|
||||
vocabObj,
|
||||
)
|
||||
}
|
||||
|
||||
func TestTekken(t *testing.T) {
|
||||
// Skip if the test data isn't available
|
||||
if _, err := os.Stat(filepath.Join("testdata", "mistral-small")); os.IsNotExist(err) {
|
||||
t.Skip("Mistral-small test data not available")
|
||||
}
|
||||
|
||||
tokenizer := tekken(t)
|
||||
|
||||
t.Run("whitespace_handling", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// The key difference from SentencePiece is that Tekken doesn't prepend whitespace
|
||||
cases := []struct {
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{" hello", " hello"},
|
||||
{"hello ", "hello "},
|
||||
{"hello world", "hello world"},
|
||||
{" hello world ", " hello world "},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != tc.expected {
|
||||
t.Errorf("Whitespace handling: got %q, want %q", decoded, tc.expected)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("chat_templates", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test the Tekken chat template format which doesn't have spaces after special tokens
|
||||
templates := []struct {
|
||||
input string
|
||||
expectSpace bool // whether we expect a space after special tokens
|
||||
}{
|
||||
{"<s>[INST]user message[/INST]", false},
|
||||
{"<s>[INST] user message[/INST]", true},
|
||||
{"<s>[INST]user message [/INST]", true},
|
||||
}
|
||||
|
||||
for _, tc := range templates {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Check if there's a space after special tokens
|
||||
hasSpaceAfterINST := strings.Contains(decoded, "[INST] ")
|
||||
|
||||
if hasSpaceAfterINST != tc.expectSpace {
|
||||
t.Errorf("Chat template space handling: got space=%v, want space=%v for %q",
|
||||
hasSpaceAfterINST, tc.expectSpace, tc.input)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("special_tokens", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test how Tekken handles special tokens
|
||||
cases := []struct {
|
||||
input string
|
||||
expected []string // We'll check if these tokens are in the decoded output
|
||||
}{
|
||||
{"<s>[INST]hello[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]"}},
|
||||
{"[INST]hello[/INST]</s>", []string{"[INST]", "hello", "[/INST]", "</s>"}},
|
||||
{"<s>[INST]hello[/INST]</s>[INST]again[/INST]", []string{"<s>", "[INST]", "hello", "[/INST]", "</s>", "[INST]", "again", "[/INST]"}},
|
||||
}
|
||||
|
||||
for _, tc := range cases {
|
||||
ids, err := tokenizer.Encode(tc.input, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", tc.input, err)
|
||||
continue
|
||||
}
|
||||
|
||||
for _, expected := range tc.expected {
|
||||
if !strings.Contains(decoded, expected) {
|
||||
t.Errorf("Special token handling: %q missing in decoded output %q", expected, decoded)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("vocabulary_coverage", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Tekken has a larger vocabulary, so test coverage of various token types
|
||||
samples := []string{
|
||||
"Hello world!",
|
||||
"This is a test of the Tekken tokenizer.",
|
||||
"It has a considerably larger vocabulary size.",
|
||||
"Special characters: !@#$%^&*()",
|
||||
"Numbers: 1234567890",
|
||||
"Multiple languages: こんにちは 你好 안녕하세요",
|
||||
"Code snippets: def function(): return True",
|
||||
}
|
||||
|
||||
for _, sample := range samples {
|
||||
ids, err := tokenizer.Encode(sample, false)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to encode %q: %v", sample, err)
|
||||
continue
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Errorf("Failed to decode tokens for %q: %v", sample, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if decoded != sample {
|
||||
t.Errorf("Vocabulary coverage: got %q, want %q", decoded, sample)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("splitting_behavior", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test the splitting behavior which might differ from SentencePiece
|
||||
cases := map[string][]string{
|
||||
"Hello World!": {"Hello", " World", "!"},
|
||||
"user message": {"user", " message"},
|
||||
"[INST]hello": {"[INST]", "hello"},
|
||||
"hello[/INST]": {"hello", "[/INST]"},
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
got := slices.Collect(tokenizer.(*BytePairEncoding).split(s))
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("Splitting behavior no match (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
t.Run("full_chat_sequence", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
// Test a complete chat sequence with Tekken's format
|
||||
chatSequence := "<s>[INST]user message[/INST]assistant message</s>[INST]new user message[/INST]"
|
||||
|
||||
ids, err := tokenizer.Encode(chatSequence, false)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to encode chat sequence: %v", err)
|
||||
}
|
||||
|
||||
decoded, err := tokenizer.Decode(ids)
|
||||
if err != nil {
|
||||
t.Fatalf("Failed to decode chat sequence tokens: %v", err)
|
||||
}
|
||||
|
||||
// In Tekken, the whitespace shouldn't be added after special tokens
|
||||
if strings.Contains(decoded, "[INST] ") {
|
||||
t.Errorf("Tekken chat sequence has unexpected space after [INST]: %q", decoded)
|
||||
}
|
||||
|
||||
if strings.Contains(decoded, "[/INST] ") {
|
||||
t.Errorf("Tekken chat sequence has unexpected space after [/INST]: %q", decoded)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
tokenizer := llama(b)
|
||||
bts, err := os.ReadFile(filepath.Join("testdata", "war-and-peace.txt"))
|
||||
|
Loading…
x
Reference in New Issue
Block a user