From 5ec6bb52a0feb53c994adaecc47fda1dbbf6db1c Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 25 Mar 2025 15:00:14 -0700 Subject: [PATCH] prototyping --- model/process_text.go | 1 + model/process_text_spm.go | 4 + runner/ollamarunner/runner.go | 32 ++++ sample/gtf.go | 53 ++++++ sample/gtf_test.go | 138 ++++++++++++++ sample/json_types.go | 160 ++++++++++++++++ sample/pushdown_automata.go | 327 ++++++++++++++++++++++++++++++++ sample/pushdown_runner.go | 264 ++++++++++++++++++++++++++ sample/samplers.go | 43 +++-- sample/structured_outputs.go | 299 ++++++++++++++++++++++++++++++ sample/structured_python.go | 339 ++++++++++++++++++++++++++++++++++ 11 files changed, 1647 insertions(+), 13 deletions(-) create mode 100644 sample/gtf.go create mode 100644 sample/gtf_test.go create mode 100644 sample/json_types.go create mode 100644 sample/pushdown_automata.go create mode 100644 sample/pushdown_runner.go create mode 100644 sample/structured_outputs.go create mode 100644 sample/structured_python.go diff --git a/model/process_text.go b/model/process_text.go index 01af65b62..943707729 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -32,6 +32,7 @@ type TextProcessor interface { Encode(s string, addSpecial bool) ([]int32, error) Decode([]int32) (string, error) Is(int32, Special) bool + Vocab() *Vocabulary } type Vocabulary struct { diff --git a/model/process_text_spm.go b/model/process_text_spm.go index 68e3ed015..967f18101 100644 --- a/model/process_text_spm.go +++ b/model/process_text_spm.go @@ -53,6 +53,10 @@ func (spm SentencePieceModel) Is(id int32, special Special) bool { return spm.vocab.Is(id, special) } +func (spm SentencePieceModel) Vocab() *Vocabulary { + return spm.vocab +} + func (spm *SentencePieceModel) split(s string) iter.Seq[string] { return func(yield func(string) bool) { for m, _ := spm.pre.FindStringMatch(s); m != nil; m, _ = spm.pre.FindNextMatch(m) { diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index 31d20db80..3b85bc32d 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -468,6 +468,20 @@ func (s *Server) processBatch() error { return fmt.Errorf("failed to sample token: %w", err) } + if seq.sampler.JSONSampler != nil { + _, err = seq.sampler.JSONSampler.UpdateState([]int32{token}) + if err != nil { + return fmt.Errorf("failed to update state: %w", err) + } + } + + if seq.sampler.PythonSampler != nil { + err = seq.sampler.PythonSampler.UpdateState(token) + if err != nil { + return fmt.Errorf("failed to update state: %w", err) + } + } + // if it's an end of sequence token, break if s.model.(model.TextProcessor).Is(token, model.SpecialEOS) { // TODO (jmorganca): we should send this back @@ -562,6 +576,22 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { } } + // jsonSampler, err := sample.NewJSONSampler(s.model.(model.TextProcessor), nil) + // if err != nil { + // http.Error(w, "failed to load model vocabulary required for format", http.StatusInternalServerError) + // return + // } + // jsonSampler = nil + // pythonSampler := sample.NewPythonSampler(s.model.(model.TextProcessor), nil) + // pythonSampler := &sample.PythonSampler{} + // functions := []sample.PythonFunction{ + // { + // Name: "add_two_strings", + // Args: []string{"s1", "s2"}, + // Types: []string{"string", "string"}, + // }, + // } + // pythonSampler.Init(functions, s.model.(model.TextProcessor)) sampler := sample.NewSampler( req.Options.Temperature, req.Options.TopK, @@ -569,6 +599,8 @@ func (s *Server) completion(w http.ResponseWriter, r *http.Request) { req.Options.MinP, req.Options.Seed, grammar, + nil, + nil, ) seq, err := s.NewSequence(req.Prompt, req.Images, NewSequenceParams{ diff --git a/sample/gtf.go b/sample/gtf.go new file mode 100644 index 000000000..b03c9f7ec --- /dev/null +++ b/sample/gtf.go @@ -0,0 +1,53 @@ +package sample + +var DefaultGrammar = map[string]string{ + "unicode": `\x{hex}{2} | \u{hex}{4} | \U{hex}{8}`, + "null": `"null"`, + "object": `"{" (kv ("," kv)*)? "}"`, + "array": `"[" (value ("," value)*)? "]"`, + "kv": `string ":" value`, + "integer": `"0" | [1-9] [0-9]*`, + "number": `"-"? integer frac? exp?`, + "frac": `"." [0-9]+`, + "exp": `("e" | "E") ("+" | "-") [0-9]+`, + "string": `"\"" char* "\""`, + "escape": `["/" | "b" | "f" | "n" | "r" | "t" | unicode]`, + "char": `[^"\\] | escape`, + "space": `(" " | "\t" | "\n" | "\r")*`, + "hex": `[0-9] | [a-f] | [A-F]`, + "boolean": `"true" | "false"`, + "value": `object | array | string | number | boolean | "null"`, +} + +const jsonString = `object | array` + +type StateMachine struct { + states map[rune]State +} + +type State struct { + NextStates []string + // bitmask? + Mask []bool + IsTerminal bool +} + +func NewStateMachine(grammar map[string]string, startRule string) *StateMachine { + states := make(map[rune]State) + + var cumu string + flag := false + for _, r := range startRule { + if r == '"' { + flag = !flag + } + if flag { + cumu += string(r) + } + } + + sm := &StateMachine{ + states: states, + } + return sm +} diff --git a/sample/gtf_test.go b/sample/gtf_test.go new file mode 100644 index 000000000..db311b001 --- /dev/null +++ b/sample/gtf_test.go @@ -0,0 +1,138 @@ +package sample + +import ( + "testing" +) + +func TestGrammarParsing(t *testing.T) { + tests := []struct { + name string + grammar map[string]string + startRule string + input string + want bool + }{ + { + name: "simple object", + grammar: map[string]string{ + "object": `"{" "}"`, + }, + startRule: "object", + input: "{}", + want: true, + }, + { + name: "simple array", + grammar: map[string]string{ + "array": `"[" "]"`, + }, + startRule: "array", + input: "[]", + want: true, + }, + { + name: "character class", + grammar: map[string]string{ + "digit": `[0-9]`, + }, + startRule: "digit", + input: "5", + want: true, + }, + { + name: "alternation", + grammar: map[string]string{ + "bool": `"true" | "false"`, + }, + startRule: "bool", + input: "true", + want: true, + }, + { + name: "repetition", + grammar: map[string]string{ + "digits": `[0-9]+`, + }, + startRule: "digits", + input: "123", + want: true, + }, + { + name: "nested rules", + grammar: map[string]string{ + "value": `object | array`, + "object": `"{" "}"`, + "array": `"[" "]"`, + }, + startRule: "value", + input: "{}", + want: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + parser := NewParser(tt.grammar) + machine, err := parser.Parse(tt.startRule) + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + matcher := NewMatcher(machine) + got, err := matcher.Match(tt.input) + if err != nil { + t.Fatalf("Match() error = %v", err) + } + if got != tt.want { + t.Errorf("Match() = %v, want %v", got, tt.want) + } + }) + } +} + +func TestJSONGrammar(t *testing.T) { + tests := []struct { + name string + input string + want bool + }{ + {"empty object", "{}", true}, + {"empty array", "[]", true}, + {"simple string", `"hello"`, true}, + {"simple number", "123", true}, + {"simple boolean", "true", true}, + {"simple null", "null", true}, + {"object with string", `{"key": "value"}`, true}, + {"array with numbers", "[1, 2, 3]", true}, + {"nested object", `{"obj": {"key": "value"}}`, true}, + {"nested array", `[1, [2, 3], 4]`, true}, + {"invalid object", "{", false}, + {"invalid array", "[1, 2", false}, + {"invalid string", `"hello`, false}, + } + + parser := NewParser(DefaultGrammar) + machine, err := parser.Parse("value") + if err != nil { + t.Fatalf("Parse() error = %v", err) + } + + matcher := NewMatcher(machine) + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got, err := matcher.Match(tt.input) + if tt.want { + if err != nil { + t.Errorf("Match() error = %v", err) + } + if !got { + t.Errorf("Match() = false, want true") + } + } else { + if err == nil && got { + t.Errorf("Match() = true, want false") + } + } + }) + } +} diff --git a/sample/json_types.go b/sample/json_types.go new file mode 100644 index 000000000..7bbb7b951 --- /dev/null +++ b/sample/json_types.go @@ -0,0 +1,160 @@ +package sample + +import ( + "fmt" +) + +type JSONState int + +const ( + StateStart JSONState = iota + StateInObject + StateInObjectKey + StateInStructuredKey + StateInStructuredValue + StateNewline + StateTab + StateSpace + StateInString + StateInInt + StateInFloat + StateInBool + StateInNull + StateInColon + StateInComma + StateInTab + StateInSpaceToValue + StateInSpaceEndValue + StateInNewlineEndValue + StateInObjSpace + StateInList + StateInListComma + StateInValue + StateInValueEnd + StateInListEnd + StateInListObjectEnd + StateInNewline + StateInNumber + StateInNumberEnd + StateInStringEnd + StateInObjectKeyEnd + StateTerminate + StateInObjectEnd + StateTransitioningToTerminate + StateInListStartJSON +) + +var JSONStates = []JSONState{ + StateStart, + StateInObject, + StateInObjectKey, + StateInStructuredKey, + StateInStructuredValue, + StateNewline, + StateTab, + StateSpace, + StateInString, + StateInInt, + StateInFloat, + StateInBool, + StateInNull, + StateInColon, + StateInComma, + StateInTab, + StateInSpaceToValue, + StateInSpaceEndValue, + StateInNewlineEndValue, + StateInObjSpace, + StateInListStartJSON, + StateInList, + StateInListComma, + StateInValue, + StateInValueEnd, + StateInListEnd, + StateInListObjectEnd, + StateInNewline, + StateInNumber, + StateInNumberEnd, + StateInStringEnd, + StateInObjectKeyEnd, + StateTerminate, + StateInObjectEnd, + StateTransitioningToTerminate, +} + +func (s JSONState) String() string { + switch s { + case StateStart: + return "StateStart" + case StateInObject: + return "StateInObject" + case StateInObjectKey: + return "StateInObjectKey" + case StateInStructuredKey: + return "StateInStructuredKey" + case StateInStructuredValue: + return "StateInStructuredValue" + case StateNewline: + return "StateNewline" + case StateTab: + return "StateTab" + case StateSpace: + return "StateSpace" + case StateInString: + return "StateInString" + case StateInInt: + return "StateInInt" + case StateInFloat: + return "StateInFloat" + case StateInBool: + return "StateInBool" + case StateInNull: + return "StateInNull" + case StateInColon: + return "StateInColon" + case StateInComma: + return "StateInComma" + case StateInTab: + return "StateInTab" + case StateInSpaceToValue: + return "StateInSpaceToValue" + case StateInSpaceEndValue: + return "StateInSpaceEndValue" + case StateInNewlineEndValue: + return "StateInNewlineEndValue" + case StateInObjSpace: + return "StateInObjSpace" + case StateInList: + return "StateInList" + case StateInListComma: + return "StateInListComma" + case StateInValue: + return "StateInValue" + case StateInValueEnd: + return "StateInValueEnd" + case StateInListEnd: + return "StateInListEnd" + case StateInListObjectEnd: + return "StateInListObjectEnd" + case StateInNewline: + return "StateInNewline" + case StateInNumber: + return "StateInNumber" + case StateInNumberEnd: + return "StateInNumberEnd" + case StateInStringEnd: + return "StateInStringEnd" + case StateInObjectKeyEnd: + return "StateInObjectKeyEnd" + case StateTerminate: + return "StateTerminate" + case StateInObjectEnd: + return "StateInObjectEnd" + case StateTransitioningToTerminate: + return "StateTransitioningToTerminate" + case StateInListStartJSON: + return "StateInListStartJSON" + default: + return fmt.Sprintf("Unknown state: %d", s) + } +} diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go new file mode 100644 index 000000000..f5eaadaac --- /dev/null +++ b/sample/pushdown_automata.go @@ -0,0 +1,327 @@ +package sample + +import ( + "fmt" + "slices" + + "github.com/ollama/ollama/model" +) + +/* +Key JSON rules to consider: + +1. Whitespace handling: + - Need to handle all valid JSON whitespace characters (\r, spaces between tokens) + - Current code only handles some whitespace cases + +2. Number validation: + - Need proper validation for special number cases like -0 + - Should handle .5 style decimals + - Need limits on scientific notation (e, E) + +3. String escaping: + - Currently marks \ as invalid but should allow escaped sequences: + - \" + - \n + - \u1234 unicode escapes + +4. Empty object/array transitions: + - Direct {} and [] cases could be more explicit + - Need clear transitions for these edge cases + +5. Nested depth limits: + - No protection against excessive nesting + - Could cause stack overflow with deeply nested structures +*/ + +// TODO: / should be valid but an escape character +var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'} + +var ( + intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'} + validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'} +) + +var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'} + +var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'} + +var validNullRunes = []rune{'n', 'u', 'l', 'l'} + +type PDA struct { + State JSONState + TransitionEdges map[rune]*PDA + MaskTokenIDToNode map[int32]*PDA +} + +func NewPDANode(state JSONState) *PDA { + return &PDA{ + State: state, + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } +} + +type PDAGraphBuilder struct { + proc model.TextProcessor + decodedToks []string + stateToNodeMap map[JSONState]*PDA + tokenToStatesMap map[int32][]JSONState +} + +func (b *PDAGraphBuilder) BuildGraph() error { + stateToNodeMap := make(map[JSONState]*PDA) + for _, state := range JSONStates { + stateToNodeMap[state] = NewPDANode(state) + } + + stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] + stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON] + + // TODO: update naming here - and revisit values + stateToNodeMap[StateInListStartJSON].TransitionEdges['{'] = stateToNodeMap[StateInObject] + stateToNodeMap[StateInListStartJSON].TransitionEdges['['] = stateToNodeMap[StateInListStartJSON] + + stateToNodeMap[StateInObject].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] + stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] + stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + stateToNodeMap[StateInObject].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + + // new line + stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] + stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab] + stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + // stateToNodeMap[StateInNewline].TransitionEdges['{'] = stateToNodeMap[StateInObject] + + // new line end value + // stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + + stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] + stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] + // TODO: see if this is needed for formatting + stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + + stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] + stateToNodeMap[StateInTab].TransitionEdges['\t'] = stateToNodeMap[StateInNewline] + + stateToNodeMap[StateInObjectKey].TransitionEdges[rune(-1)] = stateToNodeMap[StateInObjectKey] + stateToNodeMap[StateInObjectKey].TransitionEdges['"'] = stateToNodeMap[StateInObjectKeyEnd] + + stateToNodeMap[StateInObjectKeyEnd].TransitionEdges[':'] = stateToNodeMap[StateInColon] + + stateToNodeMap[StateInObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] + stateToNodeMap[StateInObjectEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + + // where values should be + // this could be combined but the probl might change, we're alr doing a skip ahead + stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] + stateToNodeMap[StateInColon].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue] + + stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList] + stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject] + addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap) + + // Leads to a value + stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList] + stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject] + addValueConnections(stateToNodeMap[StateInSpaceToValue], stateToNodeMap) + stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInSpaceToValue].TransitionEdges['\n'] = stateToNodeMap[StateInSpaceToValue] + + // Values + // string node + stateToNodeMap[StateInString].TransitionEdges[rune(-1)] = stateToNodeMap[StateInString] + stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd] + + // String end node + addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap) + // stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + + // TODO: add counters for allowable number of decimals, e, E, etc + // number node + for _, r := range validNumberRunes { + stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber] + } + addEnds(stateToNodeMap[StateInNumber], stateToNodeMap) + // stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + + // list node + stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma] + stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject] + stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList] + stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList] + // early end + stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + + // list end node + stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + // stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma] + stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + + // empty list + stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + addValueConnections(stateToNodeMap[StateInList], stateToNodeMap) + + // null node + for _, r := range validNullRunes { + stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull] + } + addEnds(stateToNodeMap[StateInNull], stateToNodeMap) + stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] + stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + + // list comma + // should point to values + stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma] + stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] + stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList] + stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInList] + stateToNodeMap[StateInListComma].TransitionEdges['\t'] = stateToNodeMap[StateInList] + + addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap) + + // list object end + stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma] + stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + // TODO: not sure if this is needed + stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + + // bool node + for _, r := range validBoolRunes { + stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] + } + stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] + addEnds(stateToNodeMap[StateInBool], stateToNodeMap) + // stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + + // comma node + stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject] + stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] + stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] + stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace] + // todo: review this space transition + // stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue] + + // space end value + // stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue] + stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] + stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue] + + b.stateToNodeMap = stateToNodeMap + if err := b.preComputeValidStates(); err != nil { + return err + } + return nil +} + +func addEnds(node *PDA, stateToNodeMap map[JSONState]*PDA) { + node.TransitionEdges[','] = stateToNodeMap[StateInComma] + node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] + node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd] +} + +func addValueConnections(node *PDA, stateToNodeMap map[JSONState]*PDA) { + node.TransitionEdges['"'] = stateToNodeMap[StateInString] + for _, r := range validNumberRunes { + node.TransitionEdges[r] = stateToNodeMap[StateInNumber] + } + // TODO(parthsareen): force the output and shift similar to structured outputs + node.TransitionEdges['t'] = stateToNodeMap[StateInBool] + node.TransitionEdges['f'] = stateToNodeMap[StateInBool] + node.TransitionEdges['n'] = stateToNodeMap[StateInNull] +} + +func (b *PDAGraphBuilder) preComputeValidStates() error { + for _, node := range b.stateToNodeMap { + // if node.State == StateInObjectKey { + // if len(b.stateToNodeMap[StateInString].MaskTokenIDToNode) > 0 { + // b.stateToNodeMap[StateInObjectKey].MaskTokenIDToNode = b.stateToNodeMap[StateInString].MaskTokenIDToNode + // fmt.Println("copying string mask to object key mask") + // } + // } + if err := b.CreateMask(node); err != nil { + return err + } + } + return nil +} + +func (b *PDAGraphBuilder) preComputeTokenToStatesMap() error { + // TODO: make can be somewhere else too + b.tokenToStatesMap = make(map[int32][]JSONState) + for i, t := range b.decodedToks { + for _, r := range t { + if r == '"' { + b.tokenToStatesMap[int32(i)] = append(b.tokenToStatesMap[int32(i)], StateInString) + } + } + } + return nil +} + +// TODO: the mask for obj key and string should be the same? +func (b *PDAGraphBuilder) CreateMask(node *PDA) error { + if node == nil { + return fmt.Errorf("node cannot be nil") + } + for i := range b.decodedToks { + token := b.decodedToks[i] + // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON + if b.proc.Is(int32(i), model.SpecialEOS) || b.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" { + continue + } + curNode := node + valid := true + consumedSpecialRunes := make(map[rune]bool) + for _, r := range token { + curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes) + if curNode == nil || !valid { + break + } + } + if valid { + node.MaskTokenIDToNode[int32(i)] = curNode + } + } + return nil +} + +func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) { + if consumedSpecialRunes[r] { + return nil, false + } + + specialRune := slices.Contains(stringInvalidRunes, r) + if specialRune { + if curNode.State == StateInString || curNode.State == StateInObjectKey { + return nil, false + } + } + + // Check for specific rune transition + if nextNode, ok := curNode.TransitionEdges[r]; ok { + // fmt.Println("next node", nextNode) + if specialRune { + if curNode.State == nextNode.State { + return nil, false + } + consumedSpecialRunes[r] = true + } + return nextNode, true + } + + // Check for sentinel value - if present, any rune is valid + if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok { + return nextNode, true + } + + return nil, false +} diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go new file mode 100644 index 000000000..cf51c42e6 --- /dev/null +++ b/sample/pushdown_runner.go @@ -0,0 +1,264 @@ +package sample + +import ( + "fmt" + "math" + "runtime" + "time" + + "github.com/ollama/ollama/model" +) + +// TODO: safety in case of invalid json +// TODO: partial JSON matching? +// TODO: interfaces to cleanup with return values +// TODO this interface shouldn't be the sampler - should just use Sampler +// TODO: add penalties for string \n stuff +// TODO: minimize number of fwd passes if there is only one match +// TODO: greedy sample initially and then backtrack if no match + +type PushdownSampler struct { + PDAGraphBuilder + curNode *PDA + braceStack []rune + stateCounter uint32 +} + +// graph should be built once and reused per tokenizer +func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) { + start := time.Now() + + fmt.Println("--------------------------------") + fmt.Println("PDA sampler") + fmt.Println("--------------------------------") + var m runtime.MemStats + runtime.ReadMemStats(&m) + before := m.Alloc + fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) + + vocab := proc.Vocab() + decodedToks := make([]string, len(vocab.Values)) + for i := range vocab.Values { + token, err := proc.Decode([]int32{int32(i)}) + if err != nil { + return nil, err + } + decodedToks[i] = token + } + + gb := &PDAGraphBuilder{ + proc: proc, + decodedToks: decodedToks, + } + + if err := gb.BuildGraph(); err != nil { + return nil, err + } + + runtime.ReadMemStats(&m) + after := m.Alloc + fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) + fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + fmt.Printf("Graph build time = %v\n", time.Since(start)) + + // TODO: this can be simplified + return &PushdownSampler{ + curNode: gb.stateToNodeMap[StateStart], + PDAGraphBuilder: *gb, + braceStack: []rune{}, + stateCounter: 0, + }, nil +} + +// TODO: need to add resampling logic if the first sample was not good +// greedy sample + backtrack? +func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) { + switch s.curNode.State { + case StateInString: + return s.maskLogits(logits, s.curNode) + + case StateInListEnd: + // force finish if no braces left + if len(s.braceStack) == 0 { + s.curNode = NewPDANode(StateTerminate) + return forceFinish(s, logits) + } + + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil + + case StateTerminate: + return forceFinish(s, logits) + + case StateInObjectEnd: + // force finish if no braces left + if len(s.braceStack) == 0 { + s.curNode = NewPDANode(StateTerminate) + return forceFinish(s, logits) + } + + peek := s.braceStack[len(s.braceStack)-1] + if peek == rune('[') { + s.curNode = s.stateToNodeMap[StateInListObjectEnd] + } + + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil + + case StateInComma: + peek := s.braceStack[len(s.braceStack)-1] + if peek == rune('[') { + s.curNode = s.stateToNodeMap[StateInListComma] + } + + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil + + default: + fmt.Println("masking logits current state", s.curNode.State) + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil + } +} + +func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) { + for i := range logits { + if s.proc.Is(int32(i), model.SpecialEOS) { + logits[i] = 1.0 + } else { + logits[i] = float32(math.Inf(-1)) + } + } + return logits, nil +} + +func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) { + fmt.Println("current state - updating", s.curNode.State) + mappedString, err := s.proc.Decode(tokenSlice) + if err != nil { + return nil, err + } + fmt.Printf(">>> mappedString: %q\n", mappedString) + + // Special handling for EOS token in terminate state + if s.curNode.State == StateTerminate { + for _, tokenID := range tokenSlice { + if s.proc.Is(tokenID, model.SpecialEOS) { + return tokenSlice, nil + } + } + } + + // flag := -1 + // endBraceRunes := []rune{'}', ']'} + for _, r := range mappedString { + // TODO: if this is enabled again, make sure to appropriately handle the state transitions + // if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 { + // fmt.Printf("stack is empty, extra closing brace %c\n", r) + // // flag = i + // break + + // } + if r == rune('{') { + s.braceStack = append(s.braceStack, r) + } + if r == rune('[') { + s.braceStack = append(s.braceStack, r) + } + if r == rune('}') { + if len(s.braceStack) == 0 { + return nil, fmt.Errorf("stack is empty, extra closing brace %c", r) + } + top := s.braceStack[len(s.braceStack)-1] + if top != rune('{') { + return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') + } + s.braceStack = s.braceStack[:len(s.braceStack)-1] + } + + if r == rune(']') { + if len(s.braceStack) == 0 { + return nil, fmt.Errorf("stack is empty, extra closing brace %c", r) + } + top := s.braceStack[len(s.braceStack)-1] + if top != rune('[') { + return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') + } + s.braceStack = s.braceStack[:len(s.braceStack)-1] + } + } + + // if flag != -1 { + // tokenSlice = tokenSlice[:flag] + // } + // fmt.Println("flag!", flag) + for _, tokenID := range tokenSlice { + // transition to the next node + nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID] + if !ok { + return nil, fmt.Errorf("invalid token: %q", mappedString) + } + fmt.Println("transitioning to", nextNode.State) + + // TODO: add a penalty for staying in the same state too long + if nextNode.State == s.curNode.State { + s.stateCounter++ + } else { + s.stateCounter = 0 + } + s.curNode = nextNode + fmt.Println("updated curNode state", s.curNode.State) + } + return tokenSlice, nil +} + +// greedy sample + backtrack? +func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) { + // Create a new slice with same length as logits, initialized to -Inf + maskedLogits := make([]float32, len(logits)) + for i := range maskedLogits { + maskedLogits[i] = float32(math.Inf(-1)) + } + + // Only update values for valid token IDs from the mask map + for tokenID := range node.MaskTokenIDToNode { + if int(tokenID) < len(logits) { + maskedLogits[tokenID] = logits[tokenID] + } + } + + return maskedLogits, nil +} + +func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) { + maxLogit := float32(math.Inf(-1)) + maxIndex := -1 + + // Find the maximum logit value among valid tokens + for tokenID := range node.MaskTokenIDToNode { + if int(tokenID) < len(logits) && logits[tokenID] > maxLogit { + maxLogit = logits[tokenID] + maxIndex = int(tokenID) + } + } + + if maxIndex == -1 { + return nil, fmt.Errorf("no valid tokens found in mask") + } + + logits[0] = float32(maxIndex) + return logits, nil + // return maxIndex, nil +} diff --git a/sample/samplers.go b/sample/samplers.go index ef8033691..bb6b4c11f 100644 --- a/sample/samplers.go +++ b/sample/samplers.go @@ -17,12 +17,14 @@ type token struct { } type Sampler struct { - rng *rand.Rand - topK int - topP float32 - minP float32 - temperature float32 - grammar *Grammar + rng *rand.Rand + topK int + topP float32 + minP float32 + temperature float32 + grammar *Grammar + JSONSampler *JSONSampler + PythonSampler *PythonSampler } func (s *Sampler) Sample(logits []float32) (int32, error) { @@ -30,6 +32,19 @@ func (s *Sampler) Sample(logits []float32) (int32, error) { return -1, errors.New("sample: no logits provided to sample") } + var err error + if s.JSONSampler != nil { + logits, err = s.JSONSampler.Apply(logits) + if err != nil { + return -1, err + } + } + if s.PythonSampler != nil { + logits, err = s.PythonSampler.ApplyMask(logits) + if err != nil { + return -1, err + } + } tokens := make([]token, len(logits)) for i := range logits { tokens[i].id = int32(i) @@ -127,7 +142,7 @@ func (s *Sampler) sample(tokens []token) (token, error) { } // TODO(parthsareen): update sampler interface to use json unmarshal https://github.com/ollama/ollama/issues/9278 -func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar) Sampler { +func NewSampler(temperature float32, topK int, topP float32, minP float32, seed int, grammar *Grammar, jsonSampler *JSONSampler, pythonSampler *PythonSampler) Sampler { var rng *rand.Rand if seed != -1 { // PCG requires two parameters: sequence and stream @@ -155,12 +170,14 @@ func NewSampler(temperature float32, topK int, topP float32, minP float32, seed } return Sampler{ - rng: rng, - topK: topK, - topP: topP, - minP: minP, - temperature: temperature, - grammar: grammar, + rng: rng, + topK: topK, + topP: topP, + minP: minP, + temperature: temperature, + grammar: grammar, + JSONSampler: jsonSampler, + PythonSampler: pythonSampler, } } diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go new file mode 100644 index 000000000..ec5db2302 --- /dev/null +++ b/sample/structured_outputs.go @@ -0,0 +1,299 @@ +package sample + +import ( + "fmt" + "log/slog" + "runtime" + "time" + + "github.com/ollama/ollama/grammar/jsonschema" + "github.com/ollama/ollama/model" +) + +type JSONSampler struct { + schema *jsonschema.Schema + propIdx int + propToNodeMap map[string]*PDA + pdaSampler *PushdownSampler + decodedToks []string +} + +func NewJSONSampler(proc model.TextProcessor, schema *jsonschema.Schema) (*JSONSampler, error) { + slog.Info("NewJSONSampler", "schema", schema) + if proc == nil { + return nil, fmt.Errorf("TextProcessor cannot be nil") + } + + pdaSampler, err := NewPushdownSampler(proc) + if err != nil { + return nil, fmt.Errorf("failed to create PushdownSampler: %w", err) + } + + if schema == nil { + return &JSONSampler{ + schema: nil, + propIdx: -1, + propToNodeMap: nil, + pdaSampler: pdaSampler, + }, nil + } + + // fmt.Println("schema not nil") + so := &JSONSampler{ + schema: schema, + propIdx: -1, + propToNodeMap: make(map[string]*PDA), + pdaSampler: pdaSampler, + } + + so.schemaToGraph() + + // Benchmark token decoding + start := time.Now() + var m runtime.MemStats + runtime.ReadMemStats(&m) + before := m.Alloc + + vocab := proc.Vocab() + decodedToks := make([]string, len(vocab.Values)) + for i := range vocab.Values { + token, err := proc.Decode([]int32{int32(i)}) + if err != nil { + return nil, err + } + decodedToks[i] = token + } + so.decodedToks = decodedToks + + runtime.ReadMemStats(&m) + after := m.Alloc + fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + fmt.Printf("Token decode time = %v\n", time.Since(start)) + + fmt.Println("--------------------------------") + fmt.Println("SOSampler") + fmt.Println("--------------------------------") + // Benchmark this section + start = time.Now() + runtime.ReadMemStats(&m) + before = m.Alloc + + // TODO: still messed up + // TODO: recursion use case + // key masks + for _, prop := range so.schema.Properties { + node := so.propToNodeMap[prop.Name] + // propName -> node + curState := node.State + fromNode := node + so.pdaSampler.CreateMask(fromNode) + for curState == StateInStructuredKey { + // there is only one edge + for r, toNode := range fromNode.TransitionEdges { + fmt.Println("rune", r, "edge", toNode.State) + so.pdaSampler.CreateMask(toNode) + fmt.Printf("created mask for %c\n", r) + curState = toNode.State + fmt.Println("next state", curState) + // TODO: theres an extra gen for " right now + fromNode = toNode + } + } + + if curState != StateInColon { + return nil, fmt.Errorf("expected state to be StateInColon, got %v", curState) + } + + // so.pdaSampler.CreateMask(fromNode) + + fromNode = fromNode.TransitionEdges[' '] + + so.pdaSampler.CreateMask(fromNode) + curState = fromNode.State + for _, toNode := range fromNode.TransitionEdges { + fmt.Println("toNode", toNode.State) + } + } + + // runtime.ReadMemStats(&m) + // after = m.Alloc + // fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + // fmt.Printf("Mask creation time = %v\n", time.Since(start)) + // fmt.Println("--------------------------------") + + return so, nil +} + +func (s *JSONSampler) schemaToGraph() { + schemaType := s.schema.EffectiveType() + switch schemaType { + case "object": + // TODO: see if we need to connect these to the JSON graph + + // each prop is a key + for _, prop := range s.schema.Properties { + // name of key + name := prop.Name + keyNode := &PDA{ + State: StateInStructuredKey, // this is unchanging, will impact sampling + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + + prevNode := keyNode + for _, r := range name { + runeNode := &PDA{ + State: StateInStructuredKey, // this is unchanging, will impact sampling + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + // fmt.Println("runeNode created", runeNode.State) + // fmt.Printf("runeNode created %c\n", r) + + // since alloc on heap connections wil still map + prevNode.TransitionEdges[r] = runeNode + prevNode = runeNode + } + + // point to end of object key node after all chars are done + // prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd] + + // link to value node + // Create a node for the end of the key (after the closing quote) + stringEndNode := &PDA{ + State: StateInStructuredKey, + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + prevNode.TransitionEdges['"'] = stringEndNode + prevNode = stringEndNode + + // Add transition for colon after key + colonNode := &PDA{ + State: StateInColon, + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + prevNode.TransitionEdges[':'] = colonNode + prevNode = colonNode + + // Add transition for space after colon + spaceNode := &PDA{ + State: StateInSpaceToValue, + TransitionEdges: make(map[rune]*PDA), + MaskTokenIDToNode: make(map[int32]*PDA), + } + prevNode.TransitionEdges[' '] = spaceNode + prevNode = spaceNode + + value := prop.Type + switch value { + case "object": + fmt.Println("object under key: ", name) + case "array": + fmt.Println("array under key: ", name) + case "string": + fmt.Println("string under key: ", name) + prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInString] + case "number": + fmt.Println("number under key: ", name) + for _, r := range validNumberRunes { + prevNode.TransitionEdges[r] = s.pdaSampler.stateToNodeMap[StateInNumber] + } + case "boolean": + fmt.Println("boolean under key: ", name) + prevNode.TransitionEdges['t'] = s.pdaSampler.stateToNodeMap[StateInBool] + prevNode.TransitionEdges['f'] = s.pdaSampler.stateToNodeMap[StateInBool] + prevNode.TransitionEdges['n'] = s.pdaSampler.stateToNodeMap[StateInNull] + } + + // points to start of the key + s.propToNodeMap[name] = keyNode + fmt.Println("name", name, "keyNode", keyNode.State) + } + } + // TODO: do values + recursion +} + +func (s *JSONSampler) Apply(logits []float32) ([]float32, error) { + if s.schema == nil { + return s.pdaSampler.Apply(logits) + } + + switch s.pdaSampler.curNode.State { + // TODO: doesnt account for multi rune case + case StateInObjectKey: + if s.propIdx > len(s.schema.Properties)-1 { + return nil, fmt.Errorf("propIdx out of bounds") + } + // fmt.Println("in object key - structured outputs") + // TODO: this tracking should probably be coming from a stack to track nested objects + // simple case + s.propIdx++ + fmt.Println("propIdx", s.propIdx) + prop := s.schema.Properties[s.propIdx] + fmt.Println("prop", prop.Name) + s.pdaSampler.curNode = s.propToNodeMap[prop.Name] + fmt.Println("changed curNode state to", s.pdaSampler.curNode.State) + logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode) + if err != nil { + return nil, err + } + return logits, nil + + default: + + // Will only happen for the last prop - can also be precomputed. + if s.propIdx == len(s.schema.Properties)-1 { + // todo: if i incremenet propidx then i know im in last value as well + switch s.pdaSampler.curNode.State { + case StateInObjectEnd: + fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State) + s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA) + s.pdaSampler.curNode = NewPDANode(StateTerminate) + s.propIdx++ + + // TODO: this needs to be optimized in some way, computing mask on the fly is expensive + case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd: + fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State) + delete(s.pdaSampler.curNode.TransitionEdges, ',') + s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA) + + s.pdaSampler.CreateMask(s.pdaSampler.curNode) + s.propIdx++ + } + } + return s.pdaSampler.Apply(logits) + } +} + +func (s *JSONSampler) UpdateState(tokenSlice []int32) ([]int32, error) { + tokenSlice, err := s.pdaSampler.UpdateState(tokenSlice) + if err != nil { + return nil, err + } + + if s.schema == nil { + // Don't need to update state for unconstrained JSON sampling + return tokenSlice, nil + } + + switch s.pdaSampler.curNode.State { + case StateInObjectKey: + s.propIdx++ + fmt.Println("propIdx", s.propIdx) + prop := s.schema.Properties[s.propIdx] + fmt.Println("prop", prop.Name) + s.pdaSampler.curNode = s.propToNodeMap[prop.Name] + // TODO: this does not work - mike + // str, err := s.pdaSampler.proc.Decode(tokenSlice) + // if err != nil { + // return nil, err + // } + // fmt.Println("str", str) + + return tokenSlice, nil + default: + return tokenSlice, nil + } +} diff --git a/sample/structured_python.go b/sample/structured_python.go new file mode 100644 index 000000000..2b8de21c3 --- /dev/null +++ b/sample/structured_python.go @@ -0,0 +1,339 @@ +package sample + +import ( + "fmt" + "math" + "slices" + + "github.com/ollama/ollama/model" +) + +type PythonState int + +const ( + PythonStateStart PythonState = iota + StateInFunction + StateInFunctionArgs + StateInFunctionArgsType + StateInFunctionEnd + PStateInString + PStateInStringEnd + PStateInNumber + PStateInList + PStateInListEnd + PStateInDict + PStateInDictEnd + PStateInTuple + PStateInTupleEnd + PStateTerminate +) + +func (s PythonState) String() string { + switch s { + case PythonStateStart: + return "PythonStateStart" + case StateInFunction: + return "StateInFunction" + case StateInFunctionArgs: + return "StateInFunctionArgs" + case StateInFunctionArgsType: + return "StateInFunctionArgsType" + case StateInFunctionEnd: + return "StateInFunctionEnd" + case PStateInString: + return "PStateInString" + case PStateInStringEnd: + return "PStateInStringEnd" + case PStateInNumber: + return "PStateInNumber" + case PStateInList: + return "PStateInList" + case PStateInListEnd: + return "PStateInListEnd" + case PStateInDict: + return "PStateInDict" + case PStateInDictEnd: + return "PStateInDictEnd" + case PStateInTuple: + return "PStateInTuple" + case PStateInTupleEnd: + return "PStateInTupleEnd" + case PStateTerminate: + return "PStateTerminate" + default: + return fmt.Sprintf("PythonState(%d)", s) + } +} + +var PythonStates = []PythonState{ + PythonStateStart, + StateInFunction, + StateInFunctionArgs, + StateInFunctionArgsType, + StateInFunctionEnd, + PStateInString, + PStateInStringEnd, + PStateInNumber, + PStateInList, + PStateInListEnd, + PStateInDict, + PStateInDictEnd, + PStateInTuple, + PStateInTupleEnd, + PStateTerminate, +} + +type Node struct { + State PythonState + TransitionEdges map[rune]*Node + MaskTokenIDToNode map[int32]*Node +} + +func NewNode(state PythonState) *Node { + return &Node{ + State: state, + TransitionEdges: make(map[rune]*Node), + MaskTokenIDToNode: make(map[int32]*Node), + } +} + +type PythonFunction struct { + Name string + Args []string + Types []string +} + +type PythonSampler struct { + stateToNodes map[PythonState]*Node + proc model.TextProcessor + decodedToks []string + curNode *Node +} + +func (s *PythonSampler) Init(functions []PythonFunction, proc model.TextProcessor) error { + s.proc = proc + decodedToks := make([]string, len(proc.Vocab().Values)) + for i := range proc.Vocab().Values { + token, err := proc.Decode([]int32{int32(i)}) + if err != nil { + return err + } + decodedToks[i] = token + } + s.decodedToks = decodedToks + s.BuildGraph() + for _, function := range functions { + prevNode := s.stateToNodes[PythonStateStart] + + for _, r := range function.Name { + nextNode := NewNode(StateInFunction) + prevNode.TransitionEdges[r] = nextNode + if err := s.CreateMask(nextNode); err != nil { + return err + } + fmt.Println("prevNode", prevNode.State) + fmt.Printf("transition edge: %q\n", r) + fmt.Println("nextNode", nextNode.State) + prevNode = nextNode + } + prevNode.TransitionEdges['('] = s.stateToNodes[StateInFunctionArgs] + s.CreateMask(prevNode) + prevNode = s.stateToNodes[StateInFunctionArgs] + for i, arg := range function.Args { + for _, r := range arg { + nextNode := NewNode(StateInFunctionArgs) + prevNode.TransitionEdges[r] = nextNode + s.CreateMask(prevNode) + prevNode = nextNode + } + prevNode.TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs] + // prevNode = s.stateToNodes[StateInFunctionArgs] + prevNode.TransitionEdges['='] = NewNode(StateInFunctionArgsType) + s.CreateMask(prevNode) + prevNode = prevNode.TransitionEdges['='] + switch function.Types[i] { + case "string": + prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInString] + s.CreateMask(prevNode.TransitionEdges['"']) + case "number": + prevNode.TransitionEdges['"'] = s.stateToNodes[PStateInNumber] + s.CreateMask(prevNode.TransitionEdges['"']) + } + } + + } + s.curNode = s.stateToNodes[PythonStateStart] + fmt.Println("curNode", s.curNode.State) + fmt.Println("transition edges", s.curNode.TransitionEdges) + if err := s.CreateMask(s.curNode); err != nil { + return err + } + fmt.Println("maskTokenIDToNode", s.curNode.MaskTokenIDToNode) + for tokenID, node := range s.curNode.MaskTokenIDToNode { + fmt.Printf("tokenID: %d, node: %v\n", s.decodedToks[tokenID], node.State) + } + + return nil +} + +func (s *PythonSampler) BuildGraph() error { + s.stateToNodes = make(map[PythonState]*Node) + for _, state := range PythonStates { + s.stateToNodes[state] = NewNode(state) + } + + for _, state := range s.stateToNodes { + if err := s.CreateMask(state); err != nil { + return err + } + } + + // String + s.stateToNodes[PStateInString].TransitionEdges[rune(-1)] = s.stateToNodes[PStateInString] + s.stateToNodes[PStateInString].TransitionEdges['"'] = s.stateToNodes[PStateInStringEnd] + + // String end + s.stateToNodes[PStateInStringEnd].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs] + s.stateToNodes[PStateInStringEnd].TransitionEdges[')'] = s.stateToNodes[PStateTerminate] + // Number + for _, r := range validNumberRunes { + s.stateToNodes[PStateInNumber].TransitionEdges[r] = s.stateToNodes[PStateInNumber] + } + s.stateToNodes[PStateInNumber].TransitionEdges[')'] = s.stateToNodes[PStateTerminate] + s.stateToNodes[PStateInNumber].TransitionEdges[','] = s.stateToNodes[StateInFunctionArgs] + s.stateToNodes[PStateInNumber].TransitionEdges[' '] = s.stateToNodes[StateInFunctionArgs] + + return nil +} + +func (s *PythonSampler) ApplyMask(logits []float32) ([]float32, error) { + if s.curNode.State == PStateTerminate { + logits, err := finish(s, logits) + if err != nil { + return nil, err + } + return logits, nil + } + logits, err := s.maskLogits(logits, s.curNode) + if err != nil { + return nil, err + } + return logits, nil +} + +func (s *PythonSampler) UpdateState(token int32) error { + mappedString, err := s.proc.Decode([]int32{token}) + if err != nil { + return err + } + fmt.Printf(">>> mappedString: %q\n", mappedString) + + if s.curNode.State == PStateTerminate { + if s.proc.Is(token, model.SpecialEOS) { + return nil + } + } + nextNode, ok := s.curNode.MaskTokenIDToNode[token] + if !ok { + return fmt.Errorf("invalid token: %q", mappedString) + } + s.curNode = nextNode + fmt.Println("curNode", s.curNode.State) + for r, node := range s.curNode.TransitionEdges { + fmt.Printf("transition edge: %q -> %v\n", r, node.State) + } + if err := s.CreateMask(s.curNode); err != nil { + return err + } + return nil +} + +func (s *PythonSampler) CreateMask(node *Node) error { + if node == nil { + return fmt.Errorf("node cannot be nil") + } + for i := range s.decodedToks { + token := s.decodedToks[i] + // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON + if s.proc.Is(int32(i), model.SpecialEOS) || s.proc.Is(int32(i), model.SpecialBOS) || token == "" || token == "\"\"" { + continue + } + curNode := node + valid := true + consumedSpecialRunes := make(map[rune]bool) + for _, r := range token { + curNode, valid = isRValid(r, curNode, consumedSpecialRunes) + if curNode == nil || !valid { + break + } + } + if valid { + if curNode.State == StateInFunction { + // fmt.Println("cm curNode", curNode.State) + // fmt.Println("cm token", s.decodedToks[i]) + } + node.MaskTokenIDToNode[int32(i)] = curNode + } + } + return nil +} + +func isRValid(r rune, curNode *Node, consumedSpecialRunes map[rune]bool) (*Node, bool) { + if consumedSpecialRunes[r] { + return nil, false + } + + specialRune := slices.Contains(stringInvalidRunes, r) + if specialRune { + if curNode.State == PStateInString || curNode.State == PStateInStringEnd { + return nil, false + } + } + + // Check for specific rune transition + if nextNode, ok := curNode.TransitionEdges[r]; ok { + // fmt.Println("next node", nextNode) + if specialRune { + if curNode.State == nextNode.State { + return nil, false + } + consumedSpecialRunes[r] = true + } + return nextNode, true + } + + // Check for sentinel value - if present, any rune is valid + if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok { + return nextNode, true + } + + return nil, false +} + +func (s *PythonSampler) maskLogits(logits []float32, node *Node) ([]float32, error) { + // Create a new slice with same length as logits, initialized to -Inf + maskedLogits := make([]float32, len(logits)) + for i := range maskedLogits { + maskedLogits[i] = float32(math.Inf(-1)) + } + + // Only update values for valid token IDs from the mask map + for tokenID := range node.MaskTokenIDToNode { + if int(tokenID) < len(logits) { + maskedLogits[tokenID] = logits[tokenID] + } + } + + return maskedLogits, nil +} + +func finish(s *PythonSampler, logits []float32) ([]float32, error) { + for i := range logits { + if s.proc.Is(int32(i), model.SpecialEOS) { + logits[i] = 1.0 + } else { + logits[i] = float32(math.Inf(-1)) + } + } + return logits, nil +}