mirror of
https://github.com/ollama/ollama.git
synced 2025-07-12 17:22:59 +02:00
model: Don't unconditionally add special tokens
We sometimes tokenize partial strings. For example, with multimodal inputs, we split the input string around the images and then tokenize each piece. In these cases, we should only add the special tokens on the first piece.
This commit is contained in:
@ -74,7 +74,7 @@ func TestLlama(t *testing.T) {
|
||||
t.Run("simple", func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ids, err := tokenizer.Encode("hello world")
|
||||
ids, err := tokenizer.Encode("hello world", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -92,7 +92,7 @@ func TestLlama(t *testing.T) {
|
||||
t.Errorf("got %q, want hello world", s)
|
||||
}
|
||||
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>")
|
||||
ids, err = tokenizer.Encode("hello <|end_of_text|>", true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -126,7 +126,7 @@ func TestLlama(t *testing.T) {
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -152,7 +152,7 @@ func TestLlama(t *testing.T) {
|
||||
}
|
||||
|
||||
for _, want := range cases {
|
||||
ids, err := tokenizer.Encode(want)
|
||||
ids, err := tokenizer.Encode(want, true)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
}
|
||||
@ -176,7 +176,7 @@ func TestLlama(t *testing.T) {
|
||||
}
|
||||
|
||||
for s, want := range cases {
|
||||
ids, err := tokenizer.Encode(s)
|
||||
ids, err := tokenizer.Encode(s, true)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
@ -222,7 +222,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
b.Run("encode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_, err := tokenizer.Encode(string(bts))
|
||||
_, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
@ -230,7 +230,7 @@ func BenchmarkBytePairEncoding(b *testing.B) {
|
||||
})
|
||||
|
||||
b.Run("decode"+strconv.Itoa(n), func(b *testing.B) {
|
||||
ids, err := tokenizer.Encode(string(bts))
|
||||
ids, err := tokenizer.Encode(string(bts), true)
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
Reference in New Issue
Block a user