From 25edfa6fdb7bfe6020891ba50cfc12f117100a14 Mon Sep 17 00:00:00 2001 From: ParthSareen Date: Tue, 4 Feb 2025 14:40:46 -0800 Subject: [PATCH] first pass so working --- model/cmd/main.go | 51 +++++++++++++++------------- sample/pushdown_automata.go | 10 ++++-- sample/pushdown_runner.go | 27 +++++++++------ sample/structured_outputs.go | 65 +++++++++++++++++++++++++++++++----- 4 files changed, 107 insertions(+), 46 deletions(-) diff --git a/model/cmd/main.go b/model/cmd/main.go index 756a5238a..0526ef032 100644 --- a/model/cmd/main.go +++ b/model/cmd/main.go @@ -106,8 +106,6 @@ func temp() error { } } - // pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) - // simple schema // This schema maps to JSON like: // { @@ -119,9 +117,12 @@ func temp() error { Properties: []*sample.Schema{ {Name: "name", Type: "string"}, {Name: "age", Type: "integer"}, + {Name: "is_student", Type: "boolean"}, + // {Name: "is_student", Type: "boolean"}, }, } + // pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor)) pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor)) if err != nil { return err @@ -129,44 +130,47 @@ func temp() error { var offset int var stringBuffer string - var firstTokenTime time.Duration + var ttft time.Duration var totalSamplingTime time.Duration count := 0 for range args.n { - logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...) + logits, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...) if err != nil { return err } - f32s := logit.Floats() - f64s := make([]float64, len(f32s)) - for i, f32 := range f32s { - f64s[i] = float64(f32) - } - samplers := []sample.Sampler{ + // f64s := make([]float64, len(f32s)) + // for i, f32 := range f32s { + // f64s[i] = float64(f32) + // } + // samplers := []sample.Transform{ + // pushdownSampler, + // sample.Weighed(), + // sample.TopP(0.9), + // sample.Weighed(), + // sample.Greedy(), + // } + transforms := []sample.Transform{ pushdownSampler, - // sample.Weighed(), - // sample.TopP(0.9), - // sample.Weighed(), - sample.Greedy(), } samplingStart := time.Now() - f64s, err = sample.Sample(f64s, samplers...) + sampler := sample.NewSampler(transforms, sample.Greedy()) + sampledIdx, err := sampler.Sample(logits.Floats()) if err != nil { return err } + samplingTime := time.Since(samplingStart) totalSamplingTime += samplingTime - // fmt.Println("sampling time", samplingTime) + fmt.Println("sampling time", samplingTime) // fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds()) var outputIDs []int32 - for _, f64 := range f64s { - if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) { - outputIDs = append(outputIDs, int32(f64)) - } + + if !m.(model.TextProcessor).Is(uint32(sampledIdx), model.SpecialEOS) { + outputIDs = append(outputIDs, int32(sampledIdx)) } if len(outputIDs) == 0 { @@ -180,9 +184,9 @@ func temp() error { return err } - if firstTokenTime == 0 { - firstTokenTime = time.Since(start) - fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds()) + if ttft == 0 { + ttft = time.Since(start) + fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds()) } // fmt.Printf("--- token: %q\n", s) @@ -196,6 +200,7 @@ func temp() error { return err } + // can do fun shifting stuff here if needed inputIDs = append(inputIDs, outputIDs...) if args.cache { offset = len(inputIDs) - 1 diff --git a/sample/pushdown_automata.go b/sample/pushdown_automata.go index d805288d9..19b237526 100644 --- a/sample/pushdown_automata.go +++ b/sample/pushdown_automata.go @@ -54,6 +54,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err //new line stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab] + stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey] @@ -76,6 +77,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList] stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject] addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap) + stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] // Values // string node @@ -97,6 +99,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] } addEnds(stateToNodeMap[StateInBool], stateToNodeMap) + stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace] // list node stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma] @@ -128,6 +131,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err for _, r := range validBoolRunes { stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool] } + stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline] addEnds(stateToNodeMap[StateInBool], stateToNodeMap) stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd] @@ -178,7 +182,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex var err error for _, node := range stateToNodeMap { - err = CreateMask(node, proc, decodedToks, vocab) + err = CreateMask(node, proc, decodedToks) if err != nil { return err } @@ -186,8 +190,8 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex return nil } -func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error { - for i := range vocab.Values { +func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error { + for i := range decodedToks { token := decodedToks[i] // Skip EOS/BOS tokens and empty tokens since they are not valid in JSON if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" { diff --git a/sample/pushdown_runner.go b/sample/pushdown_runner.go index ea568525b..97c58514b 100644 --- a/sample/pushdown_runner.go +++ b/sample/pushdown_runner.go @@ -57,7 +57,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler { // TODO: need to add resampling logic if the first sample was not good // greedy sample + backtrack? -func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { +func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) { switch s.curNode.State { case StateInString: return s.maskLogits(logits, s.curNode) @@ -70,7 +70,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { if s.proc.Is(uint32(i), model.SpecialEOS) { logits[i] = 1.0 } else { - logits[i] = math.NaN() + logits[i] = math.Inf(-1) } } return logits, nil @@ -90,7 +90,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { if s.proc.Is(uint32(i), model.SpecialEOS) { logits[i] = 1.0 } else { - logits[i] = math.NaN() + logits[i] = math.Inf(-1) } } return logits, nil @@ -123,7 +123,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) { if s.proc.Is(uint32(i), model.SpecialEOS) { logits[i] = 1.0 } else { - logits[i] = math.NaN() + logits[i] = math.Inf(-1) } } return logits, nil @@ -199,15 +199,20 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error { // greedy sample + backtrack? func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) { - // TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode - // Should be possible through bitwise ops as well - for i := range logits { - _, exists := node.MaskTokenIDToNode[int32(i)] - if !exists { - logits[i] = math.NaN() + // Create a new slice with same length as logits, initialized to -Inf + maskedLogits := make([]float64, len(logits)) + for i := range maskedLogits { + maskedLogits[i] = math.Inf(-1) + } + + // Only update values for valid token IDs from the mask map + for tokenID := range node.MaskTokenIDToNode { + if int(tokenID) < len(logits) { + maskedLogits[tokenID] = logits[tokenID] } } - return logits, nil + + return maskedLogits, nil } // TODO: add penalties for string \n stuff diff --git a/sample/structured_outputs.go b/sample/structured_outputs.go index 7e540ae9b..91c1d82de 100644 --- a/sample/structured_outputs.go +++ b/sample/structured_outputs.go @@ -13,6 +13,7 @@ type SOSampler struct { propIdx int propToNodeMap map[string]*PDANode pdaSampler *PushdownSampler + decodedToks []string } func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) { @@ -27,6 +28,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) so.schemaToGraph() + // This is prob slow vocab := proc.GetVocabulary() decodedToks := make([]string, len(vocab.Values)) for i := range vocab.Values { @@ -36,6 +38,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) } decodedToks[i] = token } + so.decodedToks = decodedToks fmt.Println("--------------------------------") fmt.Println("SOSampler") @@ -47,16 +50,19 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) before := m.Alloc // TODO: still messed up - for _, node := range so.propToNodeMap { + // TODO: recursion use case + // key masks + for _, prop := range so.schema.Properties { + node := so.propToNodeMap[prop.Name] // propName -> node curState := node.State fromNode := node - CreateMask(fromNode, proc, decodedToks, vocab) + CreateMask(fromNode, proc, decodedToks) for curState == StateInStructuredKey { // there is only one edge for r, toNode := range fromNode.TransitionEdges { // fmt.Println("rune", r, "edge", toNode.State) - CreateMask(toNode, proc, decodedToks, vocab) + CreateMask(toNode, proc, decodedToks) fmt.Printf("created mask for %c\n", r) curState = toNode.State fmt.Println("next state", curState) @@ -80,14 +86,11 @@ func (s *SOSampler) schemaToGraph() { switch schemaType { case "object": // TODO: see if we need to connect these to the JSON graph - // prevState := StateInObjectKey - // prevNode := so.stateToNodeMap[prevState] // each prop is a key for _, prop := range s.schema.Properties { // name of key name := prop.Name - // prevState := StateInObjectKey keyNode := &PDANode{ State: StateInStructuredKey, // this is unchanging, will impact sampling TransitionEdges: make(map[rune]*PDANode), @@ -116,10 +119,13 @@ func (s *SOSampler) schemaToGraph() { } } -func (s *SOSampler) Sample(logits []float64) ([]float64, error) { +func (s *SOSampler) Apply(logits []float64) ([]float64, error) { switch s.pdaSampler.curNode.State { // doesnt account for multi rune case case StateInObjectKey: + if s.propIdx > len(s.schema.Properties)-1 { + return nil, fmt.Errorf("propIdx out of bounds") + } // fmt.Println("in object key - structured outputs") // TODO: this tracking should probably be coming from a stack to track nested objects // simple case @@ -136,11 +142,52 @@ func (s *SOSampler) Sample(logits []float64) ([]float64, error) { return logits, nil default: - return s.pdaSampler.Sample(logits) + + // Will only happen for the last prop - can also be precomputed. + if s.propIdx == len(s.schema.Properties)-1 { + // todo: if i incremenet propidx then i know im in last value as well + switch s.pdaSampler.curNode.State { + case StateInObjectEnd: + fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State) + s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode) + s.pdaSampler.curNode = NewPDANode(StateTerminate) + s.propIdx++ + + case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd: + fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State) + delete(s.pdaSampler.curNode.TransitionEdges, ',') + s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode) + + CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks) + s.propIdx++ + } + } + return s.pdaSampler.Apply(logits) } } func (s *SOSampler) UpdateState(tokenSlice []int32) error { - return s.pdaSampler.UpdateState(tokenSlice) + err := s.pdaSampler.UpdateState(tokenSlice) + if err != nil { + return err + } + + switch s.pdaSampler.curNode.State { + case StateInObjectKey: + s.propIdx++ + fmt.Println("propIdx", s.propIdx) + prop := s.schema.Properties[s.propIdx] + fmt.Println("prop", prop.Name) + s.pdaSampler.curNode = s.propToNodeMap[prop.Name] + str, err := s.pdaSampler.proc.Decode(tokenSlice) + if err != nil { + return err + } + fmt.Println("str", str) + + return nil + default: + return nil + } }