mirror of
https://github.com/ollama/ollama.git
synced 2025-07-21 07:03:00 +02:00
300 lines
8.5 KiB
Go
300 lines
8.5 KiB
Go
package sample
|
|
|
|
import (
|
|
"fmt"
|
|
"log/slog"
|
|
"runtime"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/grammar/jsonschema"
|
|
"github.com/ollama/ollama/model"
|
|
)
|
|
|
|
type JSONSampler struct {
|
|
schema *jsonschema.Schema
|
|
propIdx int
|
|
propToNodeMap map[string]*PDA
|
|
pdaSampler *PushdownSampler
|
|
decodedToks []string
|
|
}
|
|
|
|
func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) {
|
|
slog.Info("NewJSONSampler", "schema", schema)
|
|
if proc == nil {
|
|
return nil, fmt.Errorf("TextProcessor cannot be nil")
|
|
}
|
|
|
|
pdaSampler, err := NewPushdownSampler(proc)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("failed to create PushdownSampler: %w", err)
|
|
}
|
|
|
|
if schema == nil {
|
|
return &JSONSampler{
|
|
schema: nil,
|
|
propIdx: -1,
|
|
propToNodeMap: nil,
|
|
pdaSampler: pdaSampler,
|
|
}, nil
|
|
}
|
|
|
|
// fmt.Println("schema not nil")
|
|
so := &JSONSampler{
|
|
schema: schema,
|
|
propIdx: -1,
|
|
propToNodeMap: make(map[string]*PDA),
|
|
pdaSampler: pdaSampler,
|
|
}
|
|
|
|
so.schemaToGraph()
|
|
|
|
// Benchmark token decoding
|
|
start := time.Now()
|
|
var m runtime.MemStats
|
|
runtime.ReadMemStats(&m)
|
|
before := m.Alloc
|
|
|
|
vocab := proc.Vocab()
|
|
decodedToks := make([]string, len(vocab.Values))
|
|
for i := range vocab.Values {
|
|
token, err := proc.Decode([]int32{int32(i)})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
decodedToks[i] = token
|
|
}
|
|
so.decodedToks = decodedToks
|
|
|
|
runtime.ReadMemStats(&m)
|
|
after := m.Alloc
|
|
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
|
fmt.Printf("Token decode time = %v\n", time.Since(start))
|
|
|
|
fmt.Println("--------------------------------")
|
|
fmt.Println("SOSampler")
|
|
fmt.Println("--------------------------------")
|
|
// Benchmark this section
|
|
start = time.Now()
|
|
runtime.ReadMemStats(&m)
|
|
before = m.Alloc
|
|
|
|
// TODO: still messed up
|
|
// TODO: recursion use case
|
|
// key masks
|
|
for _, prop := range so.schema.Properties {
|
|
node := so.propToNodeMap[prop.Name]
|
|
// propName -> node
|
|
curState := node.State
|
|
fromNode := node
|
|
so.pdaSampler.CreateMask(fromNode)
|
|
for curState == StateInStructuredKey {
|
|
// there is only one edge
|
|
for r, toNode := range fromNode.TransitionEdges {
|
|
fmt.Println("rune", r, "edge", toNode.State)
|
|
so.pdaSampler.CreateMask(toNode)
|
|
fmt.Printf("created mask for %c\n", r)
|
|
curState = toNode.State
|
|
fmt.Println("next state", curState)
|
|
// TODO: theres an extra gen for " right now
|
|
fromNode = toNode
|
|
}
|
|
}
|
|
|
|
if curState != StateInColon {
|
|
return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState)
|
|
}
|
|
|
|
// so.pdaSampler.CreateMask(fromNode)
|
|
|
|
fromNode = fromNode.TransitionEdges[' ']
|
|
|
|
so.pdaSampler.CreateMask(fromNode)
|
|
curState = fromNode.State
|
|
for _, toNode := range fromNode.TransitionEdges {
|
|
fmt.Println("toNode", toNode.State)
|
|
}
|
|
}
|
|
|
|
// runtime.ReadMemStats(&m)
|
|
// after = m.Alloc
|
|
// fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
|
// fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
|
// fmt.Println("--------------------------------")
|
|
|
|
return so, nil
|
|
}
|
|
|
|
func (s *JSONSampler) schemaToGraph() {
|
|
schemaType := s.schema.EffectiveType()
|
|
switch schemaType {
|
|
case "object":
|
|
// TODO: see if we need to connect these to the JSON graph
|
|
|
|
// each prop is a key
|
|
for _, prop := range s.schema.Properties {
|
|
// name of key
|
|
name := prop.Name
|
|
keyNode := &PDA{
|
|
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
|
TransitionEdges: make(map[rune]*PDA),
|
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
|
}
|
|
|
|
prevNode := keyNode
|
|
for _, r := range name {
|
|
runeNode := &PDA{
|
|
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
|
TransitionEdges: make(map[rune]*PDA),
|
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
|
}
|
|
// fmt.Println("runeNode created", runeNode.State)
|
|
// fmt.Printf("runeNode created %c\n", r)
|
|
|
|
// since alloc on heap connections wil still map
|
|
prevNode.TransitionEdges[r] = runeNode
|
|
prevNode = runeNode
|
|
}
|
|
|
|
// point to end of object key node after all chars are done
|
|
// prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
|
|
|
// link to value node
|
|
// Create a node for the end of the key (after the closing quote)
|
|
stringEndNode := &PDA{
|
|
State: StateInStructuredKey,
|
|
TransitionEdges: make(map[rune]*PDA),
|
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
|
}
|
|
prevNode.TransitionEdges['"'] = stringEndNode
|
|
prevNode = stringEndNode
|
|
|
|
// Add transition for colon after key
|
|
colonNode := &PDA{
|
|
State: StateInColon,
|
|
TransitionEdges: make(map[rune]*PDA),
|
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
|
}
|
|
prevNode.TransitionEdges[':'] = colonNode
|
|
prevNode = colonNode
|
|
|
|
// Add transition for space after colon
|
|
spaceNode := &PDA{
|
|
State: StateInSpaceToValue,
|
|
TransitionEdges: make(map[rune]*PDA),
|
|
MaskTokenIDToNode: make(map[int32]*PDA),
|
|
}
|
|
prevNode.TransitionEdges[' '] = spaceNode
|
|
prevNode = spaceNode
|
|
|
|
value := prop.Type
|
|
switch value {
|
|
case "object":
|
|
fmt.Println("object under key: ", name)
|
|
case "array":
|
|
fmt.Println("array under key: ", name)
|
|
case "string":
|
|
fmt.Println("string under key: ", name)
|
|
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString]
|
|
case "number":
|
|
fmt.Println("number under key: ", name)
|
|
for _, r := range validNumberRunes {
|
|
prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber]
|
|
}
|
|
case "boolean":
|
|
fmt.Println("boolean under key: ", name)
|
|
prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
|
prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool]
|
|
prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull]
|
|
}
|
|
|
|
// points to start of the key
|
|
s.propToNodeMap[name] = keyNode
|
|
fmt.Println("name", name, "keyNode", keyNode.State)
|
|
}
|
|
}
|
|
// TODO: do values + recursion
|
|
}
|
|
|
|
func (s *JSONSampler) Apply(logits []float32) ([]float32, error) {
|
|
if s.schema == nil {
|
|
return s.pdaSampler.Apply(logits)
|
|
}
|
|
|
|
switch s.pdaSampler.curNode.State {
|
|
// TODO: doesnt account for multi rune case
|
|
case StateInObjectKey:
|
|
if s.propIdx > len(s.schema.Properties)-1 {
|
|
return nil, fmt.Errorf("propIdx out of bounds")
|
|
}
|
|
// fmt.Println("in object key - structured outputs")
|
|
// TODO: this tracking should probably be coming from a stack to track nested objects
|
|
// simple case
|
|
s.propIdx++
|
|
fmt.Println("propIdx", s.propIdx)
|
|
prop := s.schema.Properties[s.propIdx]
|
|
fmt.Println("prop", prop.Name)
|
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
|
fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
|
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return logits, nil
|
|
|
|
default:
|
|
|
|
// Will only happen for the last prop - can also be precomputed.
|
|
if s.propIdx == len(s.schema.Properties)-1 {
|
|
// todo: if i incremenet propidx then i know im in last value as well
|
|
switch s.pdaSampler.curNode.State {
|
|
case StateInObjectEnd:
|
|
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
|
|
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
|
|
s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
|
s.propIdx++
|
|
|
|
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
|
|
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
|
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
|
delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
|
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
|
|
|
|
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
|
|
s.propIdx++
|
|
}
|
|
}
|
|
return s.pdaSampler.Apply(logits)
|
|
}
|
|
}
|
|
|
|
func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
|
tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if s.schema == nil {
|
|
// Don't need to update state for unconstrained JSON sampling
|
|
return tokenSlice, nil
|
|
}
|
|
|
|
switch s.pdaSampler.curNode.State {
|
|
case StateInObjectKey:
|
|
s.propIdx++
|
|
fmt.Println("propIdx", s.propIdx)
|
|
prop := s.schema.Properties[s.propIdx]
|
|
fmt.Println("prop", prop.Name)
|
|
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
|
// TODO: this does not work - mike
|
|
// str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
|
// if err != nil {
|
|
// return nil, err
|
|
// }
|
|
// fmt.Println("str", str)
|
|
|
|
return tokenSlice, nil
|
|
default:
|
|
return tokenSlice, nil
|
|
}
|
|
}
|