mirror of
https://github.com/ollama/ollama.git
synced 2025-04-07 11:28:17 +02:00
first pass so working
This commit is contained in:
parent
cc2e44b885
commit
25edfa6fdb
@ -106,8 +106,6 @@ func temp() error {
|
||||
}
|
||||
}
|
||||
|
||||
// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
||||
|
||||
// simple schema
|
||||
// This schema maps to JSON like:
|
||||
// {
|
||||
@ -119,9 +117,12 @@ func temp() error {
|
||||
Properties: []*sample.Schema{
|
||||
{Name: "name", Type: "string"},
|
||||
{Name: "age", Type: "integer"},
|
||||
{Name: "is_student", Type: "boolean"},
|
||||
// {Name: "is_student", Type: "boolean"},
|
||||
},
|
||||
}
|
||||
|
||||
// pushdownSampler := sample.NewPushdownSampler(m.(model.TextProcessor))
|
||||
pushdownSampler, err := sample.NewSOSampler(schema, m.(model.TextProcessor))
|
||||
if err != nil {
|
||||
return err
|
||||
@ -129,44 +130,47 @@ func temp() error {
|
||||
|
||||
var offset int
|
||||
var stringBuffer string
|
||||
var firstTokenTime time.Duration
|
||||
var ttft time.Duration
|
||||
var totalSamplingTime time.Duration
|
||||
count := 0
|
||||
for range args.n {
|
||||
logit, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
||||
logits, err := model.Forward(m, append(opts, model.WithInputIDs(inputIDs), model.WithOffset(offset))...)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
f32s := logit.Floats()
|
||||
f64s := make([]float64, len(f32s))
|
||||
for i, f32 := range f32s {
|
||||
f64s[i] = float64(f32)
|
||||
}
|
||||
samplers := []sample.Sampler{
|
||||
// f64s := make([]float64, len(f32s))
|
||||
// for i, f32 := range f32s {
|
||||
// f64s[i] = float64(f32)
|
||||
// }
|
||||
// samplers := []sample.Transform{
|
||||
// pushdownSampler,
|
||||
// sample.Weighed(),
|
||||
// sample.TopP(0.9),
|
||||
// sample.Weighed(),
|
||||
// sample.Greedy(),
|
||||
// }
|
||||
transforms := []sample.Transform{
|
||||
pushdownSampler,
|
||||
// sample.Weighed(),
|
||||
// sample.TopP(0.9),
|
||||
// sample.Weighed(),
|
||||
sample.Greedy(),
|
||||
}
|
||||
|
||||
samplingStart := time.Now()
|
||||
f64s, err = sample.Sample(f64s, samplers...)
|
||||
sampler := sample.NewSampler(transforms, sample.Greedy())
|
||||
sampledIdx, err := sampler.Sample(logits.Floats())
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
samplingTime := time.Since(samplingStart)
|
||||
totalSamplingTime += samplingTime
|
||||
|
||||
// fmt.Println("sampling time", samplingTime)
|
||||
fmt.Println("sampling time", samplingTime)
|
||||
// fmt.Printf("Sample time: %vms\n", finishTime.Sub(sampleTime).Milliseconds())
|
||||
|
||||
var outputIDs []int32
|
||||
for _, f64 := range f64s {
|
||||
if !m.(model.TextProcessor).Is(uint32(f64), model.SpecialEOS) {
|
||||
outputIDs = append(outputIDs, int32(f64))
|
||||
}
|
||||
|
||||
if !m.(model.TextProcessor).Is(uint32(sampledIdx), model.SpecialEOS) {
|
||||
outputIDs = append(outputIDs, int32(sampledIdx))
|
||||
}
|
||||
|
||||
if len(outputIDs) == 0 {
|
||||
@ -180,9 +184,9 @@ func temp() error {
|
||||
return err
|
||||
}
|
||||
|
||||
if firstTokenTime == 0 {
|
||||
firstTokenTime = time.Since(start)
|
||||
fmt.Printf("Time to first token: %vms\n", firstTokenTime.Milliseconds())
|
||||
if ttft == 0 {
|
||||
ttft = time.Since(start)
|
||||
fmt.Printf("Time to first token: %vms\n", ttft.Milliseconds())
|
||||
}
|
||||
|
||||
// fmt.Printf("--- token: %q\n", s)
|
||||
@ -196,6 +200,7 @@ func temp() error {
|
||||
return err
|
||||
}
|
||||
|
||||
// can do fun shifting stuff here if needed
|
||||
inputIDs = append(inputIDs, outputIDs...)
|
||||
if args.cache {
|
||||
offset = len(inputIDs) - 1
|
||||
|
@ -54,6 +54,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
//new line
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['\t'] = stateToNodeMap[StateInTab]
|
||||
stateToNodeMap[StateInNewline].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
|
||||
stateToNodeMap[StateInTab].TransitionEdges['"'] = stateToNodeMap[StateInObjectKey]
|
||||
|
||||
@ -76,6 +77,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
stateToNodeMap[StateInSpace].TransitionEdges['['] = stateToNodeMap[StateInList]
|
||||
stateToNodeMap[StateInSpace].TransitionEdges['{'] = stateToNodeMap[StateInObject]
|
||||
addValueConnections(stateToNodeMap[StateInSpace], stateToNodeMap)
|
||||
stateToNodeMap[StateInSpace].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
|
||||
// Values
|
||||
// string node
|
||||
@ -97,6 +99,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||
}
|
||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||
stateToNodeMap[StateInBool].TransitionEdges[' '] = stateToNodeMap[StateInSpace]
|
||||
|
||||
// list node
|
||||
stateToNodeMap[StateInList].TransitionEdges[','] = stateToNodeMap[StateInComma]
|
||||
@ -128,6 +131,7 @@ func BuildGraph(proc model.TextProcessor) (*PDANode, map[JSONState]*PDANode, err
|
||||
for _, r := range validBoolRunes {
|
||||
stateToNodeMap[StateInBool].TransitionEdges[r] = stateToNodeMap[StateInBool]
|
||||
}
|
||||
stateToNodeMap[StateInBool].TransitionEdges['\n'] = stateToNodeMap[StateInNewline]
|
||||
addEnds(stateToNodeMap[StateInBool], stateToNodeMap)
|
||||
|
||||
stateToNodeMap[StateInListEnd].TransitionEdges['}'] = stateToNodeMap[StateInObjectEnd]
|
||||
@ -178,7 +182,7 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
||||
|
||||
var err error
|
||||
for _, node := range stateToNodeMap {
|
||||
err = CreateMask(node, proc, decodedToks, vocab)
|
||||
err = CreateMask(node, proc, decodedToks)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@ -186,8 +190,8 @@ func PreComputeValidStates(stateToNodeMap map[JSONState]*PDANode, proc model.Tex
|
||||
return nil
|
||||
}
|
||||
|
||||
func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string, vocab *model.Vocabulary) error {
|
||||
for i := range vocab.Values {
|
||||
func CreateMask(node *PDANode, proc model.TextProcessor, decodedToks []string) error {
|
||||
for i := range decodedToks {
|
||||
token := decodedToks[i]
|
||||
// Skip EOS/BOS tokens and empty tokens since they are not valid in JSON
|
||||
if proc.Is(uint32(i), model.SpecialEOS) || proc.Is(uint32(i), model.SpecialBOS) || token == "" || token == "\"\"" {
|
||||
|
@ -57,7 +57,7 @@ func NewPushdownSampler(proc model.TextProcessor) *PushdownSampler {
|
||||
|
||||
// TODO: need to add resampling logic if the first sample was not good
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
func (s *PushdownSampler) Apply(logits []float64) ([]float64, error) {
|
||||
switch s.curNode.State {
|
||||
case StateInString:
|
||||
return s.maskLogits(logits, s.curNode)
|
||||
@ -70,7 +70,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
@ -90,7 +90,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
@ -123,7 +123,7 @@ func (s *PushdownSampler) Sample(logits []float64) ([]float64, error) {
|
||||
if s.proc.Is(uint32(i), model.SpecialEOS) {
|
||||
logits[i] = 1.0
|
||||
} else {
|
||||
logits[i] = math.NaN()
|
||||
logits[i] = math.Inf(-1)
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
@ -199,15 +199,20 @@ func (s *PushdownSampler) UpdateState(tokenSlice []int32) error {
|
||||
|
||||
// greedy sample + backtrack?
|
||||
func (s *PushdownSampler) maskLogits(logits []float64, node *PDANode) ([]float64, error) {
|
||||
// TODO: can be optimized by only masking the logits that are not in the node.MaskTokenIDToNode
|
||||
// Should be possible through bitwise ops as well
|
||||
for i := range logits {
|
||||
_, exists := node.MaskTokenIDToNode[int32(i)]
|
||||
if !exists {
|
||||
logits[i] = math.NaN()
|
||||
// Create a new slice with same length as logits, initialized to -Inf
|
||||
maskedLogits := make([]float64, len(logits))
|
||||
for i := range maskedLogits {
|
||||
maskedLogits[i] = math.Inf(-1)
|
||||
}
|
||||
|
||||
// Only update values for valid token IDs from the mask map
|
||||
for tokenID := range node.MaskTokenIDToNode {
|
||||
if int(tokenID) < len(logits) {
|
||||
maskedLogits[tokenID] = logits[tokenID]
|
||||
}
|
||||
}
|
||||
return logits, nil
|
||||
|
||||
return maskedLogits, nil
|
||||
}
|
||||
|
||||
// TODO: add penalties for string \n stuff
|
||||
|
@ -13,6 +13,7 @@ type SOSampler struct {
|
||||
propIdx int
|
||||
propToNodeMap map[string]*PDANode
|
||||
pdaSampler *PushdownSampler
|
||||
decodedToks []string
|
||||
}
|
||||
|
||||
func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error) {
|
||||
@ -27,6 +28,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
||||
|
||||
so.schemaToGraph()
|
||||
|
||||
// This is prob slow
|
||||
vocab := proc.GetVocabulary()
|
||||
decodedToks := make([]string, len(vocab.Values))
|
||||
for i := range vocab.Values {
|
||||
@ -36,6 +38,7 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
||||
}
|
||||
decodedToks[i] = token
|
||||
}
|
||||
so.decodedToks = decodedToks
|
||||
|
||||
fmt.Println("--------------------------------")
|
||||
fmt.Println("SOSampler")
|
||||
@ -47,16 +50,19 @@ func NewSOSampler(schema *Schema, proc model.TextProcessor) (*SOSampler, error)
|
||||
before := m.Alloc
|
||||
|
||||
// TODO: still messed up
|
||||
for _, node := range so.propToNodeMap {
|
||||
// TODO: recursion use case
|
||||
// key masks
|
||||
for _, prop := range so.schema.Properties {
|
||||
node := so.propToNodeMap[prop.Name]
|
||||
// propName -> node
|
||||
curState := node.State
|
||||
fromNode := node
|
||||
CreateMask(fromNode, proc, decodedToks, vocab)
|
||||
CreateMask(fromNode, proc, decodedToks)
|
||||
for curState == StateInStructuredKey {
|
||||
// there is only one edge
|
||||
for r, toNode := range fromNode.TransitionEdges {
|
||||
// fmt.Println("rune", r, "edge", toNode.State)
|
||||
CreateMask(toNode, proc, decodedToks, vocab)
|
||||
CreateMask(toNode, proc, decodedToks)
|
||||
fmt.Printf("created mask for %c\n", r)
|
||||
curState = toNode.State
|
||||
fmt.Println("next state", curState)
|
||||
@ -80,14 +86,11 @@ func (s *SOSampler) schemaToGraph() {
|
||||
switch schemaType {
|
||||
case "object":
|
||||
// TODO: see if we need to connect these to the JSON graph
|
||||
// prevState := StateInObjectKey
|
||||
// prevNode := so.stateToNodeMap[prevState]
|
||||
|
||||
// each prop is a key
|
||||
for _, prop := range s.schema.Properties {
|
||||
// name of key
|
||||
name := prop.Name
|
||||
// prevState := StateInObjectKey
|
||||
keyNode := &PDANode{
|
||||
State: StateInStructuredKey, // this is unchanging, will impact sampling
|
||||
TransitionEdges: make(map[rune]*PDANode),
|
||||
@ -116,10 +119,13 @@ func (s *SOSampler) schemaToGraph() {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
|
||||
func (s *SOSampler) Apply(logits []float64) ([]float64, error) {
|
||||
switch s.pdaSampler.curNode.State {
|
||||
// doesnt account for multi rune case
|
||||
case StateInObjectKey:
|
||||
if s.propIdx > len(s.schema.Properties)-1 {
|
||||
return nil, fmt.Errorf("propIdx out of bounds")
|
||||
}
|
||||
// fmt.Println("in object key - structured outputs")
|
||||
// TODO: this tracking should probably be coming from a stack to track nested objects
|
||||
// simple case
|
||||
@ -136,11 +142,52 @@ func (s *SOSampler) Sample(logits []float64) ([]float64, error) {
|
||||
return logits, nil
|
||||
|
||||
default:
|
||||
return s.pdaSampler.Sample(logits)
|
||||
|
||||
// Will only happen for the last prop - can also be precomputed.
|
||||
if s.propIdx == len(s.schema.Properties)-1 {
|
||||
// todo: if i incremenet propidx then i know im in last value as well
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectEnd:
|
||||
fmt.Println("<<<<< in obj end- generating mask for", s.pdaSampler.curNode.State)
|
||||
s.pdaSampler.curNode.TransitionEdges = make(map[rune]*PDANode)
|
||||
s.pdaSampler.curNode = NewPDANode(StateTerminate)
|
||||
s.propIdx++
|
||||
|
||||
case StateInNumber, StateInString, StateInBool, StateInNull, StateInListEnd:
|
||||
fmt.Println("<<<<< last prop - generating mask for", s.pdaSampler.curNode.State)
|
||||
delete(s.pdaSampler.curNode.TransitionEdges, ',')
|
||||
s.pdaSampler.curNode.MaskTokenIDToNode = make(map[int32]*PDANode)
|
||||
|
||||
CreateMask(s.pdaSampler.curNode, s.pdaSampler.proc, s.decodedToks)
|
||||
s.propIdx++
|
||||
}
|
||||
}
|
||||
return s.pdaSampler.Apply(logits)
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
func (s *SOSampler) UpdateState(tokenSlice []int32) error {
|
||||
return s.pdaSampler.UpdateState(tokenSlice)
|
||||
err := s.pdaSampler.UpdateState(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch s.pdaSampler.curNode.State {
|
||||
case StateInObjectKey:
|
||||
s.propIdx++
|
||||
fmt.Println("propIdx", s.propIdx)
|
||||
prop := s.schema.Properties[s.propIdx]
|
||||
fmt.Println("prop", prop.Name)
|
||||
s.pdaSampler.curNode = s.propToNodeMap[prop.Name]
|
||||
str, err := s.pdaSampler.proc.Decode(tokenSlice)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
fmt.Println("str", str)
|
||||
|
||||
return nil
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
Loading…
x
Reference in New Issue
Block a user