package model import ( "cmp" "iter" "log/slog" "slices" "strings" "sync" "github.com/dlclark/regexp2" heap "github.com/emirpasic/gods/v2/trees/binaryheap" ) type Special int32 const ( SpecialBOS Special = iota SpecialEOS ) const ( TOKEN_TYPE_NORMAL = iota + 1 TOKEN_TYPE_UNKNOWN TOKEN_TYPE_CONTROL TOKEN_TYPE_USER_DEFINED TOKEN_TYPE_UNUSED TOKEN_TYPE_BYTE ) type TextProcessor interface { Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool } type Vocabulary struct { Values []string Types []uint32 Scores []float32 Merges []string BOS, EOS, EOT int32 AddBOS, AddEOS, AddEOT bool specialOnce sync.Once special []string valuesOnce sync.Once values map[string]int32 mergeOnce sync.Once merge map[string]int32 } func (v *Vocabulary) Is(id int32, special Special) bool { switch special { case SpecialBOS: return id == v.BOS case SpecialEOS: return id == v.EOS || id == v.EOT default: return false } } func (v *Vocabulary) Encode(s string) int32 { v.valuesOnce.Do(func() { v.values = make(map[string]int32, len(v.Values)) for i, value := range v.Values { v.values[value] = int32(i) } }) if id, ok := v.values[s]; ok { return id } return -1 } func (v *Vocabulary) Decode(id int32) string { return v.Values[id] } func (v *Vocabulary) SpecialVocabulary() []string { v.specialOnce.Do(func() { for i := range v.Values { if slices.Contains([]int{105, 106}, i) { v.special = append(v.special, v.Values[i]) } else if v.Types[i] == TOKEN_TYPE_CONTROL { v.special = append(v.special, v.Values[i]) } } }) return v.special } func (v *Vocabulary) Merge(left, right string) int { v.mergeOnce.Do(func() { v.merge = make(map[string]int32, len(v.Merges)) for i, merge := range v.Merges { v.merge[merge] = int32(i) } }) if id, ok := v.merge[left+" "+right]; ok { return int(id) } return -1 } type BytePairEncoding struct { pre *regexp2.Regexp vocab *Vocabulary } func NewBytePairEncoding(pre string, vocab *Vocabulary) BytePairEncoding { return BytePairEncoding{ pre: regexp2.MustCompile(pre, regexp2.Unicode|regexp2.RE2), vocab: vocab, } } func (bpe BytePairEncoding) Is(id int32, special Special) bool { return bpe.vocab.Is(id, special) } func (bpe *BytePairEncoding) split(s string) iter.Seq[string] { return func(yield func(string) bool) { for m, _ := bpe.pre.FindStringMatch(s); m != nil; m, _ = bpe.pre.FindNextMatch(m) { if !yield(m.String()) { break } } } } // fragment is a string fragment and their corresponding token IDs type fragment struct { value string ids []int32 } // pair is a pair of runes and its rank type pair struct { a, b int rank int value string } type merge struct { p, n int runes []rune } func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { fragments := []fragment{{value: s}} for _, special := range bpe.vocab.SpecialVocabulary() { // TODO: process special tokens concurrently id := bpe.vocab.Encode(special) for i := 0; i < len(fragments); i++ { frag := fragments[i] if len(frag.ids) > 0 { continue } var middle []fragment switch i := strings.Index(frag.value, special); { case i < 0: middle = append(middle, frag) case i > 0: middle = append(middle, fragment{value: frag.value[:i]}) fallthrough default: middle = append(middle, fragment{value: special, ids: []int32{id}}) if rest := frag.value[i+len(special):]; rest != "" { middle = append(middle, fragment{value: rest}) } } fragments = append(fragments[:i], append(middle, fragments[i+1:]...)...) } } var ids []int32 for _, frag := range fragments { if len(frag.ids) > 0 { ids = append(ids, frag.ids...) continue } for split := range bpe.split(frag.value) { // TODO: process splits concurrently var sb strings.Builder for _, b := range []byte(split) { r := rune(b) switch { case r == 0x00ad: r = 0x0143 case r <= 0x0020: r = r + 0x0100 case r >= 0x007e && r <= 0x00a0: r = r + 0x00a2 } sb.WriteRune(r) } // short circuit if the fragment is in the vocabulary if id := bpe.vocab.Encode(sb.String()); id >= 0 { ids = append(ids, id) continue } runes := []rune(sb.String()) merges := make([]merge, len(runes)) for r := range runes { merges[r] = merge{ p: r - 1, n: r + 1, runes: []rune{runes[r]}, } } pairwise := func(a, b int) *pair { if a < 0 || b >= len(runes) { return nil } left, right := string(merges[a].runes), string(merges[b].runes) rank := bpe.vocab.Merge(left, right) if rank < 0 { return nil } return &pair{ a: a, b: b, rank: rank, value: left + right, } } pairs := heap.NewWith(func(i, j *pair) int { return cmp.Compare(i.rank, j.rank) }) for i := range len(runes) - 1 { if pair := pairwise(i, i+1); pair != nil { pairs.Push(pair) } } for !pairs.Empty() { pair, _ := pairs.Pop() left, right := merges[pair.a], merges[pair.b] if len(left.runes) == 0 || len(right.runes) == 0 || string(left.runes)+string(right.runes) != pair.value { continue } merges[pair.a].runes = append(left.runes, right.runes...) merges[pair.b].runes = nil merges[pair.a].n = right.n if right.n < len(merges) { merges[right.n].p = pair.a } if pair := pairwise(merges[pair.a].p, pair.a); pair != nil { pairs.Push(pair) } if pair := pairwise(pair.a, merges[pair.a].n); pair != nil { pairs.Push(pair) } } for _, merge := range merges { if len(merge.runes) > 0 { // TODO: handle the edge case where the rune isn't in the vocabulary if id := bpe.vocab.Encode(string(merge.runes)); id >= 0 { ids = append(ids, id) } } } } } if addSpecial && len(ids) > 0 { if bpe.vocab.AddBOS { if ids[0] == bpe.vocab.BOS { slog.Warn("adding bos token to prompt which already has it", "id", bpe.vocab.BOS) } slog.Debug("adding bos token to prompt", "id", bpe.vocab.BOS) ids = append([]int32{bpe.vocab.BOS}, ids...) } if bpe.vocab.AddEOS { if ids[len(ids)-1] == bpe.vocab.EOS { slog.Warn("adding eos token to prompt which already has it", "id", bpe.vocab.EOS) } slog.Debug("adding eos token to prompt", "id", bpe.vocab.EOS) ids = append(ids, bpe.vocab.EOS) } } return ids, nil } func (bpe BytePairEncoding) Decode(ids []int32) (string, error) { var sb strings.Builder for _, id := range ids { for _, r := range bpe.vocab.Decode(id) { switch { case r == 0x0100: // this produces 0x00 aka NULL continue case r == 0x0143: r = 0x00ad case r > 0x0100 && r <= 0x0120: r = r - 0x0100 case r > 0x0120 && r <= 0x0142: r = r - 0x00a2 } // NOTE: not using WriteRune here because it writes the UTF-8 // encoding of the rune which is _not_ what we want if err := sb.WriteByte(byte(r)); err != nil { return "", err } } } return sb.String(), nil }