diff --git a/go.mod b/go.mod index 78150c1ef..a4fbf3c32 100644 --- a/go.mod +++ b/go.mod @@ -24,8 +24,8 @@ require ( github.com/nlpodyssey/gopickle v0.3.0 github.com/pdevine/tensor v0.0.0-20240510204454-f88f4562727c golang.org/x/image v0.22.0 - gonum.org/v1/gonum v0.15.0 golang.org/x/tools v0.28.0 + gonum.org/v1/gonum v0.15.0 ) require ( @@ -72,7 +72,7 @@ require ( golang.org/x/arch v0.8.0 // indirect golang.org/x/crypto v0.31.0 golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa - golang.org/x/net v0.25.0 // indirect + golang.org/x/net v0.32.0 // indirect golang.org/x/sys v0.28.0 golang.org/x/term v0.27.0 golang.org/x/text v0.21.0 diff --git a/model/cmd/main.go b/model/cmd/main.go index c349e20d4..b7628e910 100644 --- a/model/cmd/main.go +++ b/model/cmd/main.go @@ -10,6 +10,7 @@ import ( "os" "path/filepath" "strings" + "time" "github.com/ollama/ollama/cache" "github.com/ollama/ollama/ml" @@ -27,6 +28,7 @@ var args struct { } func temp() error { + start := time.Now() flag.IntVar(&args.n, "n", 10, "number of samples") flag.BoolVar(&args.debug, "debug", false, "enable debug logging") flag.StringVar(&args.image, "image", "", "path to image file") @@ -104,9 +106,11 @@ func temp() error { } } - pdaSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) - var stringBuffer string + pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) + var offset int + var stringBuffer string + var firstTokenTime time.Duration for range args.n { logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...) if err != nil { @@ -118,15 +122,21 @@ func temp() error { for i, f32 := range f32s { f64s[i] = float64(f32) } + sampleTime := time.Now() + samplers := []sample.Sampler{ + pushdownSampler, + // sample.Weighed(), + // sample.TopP(0.9), + // sample.Weighed(), + sample.Greedy(), + } - // do sampling - // []ints back - // ints map to sampled logits - f64s, err = sample.Sample(f64s, pdaSampler, sample.Greedy()) - + f64s, err = sample.Sample(f64s, samplers...) if err != nil { return err } + finishTime := time.Now() + fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) var outputIDs []int32 for _, f64 := range f64s { @@ -134,7 +144,6 @@ func temp() error { outputIDs = append(outputIDs, int32(f64)) } } - pdaSampler.UpdateState(outputIDs) if len(outputIDs) == 0 { break @@ -147,14 +156,29 @@ func temp() error { return err } - // fmt.Print(s) + if firstTokenTime == 0 { + firstTokenTime = time.Since(start) + fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds()) + } + + // fmt.Printf("--- token: %q\n", s) + // fmt.Printf("--- outputIDs: %v\n", outputIDs) stringBuffer += s fmt.Println("--- stringBuffer", stringBuffer) + + err = pushdownSampler.UpdateState(outputIDs) + if err != nil { + return err + } + inputIDs = append(inputIDs, outputIDs...) if args.cache { offset = len(inputIDs) - 1 } } + fmt.Println("\n------ Output: ------") + fmt.Println(stringBuffer) + fmt.Println("--------------------") return nil } diff --git a/model/process_text.go b/model/process_text.go index f335f59c5..85f86cbdb 100644 --- a/model/process_text.go +++ b/model/process_text.go @@ -21,6 +21,7 @@ type TextProcessor interface { Encode(string) ([]int32, error) Decode([]int32) (string, error) Is(uint32, Special) bool + GetVocabulary() *Vocabulary } @@ -99,16 +100,16 @@ func (v *Vocabulary) Merge(left, right string) int { return -1 } +func (v *Vocabulary) GetVocabulary() *Vocabulary { + return v +} + type BytePairEncoding struct { Pretokenizer string *Vocabulary } -func (bpe BytePairEncoding) GetVocabulary() *Vocabulary { - return bpe.Vocabulary -} - func (bpe BytePairEncoding) split(s string) ([]string, error) { re, err := regexp2.Compile(bpe.Pretokenizer, regexp2.Unicode|regexp2.RE2) if err != nil { diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index 7bdecf38d..0d71878e4 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -44,8 +44,6 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err // consider adding a node to just point to values, could be good to compute that // mask rather than many different nodes - // Connect nodes - // TODO: if all are single tokens then this can just be connected instead of defining the token stateToNodeMap[StateStart].TransitionEdges['{'] = stateToNodeMap[StateInObject] stateToNodeMap[StateStart].TransitionEdges['['] = stateToNodeMap[StateInList] @@ -161,6 +159,7 @@ func addValueConnections(node *PDANode, stateToNodeMap map[JSONState]*PDANode) { node.TransitionEdges['n'] = stateToNodeMap[StateInNull] } +// TODO: tough life fr. plz fix. func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.TextProcessor) error { vocab := proc.GetVocabulary() @@ -176,33 +175,42 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex var err error for _, node := range stateToNodeMap { - for i := range vocab.Values { - token := decodedToks[i] - // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON - if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" { - continue - } - valid := true - curNode := node - consumedSpecialRunes := make(map[rune]bool) - for _, r := range token { - valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes) - if err != nil { - return err - } - if !valid { - break - } - } - if valid { - node.MaskTokenIDToNode[int32(i)] = curNode.State - } + err = createMask(node, proc, decodedToks, vocab) + if err != nil { + return err } } return nil } -// garbage interface plz fix +func createMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error { + for i := range vocab.Values { + token := decodedToks[i] + // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON + if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" { + continue + } + valid := true + curNode := node + consumedSpecialRunes := make(map[rune]bool) + var err error + for _, r := range token { + valid, curNode, err = isRuneValid(r, curNode, consumedSpecialRunes) + if err != nil { + return err + } + if !valid { + break + } + } + if valid { + node.MaskTokenIDToNode[int32(i)] = curNode.State + } + } + return nil +} + +// TODO: garbage interface plz fix func isRuneValid(r rune, curNode *PDANode, consumedSpecialRunes map[rune]bool) (bool, *PDANode, error) { if consumedSpecialRunes[r] { return false, nil, nil diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index e1d2cce65..b0bf1a7e3 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -52,6 +52,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { } } +// TODO: need to add resampling logic if the first sample was not good func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { // fmt.Println(">>> sample:", s.curNode.State) switch s.curNode.State { @@ -156,8 +157,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { // fmt.Println("pushing [ brace stack", r) } if r == rune('}') { + if len(s.braceStack) == 0 { + return fmt.Errorf("stack is empty, extra closing brace %c", r) + } top := s.braceStack[len(s.braceStack)-1] - if len(s.braceStack) == 0 || top != rune('{') { + if top != rune('{') { return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '{') } s.braceStack = s.braceStack[:len(s.braceStack)-1] @@ -165,8 +169,11 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { } if r == rune(']') { + if len(s.braceStack) == 0 { + return fmt.Errorf("stack is empty, extra closing brace %c", r) + } top := s.braceStack[len(s.braceStack)-1] - if len(s.braceStack) == 0 || top != rune('[') { + if top != rune('[') { return fmt.Errorf("unmatched closing brace, got%c, want%c", top, '[') } s.braceStack = s.braceStack[:len(s.braceStack)-1] @@ -194,6 +201,8 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { } func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) { + // TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode + // Should be possible through bitwise ops as well for i := range logits { _, exists := node.MaskTokenIDToNode[int32(i)] if !exists { diff --git a/sample/sample.go b/sample/sample.go index fcd9dbbdd..714d994de 100644 --- a/sample/sample.go +++ b/sample/sample.go @@ -165,11 +165,12 @@ func (s weighed) Sample(logits []float64) ([]float64, error) { if len(logitsCopy) == 0 { return nil, errors.New("no valid tokens found") } - logitsCopy, err := computeSoftmax(logitsCopy) + + softmax, err := computeSoftmax(logitsCopy) if err != nil { return nil, err } - w := sampleuv.NewWeighted(logitsCopy, nil) + w := sampleuv.NewWeighted(softmax, nil) if v, ok := w.Take(); ok { // returns the token ID return []float64{float64(indices[v])}, nil diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index 6b3cb7132..8cb6a50d8 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -3,10 +3,52 @@ package sample import "github.com/ollama/ollama/model" type StructuredOutput struct { - schema *Schema + schema *Schema + stateToNodeMap map[JSONState]*PDANode } -func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *PDANode { +func BuildStructuredOutputGraph(schema *Schema, proc model.TextProcessor) *StructuredOutput { + _, stateToNodeMap, err := BuildGraph(proc) + if err != nil { + panic(err) + } + return &StructuredOutput{ + schema: schema, + stateToNodeMap: stateToNodeMap, + } +} + +func (so *StructuredOutput) schemaToGraph(proc model.TextProcessor) *PDANode { + + schemaType := so.schema.EffectiveType() + switch schemaType { + case "object": + // each prop is a key + // prevState := StateInObjectKey + for _, prop := range so.schema.Properties { + // name of key + name := prop.Name + prevState := StateInObjectKey + for i, r := range name { + newState := JSONState(int(StateInObjectKey) + i + 1) // Create new unique state for each rune + + // Create new node for this state if it doesn't exist + if _, exists := so.stateToNodeMap[newState]; !exists { + so.stateToNodeMap[newState] = &PDANode{ + State: newState, + TransitionEdges: make(map[rune]*PDANode), + MaskTokenIDToNode: make(map[int32]JSONState), + } + } + + // Connect previous state to this state via the rune + so.stateToNodeMap[prevState].TransitionEdges[r] = so.stateToNodeMap[newState] + prevState = newState + } + // type of value + // propType := prop.Type + } + } return nil }