sample: wip structured outputs work

This commit is contained in:
ParthSareen 2025-03-27 11:26:49 -07:00
parent 040e65abce
commit e18540fecc
3 changed files with 373 additions and 0 deletions

176
sample/state_machine.go Normal file
View File

@ -0,0 +1,176 @@
package sample
import (
"bytes"
"strings"
"github.com/ollama/ollama/model"
)
type Node struct {
TransitionEdges map[rune]*Node
}
type Graph struct {
proc model.TextProcessor
decodedToks []string
curNode *Node
grammar []byte
rules map[string]string
}
// baseRules is the set of rules that are used to parse the grammar
// JSON grammar from RFC 7159
var baseRules = map[string]string{
"object": "\"{\" (kv (\",\" kv)*)? \"}\"",
"array": "\"[\" (value (\",\" value)*)? \"]\"",
"string": "\"\\\"\" char* \"\\\"\"",
"number": "\"-\"? integer frac? exp?",
"kv": "string \":\" value",
"integer": "\"0\" | [1-9] [0-9]*",
"frac": "\".\" [0-9]+",
"exp": "(\"e\" | \"E\") (\"+\" | \"-\") [0-9]+",
"escape": "[\"/\" | \"b\" | \"f\" | \"n\" | \"r\" | \"t\" | unicode]",
"char": "[^\"\\\\] | escape",
"space": "(\" \" | \"\\t\" | \"\\n\" | \"\\r\")*",
"hex": "[0-9] | [a-f] | [A-F]",
"boolean": "\"true\" | \"false\"",
"value": "object | array | string | number | boolean | \"null\"",
"null": "\"null\"",
}
func (g *Graph) BuildGraph(node *Node) error {
vocab := g.proc.Vocab()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := g.proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
g.decodedToks = decodedToks
g.rules = baseRules
g.rootPrefixes()
rootNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
g.parseRule(g.rules["root"], rootNode)
return nil
}
// rootPrefixes extracts all root prefixes from the grammar
// and parses the grammar string to extract root prefixes
func (g *Graph) rootPrefixes() {
lines := bytes.Split(g.grammar, []byte("\n"))
for _, line := range lines {
line = bytes.TrimSpace(line)
if len(line) == 0 || bytes.HasPrefix(line, []byte("#")) {
continue
}
parts := bytes.SplitN(line, []byte("::="), 2)
if len(parts) != 2 {
continue
}
ruleName := string(bytes.TrimSpace(parts[0]))
if strings.HasPrefix(ruleName, "root") {
g.rules[ruleName] = string(bytes.TrimSpace(parts[1]))
}
}
}
// parseRule parses a grammar rule and returns a Node
func (g *Graph) parseRule(rule string, curNode *Node) *Node {
/*
Here are the special characters in BNF grammar and their functions:
::= - Definition operator, means "is defined as"
| - Alternation, means "or"
* - Zero or more repetitions of preceding element
+ - One or more repetitions
? - Optional (zero or one occurrence)
[] - Character class, matches any single character within brackets
[^] - Negated character class, matches any character NOT listed
() - Grouping of elements
- - Range operator in character classes (e.g., [a-z])
"" - Literal string match
*/
// Split rule into tokens by whitespace
tokens := strings.Fields(rule)
if len(tokens) == 0 {
return &Node{
TransitionEdges: make(map[rune]*Node),
}
}
// Handle integer rule
if strings.Contains(rule, "[0-9]+") {
// Create node for first digit 1-9
firstDigitNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
for r := '1'; r <= '9'; r++ {
curNode.TransitionEdges[r] = firstDigitNode
}
// Create node for subsequent digits 0-9
zeroToNineNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
for r := '0'; r <= '9'; r++ {
// Loop back to same node for * operator
zeroToNineNode.TransitionEdges[r] = zeroToNineNode
}
// Connect first digit to subsequent digits
firstDigitNode.TransitionEdges = zeroToNineNode.TransitionEdges
// Also handle the "0" case
if strings.Contains(rule, "\"0\"") {
zeroNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode.TransitionEdges['0'] = zeroNode
}
return curNode
}
// recursive case
// grammar options
// TODO: handle left recursion
if strings.Contains(rule, "|") {
parts := strings.Split(rule, "|")
savedNode := curNode
for _, part := range parts {
// TODO: add correct transitions
g.parseRule(part, savedNode)
}
}
for _, token := range tokens {
if strings.HasPrefix(token, "\"") && strings.HasSuffix(token, "\"") {
token = strings.Trim(token, "\"")
for _, r := range token {
newNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode.TransitionEdges[r] = newNode
curNode = newNode
}
// strNode := &Node{
// TransitionEdges: make(map[rune]*Node),
// }
// TODO: length constraint
// to self
}
}
return curNode
}

View File

@ -0,0 +1,3 @@
package sample
type StructuredOutput struct{}

View File

@ -0,0 +1,194 @@
package sample
import (
"testing"
"github.com/ollama/ollama/model"
)
func TestBuildGraph(t *testing.T) {
tests := []struct {
name string
grammar []byte
wantErr bool
}{
{
name: "empty grammar",
grammar: []byte{},
wantErr: false,
},
{
name: "valid grammar",
grammar: []byte(`root ::= value
value ::= string | number`),
wantErr: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
proc: &mockProcessor{},
grammar: tt.grammar,
rules: make(map[string]string),
}
node := &Node{
TransitionEdges: make(map[rune]*Node),
}
err := g.BuildGraph(node)
if (err != nil) != tt.wantErr {
t.Errorf("BuildGraph() error = %v, wantErr %v", err, tt.wantErr)
}
if !tt.wantErr {
if len(g.decodedToks) == 0 {
t.Error("Expected decoded tokens, got none")
}
if len(g.rules) == 0 {
t.Error("Expected rules to be populated")
}
}
})
}
}
func TestRootPrefixes(t *testing.T) {
tests := []struct {
name string
grammar []byte
expected map[string]string
}{
{
name: "empty grammar",
grammar: []byte{},
expected: map[string]string{},
},
{
name: "grammar with root prefix",
grammar: []byte(`root ::= value
root_string ::= string`),
expected: map[string]string{
"root": "value",
"root_string": "string",
},
},
{
name: "grammar with comments and empty lines",
grammar: []byte(`# comment
root ::= value
# another comment
root_number ::= number`),
expected: map[string]string{
"root": "value",
"root_number": "number",
},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
grammar: tt.grammar,
rules: make(map[string]string),
}
g.rootPrefixes()
for k, v := range tt.expected {
if actual, ok := g.rules[k]; !ok || actual != v {
t.Errorf("Expected rule %s = %s, got %s", k, v, actual)
}
}
})
}
}
func TestParseRule(t *testing.T) {
tests := []struct {
name string
rule string
expected string
}{
{
name: "empty rule",
rule: "",
expected: "",
},
{
name: "simple string",
rule: "root ::= \"test_string\"",
expected: "test_string",
},
{
name: "simple string",
rule: "root ::= \"test_string\" | \"test_string2\"",
expected: "test_stringtest_string2",
},
{
name: "integer",
rule: "root ::= [0-9]+",
// TODO: this is infinite acutally
expected: "0123456789",
},
// TODO: handle left recursion
// {
// name: "left recursion",
// rule: "root ::= root \"test_string\"",
// expected: "test_string",
// },
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
g := &Graph{
rules: make(map[string]string),
}
rootNode := &Node{
TransitionEdges: make(map[rune]*Node),
}
curNode := rootNode
g.parseRule(tt.rule, curNode)
sb := ""
for {
if len(curNode.TransitionEdges) == 0 {
break
}
for r, n := range curNode.TransitionEdges {
sb += string(r)
curNode = n
}
t.Logf("sb: %s", sb)
}
if sb != tt.expected {
t.Errorf("Expected %s, got %s", tt.expected, sb)
}
})
}
}
// mockProcessor implements the TextProcessor interface for testing
type mockProcessor struct{}
func (m *mockProcessor) Decode(tokens []int32) (string, error) {
return "test", nil
}
func (m *mockProcessor) Vocab() *model.Vocabulary {
return &model.Vocabulary{
Values: []string{"test1", "test2"},
}
}
func (m *mockProcessor) Encode(s string, addSpecial bool) ([]int32, error) {
return []int32{0, 1}, nil
}
func (m *mockProcessor) Is(token int32, special model.Special) bool {
return false
}