From 2c40c4d35eddc86673c5b0c116e2a34ef8ee2c4a Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Sun, 9 Mar 2025 21:29:58 -0700 Subject: [PATCH] Fix follow up images and images split across batches --- model/models/gemma3/model.go | 57 ++++++++++++---------------- model/models/gemma3/model_text.go | 63 +++++++++++++++++++++++-------- 2 files changed, 73 insertions(+), 47 deletions(-) diff --git a/model/models/gemma3/model.go b/model/models/gemma3/model.go index 0ea588740..7418bb12f 100644 --- a/model/models/gemma3/model.go +++ b/model/models/gemma3/model.go @@ -5,7 +5,6 @@ import ( "encoding/binary" "hash/fnv" "image" - "slices" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -99,49 +98,43 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er return visionOutputs, nil } +type imageToken struct { + embedding ml.Tensor + index int +} + func (m *Model) PostTokenize(ctx ml.Context, inputs []input.Input) ([]input.Input, error) { - var images []input.Input + var result []input.Input fnvHash := fnv.New64a() - for i := range inputs { - if inputs[i].Multimodal == nil { - for j := range images { - if j == 0 { - inputs[i].Multimodal = images[j].Multimodal - inputs[i].MultimodalHash = images[j].MultimodalHash - } else { - inputs[i].Multimodal = inputs[i].Multimodal.(ml.Tensor).Concat(ctx, images[j].Multimodal.(ml.Tensor), 3) - fnvHash.Reset() - binary.Write(fnvHash, binary.NativeEndian, inputs[i].MultimodalHash) - binary.Write(fnvHash, binary.NativeEndian, images[j].MultimodalHash) - inputs[i].MultimodalHash = fnvHash.Sum64() - } - } - - images = nil + for _, inp := range inputs { + if inp.Multimodal == nil { + result = append(result, inp) } else { - images = append(images, inputs[i]) - inputs[i].Token = -1 - } - } - - for i := range inputs { - if inputs[i].Token == -1 { imageInputs := []input.Input{ {Token: 108}, // "\n\n" {Token: 255999}, // """ } + result = append(result, imageInputs...) + + // add image embeddings + inputMultimodal := inp.Multimodal.(ml.Tensor) + + for i := range inputMultimodal.Dim(1) { + fnvHash.Reset() + binary.Write(fnvHash, binary.NativeEndian, inp.MultimodalHash) + fnvHash.Write([]byte{byte(i)}) + + imageToken := imageToken{embedding: inputMultimodal, index: i} + result = append(result, input.Input{Multimodal: imageToken, MultimodalHash: fnvHash.Sum64()}) + } - // pad inputs with placeholders for image embeddings - imageInputs = append(imageInputs, slices.Repeat([]input.Input{{Token: 0}}, 256)...) // - imageInputs = append(imageInputs, input.Input{Token: 256000}) - - inputs = append(inputs[:i], append(imageInputs, inputs[i+1:]...)...) + result = append(result, input.Input{Token: 256000}) } } - return inputs, nil + return result, nil } func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { @@ -160,7 +153,7 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) { return nil, err } - return m.TextModel.Forward(ctx, inputs, positions, outputs, opts.Multimodal, m.Cache), nil + return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil } func init() { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index de8070d91..2180571eb 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -173,24 +173,53 @@ func (l *TextLayer) Forward(ctx ml.Context, layer int, hiddenState, positionIDs, return hiddenState.Add(ctx, residual) } -func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, multimodal []input.MultimodalIndex, cache kvcache.Cache) ml.Tensor { +func setImageEmbeddings(ctx ml.Context, hiddenState ml.Tensor, multimodal []input.MultimodalIndex, positions []int32) []int32 { + var embedding ml.Tensor + var src, dst, length int + var except []int32 + + for _, image := range multimodal { + imageToken := image.Multimodal.(imageToken) + imageSrc := imageToken.index + imageDst := image.Index + + if embedding == nil { + embedding = imageToken.embedding + src = imageSrc + dst = imageDst + length = 1 + } else if embedding == imageToken.embedding && imageSrc+1 == src && imageDst+1 == dst { + src = imageSrc + dst = imageDst + length++ + } else if embedding == imageToken.embedding && src+length == imageSrc && dst+length == imageDst { + length++ + } else { + visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) + + embedding = imageToken.embedding + src = imageSrc + dst = imageDst + length = 1 + } + + except = append(except, positions[imageDst]) + } + + if embedding != nil { + visionOutputs := embedding.View(ctx, src*embedding.Stride(1), length*embedding.Dim(0)) + ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, dst*hiddenState.Stride(1), length*hiddenState.Dim(0)))) + } + + return except +} + +func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, opts input.Options, cache kvcache.Cache) ml.Tensor { hiddenState := m.TokenEmbedding.Forward(ctx, inputs) hiddenState = hiddenState.Scale(ctx, math.Sqrt(float64(m.TextOptions.hiddenSize))) - if multimodal != nil { - visionOutputs := multimodal[0].Multimodal.(ml.Tensor) - offset := multimodal[0].Index - 1 - visionOutputs.Dim(1) - hiddenState = hiddenState.Set(ctx, visionOutputs, offset*hiddenState.Stride(1)) - - if causal, ok := cache.(*kvcache.WrapperCache).UnderlyingCache().(*kvcache.Causal); ok { - except := make([]int32, visionOutputs.Dim(1)) - for i := 0; i < visionOutputs.Dim(1); i++ { - except[i] = int32(offset + i) - } - - causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) - } - } + except := setImageEmbeddings(ctx, hiddenState, opts.Multimodal, opts.Positions) for i, layer := range m.Layers { // gemma alternates between the sliding window (local) and causal (global) @@ -203,6 +232,10 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor wc := cache.(*kvcache.WrapperCache) wc.SetLayerType(cacheType) + if causal, ok := wc.UnderlyingCache().(*kvcache.Causal); ok { + causal.SetCausal(ctx, kvcache.CausalOptions{Except: except}) + } + var lastLayerOutputs ml.Tensor if i == len(m.Layers)-1 { lastLayerOutputs = outputs