mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 11:28:17 +02:00
wip with json stuff and cleanup
This commit is contained in:
parent
25edfa6fdb
commit
aa6d5151df
49
sample/constrained.go
Normal file
49
sample/constrained.go
Normal file
@ -0,0 +1,49 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type ConstrainedSampler struct {
|
||||
schema *Schema
|
||||
propIdx int
|
||||
propToNodeMap map[string]*PDA
|
||||
pdaSampler *PushdownSampler
|
||||
decodedToks []string
|
||||
}
|
||||
|
||||
func NewConstrainedSampler(proc model.TextProcessor, schema *Schema) (*ConstrainedSampler, error) {
|
||||
pdaSampler, err := NewPushdownSampler(proc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// if schema == nil {
|
||||
return &ConstrainedSampler{
|
||||
schema: nil,
|
||||
propIdx: -1,
|
||||
propToNodeMap: nil,
|
||||
pdaSampler: pdaSampler,
|
||||
}, nil
|
||||
|
||||
}
|
||||
|
||||
func (s *ConstrainedSampler) Apply(logits []float64) ([]float64, error) {
|
||||
if s.schema == nil {
|
||||
return s.pdaSampler.Apply(logits)
|
||||
}
|
||||
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (s *ConstrainedSampler) UpdateState(tokenSlice []int32) error {
|
||||
if err := s.pdaSampler.UpdateState(tokenSlice); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.schema == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
32
sample/feedback.txt
Normal file
32
sample/feedback.txt
Normal file
@ -0,0 +1,32 @@
|
||||
// Feedback from code review:
|
||||
|
||||
// pushdown_automata.go:
|
||||
// 1. The BuildGraph function is quite long and could be split into smaller, more focused functions
|
||||
// 2. Consider using constants instead of magic runes like rune(-1) for sentinel values
|
||||
// 3. The state machine transitions could be defined more declaratively, perhaps in a config
|
||||
// 4. The stringInvalidRunes list needs to handle escape sequences properly
|
||||
// 5. The graph building could be optimized to avoid duplicate nodes/transitions
|
||||
// 6. Consider adding validation for max nesting depth of braces/brackets
|
||||
// 7. The CreateMask function is doing a lot - could be split into smaller pieces
|
||||
// 8. isRuneValid has a "garbage interface" per TODO - needs cleaner design
|
||||
|
||||
// pushdown_runner.go:
|
||||
// 1. The Apply method has a lot of duplicated logic around EOS handling
|
||||
// 2. The UpdateState method could use more granular error messages
|
||||
// 3. The braceStack validation could be moved to a separate validator
|
||||
// 4. Consider adding max length limits for strings/numbers
|
||||
// 5. The stateCounter isn't being used effectively yet
|
||||
// 6. Need to add penalties for staying in same state too long
|
||||
// 7. The maskLogits function could be optimized to avoid allocations
|
||||
// 8. Missing proper cleanup/reset functionality
|
||||
// 9. Error handling could be more consistent throughout
|
||||
// 10. Consider adding debug logging levels instead of raw fmt.Println
|
||||
|
||||
// General improvements needed:
|
||||
// - More comprehensive testing, especially edge cases
|
||||
// - Better documentation of state machine transitions
|
||||
// - Performance optimization for large inputs
|
||||
// - Memory usage optimization for the graph structure
|
||||
// - Cleaner interfaces between components
|
||||
// - More robust error handling and recovery
|
||||
|
11
sample/fused_mask_sample.go
Normal file
11
sample/fused_mask_sample.go
Normal file
@ -0,0 +1,11 @@
|
||||
package sample
|
||||
|
||||
// type fusedMaskSampler struct{}
|
||||
|
||||
// func FusedMaskSampler() Sampler {
|
||||
// return fusedMaskSampler{}
|
||||
// }
|
||||
|
||||
// func (f fusedMaskSampler) Sample(logits []float64) (int, error) {
|
||||
// return int(logits[0]), nil
|
||||
// }
|
@ -8,6 +8,19 @@ func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
func (s greedy) Sample(t []float64) (int, error) {
|
||||
return floats.MaxIdx(t), nil
|
||||
func (s greedy) Sample(logits []float32, transforms ...Transform) (int, error) {
|
||||
logits64 := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
logits64[i] = float64(v)
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, t := range transforms {
|
||||
logits64, err = t.Apply(logits64)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
return floats.MaxIdx(logits64), nil
|
||||
}
|
||||
|
@ -23,7 +23,9 @@ const (
|
||||
StateInColon
|
||||
StateInComma
|
||||
StateInTab
|
||||
StateInSpace
|
||||
StateInSpaceToValue
|
||||
StateInSpaceEndValue
|
||||
StateInNewlineEndValue
|
||||
StateInObjSpace
|
||||
StateInList
|
||||
StateInListComma
|
||||
@ -57,7 +59,9 @@ var JSONStates = []JSONState{
|
||||
StateInColon,
|
||||
StateInComma,
|
||||
StateInTab,
|
||||
StateInSpace,
|
||||
StateInSpaceToValue,
|
||||
StateInSpaceEndValue,
|
||||
StateInNewlineEndValue,
|
||||
StateInObjSpace,
|
||||
StateInList,
|
||||
StateInListComma,
|
||||
@ -107,7 +111,7 @@ func (s JSONState) String() string {
|
||||
return "StateInComma"
|
||||
case StateInTab:
|
||||
return "StateInTab"
|
||||
case StateInSpace:
|
||||
case StateInSpaceToValue:
|
||||
return "StateInSpace"
|
||||
case StateInObjSpace:
|
||||
return "StateInObjSpace"
|
||||
@ -121,6 +125,8 @@ func (s JSONState) String() string {
|
||||
return "StateInListEnd"
|
||||
case StateInNewline:
|
||||
return "StateInNewline"
|
||||
case StateInNewlineEndValue:
|
||||
return "StateInNewlineEndValue"
|
||||
case StateInNumber:
|
||||
return "StateInNumber"
|
||||
case StateInNumberEnd:
|
||||
@ -129,6 +135,8 @@ func (s JSONState) String() string {
|
||||
return "StateInStringEnd"
|
||||
case StateInObjectKeyEnd:
|
||||
return "StateInObjectKeyEnd"
|
||||
case StateInSpaceEndValue:
|
||||
return "StateInSpaceEndValue"
|
||||
case StateTerminate:
|
||||
return "StateTerminate"
|
||||
case StateInObjectEnd:
|
@ -6,8 +6,35 @@ import (
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
/*
|
||||
Key JSON rules to consider:
|
||||
|
||||
1. Whitespace handling:
|
||||
- Need to handle all valid JSON whitespace characters (\r, spaces between tokens)
|
||||
- Current code only handles some whitespace cases
|
||||
|
||||
2. Number validation:
|
||||
- Need proper validation for special number cases like -0
|
||||
- Should handle .5 style decimals
|
||||
- Need limits on scientific notation (e, E)
|
||||
|
||||
3. String escaping:
|
||||
- Currently marks \ as invalid but should allow escaped sequences:
|
||||
- \"
|
||||
- \n
|
||||
- \u1234 unicode escapes
|
||||
|
||||
4. Empty object/array transitions:
|
||||
- Direct {} and [] cases could be more explicit
|
||||
- Need clear transitions for these edge cases
|
||||
|
||||
5. Nested depth limits:
|
||||
- No protection against excessive nesting
|
||||
- Could cause stack overflow with deeply nested structures
|
||||
*/
|
||||
|
||||
// TODO: / should be valid but an escape character
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ',', '/'}
|
||||
var stringInvalidRunes = []rune{'\n', '\t', '{', '}', ':', ',', '/'}
|
||||
|
||||
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||
@ -18,31 +45,31 @@ var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||
|
||||
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||
|
||||
type PDANode struct {
|
||||
type PDA struct {
|
||||
State JSONState
|
||||
TransitionEdges map[rune]*PDANode
|
||||
MaskTokenIDToNode map[int32]*PDANode
|
||||
TransitionEdges map[rune]*PDA
|
||||
MaskTokenIDToNode map[int32]*PDA
|
||||
}
|
||||
|
||||
func NewPDANode(state JSONState) *PDANode {
|
||||
return &PDANode{
|
||||
func NewPDANode(state JSONState) *PDA {
|
||||
return &PDA{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]*PDANode),
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
||||
stateToNodeMap := make(map[JSONState]*PDANode)
|
||||
|
||||
// TODO: make this a loop
|
||||
type PDAGraphBuilder struct {
|
||||
proc model.TextProcessor
|
||||
decodedToks []string
|
||||
stateToNodeMap map[JSONState]*PDA
|
||||
}
|
||||
|
||||
func (b *PDAGraphBuilder) BuildGraph() error {
|
||||
stateToNodeMap := make(map[JSONState]*PDA)
|
||||
for _, state := range JSONStates {
|
||||
stateToNodeMap[state] = NewPDANode(state)
|
||||
}
|
||||
// TODO:
|
||||
// consider adding a node to just point to values, could be good to compute that
|
||||
// mask rather than many different nodes
|
||||
|
||||
stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
@ -51,10 +78,21 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
stateToNodeMap[StateInObject].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
stateToNodeMap[StateInObject].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
|
||||
//new line
|
||||
// new line
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
|
||||
// new line end value
|
||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInNewlineEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
// TODO: see if this is needed for formatting
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
|
||||
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
|
||||
@ -68,16 +106,16 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
|
||||
// where values should be
|
||||
// this could be combined but the probl might change, we're alr doing a skip ahead
|
||||
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
|
||||
stateToNodeMap[StateInColon].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
stateToNodeMap[StateInColon].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInColon].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
addValueConnections(stateToNodeMap[StateInColon], stateToNodeMap)
|
||||
b.addValueConnections(stateToNodeMap[StateInColon])
|
||||
|
||||
// Leads to a value
|
||||
stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
|
||||
stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
b.addValueConnections(stateToNodeMap[StateInSpaceToValue])
|
||||
stateToNodeMap[StateInSpaceToValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
|
||||
// Values
|
||||
// string node
|
||||
@ -85,149 +123,142 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
stateToNodeMap[StateInString].TransitionEdges['"'] = stateToNodeMap[StateInStringEnd]
|
||||
|
||||
// String end node
|
||||
addEnds(stateToNodeMap[StateInStringEnd], stateToNodeMap)
|
||||
b.addEnds(stateToNodeMap[StateInStringEnd])
|
||||
stateToNodeMap[StateInStringEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInStringEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||
// number node
|
||||
for _, r := range validNumberRunes {
|
||||
stateToNodeMap[StateInNumber].TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInNumber], stateToNodeMap)
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
|
||||
b.addEnds(stateToNodeMap[StateInNumber])
|
||||
stateToNodeMap[StateInNumber].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInNumber].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// list node
|
||||
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
|
||||
// list end node
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// empty list
|
||||
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
||||
b.addValueConnections(stateToNodeMap[StateInList])
|
||||
|
||||
// null node
|
||||
for _, r := range validNullRunes {
|
||||
stateToNodeMap[StateInNull].TransitionEdges[r] = stateToNodeMap[StateInNull]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInNull], stateToNodeMap)
|
||||
b.addEnds(stateToNodeMap[StateInNull])
|
||||
stateToNodeMap[StateInNull].TransitionEdges[' '] = stateToNodeMap[StateInSpaceToValue]
|
||||
stateToNodeMap[StateInNull].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// list comma
|
||||
// should point to values
|
||||
stateToNodeMap[StateInListComma].TransitionEdges[' '] = stateToNodeMap[StateInListComma]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInListComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
addValueConnections(stateToNodeMap[StateInListComma], stateToNodeMap)
|
||||
b.addValueConnections(stateToNodeMap[StateInListComma])
|
||||
|
||||
// list object end
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[','] = stateToNodeMap[StateInListComma]
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
// TODO: not sure if this is needed
|
||||
stateToNodeMap[StateInListObjectEnd].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||
}
|
||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
b.addEnds(stateToNodeMap[StateInBool])
|
||||
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
// comma node
|
||||
stateToNodeMap[StateInComma].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
stateToNodeMap[StateInComma].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInComma].TransitionEdges[' '] = stateToNodeMap[StateInObjSpace]
|
||||
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInObjSpace].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
// space end value
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[' '] = stateToNodeMap[StateInSpaceEndValue]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
stateToNodeMap[StateInSpaceEndValue].TransitionEdges['\n'] = stateToNodeMap[StateInNewlineEndValue]
|
||||
|
||||
return stateToNodeMap[StateStart], stateToNodeMap, nil
|
||||
b.stateToNodeMap = stateToNodeMap
|
||||
if err := b.preComputeValidStates(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func addEnds(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
node.TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
node.TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
node.TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
func (b *PDAGraphBuilder) addEnds(node *PDA) {
|
||||
node.TransitionEdges[','] = b.stateToNodeMap[StateInComma]
|
||||
node.TransitionEdges['}'] = b.stateToNodeMap[StateInObjectEnd]
|
||||
node.TransitionEdges[']'] = b.stateToNodeMap[StateInListEnd]
|
||||
}
|
||||
|
||||
func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
node.TransitionEdges['"'] = stateToNodeMap[StateInString]
|
||||
func (b *PDAGraphBuilder) addValueConnections(node *PDA) {
|
||||
node.TransitionEdges['"'] = b.stateToNodeMap[StateInString]
|
||||
for _, r := range validNumberRunes {
|
||||
node.TransitionEdges[r] = stateToNodeMap[StateInNumber]
|
||||
node.TransitionEdges[r] = b.stateToNodeMap[StateInNumber]
|
||||
}
|
||||
node.TransitionEdges['t'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['f'] = stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['n'] = stateToNodeMap[StateInNull]
|
||||
// TODO(parthsareen): force the output and shift similar to structured outputs
|
||||
node.TransitionEdges['t'] = b.stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['f'] = b.stateToNodeMap[StateInBool]
|
||||
node.TransitionEdges['n'] = b.stateToNodeMap[StateInNull]
|
||||
}
|
||||
|
||||
// TODO: tough life fr. plz fix.
|
||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||
|
||||
// TODO; should come from top level
|
||||
vocab := proc.GetVocabulary()
|
||||
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, node := range stateToNodeMap {
|
||||
err = CreateMask(node, proc, decodedToks)
|
||||
if err != nil {
|
||||
func (b *PDAGraphBuilder) preComputeValidStates() error {
|
||||
for _, node := range b.stateToNodeMap {
|
||||
if err := b.CreateMask(node); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error {
|
||||
for i := range decodedToks {
|
||||
token := decodedToks[i]
|
||||
func (b *PDAGraphBuilder) CreateMask(node *PDA) error {
|
||||
for i := range b.decodedToks {
|
||||
token := b.decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
if b.proc.Is(uint32(i), model.SpecialEOS) || b.proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
continue
|
||||
}
|
||||
valid := true
|
||||
curNode := node
|
||||
valid := true
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
var err error
|
||||
for _, r := range token {
|
||||
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !valid {
|
||||
curNode, valid = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if curNode == nil || !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
// cur node allows skipping
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// TODO: garbage interface plz fix
|
||||
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||
func isRuneValid(r rune, curNode *PDA, consumedSpecialRunes map[rune]bool) (*PDA, bool) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return false, nil, nil
|
||||
return nil, false
|
||||
}
|
||||
|
||||
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||
if specialRune {
|
||||
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||
return false, nil, nil
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
@ -235,17 +266,17 @@ func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (
|
||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||
if specialRune {
|
||||
if curNode.State == nextNode.State {
|
||||
return false, nil, nil
|
||||
return nil, false
|
||||
}
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return true, nextNode, nil
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
// Check for sentinel value - if present, any rune is valid
|
||||
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||
return true, nextNode, nil
|
||||
return nextNode, true
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
return nil, false
|
||||
}
|
||||
|
@ -11,17 +11,17 @@ import (
|
||||
|
||||
// TODO: safety in case of invalid json
|
||||
// TODO: interfaces to cleanup with return values
|
||||
// TODO this interface shouldn't be the sampler - should just use Sampler
|
||||
// TODO: add penalties for string \n stuff
|
||||
type PushdownSampler struct {
|
||||
// stateful
|
||||
curNode *PDANode
|
||||
proc model.TextProcessor
|
||||
stateToNodeMap map[JSONState]*PDANode
|
||||
braceStack []rune
|
||||
stateCounter uint32
|
||||
PDAGraphBuilder
|
||||
curNode *PDA
|
||||
braceStack []rune
|
||||
stateCounter uint32
|
||||
}
|
||||
|
||||
// graph should be built once and reused per tokenizer
|
||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
|
||||
start := time.Now()
|
||||
|
||||
fmt.Println("--------------------------------")
|
||||
@ -32,27 +32,38 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
before := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
|
||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
vocab := proc.GetVocabulary()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
err = PreComputeValidStates(stateToNodeMap, proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
|
||||
gb := &PDAGraphBuilder{
|
||||
proc: proc,
|
||||
decodedToks: decodedToks,
|
||||
}
|
||||
|
||||
if err := gb.BuildGraph(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
|
||||
// TODO: this can be simplified
|
||||
return &PushdownSampler{
|
||||
curNode: startNode,
|
||||
proc: proc,
|
||||
stateToNodeMap: stateToNodeMap,
|
||||
braceStack: []rune{},
|
||||
stateCounter: 0,
|
||||
}
|
||||
curNode: gb.stateToNodeMap[StateStart],
|
||||
PDAGraphBuilder: *gb,
|
||||
braceStack: []rune{},
|
||||
stateCounter: 0,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
@ -66,14 +77,7 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
return forceFinish(s, logits)
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
@ -82,18 +86,14 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateTerminate:
|
||||
return forceFinish(s, logits)
|
||||
|
||||
case StateInObjectEnd:
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
return forceFinish(s, logits)
|
||||
}
|
||||
|
||||
peek := s.braceStack[len(s.braceStack)-1]
|
||||
@ -112,22 +112,13 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
|
||||
if peek == rune('[') {
|
||||
s.curNode = s.stateToNodeMap[StateInListComma]
|
||||
}
|
||||
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateTerminate:
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
fmt.Println("masking logits current state", s.curNode.State)
|
||||
logits, err := s.maskLogits(logits, s.curNode)
|
||||
@ -138,13 +129,24 @@ func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
|
||||
}
|
||||
}
|
||||
|
||||
func forceFinish(s *PushdownSampler, logits []float64) ([]float64, error) {
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
fmt.Println("current state - updating", s.curNode.State)
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println(">>> mappedString", mappedString)
|
||||
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
||||
|
||||
// TODO: should force closing for all braces - not doing square yet
|
||||
for _, r := range mappedString {
|
||||
@ -198,7 +200,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
}
|
||||
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDA) ([]float64, error) {
|
||||
|
||||
// Create a new slice with same length as logits, initialized to -Inf
|
||||
maskedLogits := make([]float64, len(logits))
|
||||
for i := range maskedLogits {
|
||||
@ -215,4 +218,23 @@ func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64
|
||||
return maskedLogits, nil
|
||||
}
|
||||
|
||||
// TODO: add penalties for string \n stuff
|
||||
func (s *PushdownSampler) fastMaskLogits(logits []float64, node *PDA) ([]float64, error) {
|
||||
maxLogit := math.Inf(-1)
|
||||
maxIndex := -1
|
||||
|
||||
// Find the maximum logit value among valid tokens
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
|
||||
maxLogit = logits[tokenID]
|
||||
maxIndex = int(tokenID)
|
||||
}
|
||||
}
|
||||
|
||||
if maxIndex == -1 {
|
||||
return nil, fmt.Errorf("no valid tokens found in mask")
|
||||
}
|
||||
|
||||
logits[0] = float64(maxIndex)
|
||||
return logits, nil
|
||||
// return maxIndex, nil
|
||||
}
|
||||
|
163
sample/sample.go
163
sample/sample.go
@ -6,6 +6,8 @@ import (
|
||||
"math"
|
||||
"slices"
|
||||
|
||||
pq "github.com/emirpasic/gods/v2/queues/priorityqueue"
|
||||
"golang.org/x/exp/rand"
|
||||
"gonum.org/v1/gonum/floats"
|
||||
"gonum.org/v1/gonum/stat/sampleuv"
|
||||
)
|
||||
@ -15,33 +17,34 @@ type Transform interface {
|
||||
}
|
||||
|
||||
type Sampler interface {
|
||||
Sample([]float64) (int, error)
|
||||
Sample([]float32, ...Transform) (int, error)
|
||||
}
|
||||
|
||||
type SamplerConfig struct {
|
||||
transforms []Transform
|
||||
sampler Sampler
|
||||
}
|
||||
|
||||
// NewSampler creates a sampler with the given transforms and sampling method
|
||||
func NewSampler(transforms []Transform, sampler Sampler) *SamplerConfig {
|
||||
return &SamplerConfig{
|
||||
transforms: transforms,
|
||||
sampler: sampler,
|
||||
// TODO(parthsareen): potentially cache softmax values
|
||||
func softmax(logits []float64) []float64 {
|
||||
var sum float64
|
||||
tt := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
tt[i] = math.Exp(v)
|
||||
sum += tt[i]
|
||||
}
|
||||
floats.Scale(1/sum, tt)
|
||||
return tt
|
||||
}
|
||||
|
||||
type Temperature float64
|
||||
|
||||
func (t Temperature) Apply(logits []float64) ([]float64, error) {
|
||||
if t == 0 {
|
||||
return nil, errors.New("use Greedy sampler instead of Temperature(0)")
|
||||
}
|
||||
if t < 0 || t > 2 {
|
||||
return nil, errors.New("temperature must be between 0 and 2")
|
||||
}
|
||||
temp := math.Max(float64(t), 1e-7)
|
||||
|
||||
// subtracting max logit to avoid under/overflow
|
||||
maxLogit := floats.Max(logits)
|
||||
|
||||
temp := math.Max(float64(t), 1e-7)
|
||||
maxLogit := slices.Max(logits)
|
||||
for i := range logits {
|
||||
logits[i] = (logits[i] - maxLogit) / temp
|
||||
}
|
||||
@ -49,52 +52,41 @@ func (t Temperature) Apply(logits []float64) ([]float64, error) {
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
type softmax struct{}
|
||||
|
||||
func Softmax() Transform {
|
||||
return softmax{}
|
||||
type logitMap struct {
|
||||
index int
|
||||
logit float64
|
||||
}
|
||||
|
||||
func (softmax) Apply(logits []float64) ([]float64, error) {
|
||||
return computeSoftmax(logits), nil
|
||||
}
|
||||
|
||||
// TODO: cache softmax values
|
||||
func computeSoftmax(logits []float64) []float64 {
|
||||
copiedLogits := make([]float64, len(logits))
|
||||
copy(copiedLogits, logits)
|
||||
for i := range copiedLogits {
|
||||
copiedLogits[i] = math.Exp(copiedLogits[i])
|
||||
}
|
||||
|
||||
floatSum := floats.Sum(copiedLogits)
|
||||
floats.Scale(1.0/floatSum, copiedLogits)
|
||||
|
||||
return copiedLogits
|
||||
func logitMapComparator(a, b logitMap) int {
|
||||
return -cmp.Compare(a.logit, b.logit)
|
||||
}
|
||||
|
||||
type TopK int
|
||||
|
||||
// TODO(parthsareen): avoid having to check all logits after this transform
|
||||
func (k TopK) Apply(logits []float64) ([]float64, error) {
|
||||
if k <= 0 {
|
||||
return nil, errors.New("k must be positive")
|
||||
return nil, errors.New("k must be greater than 0")
|
||||
}
|
||||
if int(k) >= len(logits) {
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
indices := make([]int, len(logits))
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
q := pq.NewWith(logitMapComparator)
|
||||
for i, logit := range logits {
|
||||
q.Enqueue(logitMap{index: i, logit: logit})
|
||||
}
|
||||
|
||||
// sort in descending order
|
||||
slices.SortFunc(indices, func(i, j int) int {
|
||||
return cmp.Compare(logits[j], logits[i])
|
||||
})
|
||||
validLogits := make(map[int]float64)
|
||||
for range k {
|
||||
logitMap, _ := q.Dequeue()
|
||||
validLogits[logitMap.index] = logitMap.logit
|
||||
}
|
||||
|
||||
for _, idx := range indices[k:] {
|
||||
logits[idx] = math.Inf(-1)
|
||||
for i := range logits {
|
||||
if _, ok := validLogits[i]; !ok {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
|
||||
return logits, nil
|
||||
@ -107,8 +99,7 @@ func (p TopP) Apply(logits []float64) ([]float64, error) {
|
||||
return nil, errors.New("p must be between 0 and 1")
|
||||
}
|
||||
|
||||
probs := computeSoftmax(logits)
|
||||
|
||||
probs := softmax(logits)
|
||||
indices := make([]int, len(probs))
|
||||
for i := range indices {
|
||||
indices[i] = i
|
||||
@ -139,17 +130,11 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
|
||||
return nil, errors.New("p must be between 0 and 1")
|
||||
}
|
||||
|
||||
probs := computeSoftmax(logits)
|
||||
copiedProbs := make([]float64, len(probs))
|
||||
copy(copiedProbs, probs)
|
||||
probs := softmax(logits)
|
||||
threshold := slices.Max(probs) * float64(p)
|
||||
|
||||
slices.Sort(copiedProbs)
|
||||
|
||||
maxProb := copiedProbs[len(copiedProbs)-1]
|
||||
probThreshold := float64(p) * maxProb
|
||||
|
||||
for i := range probs {
|
||||
if probs[i] < probThreshold {
|
||||
for i, prob := range probs {
|
||||
if prob < threshold {
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
@ -157,18 +142,35 @@ func (p MinP) Apply(logits []float64) ([]float64, error) {
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
type weighed struct{}
|
||||
|
||||
func Weighed() Sampler {
|
||||
return weighed{}
|
||||
type weighted struct {
|
||||
src rand.Source
|
||||
}
|
||||
|
||||
// should return single value
|
||||
func (s weighed) Sample(logits []float64) (int, error) {
|
||||
func Weighted(seed *int64) Sampler {
|
||||
var src rand.Source
|
||||
if seed != nil {
|
||||
src = rand.NewSource(uint64(*seed))
|
||||
}
|
||||
return weighted{src: src}
|
||||
}
|
||||
|
||||
func (s weighted) Sample(logits []float32, transforms ...Transform) (int, error) {
|
||||
logits64 := make([]float64, len(logits))
|
||||
for i, v := range logits {
|
||||
logits64[i] = float64(v)
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, t := range transforms {
|
||||
logits64, err = t.Apply(logits64)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
logitsCopy := make([]float64, 0, len(logits))
|
||||
indices := make([]int, 0, len(logits))
|
||||
// the uv sampler does not support NaN values
|
||||
for i, logit := range logits {
|
||||
for i, logit := range logits64 {
|
||||
if !math.IsInf(logit, -1) {
|
||||
logitsCopy = append(logitsCopy, logit)
|
||||
indices = append(indices, i)
|
||||
@ -176,38 +178,13 @@ func (s weighed) Sample(logits []float64) (int, error) {
|
||||
}
|
||||
|
||||
if len(logitsCopy) == 0 {
|
||||
return -1, errors.New("no valid tokens found")
|
||||
return -1, errors.New("no valid logits found for weighed sampling")
|
||||
}
|
||||
|
||||
softmax := computeSoftmax(logitsCopy)
|
||||
w := sampleuv.NewWeighted(softmax, nil)
|
||||
probs := softmax(logitsCopy)
|
||||
w := sampleuv.NewWeighted(probs, s.src)
|
||||
if idx, ok := w.Take(); ok {
|
||||
// returns the token ID
|
||||
return indices[idx], nil
|
||||
}
|
||||
return -1, errors.New("weighed sampler failed")
|
||||
}
|
||||
|
||||
// Sample applies transforms and samples a token ID
|
||||
func (s *SamplerConfig) Sample(input []float32) (int, error) {
|
||||
logits := make([]float64, len(input))
|
||||
for i, v := range input {
|
||||
logits[i] = float64(v)
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, t := range s.transforms {
|
||||
if t == Temperature(0) {
|
||||
// early return with greedy if temperature is 0
|
||||
s.sampler = Greedy()
|
||||
break
|
||||
}
|
||||
|
||||
logits, err = t.Apply(logits)
|
||||
if err != nil {
|
||||
return -1, err
|
||||
}
|
||||
}
|
||||
|
||||
return s.sampler.Sample(logits)
|
||||
return -1, errors.New("weighed sampler failed, no valid token found")
|
||||
}
|
||||
|
@ -3,116 +3,129 @@ package sample
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"slices"
|
||||
"math/rand/v2"
|
||||
"testing"
|
||||
|
||||
"gonum.org/v1/gonum/floats"
|
||||
"github.com/google/go-cmp/cmp"
|
||||
)
|
||||
|
||||
func TestTemperature(t *testing.T) {
|
||||
logits, err := Temperature(0.5).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
logits, err := Temperature(0.5).Apply([]float64{2, -1, 4, -3, 1, -2, 0})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want := []float64{-14, -12, -10, -8, -6, -4, 0}
|
||||
if !floats.Equal(logits, want) {
|
||||
t.Fatalf("got: %v, want: %v", logits, want)
|
||||
want := []float64{-4, -10, 0, -14, -6, -12, -8}
|
||||
if diff := cmp.Diff(want, logits); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
if _, err := Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
||||
t.Fatalf("expected error for temperature=-1, got %v", logits)
|
||||
logits, err = Temperature(-1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for temperature=-1, got %v", logits)
|
||||
}
|
||||
if _, err := Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4}); err == nil {
|
||||
t.Fatalf("expected error for temperature=2.1, got %v", logits)
|
||||
logits, err = Temperature(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for temperature=0, got %v", logits)
|
||||
}
|
||||
logits, err = Temperature(2.1).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err == nil {
|
||||
t.Errorf("expected error for temperature=2.1, got %v", logits)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSoftmax(t *testing.T) {
|
||||
probs, err := Softmax().Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
probs := softmax([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
|
||||
expectedProbs := []float64{0.000751406628089903, 0.0020425349829204676, 0.005552185728064613, 0.015092405572827691, 0.04102541181635154, 0.11151863144543739, 0.8240174238263085}
|
||||
if !floats.Equal(probs, expectedProbs) {
|
||||
t.Fatalf("logits: %v, expectedlogits: %v", probs, expectedProbs)
|
||||
if diff := cmp.Diff(expectedProbs, probs); diff != "" {
|
||||
t.Errorf("probs mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopK(t *testing.T) {
|
||||
logits, err := TopK(3).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
expectedlogits := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 1, 2, 4}
|
||||
if !floats.Same(logits, expectedlogits) {
|
||||
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
|
||||
if diff := cmp.Diff(expectedlogits, logits); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
logits, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
|
||||
_, err = TopK(0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for k=0, got %v", logits)
|
||||
t.Errorf("expected error for k=0, got %v", err)
|
||||
}
|
||||
|
||||
logits, err = TopK(10).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
expectedlogits = []float64{-3, -2, -1, 0, 1, 2, 4}
|
||||
if !floats.Same(logits, expectedlogits) {
|
||||
t.Fatalf("logits: %v, expectedlogits: %v", logits, expectedlogits)
|
||||
if diff := cmp.Diff(expectedlogits, logits); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
}
|
||||
|
||||
func TestTopP(t *testing.T) {
|
||||
logits, err := TopP(0.9).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 2, 4}
|
||||
if !floats.Same(logits, want) {
|
||||
t.Fatalf("got: %v, want: %v", logits, want)
|
||||
if diff := cmp.Diff(want, logits); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
logits, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
|
||||
_, err = TopP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for p=1.0, got %v", logits)
|
||||
t.Error("expected error for p=1.0")
|
||||
}
|
||||
logits, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
_, err = TopP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 4})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for p=0.0, got %v", logits)
|
||||
t.Error("expected error for p=0.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMinP(t *testing.T) {
|
||||
logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
||||
logits, err := MinP(0.2).Apply([]float64{-3, -2, -1, 0, 1, 2, 4, 3})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 3, 4}
|
||||
if !floats.Same(logits, want) {
|
||||
t.Fatalf("got: %v, want: %v", logits, want)
|
||||
want := []float64{math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), math.Inf(-1), 4, 3}
|
||||
if diff := cmp.Diff(want, logits); diff != "" {
|
||||
t.Errorf("logits mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
logits, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
||||
|
||||
_, err = MinP(1.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for p=1.0, got %v", logits)
|
||||
t.Error("expected error for p=1.0")
|
||||
}
|
||||
logits, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
||||
_, err = MinP(0.0).Apply([]float64{-3, -2, -1, 0, 1, 2, 3, 4})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for p=0.0, got %v", logits)
|
||||
t.Error("expected error for p=0.0")
|
||||
}
|
||||
}
|
||||
|
||||
func TestWeighed(t *testing.T) {
|
||||
idx, err := Weighed().Sample([]float64{math.Inf(-1), 2, math.Inf(-1), math.Inf(-1)})
|
||||
idx, err := Weighted(nil).Sample([]float32{float32(math.Inf(-1)), 2, float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
want := 1
|
||||
if idx != want {
|
||||
t.Fatalf("got: %v, want: %v", idx, want)
|
||||
if diff := cmp.Diff(want, idx); diff != "" {
|
||||
t.Errorf("index mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
idx, err = Weighed().Sample([]float64{math.Inf(-1), math.Inf(-1), math.Inf(-1)})
|
||||
|
||||
idx, err = Weighted(nil).Sample([]float32{float32(math.Inf(-1)), float32(math.Inf(-1)), float32(math.Inf(-1))})
|
||||
if err == nil {
|
||||
t.Fatalf("expected error for no valid tokens, got %v", idx)
|
||||
t.Error("expected error for no valid tokens, got index", idx)
|
||||
}
|
||||
}
|
||||
|
||||
@ -132,27 +145,32 @@ func TestSample(t *testing.T) {
|
||||
id: 3,
|
||||
callOrder: &callOrder,
|
||||
}
|
||||
sampler := NewSampler([]Transform{mock1, mock2, mock3}, Greedy())
|
||||
|
||||
got, err := sampler.Sample(input)
|
||||
got, err := Greedy().Sample(input, mock1, mock2, mock3)
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
if !slices.Equal(callOrder, []int{1, 2, 3}) {
|
||||
t.Errorf("got %v, want %v", callOrder, []int{1, 2, 3})
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
|
||||
want := 3 // Greedy sampler should pick highest logit
|
||||
if got != want {
|
||||
t.Errorf("got %v, want %v", got, want)
|
||||
if diff := cmp.Diff(want, got); diff != "" {
|
||||
t.Errorf("sampled index mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
_, err = Weighted(nil).Sample(input, mock1, mock2, mock3)
|
||||
if err != nil {
|
||||
t.Error(err)
|
||||
return
|
||||
}
|
||||
wantOrder := []int{1, 2, 3}
|
||||
if diff := cmp.Diff(wantOrder, callOrder); diff != "" {
|
||||
t.Errorf("call order mismatch (-want +got):\n%s", diff)
|
||||
}
|
||||
|
||||
errMock := &testTransform{
|
||||
returnErr: fmt.Errorf("mock error"),
|
||||
}
|
||||
sampler = NewSampler([]Transform{mock1, errMock, mock2}, Greedy())
|
||||
_, err = sampler.Sample(input)
|
||||
_, err = Weighted(nil).Sample(input, mock1, errMock, mock2)
|
||||
if err == nil {
|
||||
t.Error("Expected error from sampler")
|
||||
}
|
||||
@ -174,14 +192,51 @@ func (ts *testTransform) Apply(logits []float64) ([]float64, error) {
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func TestSampleTemperatureZero(t *testing.T) {
|
||||
sampler := NewSampler([]Transform{Temperature(0)}, Greedy())
|
||||
got, err := sampler.Sample([]float32{1, 2, 3, 4})
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
func BenchmarkTransform(b *testing.B) {
|
||||
transforms := map[string]Transform{
|
||||
"Temperature": Temperature(0.5),
|
||||
"TopK": TopK(10),
|
||||
"TopP": TopP(0.9),
|
||||
"MinP": MinP(0.2),
|
||||
}
|
||||
want := 3 // Greedy sampler should pick highest logit index
|
||||
if got != want {
|
||||
t.Fatalf("got: %v, want: %v", got, want)
|
||||
|
||||
logits := make([]float64, 1<<16)
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float64()
|
||||
}
|
||||
|
||||
for name, transform := range transforms {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
_, err := transform.Apply(logits)
|
||||
if err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkSample(b *testing.B) {
|
||||
samplers := map[string]Sampler{
|
||||
"Greedy": Greedy(),
|
||||
"Weighted": Weighted(nil),
|
||||
}
|
||||
|
||||
logits := make([]float32, 1<<16)
|
||||
for i := range logits {
|
||||
logits[i] = rand.Float32()
|
||||
}
|
||||
|
||||
for name, s := range samplers {
|
||||
b.Run(name, func(b *testing.B) {
|
||||
b.ResetTimer()
|
||||
for range b.N {
|
||||
if _, err := s.Sample(logits); err != nil {
|
||||
b.Error(err)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
@ -8,27 +8,45 @@ import (
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type SOSampler struct {
|
||||
type JSONSampler struct {
|
||||
schema *Schema
|
||||
propIdx int
|
||||
propToNodeMap map[string]*PDANode
|
||||
propToNodeMap map[string]*PDA
|
||||
pdaSampler *PushdownSampler
|
||||
decodedToks []string
|
||||
}
|
||||
|
||||
func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
|
||||
pdaSampler := NewPushdownSampler(proc)
|
||||
func NewJSONSampler(proc model.TextProcessor, schema *Schema) (*JSONSampler, error) {
|
||||
pdaSampler, err := NewPushdownSampler(proc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
so := &SOSampler{
|
||||
if schema == nil {
|
||||
return &JSONSampler{
|
||||
schema: nil,
|
||||
propIdx: -1,
|
||||
propToNodeMap: nil,
|
||||
pdaSampler: pdaSampler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
fmt.Println("schema not nil")
|
||||
so := &JSONSampler{
|
||||
schema: schema,
|
||||
propIdx: -1,
|
||||
propToNodeMap: make(map[string]*PDANode),
|
||||
propToNodeMap: make(map[string]*PDA),
|
||||
pdaSampler: pdaSampler,
|
||||
}
|
||||
|
||||
so.schemaToGraph()
|
||||
|
||||
// This is prob slow
|
||||
// Benchmark token decoding
|
||||
start := time.Now()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
|
||||
vocab := proc.GetVocabulary()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
@ -40,14 +58,18 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
||||
}
|
||||
so.decodedToks = decodedToks
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Token decode memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Token decode time = %v\n", time.Since(start))
|
||||
|
||||
fmt.Println("--------------------------------")
|
||||
fmt.Println("SOSampler")
|
||||
fmt.Println("--------------------------------")
|
||||
// Benchmark this section
|
||||
start := time.Now()
|
||||
var m runtime.MemStats
|
||||
start = time.Now()
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
before = m.Alloc
|
||||
|
||||
// TODO: still messed up
|
||||
// TODO: recursion use case
|
||||
@ -57,12 +79,12 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
||||
// propName -> node
|
||||
curState := node.State
|
||||
fromNode := node
|
||||
CreateMask(fromNode, proc, decodedToks)
|
||||
so.pdaSampler.CreateMask(fromNode)
|
||||
for curState == StateInStructuredKey {
|
||||
// there is only one edge
|
||||
for r, toNode := range fromNode.TransitionEdges {
|
||||
// fmt.Println("rune", r, "edge", toNode.State)
|
||||
CreateMask(toNode, proc, decodedToks)
|
||||
so.pdaSampler.CreateMask(toNode)
|
||||
fmt.Printf("created mask for %c\n", r)
|
||||
curState = toNode.State
|
||||
fmt.Println("next state", curState)
|
||||
@ -73,7 +95,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
||||
}
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
after = m.Alloc
|
||||
fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
||||
fmt.Println("--------------------------------")
|
||||
@ -81,7 +103,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
||||
return so, nil
|
||||
}
|
||||
|
||||
func (s *SOSampler) schemaToGraph() {
|
||||
func (s *JSONSampler) schemaToGraph() {
|
||||
schemaType := s.schema.EffectiveType()
|
||||
switch schemaType {
|
||||
case "object":
|
||||
@ -91,18 +113,18 @@ func (s *SOSampler) schemaToGraph() {
|
||||
for _, prop := range s.schema.Properties {
|
||||
// name of key
|
||||
name := prop.Name
|
||||
keyNode := &PDANode{
|
||||
keyNode := &PDA{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]*PDANode),
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
|
||||
prevNode := keyNode
|
||||
for _, r := range name {
|
||||
runeNode := &PDANode{
|
||||
runeNode := &PDA{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]*PDANode),
|
||||
TransitionEdges: make(map[rune]*PDA),
|
||||
MaskTokenIDToNode: make(map[int32]*PDA),
|
||||
}
|
||||
fmt.Println("runeNode created", runeNode.State)
|
||||
fmt.Printf("runeNode created %c\n", r)
|
||||
@ -117,9 +139,14 @@ func (s *SOSampler) schemaToGraph() {
|
||||
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||
}
|
||||
}
|
||||
// TODO: do values + recursion
|
||||
}
|
||||
|
||||
func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
|
||||
func (s *JSONSampler) Apply(logits []float64) ([]float64, error) {
|
||||
if s.schema == nil {
|
||||
return s.pdaSampler.Apply(logits)
|
||||
}
|
||||
|
||||
switch s.pdaSampler.curNode.State {
|
||||
// doesnt account for multi rune case
|
||||
case StateInObjectKey:
|
||||
@ -148,17 +175,18 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
|
||||
// todo: if i incremenet propidx then i know im in last value as well
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectEnd:
|
||||
fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State)
|
||||
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode)
|
||||
fmt.Println("<<<<< in obj end - generating mask for", s.pdaSampler.curNode.State)
|
||||
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDA)
|
||||
s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
||||
s.propIdx++
|
||||
|
||||
// TODO: this needs to be optimized in some way, computing mask on the fly is expensive
|
||||
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
||||
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
||||
delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
||||
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode)
|
||||
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDA)
|
||||
|
||||
CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks)
|
||||
s.pdaSampler.CreateMask(s.pdaSampler.curNode)
|
||||
s.propIdx++
|
||||
}
|
||||
}
|
||||
@ -167,12 +195,17 @@ func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
|
||||
|
||||
}
|
||||
|
||||
func (s *SOSampler) UpdateState(tokenSlice []int32) error {
|
||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
||||
err := s.pdaSampler.UpdateState(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if s.schema == nil {
|
||||
// Don't need to update state for unconstrained JSON sampling
|
||||
return nil
|
||||
}
|
||||
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectKey:
|
||||
s.propIdx++
|
||||
|
Loading…
x
Reference in New Issue
Block a user