mirror of
https://github.com/ollama/ollama.git
synced 2025-04-16 07:31:35 +02:00
173 lines
3.8 KiB
Go
173 lines
3.8 KiB
Go
package model
|
||
|
||
import (
|
||
"log/slog"
|
||
"os"
|
||
"path/filepath"
|
||
"slices"
|
||
"testing"
|
||
|
||
"google.golang.org/protobuf/proto"
|
||
|
||
"github.com/ollama/ollama/convert/sentencepiece"
|
||
)
|
||
|
||
func loadSentencePieceVocab(t *testing.T) SentencePieceModel {
|
||
t.Helper()
|
||
|
||
bts, err := os.ReadFile(filepath.Join("testdata", "gemma2", "tokenizer.model"))
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
var spm sentencepiece.ModelProto
|
||
if err := proto.Unmarshal(bts, &spm); err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
var v Vocabulary
|
||
|
||
for _, piece := range spm.GetPieces() {
|
||
v.Values = append(v.Values, piece.GetPiece())
|
||
v.Scores = append(v.Scores, piece.GetScore())
|
||
switch t := piece.GetType(); t {
|
||
case sentencepiece.ModelProto_SentencePiece_UNKNOWN,
|
||
sentencepiece.ModelProto_SentencePiece_CONTROL,
|
||
sentencepiece.ModelProto_SentencePiece_UNUSED,
|
||
sentencepiece.ModelProto_SentencePiece_BYTE:
|
||
v.Types = append(v.Types, uint32(t))
|
||
default:
|
||
tt := uint32(sentencepiece.ModelProto_SentencePiece_NORMAL)
|
||
// todo parse the special tokens file
|
||
// - this will roundtrip correctly but the <start_of_turn> and
|
||
// <end_of_turn> tokens aren't processed
|
||
v.Types = append(v.Types, tt)
|
||
}
|
||
}
|
||
|
||
return NewSentencePieceModel(&v)
|
||
}
|
||
|
||
func TestSentencePieceEncode(t *testing.T) {
|
||
logger := slog.New(slog.NewTextHandler(os.Stdout, &slog.HandlerOptions{Level: slog.LevelDebug}))
|
||
slog.SetDefault(logger)
|
||
|
||
tokenizer := loadSentencePieceVocab(t)
|
||
|
||
t.Run("basic roundtrip", func(t *testing.T) {
|
||
t.Parallel()
|
||
|
||
cases := []string{
|
||
"hello",
|
||
"hello ",
|
||
"hello ",
|
||
" hello",
|
||
" hello ",
|
||
" hello ",
|
||
"hello world",
|
||
"请考试我的软件!12345",
|
||
"你好",
|
||
"Hello 你好 world!",
|
||
"Special characters: !@#$%^&*()_+-=[]{}|;':\",./<>?",
|
||
"Multilingual: 你好 こんにちは Привет Hola مرحبا",
|
||
"Numbers and symbols: 123456789 +- */",
|
||
"Special tokens: <bos> text <eos>",
|
||
"Code snippets: func main() { fmt.Println(\"Hello World\") }",
|
||
"Long text: " + "Lorem ipsum dolor sit amet, consectetur adipiscing elit. " +
|
||
"Sed do eiusmod tempor incididunt ut labore et dolore magna aliqua. " +
|
||
"Ut enim ad minim veniam, quis nostrud exercitation ullamco laboris.",
|
||
}
|
||
|
||
for _, want := range cases {
|
||
ids, err := tokenizer.Encode(want, true)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
|
||
if got, err := tokenizer.Decode(ids); err != nil {
|
||
t.Fatal(err)
|
||
} else if got != want {
|
||
t.Errorf("got %q, want %q [%#v]", got, want, ids)
|
||
}
|
||
}
|
||
})
|
||
|
||
t.Run("special tokens", func(t *testing.T) {
|
||
type candidate struct {
|
||
token string
|
||
ids []int32
|
||
}
|
||
|
||
cases := []candidate{
|
||
{"<bos>", []int32{2}},
|
||
{"<eos>", []int32{1}},
|
||
}
|
||
|
||
for _, want := range cases {
|
||
ids, err := tokenizer.Encode(want.token, true)
|
||
if err != nil {
|
||
t.Fatal(err)
|
||
}
|
||
if !slices.Equal(ids, want.ids) {
|
||
t.Errorf("got %#v, want %#v", ids, want.ids)
|
||
}
|
||
}
|
||
})
|
||
}
|
||
|
||
func TestSentencePieceModelDecodeByteTokens(t *testing.T) {
|
||
vocab := &Vocabulary{
|
||
Values: []string{
|
||
"normal",
|
||
"<0xEA>",
|
||
"<0x41>",
|
||
"<0xC3>",
|
||
"<0xA3>",
|
||
},
|
||
Types: []uint32{
|
||
TOKEN_TYPE_NORMAL,
|
||
TOKEN_TYPE_BYTE,
|
||
TOKEN_TYPE_BYTE,
|
||
TOKEN_TYPE_BYTE,
|
||
TOKEN_TYPE_BYTE,
|
||
},
|
||
Scores: []float32{0, 0, 0, 0, 0},
|
||
}
|
||
|
||
spm := NewSentencePieceModel(vocab)
|
||
|
||
tests := []struct {
|
||
name string
|
||
ids []int32
|
||
expected string
|
||
}{
|
||
{
|
||
name: "single byte token",
|
||
ids: []int32{1},
|
||
expected: "\xea",
|
||
},
|
||
{
|
||
name: "ASCII byte token",
|
||
ids: []int32{2},
|
||
expected: "A",
|
||
},
|
||
{
|
||
name: "multiple byte tokens forming UTF-8 character",
|
||
ids: []int32{3, 4},
|
||
expected: "ã",
|
||
},
|
||
}
|
||
|
||
for _, tt := range tests {
|
||
t.Run(tt.name, func(t *testing.T) {
|
||
result, err := spm.Decode(tt.ids)
|
||
if err != nil {
|
||
t.Errorf("failed to decode token IDs %v: %v", tt.ids, err)
|
||
}
|
||
if result != tt.expected {
|
||
t.Errorf("got %q, want %q", result, tt.expected)
|
||
}
|
||
})
|
||
}
|
||
}
|