mirror of
https://github.com/ollama/ollama.git
synced 2025-07-25 21:32:38 +02:00
model: Don't unconditionally add special tokens
We sometimes tokenize partial strings. For example, with multimodal inputs, we split the input string around the images and then tokenize each piece. In these cases, we should only add the special tokens on the first piece.
This commit is contained in:
@@ -973,7 +973,7 @@ func (s *llmServer) Tokenize(ctx context.Context, content string) ([]int, error)
|
|||||||
return s.llamaModel.Tokenize(content, false, true)
|
return s.llamaModel.Tokenize(content, false, true)
|
||||||
}
|
}
|
||||||
if s.textProcessor != nil {
|
if s.textProcessor != nil {
|
||||||
tokens, err := s.textProcessor.Encode(content)
|
tokens, err := s.textProcessor.Encode(content, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
@@ -19,7 +19,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type TextProcessor interface {
|
type TextProcessor interface {
|
||||||
Encode(string) ([]int32, error)
|
Encode(s string, addSpecial bool) ([]int32, error)
|
||||||
Decode([]int32) (string, error)
|
Decode([]int32) (string, error)
|
||||||
Is(int32, Special) bool
|
Is(int32, Special) bool
|
||||||
}
|
}
|
||||||
@@ -144,7 +144,7 @@ type merge struct {
|
|||||||
runes []rune
|
runes []rune
|
||||||
}
|
}
|
||||||
|
|
||||||
func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||||
fragments := []fragment{{value: s}}
|
fragments := []fragment{{value: s}}
|
||||||
for _, special := range bpe.vocab.SpecialVocabulary() {
|
for _, special := range bpe.vocab.SpecialVocabulary() {
|
||||||
// TODO: process special tokens concurrently
|
// TODO: process special tokens concurrently
|
||||||
@@ -282,7 +282,7 @@ func (bpe BytePairEncoding) Encode(s string) ([]int32, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ids) > 0 {
|
if addSpecial && len(ids) > 0 {
|
||||||
if bpe.vocab.AddBOS {
|
if bpe.vocab.AddBOS {
|
||||||
if ids[0] == bpe.vocab.BOS {
|
if ids[0] == bpe.vocab.BOS {
|
||||||
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS)
|
||||||
|
@@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
|
|||||||
t.Run("simple", func(t *testing.T) {
|
t.Run("simple", func(t *testing.T) {
|
||||||
t.Parallel()
|
t.Parallel()
|
||||||
|
|
||||||
ids, err := tokenizer.Encode("hello world")
|
ids, err := tokenizer.Encode("hello world", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
|
|||||||
t.Errorf("got %q, want hello world", s)
|
t.Errorf("got %q, want hello world", s)
|
||||||
}
|
}
|
||||||
|
|
||||||
ids, err = tokenizer.Encode("hello <|end_of_text|>")
|
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for s, want := range cases {
|
for s, want := range cases {
|
||||||
ids, err := tokenizer.Encode(s)
|
ids, err := tokenizer.Encode(s, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, want := range cases {
|
for _, want := range cases {
|
||||||
ids, err := tokenizer.Encode(want)
|
ids, err := tokenizer.Encode(want, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Error(err)
|
t.Error(err)
|
||||||
}
|
}
|
||||||
@@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for s, want := range cases {
|
for s, want := range cases {
|
||||||
ids, err := tokenizer.Encode(s)
|
ids, err := tokenizer.Encode(s, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
|||||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for range b.N {
|
for range b.N {
|
||||||
_, err := tokenizer.Encode(string(bts))
|
_, err := tokenizer.Encode(string(bts), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||||
ids, err := tokenizer.Encode(string(bts))
|
ids, err := tokenizer.Encode(string(bts), true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
@@ -161,7 +161,7 @@ func (s *Server) inputs(prompt string, images []ImageData) ([]input, error) {
|
|||||||
|
|
||||||
for i, part := range parts {
|
for i, part := range parts {
|
||||||
// text - tokenize
|
// text - tokenize
|
||||||
tokens, err := s.model.(model.TextProcessor).Encode(part)
|
tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
Reference in New Issue
Block a user