mirror of
https://github.com/ollama/ollama.git
synced 2025-08-27 19:38:41 +02:00
model: treat 'user defined' tokens as special tokens (#11077)
This commit is contained in:
@@ -87,7 +87,7 @@ func (v *Vocabulary) Decode(id int32) string {
|
|||||||
func (v *Vocabulary) SpecialVocabulary() []string {
|
func (v *Vocabulary) SpecialVocabulary() []string {
|
||||||
v.specialOnce.Do(func() {
|
v.specialOnce.Do(func() {
|
||||||
for i := range v.Values {
|
for i := range v.Values {
|
||||||
if v.Types[i] == TOKEN_TYPE_CONTROL {
|
if v.Types[i] == TOKEN_TYPE_CONTROL || v.Types[i] == TOKEN_TYPE_USER_DEFINED {
|
||||||
v.special = append(v.special, v.Values[i])
|
v.special = append(v.special, v.Values[i])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
16
model/vocabulary_test.go
Normal file
16
model/vocabulary_test.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package model
|
||||||
|
|
||||||
|
import "testing"
|
||||||
|
|
||||||
|
func TestVocabulary_SpecialVocabulary(t *testing.T) {
|
||||||
|
vocab := &Vocabulary{
|
||||||
|
Values: []string{"<|startoftext|>", "<|endoftext|>", "<|tool_call_start|>", "<|tool_call_end|>", "hi"},
|
||||||
|
Types: []int32{TOKEN_TYPE_CONTROL, TOKEN_TYPE_CONTROL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_NORMAL},
|
||||||
|
}
|
||||||
|
|
||||||
|
specialVocab := vocab.SpecialVocabulary()
|
||||||
|
|
||||||
|
if len(specialVocab) != 4 {
|
||||||
|
t.Errorf("expected 4 special tokens, got %d", len(specialVocab))
|
||||||
|
}
|
||||||
|
}
|
Reference in New Issue
Block a user