diff --git a/kvcache/cache.go b/kvcache/cache.go index 2541f7c16..d35489057 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) var ( @@ -51,7 +52,7 @@ type Cache interface { // StartForward is called before the start of the model's forward pass. // For each token in the coming batch, there must be a corresponding // entry in positions and seqs. - StartForward(ctx ml.Context, positions []int32, seqs []int) error + StartForward(ctx ml.Context, opts input.Options) error // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) diff --git a/kvcache/causal.go b/kvcache/causal.go index 9a79fa577..34d5337cf 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -8,6 +8,7 @@ import ( "slices" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) @@ -140,10 +141,10 @@ func (c *Causal) Close() { } } -func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - c.curBatchSize = len(positions) - c.curSequences = seqs - c.curPositions = positions +func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { + c.curBatchSize = len(opts.Positions) + c.curSequences = opts.Sequences + c.curPositions = opts.Positions var err error c.curLoc, err = c.findStartLoc() @@ -156,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err } c.curCellRange = newRange() - for i, pos := range positions { - seq := seqs[i] + for i, pos := range opts.Positions { + seq := opts.Sequences[i] c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 412f33e34..22d8efb43 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type testCase struct { @@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context := backend.NewContext() defer context.Close() - err := cache.StartForward(context, test.pos, test.seqs) + err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) if err != nil { panic(err) } diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 867ee37a5..6a9df2abc 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Encoder cache stores K and V tensors that are position independent @@ -78,9 +79,11 @@ func (c *EncoderCache) Close() { } } -func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - // The image is always in the first position - c.curPos = positions[0] +func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { + // We work with the most recent image + if len(opts.Multimodal) > 0 { + c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index] + } return nil } diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 76956a88a..aaccd1661 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -4,6 +4,7 @@ import ( "math" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Wrapper cache is a container for multiple types of caches, @@ -40,14 +41,14 @@ func (c *WrapperCache) Close() { } } -func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { +func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { for i, cache := range c.caches { - err := cache.StartForward(ctx, positions, seqs) + err := cache.StartForward(ctx, opts) if err != nil { // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail for j := i - 1; j >= 0; j-- { - for k := range positions { - _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32) + for k := range opts.Positions { + _ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) } } return err diff --git a/model/input/input.go b/model/input/input.go new file mode 100644 index 000000000..0cb3f3f41 --- /dev/null +++ b/model/input/input.go @@ -0,0 +1,37 @@ +package input + +// Input represents one token in the input stream +type Input struct { + // Token is a single element of text. + Token int32 + + // Multimodal is opaque data representing a non-text + // element such as an image (or part of one if the image + // can be processed in pieces). It may be either together + // with Token or on its own. + Multimodal any + + // MultimodalHash is a unique representation of the data + // stored in Multimodal, used for caching and comparing + // equality. + MultimodalHash uint64 +} + +// MultimodalIndex is a multimodal element (such as an image) +// together with an index into the slice of Inputs with the +// corresponding token. Note that the index is not the same +// as the position - to find that use the index with the +// Positions slice. +type MultimodalIndex struct { + Index int + Multimodal any +} + +// Options contains the inputs for a model forward pass +type Options struct { + Inputs []int32 + Multimodal []MultimodalIndex + Positions []int32 + Sequences []int + Outputs []int32 +} diff --git a/model/model.go b/model/model.go index 75b7f6397..89b6c803b 100644 --- a/model/model.go +++ b/model/model.go @@ -19,66 +19,12 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/model/input" ) -// Input represents one token in the input stream -type Input struct { - // Token is a single element of text. - Token int32 - - // Multimodal is opaque data representing a non-text - // element such as an image (or part of one if the image - // can be processed in pieces). It may be either together - // with Token or on its own. - Multimodal any - - // MultimodalHash is a unique representation of the data - // stored in Multimodal, used for caching and comparing - // equality. - MultimodalHash uint64 -} - -// MultimodalIndex is a multimodal element (such as an image) -// together with an index into the slice of Inputs with the -// corresponding token. Note that the index is not the same -// as the position - to find that use the index with the -// Positions slice. -type MultimodalIndex struct { - Index int - Multimodal any -} - -// Options contains the inputs for a model forward pass -type Options struct { - Inputs []int32 - Multimodal []MultimodalIndex - Positions []int32 - Sequences []int - Outputs []int32 -} - -type config struct { - Cache kvcache.Cache -} - -// Base implements the common fields and methods for all models -type Base struct { - b ml.Backend - config -} - -// Backend returns the underlying backend that will run the model -func (m *Base) Backend() ml.Backend { - return m.b -} - -func (m *Base) Config() config { - return m.config -} - // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { - Forward(ml.Context, Options) (ml.Tensor, error) + Forward(ml.Context, input.Options) (ml.Tensor, error) Backend() ml.Backend Config() config @@ -112,7 +58,26 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize(ml.Context, []Input) ([]Input, error) + PostTokenize(ml.Context, []input.Input) ([]input.Input, error) +} + +// Base implements the common fields and methods for all models +type Base struct { + b ml.Backend + config +} + +type config struct { + Cache kvcache.Cache +} + +// Backend returns the underlying backend that will run the model +func (m *Base) Backend() ml.Backend { + return m.b +} + +func (m *Base) Config() config { + return m.config } var models = make(map[string]func(ml.Config) (Model, error)) @@ -313,7 +278,7 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { if len(opts.Positions) != len(opts.Sequences) { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) } @@ -324,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { cache := m.Config().Cache if cache != nil { - err := cache.StartForward(ctx, opts.Positions, opts.Sequences) + err := cache.StartForward(ctx, opts) if err != nil { return nil, err } diff --git a/model/model_test.go b/model/model_test.go index 8761817e0..354dd1d8b 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -162,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) { type notTextProcessorModel struct{} -func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) { +func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) { panic("unimplemented") } diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 9ccfff612..1f27f522d 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -9,6 +9,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Options struct { @@ -137,7 +138,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 54c632960..31ba15dfd 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -12,6 +12,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Model struct { @@ -101,8 +102,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return m.Projector.Forward(ctx, crossAttentionStates), nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) { - var images []model.Input +func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { + var images []input.Input fnvHash := fnv.New64a() for i := range inputs { @@ -125,15 +126,15 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Inpu } } - inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 }) + inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 }) return inputs, nil } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor - if opts.Multimodal != nil { - crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor) + if len(opts.Multimodal) > 0 { + crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) } inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 3244c0b89..a411fddb1 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -10,6 +10,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type InputCache struct { @@ -79,7 +80,7 @@ type InputCacheSlot struct { Id int // Inputs that are stored in the KV cache - Inputs []model.Input + Inputs []input.Input // is this cache actively being processed as part of a sequence? InUse bool @@ -88,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp return slot, prompt, nil } -func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { longest := int32(-1) var longestSlot *InputCacheSlot @@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot return longestSlot, longest, nil } -func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { oldest := time.Now() var oldestSlot *InputCacheSlot @@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i if longest > 0 && longestSlot != oldestSlot { slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", len(longestSlot.Inputs)) - oldestSlot.Inputs = make([]model.Input, longest) + oldestSlot.Inputs = make([]input.Input, longest) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) if c.cache != nil { c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) @@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i return oldestSlot, longest, nil } -func countCommonPrefix(a []model.Input, b []model.Input) int32 { +func countCommonPrefix(a []input.Input, b []input.Input) int32 { var count int32 for i := range a { diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 9ce03b73f..0a1b73f5a 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) func TestCountCommon(t *testing.T) { @@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) { tests := []struct { name string - t1 []model.Input - t2 []model.Input + t1 []input.Input + t2 []input.Input expected int32 }{ { name: "Equal", - t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 3, }, { name: "Prefix", - t1: []model.Input{{Token: 1}}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{{Token: 1}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 1, }, { name: "Image Prefix", - t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, + t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, expected: 2, }, { name: "Mixed, Same Length", - t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, expected: 1, }, { name: "Empty", - t1: []model.Input{}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 0, }, { name: "Both Empty", - t1: []model.Input{}, - t2: []model.Input{}, + t1: []input.Input{}, + t2: []input.Input{}, expected: 0, }, } @@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []model.Input + prompt []input.Input longest expected best expected }{ @@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []model.Input{{Token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 0}, best: expected{result: 0, len: 0}, }, @@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 1}, {Token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 2}, best: expected{result: 1, len: 2}, }, @@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []model.Input{{Token: 2}}, + prompt: []input.Input{{Token: 2}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }, }, - prompt: []model.Input{{Token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 1}, best: expected{result: 1, len: 1}, }, @@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 2}, {Token: 3}}, + prompt: []input.Input{{Token: 2}, {Token: 3}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 1}, {Token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 1}, best: expected{result: 1, len: 2}, }, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a51b1459e..c8383a5dd 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -41,10 +42,10 @@ type Sequence struct { iBatch int // prompt inputs left to evaluate - inputs []model.Input + inputs []input.Input // inputs that have been added to a batch but not yet submitted to Forward - pendingInputs []model.Input + pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) { - var inputs []model.Input +func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { + var inputs []input.Input var parts []string var matches [][]string @@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo } for _, t := range tokens { - inputs = append(inputs, model.Input{Token: t}) + inputs = append(inputs, input.Input{Token: t}) } // image - decode and store @@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo _, _ = s.multimodalHash.Write(images[imageIndex].Data) imageHash := s.multimodalHash.Sum64() - inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) + inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) postTokenize = true } } @@ -250,9 +251,6 @@ type Server struct { // KV cache cache *InputCache - // next sequence for prompt processing to avoid starvation - nextSeq int - // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash @@ -329,29 +327,25 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() - var options model.Options - - seqIdx := s.nextSeq - 1 - for range s.seqs { - seqIdx = (seqIdx + 1) % len(s.seqs) - seq := s.seqs[seqIdx] + var options input.Options + for i, seq := range s.seqs { if seq == nil { continue } // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(i, "limit") continue } if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) - seq.cache.Inputs = []model.Input{} + seq.cache.Inputs = []input.Input{} } - for i, input := range seq.inputs { + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) @@ -363,33 +357,23 @@ func (s *Server) processBatch() error { } } - if i >= s.batchSize { + if j >= s.batchSize { break } - // TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint - // to the encoder cache. - // - // Break the batch when switching from text to images so that images are always at the beginning. - if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 || - (len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) { - s.nextSeq = seqIdx - break - } - - options.Inputs = append(options.Inputs, input.Token) - if input.Multimodal != nil { - options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal}) + options.Inputs = append(options.Inputs, inp.Token) + if inp.Multimodal != nil { + options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) } options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Sequences = append(options.Sequences, seq.cache.Id) seq.iBatch = len(options.Outputs) - if i+1 == len(seq.inputs) { + if j+1 == len(seq.inputs) { options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) } - seq.pendingInputs = append(seq.pendingInputs, input) + seq.pendingInputs = append(seq.pendingInputs, inp) } seq.inputs = seq.inputs[len(seq.pendingInputs):] @@ -417,7 +401,7 @@ func (s *Server) processBatch() error { // After calling Forward, pending inputs are now in the cache if len(seq.pendingInputs) > 0 { seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) - seq.pendingInputs = []model.Input{} + seq.pendingInputs = []input.Input{} } // don't sample prompt processing @@ -464,7 +448,7 @@ func (s *Server) processBatch() error { return err } - seq.inputs = []model.Input{{Token: token}} + seq.inputs = []input.Input{{Token: token}} seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "")