wip with json stuff and cleanup

This commit is contained in:
ParthSareen 2025-02-11 16:40:40 -08:00
parent 25edfa6fdb
commit aa6d5151df
10 changed files with 561 additions and 330 deletions

49
sample/constrained.go Normal file
View File

@ -0,0 +1,49 @@
package sample
import (
"github.com/ollama/ollama/model"
)
type ConstrainedSampler struct {
schema *Schema
propIdx int
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, err
}
// if schema == nil {
return &ConstrainedSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
return nil, nil
}
func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
return err
}
if s.schema == nil {
return nil
}
return nil
}

32
sample/feedback.txt Normal file
View File

@ -0,0 +1,32 @@
// Feedback from code review:
// pushdown_automata.go:
// 1. The BuildGraph function is quite long and could be split into smaller, more focused functions
// 2. Consider using constants instead of magic runes like rune(-1) for sentinel values
// 3. The state machine transitions could be defined more declaratively, perhaps in a config
// 4. The stringInvalidRunes list needs to handle escape sequences properly
// 5. The graph building could be optimized to avoid duplicate nodes/transitions
// 6. Consider adding validation for max nesting depth of braces/brackets
// 7. The CreateMask function is doing a lot - could be split into smaller pieces
// 8. isRuneValid has a "garbage interface" per TODO - needs cleaner design
// pushdown_runner.go:
// 1. The Apply method has a lot of duplicated logic around EOS handling
// 2. The UpdateState method could use more granular error messages
// 3. The braceStack validation could be moved to a separate validator
// 4. Consider adding max length limits for strings/numbers
// 5. The stateCounter isn't being used effectively yet
// 6. Need to add penalties for staying in same state too long
// 7. The maskLogits function could be optimized to avoid allocations
// 8. Missing proper cleanup/reset functionality
// 9. Error handling could be more consistent throughout
// 10. Consider adding debug logging levels instead of raw fmt.Println
// General improvements needed:
// - More comprehensive testing, especially edge cases
// - Better documentation of state machine transitions
// - Performance optimization for large inputs
// - Memory usage optimization for the graph structure
// - Cleaner interfaces between components
// - More robust error handling and recovery

View File

@ -0,0 +1,11 @@
package sample
// type fusedMaskSampler struct{}
// func FusedMaskSampler() Sampler {
// return fusedMaskSampler{}
// }
// func (f fusedMaskSampler) Sample(logits []float64) (int, error) {
// return int(logits[0]), nil
// }

View File

@ -8,6 +8,19 @@ func Greedy() Sampler {
return greedy{}
}
func (s greedy) Sample(t []float64) (int, error) {
return floats.MaxIdx(t), nil
func (s greedy) Sample(logits []float32, transforms ...Transform) (int, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
var err error
for _, t := range transforms {
logits64, err = t.Apply(logits64)
if err != nil {
return -1, err
}
}
return floats.MaxIdx(logits64), nil
}

View File

@ -23,7 +23,9 @@ const (
StateInColon
StateInComma
StateInTab
StateInSpace
StateInSpaceToValue
StateInSpaceEndValue
StateInNewlineEndValue
StateInObjSpace
StateInList
StateInListComma
@ -57,7 +59,9 @@ var JSONStates = []JSONState{
StateInColon,
StateInComma,
StateInTab,
StateInSpace,
StateInSpaceToValue,
StateInSpaceEndValue,
StateInNewlineEndValue,
StateInObjSpace,
StateInList,
StateInListComma,
@ -107,7 +111,7 @@ func (s JSONState) String() string {
return "StateInComma"
case StateInTab:
return "StateInTab"
case StateInSpace:
case StateInSpaceToValue:
return "StateInSpace"
case StateInObjSpace:
return "StateInObjSpace"
@ -121,6 +125,8 @@ func (s JSONState) String() string {
return "StateInListEnd"
case StateInNewline:
return "StateInNewline"
case StateInNewlineEndValue:
return "StateInNewlineEndValue"
case StateInNumber:
return "StateInNumber"
case StateInNumberEnd:
@ -129,6 +135,8 @@ func (s JSONState) String() string {
return "StateInStringEnd"
case StateInObjectKeyEnd:
return "StateInObjectKeyEnd"
case StateInSpaceEndValue:
return "StateInSpaceEndValue"
case StateTerminate:
return "StateTerminate"
case StateInObjectEnd:

View File

@ -6,8 +6,35 @@ import (
"github.com/ollama/ollama/model"
)
/*
Key JSON rules to consider:
1. Whitespace handling:
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
- Current code only handles some whitespace cases
2. Number validation:
- Need proper validation for special number cases like -0
- Should handle .5 style decimals
- Need limits on scientific notation (e, E)
3. String escaping:
- Currently marks \ as invalid but should allow escaped sequences:
- \"
- \n
- \u1234 unicode escapes
4. Empty object/array transitions:
- Direct {} and [] cases could be more explicit
- Need clear transitions for these edge cases
5. Nested depth limits:
- No protection against excessive nesting
- Could cause stack overflow with deeply nested structures
*/
// TODO: / should be valid but an escape character
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'}
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
@ -18,31 +45,31 @@ var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
type PDANode struct {
type PDA struct {
State JSONState
TransitionEdges map[rune]*PDANode
MaskTokenIDToNode map[int32]*PDANode
TransitionEdges map[rune]*PDA
MaskTokenIDToNode map[int32]*PDA
}
func NewPDANode(state JSONState) *PDANode {
return &PDANode{
func NewPDANode(state JSONState) *PDA {
return &PDA{
State: state,
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]*PDANode),
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
}
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
stateToNodeMap := make(map[JSONState]*PDANode)
// TODO: make this a loop
type PDAGraphBuilder struct {
proc model.TextProcessor
decodedToks []string
stateToNodeMap map[JSONState]*PDA
}
func (b *PDAGraphBuilder) BuildGraph() error {
stateToNodeMap := make(map[JSONState]*PDA)
for _, state := range JSONStates {
stateToNodeMap[state] = NewPDANode(state)
}
// TODO:
// consider adding a node to just point to values, could be good to compute that
// mask rather than many different nodes
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
@ -51,10 +78,21 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
//new line
// new line
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
// new line end value
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
// TODO: see if this is needed for formatting
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
@ -68,16 +106,16 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
// where values should be
// this could be combined but the probl might change, we're alr doing a skip ahead
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
b.addValueConnections(stateToNodeMap[StateInColon])
// Leads to a value
stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
b.addValueConnections(stateToNodeMap[StateInSpaceToValue])
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
// Values
// string node
@ -85,149 +123,142 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
// String end node
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
b.addEnds(stateToNodeMap[StateInStringEnd])
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// TODO: add counters for allowable number of decimals, e, E, etc
// number node
for _, r := range validNumberRunes {
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
}
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
// bool node
for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
}
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
b.addEnds(stateToNodeMap[StateInNumber])
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list node
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
// list end node
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// empty list
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
b.addValueConnections(stateToNodeMap[StateInList])
// null node
for _, r := range validNullRunes {
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
}
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
b.addEnds(stateToNodeMap[StateInNull])
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// list comma
// should point to values
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
b.addValueConnections(stateToNodeMap[StateInListComma])
// list object end
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
// TODO: not sure if this is needed
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// bool node
for _, r := range validBoolRunes {
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
}
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
b.addEnds(stateToNodeMap[StateInBool])
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
// comma node
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
// space end value
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
return stateToNodeMap[StateStart], stateToNodeMap, nil
b.stateToNodeMap = stateToNodeMap
if err := b.preComputeValidStates(); err != nil {
return err
}
return nil
}
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
func (b *PDAGraphBuilder) addEnds(node *PDA) {
node.TransitionEdges[','] = b.stateToNodeMap[StateInComma]
node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd]
node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd]
}
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
func (b *PDAGraphBuilder) addValueConnections(node *PDA) {
node.TransitionEdges['"'] = b.stateToNodeMap[StateInString]
for _, r := range validNumberRunes {
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber]
}
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
// TODO(parthsareen): force the output and shift similar to structured outputs
node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool]
node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool]
node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull]
}
// TODO: tough life fr. plz fix.
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
// TODO; should come from top level
vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return err
}
decodedToks[i] = token
}
var err error
for _, node := range stateToNodeMap {
err = CreateMask(node, proc, decodedToks)
if err != nil {
func (b *PDAGraphBuilder) preComputeValidStates() error {
for _, node := range b.stateToNodeMap {
if err := b.CreateMask(node); err != nil {
return err
}
}
return nil
}
func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error {
for i := range decodedToks {
token := decodedToks[i]
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
for i := range b.decodedToks {
token := b.decodedToks[i]
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
if b.proc.Is(uint32(i), model.SpecialEOS) || b.proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
continue
}
valid := true
curNode := node
valid := true
consumedSpecialRunes := make(map[rune]bool)
var err error
for _, r := range token {
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
if err != nil {
return err
}
if !valid {
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
if curNode == nil || !valid {
break
}
}
if valid {
// cur node allows skipping
node.MaskTokenIDToNode[int32(i)] = curNode
}
}
return nil
}
// TODO: garbage interface plz fix
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
if consumedSpecialRunes[r] {
return false, nil, nil
return nil, false
}
specialRune := slices.Contains(stringInvalidRunes, r)
if specialRune {
if curNode.State == StateInString || curNode.State == StateInObjectKey {
return false, nil, nil
return nil, false
}
}
@ -235,17 +266,17 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
if nextNode, ok := curNode.TransitionEdges[r]; ok {
if specialRune {
if curNode.State == nextNode.State {
return false, nil, nil
return nil, false
}
consumedSpecialRunes[r] = true
}
return true, nextNode, nil
return nextNode, true
}
// Check for sentinel value - if present, any rune is valid
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
return true, nextNode, nil
return nextNode, true
}
return false, nil, nil
return nil, false
}

View File

@ -11,17 +11,17 @@ import (
// TODO: safety in case of invalid json
// TODO: interfaces to cleanup with return values
// TODO this interface shouldn't be the sampler - should just use Sampler
// TODO: add penalties for string \n stuff
type PushdownSampler struct {
// stateful
curNode *PDANode
proc model.TextProcessor
stateToNodeMap map[JSONState]*PDANode
braceStack []rune
stateCounter uint32
PDAGraphBuilder
curNode *PDA
braceStack []rune
stateCounter uint32
}
// graph should be built once and reused per tokenizer
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
start := time.Now()
fmt.Println("--------------------------------")
@ -32,27 +32,38 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
before := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
startNode, stateToNodeMap, err := BuildGraph(proc)
if err != nil {
panic(err)
vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
token, err := proc.Decode([]int32{int32(i)})
if err != nil {
return nil, err
}
decodedToks[i] = token
}
err = PreComputeValidStates(stateToNodeMap, proc)
if err != nil {
panic(err)
gb := &PDAGraphBuilder{
proc: proc,
decodedToks: decodedToks,
}
if err := gb.BuildGraph(); err != nil {
return nil, err
}
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Graph build time = %v\n", time.Since(start))
// TODO: this can be simplified
return &PushdownSampler{
curNode: startNode,
proc: proc,
stateToNodeMap: stateToNodeMap,
braceStack: []rune{},
stateCounter: 0,
}
curNode: gb.stateToNodeMap[StateStart],
PDAGraphBuilder: *gb,
braceStack: []rune{},
stateCounter: 0,
}, nil
}
// TODO: need to add resampling logic if the first sample was not good
@ -66,14 +77,7 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
return forceFinish(s, logits)
}
logits, err := s.maskLogits(logits, s.curNode)
@ -82,18 +86,14 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
}
return logits, nil
case StateTerminate:
return forceFinish(s, logits)
case StateInObjectEnd:
// force finish if no braces left
if len(s.braceStack) == 0 {
s.curNode = NewPDANode(StateTerminate)
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
return forceFinish(s, logits)
}
peek := s.braceStack[len(s.braceStack)-1]
@ -112,22 +112,13 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
if peek == rune('[') {
s.curNode = s.stateToNodeMap[StateInListComma]
}
logits, err := s.maskLogits(logits, s.curNode)
if err != nil {
return nil, err
}
return logits, nil
case StateTerminate:
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
default:
fmt.Println("masking logits current state", s.curNode.State)
logits, err := s.maskLogits(logits, s.curNode)
@ -138,13 +129,24 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
}
}
func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
for i := range logits {
if s.proc.Is(uint32(i), model.SpecialEOS) {
logits[i] = 1.0
} else {
logits[i] = math.Inf(-1)
}
}
return logits, nil
}
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
fmt.Println("current state - updating", s.curNode.State)
mappedString, err := s.proc.Decode(tokenSlice)
if err != nil {
return err
}
fmt.Println(">>> mappedString", mappedString)
fmt.Printf(">>> mappedString: %q\n", mappedString)
// TODO: should force closing for all braces - not doing square yet
for _, r := range mappedString {
@ -198,7 +200,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
}
// greedy sample + backtrack?
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
// Create a new slice with same length as logits, initialized to -Inf
maskedLogits := make([]float64, len(logits))
for i := range maskedLogits {
@ -215,4 +218,23 @@ func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64
return maskedLogits, nil
}
// TODO: add penalties for string \n stuff
func (s *PushdownSampler) fastMaskLogits(logits []float64, node *PDA) ([]float64, error) {
maxLogit := math.Inf(-1)
maxIndex := -1
// Find the maximum logit value among valid tokens
for tokenID := range node.MaskTokenIDToNode {
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
maxLogit = logits[tokenID]
maxIndex = int(tokenID)
}
}
if maxIndex == -1 {
return nil, fmt.Errorf("no valid tokens found in mask")
}
logits[0] = float64(maxIndex)
return logits, nil
// return maxIndex, nil
}

View File

@ -6,6 +6,8 @@ import (
"math"
"slices"
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
"golang.org/x/exp/rand"
"gonum.org/v1/gonum/floats"
"gonum.org/v1/gonum/stat/sampleuv"
)
@ -15,33 +17,34 @@ type Transform interface {
}
type Sampler interface {
Sample([]float64) (int, error)
Sample([]float32, ...Transform) (int, error)
}
type SamplerConfig struct {
transforms []Transform
sampler Sampler
}
// NewSampler creates a sampler with the given transforms and sampling method
func NewSampler(transforms []Transform, sampler Sampler) *SamplerConfig {
return &SamplerConfig{
transforms: transforms,
sampler: sampler,
// TODO(parthsareen): potentially cache softmax values
func softmax(logits []float64) []float64 {
var sum float64
tt := make([]float64, len(logits))
for i, v := range logits {
tt[i] = math.Exp(v)
sum += tt[i]
}
floats.Scale(1/sum, tt)
return tt
}
type Temperature float64
func (t Temperature) Apply(logits []float64) ([]float64, error) {
if t == 0 {
return nil, errors.New("use Greedy sampler instead of Temperature(0)")
}
if t < 0 || t > 2 {
return nil, errors.New("temperature must be between 0 and 2")
}
temp := math.Max(float64(t), 1e-7)
// subtracting max logit to avoid under/overflow
maxLogit := floats.Max(logits)
temp := math.Max(float64(t), 1e-7)
maxLogit := slices.Max(logits)
for i := range logits {
logits[i] = (logits[i] - maxLogit) / temp
}
@ -49,52 +52,41 @@ func (t Temperature) Apply(logits []float64) ([]float64, error) {
return logits, nil
}
type softmax struct{}
func Softmax() Transform {
return softmax{}
type logitMap struct {
index int
logit float64
}
func (softmax) Apply(logits []float64) ([]float64, error) {
return computeSoftmax(logits), nil
}
// TODO: cache softmax values
func computeSoftmax(logits []float64) []float64 {
copiedLogits := make([]float64, len(logits))
copy(copiedLogits, logits)
for i := range copiedLogits {
copiedLogits[i] = math.Exp(copiedLogits[i])
}
floatSum := floats.Sum(copiedLogits)
floats.Scale(1.0/floatSum, copiedLogits)
return copiedLogits
func logitMapComparator(a, b logitMap) int {
return -cmp.Compare(a.logit, b.logit)
}
type TopK int
// TODO(parthsareen): avoid having to check all logits after this transform
func (k TopK) Apply(logits []float64) ([]float64, error) {
if k <= 0 {
return nil, errors.New("k must be positive")
return nil, errors.New("k must be greater than 0")
}
if int(k) >= len(logits) {
return logits, nil
}
indices := make([]int, len(logits))
for i := range indices {
indices[i] = i
q := pq.NewWith(logitMapComparator)
for i, logit := range logits {
q.Enqueue(logitMap{index: i, logit: logit})
}
// sort in descending order
slices.SortFunc(indices, func(i, j int) int {
return cmp.Compare(logits[j], logits[i])
})
validLogits := make(map[int]float64)
for range k {
logitMap, _ := q.Dequeue()
validLogits[logitMap.index] = logitMap.logit
}
for _, idx := range indices[k:] {
logits[idx] = math.Inf(-1)
for i := range logits {
if _, ok := validLogits[i]; !ok {
logits[i] = math.Inf(-1)
}
}
return logits, nil
@ -107,8 +99,7 @@ func (p TopP) Apply(logits []float64) ([]float64, error) {
return nil, errors.New("p must be between 0 and 1")
}
probs := computeSoftmax(logits)
probs := softmax(logits)
indices := make([]int, len(probs))
for i := range indices {
indices[i] = i
@ -139,17 +130,11 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
return nil, errors.New("p must be between 0 and 1")
}
probs := computeSoftmax(logits)
copiedProbs := make([]float64, len(probs))
copy(copiedProbs, probs)
probs := softmax(logits)
threshold := slices.Max(probs) * float64(p)
slices.Sort(copiedProbs)
maxProb := copiedProbs[len(copiedProbs)-1]
probThreshold := float64(p) * maxProb
for i := range probs {
if probs[i] < probThreshold {
for i, prob := range probs {
if prob < threshold {
logits[i] = math.Inf(-1)
}
}
@ -157,18 +142,35 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
return logits, nil
}
type weighed struct{}
func Weighed() Sampler {
return weighed{}
type weighted struct {
src rand.Source
}
// should return single value
func (s weighed) Sample(logits []float64) (int, error) {
func Weighted(seed *int64) Sampler {
var src rand.Source
if seed != nil {
src = rand.NewSource(uint64(*seed))
}
return weighted{src: src}
}
func (s weighted) Sample(logits []float32, transforms ...Transform) (int, error) {
logits64 := make([]float64, len(logits))
for i, v := range logits {
logits64[i] = float64(v)
}
var err error
for _, t := range transforms {
logits64, err = t.Apply(logits64)
if err != nil {
return -1, err
}
}
logitsCopy := make([]float64, 0, len(logits))
indices := make([]int, 0, len(logits))
// the uv sampler does not support NaN values
for i, logit := range logits {
for i, logit := range logits64 {
if !math.IsInf(logit, -1) {
logitsCopy = append(logitsCopy, logit)
indices = append(indices, i)
@ -176,38 +178,13 @@ func (s weighed) Sample(logits []float64) (int, error) {
}
if len(logitsCopy) == 0 {
return -1, errors.New("no valid tokens found")
return -1, errors.New("no valid logits found for weighed sampling")
}
softmax := computeSoftmax(logitsCopy)
w := sampleuv.NewWeighted(softmax, nil)
probs := softmax(logitsCopy)
w := sampleuv.NewWeighted(probs, s.src)
if idx, ok := w.Take(); ok {
// returns the token ID
return indices[idx], nil
}
return -1, errors.New("weighed sampler failed")
}
// Sample applies transforms and samples a token ID
func (s *SamplerConfig) Sample(input []float32) (int, error) {
logits := make([]float64, len(input))
for i, v := range input {
logits[i] = float64(v)
}
var err error
for _, t := range s.transforms {
if t == Temperature(0) {
// early return with greedy if temperature is 0
s.sampler = Greedy()
break
}
logits, err = t.Apply(logits)
if err != nil {
return -1, err
}
}
return s.sampler.Sample(logits)
return -1, errors.New("weighed sampler failed, no valid token found")
}

View File

@ -3,116 +3,129 @@ package sample
import (
"fmt"
"math"
"slices"
"math/rand/v2"
"testing"
"gonum.org/v1/gonum/floats"
"github.com/google/go-cmp/cmp"
)
func TestTemperature(t *testing.T) {
logits, err := Temperature(0.5).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
logits, err := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
if err != nil {
t.Fatal(err)
t.Error(err)
return
}
want := []float64{-14, -12, -10, -8, -6, -4, 0}
if !floats.Equal(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want)
want := []float64{-4, -10, 0, -14, -6, -12, -8}
if diff := cmp.Diff(want, logits); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
if _, err := Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
t.Fatalf("expected error for temperature=-1, got %v", logits)
logits, err = Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Errorf("expected error for temperature=-1, got %v", logits)
}
if _, err := Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
t.Fatalf("expected error for temperature=2.1, got %v", logits)
logits, err = Temperature(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Errorf("expected error for temperature=0, got %v", logits)
}
logits, err = Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Errorf("expected error for temperature=2.1, got %v", logits)
}
}
func TestSoftmax(t *testing.T) {
probs, err := Softmax().Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
}
probs := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
if !floats.Equal(probs, expectedProbs) {
t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs)
if diff := cmp.Diff(expectedProbs, probs); diff != "" {
t.Errorf("probs mismatch (-want +got):\n%s", diff)
}
}
func TestTopK(t *testing.T) {
logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
t.Error(err)
return
}
expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
if !floats.Same(logits, expectedlogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
if diff := cmp.Diff(expectedlogits, logits); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
logits, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
_, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Fatalf("expected error for k=0, got %v", logits)
t.Errorf("expected error for k=0, got %v", err)
}
logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
t.Error(err)
return
}
expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4}
if !floats.Same(logits, expectedlogits) {
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
if diff := cmp.Diff(expectedlogits, logits); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
}
func TestTopP(t *testing.T) {
logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err != nil {
t.Fatal(err)
t.Error(err)
return
}
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
if !floats.Same(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want)
if diff := cmp.Diff(want, logits); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
logits, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
_, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits)
t.Error("expected error for p=1.0")
}
logits, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
_, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits)
t.Error("expected error for p=0.0")
}
}
func TestMinP(t *testing.T) {
logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
if err != nil {
t.Fatal(err)
t.Error(err)
return
}
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 3, 4}
if !floats.Same(logits, want) {
t.Fatalf("got: %v, want: %v", logits, want)
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
if diff := cmp.Diff(want, logits); diff != "" {
t.Errorf("logits mismatch (-want +got):\n%s", diff)
}
logits, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
_, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil {
t.Fatalf("expected error for p=1.0, got %v", logits)
t.Error("expected error for p=1.0")
}
logits, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
_, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
if err == nil {
t.Fatalf("expected error for p=0.0, got %v", logits)
t.Error("expected error for p=0.0")
}
}
func TestWeighed(t *testing.T) {
idx, err := Weighed().Sample([]float64{math.Inf(-1), 2, math.Inf(-1), math.Inf(-1)})
idx, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
if err != nil {
t.Fatal(err)
t.Error(err)
return
}
want := 1
if idx != want {
t.Fatalf("got: %v, want: %v", idx, want)
if diff := cmp.Diff(want, idx); diff != "" {
t.Errorf("index mismatch (-want +got):\n%s", diff)
}
idx, err = Weighed().Sample([]float64{math.Inf(-1), math.Inf(-1), math.Inf(-1)})
idx, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
if err == nil {
t.Fatalf("expected error for no valid tokens, got %v", idx)
t.Error("expected error for no valid tokens, got index", idx)
}
}
@ -132,27 +145,32 @@ func TestSample(t *testing.T) {
id: 3,
callOrder: &callOrder,
}
sampler := NewSampler([]Transform{mock1, mock2, mock3}, Greedy())
got, err := sampler.Sample(input)
got, err := Greedy().Sample(input, mock1, mock2, mock3)
if err != nil {
t.Fatal(err)
}
if !slices.Equal(callOrder, []int{1, 2, 3}) {
t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
t.Error(err)
return
}
want := 3 // Greedy sampler should pick highest logit
if got != want {
t.Errorf("got %v, want %v", got, want)
if diff := cmp.Diff(want, got); diff != "" {
t.Errorf("sampled index mismatch (-want +got):\n%s", diff)
}
_, err = Weighted(nil).Sample(input, mock1, mock2, mock3)
if err != nil {
t.Error(err)
return
}
wantOrder := []int{1, 2, 3}
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
t.Errorf("call order mismatch (-want +got):\n%s", diff)
}
errMock := &testTransform{
returnErr: fmt.Errorf("mock error"),
}
sampler = NewSampler([]Transform{mock1, errMock, mock2}, Greedy())
_, err = sampler.Sample(input)
_, err = Weighted(nil).Sample(input, mock1, errMock, mock2)
if err == nil {
t.Error("Expected error from sampler")
}
@ -174,14 +192,51 @@ func (ts *testTransform) Apply(logits []float64) ([]float64, error) {
return logits, nil
}
func TestSampleTemperatureZero(t *testing.T) {
sampler := NewSampler([]Transform{Temperature(0)}, Greedy())
got, err := sampler.Sample([]float32{1, 2, 3, 4})
if err != nil {
t.Fatal(err)
func BenchmarkTransform(b *testing.B) {
transforms := map[string]Transform{
"Temperature": Temperature(0.5),
"TopK": TopK(10),
"TopP": TopP(0.9),
"MinP": MinP(0.2),
}
want := 3 // Greedy sampler should pick highest logit index
if got != want {
t.Fatalf("got: %v, want: %v", got, want)
logits := make([]float64, 1<<16)
for i := range logits {
logits[i] = rand.Float64()
}
for name, transform := range transforms {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
_, err := transform.Apply(logits)
if err != nil {
b.Error(err)
}
}
})
}
}
func BenchmarkSample(b *testing.B) {
samplers := map[string]Sampler{
"Greedy": Greedy(),
"Weighted": Weighted(nil),
}
logits := make([]float32, 1<<16)
for i := range logits {
logits[i] = rand.Float32()
}
for name, s := range samplers {
b.Run(name, func(b *testing.B) {
b.ResetTimer()
for range b.N {
if _, err := s.Sample(logits); err != nil {
b.Error(err)
}
}
})
}
}

View File

@ -8,27 +8,45 @@ import (
"github.com/ollama/ollama/model"
)
type SOSampler struct {
type JSONSampler struct {
schema *Schema
propIdx int
propToNodeMap map[string]*PDANode
propToNodeMap map[string]*PDA
pdaSampler *PushdownSampler
decodedToks []string
}
func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
pdaSampler := NewPushdownSampler(proc)
func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
pdaSampler, err := NewPushdownSampler(proc)
if err != nil {
return nil, err
}
so := &SOSampler{
if schema == nil {
return &JSONSampler{
schema: nil,
propIdx: -1,
propToNodeMap: nil,
pdaSampler: pdaSampler,
}, nil
}
fmt.Println("schema not nil")
so := &JSONSampler{
schema: schema,
propIdx: -1,
propToNodeMap: make(map[string]*PDANode),
propToNodeMap: make(map[string]*PDA),
pdaSampler: pdaSampler,
}
so.schemaToGraph()
// This is prob slow
// Benchmark token decoding
start := time.Now()
var m runtime.MemStats
runtime.ReadMemStats(&m)
before := m.Alloc
vocab := proc.GetVocabulary()
decodedToks := make([]string, len(vocab.Values))
for i := range vocab.Values {
@ -40,14 +58,18 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
}
so.decodedToks = decodedToks
runtime.ReadMemStats(&m)
after := m.Alloc
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Token decode time = %v\n", time.Since(start))
fmt.Println("--------------------------------")
fmt.Println("SOSampler")
fmt.Println("--------------------------------")
// Benchmark this section
start := time.Now()
var m runtime.MemStats
start = time.Now()
runtime.ReadMemStats(&m)
before := m.Alloc
before = m.Alloc
// TODO: still messed up
// TODO: recursion use case
@ -57,12 +79,12 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
// propName -> node
curState := node.State
fromNode := node
CreateMask(fromNode, proc, decodedToks)
so.pdaSampler.CreateMask(fromNode)
for curState == StateInStructuredKey {
// there is only one edge
for r, toNode := range fromNode.TransitionEdges {
// fmt.Println("rune", r, "edge", toNode.State)
CreateMask(toNode, proc, decodedToks)
so.pdaSampler.CreateMask(toNode)
fmt.Printf("created mask for %c\n", r)
curState = toNode.State
fmt.Println("next state", curState)
@ -73,7 +95,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
}
runtime.ReadMemStats(&m)
after := m.Alloc
after = m.Alloc
fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
fmt.Printf("Mask creation time = %v\n", time.Since(start))
fmt.Println("--------------------------------")
@ -81,7 +103,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
return so, nil
}
func (s *SOSampler) schemaToGraph() {
func (s *JSONSampler) schemaToGraph() {
schemaType := s.schema.EffectiveType()
switch schemaType {
case "object":
@ -91,18 +113,18 @@ func (s *SOSampler) schemaToGraph() {
for _, prop := range s.schema.Properties {
// name of key
name := prop.Name
keyNode := &PDANode{
keyNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]*PDANode),
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
prevNode := keyNode
for _, r := range name {
runeNode := &PDANode{
runeNode := &PDA{
State: StateInStructuredKey, // this is unchanging, will impact sampling
TransitionEdges: make(map[rune]*PDANode),
MaskTokenIDToNode: make(map[int32]*PDANode),
TransitionEdges: make(map[rune]*PDA),
MaskTokenIDToNode: make(map[int32]*PDA),
}
fmt.Println("runeNode created", runeNode.State)
fmt.Printf("runeNode created %c\n", r)
@ -117,9 +139,14 @@ func (s *SOSampler) schemaToGraph() {
fmt.Println("name", name, "keyNode", keyNode.State)
}
}
// TODO: do values + recursion
}
func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
if s.schema == nil {
return s.pdaSampler.Apply(logits)
}
switch s.pdaSampler.curNode.State {
// doesnt account for multi rune case
case StateInObjectKey:
@ -148,17 +175,18 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
// todo: if i incremenet propidx then i know im in last value as well
switch s.pdaSampler.curNode.State {
case StateInObjectEnd:
fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State)
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode)
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
s.pdaSampler.curNode = NewPDANode(StateTerminate)
s.propIdx++
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
delete(s.pdaSampler.curNode.TransitionEdges, ',')
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode)
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks)
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
s.propIdx++
}
}
@ -167,12 +195,17 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
}
func (s *SOSampler) UpdateState(tokenSlice []int32) error {
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
err := s.pdaSampler.UpdateState(tokenSlice)
if err != nil {
return err
}
if s.schema == nil {
// Don't need to update state for unconstrained JSON sampling
return nil
}
switch s.pdaSampler.curNode.State {
case StateInObjectKey:
s.propIdx++