mirror of
https://github.com/ollama/ollama.git
synced 2025-04-06 19:08:27 +02:00
tested with so
This commit is contained in:
parent
b973dedb4b
commit
524029cd6d
@ -106,11 +106,31 @@ func temp() error {
|
||||
}
|
||||
}
|
||||
|
||||
pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
||||
// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
||||
|
||||
// simple schema
|
||||
// This schema maps to JSON like:
|
||||
// {
|
||||
// "name": "some string value"
|
||||
// }
|
||||
schema := &sample.Schema{
|
||||
Name: "root",
|
||||
Type: "object",
|
||||
Properties: []*sample.Schema{
|
||||
{Name: "name", Type: "string"},
|
||||
},
|
||||
}
|
||||
|
||||
pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var offset int
|
||||
var stringBuffer string
|
||||
var firstTokenTime time.Duration
|
||||
var totalSamplingTime time.Duration
|
||||
count := 0
|
||||
for range args.n {
|
||||
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
||||
if err != nil {
|
||||
@ -122,7 +142,6 @@ func temp() error {
|
||||
for i, f32 := range f32s {
|
||||
f64s[i] = float64(f32)
|
||||
}
|
||||
sampleTime := time.Now()
|
||||
samplers := []sample.Sampler{
|
||||
pushdownSampler,
|
||||
// sample.Weighed(),
|
||||
@ -131,12 +150,16 @@ func temp() error {
|
||||
sample.Greedy(),
|
||||
}
|
||||
|
||||
samplingStart := time.Now()
|
||||
f64s, err = sample.Sample(f64s, samplers...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
finishTime := time.Now()
|
||||
fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
||||
samplingTime := time.Since(samplingStart)
|
||||
totalSamplingTime += samplingTime
|
||||
|
||||
fmt.Println("sampling time", samplingTime)
|
||||
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
||||
|
||||
var outputIDs []int32
|
||||
for _, f64 := range f64s {
|
||||
@ -164,6 +187,7 @@ func temp() error {
|
||||
// fmt.Printf("--- token: %q\n", s)
|
||||
// fmt.Printf("--- outputIDs: %v\n", outputIDs)
|
||||
stringBuffer += s
|
||||
count++
|
||||
fmt.Println("--- stringBuffer", stringBuffer)
|
||||
|
||||
err = pushdownSampler.UpdateState(outputIDs)
|
||||
@ -179,7 +203,7 @@ func temp() error {
|
||||
fmt.Println("\n------ Output: ------")
|
||||
fmt.Println(stringBuffer)
|
||||
fmt.Println("--------------------")
|
||||
|
||||
fmt.Println("sample average time", totalSamplingTime/time.Duration(count))
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -10,6 +10,8 @@ const (
|
||||
StateStart JSONState = iota
|
||||
StateInObject
|
||||
StateInObjectKey
|
||||
StateInStructuredKey
|
||||
StateInStructuredValue
|
||||
StateNewline
|
||||
StateTab
|
||||
StateSpace
|
||||
@ -43,6 +45,7 @@ var JSONStates = []JSONState{
|
||||
StateStart,
|
||||
StateInObject,
|
||||
StateInObjectKey,
|
||||
StateInStructuredKey,
|
||||
StateNewline,
|
||||
StateTab,
|
||||
StateSpace,
|
||||
@ -80,6 +83,8 @@ func (s JSONState) String() string {
|
||||
return "StateInObject"
|
||||
case StateInObjectKey:
|
||||
return "StateInObjectKey"
|
||||
case StateInStructuredKey:
|
||||
return "StateInStructuredKey"
|
||||
case StateNewline:
|
||||
return "StateNewline"
|
||||
case StateTab:
|
||||
|
@ -21,14 +21,14 @@ var validNullRunes = []rune{'n', 'u', 'l', 'l'}
|
||||
type PDANode struct {
|
||||
State JSONState
|
||||
TransitionEdges map[rune]*PDANode
|
||||
MaskTokenIDToNode map[int32]JSONState
|
||||
MaskTokenIDToNode map[int32]*PDANode
|
||||
}
|
||||
|
||||
func NewPDANode(state JSONState) *PDANode {
|
||||
return &PDANode{
|
||||
State: state,
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||
MaskTokenIDToNode: make(map[int32]*PDANode),
|
||||
}
|
||||
}
|
||||
|
||||
@ -103,6 +103,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList]
|
||||
// empty list
|
||||
stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd]
|
||||
addValueConnections(stateToNodeMap[StateInList], stateToNodeMap)
|
||||
|
||||
// null node
|
||||
@ -162,6 +164,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) {
|
||||
// TODO: tough life fr. plz fix.
|
||||
func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error {
|
||||
|
||||
// TODO; should come from top level
|
||||
vocab := proc.GetVocabulary()
|
||||
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
@ -175,7 +178,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
||||
|
||||
var err error
|
||||
for _, node := range stateToNodeMap {
|
||||
err = createMask(node, proc, decodedToks, vocab)
|
||||
err = CreateMask(node, proc, decodedToks, vocab)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -183,7 +186,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
||||
return nil
|
||||
}
|
||||
|
||||
func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
|
||||
func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
|
||||
for i := range vocab.Values {
|
||||
token := decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
@ -204,7 +207,8 @@ func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, v
|
||||
}
|
||||
}
|
||||
if valid {
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode.State
|
||||
// cur node allows skipping
|
||||
node.MaskTokenIDToNode[int32(i)] = curNode
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
@ -4,7 +4,6 @@ import (
|
||||
"fmt"
|
||||
"math"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
@ -22,12 +21,15 @@ type PushdownSampler struct {
|
||||
|
||||
// graph should be built once and reused per tokenizer
|
||||
func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
start := time.Now()
|
||||
// start := time.Now()
|
||||
|
||||
// fmt.Println("--------------------------------")
|
||||
// fmt.Println("PDA sampler")
|
||||
// fmt.Println("--------------------------------")
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
// before := m.Alloc
|
||||
// fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024))
|
||||
|
||||
startNode, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
@ -38,10 +40,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
panic(err)
|
||||
}
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
// after := m.Alloc
|
||||
// fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024))
|
||||
// fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
// fmt.Printf("Graph build time = %v\n", time.Since(start))
|
||||
|
||||
return &PushdownSampler{
|
||||
curNode: startNode,
|
||||
@ -53,6 +55,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
}
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
// fmt.Println(">>> sample:", s.curNode.State)
|
||||
switch s.curNode.State {
|
||||
@ -60,7 +63,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
return s.maskLogits(logits, s.curNode)
|
||||
|
||||
case StateInListEnd:
|
||||
fmt.Println("in list end", s.braceStack)
|
||||
// fmt.Println("in list end", s.braceStack)
|
||||
// force finish if no braces left
|
||||
if len(s.braceStack) == 0 {
|
||||
s.curNode = NewPDANode(StateTerminate)
|
||||
@ -139,12 +142,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
}
|
||||
|
||||
func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
fmt.Println("update state", s.curNode.State)
|
||||
// fmt.Println("current state - updating", s.curNode.State)
|
||||
mappedString, err := s.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("mappedString", mappedString)
|
||||
// fmt.Println("mappedString", mappedString)
|
||||
|
||||
// TODO: should force closing for all braces - not doing square yet
|
||||
for _, r := range mappedString {
|
||||
@ -183,23 +186,25 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
|
||||
for _, tokenID := range tokenSlice {
|
||||
// transition to the next node
|
||||
nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID]
|
||||
if !ok {
|
||||
return fmt.Errorf("invalid token: %q", mappedString)
|
||||
}
|
||||
// fmt.Println("transitioning to", nextNodeState)
|
||||
|
||||
// TODO: add a penalty for staying in the same state too long
|
||||
if nextNodeState == s.curNode.State {
|
||||
if nextNode.State == s.curNode.State {
|
||||
s.stateCounter++
|
||||
} else {
|
||||
s.stateCounter = 0
|
||||
}
|
||||
s.curNode = s.stateToNodeMap[nextNodeState]
|
||||
s.curNode = nextNode
|
||||
// fmt.Println("updated curNode state", s.curNode.State)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
|
||||
// Should be possible through bitwise ops as well
|
||||
|
@ -1,54 +1,145 @@
|
||||
package sample
|
||||
|
||||
import "github.com/ollama/ollama/model"
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"time"
|
||||
|
||||
type StructuredOutput struct {
|
||||
schema *Schema
|
||||
stateToNodeMap map[JSONState]*PDANode
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
type SOSampler struct {
|
||||
schema *Schema
|
||||
propIdx int
|
||||
propStateMap map[string]*PDANode
|
||||
pdaSampler *PushdownSampler
|
||||
}
|
||||
|
||||
func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput {
|
||||
_, stateToNodeMap, err := BuildGraph(proc)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
|
||||
pdaSampler := NewPushdownSampler(proc)
|
||||
|
||||
so := &SOSampler{
|
||||
schema: schema,
|
||||
propIdx: -1,
|
||||
propStateMap: make(map[string]*PDANode),
|
||||
pdaSampler: pdaSampler,
|
||||
}
|
||||
|
||||
return &StructuredOutput{
|
||||
schema: schema,
|
||||
stateToNodeMap: stateToNodeMap,
|
||||
so.schemaToGraph()
|
||||
|
||||
vocab := proc.GetVocabulary()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
token, err := proc.Decode([]int32{int32(i)})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
}
|
||||
|
||||
func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode {
|
||||
fmt.Println("--------------------------------")
|
||||
fmt.Println("SOSampler")
|
||||
fmt.Println("--------------------------------")
|
||||
// Benchmark this section
|
||||
start := time.Now()
|
||||
var m runtime.MemStats
|
||||
runtime.ReadMemStats(&m)
|
||||
before := m.Alloc
|
||||
|
||||
schemaType := so.schema.EffectiveType()
|
||||
switch schemaType {
|
||||
case "object":
|
||||
// each prop is a key
|
||||
// prevState := StateInObjectKey
|
||||
for _, prop := range so.schema.Properties {
|
||||
// name of key
|
||||
name := prop.Name
|
||||
prevState := StateInObjectKey
|
||||
for i, r := range name {
|
||||
newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune
|
||||
|
||||
// Create new node for this state if it doesn't exist
|
||||
if _, exists := so.stateToNodeMap[newState]; !exists {
|
||||
so.stateToNodeMap[newState] = &PDANode{
|
||||
State: newState,
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]JSONState),
|
||||
}
|
||||
}
|
||||
|
||||
// Connect previous state to this state via the rune
|
||||
so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState]
|
||||
prevState = newState
|
||||
// TODO: still messed up
|
||||
for _, node := range so.propStateMap {
|
||||
// propName -> node
|
||||
curState := node.State
|
||||
fromNode := node
|
||||
CreateMask(fromNode, proc, decodedToks, vocab)
|
||||
for curState == StateInStructuredKey {
|
||||
// there is only one edge
|
||||
for r, toNode := range fromNode.TransitionEdges {
|
||||
// fmt.Println("rune", r, "edge", toNode.State)
|
||||
CreateMask(toNode, proc, decodedToks, vocab)
|
||||
fmt.Printf("created mask for %c\n", r)
|
||||
curState = toNode.State
|
||||
fmt.Println("next state", curState)
|
||||
// TODO: theres an extra gen for " right now
|
||||
fromNode = toNode
|
||||
}
|
||||
// type of value
|
||||
// propType := prop.Type
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
runtime.ReadMemStats(&m)
|
||||
after := m.Alloc
|
||||
fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024))
|
||||
fmt.Printf("Mask creation time = %v\n", time.Since(start))
|
||||
fmt.Println("--------------------------------")
|
||||
|
||||
return so, nil
|
||||
}
|
||||
|
||||
func (s *SOSampler) schemaToGraph() {
|
||||
schemaType := s.schema.EffectiveType()
|
||||
switch schemaType {
|
||||
case "object":
|
||||
// TODO: see if we need to connect these to the JSON graph
|
||||
// prevState := StateInObjectKey
|
||||
// prevNode := so.stateToNodeMap[prevState]
|
||||
|
||||
// each prop is a key
|
||||
for _, prop := range s.schema.Properties {
|
||||
// name of key
|
||||
name := prop.Name
|
||||
// prevState := StateInObjectKey
|
||||
keyNode := &PDANode{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]*PDANode),
|
||||
}
|
||||
|
||||
prevNode := keyNode
|
||||
for _, r := range name {
|
||||
runeNode := &PDANode{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
MaskTokenIDToNode: make(map[int32]*PDANode),
|
||||
}
|
||||
fmt.Println("runeNode created", runeNode.State)
|
||||
fmt.Printf("runeNode created %c\n", r)
|
||||
// since alloc on heap connections wil still map
|
||||
prevNode.TransitionEdges[r] = runeNode
|
||||
prevNode = runeNode
|
||||
}
|
||||
// point to end of object key node after all chars are done
|
||||
prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd]
|
||||
// points to start of the key
|
||||
s.propStateMap[name] = keyNode
|
||||
fmt.Println("name", name, "keyNode", keyNode.State)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
|
||||
switch s.pdaSampler.curNode.State {
|
||||
// doesnt account for multi rune case
|
||||
case StateInObjectKey:
|
||||
// fmt.Println("in object key - structured outputs")
|
||||
// TODO: this tracking should probably be coming from a stack to track nested objects
|
||||
// simple case
|
||||
s.propIdx++
|
||||
prop := s.schema.Properties[s.propIdx]
|
||||
// fmt.Println("prop", prop.Name)
|
||||
s.pdaSampler.curNode = s.propStateMap[prop.Name]
|
||||
// fmt.Println("changed curNode state to", s.pdaSampler.curNode.State)
|
||||
logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
return s.pdaSampler.Sample(logits)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *SOSampler) UpdateState(tokenSlice []int32) error {
|
||||
return s.pdaSampler.UpdateState(tokenSlice)
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user