From a1cda80bcb0b47d493be9dc061a2dfa8a0ddd61c Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Sat, 8 Mar 2025 15:45:31 -0800 Subject: [PATCH] model: Update encoder cache to use multimodal input processing handler The encoder cache needs to know the position of images in the input stream so that it knows when to delete them. Previously images didn't have a position, so we implied one by breaking batches before an image and then assuming the image was in the first position. However, multimodal objects are now given explicit positions in the input stream, so we can use that instead. Breaking batches was also a way to simulate a cross attention mask for mllama. However, given that it only supports a single sequence and a single image, this mask doesn't serve any real purpose. Removing the batch break does not appear to affect the quality of the output. Most of this is simply moving the input data structures to a new package to avoid import cycles. --- kvcache/cache.go | 3 +- kvcache/causal.go | 13 ++--- kvcache/causal_test.go | 3 +- kvcache/encoder.go | 9 ++-- kvcache/wrapper.go | 9 ++-- model/input/input.go | 37 ++++++++++++++ model/model.go | 83 +++++++++---------------------- model/model_test.go | 3 +- model/models/llama/model.go | 3 +- model/models/mllama/model.go | 13 ++--- runner/ollamarunner/cache.go | 13 ++--- runner/ollamarunner/cache_test.go | 72 +++++++++++++-------------- runner/ollamarunner/runner.go | 56 ++++++++------------- 13 files changed, 157 insertions(+), 160 deletions(-) create mode 100644 model/input/input.go diff --git a/kvcache/cache.go b/kvcache/cache.go index 2541f7c16..d35489057 100644 --- a/kvcache/cache.go +++ b/kvcache/cache.go @@ -4,6 +4,7 @@ import ( "errors" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) var ( @@ -51,7 +52,7 @@ type Cache interface { // StartForward is called before the start of the model's forward pass. // For each token in the coming batch, there must be a corresponding // entry in positions and seqs. - StartForward(ctx ml.Context, positions []int32, seqs []int) error + StartForward(ctx ml.Context, opts input.Options) error // CopyPrefix copies tokens in the range [0, len) from srcSeq to dstSeq CopyPrefix(srcSeq, dstSeq int, len int32) diff --git a/kvcache/causal.go b/kvcache/causal.go index 9a79fa577..34d5337cf 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -8,6 +8,7 @@ import ( "slices" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type shiftFn func(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) @@ -140,10 +141,10 @@ func (c *Causal) Close() { } } -func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - c.curBatchSize = len(positions) - c.curSequences = seqs - c.curPositions = positions +func (c *Causal) StartForward(ctx ml.Context, opts input.Options) error { + c.curBatchSize = len(opts.Positions) + c.curSequences = opts.Sequences + c.curPositions = opts.Positions var err error c.curLoc, err = c.findStartLoc() @@ -156,8 +157,8 @@ func (c *Causal) StartForward(ctx ml.Context, positions []int32, seqs []int) err } c.curCellRange = newRange() - for i, pos := range positions { - seq := seqs[i] + for i, pos := range opts.Positions { + seq := opts.Sequences[i] c.cells[c.curLoc+i] = cacheCell{pos: pos, sequences: []int{seq}} diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 412f33e34..22d8efb43 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -6,6 +6,7 @@ import ( "testing" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) type testCase struct { @@ -269,7 +270,7 @@ func testCache(t *testing.T, backend ml.Backend, cache Cache, tests []testCase) context := backend.NewContext() defer context.Close() - err := cache.StartForward(context, test.pos, test.seqs) + err := cache.StartForward(context, input.Options{Positions: test.pos, Sequences: test.seqs}) if err != nil { panic(err) } diff --git a/kvcache/encoder.go b/kvcache/encoder.go index 867ee37a5..6a9df2abc 100644 --- a/kvcache/encoder.go +++ b/kvcache/encoder.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Encoder cache stores K and V tensors that are position independent @@ -78,9 +79,11 @@ func (c *EncoderCache) Close() { } } -func (c *EncoderCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { - // The image is always in the first position - c.curPos = positions[0] +func (c *EncoderCache) StartForward(ctx ml.Context, opts input.Options) error { + // We work with the most recent image + if len(opts.Multimodal) > 0 { + c.curPos = opts.Positions[opts.Multimodal[len(opts.Multimodal)-1].Index] + } return nil } diff --git a/kvcache/wrapper.go b/kvcache/wrapper.go index 76956a88a..aaccd1661 100644 --- a/kvcache/wrapper.go +++ b/kvcache/wrapper.go @@ -4,6 +4,7 @@ import ( "math" "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/input" ) // Wrapper cache is a container for multiple types of caches, @@ -40,14 +41,14 @@ func (c *WrapperCache) Close() { } } -func (c *WrapperCache) StartForward(ctx ml.Context, positions []int32, seqs []int) error { +func (c *WrapperCache) StartForward(ctx ml.Context, opts input.Options) error { for i, cache := range c.caches { - err := cache.StartForward(ctx, positions, seqs) + err := cache.StartForward(ctx, opts) if err != nil { // unwind on error - Remove with endIndex set to math.MaxInt32 does not fail for j := i - 1; j >= 0; j-- { - for k := range positions { - _ = c.caches[j].Remove(seqs[k], positions[k], math.MaxInt32) + for k := range opts.Positions { + _ = c.caches[j].Remove(opts.Sequences[k], opts.Positions[k], math.MaxInt32) } } return err diff --git a/model/input/input.go b/model/input/input.go new file mode 100644 index 000000000..0cb3f3f41 --- /dev/null +++ b/model/input/input.go @@ -0,0 +1,37 @@ +package input + +// Input represents one token in the input stream +type Input struct { + // Token is a single element of text. + Token int32 + + // Multimodal is opaque data representing a non-text + // element such as an image (or part of one if the image + // can be processed in pieces). It may be either together + // with Token or on its own. + Multimodal any + + // MultimodalHash is a unique representation of the data + // stored in Multimodal, used for caching and comparing + // equality. + MultimodalHash uint64 +} + +// MultimodalIndex is a multimodal element (such as an image) +// together with an index into the slice of Inputs with the +// corresponding token. Note that the index is not the same +// as the position - to find that use the index with the +// Positions slice. +type MultimodalIndex struct { + Index int + Multimodal any +} + +// Options contains the inputs for a model forward pass +type Options struct { + Inputs []int32 + Multimodal []MultimodalIndex + Positions []int32 + Sequences []int + Outputs []int32 +} diff --git a/model/model.go b/model/model.go index 75b7f6397..89b6c803b 100644 --- a/model/model.go +++ b/model/model.go @@ -19,66 +19,12 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" _ "github.com/ollama/ollama/ml/backend" + "github.com/ollama/ollama/model/input" ) -// Input represents one token in the input stream -type Input struct { - // Token is a single element of text. - Token int32 - - // Multimodal is opaque data representing a non-text - // element such as an image (or part of one if the image - // can be processed in pieces). It may be either together - // with Token or on its own. - Multimodal any - - // MultimodalHash is a unique representation of the data - // stored in Multimodal, used for caching and comparing - // equality. - MultimodalHash uint64 -} - -// MultimodalIndex is a multimodal element (such as an image) -// together with an index into the slice of Inputs with the -// corresponding token. Note that the index is not the same -// as the position - to find that use the index with the -// Positions slice. -type MultimodalIndex struct { - Index int - Multimodal any -} - -// Options contains the inputs for a model forward pass -type Options struct { - Inputs []int32 - Multimodal []MultimodalIndex - Positions []int32 - Sequences []int - Outputs []int32 -} - -type config struct { - Cache kvcache.Cache -} - -// Base implements the common fields and methods for all models -type Base struct { - b ml.Backend - config -} - -// Backend returns the underlying backend that will run the model -func (m *Base) Backend() ml.Backend { - return m.b -} - -func (m *Base) Config() config { - return m.config -} - // Model implements a specific model architecture, defining the forward pass and any model-specific configuration type Model interface { - Forward(ml.Context, Options) (ml.Tensor, error) + Forward(ml.Context, input.Options) (ml.Tensor, error) Backend() ml.Backend Config() config @@ -112,7 +58,26 @@ type MultimodalProcessor interface { // This function is also responsible for updating MultimodalHash for any Multimodal // that is modified to ensure that there is a unique hash value that accurately // represents the contents. - PostTokenize(ml.Context, []Input) ([]Input, error) + PostTokenize(ml.Context, []input.Input) ([]input.Input, error) +} + +// Base implements the common fields and methods for all models +type Base struct { + b ml.Backend + config +} + +type config struct { + Cache kvcache.Cache +} + +// Backend returns the underlying backend that will run the model +func (m *Base) Backend() ml.Backend { + return m.b +} + +func (m *Base) Config() config { + return m.config } var models = make(map[string]func(ml.Config) (Model, error)) @@ -313,7 +278,7 @@ func canNil(t reflect.Type) bool { t.Kind() == reflect.Slice } -func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { +func Forward(ctx ml.Context, m Model, opts input.Options) (ml.Tensor, error) { if len(opts.Positions) != len(opts.Sequences) { return nil, fmt.Errorf("length of positions (%v) must match length of seqs (%v)", len(opts.Positions), len(opts.Sequences)) } @@ -324,7 +289,7 @@ func Forward(ctx ml.Context, m Model, opts Options) (ml.Tensor, error) { cache := m.Config().Cache if cache != nil { - err := cache.StartForward(ctx, opts.Positions, opts.Sequences) + err := cache.StartForward(ctx, opts) if err != nil { return nil, err } diff --git a/model/model_test.go b/model/model_test.go index 8761817e0..354dd1d8b 100644 --- a/model/model_test.go +++ b/model/model_test.go @@ -11,6 +11,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/backend/ggml" "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model/input" ) func TestParseTags(t *testing.T) { @@ -162,7 +163,7 @@ func TestGetTextProcessor(t *testing.T) { type notTextProcessorModel struct{} -func (notTextProcessorModel) Forward(ml.Context, Options) (ml.Tensor, error) { +func (notTextProcessorModel) Forward(ml.Context, input.Options) (ml.Tensor, error) { panic("unimplemented") } diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 9ccfff612..1f27f522d 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -9,6 +9,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Options struct { @@ -137,7 +138,7 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten return hiddenState.Add(ctx, residual) } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) if err != nil { return nil, err diff --git a/model/models/mllama/model.go b/model/models/mllama/model.go index 54c632960..31ba15dfd 100644 --- a/model/models/mllama/model.go +++ b/model/models/mllama/model.go @@ -12,6 +12,7 @@ import ( "github.com/ollama/ollama/ml" "github.com/ollama/ollama/ml/nn" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type Model struct { @@ -101,8 +102,8 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return m.Projector.Forward(ctx, crossAttentionStates), nil } -func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Input, error) { - var images []model.Input +func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { + var images []input.Input fnvHash := fnv.New64a() for i := range inputs { @@ -125,15 +126,15 @@ func (m *Model) PostTokenize(ctx ml.Context, inputs []model.Input) ([]model.Inpu } } - inputs = slices.DeleteFunc(inputs, func(input model.Input) bool { return input.Token == -1 }) + inputs = slices.DeleteFunc(inputs, func(input input.Input) bool { return input.Token == -1 }) return inputs, nil } -func (m *Model) Forward(ctx ml.Context, opts model.Options) (ml.Tensor, error) { +func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { var crossAttentionStates ml.Tensor - if opts.Multimodal != nil { - crossAttentionStates = opts.Multimodal[0].Multimodal.(ml.Tensor) + if len(opts.Multimodal) > 0 { + crossAttentionStates = opts.Multimodal[len(opts.Multimodal)-1].Multimodal.(ml.Tensor) } inputs, err := ctx.Input().FromIntSlice(opts.Inputs, len(opts.Inputs)) diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index 3244c0b89..a411fddb1 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -10,6 +10,7 @@ import ( "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) type InputCache struct { @@ -79,7 +80,7 @@ type InputCacheSlot struct { Id int // Inputs that are stored in the KV cache - Inputs []model.Input + Inputs []input.Input // is this cache actively being processed as part of a sequence? InUse bool @@ -88,7 +89,7 @@ type InputCacheSlot struct { lastUsed time.Time } -func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*InputCacheSlot, []model.Input, error) { +func (c *InputCache) LoadCacheSlot(prompt []input.Input, cachePrompt bool) (*InputCacheSlot, []input.Input, error) { var slot *InputCacheSlot var numPast int32 var err error @@ -139,7 +140,7 @@ func (c *InputCache) LoadCacheSlot(prompt []model.Input, cachePrompt bool) (*Inp return slot, prompt, nil } -func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findLongestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { longest := int32(-1) var longestSlot *InputCacheSlot @@ -162,7 +163,7 @@ func (c *InputCache) findLongestCacheSlot(prompt []model.Input) (*InputCacheSlot return longestSlot, longest, nil } -func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, int32, error) { +func (c *InputCache) findBestCacheSlot(prompt []input.Input) (*InputCacheSlot, int32, error) { oldest := time.Now() var oldestSlot *InputCacheSlot @@ -198,7 +199,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i if longest > 0 && longestSlot != oldestSlot { slog.Debug("forking cache slot", "src", longestSlot.Id, "dst", oldestSlot.Id, "inputs", longest, "total", len(longestSlot.Inputs)) - oldestSlot.Inputs = make([]model.Input, longest) + oldestSlot.Inputs = make([]input.Input, longest) copy(oldestSlot.Inputs, longestSlot.Inputs[:longest]) if c.cache != nil { c.cache.CopyPrefix(longestSlot.Id, oldestSlot.Id, longest) @@ -208,7 +209,7 @@ func (c *InputCache) findBestCacheSlot(prompt []model.Input) (*InputCacheSlot, i return oldestSlot, longest, nil } -func countCommonPrefix(a []model.Input, b []model.Input) int32 { +func countCommonPrefix(a []input.Input, b []input.Input) int32 { var count int32 for i := range a { diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index 9ce03b73f..0a1b73f5a 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -5,7 +5,7 @@ import ( "testing" "time" - "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" ) func TestCountCommon(t *testing.T) { @@ -15,50 +15,50 @@ func TestCountCommon(t *testing.T) { tests := []struct { name string - t1 []model.Input - t2 []model.Input + t1 []input.Input + t2 []input.Input expected int32 }{ { name: "Equal", - t1: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 3, }, { name: "Prefix", - t1: []model.Input{{Token: 1}}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{{Token: 1}}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 1, }, { name: "Image Prefix", - t1: []model.Input{{Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, + t1: []input.Input{{Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Multimodal: imgA, MultimodalHash: 1}, {Multimodal: imgB, MultimodalHash: 2}, {Multimodal: imgC, MultimodalHash: 3}}, expected: 1, }, { name: "Mixed", - t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}, {Token: 5}}, expected: 2, }, { name: "Mixed, Same Length", - t1: []model.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, - t2: []model.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, + t1: []input.Input{{Token: 1}, {Multimodal: imgA, MultimodalHash: 1}}, + t2: []input.Input{{Token: 1}, {Multimodal: imgB, MultimodalHash: 2}}, expected: 1, }, { name: "Empty", - t1: []model.Input{}, - t2: []model.Input{{Token: 1}, {Token: 2}, {Token: 3}}, + t1: []input.Input{}, + t2: []input.Input{{Token: 1}, {Token: 2}, {Token: 3}}, expected: 0, }, { name: "Both Empty", - t1: []model.Input{}, - t2: []model.Input{}, + t1: []input.Input{}, + t2: []input.Input{}, expected: 0, }, } @@ -82,7 +82,7 @@ func TestFindCacheSlot(t *testing.T) { tests := []struct { name string cache InputCache - prompt []model.Input + prompt []input.Input longest expected best expected }{ @@ -91,18 +91,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []model.Input{{Token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 0}, best: expected{result: 0, len: 0}, }, @@ -111,18 +111,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 1}, {Token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 2}, best: expected{result: 1, len: 2}, }, @@ -131,18 +131,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }}, - prompt: []model.Input{{Token: 2}}, + prompt: []input.Input{{Token: 2}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -152,19 +152,19 @@ func TestFindCacheSlot(t *testing.T) { slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{}, + Inputs: []input.Input{}, InUse: false, lastUsed: time.Time{}, }, }, }, - prompt: []model.Input{{Token: 1}}, + prompt: []input.Input{{Token: 1}}, longest: expected{result: 0, len: 1}, best: expected{result: 1, len: 1}, }, @@ -173,18 +173,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 2}, {Token: 3}}, + prompt: []input.Input{{Token: 2}, {Token: 3}}, longest: expected{result: 0, len: 0}, best: expected{result: 1, len: 0}, }, @@ -193,18 +193,18 @@ func TestFindCacheSlot(t *testing.T) { cache: InputCache{slots: []InputCacheSlot{ { Id: 0, - Inputs: []model.Input{{Token: 1}, {Token: 2}}, + Inputs: []input.Input{{Token: 1}, {Token: 2}}, InUse: true, lastUsed: time.Now().Add(-time.Second), }, { Id: 1, - Inputs: []model.Input{{Token: 1}}, + Inputs: []input.Input{{Token: 1}}, InUse: false, lastUsed: time.Now().Add(-2 * time.Second), }, }}, - prompt: []model.Input{{Token: 1}, {Token: 2}}, + prompt: []input.Input{{Token: 1}, {Token: 2}}, longest: expected{result: 1, len: 1}, best: expected{result: 1, len: 2}, }, diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index a51b1459e..c8383a5dd 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -26,6 +26,7 @@ import ( "github.com/ollama/ollama/api" "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" "github.com/ollama/ollama/runner/common" "github.com/ollama/ollama/sample" @@ -41,10 +42,10 @@ type Sequence struct { iBatch int // prompt inputs left to evaluate - inputs []model.Input + inputs []input.Input // inputs that have been added to a batch but not yet submitted to Forward - pendingInputs []model.Input + pendingInputs []input.Input // tokens that have been generated but not returned yet (e.g. for stop sequences) pendingResponses []string @@ -144,8 +145,8 @@ func (s *Server) NewSequence(prompt string, images []ImageData, params NewSequen // inputs processes the prompt and images into a list of inputs // by splitting the prompt on [img-] tags, tokenizing text and // decoding images -func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]model.Input, error) { - var inputs []model.Input +func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]input.Input, error) { + var inputs []input.Input var parts []string var matches [][]string @@ -168,7 +169,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo } for _, t := range tokens { - inputs = append(inputs, model.Input{Token: t}) + inputs = append(inputs, input.Input{Token: t}) } // image - decode and store @@ -196,7 +197,7 @@ func (s *Server) inputs(ctx ml.Context, prompt string, images []ImageData) ([]mo _, _ = s.multimodalHash.Write(images[imageIndex].Data) imageHash := s.multimodalHash.Sum64() - inputs = append(inputs, model.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) + inputs = append(inputs, input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) postTokenize = true } } @@ -250,9 +251,6 @@ type Server struct { // KV cache cache *InputCache - // next sequence for prompt processing to avoid starvation - nextSeq int - // multimodalHash generates hashes for comparing equality // of non-text data multimodalHash maphash.Hash @@ -329,29 +327,25 @@ func (s *Server) processBatch() error { } defer s.mu.Unlock() - var options model.Options - - seqIdx := s.nextSeq - 1 - for range s.seqs { - seqIdx = (seqIdx + 1) % len(s.seqs) - seq := s.seqs[seqIdx] + var options input.Options + for i, seq := range s.seqs { if seq == nil { continue } // if past the num predict limit if seq.numPredict > 0 && seq.numPredicted >= seq.numPredict { - s.removeSequence(seqIdx, "limit") + s.removeSequence(i, "limit") continue } if !s.cache.enabled { seq.inputs = append(seq.cache.Inputs, seq.inputs...) - seq.cache.Inputs = []model.Input{} + seq.cache.Inputs = []input.Input{} } - for i, input := range seq.inputs { + for j, inp := range seq.inputs { if int32(len(seq.cache.Inputs)+len(seq.pendingInputs)+1) > s.cache.numCtx { if len(seq.pendingInputs) == 0 { err := s.cache.ShiftCacheSlot(seq.cache, seq.numKeep) @@ -363,33 +357,23 @@ func (s *Server) processBatch() error { } } - if i >= s.batchSize { + if j >= s.batchSize { break } - // TODO(jessegross): This is a workaround for generating an attention mask and also providing a hint - // to the encoder cache. - // - // Break the batch when switching from text to images so that images are always at the beginning. - if input.Multimodal != nil && !(len(seq.pendingInputs) == 0 || - (len(options.Multimodal) > 0 && options.Multimodal[len(options.Multimodal)-1].Index == len(options.Inputs)-1)) { - s.nextSeq = seqIdx - break - } - - options.Inputs = append(options.Inputs, input.Token) - if input.Multimodal != nil { - options.Multimodal = append(options.Multimodal, model.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: input.Multimodal}) + options.Inputs = append(options.Inputs, inp.Token) + if inp.Multimodal != nil { + options.Multimodal = append(options.Multimodal, input.MultimodalIndex{Index: len(options.Inputs) - 1, Multimodal: inp.Multimodal}) } options.Positions = append(options.Positions, int32(len(seq.cache.Inputs)+len(seq.pendingInputs))) options.Sequences = append(options.Sequences, seq.cache.Id) seq.iBatch = len(options.Outputs) - if i+1 == len(seq.inputs) { + if j+1 == len(seq.inputs) { options.Outputs = append(options.Outputs, int32(len(options.Inputs)-1)) } - seq.pendingInputs = append(seq.pendingInputs, input) + seq.pendingInputs = append(seq.pendingInputs, inp) } seq.inputs = seq.inputs[len(seq.pendingInputs):] @@ -417,7 +401,7 @@ func (s *Server) processBatch() error { // After calling Forward, pending inputs are now in the cache if len(seq.pendingInputs) > 0 { seq.cache.Inputs = append(seq.cache.Inputs, seq.pendingInputs...) - seq.pendingInputs = []model.Input{} + seq.pendingInputs = []input.Input{} } // don't sample prompt processing @@ -464,7 +448,7 @@ func (s *Server) processBatch() error { return err } - seq.inputs = []model.Input{{Token: token}} + seq.inputs = []input.Input{{Token: token}} seq.pendingResponses = append(seq.pendingResponses, piece) sequence := strings.Join(seq.pendingResponses, "")