diff --git a/model/models/mistral3/model.go b/model/models/mistral3/model.go index 6aa7645b5..ed1bb9d9f 100644 --- a/model/models/mistral3/model.go +++ b/model/models/mistral3/model.go @@ -4,6 +4,7 @@ import ( "bytes" "image" "slices" + "sync" "github.com/ollama/ollama/kvcache" "github.com/ollama/ollama/ml" @@ -107,14 +108,37 @@ func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, er features, size := m.MultiModalProjector.Forward(ctx, visionOutputs, size) // split into patches to be sent to the text transformer - rows := make([]ml.Tensor, size.Y) + parent := imageFeatures{tensor: features} + rows := make([]*imageRow, size.Y) for i := range rows { - rows[i] = features.View(ctx, features.Stride(1)*i*size.X, features.Dim(0), features.Stride(1), size.X) + rows[i] = &imageRow{parent: &parent, s: i, shape: []int{features.Dim(0), size.X}} } return rows, nil } +type imageFeatures struct { + tensor ml.Tensor + + dataOnce sync.Once + data []float32 +} + +type imageRow struct { + parent *imageFeatures + s int + shape []int +} + +func (r *imageRow) data() []float32 { + n := 1 + for _, s := range r.shape { + n *= s + } + + return r.parent.data[r.s*n : (r.s+1)*n] +} + // PostTokenize arranges Mistral 3's inputs for the forward pass // In Mistral 3 and Pixtral, the input patches are arranged as follows: // [IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_BREAK][IMG]...[IMG][IMG_END] @@ -126,11 +150,11 @@ func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) { if inp.Multimodal == nil { result = append(result, inp) } else { - inputMultimodal := inp.Multimodal.([]ml.Tensor) + inputMultimodal := inp.Multimodal.([]*imageRow) for i, row := range inputMultimodal { // [IMG] - result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.Dim(1) + 1}) - result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.Dim(1)-1)...) + result = append(result, input.Input{Token: 10, Multimodal: row, MultimodalHash: inp.MultimodalHash, SameBatch: row.shape[1] + 1}) + result = append(result, slices.Repeat([]input.Input{{Token: 10}}, row.shape[1]-1)...) if i == len(inputMultimodal)-1 { // [IMG_END] result = append(result, input.Input{Token: 13}) diff --git a/model/models/mistral3/model_text.go b/model/models/mistral3/model_text.go index 03b8caa0c..f633d45fd 100644 --- a/model/models/mistral3/model_text.go +++ b/model/models/mistral3/model_text.go @@ -110,8 +110,19 @@ func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor // image embeddings for _, image := range batch.Multimodal { - visionOutputs := image.Multimodal.(ml.Tensor) - ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + row := image.Multimodal.(*imageRow) + row.parent.dataOnce.Do(func() { + // use a new, throwaway context so the image tensor is not added to the graph + m.Backend().NewContext().Forward(row.parent.tensor).Compute(row.parent.tensor) + row.parent.data = row.parent.tensor.Floats() + }) + + imageFeature, err := ctx.Input().FromFloatSlice(row.data(), row.shape...) + if err != nil { + panic(err) + } + + ctx.Forward(imageFeature.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), imageFeature.Dim(0)*imageFeature.Dim(1)))) } for i, layer := range m.Layers {