mirror of
https://github.com/ollama/ollama.git
synced 2025-04-04 09:58:31 +02:00
wip
This commit is contained in:
parent
198fde82aa
commit
c56a8b7749
@ -104,6 +104,8 @@ func temp() error {
|
||||
}
|
||||
}
|
||||
|
||||
pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
||||
var stringBuffer string
|
||||
var offset int
|
||||
for range args.n {
|
||||
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
||||
@ -118,7 +120,10 @@ func temp() error {
|
||||
}
|
||||
|
||||
// do sampling
|
||||
f64s, err = sample.Sample(f64s, sample.Greedy())
|
||||
// []ints back
|
||||
// ints map to sampled logits
|
||||
f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy())
|
||||
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -129,6 +134,7 @@ func temp() error {
|
||||
outputIDs = append(outputIDs, int32(f64))
|
||||
}
|
||||
}
|
||||
pdaSampler.UpdateState(outputIDs)
|
||||
|
||||
if len(outputIDs) == 0 {
|
||||
break
|
||||
@ -141,8 +147,9 @@ func temp() error {
|
||||
return err
|
||||
}
|
||||
|
||||
fmt.Print(s)
|
||||
|
||||
// fmt.Print(s)
|
||||
stringBuffer += s
|
||||
fmt.Println("--- stringBuffer", stringBuffer)
|
||||
inputIDs = append(inputIDs, outputIDs...)
|
||||
if args.cache {
|
||||
offset = len(inputIDs) - 1
|
||||
|
1
model/cmd/test.go
Normal file
1
model/cmd/test.go
Normal file
@ -0,0 +1 @@
|
||||
package main
|
@ -21,6 +21,7 @@ type TextProcessor interface {
|
||||
Encode(string) ([]int32, error)
|
||||
Decode([]int32) (string, error)
|
||||
Is(uint32, Special) bool
|
||||
GetVocabulary() *Vocabulary
|
||||
}
|
||||
|
||||
type Vocabulary struct {
|
||||
@ -104,6 +105,10 @@ type BytePairEncoding struct {
|
||||
*Vocabulary
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) GetVocabulary() *Vocabulary {
|
||||
return bpe.Vocabulary
|
||||
}
|
||||
|
||||
func (bpe BytePairEncoding) split(s string) ([]string, error) {
|
||||
re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2)
|
||||
if err != nil {
|
||||
|
@ -1,11 +1,7 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type JSONState int
|
||||
@ -136,219 +132,3 @@ func (s JSONState) String() string {
|
||||
return fmt.Sprintf("Unknown state: %d", s)
|
||||
}
|
||||
}
|
||||
|
||||
type JSONSampler struct {
|
||||
curNode *Node
|
||||
proc model.TextProcessor
|
||||
stack []*Node
|
||||
bracketCounter int
|
||||
}
|
||||
|
||||
func NewJSONSampler(proc model.TextProcessor) (*JSONSampler, error) {
|
||||
// fmt.Println("Creating new JSON sampler")
|
||||
startNode, err := buildStateMachine(proc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
js := &JSONSampler{
|
||||
curNode: startNode,
|
||||
proc: proc,
|
||||
stack: []*Node{},
|
||||
bracketCounter: 0,
|
||||
}
|
||||
|
||||
return js, nil
|
||||
}
|
||||
|
||||
func isTokenSubset(subset, superset []int32) bool {
|
||||
freq1 := make(map[int32]int)
|
||||
freq2 := make(map[int32]int)
|
||||
|
||||
for _, v := range subset {
|
||||
freq1[v]++
|
||||
}
|
||||
for _, v := range superset {
|
||||
freq2[v]++
|
||||
}
|
||||
isSubset := true
|
||||
for k, count1 := range freq1 {
|
||||
count2 := freq2[k]
|
||||
if count1 > count2 {
|
||||
isSubset = false
|
||||
break
|
||||
}
|
||||
}
|
||||
return isSubset
|
||||
}
|
||||
|
||||
func (s *JSONSampler) UpdateState(tokenSlice []int32) error {
|
||||
// fmt.Printf("Updating state with token: %v\n", tokenSlice)
|
||||
// fmt.Printf("Current state: %s\n", s.curNode.State)
|
||||
|
||||
// fmt.Println("tokenSlice", tokenSlice)
|
||||
// todo: account for strings here
|
||||
|
||||
objectTokens, err := ComputeTokenVariants([]string{"{", " {", "{\n", " {\n"}, s.proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// only move to terminate state if stack is empty
|
||||
if s.curNode.State == StateInObjectEnd {
|
||||
fmt.Println("debug: node.State", s.curNode.State)
|
||||
if len(s.stack) > 0 {
|
||||
s.stack = s.stack[:len(s.stack)-1]
|
||||
fmt.Println("popped and cur state", s.curNode.State)
|
||||
return nil
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
for node, edge := range s.curNode.TransitionEdges {
|
||||
for _, validToken := range edge {
|
||||
if isTokenSubset(tokenSlice, validToken) {
|
||||
s.curNode = node
|
||||
for _, token := range objectTokens {
|
||||
if isTokenSubset(tokenSlice, token) {
|
||||
fmt.Println("Appending to stack", s.curNode.State)
|
||||
s.stack = append(s.stack, s.curNode)
|
||||
}
|
||||
}
|
||||
// fmt.Printf("Transitioned to state: %s\n", node.State)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
for node, edge := range s.curNode.TransitionEdges {
|
||||
for _, validToken := range edge {
|
||||
if len(validToken) == 1 && validToken[0] == -1 || validToken[0] == -2 {
|
||||
s.curNode = node
|
||||
// fmt.Printf("Accepting any token, staying in state: %s\n", node.State)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Println("invalid token ", tokenSlice)
|
||||
dec, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("decoded token ", dec)
|
||||
return errors.New("invalid token")
|
||||
}
|
||||
|
||||
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
||||
fmt.Printf("Sampling in state: %s\n", s.curNode.State)
|
||||
var err error
|
||||
|
||||
switch s.curNode.State {
|
||||
case StateTerminate:
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInInt:
|
||||
validStates := []int32{}
|
||||
minus, err := s.proc.Encode("-")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
digits := make([][]int32, 10)
|
||||
for i := 0; i < 10; i++ {
|
||||
digits[i], err = s.proc.Encode(fmt.Sprintf("%d", i))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
// Allow "-" and digits 0-9 at start
|
||||
for i := range logits {
|
||||
for _, d := range digits {
|
||||
if len(d) == 1 && int32(i) == d[0] {
|
||||
validStates = append(validStates, int32(i))
|
||||
}
|
||||
}
|
||||
if len(minus) == 1 && int32(i) == minus[0] {
|
||||
validStates = append(validStates, int32(i))
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
case StateInString:
|
||||
penalizeNewlineVariants := []string{"\n", " \"\n"}
|
||||
penalizeNewlineToks, err := ComputeTokenVariants(penalizeNewlineVariants, s.proc)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
penalizeNewlineToks = append(penalizeNewlineToks, []int32{702})
|
||||
logits, err = s.maskSpecificLogits(logits, penalizeNewlineToks)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
validStates := getValidStates(s.curNode)
|
||||
logits, err = s.maskLogits(logits, validStates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
validStates := getValidStates(s.curNode)
|
||||
logits, err = s.maskLogits(logits, validStates)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
}
|
||||
|
||||
func getValidStates(node *Node) []int32 {
|
||||
validStates := []int32{}
|
||||
for _, edge := range node.TransitionEdges {
|
||||
for _, token := range edge {
|
||||
validStates = append(validStates, token...)
|
||||
}
|
||||
}
|
||||
return validStates
|
||||
}
|
||||
|
||||
func (s *JSONSampler) maskLogits(logits []float64, validStates []int32) ([]float64, error) {
|
||||
// fmt.Printf("Masking logits with valid states: %v\n", validStates)
|
||||
// todo: this can prob be more efficient
|
||||
for i := range logits {
|
||||
isValid := false
|
||||
for _, token := range validStates {
|
||||
if token == -1 {
|
||||
// fmt.Println("Found sentinel token, returning unmasked logits")
|
||||
return logits, nil
|
||||
}
|
||||
if i == int(token) {
|
||||
// fmt.Printf("Found valid token: %d\n", token)
|
||||
isValid = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !isValid {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *JSONSampler) maskSpecificLogits(logits []float64, tokensToMask []token) ([]float64, error) {
|
||||
// fmt.Printf("Masking specific logits: %v\n", tokensToMask)
|
||||
for i := range logits {
|
||||
for _, token := range tokensToMask {
|
||||
for _, chunked := range token {
|
||||
if int(chunked) == i {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
296
sample/hid.txt
Normal file
296
sample/hid.txt
Normal file
@ -0,0 +1,296 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
var stringInvalidRunes = []rune{'\\', '\n', '\t', '{', '}', ':', ','}
|
||||
|
||||
var intInvalidRunes = []rune{'e', 'E', ' ', '\n', '\t', '{', '}', ':', ',', '"'}
|
||||
var validIntRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '-'}
|
||||
|
||||
var validNumberRunes = []rune{'0', '1', '2', '3', '4', '5', '6', '7', '8', '9', '.', '-', '+', 'e', 'E'}
|
||||
|
||||
var validBoolRunes = []rune{'t', 'r', 'u', 'e', 'f', 'a', 'l', 's', 'e'}
|
||||
|
||||
var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||
|
||||
type PDANode struct {
|
||||
State JSONState
|
||||
TransitionEdges map[rune]*PDANode
|
||||
MaskTokenIDToNode map[int32]JSONState
|
||||
}
|
||||
|
||||
func NewPDANode(state JSONState) *PDANode {
|
||||
return &PDANode{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||
}
|
||||
}
|
||||
|
||||
func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, error) {
|
||||
stateToNodeMap := make(map[JSONState]*PDANode)
|
||||
|
||||
startNode := NewPDANode(StateStart)
|
||||
stateToNodeMap[StateStart] = startNode
|
||||
|
||||
objNode := NewPDANode(StateInObject)
|
||||
stateToNodeMap[StateInObject] = objNode
|
||||
|
||||
objEndNode := NewPDANode(StateInObjectEnd)
|
||||
stateToNodeMap[StateInObjectEnd] = objEndNode
|
||||
|
||||
objKeyNode := NewPDANode(StateInObjectKey)
|
||||
stateToNodeMap[StateInObjectKey] = objKeyNode
|
||||
|
||||
objKeyEndNode := NewPDANode(StateInObjectKeyEnd)
|
||||
stateToNodeMap[StateInObjectKeyEnd] = objKeyEndNode
|
||||
|
||||
colonNode := NewPDANode(StateInColon)
|
||||
stateToNodeMap[StateInColon] = colonNode
|
||||
|
||||
commaNode := NewPDANode(StateInComma)
|
||||
stateToNodeMap[StateInComma] = commaNode
|
||||
|
||||
newlineNode := NewPDANode(StateInNewline)
|
||||
stateToNodeMap[StateInNewline] = newlineNode
|
||||
|
||||
spaceNode := NewPDANode(StateInSpace)
|
||||
stateToNodeMap[StateInSpace] = spaceNode
|
||||
|
||||
spaceObjNode := NewPDANode(StateInObjSpace)
|
||||
stateToNodeMap[StateInObjSpace] = spaceObjNode
|
||||
|
||||
tabNode := NewPDANode(StateInTab)
|
||||
stateToNodeMap[StateInTab] = tabNode
|
||||
|
||||
stringNode := NewPDANode(StateInString)
|
||||
stateToNodeMap[StateInString] = stringNode
|
||||
|
||||
stringEndNode := NewPDANode(StateInStringEnd)
|
||||
stateToNodeMap[StateInStringEnd] = stringEndNode
|
||||
|
||||
listNode := NewPDANode(StateInList)
|
||||
stateToNodeMap[StateInList] = listNode
|
||||
|
||||
listCommaNode := NewPDANode(StateInListComma)
|
||||
stateToNodeMap[StateInListComma] = listCommaNode
|
||||
|
||||
listEndNode := NewPDANode(StateListEnd)
|
||||
stateToNodeMap[StateListEnd] = listEndNode
|
||||
|
||||
numberNode := NewPDANode(StateInNumber)
|
||||
stateToNodeMap[StateInNumber] = numberNode
|
||||
|
||||
boolNode := NewPDANode(StateInBool)
|
||||
stateToNodeMap[StateInBool] = boolNode
|
||||
|
||||
nullNode := NewPDANode(StateInNull)
|
||||
stateToNodeMap[StateInNull] = nullNode
|
||||
|
||||
// Defined with structured outputs only
|
||||
intNode := NewPDANode(StateInInt)
|
||||
stateToNodeMap[StateInInt] = intNode
|
||||
|
||||
// TODO:
|
||||
// consider adding a node to just point to values, could be good to compute that
|
||||
// mask rather than many different nodes
|
||||
|
||||
// Connect nodes
|
||||
// TODO: if all are single tokens then this can just be connected instead of defining the token
|
||||
startNode.TransitionEdges['{'] = objNode
|
||||
|
||||
objNode.TransitionEdges['"'] = objKeyNode
|
||||
objNode.TransitionEdges['\n'] = newlineNode
|
||||
// objNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
newlineNode.TransitionEdges['"'] = objKeyNode
|
||||
newlineNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
tabNode.TransitionEdges['"'] = objKeyNode
|
||||
// tabNode.TransitionEdges['\t'] = tabNode
|
||||
|
||||
objKeyNode.TransitionEdges[rune(-1)] = objKeyNode
|
||||
objKeyNode.TransitionEdges['"'] = objKeyEndNode
|
||||
|
||||
objKeyEndNode.TransitionEdges[':'] = colonNode
|
||||
objEndNode.TransitionEdges[' '] = spaceNode
|
||||
|
||||
// where values should be
|
||||
// this could be combined but the probs might change, we're alr doing a skip ahead
|
||||
colonNode.TransitionEdges[' '] = spaceNode
|
||||
|
||||
// Leads to a value
|
||||
spaceNode.TransitionEdges['"'] = stringNode
|
||||
spaceNode.TransitionEdges['['] = listNode
|
||||
spaceNode.TransitionEdges['{'] = objNode
|
||||
|
||||
for _, r := range validNumberRunes {
|
||||
spaceNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
spaceNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
|
||||
for _, r := range validNullRunes {
|
||||
spaceNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
|
||||
// Values
|
||||
// string node
|
||||
stringNode.TransitionEdges[rune(-1)] = stringNode
|
||||
stringNode.TransitionEdges['"'] = stringEndNode
|
||||
|
||||
stringEndNode.TransitionEdges[','] = commaNode
|
||||
stringEndNode.TransitionEdges['}'] = objEndNode
|
||||
stringEndNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
// TODO: add counters for allowable number of decimals, e, E, etc
|
||||
// number node
|
||||
for _, r := range validNumberRunes {
|
||||
numberNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
numberNode.TransitionEdges[','] = commaNode
|
||||
numberNode.TransitionEdges['}'] = objEndNode
|
||||
numberNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
for _, r := range validBoolRunes {
|
||||
boolNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
|
||||
// list node
|
||||
listNode.TransitionEdges[','] = commaNode
|
||||
listNode.TransitionEdges['"'] = stringNode
|
||||
// squash states to a value
|
||||
for _, r := range validNumberRunes {
|
||||
listNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
listNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
for _, r := range validNullRunes {
|
||||
listNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
|
||||
// null node
|
||||
for _, r := range validNullRunes {
|
||||
nullNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
nullNode.TransitionEdges[','] = commaNode
|
||||
nullNode.TransitionEdges['}'] = objEndNode
|
||||
nullNode.TransitionEdges[']'] = listEndNode
|
||||
|
||||
// list comma
|
||||
// should point to values
|
||||
listCommaNode.TransitionEdges['"'] = stringNode
|
||||
listCommaNode.TransitionEdges[' '] = listCommaNode
|
||||
listCommaNode.TransitionEdges['{'] = objNode
|
||||
listCommaNode.TransitionEdges['\n'] = newlineNode
|
||||
|
||||
for _, r := range validNumberRunes {
|
||||
listCommaNode.TransitionEdges[r] = numberNode
|
||||
}
|
||||
for _, r := range validBoolRunes {
|
||||
listCommaNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
for _, r := range validNullRunes {
|
||||
listCommaNode.TransitionEdges[r] = nullNode
|
||||
}
|
||||
|
||||
// bool node
|
||||
for _, r := range validBoolRunes {
|
||||
boolNode.TransitionEdges[r] = boolNode
|
||||
}
|
||||
boolNode.TransitionEdges['}'] = objEndNode
|
||||
boolNode.TransitionEdges[']'] = listEndNode
|
||||
boolNode.TransitionEdges[','] = commaNode
|
||||
|
||||
listEndNode.TransitionEdges['}'] = objEndNode
|
||||
listEndNode.TransitionEdges[','] = commaNode
|
||||
|
||||
commaNode.TransitionEdges['{'] = objNode
|
||||
commaNode.TransitionEdges['\n'] = newlineNode
|
||||
commaNode.TransitionEdges['\t'] = tabNode
|
||||
commaNode.TransitionEdges['"'] = objKeyNode
|
||||
commaNode.TransitionEdges[' '] = spaceObjNode
|
||||
|
||||
spaceObjNode.TransitionEdges['"'] = objKeyNode
|
||||
|
||||
return startNode, stateToNodeMap, nil
|
||||
}
|
||||
|
||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||
|
||||
vocab := proc.GetVocabulary()
|
||||
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
|
||||
var err error
|
||||
for _, node := range stateToNodeMap {
|
||||
for i := range vocab.Values {
|
||||
token := decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" {
|
||||
continue
|
||||
}
|
||||
valid := true
|
||||
curNode := node
|
||||
consumedSpecialRunes := make(map[rune]bool)
|
||||
for _, r := range token {
|
||||
valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if !valid {
|
||||
break
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) {
|
||||
if consumedSpecialRunes[r] {
|
||||
return false, nil, nil
|
||||
}
|
||||
|
||||
specialRune := slices.Contains(stringInvalidRunes, r)
|
||||
if specialRune {
|
||||
if curNode.State == StateInString || curNode.State == StateInObjectKey {
|
||||
return false, nil, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check for specific rune transition
|
||||
if nextNode, ok := curNode.TransitionEdges[r]; ok {
|
||||
if specialRune {
|
||||
if curNode.State == nextNode.State {
|
||||
return false, nil, nil
|
||||
}
|
||||
// fmt.Println("special rune", r, "consumed")
|
||||
consumedSpecialRunes[r] = true
|
||||
}
|
||||
return true, nextNode, nil
|
||||
}
|
||||
|
||||
// Check for sentinel value - if present, any rune is valid
|
||||
if nextNode, ok := curNode.TransitionEdges[rune(-1)]; ok {
|
||||
return true, nextNode, nil
|
||||
}
|
||||
|
||||
return false, nil, nil
|
||||
}
|
@ -1,104 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type JSONState int
|
||||
|
||||
const (
|
||||
StateStart JSONState = iota // Initial state
|
||||
StateInObject // Inside an object {}
|
||||
StateInArray // Inside an array []
|
||||
StateInString // Inside a string ""
|
||||
StateAfterKey // After object key, expecting :
|
||||
StateAfterColon // After :, expecting value
|
||||
StateAfterValue // After value, expecting , or closing bracket
|
||||
StateDone // JSON parsing complete
|
||||
)
|
||||
|
||||
type JSONSampler struct {
|
||||
state JSONState
|
||||
stack []string
|
||||
proc model.TextProcessor
|
||||
}
|
||||
|
||||
func NewJSONSampler(proc model.TextProcessor) *JSONSampler {
|
||||
return &JSONSampler{
|
||||
state: StateStart,
|
||||
proc: proc,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *JSONSampler) Sample(logits []float64) ([]float64, error) {
|
||||
// Pre-decode valid tokens for current state
|
||||
validTokens := make(map[uint32]bool)
|
||||
|
||||
// Always allow EOS token in any state
|
||||
// TODO: Check for other special tokens if needed
|
||||
for i := range logits {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
}
|
||||
|
||||
// Build set of valid tokens based on current state
|
||||
switch s.state {
|
||||
case StateStart:
|
||||
// Only allow opening brace
|
||||
for i := range logits {
|
||||
text, err := s.proc.Decode([]int32{int32(i)})
|
||||
if err == nil && text == "{" {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
}
|
||||
case StateInObject, StateInArray:
|
||||
// Allow any token
|
||||
for i := range logits {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
case StateInString:
|
||||
// Allow any token except closing brace
|
||||
for i := range logits {
|
||||
text, err := s.proc.Decode([]int32{int32(i)})
|
||||
if err == nil && text != "}" {
|
||||
validTokens[uint32(i)] = true
|
||||
}
|
||||
}
|
||||
case StateDone:
|
||||
// No tokens allowed
|
||||
}
|
||||
|
||||
// Mark invalid tokens as NaN in one pass
|
||||
for i := range logits {
|
||||
if !validTokens[uint32(i)] {
|
||||
logits[i] = math.NaN()
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
}
|
||||
|
||||
func (s *JSONSampler) UpdateState(tokenID int) error {
|
||||
text, err := s.proc.Decode([]int32{int32(tokenID)})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to decode token: %w", err)
|
||||
}
|
||||
|
||||
switch s.state {
|
||||
case StateStart:
|
||||
if text != "{" {
|
||||
return fmt.Errorf("expected {, got %s", text)
|
||||
}
|
||||
s.state = StateInObject
|
||||
case StateInObject:
|
||||
if text == "}" {
|
||||
s.state = StateDone
|
||||
}
|
||||
case StateDone:
|
||||
return fmt.Errorf("unexpected token after closing bracket: %s", text)
|
||||
}
|
||||
return nil
|
||||
}
|
@ -165,9 +165,10 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
||||
if len(logitsCopy) == 0 {
|
||||
return nil, errors.New("no valid tokens found")
|
||||
}
|
||||
|
||||
// usually, a softmax is applied to sample from the logits
|
||||
// in this case the uv sampler normalizes the logits so that the sum of the weights is 1
|
||||
logitsCopy, err := computeSoftmax(logitsCopy)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
w := sampleuv.NewWeighted(logitsCopy, nil)
|
||||
if v, ok := w.Take(); ok {
|
||||
// returns the token ID
|
||||
@ -176,17 +177,6 @@ func (s weighed) Sample(logits []float64) ([]float64, error) {
|
||||
return nil, errors.New("weighed sampler failed")
|
||||
}
|
||||
|
||||
// TODO: remove after next PR merge
|
||||
type greedy struct{}
|
||||
|
||||
func Greedy() Sampler {
|
||||
return greedy{}
|
||||
}
|
||||
|
||||
func (greedy) Sample(logits []float64) ([]float64, error) {
|
||||
return []float64{float64(floats.MaxIdx(logits))}, nil
|
||||
}
|
||||
|
||||
func Sample(logits []float64, samplers ...Sampler) ([]float64, error) {
|
||||
var err error
|
||||
for _, sampler := range samplers {
|
||||
|
@ -3,14 +3,9 @@ package sample
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand"
|
||||
"os"
|
||||
"runtime"
|
||||
"slices"
|
||||
"testing"
|
||||
|
||||
"runtime/trace"
|
||||
|
||||
"gonum.org/v1/gonum/floats"
|
||||
)
|
||||
|
||||
|
@ -1,218 +0,0 @@
|
||||
package sample
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type token []int32
|
||||
|
||||
type Node struct {
|
||||
State JSONState
|
||||
TransitionEdges map[*Node][]token
|
||||
}
|
||||
|
||||
func NewNode(state JSONState) *Node {
|
||||
return &Node{
|
||||
State: state,
|
||||
TransitionEdges: make(map[*Node][]token),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// startToken token
|
||||
startTokenVariants []token
|
||||
// endToken token
|
||||
// stringToken token
|
||||
// objectKeyToken token
|
||||
tabToken token
|
||||
spaceToken token
|
||||
newlineToken token
|
||||
newlineSpace token
|
||||
// commaToken token
|
||||
// commaToken2 token
|
||||
// commaToken3 token
|
||||
// colonToken token
|
||||
// colonToken2 token
|
||||
colonTokenVariants []token
|
||||
commaTokenVariants []token
|
||||
stringTokenVariants []token
|
||||
endTokenVariants []token
|
||||
objectKeyTokenVariants []token
|
||||
objKeyToColonVariants []token
|
||||
stringToObjectKeyVariants []token
|
||||
stringToCommaVariants []token
|
||||
stringToObjectVariants []token
|
||||
stringEndToObjectEndVariants []token
|
||||
stringEndToCommaVariants []token
|
||||
)
|
||||
|
||||
func ComputeTokenVariants(variants []string, proc model.TextProcessor) ([]token, error) {
|
||||
var allTokens token
|
||||
for _, variant := range variants {
|
||||
if t, err := proc.Encode(variant); err == nil {
|
||||
allTokens = append(allTokens, t...)
|
||||
}
|
||||
}
|
||||
if len(allTokens) == 0 {
|
||||
return nil, fmt.Errorf("no valid tokens found for variants")
|
||||
}
|
||||
return []token{allTokens}, nil
|
||||
}
|
||||
func initTokens(proc model.TextProcessor) error {
|
||||
var err error
|
||||
|
||||
s, err := proc.Decode([]int32{761})
|
||||
fmt.Printf("761 decoded %q\n", s)
|
||||
|
||||
// Compute start token variants
|
||||
startVariants := []string{"{", " {", "{\n", " {\n"}
|
||||
startTokenVariants, err = ComputeTokenVariants(startVariants, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// Compute end token variants
|
||||
endVariants := []string{"}", " }", "}\n", " }\n"}
|
||||
endTokenVariants, err = ComputeTokenVariants(endVariants, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compute string token variants
|
||||
// TODO: removed \n
|
||||
stringVariants := []string{"\"", " \""}
|
||||
stringTokenVariants, err = ComputeTokenVariants(stringVariants, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\",\n"}, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
// objectKeyTokenVariants = []token{stringTokenVariants[0], stringTokenVariants[1]}
|
||||
objectKeyTokenVariants = stringTokenVariants
|
||||
// Compute whitespace tokens
|
||||
tabToken, err = proc.Encode("\t")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
spaceToken, err = proc.Encode(" ")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newlineToken, err = proc.Encode("\n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
newlineSpace, err = proc.Encode(" \n")
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compute colon variants
|
||||
colonVariants := []string{":"}
|
||||
colonTokenVariants, err = ComputeTokenVariants(colonVariants, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
objKeyToColonVariants, err = ComputeTokenVariants([]string{"\":"}, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Compute comma variants
|
||||
commaVariants := []string{",", " ,", ",\n", "\",", "\", "}
|
||||
commaTokenVariants, err = ComputeTokenVariants(commaVariants, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Printf("commaTokenVariants: %v\n", commaTokenVariants)
|
||||
stringToCommaVariants, err = ComputeTokenVariants([]string{"\",", "\","}, proc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stringEndToCommaVariants, err = ComputeTokenVariants([]string{",", ",\n"}, proc)
|
||||
stringToObjectKeyVariants, err = ComputeTokenVariants([]string{"\",", ",\n", "\","}, proc)
|
||||
stringToObjectVariants, err = ComputeTokenVariants([]string{"\",\n"}, proc)
|
||||
stringEndToObjectEndVariants, err = ComputeTokenVariants([]string{"\n"}, proc)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func buildStateMachine(proc model.TextProcessor) (*Node, error) {
|
||||
if err := initTokens(proc); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
startNode := NewNode(StateStart)
|
||||
objectNode := NewNode(StateInObject)
|
||||
objectKeyNode := NewNode(StateInObjectKey)
|
||||
objectKeyEndNode := NewNode(StateInObjectKeyEnd)
|
||||
stringNode := NewNode(StateInString)
|
||||
// intNode := NewNode(StateInInt)
|
||||
commaNode := NewNode(StateInComma)
|
||||
colonNode := NewNode(StateInColon)
|
||||
stringEndNode := NewNode(StateInStringEnd)
|
||||
endNode := NewNode(StateEnd)
|
||||
terminateNode := NewNode(StateTerminate)
|
||||
|
||||
sentinelToken := token([]int32{-1})
|
||||
// intSentinelToken := token([]int32{-2})
|
||||
|
||||
// TODO: cleanup connections of rules
|
||||
startNode.TransitionEdges[objectNode] = startTokenVariants
|
||||
|
||||
objectNode.TransitionEdges[objectKeyNode] = stringTokenVariants
|
||||
objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
||||
objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
||||
|
||||
// objectNode.TransitionEdges[objectNode] = []token{newlineToken}
|
||||
// objectNode.TransitionEdges[objectNode] = []token{spaceToken}
|
||||
|
||||
objectKeyNode.TransitionEdges[objectKeyNode] = []token{sentinelToken}
|
||||
// characterize end of object key
|
||||
objectKeyNode.TransitionEdges[objectKeyEndNode] = stringTokenVariants
|
||||
objectKeyNode.TransitionEdges[colonNode] = objKeyToColonVariants
|
||||
|
||||
// TODO: enable this - key -> object
|
||||
// objectKeyNode.TransitionEdges[objectNode] = startTokenVariants
|
||||
|
||||
// objectKeyNode.TransitionEdges[intNode] = []token{sentinelToken}
|
||||
|
||||
// intNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
||||
// intNode.TransitionEdges[commaNode] = commaTokenVariants
|
||||
// TODO: handle
|
||||
// intNode.TransitionEdges[terminateNode] = endTokenVariants
|
||||
|
||||
commaNode.TransitionEdges[objectKeyNode] = stringTokenVariants
|
||||
// commaNode.TransitionEdges[objectNode] = startTokenVariants
|
||||
|
||||
colonNode.TransitionEdges[stringNode] = stringTokenVariants
|
||||
//TODO: enable
|
||||
// colonNode.TransitionEdges[intNode] = []token{intSentinelToken}
|
||||
colonNode.TransitionEdges[objectNode] = startTokenVariants
|
||||
|
||||
stringNode.TransitionEdges[stringNode] = []token{sentinelToken}
|
||||
stringNode.TransitionEdges[stringEndNode] = stringTokenVariants
|
||||
// TODO: "\""," Case not accounted for
|
||||
stringNode.TransitionEdges[commaNode] = stringToCommaVariants
|
||||
|
||||
// TODO: "\"",\"" Case not accounted for
|
||||
stringNode.TransitionEdges[objectNode] = stringToObjectVariants
|
||||
|
||||
stringEndNode.TransitionEdges[commaNode] = stringEndToCommaVariants
|
||||
stringEndNode.TransitionEdges[objectNode] = stringToObjectKeyVariants
|
||||
stringEndNode.TransitionEdges[endNode] = stringEndToObjectEndVariants
|
||||
// stringEndNode.TransitionEdges[terminateNode] = endTokenVariants
|
||||
|
||||
// Should be obj end
|
||||
// TODO: handle
|
||||
endNode.TransitionEdges[terminateNode] = []token{}
|
||||
|
||||
endNode.TransitionEdges[commaNode] = commaTokenVariants
|
||||
|
||||
terminateNode.TransitionEdges[terminateNode] = []token{}
|
||||
return startNode, nil
|
||||
}
|
@ -7,92 +7,6 @@ type StructuredOutput struct {
|
||||
}
|
||||
|
||||
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode {
|
||||
// _, stateToNodeMap, err := BuildGraph(proc)
|
||||
// if err != nil {
|
||||
// panic(err)
|
||||
// }
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// func constrainGraph(graph *PDANode, schema *Schema) *PDANode {
|
||||
// // If no schema constraints, return original graph node
|
||||
// if schema == nil {
|
||||
// return graph
|
||||
// }
|
||||
|
||||
// // Create a new node with same state
|
||||
// constrainedNode := NewPDANode(graph.State)
|
||||
|
||||
// // Copy over existing transitions and masks
|
||||
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
||||
// for r, node := range graph.TransitionEdges {
|
||||
// constrainedNode.TransitionEdges[r] = node
|
||||
// }
|
||||
// constrainedNode.MaskTokenIDToNode = graph.MaskTokenIDToNode
|
||||
|
||||
// // Apply schema constraints based on type
|
||||
// switch schema.EffectiveType() {
|
||||
// case "object":
|
||||
// // Only allow defined property names in object keys
|
||||
// if graph.State == StateInObjectKey {
|
||||
// // TODO: Add property name validation
|
||||
// }
|
||||
|
||||
// // Constrain property values based on schema
|
||||
// if graph.State == StateInColon || graph.State == StateInSpace {
|
||||
// // Clear transitions to only allow valid types
|
||||
// constrainedNode.TransitionEdges = make(map[rune]*PDANode)
|
||||
|
||||
// // Add transitions based on property schemas
|
||||
// for _, prop := range schema.Properties {
|
||||
// switch prop.EffectiveType() {
|
||||
// case "object":
|
||||
// if objNode, ok := graph.TransitionEdges['{']; ok {
|
||||
// constrainedNode.TransitionEdges['{'] = constrainGraph(objNode, prop)
|
||||
// }
|
||||
// case "array":
|
||||
// if arrNode, ok := graph.TransitionEdges['[']; ok {
|
||||
// constrainedNode.TransitionEdges['['] = constrainGraph(arrNode, prop)
|
||||
// }
|
||||
// case "string":
|
||||
// if strNode, ok := graph.TransitionEdges['"']; ok {
|
||||
// constrainedNode.TransitionEdges['"'] = constrainGraph(strNode, prop)
|
||||
// }
|
||||
// case "number":
|
||||
// for _, r := range validNumberRunes {
|
||||
// if numNode, ok := graph.TransitionEdges[r]; ok {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(numNode, prop)
|
||||
// }
|
||||
// }
|
||||
// case "integer":
|
||||
// for _, r := range validIntRunes {
|
||||
// if intNode, ok := graph.TransitionEdges[r]; ok {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(intNode, prop)
|
||||
// }
|
||||
// }
|
||||
// case "boolean":
|
||||
// for _, r := range []rune{'t', 'f'} {
|
||||
// if boolNode, ok := graph.TransitionEdges[r]; ok {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(boolNode, prop)
|
||||
// }
|
||||
// }
|
||||
// case "null":
|
||||
// if nullNode, ok := graph.TransitionEdges['n']; ok {
|
||||
// constrainedNode.TransitionEdges['n'] = constrainGraph(nullNode, prop)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// case "array":
|
||||
// // Constrain array items based on schema
|
||||
// if schema.Items != nil {
|
||||
// for r, node := range graph.TransitionEdges {
|
||||
// constrainedNode.TransitionEdges[r] = constrainGraph(node, schema.Items)
|
||||
// }
|
||||
// }
|
||||
// }
|
||||
|
||||
// return constrainedNode
|
||||
// }
|
||||
|
BIN
sample/trace.out
Normal file
BIN
sample/trace.out
Normal file
Binary file not shown.
Loading…
x
Reference in New Issue
Block a user