This commit is contained in:
Bruce MacDonald 2025-03-21 13:17:13 -07:00 committed by jmorganca
parent 63e6509ec0
commit 4d8dac8ffc
6 changed files with 290 additions and 85 deletions

View File

@ -73,6 +73,7 @@ func (p *mistral3Model) KV(t *Tokenizer) ggml.KV {
kv["mistral3.vision.image_size"] = p.VisionModel.ImageSize
kv["mistral3.vision.patch_size"] = p.VisionModel.PatchSize
kv["mistral3.vision.num_channels"] = p.VisionModel.NumChannels
// kv["mistral3.vision.attention.layer_norm_epsilon"] = 1e-05 // Default value
kv["mistral3.vision.rope.freq_base"] = p.VisionModel.RopeTheta
// Multimodal configuration

View File

@ -51,7 +51,7 @@ func (p *ImageProcessor) pack(img image.Image, mean, std [3]float32) []float32 {
func (p ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := image.Point{p.imageSize, p.imageSize}
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBicubic)
data := p.pack(newImage, imageproc.ImageNetStandardMean, imageproc.ImageNetStandardSTD)
return data, nil

View File

@ -8,6 +8,7 @@ import (
"io"
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model/imageproc"
)
@ -27,8 +28,8 @@ func getResizeOutputImageSize(img image.Image, longestEdge int, patchSize image.
if ratio > 1.0 {
newSize = image.Point{
int(math.Ceil(float64(b.Max.X) / ratio)),
int(math.Ceil(float64(b.Max.Y) / ratio)),
int(math.Floor(float64(b.Max.X) / ratio)),
int(math.Floor(float64(b.Max.Y) / ratio)),
}
}
@ -66,3 +67,30 @@ func Preprocess(imageData io.Reader) ([]float32, map[string]any, error) {
opts := map[string]any{}
return data, opts, nil
}
type ImageProcessor struct {
imageSize int
patchSize int
numChannels int
longestEdge int
}
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
longestEdge: int(c.Uint("vision.longest_edge", 1024)),
}
}
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
return data, nil
}

View File

@ -1,65 +1,27 @@
package mistral3
import (
"bytes"
"image"
_ "image/jpeg"
_ "image/png"
"slices"
"github.com/ollama/ollama/kvcache"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/model"
"github.com/ollama/ollama/model/imageproc"
"github.com/ollama/ollama/model/input"
)
type Model struct {
model.Base
*TextModel
*VisionModel `gguf:"v,vision"`
*MultiModalProjector `gguf:"mm"`
ImageProcessor
// TODO: Add VisionModel field
// *VisionModel `gguf:"v,vision"`
// TODO: Add MultiModalProjector field for combining vision and text features
// *MultiModalProjector `gguf:"mm"`
}
// Adding ImageProcessor struct
type ImageProcessor struct {
imageSize int
patchSize int
numChannels int
longestEdge int
}
// Function to create a new ImageProcessor
func newImageProcessor(c ml.Config) ImageProcessor {
return ImageProcessor{
imageSize: int(c.Uint("vision.image_size", 1024)),
patchSize: int(c.Uint("vision.patch_size", 16)),
numChannels: int(c.Uint("vision.num_channels", 3)),
longestEdge: int(c.Uint("vision.longest_edge", 1024)),
}
}
// Method to process images for the model
func (p *ImageProcessor) ProcessImage(img image.Image) ([]float32, error) {
// Get output size based on longest edge and patch size
outputSize := getResizeOutputImageSize(img, p.longestEdge, image.Point{p.patchSize, p.patchSize})
// Resize the image
newImage := imageproc.Composite(img)
newImage = imageproc.Resize(newImage, outputSize, imageproc.ResizeBilinear)
// Normalize image data
data := imageproc.Normalize(newImage, imageproc.ClipDefaultMean, imageproc.ClipDefaultSTD, true, true)
return data, nil
}
// TODO: Implement MultimodalProcessor interface
// var _ model.MultimodalProcessor = (*Model)(nil)
// Implement MultimodalProcessor interface
var _ model.MultimodalProcessor = (*Model)(nil)
func New(c ml.Config) (model.Model, error) {
textModel, err := NewTextModel(c)
@ -68,15 +30,10 @@ func New(c ml.Config) (model.Model, error) {
}
m := &Model{
TextModel: textModel,
// Initialize the ImageProcessor
ImageProcessor: newImageProcessor(c),
// TODO: Initialize VisionModel if present
// VisionModel: newVisionModel(c),
// TODO: Initialize MultiModalProjector
// MultiModalProjector: &MultiModalProjector{...},
TextModel: textModel,
VisionModel: newVisionModel(c),
ImageProcessor: newImageProcessor(c),
MultiModalProjector: newMultiModalProjector(c),
}
m.Cache = kvcache.NewCausalCache(m.TextModel.Shift)
@ -84,37 +41,63 @@ func New(c ml.Config) (model.Model, error) {
return m, nil
}
// Implement EncodeMultimodal method for processing images
func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) (any, error) {
// Check if vision model exists - return error for now
return nil, model.ErrNoVisionModel
if len(m.VisionModel.Layers) == 0 {
return nil, model.ErrNoVisionModel
}
// This will be implemented when adding the vision model:
/*
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
// Decode image
image, _, err := image.Decode(bytes.NewReader(multimodalData))
if err != nil {
return nil, err
}
// Process image
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
// Create tensor from image data
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
// Forward pass through vision model
visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
// Project to text embedding space
visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs, m.VisionModel.eps)
return visionOutputs, nil
}
func (m *Model) PostTokenize(inputs []input.Input) ([]input.Input, error) {
var result []input.Input
for _, inp := range inputs {
if inp.Multimodal == nil {
result = append(result, inp)
} else {
inputMultimodal := inp.Multimodal.(ml.Tensor)
// Add special image tokens - using the imageTokenIndex from config
result = append(result,
input.Input{Token: int32(m.MultiModalProjector.imageTokenIndex)}, // Image token
input.Input{Multimodal: inputMultimodal, MultimodalHash: inp.MultimodalHash}, // Image data
)
// Add image token placeholders
result = append(result, slices.Repeat([]input.Input{{Token: 0}}, inputMultimodal.Dim(1)-1)...)
}
}
f32s, err := m.ImageProcessor.ProcessImage(image)
if err != nil {
return nil, err
}
pixelValues, err := ctx.Input().FromFloatSlice(f32s,
m.ImageProcessor.imageSize,
m.ImageProcessor.imageSize,
m.ImageProcessor.numChannels,
)
if err != nil {
return nil, err
}
// Will need VisionModel to process this
// visionOutputs := m.VisionModel.Forward(ctx, pixelValues)
// visionOutputs = m.MultiModalProjector.Forward(ctx, visionOutputs)
// return visionOutputs, nil
*/
return result, nil
}
func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
@ -133,8 +116,20 @@ func (m *Model) Forward(ctx ml.Context, opts input.Options) (ml.Tensor, error) {
return nil, err
}
// TODO: Add handling of multimodal inputs when vision model is added
// Set image embeddings into hidden state if present in opts.Multimodal
// Handle multimodal inputs
// var except []int
// hiddenState := m.TextModel.TokenEmbedding.Forward(ctx, inputs)
// for _, image := range opts.Multimodal {
// visionOutputs := image.Multimodal.(ml.Tensor)
// // Copy vision outputs into the hidden state
// ctx.Forward(visionOutputs.Copy(ctx, hiddenState.View(ctx, image.Index*hiddenState.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1))))
// for i := range visionOutputs.Dim(1) {
// except = append(except, image.Index+i)
// }
// }
return m.TextModel.Forward(ctx, inputs, positions, outputs, opts, m.Cache), nil
}

View File

@ -0,0 +1,143 @@
package mistral3
import (
"math"
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
var batchSize int = 1
type VisionSelfAttention struct {
Query *nn.Linear `gguf:"attn_q"`
Key *nn.Linear `gguf:"attn_k"`
Value *nn.Linear `gguf:"attn_v"`
Output *nn.Linear `gguf:"attn_output"`
RopeFactors ml.Tensor `gguf:"rope_freqs.weight"`
}
func (sa *VisionSelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
headDim := opts.headDim
query := sa.Query.Forward(ctx, hiddenState)
key := sa.Key.Forward(ctx, hiddenState)
value := sa.Value.Forward(ctx, hiddenState)
query = query.Reshape(ctx, headDim, opts.numHeads, batchSize)
key = key.Reshape(ctx, headDim, opts.numHeads, batchSize)
value = value.Reshape(ctx, headDim, opts.numHeads, batchSize)
ropeType := uint32(0)
query = query.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale)
key = key.RoPE(ctx, positionIDs, sa.RopeFactors, uint32(headDim), ropeType, opts.ropeBase, opts.ropeScale)
attention := nn.Attention(ctx, query, key, value, 1.0/math.Sqrt(float64(headDim)), nil)
attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2), batchSize)
return sa.Output.Forward(ctx, attention)
}
type VisionMLP struct {
Gate *nn.Linear `gguf:"ffn_gate"`
Up *nn.Linear `gguf:"ffn_up"`
Down *nn.Linear `gguf:"ffn_down"`
}
func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenState ml.Tensor, opts *VisionModelOptions) ml.Tensor {
hiddenState = mlp.Gate.Forward(ctx, hiddenState).GELU(ctx).Mul(ctx, mlp.Up.Forward(ctx, hiddenState))
return mlp.Down.Forward(ctx, hiddenState)
}
type VisionEncoderLayer struct {
AttentionNorm *nn.RMSNorm `gguf:"attn_norm"`
SelfAttention *VisionSelfAttention
FFNNorm *nn.RMSNorm `gguf:"ffn_norm"`
MLP *VisionMLP `gguf:"mlp"`
}
func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, opts *VisionModelOptions) ml.Tensor {
residual := hiddenState
// self attention
hiddenState = e.AttentionNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.SelfAttention.Forward(ctx, hiddenState, positionIDs, opts)
hiddenState = hiddenState.Add(ctx, residual)
residual = hiddenState
// feed forward
hiddenState = e.FFNNorm.Forward(ctx, hiddenState, opts.eps)
hiddenState = e.MLP.Forward(ctx, hiddenState, opts)
return hiddenState.Add(ctx, residual)
}
type VisionModelOptions struct {
hiddenSize int
numHeads int
headDim int
intermediateSize int
imageSize int
patchSize int
numChannels int
eps float32
ropeBase float32
ropeScale float32
}
type VisionModel struct {
PatchEmbedding *nn.Conv2D `gguf:"patch_conv"`
EncoderNorm *nn.LayerNorm `gguf:"encoder_norm"`
Layers []VisionEncoderLayer `gguf:"blk"`
*VisionModelOptions
}
func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor) ml.Tensor {
numPatchesH := m.imageSize / m.patchSize
numPatchesW := m.imageSize / m.patchSize
numPatches := numPatchesH * numPatchesW
hiddenState := m.PatchEmbedding.Forward(ctx, pixelValues, m.patchSize, m.patchSize, 0, 0, 1, 1)
hiddenState = hiddenState.Reshape(ctx, numPatches, m.hiddenSize)
hiddenState = hiddenState.Permute(ctx, 1, 0, 2, 3).Contiguous(ctx)
// Create position IDs
positions := make([]int32, numPatches)
for i := range positions {
positions[i] = int32(i)
}
positionIDs, err := ctx.Input().FromIntSlice(positions, len(positions))
if err != nil {
panic(err)
}
// Apply encoder normalization
hiddenState = m.EncoderNorm.Forward(ctx, hiddenState, m.eps)
// Process through transformer layers
for _, layer := range m.Layers {
hiddenState = layer.Forward(ctx, hiddenState, positionIDs, m.VisionModelOptions)
}
return hiddenState
}
func newVisionModel(c ml.Config) *VisionModel {
return &VisionModel{
Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 24)),
VisionModelOptions: &VisionModelOptions{
hiddenSize: int(c.Uint("vision.embedding_length", 1024)),
numHeads: int(c.Uint("vision.attention.head_count", 16)),
headDim: int(c.Uint("vision.attention.key_length", 64)),
intermediateSize: int(c.Uint("vision.feed_forward_length", 4096)),
imageSize: int(c.Uint("vision.image_size", 1540)),
patchSize: int(c.Uint("vision.patch_size", 14)),
numChannels: int(c.Uint("vision.num_channels", 3)),
eps: c.Float("vision.attention.layer_norm_epsilon", 1e-05),
ropeBase: c.Float("vision.rope.freq_base", 10000.0),
ropeScale: c.Float("vision.rope.freq_scale", 1.0),
},
}
}

View File

@ -0,0 +1,38 @@
package mistral3
import (
"github.com/ollama/ollama/ml"
"github.com/ollama/ollama/ml/nn"
)
type MultiModalProjector struct {
Norm *nn.RMSNorm `gguf:"norm"`
Projection *nn.Linear `gguf:"projection"`
spatialMergeSize int
imageTokenIndex int
hasBias bool
}
func (p *MultiModalProjector) Forward(ctx ml.Context, visionOutputs ml.Tensor, eps float32) ml.Tensor {
// Apply normalization
visionOutputs = p.Norm.Forward(ctx, visionOutputs, eps)
// If the spatial merge size is > 1, average pool the patches
if p.spatialMergeSize > 1 {
// Implementation depends on how the model handles spatial merging
// For simplicity, we'll use a spatial pooling approach
visionOutputs = visionOutputs.AvgPool2D(ctx, p.spatialMergeSize, p.spatialMergeSize, 0)
}
// Project to text embedding dimension
return p.Projection.Forward(ctx, visionOutputs)
}
func newMultiModalProjector(c ml.Config) *MultiModalProjector {
return &MultiModalProjector{
spatialMergeSize: int(c.Uint("spatial_merge_size", 2)),
imageTokenIndex: int(c.Uint("image_token_index", 10)),
hasBias: c.Bool("mm.projector_bias", false),
}
}