From b70fc4d51e76fc023afcd005c467d415c0c62750 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Wed, 5 Mar 2025 13:27:53 -0800 Subject: [PATCH] model: Don't unconditionally add special tokens We sometimes tokenize partial strings. For example, with multimodal inputs, we split the input string around the images and then tokenize each piece. In these cases, we should only add the special tokens on the first piece. --- llm/server.go | 2 +- model/process_text.go | 6 +++--- model/process_text_test.go | 14 +++++++------- runner/ollamarunner/runner.go | 2 +- 4 files changed, 12 insertions(+), 12 deletions(-) diff --git a/llm/server.go b/llm/server.go index 09690a5ff..9553ba8f0 100644 --- a/llm/server.go +++ b/llm/server.go @@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error) return s.llamaModel.Tokenize(content, false, true) } if s.textProcessor != nil { - tokens, err := s.textProcessor.Encode(content) + tokens, err := s.textProcessor.Encode(content, false) if err != nil { return nil, err } diff --git a/model/process_text.go b/model/process_text.go index 7083f36fd..bfb0a5f20 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -19,7 +19,7 @@ const ( ) type TextProcessor interface { - Encode(string) ([]int32, error) + Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool } @@ -144,7 +144,7 @@ type merge struct { runes []rune } -func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { +func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range bpe.vocab.SpecialVocabulary() { // TODO: process special tokens concurrently @@ -282,7 +282,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) { } } - if len(ids) > 0 { + if addSpecial && len(ids) > 0 { if bpe.vocab.AddBOS { if ids[0] == bpe.vocab.BOS { slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) diff --git a/model/process_text_test.go b/model/process_text_test.go index cad1f94ff..f48303212 100644 --- a/model/process_text_test.go +++ b/model/process_text_test.go @@ -74,7 +74,7 @@ func TestLlama(t *testing.T) { t.Run("simple", func(t *testing.T) { t.Parallel() - ids, err := tokenizer.Encode("hello world") + ids, err := tokenizer.Encode("hello world", true) if err != nil { t.Error(err) } @@ -92,7 +92,7 @@ func TestLlama(t *testing.T) { t.Errorf("got %q, want hello world", s) } - ids, err = tokenizer.Encode("hello <|end_of_text|>") + ids, err = tokenizer.Encode("hello <|end_of_text|>", true) if err != nil { t.Error(err) } @@ -126,7 +126,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Error(err) } @@ -152,7 +152,7 @@ func TestLlama(t *testing.T) { } for _, want := range cases { - ids, err := tokenizer.Encode(want) + ids, err := tokenizer.Encode(want, true) if err != nil { t.Error(err) } @@ -176,7 +176,7 @@ func TestLlama(t *testing.T) { } for s, want := range cases { - ids, err := tokenizer.Encode(s) + ids, err := tokenizer.Encode(s, true) if err != nil { t.Fatal(err) } @@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { b.Run("encode"+strconv.Itoa(n), func(b *testing.B) { b.ResetTimer() for range b.N { - _, err := tokenizer.Encode(string(bts)) + _, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } @@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) { }) b.Run("decode"+strconv.Itoa(n), func(b *testing.B) { - ids, err := tokenizer.Encode(string(bts)) + ids, err := tokenizer.Encode(string(bts), true) if err != nil { b.Fatal(err) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 1a4bbf19e..9ba6563f0 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -161,7 +161,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) { for i, part := range parts { // text - tokenize - tokens, err := s.model.(model.TextProcessor).Encode(part) + tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) if err != nil { return nil, err }