mirror of
https://github.com/ollama/ollama.git
synced 2025-04-16 15:41:20 +02:00
247 lines
6.0 KiB
Go
247 lines
6.0 KiB
Go
package model
|
|
|
|
import (
|
|
"iter"
|
|
"log/slog"
|
|
"strings"
|
|
|
|
"github.com/dlclark/regexp2"
|
|
queue "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
|
)
|
|
|
|
const spmWhitespaceSep = "▁"
|
|
|
|
func replaceWhitespaceBySeperator(s string) string {
|
|
return strings.ReplaceAll(s, " ", spmWhitespaceSep)
|
|
}
|
|
|
|
type SentencePieceModel struct {
|
|
maxTokenLen int
|
|
pre *regexp2.Regexp
|
|
vocab *Vocabulary
|
|
}
|
|
|
|
var _ TextProcessor = (*SentencePieceModel)(nil)
|
|
|
|
func NewSentencePieceModel(pre string, vocab *Vocabulary) SentencePieceModel {
|
|
slog.Debug("Tokens", "num tokens", len(vocab.Values), "vals", vocab.Values[:5], "scores", vocab.Scores[:5], "types", vocab.Types[:5])
|
|
|
|
counter := map[int]int{}
|
|
var maxTokenLen int
|
|
for cnt := range vocab.Types {
|
|
switch vocab.Types[cnt] {
|
|
case TOKEN_TYPE_NORMAL, TOKEN_TYPE_USER_DEFINED, TOKEN_TYPE_UNUSED:
|
|
maxTokenLen = max(maxTokenLen, len(vocab.Values[cnt]))
|
|
fallthrough
|
|
default:
|
|
counter[int(vocab.Types[cnt])] += 1
|
|
}
|
|
}
|
|
|
|
slog.Debug("Token counts", "normal", counter[TOKEN_TYPE_NORMAL], "unknown", counter[TOKEN_TYPE_UNKNOWN], "control", counter[TOKEN_TYPE_CONTROL],
|
|
"user defined", counter[TOKEN_TYPE_USER_DEFINED], "unused", counter[TOKEN_TYPE_UNUSED], "byte", counter[TOKEN_TYPE_BYTE],
|
|
"max token len", maxTokenLen)
|
|
|
|
return SentencePieceModel{
|
|
maxTokenLen: maxTokenLen,
|
|
pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2),
|
|
vocab: vocab,
|
|
}
|
|
}
|
|
|
|
func (spm SentencePieceModel) Is(id int32, special Special) bool {
|
|
return spm.vocab.Is(id, special)
|
|
}
|
|
|
|
func (spm *SentencePieceModel) split(s string) iter.Seq[string] {
|
|
return func(yield func(string) bool) {
|
|
for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) {
|
|
if !yield(m.String()) {
|
|
break
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
func (spm SentencePieceModel) Encode(s string, addSpecial bool) ([]int32, error) {
|
|
fragments := []fragment{{value: s}}
|
|
for _, special := range spm.vocab.SpecialVocabulary() {
|
|
// TODO: process special tokens concurrently
|
|
id := spm.vocab.Encode(special)
|
|
for i := 0; i < len(fragments); i++ {
|
|
frag := fragments[i]
|
|
if len(frag.ids) > 0 {
|
|
continue
|
|
}
|
|
|
|
var middle []fragment
|
|
switch i := strings.Index(frag.value, special); {
|
|
case i < 0:
|
|
middle = append(middle, frag)
|
|
case i > 0:
|
|
middle = append(middle, fragment{value: frag.value[:i]})
|
|
fallthrough
|
|
default:
|
|
middle = append(middle, fragment{value: special, ids: []int32{id}})
|
|
if rest := frag.value[i+len(special):]; rest != "" {
|
|
middle = append(middle, fragment{value: rest})
|
|
}
|
|
}
|
|
|
|
fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...)
|
|
}
|
|
}
|
|
slog.Debug("fragments", "frags", fragments)
|
|
|
|
var ids []int32
|
|
for _, frag := range fragments {
|
|
if len(frag.ids) > 0 {
|
|
ids = append(ids, frag.ids...)
|
|
continue
|
|
}
|
|
|
|
for split := range spm.split(frag.value) {
|
|
split = replaceWhitespaceBySeperator(split)
|
|
|
|
var sb strings.Builder
|
|
sb.Write([]byte(split))
|
|
if id := spm.vocab.Encode(sb.String()); id >= 0 {
|
|
ids = append(ids, id)
|
|
continue
|
|
}
|
|
|
|
runes := []rune(sb.String())
|
|
pq := queue.NewWith(func(a, b any) int {
|
|
priA := a.(*candidate)
|
|
priB := b.(*candidate)
|
|
if priA.score > priB.score || (priA.score == priB.score && priA.a < priB.a) {
|
|
return -1
|
|
}
|
|
return 1
|
|
})
|
|
|
|
merges := make([]merge, len(runes))
|
|
for r := range runes {
|
|
merges[r] = merge{
|
|
p: r - 1,
|
|
n: r + 1,
|
|
runes: []rune{runes[r]},
|
|
}
|
|
}
|
|
|
|
slog.Debug("tokenizer", "merges", merges)
|
|
|
|
pairwise := func(a, b int) *candidate {
|
|
if a < 0 || b >= len(runes) {
|
|
return nil
|
|
}
|
|
|
|
left, right := string(merges[a].runes), string(merges[b].runes)
|
|
if id := spm.vocab.Encode(left + right); id >= 0 {
|
|
return &candidate{
|
|
a: a,
|
|
b: b,
|
|
score: spm.vocab.Scores[id],
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
for i := range len(runes) - 1 {
|
|
if pair := pairwise(i, i+1); pair != nil {
|
|
pq.Enqueue(pair)
|
|
}
|
|
}
|
|
|
|
pqv := pq.Values()
|
|
for _, v := range pqv {
|
|
e := v.(*candidate)
|
|
slog.Debug("candidate", "candidate", e)
|
|
}
|
|
|
|
for !pq.Empty() {
|
|
v, _ := pq.Dequeue()
|
|
pair := v.(*candidate)
|
|
left, right := merges[pair.a], merges[pair.b]
|
|
|
|
slog.Debug("pair", "left", left, "right", right)
|
|
if len(left.runes) == 0 || len(right.runes) == 0 {
|
|
continue
|
|
}
|
|
|
|
if id := spm.vocab.Encode(string(left.runes) + string(right.runes)); id < 0 {
|
|
continue
|
|
}
|
|
|
|
merges[pair.a].runes = append(left.runes, right.runes...)
|
|
merges[pair.b].runes = nil
|
|
merges[pair.a].n = right.n
|
|
if right.n < len(merges) {
|
|
merges[right.n].p = pair.a
|
|
}
|
|
|
|
if pair := pairwise(merges[pair.a].p, pair.a); pair != nil {
|
|
pq.Enqueue(pair)
|
|
}
|
|
|
|
if pair := pairwise(pair.a, merges[pair.a].n); pair != nil {
|
|
pq.Enqueue(pair)
|
|
}
|
|
}
|
|
|
|
slog.Debug("merges", "merges", merges)
|
|
|
|
for _, merge := range merges {
|
|
if len(merge.runes) > 0 {
|
|
if id := spm.vocab.Encode(string(merge.runes)); id >= 0 {
|
|
ids = append(ids, id)
|
|
} else {
|
|
slog.Debug("missing token", "token", string(merge.runes))
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
if addSpecial && len(ids) > 0 {
|
|
if spm.vocab.AddBOS {
|
|
if ids[0] == spm.vocab.BOS {
|
|
slog.Warn("adding bos token to prompt which already has it", "id", spm.vocab.BOS)
|
|
}
|
|
|
|
slog.Debug("adding bos token to prompt", "id", spm.vocab.BOS)
|
|
ids = append([]int32{spm.vocab.BOS}, ids...)
|
|
}
|
|
|
|
if spm.vocab.AddEOS {
|
|
if ids[len(ids)-1] == spm.vocab.EOS {
|
|
slog.Warn("adding eos token to prompt which already has it", "id", spm.vocab.EOS)
|
|
}
|
|
|
|
slog.Debug("adding eos token to prompt", "id", spm.vocab.EOS)
|
|
ids = append(ids, spm.vocab.EOS)
|
|
}
|
|
}
|
|
|
|
return ids, nil
|
|
}
|
|
|
|
type candidate struct {
|
|
a, b int
|
|
score float32
|
|
}
|
|
|
|
func (spm SentencePieceModel) Decode(ids []int32) (string, error) {
|
|
var sb strings.Builder
|
|
for _, id := range ids {
|
|
data := spm.vocab.Decode(id)
|
|
data = strings.ReplaceAll(data, spmWhitespaceSep, " ")
|
|
if _, err := sb.WriteString(data); err != nil {
|
|
return "", err
|
|
}
|
|
}
|
|
|
|
slog.Debug("decoded", "ids", ids, "text", sb.String())
|
|
return sb.String(), nil
|
|
}
|