diff --git a/sample/samplers.go b/sample/samplers.go index ef8033691..8c0690f9d 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -5,9 +5,9 @@ import ( "math" "math/rand/v2" "slices" - "sync" "github.com/ollama/ollama/llama" + "github.com/ollama/ollama/model" ) // token represents information about a single token during sampling @@ -165,22 +165,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed } type Grammar struct { - vocab *Vocab - grammar string - sampler *llama.Sampler + vocab *model.Vocabulary + grammar *llama.Grammar } -func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) { - v, err := vocab.Load() - if err != nil { - return nil, err +func NewGrammar(vocab *model.Vocabulary, grammarStr string) (*Grammar, error) { + grammar := llama.InitGrammarChain(grammarStr) + for _, s := range vocab.Values { + id := vocab.Encode(s) + grammar.AddSymbol(s, uint32(id)) + grammar.AddTokenPiece(uint32(id), s) } - - return &Grammar{ - vocab: vocab, - grammar: grammar, - sampler: llama.NewGrammarSampler(v, grammar), - }, nil + grammar.SetEOGToken(uint32(vocab.EOS)) + return &Grammar{vocab: vocab, grammar: grammar}, nil } func (g *Grammar) Apply(tokens []token) { @@ -189,8 +186,7 @@ func (g *Grammar) Apply(tokens []token) { tds[i].Id = token.id tds[i].Logit = token.value } - - g.sampler.Apply(tds) + g.grammar.Apply(tds) for i := range tokens { tokens[i].value = tds[i].Logit @@ -198,29 +194,5 @@ func (g *Grammar) Apply(tokens []token) { } func (g *Grammar) Accept(token int32) { - g.sampler.Accept(token) -} - -type Vocab struct { - once sync.Once - vocab *llama.Vocab - err error - path string -} - -func NewVocab(path string) *Vocab { - return &Vocab{path: path} -} - -// Load returns the lazily-loaded vocabulary -func (v *Vocab) Load() (*llama.Vocab, error) { - v.once.Do(func() { - vocab, err := llama.LoadVocabFromFile(v.path) - if err != nil { - v.err = err - return - } - v.vocab = vocab - }) - return v.vocab, v.err + g.grammar.Accept(token) }