mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
sample: use grammar interface without modedl loading
This commit is contained in:
parent
2d64c195a2
commit
dde185b86d
@ -5,9 +5,9 @@ import (
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"slices"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/llama"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
// token represents information about a single token during sampling
|
||||
@ -165,22 +165,19 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
|
||||
}
|
||||
|
||||
type Grammar struct {
|
||||
vocab *Vocab
|
||||
grammar string
|
||||
sampler *llama.Sampler
|
||||
vocab *model.Vocabulary
|
||||
grammar *llama.Grammar
|
||||
}
|
||||
|
||||
func NewGrammar(vocab *Vocab, grammar string) (*Grammar, error) {
|
||||
v, err := vocab.Load()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
func NewGrammar(vocab *model.Vocabulary, grammarStr string) (*Grammar, error) {
|
||||
grammar := llama.InitGrammarChain(grammarStr)
|
||||
for _, s := range vocab.Values {
|
||||
id := vocab.Encode(s)
|
||||
grammar.AddSymbol(s, uint32(id))
|
||||
grammar.AddTokenPiece(uint32(id), s)
|
||||
}
|
||||
|
||||
return &Grammar{
|
||||
vocab: vocab,
|
||||
grammar: grammar,
|
||||
sampler: llama.NewGrammarSampler(v, grammar),
|
||||
}, nil
|
||||
grammar.SetEOGToken(uint32(vocab.EOS))
|
||||
return &Grammar{vocab: vocab, grammar: grammar}, nil
|
||||
}
|
||||
|
||||
func (g *Grammar) Apply(tokens []token) {
|
||||
@ -189,8 +186,7 @@ func (g *Grammar) Apply(tokens []token) {
|
||||
tds[i].Id = token.id
|
||||
tds[i].Logit = token.value
|
||||
}
|
||||
|
||||
g.sampler.Apply(tds)
|
||||
g.grammar.Apply(tds)
|
||||
|
||||
for i := range tokens {
|
||||
tokens[i].value = tds[i].Logit
|
||||
@ -198,29 +194,5 @@ func (g *Grammar) Apply(tokens []token) {
|
||||
}
|
||||
|
||||
func (g *Grammar) Accept(token int32) {
|
||||
g.sampler.Accept(token)
|
||||
}
|
||||
|
||||
type Vocab struct {
|
||||
once sync.Once
|
||||
vocab *llama.Vocab
|
||||
err error
|
||||
path string
|
||||
}
|
||||
|
||||
func NewVocab(path string) *Vocab {
|
||||
return &Vocab{path: path}
|
||||
}
|
||||
|
||||
// Load returns the lazily-loaded vocabulary
|
||||
func (v *Vocab) Load() (*llama.Vocab, error) {
|
||||
v.once.Do(func() {
|
||||
vocab, err := llama.LoadVocabFromFile(v.path)
|
||||
if err != nil {
|
||||
v.err = err
|
||||
return
|
||||
}
|
||||
v.vocab = vocab
|
||||
})
|
||||
return v.vocab, v.err
|
||||
g.grammar.Accept(token)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user