diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index d9ccc151e..366d830bc 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -1,6 +1,7 @@ package sample import ( + "fmt" "slices" "github.com/ollama/ollama/model" @@ -34,7 +35,7 @@ Key JSON rules to consider: */ // TODO: / should be valid but an escape character -var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'} +var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'} var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} @@ -109,12 +110,12 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject] - b.addValueConnections(stateToNodeMap[StateInColon]) + addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap) // Leads to a value stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject] - b.addValueConnections(stateToNodeMap[StateInSpaceToValue]) + addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap) stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] // Values @@ -123,7 +124,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd] // String end node - b.addEnds(stateToNodeMap[StateInStringEnd]) + addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap) stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] @@ -132,7 +133,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { for _, r := range validNumberRunes { stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber] } - b.addEnds(stateToNodeMap[StateInNumber]) + addEnds(stateToNodeMap[StateInNumber], stateToNodeMap) stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] @@ -150,13 +151,13 @@ func (b *PDAGraphBuilder) BuildGraph() error { // empty list stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] - b.addValueConnections(stateToNodeMap[StateInList]) + addValueConnections(stateToNodeMap[StateInList], stateToNodeMap) // null node for _, r := range validNullRunes { stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull] } - b.addEnds(stateToNodeMap[StateInNull]) + addEnds(stateToNodeMap[StateInNull], stateToNodeMap) stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] @@ -165,7 +166,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma] stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] - b.addValueConnections(stateToNodeMap[StateInListComma]) + addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap) // list object end stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma] @@ -178,7 +179,7 @@ func (b *PDAGraphBuilder) BuildGraph() error { stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] } stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] - b.addEnds(stateToNodeMap[StateInBool]) + addEnds(stateToNodeMap[StateInBool], stateToNodeMap) stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] @@ -201,21 +202,21 @@ func (b *PDAGraphBuilder) BuildGraph() error { return nil } -func (b *PDAGraphBuilder) addEnds(node *PDA) { - node.TransitionEdges[','] = b.stateToNodeMap[StateInComma] - node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd] - node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd] +func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) { + node.TransitionEdges[','] = stateToNodeMap[StateInComma] + node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd] } -func (b *PDAGraphBuilder) addValueConnections(node *PDA) { - node.TransitionEdges['"'] = b.stateToNodeMap[StateInString] +func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) { + node.TransitionEdges['"'] = stateToNodeMap[StateInString] for _, r := range validNumberRunes { - node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber] + node.TransitionEdges[r] = stateToNodeMap[StateInNumber] } // TODO(parthsareen): force the output and shift similar to structured outputs - node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool] - node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool] - node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull] + node.TransitionEdges['t'] = stateToNodeMap[StateInBool] + node.TransitionEdges['f'] = stateToNodeMap[StateInBool] + node.TransitionEdges['n'] = stateToNodeMap[StateInNull] } func (b *PDAGraphBuilder) preComputeValidStates() error { @@ -228,6 +229,9 @@ func (b *PDAGraphBuilder) preComputeValidStates() error { } func (b *PDAGraphBuilder) CreateMask(node *PDA) error { + if node == nil { + return fmt.Errorf("node cannot be nil") + } for i := range b.decodedToks { token := b.decodedToks[i] // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON @@ -264,6 +268,7 @@ func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA // Check for specific rune transition if nextNode, ok := curNode.TransitionEdges[r]; ok { + // fmt.Println("next node", nextNode) if specialRune { if curNode.State == nextNode.State { return nil, false diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index adf5ce996..2ae16d062 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -17,9 +17,13 @@ type JSONSampler struct { } func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) { + if proc == nil { + return nil, fmt.Errorf("TextProcessor cannot be nil") + } + pdaSampler, err := NewPushdownSampler(proc) if err != nil { - return nil, err + return nil, fmt.Errorf("failed to create PushdownSampler: %w", err) } if schema == nil {