diff --git a/model/bytepairencoding.go b/model/bytepairencoding.go index 7ade497daa..e4083dfceb 100644 --- a/model/bytepairencoding.go +++ b/model/bytepairencoding.go @@ -109,7 +109,7 @@ func (bpe BytePairEncoding) Encode(s string, addSpecial bool) ([]int32, error) { r = 0x0143 case r <= 0x0020: r = r + 0x0100 - case r >= 0x007e && r <= 0x00a0: + case r >= 0x007f && r <= 0x00a0: r = r + 0x00a2 } diff --git a/model/bytepairencoding_test.go b/model/bytepairencoding_test.go index 7e310b56e5..71947be993 100644 --- a/model/bytepairencoding_test.go +++ b/model/bytepairencoding_test.go @@ -207,6 +207,36 @@ func TestLlama(t *testing.T) { } } }) + + t.Run("roundtriping 0x00-0xFF", func(t *testing.T) { + t.Parallel() + + for b := 0x00; b <= 0xFF; b++ { + input := string(rune(b)) + ids, err := tokenizer.Encode(input, false) + if err != nil { + t.Errorf("failed to encode rune 0x%02X: %v", b, err) + continue + } + + decoded, err := tokenizer.Decode(ids) + if err != nil { + t.Errorf("failed to decode rune 0x%02X: %v", b, err) + continue + } + + if b == 0x00 { + if len(decoded) != 0 { + t.Errorf("Decode(Encode(0x00)) should be empty, got %v", ids) + } + continue + } + + if decoded != input { + t.Errorf("rune 0x%02X failed roundtrip: got %q, want %q", b, decoded, input) + } + } + }) } func BenchmarkBytePairEncoding(b *testing.B) {