This commit is contained in:
ParthSareen 2025-01-23 20:21:50 -08:00
parent 6ba557f25b
commit a2a73ce5e0
3 changed files with 336 additions and 4 deletions

View File

@ -25,10 +25,13 @@ const (
StateInArray
StateInColon
StateInComma
StateInTab
StateInSpace
StateInNewline
StateInStringEnd
StateInObjectKeyEnd
StateTerminate
StateEnd
StateInObjectEnd
)
func (s JSONState) String() string {
@ -59,12 +62,18 @@ func (s JSONState) String() string {
return "StateInNull"
case StateInArray:
return "StateInArray"
case StateEnd:
return "StateEnd"
case StateInObjectEnd:
return "StateInObjectEnd"
case StateInComma:
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateInNewline:
return "StateInNewline"
case StateInSpace:
return "StateInSpace"
case StateTerminate:
return "StateTerminate"
case StateInStringEnd:
@ -124,13 +133,14 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
// fmt.Println("tokenSlice", tokenSlice)
// todo: account for strings here
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
if err != nil {
return err
}
// only move to terminate state if stack is empty
if s.curNode.State == StateEnd {
if s.curNode.State == StateInObjectEnd {
fmt.Println("debug: node.State", s.curNode.State)
if len(s.stack) > 0 {
s.stack = s.stack[:len(s.stack)-1]

175
sample/pushdown_automata.go Normal file
View File

@ -0,0 +1,175 @@
package sample
import (
"slices"
"github.com/ollama/ollama/model"
)
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
type PDANode struct {
State JSONState
TransitionEdges map[rune]*PDANode
MaskTokenIDToNode map[int32]JSONState
}
func NewPDANode(state JSONState) *PDANode {
return &PDANode{
State: state,
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]JSONState),
}
}
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
stateToNodeMap := make(map[JSONState]*PDANode)
startNode := NewPDANode(StateStart)
stateToNodeMap[StateStart] = startNode
objNode := NewPDANode(StateInObject)
stateToNodeMap[StateInObject] = objNode
objEndNode := NewPDANode(StateInObjectEnd)
stateToNodeMap[StateInObjectEnd] = objEndNode
objKeyNode := NewPDANode(StateInObjectKey)
stateToNodeMap[StateInObjectKey] = objKeyNode
objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
colonNode := NewPDANode(StateInColon)
stateToNodeMap[StateInColon] = colonNode
commaNode := NewPDANode(StateInComma)
stateToNodeMap[StateInComma] = commaNode
newlineNode := NewPDANode(StateInNewline)
stateToNodeMap[StateInNewline] = newlineNode
spaceNode := NewPDANode(StateInSpace)
stateToNodeMap[StateInSpace] = spaceNode
tabNode := NewPDANode(StateInTab)
stateToNodeMap[StateInTab] = tabNode
stringNode := NewPDANode(StateInString)
stateToNodeMap[StateInString] = stringNode
stringEndNode := NewPDANode(StateInStringEnd)
stateToNodeMap[StateInStringEnd] = stringEndNode
// terminateNode := NewNode(StateTerminate)
// Connect nodes
// TODO: if all are single tokens then this can just be connected instead of defining the token
startNode.TransitionEdges['{'] = objNode
objNode.TransitionEdges['"'] = objKeyNode
objNode.TransitionEdges['\n'] = newlineNode
newlineNode.TransitionEdges['"'] = objKeyNode
newlineNode.TransitionEdges['\t'] = tabNode
tabNode.TransitionEdges['"'] = objKeyNode
spaceNode.TransitionEdges['"'] = stringNode
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
objKeyNode.TransitionEdges['"'] = objKeyEndNode
objKeyNode.TransitionEdges[' '] = spaceNode
// objKeyNode.TransitionEdges['\t'] = tabNode
objKeyEndNode.TransitionEdges[':'] = colonNode
colonNode.TransitionEdges['"'] = stringNode
colonNode.TransitionEdges[' '] = spaceNode
stringNode.TransitionEdges[rune(-1)] = stringNode
stringNode.TransitionEdges['"'] = stringEndNode
stringEndNode.TransitionEdges[','] = commaNode
stringEndNode.TransitionEdges['}'] = objEndNode
commaNode.TransitionEdges['{'] = objNode
commaNode.TransitionEdges['\n'] = newlineNode
commaNode.TransitionEdges['\t'] = tabNode
commaNode.TransitionEdges['"'] = objKeyNode
return startNode, stateToNodeMap, nil
}
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
var err error
for _, node := range stateToNodeMap {
for i := range vocab.Values {
token := decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
continue
}
valid := true
curNode := node
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
if err != nil {
return err
}
if !valid {
break
}
}
if valid {
node.MaskTokenIDToNode[int32(i)] = curNode.State
}
}
}
return nil
}
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
if consumedSpecialRunes[r] {
return false, nil, nil
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey {
return false, nil, nil
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
if specialRune {
if curNode.State == nextNode.State {
return false, nil, nil
}
// fmt.Println("special rune", r, "consumed")
consumedSpecialRunes[r] = true
}
return true, nextNode, nil
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return true, nextNode, nil
}
return false, nil, nil
}

147
sample/pushdown_runner.go Normal file
View File

@ -0,0 +1,147 @@
package sample
import (
"fmt"
"math"
"github.com/ollama/ollama/model"
)
type PushdownSampler struct {
// stateful
curNode *PDANode
proc model.TextProcessor
stateToNodeMap map[JSONState]*PDANode
braceStack []rune
}
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
startNode, stateToNodeMap, err := BuildGraph(proc)
if err != nil {
panic(err)
}
err = PreComputeValidStates(stateToNodeMap, proc)
if err != nil {
panic(err)
}
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
// token, err := proc.Decode([]int32{int32(id)})
// if err != nil {
// panic(err)
// }
// fmt.Println("id", id, "node", node, "token", token)
// }
// time.Sleep(10 * time.Second)
return &PushdownSampler{
curNode: startNode,
proc: proc,
stateToNodeMap: stateToNodeMap,
braceStack: []rune{},
}
}
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
fmt.Println("sample:", s.curNode.State)
switch s.curNode.State {
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.NaN()
}
}
return logits, nil
}
valid, err := s.proc.Encode("}")
if err != nil {
return nil, err
}
for i := range logits {
for _, token := range valid {
if i != int(token) {
logits[i] = math.NaN()
}
}
}
return logits, nil
// return logits, nil
case StateTerminate:
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.NaN()
}
}
return logits, nil
// case StateInStringEnd:
// return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
}
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
fmt.Println("update state", s.curNode.State)
// TODO: need to handle end states and entering object case
if s.curNode.State == StateInObjectEnd {
fmt.Println("in object end")
if len(s.braceStack) > 0 {
s.braceStack = s.braceStack[:len(s.braceStack)-1]
return nil
}
s.curNode = NewPDANode(StateTerminate)
// TODO: return here?
}
// need this cause there could be multiple transitions
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
for _, r := range mappedString {
if r == rune('{') {
s.braceStack = append(s.braceStack, r)
}
if r == rune('}') {
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
return fmt.Errorf("unmatched closing brace")
}
s.braceStack = s.braceStack[:len(s.braceStack)-1]
}
}
for _, tokenID := range tokenSlice {
// transition to the next node
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
if !ok {
return fmt.Errorf("invalid token: %q", mappedString)
}
fmt.Println("transitioning to", nextNode)
s.curNode = s.stateToNodeMap[nextNode]
}
return nil
}
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
for i := range logits {
_, exists := node.MaskTokenIDToNode[int32(i)]
if !exists {
logits[i] = math.NaN()
}
}
return logits, nil
}
// TODO: add penalties for string \n stuff