From a2a73ce5e059fb58159eea770f5dad8e1cfeb703 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 23 Jan 2025 20:21:50 -0800 Subject: [PATCH] wip! --- sample/fast_json.go | 18 +++- sample/pushdown_automata.go | 175 ++++++++++++++++++++++++++++++++++++ sample/pushdown_runner.go | 147 ++++++++++++++++++++++++++++++ 3 files changed, 336 insertions(+), 4 deletions(-) create mode 100644 sample/pushdown_automata.go create mode 100644 sample/pushdown_runner.go diff --git a/sample/fast_json.go b/sample/fast_json.go index 886efb1f7..b5f9088cc 100644 --- a/sample/fast_json.go +++ b/sample/fast_json.go @@ -25,10 +25,13 @@ const ( StateInArray StateInColon StateInComma + StateInTab + StateInSpace + StateInNewline StateInStringEnd StateInObjectKeyEnd StateTerminate - StateEnd + StateInObjectEnd ) func (s JSONState) String() string { @@ -59,12 +62,18 @@ func (s JSONState) String() string { return "StateInNull" case StateInArray: return "StateInArray" - case StateEnd: - return "StateEnd" + case StateInObjectEnd: + return "StateInObjectEnd" case StateInComma: return "StateInComma" + case StateInTab: + return "StateInTab" case StateInObjectKeyEnd: return "StateInObjectKeyEnd" + case StateInNewline: + return "StateInNewline" + case StateInSpace: + return "StateInSpace" case StateTerminate: return "StateTerminate" case StateInStringEnd: @@ -124,13 +133,14 @@ func (s *JSONSampler) UpdateState(tokenSlice []int32) error { // fmt.Println("tokenSlice", tokenSlice) // todo: account for strings here + objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc) if err != nil { return err } // only move to terminate state if stack is empty - if s.curNode.State == StateEnd { + if s.curNode.State == StateInObjectEnd { fmt.Println("debug: node.State", s.curNode.State) if len(s.stack) > 0 { s.stack = s.stack[:len(s.stack)-1] diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go new file mode 100644 index 000000000..111a91037 --- /dev/null +++ b/sample/pushdown_automata.go @@ -0,0 +1,175 @@ +package sample + +import ( + "slices" + + "github.com/ollama/ollama/model" +) + +var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','} + +type PDANode struct { + State JSONState + TransitionEdges map[rune]*PDANode + MaskTokenIDToNode map[int32]JSONState +} + +func NewPDANode(state JSONState) *PDANode { + return &PDANode{ + State: state, + TransitionEdges: make(map[rune]*PDANode), + MaskTokenIDToNode: make(map[int32]JSONState), + } +} + +func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) { + stateToNodeMap := make(map[JSONState]*PDANode) + + startNode := NewPDANode(StateStart) + stateToNodeMap[StateStart] = startNode + + objNode := NewPDANode(StateInObject) + stateToNodeMap[StateInObject] = objNode + + objEndNode := NewPDANode(StateInObjectEnd) + stateToNodeMap[StateInObjectEnd] = objEndNode + + objKeyNode := NewPDANode(StateInObjectKey) + stateToNodeMap[StateInObjectKey] = objKeyNode + + objKeyEndNode := NewPDANode(StateInObjectKeyEnd) + stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode + + colonNode := NewPDANode(StateInColon) + stateToNodeMap[StateInColon] = colonNode + + commaNode := NewPDANode(StateInComma) + stateToNodeMap[StateInComma] = commaNode + + newlineNode := NewPDANode(StateInNewline) + stateToNodeMap[StateInNewline] = newlineNode + + spaceNode := NewPDANode(StateInSpace) + stateToNodeMap[StateInSpace] = spaceNode + + tabNode := NewPDANode(StateInTab) + stateToNodeMap[StateInTab] = tabNode + + stringNode := NewPDANode(StateInString) + stateToNodeMap[StateInString] = stringNode + + stringEndNode := NewPDANode(StateInStringEnd) + stateToNodeMap[StateInStringEnd] = stringEndNode + + // terminateNode := NewNode(StateTerminate) + + // Connect nodes + // TODO: if all are single tokens then this can just be connected instead of defining the token + startNode.TransitionEdges['{'] = objNode + + objNode.TransitionEdges['"'] = objKeyNode + objNode.TransitionEdges['\n'] = newlineNode + + newlineNode.TransitionEdges['"'] = objKeyNode + newlineNode.TransitionEdges['\t'] = tabNode + + tabNode.TransitionEdges['"'] = objKeyNode + + spaceNode.TransitionEdges['"'] = stringNode + + objKeyNode.TransitionEdges[rune(-1)] = objKeyNode + objKeyNode.TransitionEdges['"'] = objKeyEndNode + objKeyNode.TransitionEdges[' '] = spaceNode + // objKeyNode.TransitionEdges['\t'] = tabNode + + objKeyEndNode.TransitionEdges[':'] = colonNode + + colonNode.TransitionEdges['"'] = stringNode + colonNode.TransitionEdges[' '] = spaceNode + + stringNode.TransitionEdges[rune(-1)] = stringNode + stringNode.TransitionEdges['"'] = stringEndNode + + stringEndNode.TransitionEdges[','] = commaNode + stringEndNode.TransitionEdges['}'] = objEndNode + + commaNode.TransitionEdges['{'] = objNode + commaNode.TransitionEdges['\n'] = newlineNode + commaNode.TransitionEdges['\t'] = tabNode + commaNode.TransitionEdges['"'] = objKeyNode + + return startNode, stateToNodeMap, nil +} + +func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { + + vocab := proc.GetVocabulary() + + decodedToks := make([]string, len(vocab.Values)) + for i := range vocab.Values { + token, err := proc.Decode([]int32{int32(i)}) + if err != nil { + return err + } + decodedToks[i] = token + } + + var err error + for _, node := range stateToNodeMap { + for i := range vocab.Values { + token := decodedToks[i] + // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON + if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" { + continue + } + valid := true + curNode := node + consumedSpecialRunes := make(map[rune]bool) + for _, r := range token { + valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes) + if err != nil { + return err + } + if !valid { + break + } + } + if valid { + node.MaskTokenIDToNode[int32(i)] = curNode.State + } + } + } + return nil +} + +func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) { + if consumedSpecialRunes[r] { + return false, nil, nil + } + + specialRune := slices.Contains(stringInvalidRunes, r) + if specialRune { + if curNode.State == StateInString || curNode.State == StateInObjectKey { + return false, nil, nil + } + } + + // Check for specific rune transition + if nextNode, ok := curNode.TransitionEdges[r]; ok { + if specialRune { + if curNode.State == nextNode.State { + return false, nil, nil + } + // fmt.Println("special rune", r, "consumed") + consumedSpecialRunes[r] = true + } + return true, nextNode, nil + } + + // Check for sentinel value - if present, any rune is valid + if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok { + return true, nextNode, nil + } + + return false, nil, nil +} diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go new file mode 100644 index 000000000..4d27f93dc --- /dev/null +++ b/sample/pushdown_runner.go @@ -0,0 +1,147 @@ +package sample + +import ( + "fmt" + "math" + + "github.com/ollama/ollama/model" +) + +type PushdownSampler struct { + // stateful + curNode *PDANode + proc model.TextProcessor + stateToNodeMap map[JSONState]*PDANode + braceStack []rune +} + +func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { + startNode, stateToNodeMap, err := BuildGraph(proc) + if err != nil { + panic(err) + } + err = PreComputeValidStates(stateToNodeMap, proc) + if err != nil { + panic(err) + } + // for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode { + // token, err := proc.Decode([]int32{int32(id)}) + // if err != nil { + // panic(err) + // } + // fmt.Println("id", id, "node", node, "token", token) + // } + // time.Sleep(10 * time.Second) + return &PushdownSampler{ + curNode: startNode, + proc: proc, + stateToNodeMap: stateToNodeMap, + braceStack: []rune{}, + } +} + +func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { + fmt.Println("sample:", s.curNode.State) + + switch s.curNode.State { + case StateInObjectEnd: + // force finish if no braces left + if len(s.braceStack) == 0 { + s.curNode = NewPDANode(StateTerminate) + for i := range logits { + if s.proc.Is(uint32(i), model.SpecialEOS) { + logits[i] = 1.0 + } else { + logits[i] = math.NaN() + } + } + return logits, nil + } + valid, err := s.proc.Encode("}") + if err != nil { + return nil, err + } + for i := range logits { + for _, token := range valid { + if i != int(token) { + logits[i] = math.NaN() + } + } + } + return logits, nil + // return logits, nil + case StateTerminate: + for i := range logits { + if s.proc.Is(uint32(i), model.SpecialEOS) { + logits[i] = 1.0 + } else { + logits[i] = math.NaN() + } + } + return logits, nil + + // case StateInStringEnd: + + // return logits, nil + default: + fmt.Println("masking logits current state", s.curNode.State) + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil + } +} + +func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { + fmt.Println("update state", s.curNode.State) + + // TODO: need to handle end states and entering object case + if s.curNode.State == StateInObjectEnd { + fmt.Println("in object end") + if len(s.braceStack) > 0 { + s.braceStack = s.braceStack[:len(s.braceStack)-1] + return nil + } + s.curNode = NewPDANode(StateTerminate) + // TODO: return here? + } + // need this cause there could be multiple transitions + mappedString, err := s.proc.Decode(tokenSlice) + if err != nil { + return err + } + for _, r := range mappedString { + if r == rune('{') { + s.braceStack = append(s.braceStack, r) + } + if r == rune('}') { + if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') { + return fmt.Errorf("unmatched closing brace") + } + s.braceStack = s.braceStack[:len(s.braceStack)-1] + } + } + for _, tokenID := range tokenSlice { + // transition to the next node + nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID] + if !ok { + return fmt.Errorf("invalid token: %q", mappedString) + } + fmt.Println("transitioning to", nextNode) + s.curNode = s.stateToNodeMap[nextNode] + } + return nil +} + +func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) { + for i := range logits { + _, exists := node.MaskTokenIDToNode[int32(i)] + if !exists { + logits[i] = math.NaN() + } + } + return logits, nil +} + +// TODO: add penalties for string \n stuff