sample: use grammar interface without modedl loading

This commit is contained in:
ParthSareen 2025-04-02 14:36:44 -07:00
parent 2d64c195a2
commit dde185b86d

View File

@ -5,9 +5,9 @@ import (
"math"
"math/rand/v2"
"slices"
"sync"
"github.com/ollama/ollama/llama"
"github.com/ollama/ollama/model"
)
// token represents information about a single token during sampling
@ -165,22 +165,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
type Grammar struct {
vocab *Vocab
grammar string
sampler *llama.Sampler
vocab *model.Vocabulary
grammar *llama.Grammar
}
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
v, err := vocab.Load()
if err != nil {
return nil, err
func NewGrammar(vocab *model.Vocabulary, grammarStr string) (*Grammar, error) {
grammar := llama.InitGrammarChain(grammarStr)
for _, s := range vocab.Values {
id := vocab.Encode(s)
grammar.AddSymbol(s, uint32(id))
grammar.AddTokenPiece(uint32(id), s)
}
return &Grammar{
vocab: vocab,
grammar: grammar,
sampler: llama.NewGrammarSampler(v, grammar),
}, nil
grammar.SetEOGToken(uint32(vocab.EOS))
return &Grammar{vocab: vocab, grammar: grammar}, nil
}
func (g *Grammar) Apply(tokens []token) {
@ -189,8 +186,7 @@ func (g *Grammar) Apply(tokens []token) {
tds[i].Id = token.id
tds[i].Logit = token.value
}
g.sampler.Apply(tds)
g.grammar.Apply(tds)
for i := range tokens {
tokens[i].value = tds[i].Logit
@ -198,29 +194,5 @@ func (g *Grammar) Apply(tokens []token) {
}
func (g *Grammar) Accept(token int32) {
g.sampler.Accept(token)
}
type Vocab struct {
once sync.Once
vocab *llama.Vocab
err error
path string
}
func NewVocab(path string) *Vocab {
return &Vocab{path: path}
}
// Load returns the lazily-loaded vocabulary
func (v *Vocab) Load() (*llama.Vocab, error) {
v.once.Do(func() {
vocab, err := llama.LoadVocabFromFile(v.path)
if err != nil {
v.err = err
return
}
v.vocab = vocab
})
return v.vocab, v.err
g.grammar.Accept(token)
}