diff --git a/model/cmd/main.go b/model/cmd/main.go index b7628e910..9d90b5f8e 100644 --- a/model/cmd/main.go +++ b/model/cmd/main.go @@ -106,11 +106,31 @@ func temp() error { } } - pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) + // pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) + + // simple schema + // This schema maps to JSON like: + // { + // "name": "some string value" + // } + schema := &sample.Schema{ + Name: "root", + Type: "object", + Properties: []*sample.Schema{ + {Name: "name", Type: "string"}, + }, + } + + pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor)) + if err != nil { + return err + } var offset int var stringBuffer string var firstTokenTime time.Duration + var totalSamplingTime time.Duration + count := 0 for range args.n { logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...) if err != nil { @@ -122,7 +142,6 @@ func temp() error { for i, f32 := range f32s { f64s[i] = float64(f32) } - sampleTime := time.Now() samplers := []sample.Sampler{ pushdownSampler, // sample.Weighed(), @@ -131,12 +150,16 @@ func temp() error { sample.Greedy(), } + samplingStart := time.Now() f64s, err = sample.Sample(f64s, samplers...) if err != nil { return err } - finishTime := time.Now() - fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) + samplingTime := time.Since(samplingStart) + totalSamplingTime += samplingTime + + fmt.Println("sampling time", samplingTime) + // fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) var outputIDs []int32 for _, f64 := range f64s { @@ -164,6 +187,7 @@ func temp() error { // fmt.Printf("--- token: %q\n", s) // fmt.Printf("--- outputIDs: %v\n", outputIDs) stringBuffer += s + count++ fmt.Println("--- stringBuffer", stringBuffer) err = pushdownSampler.UpdateState(outputIDs) @@ -179,7 +203,7 @@ func temp() error { fmt.Println("\n------ Output: ------") fmt.Println(stringBuffer) fmt.Println("--------------------") - + fmt.Println("sample average time", totalSamplingTime/time.Duration(count)) return nil } diff --git a/sample/fast_json.go b/sample/fast_json.go index bd80e8388..925de06ae 100644 --- a/sample/fast_json.go +++ b/sample/fast_json.go @@ -10,6 +10,8 @@ const ( StateStart JSONState = iota StateInObject StateInObjectKey + StateInStructuredKey + StateInStructuredValue StateNewline StateTab StateSpace @@ -43,6 +45,7 @@ var JSONStates = []JSONState{ StateStart, StateInObject, StateInObjectKey, + StateInStructuredKey, StateNewline, StateTab, StateSpace, @@ -80,6 +83,8 @@ func (s JSONState) String() string { return "StateInObject" case StateInObjectKey: return "StateInObjectKey" + case StateInStructuredKey: + return "StateInStructuredKey" case StateNewline: return "StateNewline" case StateTab: diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index 0d71878e4..d805288d9 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -21,14 +21,14 @@ var validNullRunes = []rune{'n', 'u', 'l', 'l'} type PDANode struct { State JSONState TransitionEdges map[rune]*PDANode - MaskTokenIDToNode map[int32]JSONState + MaskTokenIDToNode map[int32]*PDANode } func NewPDANode(state JSONState) *PDANode { return &PDANode{ State: state, TransitionEdges: make(map[rune]*PDANode), - MaskTokenIDToNode: make(map[int32]JSONState), + MaskTokenIDToNode: make(map[int32]*PDANode), } } @@ -103,6 +103,8 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err stateToNodeMap[StateInList].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateInList].TransitionEdges[' '] = stateToNodeMap[StateInList] stateToNodeMap[StateInList].TransitionEdges['\n'] = stateToNodeMap[StateInList] + // empty list + stateToNodeMap[StateInList].TransitionEdges[']'] = stateToNodeMap[StateInListEnd] addValueConnections(stateToNodeMap[StateInList], stateToNodeMap) // null node @@ -162,6 +164,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { // TODO: tough life fr. plz fix. func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { + // TODO; should come from top level vocab := proc.GetVocabulary() decodedToks := make([]string, len(vocab.Values)) @@ -175,7 +178,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex var err error for _, node := range stateToNodeMap { - err = createMask(node, proc, decodedToks, vocab) + err = CreateMask(node, proc, decodedToks, vocab) if err != nil { return err } @@ -183,7 +186,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex return nil } -func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error { +func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error { for i := range vocab.Values { token := decodedToks[i] // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON @@ -204,7 +207,8 @@ func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, v } } if valid { - node.MaskTokenIDToNode[int32(i)] = curNode.State + // cur node allows skipping + node.MaskTokenIDToNode[int32(i)] = curNode } } return nil diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index b0bf1a7e3..fe75eb864 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -4,7 +4,6 @@ import ( "fmt" "math" "runtime" - "time" "github.com/ollama/ollama/model" ) @@ -22,12 +21,15 @@ type PushdownSampler struct { // graph should be built once and reused per tokenizer func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { - start := time.Now() + // start := time.Now() + // fmt.Println("--------------------------------") + // fmt.Println("PDA sampler") + // fmt.Println("--------------------------------") var m runtime.MemStats runtime.ReadMemStats(&m) - before := m.Alloc - fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) + // before := m.Alloc + // fmt.Printf("Alloc = %.2f MB\n", float64(before)/(1024*1024)) startNode, stateToNodeMap, err := BuildGraph(proc) if err != nil { @@ -38,10 +40,10 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { panic(err) } runtime.ReadMemStats(&m) - after := m.Alloc - fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) - fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) - fmt.Printf("Graph build time = %v\n", time.Since(start)) + // after := m.Alloc + // fmt.Printf("Alloc = %.2f MB\n", float64(after)/(1024*1024)) + // fmt.Printf("Graph memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + // fmt.Printf("Graph build time = %v\n", time.Since(start)) return &PushdownSampler{ curNode: startNode, @@ -53,6 +55,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { } // TODO: need to add resampling logic if the first sample was not good +// greedy sample + backtrack? func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { // fmt.Println(">>> sample:", s.curNode.State) switch s.curNode.State { @@ -60,7 +63,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { return s.maskLogits(logits, s.curNode) case StateInListEnd: - fmt.Println("in list end", s.braceStack) + // fmt.Println("in list end", s.braceStack) // force finish if no braces left if len(s.braceStack) == 0 { s.curNode = NewPDANode(StateTerminate) @@ -139,12 +142,12 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { } func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { - fmt.Println("update state", s.curNode.State) + // fmt.Println("current state - updating", s.curNode.State) mappedString, err := s.proc.Decode(tokenSlice) if err != nil { return err } - fmt.Println("mappedString", mappedString) + // fmt.Println("mappedString", mappedString) // TODO: should force closing for all braces - not doing square yet for _, r := range mappedString { @@ -183,23 +186,25 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { for _, tokenID := range tokenSlice { // transition to the next node - nextNodeState, ok := s.curNode.MaskTokenIDToNode[tokenID] + nextNode, ok := s.curNode.MaskTokenIDToNode[tokenID] if !ok { return fmt.Errorf("invalid token: %q", mappedString) } // fmt.Println("transitioning to", nextNodeState) // TODO: add a penalty for staying in the same state too long - if nextNodeState == s.curNode.State { + if nextNode.State == s.curNode.State { s.stateCounter++ } else { s.stateCounter = 0 } - s.curNode = s.stateToNodeMap[nextNodeState] + s.curNode = nextNode + // fmt.Println("updated curNode state", s.curNode.State) } return nil } +// greedy sample + backtrack? func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) { // TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode // Should be possible through bitwise ops as well diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index 8cb6a50d8..309ead4aa 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -1,54 +1,145 @@ package sample -import "github.com/ollama/ollama/model" +import ( + "fmt" + "runtime" + "time" -type StructuredOutput struct { - schema *Schema - stateToNodeMap map[JSONState]*PDANode + "github.com/ollama/ollama/model" +) + +type SOSampler struct { + schema *Schema + propIdx int + propStateMap map[string]*PDANode + pdaSampler *PushdownSampler } -func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput { - _, stateToNodeMap, err := BuildGraph(proc) - if err != nil { - panic(err) +func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) { + pdaSampler := NewPushdownSampler(proc) + + so := &SOSampler{ + schema: schema, + propIdx: -1, + propStateMap: make(map[string]*PDANode), + pdaSampler: pdaSampler, } - return &StructuredOutput{ - schema: schema, - stateToNodeMap: stateToNodeMap, + so.schemaToGraph() + + vocab := proc.GetVocabulary() + decodedToks := make([]string, len(vocab.Values)) + for i := range vocab.Values { + token, err := proc.Decode([]int32{int32(i)}) + if err != nil { + return nil, err + } + decodedToks[i] = token } -} -func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode { + fmt.Println("--------------------------------") + fmt.Println("SOSampler") + fmt.Println("--------------------------------") + // Benchmark this section + start := time.Now() + var m runtime.MemStats + runtime.ReadMemStats(&m) + before := m.Alloc - schemaType := so.schema.EffectiveType() - switch schemaType { - case "object": - // each prop is a key - // prevState := StateInObjectKey - for _, prop := range so.schema.Properties { - // name of key - name := prop.Name - prevState := StateInObjectKey - for i, r := range name { - newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune - - // Create new node for this state if it doesn't exist - if _, exists := so.stateToNodeMap[newState]; !exists { - so.stateToNodeMap[newState] = &PDANode{ - State: newState, - TransitionEdges: make(map[rune]*PDANode), - MaskTokenIDToNode: make(map[int32]JSONState), - } - } - - // Connect previous state to this state via the rune - so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState] - prevState = newState + // TODO: still messed up + for _, node := range so.propStateMap { + // propName -> node + curState := node.State + fromNode := node + CreateMask(fromNode, proc, decodedToks, vocab) + for curState == StateInStructuredKey { + // there is only one edge + for r, toNode := range fromNode.TransitionEdges { + // fmt.Println("rune", r, "edge", toNode.State) + CreateMask(toNode, proc, decodedToks, vocab) + fmt.Printf("created mask for %c\n", r) + curState = toNode.State + fmt.Println("next state", curState) + // TODO: theres an extra gen for " right now + fromNode = toNode } - // type of value - // propType := prop.Type } } - return nil + + runtime.ReadMemStats(&m) + after := m.Alloc + fmt.Printf("Mask creation memory usage = %.2f MB\n", float64(after-before)/(1024*1024)) + fmt.Printf("Mask creation time = %v\n", time.Since(start)) + fmt.Println("--------------------------------") + + return so, nil +} + +func (s *SOSampler) schemaToGraph() { + schemaType := s.schema.EffectiveType() + switch schemaType { + case "object": + // TODO: see if we need to connect these to the JSON graph + // prevState := StateInObjectKey + // prevNode := so.stateToNodeMap[prevState] + + // each prop is a key + for _, prop := range s.schema.Properties { + // name of key + name := prop.Name + // prevState := StateInObjectKey + keyNode := &PDANode{ + State: StateInStructuredKey, // this is unchanging, will impact sampling + TransitionEdges: make(map[rune]*PDANode), + MaskTokenIDToNode: make(map[int32]*PDANode), + } + + prevNode := keyNode + for _, r := range name { + runeNode := &PDANode{ + State: StateInStructuredKey, // this is unchanging, will impact sampling + TransitionEdges: make(map[rune]*PDANode), + MaskTokenIDToNode: make(map[int32]*PDANode), + } + fmt.Println("runeNode created", runeNode.State) + fmt.Printf("runeNode created %c\n", r) + // since alloc on heap connections wil still map + prevNode.TransitionEdges[r] = runeNode + prevNode = runeNode + } + // point to end of object key node after all chars are done + prevNode.TransitionEdges['"'] = s.pdaSampler.stateToNodeMap[StateInObjectKeyEnd] + // points to start of the key + s.propStateMap[name] = keyNode + fmt.Println("name", name, "keyNode", keyNode.State) + } + } +} + +func (s *SOSampler) Sample(logits []float64) ([]float64, error) { + switch s.pdaSampler.curNode.State { + // doesnt account for multi rune case + case StateInObjectKey: + // fmt.Println("in object key - structured outputs") + // TODO: this tracking should probably be coming from a stack to track nested objects + // simple case + s.propIdx++ + prop := s.schema.Properties[s.propIdx] + // fmt.Println("prop", prop.Name) + s.pdaSampler.curNode = s.propStateMap[prop.Name] + // fmt.Println("changed curNode state to", s.pdaSampler.curNode.State) + logits, err := s.pdaSampler.maskLogits(logits, s.pdaSampler.curNode) + if err != nil { + return nil, err + } + return logits, nil + + default: + return s.pdaSampler.Sample(logits) + } + +} + +func (s *SOSampler) UpdateState(tokenSlice []int32) error { + return s.pdaSampler.UpdateState(tokenSlice) }