mirror of
https://github.com/ollama/ollama.git
synced 2025-07-23 08:25:38 +02:00
265 lines
6.7 KiB
Go
265 lines
6.7 KiB
Go
package sample
|
|
|
|
import (
|
|
"fmt"
|
|
"math"
|
|
"runtime"
|
|
"time"
|
|
|
|
"github.com/ollama/ollama/model"
|
|
)
|
|
|
|
// TODO: safety in case of invalid json
|
|
// TODO: partial JSON matching?
|
|
// TODO: interfaces to cleanup with return values
|
|
// TODO this interface shouldn't be the sampler - should just use Sampler
|
|
// TODO: add penalties for string \n stuff
|
|
// TODO: minimize number of fwd passes if there is only one match
|
|
// TODO: greedy sample initially and then backtrack if no match
|
|
|
|
type PushdownSampler struct {
|
|
PDAGraphBuilder
|
|
curNode *PDA
|
|
braceStack []rune
|
|
stateCounter uint32
|
|
}
|
|
|
|
// graph should be built once and reused per tokenizer
|
|
func NewPushdownSampler(proc model.TextProcessor) (*PushdownSampler, error) {
|
|
start := time.Now()
|
|
|
|
fmt.Println("--------------------------------")
|
|
fmt.Println("PDA sampler")
|
|
fmt.Println("--------------------------------")
|
|
var m runtime.MemStats
|
|
runtime.ReadMemStats(&m)
|
|
before := m.Alloc
|
|
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
|
|
|
vocab := proc.Vocab()
|
|
decodedToks := make([]string, len(vocab.Values))
|
|
for i := range vocab.Values {
|
|
token, err := proc.Decode([]int32{int32(i)})
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
decodedToks[i] = token
|
|
}
|
|
|
|
gb := &PDAGraphBuilder{
|
|
proc: proc,
|
|
decodedToks: decodedToks,
|
|
}
|
|
|
|
if err := gb.BuildGraph(); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
runtime.ReadMemStats(&m)
|
|
after := m.Alloc
|
|
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
|
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
|
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
|
|
|
// TODO: this can be simplified
|
|
return &PushdownSampler{
|
|
curNode: gb.stateToNodeMap[StateStart],
|
|
PDAGraphBuilder: *gb,
|
|
braceStack: []rune{},
|
|
stateCounter: 0,
|
|
}, nil
|
|
}
|
|
|
|
// TODO: need to add resampling logic if the first sample was not good
|
|
// greedy sample + backtrack?
|
|
func (s *PushdownSampler) Apply(logits []float32) ([]float32, error) {
|
|
switch s.curNode.State {
|
|
case StateInString:
|
|
return s.maskLogits(logits, s.curNode)
|
|
|
|
case StateInListEnd:
|
|
// force finish if no braces left
|
|
if len(s.braceStack) == 0 {
|
|
s.curNode = NewPDANode(StateTerminate)
|
|
return forceFinish(s, logits)
|
|
}
|
|
|
|
logits, err := s.maskLogits(logits, s.curNode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return logits, nil
|
|
|
|
case StateTerminate:
|
|
return forceFinish(s, logits)
|
|
|
|
case StateInObjectEnd:
|
|
// force finish if no braces left
|
|
if len(s.braceStack) == 0 {
|
|
s.curNode = NewPDANode(StateTerminate)
|
|
return forceFinish(s, logits)
|
|
}
|
|
|
|
peek := s.braceStack[len(s.braceStack)-1]
|
|
if peek == rune('[') {
|
|
s.curNode = s.stateToNodeMap[StateInListObjectEnd]
|
|
}
|
|
|
|
logits, err := s.maskLogits(logits, s.curNode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return logits, nil
|
|
|
|
case StateInComma:
|
|
peek := s.braceStack[len(s.braceStack)-1]
|
|
if peek == rune('[') {
|
|
s.curNode = s.stateToNodeMap[StateInListComma]
|
|
}
|
|
|
|
logits, err := s.maskLogits(logits, s.curNode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return logits, nil
|
|
|
|
default:
|
|
fmt.Println("masking logits current state", s.curNode.State)
|
|
logits, err := s.maskLogits(logits, s.curNode)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return logits, nil
|
|
}
|
|
}
|
|
|
|
func forceFinish(s *PushdownSampler, logits []float32) ([]float32, error) {
|
|
for i := range logits {
|
|
if s.proc.Is(int32(i), model.SpecialEOS) {
|
|
logits[i] = 1.0
|
|
} else {
|
|
logits[i] = float32(math.Inf(-1))
|
|
}
|
|
}
|
|
return logits, nil
|
|
}
|
|
|
|
func (s *PushdownSampler) UpdateState(tokenSlice []int32) ([]int32, error) {
|
|
fmt.Println("current state - updating", s.curNode.State)
|
|
mappedString, err := s.proc.Decode(tokenSlice)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
fmt.Printf(">>> mappedString: %q\n", mappedString)
|
|
|
|
// Special handling for EOS token in terminate state
|
|
if s.curNode.State == StateTerminate {
|
|
for _, tokenID := range tokenSlice {
|
|
if s.proc.Is(tokenID, model.SpecialEOS) {
|
|
return tokenSlice, nil
|
|
}
|
|
}
|
|
}
|
|
|
|
// flag := -1
|
|
// endBraceRunes := []rune{'}', ']'}
|
|
for _, r := range mappedString {
|
|
// TODO: if this is enabled again, make sure to appropriately handle the state transitions
|
|
// if slices.Contains(endBraceRunes, r) && len(s.braceStack) == 0 {
|
|
// fmt.Printf("stack is empty, extra closing brace %c\n", r)
|
|
// // flag = i
|
|
// break
|
|
|
|
// }
|
|
if r == rune('{') {
|
|
s.braceStack = append(s.braceStack, r)
|
|
}
|
|
if r == rune('[') {
|
|
s.braceStack = append(s.braceStack, r)
|
|
}
|
|
if r == rune('}') {
|
|
if len(s.braceStack) == 0 {
|
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
|
}
|
|
top := s.braceStack[len(s.braceStack)-1]
|
|
if top != rune('{') {
|
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{')
|
|
}
|
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
|
}
|
|
|
|
if r == rune(']') {
|
|
if len(s.braceStack) == 0 {
|
|
return nil, fmt.Errorf("stack is empty, extra closing brace %c", r)
|
|
}
|
|
top := s.braceStack[len(s.braceStack)-1]
|
|
if top != rune('[') {
|
|
return nil, fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[')
|
|
}
|
|
s.braceStack = s.braceStack[:len(s.braceStack)-1]
|
|
}
|
|
}
|
|
|
|
// if flag != -1 {
|
|
// tokenSlice = tokenSlice[:flag]
|
|
// }
|
|
// fmt.Println("flag!", flag)
|
|
for _, tokenID := range tokenSlice {
|
|
// transition to the next node
|
|
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
|
if !ok {
|
|
return nil, fmt.Errorf("invalid token: %q", mappedString)
|
|
}
|
|
fmt.Println("transitioning to", nextNode.State)
|
|
|
|
// TODO: add a penalty for staying in the same state too long
|
|
if nextNode.State == s.curNode.State {
|
|
s.stateCounter++
|
|
} else {
|
|
s.stateCounter = 0
|
|
}
|
|
s.curNode = nextNode
|
|
fmt.Println("updated curNode state", s.curNode.State)
|
|
}
|
|
return tokenSlice, nil
|
|
}
|
|
|
|
// greedy sample + backtrack?
|
|
func (s *PushdownSampler) maskLogits(logits []float32, node *PDA) ([]float32, error) {
|
|
// Create a new slice with same length as logits, initialized to -Inf
|
|
maskedLogits := make([]float32, len(logits))
|
|
for i := range maskedLogits {
|
|
maskedLogits[i] = float32(math.Inf(-1))
|
|
}
|
|
|
|
// Only update values for valid token IDs from the mask map
|
|
for tokenID := range node.MaskTokenIDToNode {
|
|
if int(tokenID) < len(logits) {
|
|
maskedLogits[tokenID] = logits[tokenID]
|
|
}
|
|
}
|
|
|
|
return maskedLogits, nil
|
|
}
|
|
|
|
func (s *PushdownSampler) fastMaskLogits(logits []float32, node *PDA) ([]float32, error) {
|
|
maxLogit := float32(math.Inf(-1))
|
|
maxIndex := -1
|
|
|
|
// Find the maximum logit value among valid tokens
|
|
for tokenID := range node.MaskTokenIDToNode {
|
|
if int(tokenID) < len(logits) && logits[tokenID] > maxLogit {
|
|
maxLogit = logits[tokenID]
|
|
maxIndex = int(tokenID)
|
|
}
|
|
}
|
|
|
|
if maxIndex == -1 {
|
|
return nil, fmt.Errorf("no valid tokens found in mask")
|
|
}
|
|
|
|
logits[0] = float32(maxIndex)
|
|
return logits, nil
|
|
// return maxIndex, nil
|
|
}
|