This commit is contained in:
ParthSareen 2025-01-30 15:05:25 -08:00
parent 198fde82aa
commit c56a8b7749
11 changed files with 316 additions and 650 deletions

View File

@ -104,6 +104,8 @@ func temp() error {
}
}
pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
var stringBuffer string
var offset int
for range args.n {
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
@ -118,7 +120,10 @@ func temp() error {
}
// do sampling
f64s, err = sample.Sample(f64s, sample.Greedy())
// []ints back
// ints map to sampled logits
f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
if err != nil {
return err
}
@ -129,6 +134,7 @@ func temp() error {
outputIDs = append(outputIDs, int32(f64))
}
}
pdaSampler.UpdateState(outputIDs)
if len(outputIDs) == 0 {
break
@ -141,8 +147,9 @@ func temp() error {
return err
}
fmt.Print(s)
// fmt.Print(s)
stringBuffer += s
fmt.Println("--- stringBuffer", stringBuffer)
inputIDs = append(inputIDs, outputIDs...)
if args.cache {
offset = len(inputIDs) - 1

1
model/cmd/test.go Normal file
View File

@ -0,0 +1 @@
package main

View File

@ -21,6 +21,7 @@ type TextProcessor interface {
Encode(string) ([]int32, error)
Decode([]int32) (string, error)
Is(uint32, Special) bool
GetVocabulary() *Vocabulary
}
type Vocabulary struct {
@ -104,6 +105,10 @@ type BytePairEncoding struct {
*Vocabulary
}
func (bpe BytePairEncoding) GetVocabulary() *Vocabulary {
return bpe.Vocabulary
}
func (bpe BytePairEncoding) split(s string) ([]string, error) {
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
if err != nil {

View File

@ -1,11 +1,7 @@
package sample
import (
"errors"
"fmt"
"math"
"github.com/ollama/ollama/model"
)
type JSONState int
@ -136,219 +132,3 @@ func (s JSONState) String() string {
return fmt.Sprintf("Unknown state: %d", s)
}
}
type JSONSampler struct {
curNode *Node
proc model.TextProcessor
stack []*Node
bracketCounter int
}
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
// fmt.Println("Creating new JSON sampler")
startNode, err := buildStateMachine(proc)
if err != nil {
return nil, err
}
js := &JSONSampler{
curNode: startNode,
proc: proc,
stack: []*Node{},
bracketCounter: 0,
}
return js, nil
}
func isTokenSubset(subset, superset []int32) bool {
freq1 := make(map[int32]int)
freq2 := make(map[int32]int)
for _, v := range subset {
freq1[v]++
}
for _, v := range superset {
freq2[v]++
}
isSubset := true
for k, count1 := range freq1 {
count2 := freq2[k]
if count1 > count2 {
isSubset = false
break
}
}
return isSubset
}
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
// fmt.Printf("Current state: %s\n", s.curNode.State)
// fmt.Println("tokenSlice", tokenSlice)
// todo: account for strings here
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
if err != nil {
return err
}
// only move to terminate state if stack is empty
if s.curNode.State == StateInObjectEnd {
fmt.Println("debug: node.State", s.curNode.State)
if len(s.stack) > 0 {
s.stack = s.stack[:len(s.stack)-1]
fmt.Println("popped and cur state", s.curNode.State)
return nil
}
return nil
}
for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge {
if isTokenSubset(tokenSlice, validToken) {
s.curNode = node
for _, token := range objectTokens {
if isTokenSubset(tokenSlice, token) {
fmt.Println("Appending to stack", s.curNode.State)
s.stack = append(s.stack, s.curNode)
}
}
// fmt.Printf("Transitioned to state: %s\n", node.State)
return nil
}
}
}
for node, edge := range s.curNode.TransitionEdges {
for _, validToken := range edge {
if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
s.curNode = node
// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
return nil
}
}
}
fmt.Println("invalid token ", tokenSlice)
dec, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
fmt.Println("decoded token ", dec)
return errors.New("invalid token")
}
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
fmt.Printf("Sampling in state: %s\n", s.curNode.State)
var err error
switch s.curNode.State {
case StateTerminate:
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.NaN()
}
}
return logits, nil
case StateInInt:
validStates := []int32{}
minus, err := s.proc.Encode("-")
if err != nil {
return nil, err
}
digits := make([][]int32, 10)
for i := 0; i < 10; i++ {
digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
if err != nil {
return nil, err
}
}
// Allow "-" and digits 0-9 at start
for i := range logits {
for _, d := range digits {
if len(d) == 1 && int32(i) == d[0] {
validStates = append(validStates, int32(i))
}
}
if len(minus) == 1 && int32(i) == minus[0] {
validStates = append(validStates, int32(i))
}
}
return logits, nil
case StateInString:
penalizeNewlineVariants := []string{"\n", " \"\n"}
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
if err != nil {
return nil, err
}
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
if err != nil {
return nil, err
}
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
if err != nil {
return nil, err
}
return logits, nil
default:
validStates := getValidStates(s.curNode)
logits, err = s.maskLogits(logits, validStates)
if err != nil {
return nil, err
}
return logits, nil
}
}
func getValidStates(node *Node) []int32 {
validStates := []int32{}
for _, edge := range node.TransitionEdges {
for _, token := range edge {
validStates = append(validStates, token...)
}
}
return validStates
}
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
// todo: this can prob be more efficient
for i := range logits {
isValid := false
for _, token := range validStates {
if token == -1 {
// fmt.Println("Found sentinel token, returning unmasked logits")
return logits, nil
}
if i == int(token) {
// fmt.Printf("Found valid token: %d\n", token)
isValid = true
break
}
}
if !isValid {
logits[i] = math.NaN()
}
}
return logits, nil
}
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
for i := range logits {
for _, token := range tokensToMask {
for _, chunked := range token {
if int(chunked) == i {
logits[i] = math.NaN()
}
}
}
}
return logits, nil
}

296
sample/hid.txt Normal file
View File

@ -0,0 +1,296 @@
package sample
import (
"slices"
"github.com/ollama/ollama/model"
)
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDANode struct {
State JSONState
TransitionEdges map[rune]*PDANode
MaskTokenIDToNode map[int32]JSONState
}
func NewPDANode(state JSONState) *PDANode {
return &PDANode{
State: state,
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]JSONState),
}
}
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
stateToNodeMap := make(map[JSONState]*PDANode)
startNode := NewPDANode(StateStart)
stateToNodeMap[StateStart] = startNode
objNode := NewPDANode(StateInObject)
stateToNodeMap[StateInObject] = objNode
objEndNode := NewPDANode(StateInObjectEnd)
stateToNodeMap[StateInObjectEnd] = objEndNode
objKeyNode := NewPDANode(StateInObjectKey)
stateToNodeMap[StateInObjectKey] = objKeyNode
objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
colonNode := NewPDANode(StateInColon)
stateToNodeMap[StateInColon] = colonNode
commaNode := NewPDANode(StateInComma)
stateToNodeMap[StateInComma] = commaNode
newlineNode := NewPDANode(StateInNewline)
stateToNodeMap[StateInNewline] = newlineNode
spaceNode := NewPDANode(StateInSpace)
stateToNodeMap[StateInSpace] = spaceNode
spaceObjNode := NewPDANode(StateInObjSpace)
stateToNodeMap[StateInObjSpace] = spaceObjNode
tabNode := NewPDANode(StateInTab)
stateToNodeMap[StateInTab] = tabNode
stringNode := NewPDANode(StateInString)
stateToNodeMap[StateInString] = stringNode
stringEndNode := NewPDANode(StateInStringEnd)
stateToNodeMap[StateInStringEnd] = stringEndNode
listNode := NewPDANode(StateInList)
stateToNodeMap[StateInList] = listNode
listCommaNode := NewPDANode(StateInListComma)
stateToNodeMap[StateInListComma] = listCommaNode
listEndNode := NewPDANode(StateListEnd)
stateToNodeMap[StateListEnd] = listEndNode
numberNode := NewPDANode(StateInNumber)
stateToNodeMap[StateInNumber] = numberNode
boolNode := NewPDANode(StateInBool)
stateToNodeMap[StateInBool] = boolNode
nullNode := NewPDANode(StateInNull)
stateToNodeMap[StateInNull] = nullNode
// Defined with structured outputs only
intNode := NewPDANode(StateInInt)
stateToNodeMap[StateInInt] = intNode
// TODO:
// consider adding a node to just point to values, could be good to compute that
// mask rather than many different nodes
// Connect nodes
// TODO: if all are single tokens then this can just be connected instead of defining the token
startNode.TransitionEdges['{'] = objNode
objNode.TransitionEdges['"'] = objKeyNode
objNode.TransitionEdges['\n'] = newlineNode
// objNode.TransitionEdges['\t'] = tabNode
newlineNode.TransitionEdges['"'] = objKeyNode
newlineNode.TransitionEdges['\t'] = tabNode
tabNode.TransitionEdges['"'] = objKeyNode
// tabNode.TransitionEdges['\t'] = tabNode
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
objKeyNode.TransitionEdges['"'] = objKeyEndNode
objKeyEndNode.TransitionEdges[':'] = colonNode
objEndNode.TransitionEdges[' '] = spaceNode
// where values should be
// this could be combined but the probs might change, we're alr doing a skip ahead
colonNode.TransitionEdges[' '] = spaceNode
// Leads to a value
spaceNode.TransitionEdges['"'] = stringNode
spaceNode.TransitionEdges['['] = listNode
spaceNode.TransitionEdges['{'] = objNode
for _, r := range validNumberRunes {
spaceNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
spaceNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
spaceNode.TransitionEdges[r] = nullNode
}
// Values
// string node
stringNode.TransitionEdges[rune(-1)] = stringNode
stringNode.TransitionEdges['"'] = stringEndNode
stringEndNode.TransitionEdges[','] = commaNode
stringEndNode.TransitionEdges['}'] = objEndNode
stringEndNode.TransitionEdges[']'] = listEndNode
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
numberNode.TransitionEdges[r] = numberNode
}
numberNode.TransitionEdges[','] = commaNode
numberNode.TransitionEdges['}'] = objEndNode
numberNode.TransitionEdges[']'] = listEndNode
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
// list node
listNode.TransitionEdges[','] = commaNode
listNode.TransitionEdges['"'] = stringNode
// squash states to a value
for _, r := range validNumberRunes {
listNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listNode.TransitionEdges[r] = nullNode
}
// null node
for _, r := range validNullRunes {
nullNode.TransitionEdges[r] = nullNode
}
nullNode.TransitionEdges[','] = commaNode
nullNode.TransitionEdges['}'] = objEndNode
nullNode.TransitionEdges[']'] = listEndNode
// list comma
// should point to values
listCommaNode.TransitionEdges['"'] = stringNode
listCommaNode.TransitionEdges[' '] = listCommaNode
listCommaNode.TransitionEdges['{'] = objNode
listCommaNode.TransitionEdges['\n'] = newlineNode
for _, r := range validNumberRunes {
listCommaNode.TransitionEdges[r] = numberNode
}
for _, r := range validBoolRunes {
listCommaNode.TransitionEdges[r] = boolNode
}
for _, r := range validNullRunes {
listCommaNode.TransitionEdges[r] = nullNode
}
// bool node
for _, r := range validBoolRunes {
boolNode.TransitionEdges[r] = boolNode
}
boolNode.TransitionEdges['}'] = objEndNode
boolNode.TransitionEdges[']'] = listEndNode
boolNode.TransitionEdges[','] = commaNode
listEndNode.TransitionEdges['}'] = objEndNode
listEndNode.TransitionEdges[','] = commaNode
commaNode.TransitionEdges['{'] = objNode
commaNode.TransitionEdges['\n'] = newlineNode
commaNode.TransitionEdges['\t'] = tabNode
commaNode.TransitionEdges['"'] = objKeyNode
commaNode.TransitionEdges[' '] = spaceObjNode
spaceObjNode.TransitionEdges['"'] = objKeyNode
return startNode, stateToNodeMap, nil
}
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
var err error
for _, node := range stateToNodeMap {
for i := range vocab.Values {
token := decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
continue
}
valid := true
curNode := node
consumedSpecialRunes := make(map[rune]bool)
for _, r := range token {
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
if err != nil {
return err
}
if !valid {
break
}
}
if valid {
node.MaskTokenIDToNode[int32(i)] = curNode.State
}
}
}
return nil
}
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
if consumedSpecialRunes[r] {
return false, nil, nil
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey {
return false, nil, nil
}
}
// Check for specific rune transition
if nextNode, ok := curNode.TransitionEdges[r]; ok {
if specialRune {
if curNode.State == nextNode.State {
return false, nil, nil
}
// fmt.Println("special rune", r, "consumed")
consumedSpecialRunes[r] = true
}
return true, nextNode, nil
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return true, nextNode, nil
}
return false, nil, nil
}

View File

@ -1,104 +0,0 @@
package sample
import (
"fmt"
"math"
"github.com/ollama/ollama/model"
)
type JSONState int
const (
StateStart JSONState = iota // Initial state
StateInObject // Inside an object {}
StateInArray // Inside an array []
StateInString // Inside a string ""
StateAfterKey // After object key, expecting :
StateAfterColon // After :, expecting value
StateAfterValue // After value, expecting , or closing bracket
StateDone // JSON parsing complete
)
type JSONSampler struct {
state JSONState
stack []string
proc model.TextProcessor
}
func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
return &JSONSampler{
state: StateStart,
proc: proc,
}
}
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
// Pre-decode valid tokens for current state
validTokens := make(map[uint32]bool)
// Always allow EOS token in any state
// TODO: Check for other special tokens if needed
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
validTokens[uint32(i)] = true
}
}
// Build set of valid tokens based on current state
switch s.state {
case StateStart:
// Only allow opening brace
for i := range logits {
text, err := s.proc.Decode([]int32{int32(i)})
if err == nil && text == "{" {
validTokens[uint32(i)] = true
}
}
case StateInObject, StateInArray:
// Allow any token
for i := range logits {
validTokens[uint32(i)] = true
}
case StateInString:
// Allow any token except closing brace
for i := range logits {
text, err := s.proc.Decode([]int32{int32(i)})
if err == nil && text != "}" {
validTokens[uint32(i)] = true
}
}
case StateDone:
// No tokens allowed
}
// Mark invalid tokens as NaN in one pass
for i := range logits {
if !validTokens[uint32(i)] {
logits[i] = math.NaN()
}
}
return logits, nil
}
func (s *JSONSampler) UpdateState(tokenID int) error {
text, err := s.proc.Decode([]int32{int32(tokenID)})
if err != nil {
return fmt.Errorf("failed to decode token: %w", err)
}
switch s.state {
case StateStart:
if text != "{" {
return fmt.Errorf("expected {, got %s", text)
}
s.state = StateInObject
case StateInObject:
if text == "}" {
s.state = StateDone
}
case StateDone:
return fmt.Errorf("unexpected token after closing bracket: %s", text)
}
return nil
}

View File

@ -165,9 +165,10 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
if len(logitsCopy) == 0 {
return nil, errors.New("no valid tokens found")
}
// usually, a softmax is applied to sample from the logits
// in this case the uv sampler normalizes the logits so that the sum of the weights is 1
logitsCopy, err := computeSoftmax(logitsCopy)
if err != nil {
return nil, err
}
w := sampleuv.NewWeighted(logitsCopy, nil)
if v, ok := w.Take(); ok {
// returns the token ID
@ -176,17 +177,6 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
return nil, errors.New("weighed sampler failed")
}
// TODO: remove after next PR merge
type greedy struct{}
func Greedy() Sampler {
return greedy{}
}
func (greedy) Sample(logits []float64) ([]float64, error) {
return []float64{float64(floats.MaxIdx(logits))}, nil
}
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
var err error
for _, sampler := range samplers {

View File

@ -3,14 +3,9 @@ package sample
import (
"fmt"
"math"
"math/rand"
"os"
"runtime"
"slices"
"testing"
"runtime/trace"
"gonum.org/v1/gonum/floats"
)

View File

@ -1,218 +0,0 @@
package sample
import (
"fmt"
"github.com/ollama/ollama/model"
)
type token []int32
type Node struct {
State JSONState
TransitionEdges map[*Node][]token
}
func NewNode(state JSONState) *Node {
return &Node{
State: state,
TransitionEdges: make(map[*Node][]token),
}
}
var (
// startToken token
startTokenVariants []token
// endToken token
// stringToken token
// objectKeyToken token
tabToken token
spaceToken token
newlineToken token
newlineSpace token
// commaToken token
// commaToken2 token
// commaToken3 token
// colonToken token
// colonToken2 token
colonTokenVariants []token
commaTokenVariants []token
stringTokenVariants []token
endTokenVariants []token
objectKeyTokenVariants []token
objKeyToColonVariants []token
stringToObjectKeyVariants []token
stringToCommaVariants []token
stringToObjectVariants []token
stringEndToObjectEndVariants []token
stringEndToCommaVariants []token
)
func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
var allTokens token
for _, variant := range variants {
if t, err := proc.Encode(variant); err == nil {
allTokens = append(allTokens, t...)
}
}
if len(allTokens) == 0 {
return nil, fmt.Errorf("no valid tokens found for variants")
}
return []token{allTokens}, nil
}
func initTokens(proc model.TextProcessor) error {
var err error
s, err := proc.Decode([]int32{761})
fmt.Printf("761 decoded %q\n", s)
// Compute start token variants
startVariants := []string{"{", " {", "{\n", " {\n"}
startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
if err != nil {
return err
}
// Compute end token variants
endVariants := []string{"}", " }", "}\n", " }\n"}
endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
if err != nil {
return err
}
// Compute string token variants
// TODO: removed \n
stringVariants := []string{"\"", " \""}
stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
if err != nil {
return err
}
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
if err != nil {
return err
}
// objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
objectKeyTokenVariants = stringTokenVariants
// Compute whitespace tokens
tabToken, err = proc.Encode("\t")
if err != nil {
return err
}
spaceToken, err = proc.Encode(" ")
if err != nil {
return err
}
newlineToken, err = proc.Encode("\n")
if err != nil {
return err
}
newlineSpace, err = proc.Encode(" \n")
if err != nil {
return err
}
// Compute colon variants
colonVariants := []string{":"}
colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
if err != nil {
return err
}
objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
if err != nil {
return err
}
// Compute comma variants
commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
if err != nil {
return err
}
fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
if err != nil {
return err
}
stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
return nil
}
func buildStateMachine(proc model.TextProcessor) (*Node, error) {
if err := initTokens(proc); err != nil {
return nil, err
}
startNode := NewNode(StateStart)
objectNode := NewNode(StateInObject)
objectKeyNode := NewNode(StateInObjectKey)
objectKeyEndNode := NewNode(StateInObjectKeyEnd)
stringNode := NewNode(StateInString)
// intNode := NewNode(StateInInt)
commaNode := NewNode(StateInComma)
colonNode := NewNode(StateInColon)
stringEndNode := NewNode(StateInStringEnd)
endNode := NewNode(StateEnd)
terminateNode := NewNode(StateTerminate)
sentinelToken := token([]int32{-1})
// intSentinelToken := token([]int32{-2})
// TODO: cleanup connections of rules
startNode.TransitionEdges[objectNode] = startTokenVariants
objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
objectNode.TransitionEdges[objectNode] = []token{newlineToken}
objectNode.TransitionEdges[objectNode] = []token{spaceToken}
// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
// characterize end of object key
objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
// TODO: enable this - key -> object
// objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
// intNode.TransitionEdges[intNode] = []token{intSentinelToken}
// intNode.TransitionEdges[commaNode] = commaTokenVariants
// TODO: handle
// intNode.TransitionEdges[terminateNode] = endTokenVariants
commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
// commaNode.TransitionEdges[objectNode] = startTokenVariants
colonNode.TransitionEdges[stringNode] = stringTokenVariants
//TODO: enable
// colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
colonNode.TransitionEdges[objectNode] = startTokenVariants
stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
// TODO: "\""," Case not accounted for
stringNode.TransitionEdges[commaNode] = stringToCommaVariants
// TODO: "\"",\"" Case not accounted for
stringNode.TransitionEdges[objectNode] = stringToObjectVariants
stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
// stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
// Should be obj end
// TODO: handle
endNode.TransitionEdges[terminateNode] = []token{}
endNode.TransitionEdges[commaNode] = commaTokenVariants
terminateNode.TransitionEdges[terminateNode] = []token{}
return startNode, nil
}

View File

@ -7,92 +7,6 @@ type StructuredOutput struct {
}
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
// _, stateToNodeMap, err := BuildGraph(proc)
// if err != nil {
// panic(err)
// }
return nil
}
// func constrainGraph(graph *PDANode, schema *Schema) *PDANode {
// // If no schema constraints, return original graph node
// if schema == nil {
// return graph
// }
// // Create a new node with same state
// constrainedNode := NewPDANode(graph.State)
// // Copy over existing transitions and masks
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
// for r, node := range graph.TransitionEdges {
// constrainedNode.TransitionEdges[r] = node
// }
// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode
// // Apply schema constraints based on type
// switch schema.EffectiveType() {
// case "object":
// // Only allow defined property names in object keys
// if graph.State == StateInObjectKey {
// // TODO: Add property name validation
// }
// // Constrain property values based on schema
// if graph.State == StateInColon || graph.State == StateInSpace {
// // Clear transitions to only allow valid types
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
// // Add transitions based on property schemas
// for _, prop := range schema.Properties {
// switch prop.EffectiveType() {
// case "object":
// if objNode, ok := graph.TransitionEdges['{']; ok {
// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop)
// }
// case "array":
// if arrNode, ok := graph.TransitionEdges['[']; ok {
// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop)
// }
// case "string":
// if strNode, ok := graph.TransitionEdges['"']; ok {
// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop)
// }
// case "number":
// for _, r := range validNumberRunes {
// if numNode, ok := graph.TransitionEdges[r]; ok {
// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop)
// }
// }
// case "integer":
// for _, r := range validIntRunes {
// if intNode, ok := graph.TransitionEdges[r]; ok {
// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop)
// }
// }
// case "boolean":
// for _, r := range []rune{'t', 'f'} {
// if boolNode, ok := graph.TransitionEdges[r]; ok {
// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop)
// }
// }
// case "null":
// if nullNode, ok := graph.TransitionEdges['n']; ok {
// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop)
// }
// }
// }
// }
// case "array":
// // Constrain array items based on schema
// if schema.Items != nil {
// for r, node := range graph.TransitionEdges {
// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items)
// }
// }
// }
// return constrainedNode
// }

BIN
sample/trace.out Normal file

Binary file not shown.