This commit is contained in:
ParthSareen 2025-01-27 16:33:55 -08:00
parent a2a73ce5e0
commit e93db4d20e
3 changed files with 214 additions and 29 deletions

View File

@ -22,12 +22,18 @@ const (
StateInFloat
StateInBool
StateInNull
StateInArray
StateInColon
StateInComma
StateInTab
StateInSpace
StateInObjSpace
StateInList
StateInListComma
StateListEnd
StateInListEnd
StateInNewline
StateInNumber
StateInNumberEnd
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
@ -42,42 +48,54 @@ func (s JSONState) String() string {
return "StateInObject"
case StateInObjectKey:
return "StateInObjectKey"
case StateInString:
return "StateInString"
case StateNewline:
return "StateNewline"
case StateTab:
return "StateTab"
case StateSpace:
return "StateSpace"
case StateInString:
return "StateInString"
case StateInInt:
return "StateInInt"
case StateInFloat:
return "StateInFloat"
case StateInColon:
return "StateInColon"
case StateInBool:
return "StateInBool"
case StateInNull:
return "StateInNull"
case StateInArray:
return "StateInArray"
case StateInObjectEnd:
return "StateInObjectEnd"
case StateInColon:
return "StateInColon"
case StateInComma:
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateInNewline:
return "StateInNewline"
case StateInSpace:
return "StateInSpace"
case StateTerminate:
return "StateTerminate"
case StateInObjSpace:
return "StateInObjSpace"
case StateInList:
return "StateInList"
case StateInListComma:
return "StateInListComma"
case StateListEnd:
return "StateListEnd"
case StateInListEnd:
return "StateInListEnd"
case StateInNewline:
return "StateInNewline"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
return "StateInNumberEnd"
case StateInStringEnd:
return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:
return "StateInObjectEnd"
default:
return fmt.Sprintf("Unknown state: %d", s)
}
@ -264,6 +282,7 @@ func getValidStates(node *Node) []int32 {
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
// todo: this can prob be more efficient
for i := range logits {
isValid := false
for _, token := range validStates {

View File

@ -8,6 +8,15 @@ import (
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDANode struct {
State JSONState
TransitionEdges map[rune]*PDANode
@ -52,6 +61,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
spaceNode := NewPDANode(StateInSpace)
stateToNodeMap[StateInSpace] = spaceNode
spaceObjNode := NewPDANode(StateInObjSpace)
stateToNodeMap[StateInObjSpace] = spaceObjNode
tabNode := NewPDANode(StateInTab)
stateToNodeMap[StateInTab] = tabNode
@ -61,7 +73,31 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
stringEndNode := NewPDANode(StateInStringEnd)
stateToNodeMap[StateInStringEnd] = stringEndNode
// terminateNode := NewNode(StateTerminate)
listNode := NewPDANode(StateInList)
stateToNodeMap[StateInList] = listNode
listCommaNode := NewPDANode(StateInListComma)
stateToNodeMap[StateInListComma] = listCommaNode
listEndNode := NewPDANode(StateListEnd)
stateToNodeMap[StateListEnd] = listEndNode
numberNode := NewPDANode(StateInNumber)
stateToNodeMap[StateInNumber] = numberNode
boolNode := NewPDANode(StateInBool)
stateToNodeMap[StateInBool] = boolNode
nullNode := NewPDANode(StateInNull)
stateToNodeMap[StateInNull] = nullNode
// Defined with structured outputs only
intNode := NewPDANode(StateInInt)
stateToNodeMap[StateInInt] = intNode
// TODO:
// consider adding a node to just point to values, could be good to compute that
// mask rather than many different nodes
// Connect nodes
// TODO: if all are single tokens then this can just be connected instead of defining the token
@ -69,34 +105,119 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
objNode.TransitionEdges['"'] = objKeyNode
objNode.TransitionEdges['\n'] = newlineNode
// objNode.TransitionEdges['\t'] = tabNode
newlineNode.TransitionEdges['"'] = objKeyNode
newlineNode.TransitionEdges['\t'] = tabNode
tabNode.TransitionEdges['"'] = objKeyNode
spaceNode.TransitionEdges['"'] = stringNode
// tabNode.TransitionEdges['\t'] = tabNode
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
objKeyNode.TransitionEdges['"'] = objKeyEndNode
objKeyNode.TransitionEdges[' '] = spaceNode
// objKeyNode.TransitionEdges['\t'] = tabNode
objKeyEndNode.TransitionEdges[':'] = colonNode
objEndNode.TransitionEdges[' '] = spaceNode
colonNode.TransitionEdges['"'] = stringNode
// where values should be
// this could be combined but the probs might change, we're alr doing a skip ahead
colonNode.TransitionEdges[' '] = spaceNode
// Leads to a value
spaceNode.TransitionEdges['"'] = stringNode
spaceNode.TransitionEdges['['] = listNode
spaceNode.TransitionEdges['{'] = objNode
for _, r := range validNumberRunes {
spaceNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
spaceNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
spaceNode.TransitionEdges[r] = nullNode
}
// Values
// string node
stringNode.TransitionEdges[rune(-1)] = stringNode
stringNode.TransitionEdges['"'] = stringEndNode
stringEndNode.TransitionEdges[','] = commaNode
stringEndNode.TransitionEdges['}'] = objEndNode
stringEndNode.TransitionEdges[']'] = listEndNode
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
numberNode.TransitionEdges[r] = numberNode
}
numberNode.TransitionEdges[','] = commaNode
numberNode.TransitionEdges['}'] = objEndNode
numberNode.TransitionEdges[']'] = listEndNode
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
// list node
listNode.TransitionEdges[','] = commaNode
listNode.TransitionEdges['"'] = stringNode
// squash states to a value
for _, r := range validNumberRunes {
listNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listNode.TransitionEdges[r] = nullNode
}
// null node
for _, r := range validNullRunes {
nullNode.TransitionEdges[r] = nullNode
}
nullNode.TransitionEdges[','] = commaNode
nullNode.TransitionEdges['}'] = objEndNode
nullNode.TransitionEdges[']'] = listEndNode
// list comma
// should point to values
listCommaNode.TransitionEdges['"'] = stringNode
listCommaNode.TransitionEdges[' '] = listCommaNode
listCommaNode.TransitionEdges['{'] = objNode
listCommaNode.TransitionEdges['\n'] = newlineNode
for _, r := range validNumberRunes {
listCommaNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listCommaNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listCommaNode.TransitionEdges[r] = nullNode
}
// bool node
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
boolNode.TransitionEdges['}'] = objEndNode
boolNode.TransitionEdges[']'] = listEndNode
boolNode.TransitionEdges[','] = commaNode
listEndNode.TransitionEdges['}'] = objEndNode
listEndNode.TransitionEdges[','] = commaNode
commaNode.TransitionEdges['{'] = objNode
commaNode.TransitionEdges['\n'] = newlineNode
commaNode.TransitionEdges['\t'] = tabNode
commaNode.TransitionEdges['"'] = objKeyNode
commaNode.TransitionEdges[' '] = spaceObjNode
spaceObjNode.TransitionEdges['"'] = objKeyNode
return startNode, stateToNodeMap, nil
}

View File

@ -3,6 +3,8 @@ package sample
import (
"fmt"
"math"
"runtime"
"time"
"github.com/ollama/ollama/model"
)
@ -13,9 +15,17 @@ type PushdownSampler struct {
proc model.TextProcessor
stateToNodeMap map[JSONState]*PDANode
braceStack []rune
stateCounter uint32
}
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
startNode, stateToNodeMap, err := BuildGraph(proc)
if err != nil {
panic(err)
@ -24,6 +34,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
if err != nil {
panic(err)
}
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
// token, err := proc.Decode([]int32{int32(id)})
// if err != nil {
@ -37,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
proc: proc,
stateToNodeMap: stateToNodeMap,
braceStack: []rune{},
stateCounter: 0,
}
}
@ -69,7 +85,19 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
}
}
return logits, nil
// return logits, nil
case StateInComma:
peek := s.braceStack[len(s.braceStack)-1]
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
fmt.Println("switching to list comma", s.curNode.State)
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateTerminate:
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
@ -80,9 +108,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
}
return logits, nil
// case StateInStringEnd:
// return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
@ -96,7 +121,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
fmt.Println("update state", s.curNode.State)
// TODO: need to handle end states and entering object case
// TODO: need to handle end states and entering object case, and list case
if s.curNode.State == StateInObjectEnd {
fmt.Println("in object end")
if len(s.braceStack) > 0 {
@ -111,25 +136,45 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
if err != nil {
return err
}
// TODO: should force closing for all braces
for _, r := range mappedString {
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('[') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('}') {
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
return fmt.Errorf("unmatched closing brace")
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
fmt.Println("popping brace stack", s.braceStack)
}
if r == rune(']') {
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
return fmt.Errorf("unmatched closing brace")
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
fmt.Println("popping brace stack", s.braceStack)
}
}
for _, tokenID := range tokenSlice {
// transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
fmt.Println("transitioning to", nextNode)
s.curNode = s.stateToNodeMap[nextNode]
fmt.Println("transitioning to", nextNodeState)
// TODO: add a penalty for staying in the same state too long
if nextNodeState == s.curNode.State {
s.stateCounter++
} else {
s.stateCounter = 0
}
s.curNode = s.stateToNodeMap[nextNodeState]
}
return nil
}