From e18540fecc882c10d419fef9cb89cb0386843adf Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Thu, 27 Mar 2025 11:26:49 -0700 Subject: [PATCH] sample: wip structured outputs work --- sample/state_machine.go | 176 +++++++++++++++++++++++++++ sample/structured_outputs.go | 3 + sample/structured_outputs_test.go | 194 ++++++++++++++++++++++++++++++ 3 files changed, 373 insertions(+) create mode 100644 sample/state_machine.go create mode 100644 sample/structured_outputs.go create mode 100644 sample/structured_outputs_test.go diff --git a/sample/state_machine.go b/sample/state_machine.go new file mode 100644 index 000000000..766ab64b0 --- /dev/null +++ b/sample/state_machine.go @@ -0,0 +1,176 @@ +package sample + +import ( + "bytes" + "strings" + + "github.com/ollama/ollama/model" +) + +type Node struct { + TransitionEdges map[rune]*Node +} + +type Graph struct { + proc model.TextProcessor + decodedToks []string + curNode *Node + grammar []byte + rules map[string]string +} + +// baseRules is the set of rules that are used to parse the grammar +// JSON grammar from RFC 7159 +var baseRules = map[string]string{ + "object": "\"{\" (kv (\",\" kv)*)? \"}\"", + "array": "\"[\" (value (\",\" value)*)? \"]\"", + "string": "\"\\\"\" char* \"\\\"\"", + "number": "\"-\"? integer frac? exp?", + "kv": "string \":\" value", + "integer": "\"0\" | [1-9] [0-9]*", + "frac": "\".\" [0-9]+", + "exp": "(\"e\" | \"E\") (\"+\" | \"-\") [0-9]+", + "escape": "[\"/\" | \"b\" | \"f\" | \"n\" | \"r\" | \"t\" | unicode]", + "char": "[^\"\\\\] | escape", + "space": "(\" \" | \"\\t\" | \"\\n\" | \"\\r\")*", + "hex": "[0-9] | [a-f] | [A-F]", + "boolean": "\"true\" | \"false\"", + "value": "object | array | string | number | boolean | \"null\"", + "null": "\"null\"", +} + +func (g *Graph) BuildGraph(node *Node) error { + vocab := g.proc.Vocab() + decodedToks := make([]string, len(vocab.Values)) + for i := range vocab.Values { + token, err := g.proc.Decode([]int32{int32(i)}) + if err != nil { + return err + } + decodedToks[i] = token + } + + g.decodedToks = decodedToks + g.rules = baseRules + g.rootPrefixes() + rootNode := &Node{ + TransitionEdges: make(map[rune]*Node), + } + g.parseRule(g.rules["root"], rootNode) + + return nil +} + +// rootPrefixes extracts all root prefixes from the grammar +// and parses the grammar string to extract root prefixes +func (g *Graph) rootPrefixes() { + lines := bytes.Split(g.grammar, []byte("\n")) + for _, line := range lines { + line = bytes.TrimSpace(line) + if len(line) == 0 || bytes.HasPrefix(line, []byte("#")) { + continue + } + + parts := bytes.SplitN(line, []byte("::="), 2) + if len(parts) != 2 { + continue + } + + ruleName := string(bytes.TrimSpace(parts[0])) + if strings.HasPrefix(ruleName, "root") { + g.rules[ruleName] = string(bytes.TrimSpace(parts[1])) + } + } +} + +// parseRule parses a grammar rule and returns a Node +func (g *Graph) parseRule(rule string, curNode *Node) *Node { + /* + Here are the special characters in BNF grammar and their functions: + ::= - Definition operator, means "is defined as" + | - Alternation, means "or" + * - Zero or more repetitions of preceding element + + - One or more repetitions + ? - Optional (zero or one occurrence) + [] - Character class, matches any single character within brackets + [^] - Negated character class, matches any character NOT listed + () - Grouping of elements + - - Range operator in character classes (e.g., [a-z]) + "" - Literal string match + */ + + // Split rule into tokens by whitespace + tokens := strings.Fields(rule) + if len(tokens) == 0 { + return &Node{ + TransitionEdges: make(map[rune]*Node), + } + } + + // Handle integer rule + if strings.Contains(rule, "[0-9]+") { + // Create node for first digit 1-9 + firstDigitNode := &Node{ + TransitionEdges: make(map[rune]*Node), + } + for r := '1'; r <= '9'; r++ { + curNode.TransitionEdges[r] = firstDigitNode + } + + // Create node for subsequent digits 0-9 + zeroToNineNode := &Node{ + TransitionEdges: make(map[rune]*Node), + } + for r := '0'; r <= '9'; r++ { + // Loop back to same node for * operator + zeroToNineNode.TransitionEdges[r] = zeroToNineNode + } + + // Connect first digit to subsequent digits + firstDigitNode.TransitionEdges = zeroToNineNode.TransitionEdges + + // Also handle the "0" case + if strings.Contains(rule, "\"0\"") { + zeroNode := &Node{ + TransitionEdges: make(map[rune]*Node), + } + curNode.TransitionEdges['0'] = zeroNode + } + + return curNode + } + + // recursive case + // grammar options + // TODO: handle left recursion + if strings.Contains(rule, "|") { + parts := strings.Split(rule, "|") + savedNode := curNode + for _, part := range parts { + // TODO: add correct transitions + g.parseRule(part, savedNode) + } + } + + for _, token := range tokens { + if strings.HasPrefix(token, "\"") && strings.HasSuffix(token, "\"") { + token = strings.Trim(token, "\"") + + for _, r := range token { + newNode := &Node{ + TransitionEdges: make(map[rune]*Node), + } + curNode.TransitionEdges[r] = newNode + curNode = newNode + } + // strNode := &Node{ + // TransitionEdges: make(map[rune]*Node), + // } + + // TODO: length constraint + // to self + } + } + + return curNode +} diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go new file mode 100644 index 000000000..e09a4e0f3 --- /dev/null +++ b/sample/structured_outputs.go @@ -0,0 +1,3 @@ +package sample + +type StructuredOutput struct{} diff --git a/sample/structured_outputs_test.go b/sample/structured_outputs_test.go new file mode 100644 index 000000000..b592f0562 --- /dev/null +++ b/sample/structured_outputs_test.go @@ -0,0 +1,194 @@ +package sample + +import ( + "testing" + + "github.com/ollama/ollama/model" +) + +func TestBuildGraph(t *testing.T) { + tests := []struct { + name string + grammar []byte + wantErr bool + }{ + { + name: "empty grammar", + grammar: []byte{}, + wantErr: false, + }, + { + name: "valid grammar", + grammar: []byte(`root ::= value +value ::= string | number`), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := &Graph{ + proc: &mockProcessor{}, + grammar: tt.grammar, + rules: make(map[string]string), + } + + node := &Node{ + TransitionEdges: make(map[rune]*Node), + } + + err := g.BuildGraph(node) + if (err != nil) != tt.wantErr { + t.Errorf("BuildGraph() error = %v, wantErr %v", err, tt.wantErr) + } + + if !tt.wantErr { + if len(g.decodedToks) == 0 { + t.Error("Expected decoded tokens, got none") + } + if len(g.rules) == 0 { + t.Error("Expected rules to be populated") + } + } + }) + } +} + +func TestRootPrefixes(t *testing.T) { + tests := []struct { + name string + grammar []byte + expected map[string]string + }{ + { + name: "empty grammar", + grammar: []byte{}, + expected: map[string]string{}, + }, + { + name: "grammar with root prefix", + grammar: []byte(`root ::= value +root_string ::= string`), + expected: map[string]string{ + "root": "value", + "root_string": "string", + }, + }, + { + name: "grammar with comments and empty lines", + grammar: []byte(`# comment +root ::= value + +# another comment +root_number ::= number`), + expected: map[string]string{ + "root": "value", + "root_number": "number", + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := &Graph{ + grammar: tt.grammar, + rules: make(map[string]string), + } + + g.rootPrefixes() + + for k, v := range tt.expected { + if actual, ok := g.rules[k]; !ok || actual != v { + t.Errorf("Expected rule %s = %s, got %s", k, v, actual) + } + } + }) + } +} + +func TestParseRule(t *testing.T) { + tests := []struct { + name string + rule string + expected string + }{ + { + name: "empty rule", + rule: "", + expected: "", + }, + { + name: "simple string", + rule: "root ::= \"test_string\"", + expected: "test_string", + }, + { + name: "simple string", + rule: "root ::= \"test_string\" | \"test_string2\"", + expected: "test_stringtest_string2", + }, + { + name: "integer", + rule: "root ::= [0-9]+", + // TODO: this is infinite acutally + expected: "0123456789", + }, + // TODO: handle left recursion + // { + // name: "left recursion", + // rule: "root ::= root \"test_string\"", + // expected: "test_string", + // }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + g := &Graph{ + rules: make(map[string]string), + } + + rootNode := &Node{ + TransitionEdges: make(map[rune]*Node), + } + curNode := rootNode + g.parseRule(tt.rule, curNode) + sb := "" + for { + if len(curNode.TransitionEdges) == 0 { + break + } + + for r, n := range curNode.TransitionEdges { + sb += string(r) + curNode = n + } + t.Logf("sb: %s", sb) + } + + if sb != tt.expected { + t.Errorf("Expected %s, got %s", tt.expected, sb) + } + }) + } +} + +// mockProcessor implements the TextProcessor interface for testing +type mockProcessor struct{} + +func (m *mockProcessor) Decode(tokens []int32) (string, error) { + return "test", nil +} + +func (m *mockProcessor) Vocab() *model.Vocabulary { + return &model.Vocabulary{ + Values: []string{"test1", "test2"}, + } +} + +func (m *mockProcessor) Encode(s string, addSpecial bool) ([]int32, error) { + return []int32{0, 1}, nil +} + +func (m *mockProcessor) Is(token int32, special model.Special) bool { + return false +}