mirror of
https://github.com/ollama/ollama.git
synced 2025-04-08 03:48:21 +02:00
WIP
This commit is contained in:
parent
a2a73ce5e0
commit
e93db4d20e
@ -22,12 +22,18 @@ const (
|
||||
StateInFloat
|
||||
StateInBool
|
||||
StateInNull
|
||||
StateInArray
|
||||
StateInColon
|
||||
StateInComma
|
||||
StateInTab
|
||||
StateInSpace
|
||||
StateInObjSpace
|
||||
StateInList
|
||||
StateInListComma
|
||||
StateListEnd
|
||||
StateInListEnd
|
||||
StateInNewline
|
||||
StateInNumber
|
||||
StateInNumberEnd
|
||||
StateInStringEnd
|
||||
StateInObjectKeyEnd
|
||||
StateTerminate
|
||||
@ -42,42 +48,54 @@ func (s JSONState) String() string {
|
||||
return "StateInObject"
|
||||
case StateInObjectKey:
|
||||
return "StateInObjectKey"
|
||||
case StateInString:
|
||||
return "StateInString"
|
||||
case StateNewline:
|
||||
return "StateNewline"
|
||||
case StateTab:
|
||||
return "StateTab"
|
||||
case StateSpace:
|
||||
return "StateSpace"
|
||||
case StateInString:
|
||||
return "StateInString"
|
||||
case StateInInt:
|
||||
return "StateInInt"
|
||||
case StateInFloat:
|
||||
return "StateInFloat"
|
||||
case StateInColon:
|
||||
return "StateInColon"
|
||||
case StateInBool:
|
||||
return "StateInBool"
|
||||
case StateInNull:
|
||||
return "StateInNull"
|
||||
case StateInArray:
|
||||
return "StateInArray"
|
||||
case StateInObjectEnd:
|
||||
return "StateInObjectEnd"
|
||||
case StateInColon:
|
||||
return "StateInColon"
|
||||
case StateInComma:
|
||||
return "StateInComma"
|
||||
case StateInTab:
|
||||
return "StateInTab"
|
||||
case StateInObjectKeyEnd:
|
||||
return "StateInObjectKeyEnd"
|
||||
case StateInNewline:
|
||||
return "StateInNewline"
|
||||
case StateInSpace:
|
||||
return "StateInSpace"
|
||||
case StateTerminate:
|
||||
return "StateTerminate"
|
||||
case StateInObjSpace:
|
||||
return "StateInObjSpace"
|
||||
case StateInList:
|
||||
return "StateInList"
|
||||
case StateInListComma:
|
||||
return "StateInListComma"
|
||||
case StateListEnd:
|
||||
return "StateListEnd"
|
||||
case StateInListEnd:
|
||||
return "StateInListEnd"
|
||||
case StateInNewline:
|
||||
return "StateInNewline"
|
||||
case StateInNumber:
|
||||
return "StateInNumber"
|
||||
case StateInNumberEnd:
|
||||
return "StateInNumberEnd"
|
||||
case StateInStringEnd:
|
||||
return "StateInStringEnd"
|
||||
case StateInObjectKeyEnd:
|
||||
return "StateInObjectKeyEnd"
|
||||
case StateTerminate:
|
||||
return "StateTerminate"
|
||||
case StateInObjectEnd:
|
||||
return "StateInObjectEnd"
|
||||
default:
|
||||
return fmt.Sprintf("Unknown state: %d", s)
|
||||
}
|
||||
@ -264,6 +282,7 @@ func getValidStates(node *Node) []int32 {
|
||||
|
||||
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
|
||||
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
|
||||
// todo: this can prob be more efficient
|
||||
for i := range logits {
|
||||
isValid := false
|
||||
for _, token := range validStates {
|
||||
|
@ -8,6 +8,15 @@ import (
|
||||
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
||||
|
||||
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||
|
||||
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||
|
||||
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||
|
||||
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||
|
||||
type PDANode struct {
|
||||
State JSONState
|
||||
TransitionEdges map[rune]*PDANode
|
||||
@ -52,6 +61,9 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
spaceNode := NewPDANode(StateInSpace)
|
||||
stateToNodeMap[StateInSpace] = spaceNode
|
||||
|
||||
spaceObjNode := NewPDANode(StateInObjSpace)
|
||||
stateToNodeMap[StateInObjSpace] = spaceObjNode
|
||||
|
||||
tabNode := NewPDANode(StateInTab)
|
||||
stateToNodeMap[StateInTab] = tabNode
|
||||
|
||||
@ -61,7 +73,31 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
stringEndNode := NewPDANode(StateInStringEnd)
|
||||
stateToNodeMap[StateInStringEnd] = stringEndNode
|
||||
|
||||
// terminateNode := NewNode(StateTerminate)
|
||||
listNode := NewPDANode(StateInList)
|
||||
stateToNodeMap[StateInList] = listNode
|
||||
|
||||
listCommaNode := NewPDANode(StateInListComma)
|
||||
stateToNodeMap[StateInListComma] = listCommaNode
|
||||
|
||||
listEndNode := NewPDANode(StateListEnd)
|
||||
stateToNodeMap[StateListEnd] = listEndNode
|
||||
|
||||
numberNode := NewPDANode(StateInNumber)
|
||||
stateToNodeMap[StateInNumber] = numberNode
|
||||
|
||||
boolNode := NewPDANode(StateInBool)
|
||||
stateToNodeMap[StateInBool] = boolNode
|
||||
|
||||
nullNode := NewPDANode(StateInNull)
|
||||
stateToNodeMap[StateInNull] = nullNode
|
||||
|
||||
// Defined with structured outputs only
|
||||
intNode := NewPDANode(StateInInt)
|
||||
stateToNodeMap[StateInInt] = intNode
|
||||
|
||||
// TODO:
|
||||
// consider adding a node to just point to values, could be good to compute that
|
||||
// mask rather than many different nodes
|
||||
|
||||
// Connect nodes
|
||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||
@ -69,34 +105,119 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
|
||||
objNode.TransitionEdges['"'] = objKeyNode
|
||||
objNode.TransitionEdges['\n'] = newlineNode
|
||||
// objNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||
newlineNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
tabNode.TransitionEdges['"'] = objKeyNode
|
||||
|
||||
spaceNode.TransitionEdges['"'] = stringNode
|
||||
// tabNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||
objKeyNode.TransitionEdges[' '] = spaceNode
|
||||
// objKeyNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||
objEndNode.TransitionEdges[' '] = spaceNode
|
||||
|
||||
colonNode.TransitionEdges['"'] = stringNode
|
||||
// where values should be
|
||||
// this could be combined but the probs might change, we're alr doing a skip ahead
|
||||
colonNode.TransitionEdges[' '] = spaceNode
|
||||
|
||||
// Leads to a value
|
||||
spaceNode.TransitionEdges['"'] = stringNode
|
||||
spaceNode.TransitionEdges['['] = listNode
|
||||
spaceNode.TransitionEdges['{'] = objNode
|
||||
|
||||
for _, r := range validNumberRunes {
|
||||
spaceNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
spaceNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
|
||||
for _, r := range validNullRunes {
|
||||
spaceNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
|
||||
// Values
|
||||
// string node
|
||||
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||
stringNode.TransitionEdges['"'] = stringEndNode
|
||||
|
||||
stringEndNode.TransitionEdges[','] = commaNode
|
||||
stringEndNode.TransitionEdges['}'] = objEndNode
|
||||
stringEndNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||
// number node
|
||||
for _, r := range validNumberRunes {
|
||||
numberNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
numberNode.TransitionEdges[','] = commaNode
|
||||
numberNode.TransitionEdges['}'] = objEndNode
|
||||
numberNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
for _, r := range validBoolRunes {
|
||||
boolNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
|
||||
// list node
|
||||
listNode.TransitionEdges[','] = commaNode
|
||||
listNode.TransitionEdges['"'] = stringNode
|
||||
// squash states to a value
|
||||
for _, r := range validNumberRunes {
|
||||
listNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
listNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
for _, r := range validNullRunes {
|
||||
listNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
|
||||
// null node
|
||||
for _, r := range validNullRunes {
|
||||
nullNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
nullNode.TransitionEdges[','] = commaNode
|
||||
nullNode.TransitionEdges['}'] = objEndNode
|
||||
nullNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
// list comma
|
||||
// should point to values
|
||||
listCommaNode.TransitionEdges['"'] = stringNode
|
||||
listCommaNode.TransitionEdges[' '] = listCommaNode
|
||||
listCommaNode.TransitionEdges['{'] = objNode
|
||||
listCommaNode.TransitionEdges['\n'] = newlineNode
|
||||
|
||||
for _, r := range validNumberRunes {
|
||||
listCommaNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
listCommaNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
for _, r := range validNullRunes {
|
||||
listCommaNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
boolNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
boolNode.TransitionEdges['}'] = objEndNode
|
||||
boolNode.TransitionEdges[']'] = listEndNode
|
||||
boolNode.TransitionEdges[','] = commaNode
|
||||
|
||||
listEndNode.TransitionEdges['}'] = objEndNode
|
||||
listEndNode.TransitionEdges[','] = commaNode
|
||||
|
||||
commaNode.TransitionEdges['{'] = objNode
|
||||
commaNode.TransitionEdges['\n'] = newlineNode
|
||||
commaNode.TransitionEdges['\t'] = tabNode
|
||||
commaNode.TransitionEdges['"'] = objKeyNode
|
||||
commaNode.TransitionEdges[' '] = spaceObjNode
|
||||
|
||||
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
||||
|
||||
return startNode, stateToNodeMap, nil
|
||||
}
|
||||
|
@ -3,6 +3,8 @@ package sample
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@ -13,9 +15,17 @@ type PushdownSampler struct {
|
||||
proc model.TextProcessor
|
||||
stateToNodeMap map[JSONState]*PDANode
|
||||
braceStack []rune
|
||||
stateCounter uint32
|
||||
}
|
||||
|
||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
start := time.Now()
|
||||
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
|
||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
@ -24,6 +34,11 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
// for id, node := range stateToNodeMap[StateInComma].MaskTokenIDToNode {
|
||||
// token, err := proc.Decode([]int32{int32(id)})
|
||||
// if err != nil {
|
||||
@ -37,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
proc: proc,
|
||||
stateToNodeMap: stateToNodeMap,
|
||||
braceStack: []rune{},
|
||||
stateCounter: 0,
|
||||
}
|
||||
}
|
||||
|
||||
@ -69,7 +85,19 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
// return logits, nil
|
||||
|
||||
case StateInComma:
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||
fmt.Println("switching to list comma", s.curNode.State)
|
||||
}
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateTerminate:
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
@ -80,9 +108,6 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
// case StateInStringEnd:
|
||||
|
||||
// return logits, nil
|
||||
default:
|
||||
fmt.Println("masking logits current state", s.curNode.State)
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
@ -96,7 +121,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
fmt.Println("update state", s.curNode.State)
|
||||
|
||||
// TODO: need to handle end states and entering object case
|
||||
// TODO: need to handle end states and entering object case, and list case
|
||||
if s.curNode.State == StateInObjectEnd {
|
||||
fmt.Println("in object end")
|
||||
if len(s.braceStack) > 0 {
|
||||
@ -111,25 +136,45 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// TODO: should force closing for all braces
|
||||
for _, r := range mappedString {
|
||||
if r == rune('{') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('[') {
|
||||
s.braceStack = append(s.braceStack, r)
|
||||
}
|
||||
if r == rune('}') {
|
||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('{') {
|
||||
return fmt.Errorf("unmatched closing brace")
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
fmt.Println("popping brace stack", s.braceStack)
|
||||
}
|
||||
|
||||
if r == rune(']') {
|
||||
if len(s.braceStack) == 0 || s.braceStack[len(s.braceStack)-1] != rune('[') {
|
||||
return fmt.Errorf("unmatched closing brace")
|
||||
}
|
||||
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
||||
fmt.Println("popping brace stack", s.braceStack)
|
||||
}
|
||||
}
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
fmt.Println("transitioning to", nextNode)
|
||||
s.curNode = s.stateToNodeMap[nextNode]
|
||||
fmt.Println("transitioning to", nextNodeState)
|
||||
|
||||
// TODO: add a penalty for staying in the same state too long
|
||||
if nextNodeState == s.curNode.State {
|
||||
s.stateCounter++
|
||||
} else {
|
||||
s.stateCounter = 0
|
||||
}
|
||||
s.curNode = s.stateToNodeMap[nextNodeState]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user