mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
sample: wip structured outputs work
This commit is contained in:
parent
040e65abce
commit
e18540fecc
176
sample/state_machine.go
Normal file
176
sample/state_machine.go
Normal file
@ -0,0 +1,176 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type Node struct {
|
||||
TransitionEdges map[rune]*Node
|
||||
}
|
||||
|
||||
type Graph struct {
|
||||
proc model.TextProcessor
|
||||
decodedToks []string
|
||||
curNode *Node
|
||||
grammar []byte
|
||||
rules map[string]string
|
||||
}
|
||||
|
||||
// baseRules is the set of rules that are used to parse the grammar
|
||||
// JSON grammar from RFC 7159
|
||||
var baseRules = map[string]string{
|
||||
"object": "\"{\" (kv (\",\" kv)*)? \"}\"",
|
||||
"array": "\"[\" (value (\",\" value)*)? \"]\"",
|
||||
"string": "\"\\\"\" char* \"\\\"\"",
|
||||
"number": "\"-\"? integer frac? exp?",
|
||||
"kv": "string \":\" value",
|
||||
"integer": "\"0\" | [1-9] [0-9]*",
|
||||
"frac": "\".\" [0-9]+",
|
||||
"exp": "(\"e\" | \"E\") (\"+\" | \"-\") [0-9]+",
|
||||
"escape": "[\"/\" | \"b\" | \"f\" | \"n\" | \"r\" | \"t\" | unicode]",
|
||||
"char": "[^\"\\\\] | escape",
|
||||
"space": "(\" \" | \"\\t\" | \"\\n\" | \"\\r\")*",
|
||||
"hex": "[0-9] | [a-f] | [A-F]",
|
||||
"boolean": "\"true\" | \"false\"",
|
||||
"value": "object | array | string | number | boolean | \"null\"",
|
||||
"null": "\"null\"",
|
||||
}
|
||||
|
||||
func (g *Graph) BuildGraph(node *Node) error {
|
||||
vocab := g.proc.Vocab()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := g.proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
|
||||
g.decodedToks = decodedToks
|
||||
g.rules = baseRules
|
||||
g.rootPrefixes()
|
||||
rootNode := &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
g.parseRule(g.rules["root"], rootNode)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// rootPrefixes extracts all root prefixes from the grammar
|
||||
// and parses the grammar string to extract root prefixes
|
||||
func (g *Graph) rootPrefixes() {
|
||||
lines := bytes.Split(g.grammar, []byte("\n"))
|
||||
for _, line := range lines {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 || bytes.HasPrefix(line, []byte("#")) {
|
||||
continue
|
||||
}
|
||||
|
||||
parts := bytes.SplitN(line, []byte("::="), 2)
|
||||
if len(parts) != 2 {
|
||||
continue
|
||||
}
|
||||
|
||||
ruleName := string(bytes.TrimSpace(parts[0]))
|
||||
if strings.HasPrefix(ruleName, "root") {
|
||||
g.rules[ruleName] = string(bytes.TrimSpace(parts[1]))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// parseRule parses a grammar rule and returns a Node
|
||||
func (g *Graph) parseRule(rule string, curNode *Node) *Node {
|
||||
/*
|
||||
Here are the special characters in BNF grammar and their functions:
|
||||
::= - Definition operator, means "is defined as"
|
||||
| - Alternation, means "or"
|
||||
* - Zero or more repetitions of preceding element
|
||||
+ - One or more repetitions
|
||||
? - Optional (zero or one occurrence)
|
||||
[] - Character class, matches any single character within brackets
|
||||
[^] - Negated character class, matches any character NOT listed
|
||||
() - Grouping of elements
|
||||
- - Range operator in character classes (e.g., [a-z])
|
||||
"" - Literal string match
|
||||
*/
|
||||
|
||||
// Split rule into tokens by whitespace
|
||||
tokens := strings.Fields(rule)
|
||||
if len(tokens) == 0 {
|
||||
return &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
}
|
||||
|
||||
// Handle integer rule
|
||||
if strings.Contains(rule, "[0-9]+") {
|
||||
// Create node for first digit 1-9
|
||||
firstDigitNode := &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
for r := '1'; r <= '9'; r++ {
|
||||
curNode.TransitionEdges[r] = firstDigitNode
|
||||
}
|
||||
|
||||
// Create node for subsequent digits 0-9
|
||||
zeroToNineNode := &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
for r := '0'; r <= '9'; r++ {
|
||||
// Loop back to same node for * operator
|
||||
zeroToNineNode.TransitionEdges[r] = zeroToNineNode
|
||||
}
|
||||
|
||||
// Connect first digit to subsequent digits
|
||||
firstDigitNode.TransitionEdges = zeroToNineNode.TransitionEdges
|
||||
|
||||
// Also handle the "0" case
|
||||
if strings.Contains(rule, "\"0\"") {
|
||||
zeroNode := &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
curNode.TransitionEdges['0'] = zeroNode
|
||||
}
|
||||
|
||||
return curNode
|
||||
}
|
||||
|
||||
// recursive case
|
||||
// grammar options
|
||||
// TODO: handle left recursion
|
||||
if strings.Contains(rule, "|") {
|
||||
parts := strings.Split(rule, "|")
|
||||
savedNode := curNode
|
||||
for _, part := range parts {
|
||||
// TODO: add correct transitions
|
||||
g.parseRule(part, savedNode)
|
||||
}
|
||||
}
|
||||
|
||||
for _, token := range tokens {
|
||||
if strings.HasPrefix(token, "\"") && strings.HasSuffix(token, "\"") {
|
||||
token = strings.Trim(token, "\"")
|
||||
|
||||
for _, r := range token {
|
||||
newNode := &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
curNode.TransitionEdges[r] = newNode
|
||||
curNode = newNode
|
||||
}
|
||||
// strNode := &Node{
|
||||
// TransitionEdges: make(map[rune]*Node),
|
||||
// }
|
||||
|
||||
// TODO: length constraint
|
||||
// to self
|
||||
}
|
||||
}
|
||||
|
||||
return curNode
|
||||
}
|
3
sample/structured_outputs.go
Normal file
3
sample/structured_outputs.go
Normal file
@ -0,0 +1,3 @@
|
||||
package sample
|
||||
|
||||
type StructuredOutput struct{}
|
194
sample/structured_outputs_test.go
Normal file
194
sample/structured_outputs_test.go
Normal file
@ -0,0 +1,194 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func TestBuildGraph(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grammar []byte
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "empty grammar",
|
||||
grammar: []byte{},
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid grammar",
|
||||
grammar: []byte(`root ::= value
|
||||
value ::= string | number`),
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Graph{
|
||||
proc: &mockProcessor{},
|
||||
grammar: tt.grammar,
|
||||
rules: make(map[string]string),
|
||||
}
|
||||
|
||||
node := &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
|
||||
err := g.BuildGraph(node)
|
||||
if (err != nil) != tt.wantErr {
|
||||
t.Errorf("BuildGraph() error = %v, wantErr %v", err, tt.wantErr)
|
||||
}
|
||||
|
||||
if !tt.wantErr {
|
||||
if len(g.decodedToks) == 0 {
|
||||
t.Error("Expected decoded tokens, got none")
|
||||
}
|
||||
if len(g.rules) == 0 {
|
||||
t.Error("Expected rules to be populated")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestRootPrefixes(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
grammar []byte
|
||||
expected map[string]string
|
||||
}{
|
||||
{
|
||||
name: "empty grammar",
|
||||
grammar: []byte{},
|
||||
expected: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "grammar with root prefix",
|
||||
grammar: []byte(`root ::= value
|
||||
root_string ::= string`),
|
||||
expected: map[string]string{
|
||||
"root": "value",
|
||||
"root_string": "string",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "grammar with comments and empty lines",
|
||||
grammar: []byte(`# comment
|
||||
root ::= value
|
||||
|
||||
# another comment
|
||||
root_number ::= number`),
|
||||
expected: map[string]string{
|
||||
"root": "value",
|
||||
"root_number": "number",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Graph{
|
||||
grammar: tt.grammar,
|
||||
rules: make(map[string]string),
|
||||
}
|
||||
|
||||
g.rootPrefixes()
|
||||
|
||||
for k, v := range tt.expected {
|
||||
if actual, ok := g.rules[k]; !ok || actual != v {
|
||||
t.Errorf("Expected rule %s = %s, got %s", k, v, actual)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestParseRule(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
rule string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "empty rule",
|
||||
rule: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "simple string",
|
||||
rule: "root ::= \"test_string\"",
|
||||
expected: "test_string",
|
||||
},
|
||||
{
|
||||
name: "simple string",
|
||||
rule: "root ::= \"test_string\" | \"test_string2\"",
|
||||
expected: "test_stringtest_string2",
|
||||
},
|
||||
{
|
||||
name: "integer",
|
||||
rule: "root ::= [0-9]+",
|
||||
// TODO: this is infinite acutally
|
||||
expected: "0123456789",
|
||||
},
|
||||
// TODO: handle left recursion
|
||||
// {
|
||||
// name: "left recursion",
|
||||
// rule: "root ::= root \"test_string\"",
|
||||
// expected: "test_string",
|
||||
// },
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
g := &Graph{
|
||||
rules: make(map[string]string),
|
||||
}
|
||||
|
||||
rootNode := &Node{
|
||||
TransitionEdges: make(map[rune]*Node),
|
||||
}
|
||||
curNode := rootNode
|
||||
g.parseRule(tt.rule, curNode)
|
||||
sb := ""
|
||||
for {
|
||||
if len(curNode.TransitionEdges) == 0 {
|
||||
break
|
||||
}
|
||||
|
||||
for r, n := range curNode.TransitionEdges {
|
||||
sb += string(r)
|
||||
curNode = n
|
||||
}
|
||||
t.Logf("sb: %s", sb)
|
||||
}
|
||||
|
||||
if sb != tt.expected {
|
||||
t.Errorf("Expected %s, got %s", tt.expected, sb)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// mockProcessor implements the TextProcessor interface for testing
|
||||
type mockProcessor struct{}
|
||||
|
||||
func (m *mockProcessor) Decode(tokens []int32) (string, error) {
|
||||
return "test", nil
|
||||
}
|
||||
|
||||
func (m *mockProcessor) Vocab() *model.Vocabulary {
|
||||
return &model.Vocabulary{
|
||||
Values: []string{"test1", "test2"},
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockProcessor) Encode(s string, addSpecial bool) ([]int32, error) {
|
||||
return []int32{0, 1}, nil
|
||||
}
|
||||
|
||||
func (m *mockProcessor) Is(token int32, special model.Special) bool {
|
||||
return false
|
||||
}
|
Loading…
x
Reference in New Issue
Block a user