From c56a8b77498392b3000767528fd2372447d6f760 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 30 Jan 2025 15:05:25 -0800 Subject: [PATCH] wip --- model/cmd/main.go | 13 +- model/cmd/test.go | 1 + model/process_text.go | 5 + sample/fast_json.go | 220 -------------------------- sample/hid.txt | 296 +++++++++++++++++++++++++++++++++++ sample/json_sampler.go | 104 ------------ sample/sample.go | 18 +-- sample/sample_test.go | 5 - sample/state_machine.go | 218 -------------------------- sample/structured_outputs.go | 86 ---------- sample/trace.out | Bin 0 -> 5912 bytes 11 files changed, 316 insertions(+), 650 deletions(-) create mode 100644 model/cmd/test.go create mode 100644 sample/hid.txt delete mode 100644 sample/json_sampler.go delete mode 100644 sample/state_machine.go create mode 100644 sample/trace.out diff --git a/model/cmd/main.go b/model/cmd/main.go index ed7901c7e..c349e20d4 100644 --- a/model/cmd/main.go +++ b/model/cmd/main.go @@ -104,6 +104,8 @@ func temp() error { } } + pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) + var stringBuffer string var offset int for range args.n { logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...) @@ -118,7 +120,10 @@ func temp() error { } // do sampling - f64s, err = sample.Sample(f64s, sample.Greedy()) + // []ints back + // ints map to sampled logits + f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy()) + if err != nil { return err } @@ -129,6 +134,7 @@ func temp() error { outputIDs = append(outputIDs, int32(f64)) } } + pdaSampler.UpdateState(outputIDs) if len(outputIDs) == 0 { break @@ -141,8 +147,9 @@ func temp() error { return err } - fmt.Print(s) - + // fmt.Print(s) + stringBuffer += s + fmt.Println("--- stringBuffer", stringBuffer) inputIDs = append(inputIDs, outputIDs...) if args.cache { offset = len(inputIDs) - 1 diff --git a/model/cmd/test.go b/model/cmd/test.go new file mode 100644 index 000000000..06ab7d0f9 --- /dev/null +++ b/model/cmd/test.go @@ -0,0 +1 @@ +package main diff --git a/model/process_text.go b/model/process_text.go index b2024d778..f335f59c5 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -21,6 +21,7 @@ type TextProcessor interface { Encode(string) ([]int32, error) Decode([]int32) (string, error) Is(uint32, Special) bool + GetVocabulary() *Vocabulary } type Vocabulary struct { @@ -104,6 +105,10 @@ type BytePairEncoding struct { *Vocabulary } +func (bpe BytePairEncoding) GetVocabulary() *Vocabulary { + return bpe.Vocabulary +} + func (bpe BytePairEncoding) split(s string) ([]string, error) { re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2) if err != nil { diff --git a/sample/fast_json.go b/sample/fast_json.go index 486490f21..bd80e8388 100644 --- a/sample/fast_json.go +++ b/sample/fast_json.go @@ -1,11 +1,7 @@ package sample import ( - "errors" "fmt" - "math" - - "github.com/ollama/ollama/model" ) type JSONState int @@ -136,219 +132,3 @@ func (s JSONState) String() string { return fmt.Sprintf("Unknown state: %d", s) } } - -type JSONSampler struct { - curNode *Node - proc model.TextProcessor - stack []*Node - bracketCounter int -} - -func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) { - // fmt.Println("Creating new JSON sampler") - startNode, err := buildStateMachine(proc) - if err != nil { - return nil, err - } - js := &JSONSampler{ - curNode: startNode, - proc: proc, - stack: []*Node{}, - bracketCounter: 0, - } - - return js, nil -} - -func isTokenSubset(subset, superset []int32) bool { - freq1 := make(map[int32]int) - freq2 := make(map[int32]int) - - for _, v := range subset { - freq1[v]++ - } - for _, v := range superset { - freq2[v]++ - } - isSubset := true - for k, count1 := range freq1 { - count2 := freq2[k] - if count1 > count2 { - isSubset = false - break - } - } - return isSubset -} - -func (s *JSONSampler) UpdateState(tokenSlice []int32) error { - // fmt.Printf("Updating state with token: %v\n", tokenSlice) - // fmt.Printf("Current state: %s\n", s.curNode.State) - - // fmt.Println("tokenSlice", tokenSlice) - // todo: account for strings here - - objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc) - if err != nil { - return err - } - - // only move to terminate state if stack is empty - if s.curNode.State == StateInObjectEnd { - fmt.Println("debug: node.State", s.curNode.State) - if len(s.stack) > 0 { - s.stack = s.stack[:len(s.stack)-1] - fmt.Println("popped and cur state", s.curNode.State) - return nil - } - return nil - } - - for node, edge := range s.curNode.TransitionEdges { - for _, validToken := range edge { - if isTokenSubset(tokenSlice, validToken) { - s.curNode = node - for _, token := range objectTokens { - if isTokenSubset(tokenSlice, token) { - fmt.Println("Appending to stack", s.curNode.State) - s.stack = append(s.stack, s.curNode) - } - } - // fmt.Printf("Transitioned to state: %s\n", node.State) - return nil - } - } - } - for node, edge := range s.curNode.TransitionEdges { - for _, validToken := range edge { - if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 { - s.curNode = node - // fmt.Printf("Accepting any token, staying in state: %s\n", node.State) - return nil - } - } - } - fmt.Println("invalid token ", tokenSlice) - dec, err := s.proc.Decode(tokenSlice) - if err != nil { - return err - } - fmt.Println("decoded token ", dec) - return errors.New("invalid token") -} - -func (s *JSONSampler) Sample(logits []float64) ([]float64, error) { - fmt.Printf("Sampling in state: %s\n", s.curNode.State) - var err error - - switch s.curNode.State { - case StateTerminate: - for i := range logits { - if s.proc.Is(uint32(i), model.SpecialEOS) { - logits[i] = 1.0 - } else { - logits[i] = math.NaN() - } - } - return logits, nil - - case StateInInt: - validStates := []int32{} - minus, err := s.proc.Encode("-") - if err != nil { - return nil, err - } - digits := make([][]int32, 10) - for i := 0; i < 10; i++ { - digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i)) - if err != nil { - return nil, err - } - } - // Allow "-" and digits 0-9 at start - for i := range logits { - for _, d := range digits { - if len(d) == 1 && int32(i) == d[0] { - validStates = append(validStates, int32(i)) - } - } - if len(minus) == 1 && int32(i) == minus[0] { - validStates = append(validStates, int32(i)) - } - } - return logits, nil - - case StateInString: - penalizeNewlineVariants := []string{"\n", " \"\n"} - penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc) - if err != nil { - return nil, err - } - penalizeNewlineToks = append(penalizeNewlineToks, []int32{702}) - logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks) - if err != nil { - return nil, err - } - validStates := getValidStates(s.curNode) - logits, err = s.maskLogits(logits, validStates) - if err != nil { - return nil, err - } - return logits, nil - - default: - validStates := getValidStates(s.curNode) - logits, err = s.maskLogits(logits, validStates) - if err != nil { - return nil, err - } - return logits, nil - } -} - -func getValidStates(node *Node) []int32 { - validStates := []int32{} - for _, edge := range node.TransitionEdges { - for _, token := range edge { - validStates = append(validStates, token...) - } - } - return validStates -} - -func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) { - // fmt.Printf("Masking logits with valid states: %v\n", validStates) - // todo: this can prob be more efficient - for i := range logits { - isValid := false - for _, token := range validStates { - if token == -1 { - // fmt.Println("Found sentinel token, returning unmasked logits") - return logits, nil - } - if i == int(token) { - // fmt.Printf("Found valid token: %d\n", token) - isValid = true - break - } - } - if !isValid { - logits[i] = math.NaN() - } - } - return logits, nil -} - -func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) { - // fmt.Printf("Masking specific logits: %v\n", tokensToMask) - for i := range logits { - for _, token := range tokensToMask { - for _, chunked := range token { - if int(chunked) == i { - logits[i] = math.NaN() - } - } - } - } - return logits, nil -} diff --git a/sample/hid.txt b/sample/hid.txt new file mode 100644 index 000000000..d58f23cc4 --- /dev/null +++ b/sample/hid.txt @@ -0,0 +1,296 @@ +package sample + +import ( + "slices" + + "github.com/ollama/ollama/model" +) + +var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','} + +var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} +var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} + +var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'} + +var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'} + +var validNullRunes = []rune{'n', 'u', 'l', 'l'} + +type PDANode struct { + State JSONState + TransitionEdges map[rune]*PDANode + MaskTokenIDToNode map[int32]JSONState +} + +func NewPDANode(state JSONState) *PDANode { + return &PDANode{ + State: state, + TransitionEdges: make(map[rune]*PDANode), + MaskTokenIDToNode: make(map[int32]JSONState), + } +} + +func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) { + stateToNodeMap := make(map[JSONState]*PDANode) + + startNode := NewPDANode(StateStart) + stateToNodeMap[StateStart] = startNode + + objNode := NewPDANode(StateInObject) + stateToNodeMap[StateInObject] = objNode + + objEndNode := NewPDANode(StateInObjectEnd) + stateToNodeMap[StateInObjectEnd] = objEndNode + + objKeyNode := NewPDANode(StateInObjectKey) + stateToNodeMap[StateInObjectKey] = objKeyNode + + objKeyEndNode := NewPDANode(StateInObjectKeyEnd) + stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode + + colonNode := NewPDANode(StateInColon) + stateToNodeMap[StateInColon] = colonNode + + commaNode := NewPDANode(StateInComma) + stateToNodeMap[StateInComma] = commaNode + + newlineNode := NewPDANode(StateInNewline) + stateToNodeMap[StateInNewline] = newlineNode + + spaceNode := NewPDANode(StateInSpace) + stateToNodeMap[StateInSpace] = spaceNode + + spaceObjNode := NewPDANode(StateInObjSpace) + stateToNodeMap[StateInObjSpace] = spaceObjNode + + tabNode := NewPDANode(StateInTab) + stateToNodeMap[StateInTab] = tabNode + + stringNode := NewPDANode(StateInString) + stateToNodeMap[StateInString] = stringNode + + stringEndNode := NewPDANode(StateInStringEnd) + stateToNodeMap[StateInStringEnd] = stringEndNode + + listNode := NewPDANode(StateInList) + stateToNodeMap[StateInList] = listNode + + listCommaNode := NewPDANode(StateInListComma) + stateToNodeMap[StateInListComma] = listCommaNode + + listEndNode := NewPDANode(StateListEnd) + stateToNodeMap[StateListEnd] = listEndNode + + numberNode := NewPDANode(StateInNumber) + stateToNodeMap[StateInNumber] = numberNode + + boolNode := NewPDANode(StateInBool) + stateToNodeMap[StateInBool] = boolNode + + nullNode := NewPDANode(StateInNull) + stateToNodeMap[StateInNull] = nullNode + + // Defined with structured outputs only + intNode := NewPDANode(StateInInt) + stateToNodeMap[StateInInt] = intNode + + // TODO: + // consider adding a node to just point to values, could be good to compute that + // mask rather than many different nodes + + // Connect nodes + // TODO: if all are single tokens then this can just be connected instead of defining the token + startNode.TransitionEdges['{'] = objNode + + objNode.TransitionEdges['"'] = objKeyNode + objNode.TransitionEdges['\n'] = newlineNode + // objNode.TransitionEdges['\t'] = tabNode + + newlineNode.TransitionEdges['"'] = objKeyNode + newlineNode.TransitionEdges['\t'] = tabNode + + tabNode.TransitionEdges['"'] = objKeyNode + // tabNode.TransitionEdges['\t'] = tabNode + + objKeyNode.TransitionEdges[rune(-1)] = objKeyNode + objKeyNode.TransitionEdges['"'] = objKeyEndNode + + objKeyEndNode.TransitionEdges[':'] = colonNode + objEndNode.TransitionEdges[' '] = spaceNode + + // where values should be + // this could be combined but the probs might change, we're alr doing a skip ahead + colonNode.TransitionEdges[' '] = spaceNode + + // Leads to a value + spaceNode.TransitionEdges['"'] = stringNode + spaceNode.TransitionEdges['['] = listNode + spaceNode.TransitionEdges['{'] = objNode + + for _, r := range validNumberRunes { + spaceNode.TransitionEdges[r] = numberNode + } + for _, r := range validBoolRunes { + spaceNode.TransitionEdges[r] = boolNode + } + + for _, r := range validNullRunes { + spaceNode.TransitionEdges[r] = nullNode + } + + // Values + // string node + stringNode.TransitionEdges[rune(-1)] = stringNode + stringNode.TransitionEdges['"'] = stringEndNode + + stringEndNode.TransitionEdges[','] = commaNode + stringEndNode.TransitionEdges['}'] = objEndNode + stringEndNode.TransitionEdges[']'] = listEndNode + + // TODO: add counters for allowable number of decimals, e, E, etc + // number node + for _, r := range validNumberRunes { + numberNode.TransitionEdges[r] = numberNode + } + numberNode.TransitionEdges[','] = commaNode + numberNode.TransitionEdges['}'] = objEndNode + numberNode.TransitionEdges[']'] = listEndNode + + for _, r := range validBoolRunes { + boolNode.TransitionEdges[r] = boolNode + } + + // list node + listNode.TransitionEdges[','] = commaNode + listNode.TransitionEdges['"'] = stringNode + // squash states to a value + for _, r := range validNumberRunes { + listNode.TransitionEdges[r] = numberNode + } + for _, r := range validBoolRunes { + listNode.TransitionEdges[r] = boolNode + } + for _, r := range validNullRunes { + listNode.TransitionEdges[r] = nullNode + } + + // null node + for _, r := range validNullRunes { + nullNode.TransitionEdges[r] = nullNode + } + nullNode.TransitionEdges[','] = commaNode + nullNode.TransitionEdges['}'] = objEndNode + nullNode.TransitionEdges[']'] = listEndNode + + // list comma + // should point to values + listCommaNode.TransitionEdges['"'] = stringNode + listCommaNode.TransitionEdges[' '] = listCommaNode + listCommaNode.TransitionEdges['{'] = objNode + listCommaNode.TransitionEdges['\n'] = newlineNode + + for _, r := range validNumberRunes { + listCommaNode.TransitionEdges[r] = numberNode + } + for _, r := range validBoolRunes { + listCommaNode.TransitionEdges[r] = boolNode + } + for _, r := range validNullRunes { + listCommaNode.TransitionEdges[r] = nullNode + } + + // bool node + for _, r := range validBoolRunes { + boolNode.TransitionEdges[r] = boolNode + } + boolNode.TransitionEdges['}'] = objEndNode + boolNode.TransitionEdges[']'] = listEndNode + boolNode.TransitionEdges[','] = commaNode + + listEndNode.TransitionEdges['}'] = objEndNode + listEndNode.TransitionEdges[','] = commaNode + + commaNode.TransitionEdges['{'] = objNode + commaNode.TransitionEdges['\n'] = newlineNode + commaNode.TransitionEdges['\t'] = tabNode + commaNode.TransitionEdges['"'] = objKeyNode + commaNode.TransitionEdges[' '] = spaceObjNode + + spaceObjNode.TransitionEdges['"'] = objKeyNode + + return startNode, stateToNodeMap, nil +} + +func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { + + vocab := proc.GetVocabulary() + + decodedToks := make([]string, len(vocab.Values)) + for i := range vocab.Values { + token, err := proc.Decode([]int32{int32(i)}) + if err != nil { + return err + } + decodedToks[i] = token + } + + var err error + for _, node := range stateToNodeMap { + for i := range vocab.Values { + token := decodedToks[i] + // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON + if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" { + continue + } + valid := true + curNode := node + consumedSpecialRunes := make(map[rune]bool) + for _, r := range token { + valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes) + if err != nil { + return err + } + if !valid { + break + } + } + if valid { + node.MaskTokenIDToNode[int32(i)] = curNode.State + } + } + } + return nil +} + +func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) { + if consumedSpecialRunes[r] { + return false, nil, nil + } + + specialRune := slices.Contains(stringInvalidRunes, r) + if specialRune { + if curNode.State == StateInString || curNode.State == StateInObjectKey { + return false, nil, nil + } + } + + // Check for specific rune transition + if nextNode, ok := curNode.TransitionEdges[r]; ok { + if specialRune { + if curNode.State == nextNode.State { + return false, nil, nil + } + // fmt.Println("special rune", r, "consumed") + consumedSpecialRunes[r] = true + } + return true, nextNode, nil + } + + // Check for sentinel value - if present, any rune is valid + if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok { + return true, nextNode, nil + } + + return false, nil, nil +} diff --git a/sample/json_sampler.go b/sample/json_sampler.go deleted file mode 100644 index 172f92d19..000000000 --- a/sample/json_sampler.go +++ /dev/null @@ -1,104 +0,0 @@ -package sample - -import ( - "fmt" - "math" - - "github.com/ollama/ollama/model" -) - -type JSONState int - -const ( - StateStart JSONState = iota // Initial state - StateInObject // Inside an object {} - StateInArray // Inside an array [] - StateInString // Inside a string "" - StateAfterKey // After object key, expecting : - StateAfterColon // After :, expecting value - StateAfterValue // After value, expecting , or closing bracket - StateDone // JSON parsing complete -) - -type JSONSampler struct { - state JSONState - stack []string - proc model.TextProcessor -} - -func NewJSONSampler(proc model.TextProcessor) *JSONSampler { - return &JSONSampler{ - state: StateStart, - proc: proc, - } -} - -func (s *JSONSampler) Sample(logits []float64) ([]float64, error) { - // Pre-decode valid tokens for current state - validTokens := make(map[uint32]bool) - - // Always allow EOS token in any state - // TODO: Check for other special tokens if needed - for i := range logits { - if s.proc.Is(uint32(i), model.SpecialEOS) { - validTokens[uint32(i)] = true - } - } - - // Build set of valid tokens based on current state - switch s.state { - case StateStart: - // Only allow opening brace - for i := range logits { - text, err := s.proc.Decode([]int32{int32(i)}) - if err == nil && text == "{" { - validTokens[uint32(i)] = true - } - } - case StateInObject, StateInArray: - // Allow any token - for i := range logits { - validTokens[uint32(i)] = true - } - case StateInString: - // Allow any token except closing brace - for i := range logits { - text, err := s.proc.Decode([]int32{int32(i)}) - if err == nil && text != "}" { - validTokens[uint32(i)] = true - } - } - case StateDone: - // No tokens allowed - } - - // Mark invalid tokens as NaN in one pass - for i := range logits { - if !validTokens[uint32(i)] { - logits[i] = math.NaN() - } - } - return logits, nil -} - -func (s *JSONSampler) UpdateState(tokenID int) error { - text, err := s.proc.Decode([]int32{int32(tokenID)}) - if err != nil { - return fmt.Errorf("failed to decode token: %w", err) - } - - switch s.state { - case StateStart: - if text != "{" { - return fmt.Errorf("expected {, got %s", text) - } - s.state = StateInObject - case StateInObject: - if text == "}" { - s.state = StateDone - } - case StateDone: - return fmt.Errorf("unexpected token after closing bracket: %s", text) - } - return nil -} diff --git a/sample/sample.go b/sample/sample.go index a735785f0..fcd9dbbdd 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -165,9 +165,10 @@ func (s weighed) Sample(logits []float64) ([]float64, error) { if len(logitsCopy) == 0 { return nil, errors.New("no valid tokens found") } - - // usually, a softmax is applied to sample from the logits - // in this case the uv sampler normalizes the logits so that the sum of the weights is 1 + logitsCopy, err := computeSoftmax(logitsCopy) + if err != nil { + return nil, err + } w := sampleuv.NewWeighted(logitsCopy, nil) if v, ok := w.Take(); ok { // returns the token ID @@ -176,17 +177,6 @@ func (s weighed) Sample(logits []float64) ([]float64, error) { return nil, errors.New("weighed sampler failed") } -// TODO: remove after next PR merge -type greedy struct{} - -func Greedy() Sampler { - return greedy{} -} - -func (greedy) Sample(logits []float64) ([]float64, error) { - return []float64{float64(floats.MaxIdx(logits))}, nil -} - func Sample(logits []float64, samplers ...Sampler) ([]float64, error) { var err error for _, sampler := range samplers { diff --git a/sample/sample_test.go b/sample/sample_test.go index 8900e824f..4039c29cd 100644 --- a/sample/sample_test.go +++ b/sample/sample_test.go @@ -3,14 +3,9 @@ package sample import ( "fmt" "math" - "math/rand" - "os" - "runtime" "slices" "testing" - "runtime/trace" - "gonum.org/v1/gonum/floats" ) diff --git a/sample/state_machine.go b/sample/state_machine.go deleted file mode 100644 index a5e8779fe..000000000 --- a/sample/state_machine.go +++ /dev/null @@ -1,218 +0,0 @@ -package sample - -import ( - "fmt" - - "github.com/ollama/ollama/model" -) - -type token []int32 - -type Node struct { - State JSONState - TransitionEdges map[*Node][]token -} - -func NewNode(state JSONState) *Node { - return &Node{ - State: state, - TransitionEdges: make(map[*Node][]token), - } -} - -var ( - // startToken token - startTokenVariants []token - // endToken token - // stringToken token - // objectKeyToken token - tabToken token - spaceToken token - newlineToken token - newlineSpace token - // commaToken token - // commaToken2 token - // commaToken3 token - // colonToken token - // colonToken2 token - colonTokenVariants []token - commaTokenVariants []token - stringTokenVariants []token - endTokenVariants []token - objectKeyTokenVariants []token - objKeyToColonVariants []token - stringToObjectKeyVariants []token - stringToCommaVariants []token - stringToObjectVariants []token - stringEndToObjectEndVariants []token - stringEndToCommaVariants []token -) - -func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) { - var allTokens token - for _, variant := range variants { - if t, err := proc.Encode(variant); err == nil { - allTokens = append(allTokens, t...) - } - } - if len(allTokens) == 0 { - return nil, fmt.Errorf("no valid tokens found for variants") - } - return []token{allTokens}, nil -} -func initTokens(proc model.TextProcessor) error { - var err error - - s, err := proc.Decode([]int32{761}) - fmt.Printf("761 decoded %q\n", s) - - // Compute start token variants - startVariants := []string{"{", " {", "{\n", " {\n"} - startTokenVariants, err = ComputeTokenVariants(startVariants, proc) - if err != nil { - return err - } - // Compute end token variants - endVariants := []string{"}", " }", "}\n", " }\n"} - endTokenVariants, err = ComputeTokenVariants(endVariants, proc) - if err != nil { - return err - } - - // Compute string token variants - // TODO: removed \n - stringVariants := []string{"\"", " \""} - stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc) - if err != nil { - return err - } - stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc) - if err != nil { - return err - } - // objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]} - objectKeyTokenVariants = stringTokenVariants - // Compute whitespace tokens - tabToken, err = proc.Encode("\t") - if err != nil { - return err - } - spaceToken, err = proc.Encode(" ") - if err != nil { - return err - } - newlineToken, err = proc.Encode("\n") - if err != nil { - return err - } - newlineSpace, err = proc.Encode(" \n") - if err != nil { - return err - } - - // Compute colon variants - colonVariants := []string{":"} - colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc) - if err != nil { - return err - } - objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc) - if err != nil { - return err - } - - // Compute comma variants - commaVariants := []string{",", " ,", ",\n", "\",", "\", "} - commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc) - if err != nil { - return err - } - fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants) - stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc) - if err != nil { - return err - } - - stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc) - stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc) - stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc) - stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc) - - return nil -} - -func buildStateMachine(proc model.TextProcessor) (*Node, error) { - if err := initTokens(proc); err != nil { - return nil, err - } - - startNode := NewNode(StateStart) - objectNode := NewNode(StateInObject) - objectKeyNode := NewNode(StateInObjectKey) - objectKeyEndNode := NewNode(StateInObjectKeyEnd) - stringNode := NewNode(StateInString) - // intNode := NewNode(StateInInt) - commaNode := NewNode(StateInComma) - colonNode := NewNode(StateInColon) - stringEndNode := NewNode(StateInStringEnd) - endNode := NewNode(StateEnd) - terminateNode := NewNode(StateTerminate) - - sentinelToken := token([]int32{-1}) - // intSentinelToken := token([]int32{-2}) - - // TODO: cleanup connections of rules - startNode.TransitionEdges[objectNode] = startTokenVariants - - objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants - objectNode.TransitionEdges[objectNode] = []token{newlineToken} - objectNode.TransitionEdges[objectNode] = []token{spaceToken} - - // objectNode.TransitionEdges[objectNode] = []token{newlineToken} - // objectNode.TransitionEdges[objectNode] = []token{spaceToken} - - objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken} - // characterize end of object key - objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants - objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants - - // TODO: enable this - key -> object - // objectKeyNode.TransitionEdges[objectNode] = startTokenVariants - - // objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken} - - // intNode.TransitionEdges[intNode] = []token{intSentinelToken} - // intNode.TransitionEdges[commaNode] = commaTokenVariants - // TODO: handle - // intNode.TransitionEdges[terminateNode] = endTokenVariants - - commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants - // commaNode.TransitionEdges[objectNode] = startTokenVariants - - colonNode.TransitionEdges[stringNode] = stringTokenVariants - //TODO: enable - // colonNode.TransitionEdges[intNode] = []token{intSentinelToken} - colonNode.TransitionEdges[objectNode] = startTokenVariants - - stringNode.TransitionEdges[stringNode] = []token{sentinelToken} - stringNode.TransitionEdges[stringEndNode] = stringTokenVariants - // TODO: "\""," Case not accounted for - stringNode.TransitionEdges[commaNode] = stringToCommaVariants - - // TODO: "\"",\"" Case not accounted for - stringNode.TransitionEdges[objectNode] = stringToObjectVariants - - stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants - stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants - stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants - // stringEndNode.TransitionEdges[terminateNode] = endTokenVariants - - // Should be obj end - // TODO: handle - endNode.TransitionEdges[terminateNode] = []token{} - - endNode.TransitionEdges[commaNode] = commaTokenVariants - - terminateNode.TransitionEdges[terminateNode] = []token{} - return startNode, nil -} diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index a02ad9fc2..6b3cb7132 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -7,92 +7,6 @@ type StructuredOutput struct { } func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode { - // _, stateToNodeMap, err := BuildGraph(proc) - // if err != nil { - // panic(err) - // } return nil } - -// func constrainGraph(graph *PDANode, schema *Schema) *PDANode { -// // If no schema constraints, return original graph node -// if schema == nil { -// return graph -// } - -// // Create a new node with same state -// constrainedNode := NewPDANode(graph.State) - -// // Copy over existing transitions and masks -// constrainedNode.TransitionEdges = make(map[rune]*PDANode) -// for r, node := range graph.TransitionEdges { -// constrainedNode.TransitionEdges[r] = node -// } -// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode - -// // Apply schema constraints based on type -// switch schema.EffectiveType() { -// case "object": -// // Only allow defined property names in object keys -// if graph.State == StateInObjectKey { -// // TODO: Add property name validation -// } - -// // Constrain property values based on schema -// if graph.State == StateInColon || graph.State == StateInSpace { -// // Clear transitions to only allow valid types -// constrainedNode.TransitionEdges = make(map[rune]*PDANode) - -// // Add transitions based on property schemas -// for _, prop := range schema.Properties { -// switch prop.EffectiveType() { -// case "object": -// if objNode, ok := graph.TransitionEdges['{']; ok { -// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop) -// } -// case "array": -// if arrNode, ok := graph.TransitionEdges['[']; ok { -// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop) -// } -// case "string": -// if strNode, ok := graph.TransitionEdges['"']; ok { -// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop) -// } -// case "number": -// for _, r := range validNumberRunes { -// if numNode, ok := graph.TransitionEdges[r]; ok { -// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop) -// } -// } -// case "integer": -// for _, r := range validIntRunes { -// if intNode, ok := graph.TransitionEdges[r]; ok { -// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop) -// } -// } -// case "boolean": -// for _, r := range []rune{'t', 'f'} { -// if boolNode, ok := graph.TransitionEdges[r]; ok { -// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop) -// } -// } -// case "null": -// if nullNode, ok := graph.TransitionEdges['n']; ok { -// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop) -// } -// } -// } -// } - -// case "array": -// // Constrain array items based on schema -// if schema.Items != nil { -// for r, node := range graph.TransitionEdges { -// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items) -// } -// } -// } - -// return constrainedNode -// } diff --git a/sample/trace.out b/sample/trace.out new file mode 100644 index 0000000000000000000000000000000000000000..04fd9593340fe65071a8668a8d9f1a27f83bdbc8 GIT binary patch literal 5912 zcmbtX4UkjS72bQ_yPM>_Yw*3L?9RzwBg_x7l=)_tLy% zSM9V-C{kMgMzMn<6z$MaC<+1!Lg}y+XYj8De@1u8bkth^iVkB(q1G09?n~ZdcXXf| zGjsE9?s?}s=Y03vb8cd)zaunlx<8W+N93}yGK9_=AGG%8k9@u93HF{`@l$5q^7cLp zLb>-go%QWGyp9T=G=*FZi}(ueYl0cCrIIZZ>ncer!k*imZjs@d-T!uoHT(~#_BUsI z+Yjj4ht6q_V1{_heYh+T%#EzZf!18^9vnD=awGR(@euMc6wB>&^X1%XD^t#k+!q34 z;n!mBs137x1qB{GM9EFz1PPzRk`z`SS|Xh*qTwMsy6Tv8a;6gRhPRATFxTY%eYW|Lld^CNd|!>~4&_%FByw#CiGCh=7TY<|W=9^*&Wk^FpunWu z$X?`gV0+~@>~T09ypP-XiUl#egIkGM@kg9bnB#GZ9-F;F#2yrA%#EzV-VUzPw%Y$@DonWa12?hzOq>}oA1al37wF)FI zs#Ar$4A)pd!<@gyRq3hThumW+Vcp1UTZ2Vkt>D>HA(nZqULcn7>igDB{c~dv=d9#C zl$eLl_`djyPRyEfJ2ed>z6zm{uO8Ub_~2W*^ySqCZ4O@Yly#I{xu)v!L?fgNI><;ea8)aAMn9EHbqf!e@wQiQur8TLjKwwFUMw zx!f@nIJ77?@(Bujp0ng~d0amW3Ge|{8iAE^0V3v*%fmu+n8U(kCrB8|my`0v80@ox zRWvvjk(Y49V82|jG!-Sd33)MRF*vpg)0;t3N%y6h>evh&F5t?IEAgzX^l}g+Orbi%c zAkj0cjY_<_4qwHsvtsP!E8##tiE_E0;lSM72z=k2%dN)ZCRC{>#zS^E(Mj$;!eTi5 zP&8_NdePXyz^dfJDJfF@DN`30>0<1-Le4(ET4QtzqkB$BOY7fMhg}088Zi`y2sQ%);|?j zjLGq)sR)nx**q5M!?k*?-19)uJ2+t1b zd9?SWZhB;l=~}e@AD}-ic%Pl|l$QkI=hi>)r@T15a|?5LH~2oIEFDO z7iZ6T#&K}x+j(*Bm2aYF>15vgPsy{}VbV4KzxeoPFj;U_=oo!oz3?zvzYn;+reO_Y zTzLHh;M1a(jp*p#^*~Q@j1#T4P7nj_y3q3(6OK(Eun;r)<)BoL?Ti z-?sYzOl}IEu$)VdM*Q)n{RIOr(#eXO{{Y?IY6iYzJE%+vzEN5l`YZ=`>iX+T%Rz>` z`U-@9OXp7hn?k6&&KGJqT30wz2P}*gz4?NkEjh$AL$(xC`JJm)y9{SJ~HXhR`$S0P3yjfqP9-s}7F z;vFluGTZ(NlirmN6Q<5gw}XXr+g)0^RU^13mNBCB-S;7g_TaPMulkIwqFdiA+}<8F zswhhkw?xhOht>3;oF-UU*yV4F%F%cvoROpLgmKb^ez?45;;Ir(v=i1wSK`rx+)fCe zRkV~GiTB6lC@Ht}t7&-%U@R6z&QM5VE!X6P9LW${g?5(`3AF{gR3+LTS`m(C2*=Vj z!rLMPVa2b>N|e~``aCU1!8wnWh>^IkZsKA_IW#uAbr1a+Ih~9vVM-~fa>KG7;@H*d1_Mn0<8Zwif$ag{WcbchT8(I=X1=7RGSYyW zl>5^1u+$|d65+HIQziO$Bh)D+;(hX+azxV75vhnJ=~xDEOeM80<3~4jw1@sjA$6CR ziW~ySdiOYEAsv`B%qtbM;4?*ToDqv>2C{vjh?irCpcOwPKOwl+6ZNZ-Q(6X#TF7X*lV7mtnnpVt_OU-u; z0hwhkwFMU!)GuqhP|y@}aX7A!%WX+`gyg|-bzi*z8n+(C0#ViILL6=ZMc%HO zNZWwA+Yh?EP=E^|T|E}CU}NEo2}Dr@r0iT%UCiWS>(DDfF|ihnD~9 zp5eX##VRt`s}099kpXgDDaAOfC5suLk?TDMdYE2zmKLraWbs7uz=)@KZ!xi#w8m5= zn+&Pxm^9P@y&rx^=B9xcQa@y2cQ#3GST?Q*!7U**7?YA}w4gK&;mV<@p^h1$soz49 zv`n^7>s3@hrbI8T{|@v02{jDom>RvYs|2G3eW?foR+gCr#`MCJRiKU!l5VfTE8GSn zattn93Ap-^9%J`I-|DYqxhac9XyfS_a}B%6yb2X$3sglWE6j${7t?QU}CbvLias7W1#k-yW literal 0 HcmV?d00001