compute image embeddings once

This commit is contained in:
Michael Yang 2025-04-01 16:29:35 -07:00
parent 2ab14468a8
commit 87cf2fa1b8
2 changed files with 42 additions and 7 deletions

View File

@ -4,6 +4,7 @@ import (
"bytes"
"image"
"slices"
"sync"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
@ -107,14 +108,37 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er
features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size)
// split into patches to be sent to the text transformer
rows := make([]ml.Tensor, size.Y)
parent := imageFeatures{tensor: features}
rows := make([]*imageRow, size.Y)
for i := range rows {
rows[i] = features.View(ctx, features.Stride(1)*i*size.X, features.Dim(0), features.Stride(1), size.X)
rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}}
}
return rows, nil
}
type imageFeatures struct {
tensor ml.Tensor
dataOnce sync.Once
data []float32
}
type imageRow struct {
parent *imageFeatures
s int
shape []int
}
func (r *imageRow) data() []float32 {
n := 1
for _, s := range r.shape {
n *= s
}
return r.parent.data[r.s*n : (r.s+1)*n]
}
// PostTokenize arranges Mistral 3's inputs for the forward pass
// In Mistral 3 and Pixtral, the input patches are arranged as follows:
// [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END]
@ -126,11 +150,11 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.([]ml.Tensor)
inputMultimodal := inp.Multimodal.([]*imageRow)
for i, row := range inputMultimodal {
// [IMG]
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.Dim(1) + 1})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Dim(1)-1)...)
result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1] + 1})
result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...)
if i == len(inputMultimodal)-1 {
// [IMG_END]
result = append(result, input.Input{Token: 13})

View File

@ -110,8 +110,19 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor
// image embeddings
for _, image := range batch.Multimodal {
visionOutputs := image.Multimodal.(ml.Tensor)
ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
row := image.Multimodal.(*imageRow)
row.parent.dataOnce.Do(func() {
// use a new, throwaway context so the image tensor is not added to the graph
m.Backend().NewContext().Forward(row.parent.tensor).Compute(row.parent.tensor)
row.parent.data = row.parent.tensor.Floats()
})
imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...)
if err != nil {
panic(err)
}
ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1))))
}
for i, layer := range m.Layers {