tested with so

This commit is contained in:
ParthSareen
2025-01-31 17:12:39 -08:00
parent b973dedb4b
commit 524029cd6d
5 changed files with 193 additions and 64 deletions

View File

@@ -106,11 +106,31 @@ func temp() error {
} }
} }
pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) // pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
// simple schema
// This schema maps to JSON like:
// {
// "name": "some string value"
// }
schema := &sample.Schema{
Name: "root",
Type: "object",
Properties: []*sample.Schema{
{Name: "name", Type: "string"},
},
}
pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor))
if err != nil {
return err
}
var offset int var offset int
var stringBuffer string var stringBuffer string
var firstTokenTime time.Duration var firstTokenTime time.Duration
var totalSamplingTime time.Duration
count := 0
for range args.n { for range args.n {
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...) logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
if err != nil { if err != nil {
@@ -122,7 +142,6 @@ func temp() error {
for i, f32 := range f32s { for i, f32 := range f32s {
f64s[i] = float64(f32) f64s[i] = float64(f32)
} }
sampleTime := time.Now()
samplers := []sample.Sampler{ samplers := []sample.Sampler{
pushdownSampler, pushdownSampler,
// sample.Weighed(), // sample.Weighed(),
@@ -131,12 +150,16 @@ func temp() error {
sample.Greedy(), sample.Greedy(),
} }
samplingStart := time.Now()
f64s, err = sample.Sample(f64s, samplers...) f64s, err = sample.Sample(f64s, samplers...)
if err != nil { if err != nil {
return err return err
} }
finishTime := time.Now() samplingTime := time.Since(samplingStart)
fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) totalSamplingTime += samplingTime
fmt.Println("sampling time", samplingTime)
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
var outputIDs []int32 var outputIDs []int32
for _, f64 := range f64s { for _, f64 := range f64s {
@@ -164,6 +187,7 @@ func temp() error {
// fmt.Printf("--- token: %q\n", s) // fmt.Printf("--- token: %q\n", s)
// fmt.Printf("--- outputIDs: %v\n", outputIDs) // fmt.Printf("--- outputIDs: %v\n", outputIDs)
stringBuffer += s stringBuffer += s
count++
fmt.Println("--- stringBuffer", stringBuffer) fmt.Println("--- stringBuffer", stringBuffer)
err = pushdownSampler.UpdateState(outputIDs) err = pushdownSampler.UpdateState(outputIDs)
@@ -179,7 +203,7 @@ func temp() error {
fmt.Println("\n------ Output: ------") fmt.Println("\n------ Output: ------")
fmt.Println(stringBuffer) fmt.Println(stringBuffer)
fmt.Println("--------------------") fmt.Println("--------------------")
fmt.Println("sample average time", totalSamplingTime/time.Duration(count))
return nil return nil
} }

View File

@@ -10,6 +10,8 @@ const (
StateStart JSONState = iota StateStart JSONState = iota
StateInObject StateInObject
StateInObjectKey StateInObjectKey
StateInStructuredKey
StateInStructuredValue
StateNewline StateNewline
StateTab StateTab
StateSpace StateSpace
@@ -43,6 +45,7 @@ var JSONStates = []JSONState{
StateStart, StateStart,
StateInObject, StateInObject,
StateInObjectKey, StateInObjectKey,
StateInStructuredKey,
StateNewline, StateNewline,
StateTab, StateTab,
StateSpace, StateSpace,
@@ -80,6 +83,8 @@ func (s JSONState) String() string {
return "StateInObject" return "StateInObject"
case StateInObjectKey: case StateInObjectKey:
return "StateInObjectKey" return "StateInObjectKey"
case StateInStructuredKey:
return "StateInStructuredKey"
case StateNewline: case StateNewline:
return "StateNewline" return "StateNewline"
case StateTab: case StateTab:

View File

@@ -21,14 +21,14 @@ var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDANode struct { type PDANode struct {
State JSONState State JSONState
TransitionEdges map[rune]*PDANode TransitionEdges map[rune]*PDANode
MaskTokenIDToNode map[int32]JSONState MaskTokenIDToNode map[int32]*PDANode
} }
func NewPDANode(state JSONState) *PDANode { func NewPDANode(state JSONState) *PDANode {
return &PDANode{ return &PDANode{
State: state, State: state,
TransitionEdges: make(map[rune]*PDANode), TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]JSONState), MaskTokenIDToNode: make(map[int32]*PDANode),
} }
} }
@@ -103,6 +103,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap) addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
// null node // null node
@@ -162,6 +164,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
// TODO: tough life fr. plz fix. // TODO: tough life fr. plz fix.
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
// TODO; should come from top level
vocab := proc.GetVocabulary() vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values)) decodedToks := make([]string, len(vocab.Values))
@@ -175,7 +178,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
var err error var err error
for _, node := range stateToNodeMap { for _, node := range stateToNodeMap {
err = createMask(node, proc, decodedToks, vocab) err = CreateMask(node, proc, decodedToks, vocab)
if err != nil { if err != nil {
return err return err
} }
@@ -183,7 +186,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
return nil return nil
} }
func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error { func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
for i := range vocab.Values { for i := range vocab.Values {
token := decodedToks[i] token := decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
@@ -204,7 +207,8 @@ func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, v
} }
} }
if valid { if valid {
node.MaskTokenIDToNode[int32(i)] = curNode.State // cur node allows skipping
node.MaskTokenIDToNode[int32(i)] = curNode
} }
} }
return nil return nil

View File

@@ -4,7 +4,6 @@ import (
"fmt" "fmt"
"math" "math"
"runtime" "runtime"
"time"
"github.com/ollama/ollama/model" "github.com/ollama/ollama/model"
) )
@@ -22,12 +21,15 @@ type PushdownSampler struct {
// graph should be built once and reused per tokenizer // graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
start := time.Now() // start := time.Now()
// fmt.Println("--------------------------------")
// fmt.Println("PDA sampler")
// fmt.Println("--------------------------------")
var m runtime.MemStats var m runtime.MemStats
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
before := m.Alloc // before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) // fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
startNode, stateToNodeMap, err := BuildGraph(proc) startNode, stateToNodeMap, err := BuildGraph(proc)
if err != nil { if err != nil {
@@ -38,10 +40,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
panic(err) panic(err)
} }
runtime.ReadMemStats(&m) runtime.ReadMemStats(&m)
after := m.Alloc // after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) // fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) // fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start)) // fmt.Printf("Graph build time = %v\n", time.Since(start))
return &PushdownSampler{ return &PushdownSampler{
curNode: startNode, curNode: startNode,
@@ -53,6 +55,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
} }
// TODO: need to add resampling logic if the first sample was not good // TODO: need to add resampling logic if the first sample was not good
// greedy sample + backtrack?
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
// fmt.Println(">>> sample:", s.curNode.State) // fmt.Println(">>> sample:", s.curNode.State)
switch s.curNode.State { switch s.curNode.State {
@@ -60,7 +63,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
return s.maskLogits(logits, s.curNode) return s.maskLogits(logits, s.curNode)
case StateInListEnd: case StateInListEnd:
fmt.Println("in list end", s.braceStack) // fmt.Println("in list end", s.braceStack)
// force finish if no braces left // force finish if no braces left
if len(s.braceStack) == 0 { if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate) s.curNode = NewPDANode(StateTerminate)
@@ -139,12 +142,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
} }
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
fmt.Println("update state", s.curNode.State) // fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice) mappedString, err := s.proc.Decode(tokenSlice)
if err != nil { if err != nil {
return err return err
} }
fmt.Println("mappedString", mappedString) // fmt.Println("mappedString", mappedString)
// TODO: should force closing for all braces - not doing square yet // TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString { for _, r := range mappedString {
@@ -183,23 +186,25 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
for _, tokenID := range tokenSlice { for _, tokenID := range tokenSlice {
// transition to the next node // transition to the next node
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID] nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok { if !ok {
return fmt.Errorf("invalid token: %q", mappedString) return fmt.Errorf("invalid token: %q", mappedString)
} }
// fmt.Println("transitioning to", nextNodeState) // fmt.Println("transitioning to", nextNodeState)
// TODO: add a penalty for staying in the same state too long // TODO: add a penalty for staying in the same state too long
if nextNodeState == s.curNode.State { if nextNode.State == s.curNode.State {
s.stateCounter++ s.stateCounter++
} else { } else {
s.stateCounter = 0 s.stateCounter = 0
} }
s.curNode = s.stateToNodeMap[nextNodeState] s.curNode = nextNode
// fmt.Println("updated curNode state", s.curNode.State)
} }
return nil return nil
} }
// greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) { func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode // TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
// Should be possible through bitwise ops as well // Should be possible through bitwise ops as well

View File

@@ -1,54 +1,145 @@
package sample package sample
import "github.com/ollama/ollama/model" import (
"fmt"
"runtime"
"time"
type StructuredOutput struct { "github.com/ollama/ollama/model"
schema *Schema )
stateToNodeMap map[JSONState]*PDANode
type SOSampler struct {
schema *Schema
propIdx int
propStateMap map[string]*PDANode
pdaSampler *PushdownSampler
} }
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput { func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
_, stateToNodeMap, err := BuildGraph(proc) pdaSampler := NewPushdownSampler(proc)
if err != nil {
panic(err) so := &SOSampler{
schema: schema,
propIdx: -1,
propStateMap: make(map[string]*PDANode),
pdaSampler: pdaSampler,
} }
return &StructuredOutput{ so.schemaToGraph()
schema: schema,
stateToNodeMap: stateToNodeMap, vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
} }
}
func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode { fmt.Println("--------------------------------")
fmt.Println("SOSampler")
fmt.Println("--------------------------------")
// Benchmark this section
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
schemaType := so.schema.EffectiveType() // TODO: still messed up
switch schemaType { for _, node := range so.propStateMap {
case "object": // propName -> node
// each prop is a key curState := node.State
// prevState := StateInObjectKey fromNode := node
for _, prop := range so.schema.Properties { CreateMask(fromNode, proc, decodedToks, vocab)
// name of key for curState == StateInStructuredKey {
name := prop.Name // there is only one edge
prevState := StateInObjectKey for r, toNode := range fromNode.TransitionEdges {
for i, r := range name { // fmt.Println("rune", r, "edge", toNode.State)
newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune CreateMask(toNode, proc, decodedToks, vocab)
fmt.Printf("created mask for %c\n", r)
// Create new node for this state if it doesn't exist curState = toNode.State
if _, exists := so.stateToNodeMap[newState]; !exists { fmt.Println("next state", curState)
so.stateToNodeMap[newState] = &PDANode{ // TODO: theres an extra gen for " right now
State: newState, fromNode = toNode
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]JSONState),
}
}
// Connect previous state to this state via the rune
so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState]
prevState = newState
} }
// type of value
// propType := prop.Type
} }
} }
return nil
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Mask creation time = %v\n", time.Since(start))
fmt.Println("--------------------------------")
return so, nil
}
func (s *SOSampler) schemaToGraph() {
schemaType := s.schema.EffectiveType()
switch schemaType {
case "object":
// TODO: see if we need to connect these to the JSON graph
// prevState := StateInObjectKey
// prevNode := so.stateToNodeMap[prevState]
// each prop is a key
for _, prop := range s.schema.Properties {
// name of key
name := prop.Name
// prevState := StateInObjectKey
keyNode := &PDANode{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]*PDANode),
}
prevNode := keyNode
for _, r := range name {
runeNode := &PDANode{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]*PDANode),
}
fmt.Println("runeNode created", runeNode.State)
fmt.Printf("runeNode created %c\n", r)
// since alloc on heap connections wil still map
prevNode.TransitionEdges[r] = runeNode
prevNode = runeNode
}
// point to end of object key node after all chars are done
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
// points to start of the key
s.propStateMap[name] = keyNode
fmt.Println("name", name, "keyNode", keyNode.State)
}
}
}
func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
switch s.pdaSampler.curNode.State {
// doesnt account for multi rune case
case StateInObjectKey:
// fmt.Println("in object key - structured outputs")
// TODO: this tracking should probably be coming from a stack to track nested objects
// simple case
s.propIdx++
prop := s.schema.Properties[s.propIdx]
// fmt.Println("prop", prop.Name)
s.pdaSampler.curNode = s.propStateMap[prop.Name]
// fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
if err != nil {
return nil, err
}
return logits, nil
default:
return s.pdaSampler.Sample(logits)
}
}
func (s *SOSampler) UpdateState(tokenSlice []int32) error {
return s.pdaSampler.UpdateState(tokenSlice)
} }