mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 11:28:17 +02:00
wip!
This commit is contained in:
parent
6ba557f25b
commit
a2a73ce5e0
@ -25,10 +25,13 @@ const (
|
||||
StateInArray
|
||||
StateInColon
|
||||
StateInComma
|
||||
StateInTab
|
||||
StateInSpace
|
||||
StateInNewline
|
||||
StateInStringEnd
|
||||
StateInObjectKeyEnd
|
||||
StateTerminate
|
||||
StateEnd
|
||||
StateInObjectEnd
|
||||
)
|
||||
|
||||
func (s JSONState) String() string {
|
||||
@ -59,12 +62,18 @@ func (s JSONState) String() string {
|
||||
return "StateInNull"
|
||||
case StateInArray:
|
||||
return "StateInArray"
|
||||
case StateEnd:
|
||||
return "StateEnd"
|
||||
case StateInObjectEnd:
|
||||
return "StateInObjectEnd"
|
||||
case StateInComma:
|
||||
return "StateInComma"
|
||||
case StateInTab:
|
||||
return "StateInTab"
|
||||
case StateInObjectKeyEnd:
|
||||
return "StateInObjectKeyEnd"
|
||||
case StateInNewline:
|
||||
return "StateInNewline"
|
||||
case StateInSpace:
|
||||
return "StateInSpace"
|
||||
case StateTerminate:
|
||||
return "StateTerminate"
|
||||
case StateInStringEnd:
|
||||
@ -124,13 +133,14 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
||||
|
||||
// fmt.Println("tokenSlice", tokenSlice)
|
||||
// todo: account for strings here
|
||||
|
||||
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only move to terminate state if stack is empty
|
||||
if s.curNode.State == StateEnd {
|
||||
if s.curNode.State == StateInObjectEnd {
|
||||
fmt.Println("debug: node.State", s.curNode.State)
|
||||
if len(s.stack) > 0 {
|
||||
s.stack = s.stack[:len(s.stack)-1]
|
||||
|
175
sample/pushdown_automata.go
Normal file
175
sample/pushdown_automata.go
Normal file
@ -0,0 +1,175 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
||||
|
||||
type PDANode struct {
|
||||
State JSONState
|
||||
TransitionEdges map[rune]*PDANode
|
||||
MaskTokenIDToNode map[int32]JSONState
|
||||
}
|
||||
|
||||
func NewPDANode(state JSONState) *PDANode {
|
||||
return &PDANode{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
||||
stateToNodeMap := make(map[JSONState]*PDANode)
|
||||
|
||||
startNode := NewPDANode(StateStart)
|
||||
stateToNodeMap[StateStart] = startNode
|
||||
|
||||
objNode := NewPDANode(StateInObject)
|
||||
stateToNodeMap[StateInObject] = objNode
|
||||
|
||||
objEndNode := NewPDANode(StateInObjectEnd)
|
||||
stateToNodeMap[StateInObjectEnd] = objEndNode
|
||||
|
||||
objKeyNode := NewPDANode(StateInObjectKey)
|
||||
stateToNodeMap[StateInObjectKey] = objKeyNode
|
||||
|
||||
objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
|
||||
stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
|
||||
|
||||
colonNode := NewPDANode(StateInColon)
|
||||
stateToNodeMap[StateInColon] = colonNode
|
||||
|
||||
commaNode := NewPDANode(StateInComma)
|
||||
stateToNodeMap[StateInComma] = commaNode
|
||||
|
||||
newlineNode := NewPDANode(StateInNewline)
|
||||
stateToNodeMap[StateInNewline] = newlineNode
|
||||
|
||||
spaceNode := NewPDANode(StateInSpace)
|
||||
stateToNodeMap[StateInSpace] = spaceNode
|
||||
|
||||
tabNode := NewPDANode(StateInTab)
|
||||
stateToNodeMap[StateInTab] = tabNode
|
||||
|
||||
stringNode := NewPDANode(StateInString)
|
||||
stateToNodeMap[StateInString] = stringNode
|
||||
|
||||
stringEndNode := NewPDANode(StateInStringEnd)
|
||||
stateToNodeMap[StateInStringEnd] = stringEndNode
|
||||
|
||||
// terminateNode := NewNode(StateTerminate)
|
||||
|
||||
// Connect nodes
|
||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||
startNode.TransitionEdges['{'] = objNode
|
||||
|
||||
objNode.TransitionEdges['"'] = objKeyNode
|
||||
objNode.TransitionEdges['\n'] = newlineNode
|
||||
|
||||
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||
newlineNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
tabNode.TransitionEdges['"'] = objKeyNode
|
||||
|
||||
spaceNode.TransitionEdges['"'] = stringNode
|
||||
|
||||
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||
objKeyNode.TransitionEdges[' '] = spaceNode
|
||||
// objKeyNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||
|
||||
colonNode.TransitionEdges['"'] = stringNode
|
||||
colonNode.TransitionEdges[' '] = spaceNode
|
||||
|
||||
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||
stringNode.TransitionEdges['"'] = stringEndNode
|
||||
|
||||
stringEndNode.TransitionEdges[','] = commaNode
|
||||
stringEndNode.TransitionEdges['}'] = objEndNode
|
||||
|
||||
commaNode.TransitionEdges['{'] = objNode
|
||||
commaNode.TransitionEdges['\n'] = newlineNode
|
||||
commaNode.TransitionEdges['\t'] = tabNode
|
||||
commaNode.TransitionEdges['"'] = objKeyNode
|
||||
|
||||
return startNode, stateToNodeMap, nil
|
||||
}
|
||||
|
||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||
|
||||
vocab := proc.GetVocabulary()
|
||||
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, node := range stateToNodeMap {
|
||||
for i := range vocab.Values {
|
||||
token := decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
|
||||
continue
|
||||
}
|
||||
valid := true
|
||||
curNode := node
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||
if specialRune {
|
||||
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||
return false, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific rune transition
|
||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||
if specialRune {
|
||||
if curNode.State == nextNode.State {
|
||||
return false, nil, nil
|
||||
}
|
||||
// fmt.Println("special rune", r, "consumed")
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return true, nextNode, nil
|
||||
}
|
||||
|
||||
// Check for sentinel value - if present, any rune is valid
|
||||
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||
return true, nextNode, nil
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
147
sample/pushdown_runner.go
Normal file
147
sample/pushdown_runner.go
Normal file
@ -0,0 +1,147 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type PushdownSampler struct {
|
||||
// stateful
|
||||
curNode *PDANode
|
||||
proc model.TextProcessor
|
||||
stateToNodeMap map[JSONState]*PDANode
|
||||
braceStack []rune
|
||||
}
|
||||
|
||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
err = PreComputeValidStates(stateToNodeMap, proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
||||
// token, err := proc.Decode([]int32{int32(id)})
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
// fmt.Println("id", id, "node", node, "token", token)
|
||||
// }
|
||||
// time.Sleep(10 * time.Second)
|
||||
return &PushdownSampler{
|
||||
curNode: startNode,
|
||||
proc: proc,
|
||||
stateToNodeMap: stateToNodeMap,
|
||||
braceStack: []rune{},
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
fmt.Println("sample:", s.curNode.State)
|
||||
|
||||
switch s.curNode.State {
|
||||
case StateInObjectEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
valid, err := s.proc.Encode("}")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for i := range logits {
|
||||
for _, token := range valid {
|
||||
if i != int(token) {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
// return logits, nil
|
||||
case StateTerminate:
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
// case StateInStringEnd:
|
||||
|
||||
// return logits, nil
|
||||
default:
|
||||
fmt.Println("masking logits current state", s.curNode.State)
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
fmt.Println("update state", s.curNode.State)
|
||||
|
||||
// TODO: need to handle end states and entering object case
|
||||
if s.curNode.State == StateInObjectEnd {
|
||||
fmt.Println("in object end")
|
||||
if len(s.braceStack) > 0 {
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
return nil
|
||||
}
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
// TODO: return here?
|
||||
}
|
||||
// need this cause there could be multiple transitions
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
for _, r := range mappedString {
|
||||
if r == rune('{') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
||||
return fmt.Errorf("unmatched closing brace")
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
}
|
||||
}
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
fmt.Println("transitioning to", nextNode)
|
||||
s.curNode = s.stateToNodeMap[nextNode]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||
for i := range logits {
|
||||
_, exists := node.MaskTokenIDToNode[int32(i)]
|
||||
if !exists {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
// TODO: add penalties for string \n stuff
|
Loading…
x
Reference in New Issue
Block a user