prototyping

This commit is contained in:
ParthSareen
2025-03-25 15:00:14 -07:00
parent 1fd9967558
commit 5ec6bb52a0
11 changed files with 1647 additions and 13 deletions

View File

@@ -17,12 +17,14 @@ type token struct {
}
type Sampler struct {
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
rng *rand.Rand
topK int
topP float32
minP float32
temperature float32
grammar *Grammar
JSONSampler *JSONSampler
PythonSampler *PythonSampler
}
func (s *Sampler) Sample(logits []float32) (int32, error) {
@@ -30,6 +32,19 @@ func (s *Sampler) Sample(logits []float32) (int32, error) {
return -1, errors.New("sample: no logits provided to sample")
}
var err error
if s.JSONSampler != nil {
logits, err = s.JSONSampler.Apply(logits)
if err != nil {
return -1, err
}
}
if s.PythonSampler != nil {
logits, err = s.PythonSampler.ApplyMask(logits)
if err != nil {
return -1, err
}
}
tokens := make([]token, len(logits))
for i := range logits {
tokens[i].id = int32(i)
@@ -127,7 +142,7 @@ func (s *Sampler) sample(tokens []token) (token, error) {
}
// TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler {
func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler {
var rng *rand.Rand
if seed != -1 {
// PCG requires two parameters: sequence and stream
@@ -155,12 +170,14 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed
}
return Sampler{
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
rng: rng,
topK: topK,
topP: topP,
minP: minP,
temperature: temperature,
grammar: grammar,
JSONSampler: jsonSampler,
PythonSampler: pythonSampler,
}
}