diff --git a/convert/convert.go b/convert/convert.go index bed59a5750..3e98eee1ac 100644 --- a/convert/convert.go +++ b/convert/convert.go @@ -198,6 +198,8 @@ func ConvertModel(fsys fs.FS, f *os.File) error { conv = &qwen2Model{} case "Qwen2_5_VLForConditionalGeneration": conv = &qwen25VLModel{} + case "Qwen3VLForConditionalGeneration", "Qwen3VLMoeForConditionalGeneration": + conv = &qwen3VLModel{} case "BertModel": conv = &bertModel{} case "CohereForCausalLM": diff --git a/convert/convert_qwen3.go b/convert/convert_qwen3.go new file mode 100644 index 0000000000..f54418a9c9 --- /dev/null +++ b/convert/convert_qwen3.go @@ -0,0 +1,157 @@ +package convert + +import ( + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" +) + +type qwen3Model struct { + ModelParameters + MaxPositionEmbeddings uint32 `json:"max_position_embeddings"` + HiddenSize uint32 `json:"hidden_size"` + HiddenLayers uint32 `json:"num_hidden_layers"` + IntermediateSize uint32 `json:"intermediate_size"` + NumAttentionHeads uint32 `json:"num_attention_heads"` + NumKeyValueHeads uint32 `json:"num_key_value_heads"` + HeadDim uint32 `json:"head_dim"` + NumExperts uint32 `json:"num_experts"` + NumExpertsPerToken uint32 `json:"num_experts_per_tok"` + NormTopkProb bool `json:"norm_topk_prob"` + RopeTheta float32 `json:"rope_theta"` + RopeScaling struct { + Type string `json:"type"` + Factor ropeFactor `json:"factor"` + OriginalMaxPositionEmbeddings uint32 `json:"original_max_position_embeddings"` + MropeSection []int32 `json:"mrope_section"` + } `json:"rope_scaling"` + RMSNormEPS float32 `json:"rms_norm_eps"` +} + +// KV implements ModelConverter. +func (q *qwen3Model) KV(t *Tokenizer) ggml.KV { + arch := "qwen3" + if q.NumExperts > 0 { + arch += "moe" + } + + kv := q.ModelParameters.KV(t) + kv["general.architecture"] = arch + kv["block_count"] = q.HiddenLayers + kv["context_length"] = q.MaxPositionEmbeddings + kv["embedding_length"] = q.HiddenSize + kv["feed_forward_length"] = q.IntermediateSize + kv["attention.head_count"] = q.NumAttentionHeads + kv["attention.head_count_kv"] = q.NumKeyValueHeads + kv["attention.key_length"] = q.HeadDim + kv["attention.value_length"] = q.HeadDim + + if q.NumExperts > 0 { + kv["expert_count"] = q.NumExperts + kv["expert_used_count"] = q.NumExpertsPerToken + kv["norm_top_k_prob"] = q.NormTopkProb + } + + kv["rope.freq_base"] = q.RopeTheta + kv["attention.layer_norm_rms_epsilon"] = q.RMSNormEPS + + switch q.RopeScaling.Type { + case "": + // no scaling + case "yarn": + kv["rope.scaling.type"] = q.RopeScaling.Type + kv["rope.scaling.factor"] = q.RopeScaling.Factor + case "mrope", "default": + kv["rope.mrope_section"] = q.RopeScaling.MropeSection + default: + panic("unknown rope scaling type") + } + return kv +} + +// Tensors implements ModelConverter. +func (q *qwen3Model) Tensors(ts []Tensor) []*ggml.Tensor { + var out []*ggml.Tensor + + // TODO: handle split experts + + for _, t := range ts { + switch { + case strings.Contains(t.Name(), "ffn_gate_up_exps"): + afterFunc := func(t tensor.Tensor) (tensor.Tensor, error) { return tensor.Transpose(t, 0, 2, 1) } + for t := range splitDim(t, 2, + split{Replacer: strings.NewReplacer("gate_up", "gate"), afterFunc: afterFunc}, + split{Replacer: strings.NewReplacer("gate_up", "up"), afterFunc: afterFunc}, + ) { + t.Shape[1], t.Shape[2] = t.Shape[2], t.Shape[1] + out = append(out, t) + } + case strings.Contains(t.Name(), "ffn_down_exps"): + shape := slices.Clone(t.Shape()) + shape[1], shape[2] = shape[2], shape[1] + t.SetRepacker(func(_ string, data []float32, shape []uint64) ([]float32, error) { + dims := make([]int, len(shape)) + for i := range shape { + dims[i] = int(shape[i]) + } + + var tt tensor.Tensor = tensor.New(tensor.WithShape(dims...), tensor.WithBacking(data)) + tt, err := tensor.Transpose(tt, 0, 2, 1) + if err != nil { + return nil, err + } + + // flatten tensor so it can be written as a vector + if err := tt.Reshape(tt.Shape().TotalSize()); err != nil { + return nil, err + } + + return native.VectorF32(tt.(*tensor.Dense)) + }) + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: shape, + WriterTo: t, + }) + default: + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: t.Shape(), + WriterTo: t, + }) + } + } + + return out +} + +// Replacements implements ModelConverter. +func (q *qwen3Model) Replacements() []string { + return []string{ + "lm_head", "output", + "model.embed_tokens", "token_embd", + "model.layers", "blk", + "input_layernorm", "attn_norm", + "self_attn.k_proj", "attn_k", + "self_attn.k_norm", "attn_k_norm", + "self_attn.v_proj", "attn_v", + "self_attn.q_proj", "attn_q", + "self_attn.q_norm", "attn_q_norm", + "self_attn.o_proj", "attn_output", + "mlp.down_proj", "ffn_down", + "mlp.gate_proj", "ffn_gate", + "mlp.up_proj", "ffn_up", + "mlp.gate.weight", "ffn_gate_inp.weight", + "mlp.experts.down_proj", "ffn_down_exps.weight", + "mlp.experts.gate_up_proj", "ffn_gate_up_exps.weight", + "post_attention_layernorm", "ffn_norm", + "model.norm", "output_norm", + } +} + +var _ ModelConverter = (*qwen3Model)(nil) diff --git a/convert/convert_qwen3vl.go b/convert/convert_qwen3vl.go new file mode 100644 index 0000000000..e0ccb805fe --- /dev/null +++ b/convert/convert_qwen3vl.go @@ -0,0 +1,116 @@ +package convert + +import ( + "cmp" + "encoding/json" + "io/fs" + "slices" + "strings" + + "github.com/ollama/ollama/fs/ggml" +) + +type qwen3VLModel struct { + qwen3Model `json:"text_config"` + + VisionModel struct { + Depth uint32 `json:"depth"` + HiddenSize uint32 `json:"hidden_size"` + NumHeads uint32 `json:"num_heads"` + InChannels uint32 `json:"in_channels"` + PatchSize uint32 `json:"patch_size"` + SpatialMergeSize uint32 `json:"spatial_merge_size"` + WindowSize uint32 `json:"window_size"` + RMSNormEps float32 `json:"layer_norm_epsilon"` + RopeTheta float32 `json:"rope_theta"` + TemporalPatchSize uint32 `json:"temporal_patch_size"` + DeepstackVisualIndexes []int32 `json:"deepstack_visual_indexes"` + + Size struct { + ShortestEdge uint32 `json:"shortest_edge"` + LongestEdge uint32 `json:"longest_edge"` + } `json:"size"` + + ImageMean []float32 `json:"image_mean"` + ImageStd []float32 `json:"image_std"` + } `json:"vision_config"` +} + +func (m *qwen3VLModel) parseMore(fsys fs.FS) error { + bts, err := fs.ReadFile(fsys, "preprocessor_config.json") + if err != nil { + return err + } + + return json.Unmarshal(bts, &m.VisionModel) +} + +func (m *qwen3VLModel) KV(t *Tokenizer) ggml.KV { + kv := m.qwen3Model.KV(t) + + arch := "qwen3vl" + if m.NumExperts > 0 { + arch += "moe" + } + // override architecture + kv["general.architecture"] = arch + + kv["vision.block_count"] = cmp.Or(m.VisionModel.Depth, 32) + kv["vision.embedding_length"] = m.VisionModel.HiddenSize + kv["vision.attention.head_count"] = cmp.Or(m.VisionModel.NumHeads, 16) + kv["vision.num_channels"] = m.VisionModel.InChannels + kv["vision.patch_size"] = cmp.Or(m.VisionModel.PatchSize, 14) + kv["vision.spatial_merge_size"] = cmp.Or(m.VisionModel.SpatialMergeSize, 2) + kv["vision.attention.layer_norm_epsilon"] = cmp.Or(m.VisionModel.RMSNormEps, 1e-6) + kv["vision.rope.freq_base"] = cmp.Or(m.VisionModel.RopeTheta, 1e4) + kv["vision.temporal_patch_size"] = cmp.Or(m.VisionModel.TemporalPatchSize, 2) + kv["vision.deepstack_visual_indexes"] = m.VisionModel.DeepstackVisualIndexes + + kv["vision.shortest_edge"] = m.VisionModel.Size.ShortestEdge + kv["vision.longest_edge"] = m.VisionModel.Size.LongestEdge + + kv["vision.image_mean"] = m.VisionModel.ImageMean + kv["vision.image_std"] = m.VisionModel.ImageStd + + return kv +} + +func (m *qwen3VLModel) Tensors(ts []Tensor) []*ggml.Tensor { + var rest []Tensor + var out []*ggml.Tensor + for _, t := range ts { + switch { + case strings.Contains(t.Name(), "attn_qkv"): + out = append(out, slices.Collect(splitDim(t, 0, + split{Replacer: strings.NewReplacer("attn_qkv", "attn_q")}, + split{Replacer: strings.NewReplacer("attn_qkv", "attn_k")}, + split{Replacer: strings.NewReplacer("attn_qkv", "attn_v")}, + ))...) + case strings.Contains(t.Name(), "patch_embed") && strings.HasSuffix(t.Name(), "weight"): + shape := t.Shape() + out = append(out, &ggml.Tensor{ + Name: t.Name(), + Kind: t.Kind(), + Shape: append([]uint64{shape[0] * shape[1]}, shape[2:]...), + WriterTo: t, + }) + default: + rest = append(rest, t) + } + } + + return append(m.qwen3Model.Tensors(rest), out...) +} + +func (m *qwen3VLModel) Replacements() []string { + return append( + m.qwen3Model.Replacements(), + "model.language_", "", + "model.visual", "v", + "patch_embed.proj", "patch_embed", + "blocks", "blk", + "attn.qkv", "attn_qkv", + "attn.proj", "attn_out", + "deepstack_merger_list", "deepstack_merger", + ) +} diff --git a/convert/tensor.go b/convert/tensor.go index 9b8517f1ec..27bdd13ff1 100644 --- a/convert/tensor.go +++ b/convert/tensor.go @@ -19,8 +19,8 @@ type split struct { dim int slices []tensor.Slice - // fn is an optional function to apply to the tensor after slicing - fn func(tensor.Tensor) (tensor.Tensor, error) + // afterFunc is an optional function to apply to the tensor after slicing + afterFunc func(tensor.Tensor) (tensor.Tensor, error) } // splitDim splits a tensor along a specified dimension into multiple tensors. The dimension @@ -54,8 +54,8 @@ func splitDim(t Tensor, dim int, splits ...split) iter.Seq[*ggml.Tensor] { tt = tensor.Materialize(tt) - if split.fn != nil { - tt, err = split.fn(tt) + if split.afterFunc != nil { + tt, err = split.afterFunc(tt) if err != nil { return nil, err } diff --git a/convert/tensor_test.go b/convert/tensor_test.go index 3a34bbff6f..c1f58da6e4 100644 --- a/convert/tensor_test.go +++ b/convert/tensor_test.go @@ -432,7 +432,7 @@ func TestSplitDim(t *testing.T) { t.Run("split with transpose", func(t *testing.T) { next, stop := iter.Pull(splitDim(&r, 1, split{Replacer: strings.NewReplacer("a", "x")}, - split{Replacer: strings.NewReplacer("b", "y"), fn: func(tt tensor.Tensor) (tensor.Tensor, error) { + split{Replacer: strings.NewReplacer("b", "y"), afterFunc: func(tt tensor.Tensor) (tensor.Tensor, error) { return tensor.Transpose(tt, 1, 0) }}, )) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index fcb3d9fdb4..c0ca068ab6 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -242,13 +242,13 @@ func (kv KV) OllamaEngineRequired() bool { return slices.Contains([]string{ "gemma3", "gemma3n", - "mistral3", - "qwen3", - "qwen3moe", + "gptoss", "gpt-oss", "llama4", + "mistral3", "mllama", "qwen25vl", - "gptoss", "gpt-oss", + "qwen3", "qwen3moe", + "qwen3vl", "qwen3vlmoe", }, kv.Architecture()) } diff --git a/integration/llm_image_test.go b/integration/llm_image_test.go index e359156535..e1c16bafdb 100644 --- a/integration/llm_image_test.go +++ b/integration/llm_image_test.go @@ -26,6 +26,13 @@ func TestVisionModels(t *testing.T) { { model: "gemma3", }, + { + model: "qwen3-vl:8b", + }, + { + // Qwen 3 VL mixture of experts + model: "qwen3-vl:30b", + }, } for _, v := range testCases { diff --git a/ml/backend.go b/ml/backend.go index 764ff0854b..bf390c0121 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -161,6 +161,7 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor + Conv3D(ctx Context, weight Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) Tensor IM2Col(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index 3feb5b5d2b..33401c3048 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1182,6 +1182,10 @@ func (t *Tensor) Concat(ctx ml.Context, t2 ml.Tensor, dim int) ml.Tensor { } func (t *Tensor) Contiguous(ctx ml.Context, shape ...int) ml.Tensor { + if slices.Contains(shape, -1) { + inferShape(t, shape) + } + switch len(shape) { case 0: return &Tensor{ @@ -1324,7 +1328,43 @@ func (t *Tensor) Copy(ctx ml.Context, t2 ml.Tensor) ml.Tensor { } } +// inferShape updates shape in place to automatically set a single -1 dimesion +// based on the input tensor and the other dimensions +func inferShape(t *Tensor, shape []int) { + total := 1 + for _, dim := range t.Shape() { + total *= dim + } + + dim := -1 + for i := range shape { + switch shape[i] { + case -1: + if dim != -1 { + panic("only one dimension can be inferred") + } + dim = i + case 0: + panic("dimension cannot be zero") + default: + if total%shape[i] != 0 { + panic("cannot infer dimension") + } + + total /= shape[i] + } + } + + if dim != -1 { + shape[dim] = total + } +} + func (t *Tensor) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + if slices.Contains(shape, -1) { + inferShape(t, shape) + } + switch len(shape) { case 1: return &Tensor{ @@ -1537,6 +1577,16 @@ func (t *Tensor) Conv2D(ctx ml.Context, t2 ml.Tensor, s0, s1, p0, p1, d0, d1 int } } +func (t *Tensor) Conv3D(ctx ml.Context, t2 ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor { + var tt ml.Tensor = &Tensor{ + b: t.b, + t: C.ggml_conv_3d(ctx.(*Context).ctx, t.t, t2.(*Tensor).t, C.int64_t(c), C.int(s0), C.int(s1), C.int(s2), C.int(p0), C.int(p1), C.int(p2), C.int(d0), C.int(d1), C.int(d2)), + } + + tt = tt.Reshape(ctx, t.Dim(3)/c, t2.Dim(3)/c) + return tt +} + func (t *Tensor) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { return &Tensor{ b: t.b, diff --git a/ml/backend/ggml/ggml_test.go b/ml/backend/ggml/ggml_test.go new file mode 100644 index 0000000000..4717ea905d --- /dev/null +++ b/ml/backend/ggml/ggml_test.go @@ -0,0 +1,126 @@ +package ggml + +import ( + "errors" + "os" + "testing" + + "github.com/google/go-cmp/cmp" + "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" +) + +func setup(tb testing.TB) ml.Context { + tb.Helper() + + f, err := os.CreateTemp(tb.TempDir(), "*.bin") + if err != nil { + tb.Fatal(err) + } + defer f.Close() + + if err := ggml.WriteGGUF(f, ggml.KV{"general.architecture": "test"}, nil); err != nil { + tb.Fatal(err) + } + + b, err := ml.NewBackend(f.Name(), ml.BackendParams{}) + if err != nil { + tb.Fatal(err) + } + + ctx := b.NewContext().Input() + + tb.Cleanup(func() { + ctx.Close() + b.Close() + }) + + return ctx +} + +func TestInferShape(t *testing.T) { + cases := []struct { + name string + input []int + want []int + err error + }{ + { + name: "no inferred shape", + input: []int{2, 3, 4}, + want: []int{2, 3, 4}, + }, + { + name: "infer begin", + input: []int{-1, 3, 4}, + want: []int{2, 3, 4}, + }, + { + name: "infer mid", + input: []int{2, -1, 4}, + want: []int{2, 3, 4}, + }, + { + name: "infer end", + input: []int{2, 3, -1}, + want: []int{2, 3, 4}, + }, + { + name: "too many inferred dims", + input: []int{-1, 3, -1}, + err: errors.New("only one dimension can be inferred"), + }, + { + name: "infer gather", + input: []int{2, -1}, + want: []int{2, 12}, + }, + { + name: "infer gather all", + input: []int{-1}, + want: []int{24}, + }, + { + name: "infer split", + input: []int{2, -1, 3, 2}, + want: []int{2, 2, 3, 2}, + }, + { + name: "indivisible infer", + input: []int{2, -1, 2, 4}, + err: errors.New("cannot infer dimension"), + }, + { + name: "infer zero dim", + input: []int{2, 0, 4}, + err: errors.New("dimension cannot be zero"), + }, + } + + ctx := setup(t) + tensor, ok := ctx.Empty(ml.DTypeF32, 2, 3, 4).(*Tensor) + if !ok { + t.Fatal("expected *Tensor") + } + + for _, tt := range cases { + t.Run(tt.name, func(t *testing.T) { + defer func() { + if r := recover(); r == nil && tt.err == nil { + // all good + } else if r != nil && tt.err == nil { + t.Errorf("unexpected panic: %v", r) + } else if r == nil && tt.err != nil { + t.Errorf("expected panic but did not get one: %v", tt.err) + } else if errStr, ok := r.(string); ok && errStr != tt.err.Error() { + t.Errorf("expected panic %q but got %q", tt.err.Error(), errStr) + } + }() + + inferShape(tensor, tt.input) + if diff := cmp.Diff(tt.want, tt.input); diff != "" { + t.Errorf("%s: shape mismatch (-want +got):\n%s", tt.name, diff) + } + }) + } +} diff --git a/ml/nn/convolution.go b/ml/nn/convolution.go index 8e015c73f6..db8c61471a 100644 --- a/ml/nn/convolution.go +++ b/ml/nn/convolution.go @@ -4,8 +4,26 @@ import "github.com/ollama/ollama/ml" type Conv2D struct { Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` } func (m *Conv2D) Forward(ctx ml.Context, t ml.Tensor, s0, s1, p0, p1, d0, d1 int) ml.Tensor { - return m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1) + t = m.Weight.Conv2D(ctx, t, s0, s1, p0, p1, d0, d1) + if m.Bias != nil { + t = t.Add(ctx, m.Bias) + } + return t +} + +type Conv3D struct { + Weight ml.Tensor `gguf:"weight"` + Bias ml.Tensor `gguf:"bias"` +} + +func (m *Conv3D) Forward(ctx ml.Context, t ml.Tensor, c, s0, s1, s2, p0, p1, p2, d0, d1, d2 int) ml.Tensor { + t = m.Weight.Conv3D(ctx, t, c, s0, s1, s2, p0, p1, p2, d0, d1, d2) + if m.Bias != nil { + t = t.Add(ctx, m.Bias) + } + return t } diff --git a/model/models/models.go b/model/models/models.go index 0cda615af6..deefeb58f9 100644 --- a/model/models/models.go +++ b/model/models/models.go @@ -14,4 +14,5 @@ import ( _ "github.com/ollama/ollama/model/models/qwen2" _ "github.com/ollama/ollama/model/models/qwen25vl" _ "github.com/ollama/ollama/model/models/qwen3" + _ "github.com/ollama/ollama/model/models/qwen3vl" ) diff --git a/model/models/qwen3/model.go b/model/models/qwen3/model.go index 72ce36ed94..483439ac47 100644 --- a/model/models/qwen3/model.go +++ b/model/models/qwen3/model.go @@ -3,6 +3,7 @@ package qwen3 import ( "cmp" "math" + "strings" "github.com/ollama/ollama/fs" "github.com/ollama/ollama/kvcache" @@ -210,7 +211,7 @@ var _ model.Model = (*Model)(nil) func New(c fs.Config) (model.Model, error) { layers := make([]Layer, c.Uint("block_count")) for i := range layers { - if c.String("general.architecture") == "qwen3moe" { + if strings.HasSuffix(c.String("general.architecture"), "moe") { layers[i].MLP = &sparse{} } else { layers[i].MLP = &dense{} diff --git a/model/models/qwen3vl/imageprocessor.go b/model/models/qwen3vl/imageprocessor.go new file mode 100644 index 0000000000..621167f5e5 --- /dev/null +++ b/model/models/qwen3vl/imageprocessor.go @@ -0,0 +1,194 @@ +package qwen3vl + +import ( + "fmt" + "image" + "math" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model/imageproc" +) + +// ImageProcessor contains configuration for the Qwen 3 VL image processing +type ImageProcessor struct { + numChannels int + patchSize int + temporalPatchSize int + mergeSize int + shortestEdge int + longestEdge int + factor int + rescaleFactor float32 + imageMean []float32 + imageStd []float32 +} + +// newImageProcessor creates a new image processor with default values +func newImageProcessor(c fs.Config) ImageProcessor { + patchSize := int(c.Uint("vision.patch_size", 14)) + mergeSize := int(c.Uint("vision.spatial_merge_size", 2)) + + return ImageProcessor{ + numChannels: int(c.Uint("vision.num_channels", 3)), // not set + patchSize: patchSize, + temporalPatchSize: 2, + mergeSize: mergeSize, + shortestEdge: int(c.Uint("vision.shortest_edge", 64<<10)), + // FIXME(mxyng): the model defined longest edge (16M) is too large for the default + // context length of 8K and will panic. Adjusting to 2M for now. + // longestEdge: int(c.Uint("vision.longest_edge", 16<<20)), + longestEdge: 2 << 20, + factor: patchSize * mergeSize, + rescaleFactor: 1.0 / 255.0, + imageMean: c.Floats("vision.image_mean", imageproc.ImageNetStandardMean[:]), + imageStd: c.Floats("vision.image_std", imageproc.ImageNetStandardSTD[:]), + } +} + +// SmartResize implements the smart resize algorithm +func (p *ImageProcessor) SmartResize(height, width int) (int, int) { + factor := p.factor + + if height < factor || width < factor { + panic(fmt.Sprintf("height:%d or width:%d must be larger than factor:%d", height, width, factor)) + } else if aspectRatio := max(height, width) / min(height, width); aspectRatio > 200 { + panic(fmt.Sprintf("absolute aspect ratio must be smaller than 200, got %v", aspectRatio)) + } + + round := func(x float64) int { return int(math.RoundToEven(x)) } + + hBar := round(float64(height)/float64(factor)) * factor + wBar := round(float64(width)/float64(factor)) * factor + + if hBar*wBar > p.longestEdge { + beta := math.Sqrt(float64(height*width) / float64(p.longestEdge)) + + hBar = int(math.Floor(float64(height)/beta/float64(factor))) * factor + wBar = int(math.Floor(float64(width)/beta/float64(factor))) * factor + } else if hBar*wBar < p.shortestEdge { + beta := math.Sqrt(float64(p.shortestEdge) / float64(height*width)) + + hBar = int(math.Ceil(float64(height)*beta/float64(factor))) * factor + wBar = int(math.Ceil(float64(width)*beta/float64(factor))) * factor + } + + return hBar, wBar +} + +type Grid struct { + Height int + Width int + Temporal int +} + +func (p *ImageProcessor) ProcessImage(ctx ml.Context, img image.Image) (ml.Tensor, *Grid, error) { + origWidth := img.Bounds().Dx() + origHeight := img.Bounds().Dy() + + // Calculate smart resize dimensions + resizedHeight, resizedWidth := p.SmartResize(origHeight, origWidth) + + // Resize image using existing functions + resizedImg := imageproc.Resize(img, image.Point{X: resizedWidth, Y: resizedHeight}, imageproc.ResizeBilinear) + + normalizedPixels := imageproc.Normalize( + resizedImg, + [3]float32{p.imageMean[0], p.imageMean[1], p.imageMean[2]}, + [3]float32{p.imageStd[0], p.imageStd[1], p.imageStd[2]}, + true, // rescale + true, // channelFirst + ) + + // Calculate grid dimensions + grid := &Grid{ + Height: resizedHeight / p.patchSize, + Width: resizedWidth / p.patchSize, + Temporal: 1, // For single images, temporal dimension is 1 + } + + patches, err := p.createPatches(normalizedPixels, resizedHeight, resizedWidth, grid) + if err != nil { + return nil, nil, fmt.Errorf("failed to create patches: %v", err) + } + + patchDim := p.numChannels * p.temporalPatchSize * + p.patchSize * p.patchSize + numPatches := grid.Temporal * grid.Height * grid.Width + + pixelValues := ctx.Input().FromFloats(patches, patchDim, numPatches) + + // Return patches and grid dimensions + return pixelValues, grid, nil +} + +func (p *ImageProcessor) createPatches(pixels []float32, height, width int, grid *Grid) ([]float32, error) { + channels := p.numChannels + patchSize := p.patchSize + mergeSize := p.mergeSize + temporalPatchSize := p.temporalPatchSize + + // Calculate output dimensions + numPatches := grid.Temporal * grid.Height * grid.Width + patchDim := channels * temporalPatchSize * patchSize * patchSize + + result := make([]float32, numPatches*patchDim) + patchIndex := 0 + + // Single temporal frame handling (copies to all frames) + for range grid.Temporal { + for h := 0; h < grid.Height; h += mergeSize { + for w := 0; w < grid.Width; w += mergeSize { + // Handle the 2x2 merged patches + for mh := range mergeSize { + for mw := range mergeSize { + baseOffset := patchIndex * patchDim + + // Extract patch data for first temporal frame + for c := range channels { + channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize) + + for py := range patchSize { + for px := range patchSize { + // Calculate source pixel coordinates + y := (h+mh)*patchSize + py + x := (w+mw)*patchSize + px + + // Source index in input tensor (CHW format) + srcIdx := c*height*width + y*width + x + + // Destination index in first temporal frame + dstIdx := channelOffset + (py * patchSize) + px + + if srcIdx < len(pixels) && dstIdx < len(result) { + result[dstIdx] = pixels[srcIdx] + } + } + } + } + + // Copy first temporal frame to all other frames + if temporalPatchSize > 1 { + for c := range channels { + channelOffset := baseOffset + (c * temporalPatchSize * patchSize * patchSize) + firstFrameOffset := channelOffset + frameSize := patchSize * patchSize + + // Copy first frame to all other frames + for tp := 1; tp < temporalPatchSize; tp++ { + currentFrameOffset := channelOffset + (tp * frameSize) + copy(result[currentFrameOffset:currentFrameOffset+frameSize], + result[firstFrameOffset:firstFrameOffset+frameSize]) + } + } + } + + patchIndex++ + } + } + } + } + } + + return result, nil +} diff --git a/model/models/qwen3vl/model.go b/model/models/qwen3vl/model.go new file mode 100644 index 0000000000..08beb37c20 --- /dev/null +++ b/model/models/qwen3vl/model.go @@ -0,0 +1,204 @@ +package qwen3vl + +import ( + "bytes" + "image" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" + "github.com/ollama/ollama/model/input" +) + +type Model struct { + model.Base + model.TextProcessor + + *TextModel + *VisionModel `gguf:"v"` + + ImageProcessor + + positionCache []int32 +} + +func (m *Model) EncodeMultimodal(ctx ml.Context, multimodalData []byte) ([]input.Multimodal, error) { + if len(m.VisionModel.Layers) == 0 { + return nil, model.ErrNoVisionModel + } + + img, _, err := image.Decode(bytes.NewReader(multimodalData)) + if err != nil { + return nil, err + } + + pixelValues, grid, err := m.ProcessImage(ctx, img) + if err != nil { + return nil, err + } + + // Calculate tensor dimensions + visionOutputs, deepstackVisualEmbeds := m.VisionModel.Forward(ctx, pixelValues, grid) + mm := []input.Multimodal{{Tensor: visionOutputs, Data: grid}} + for i := range deepstackVisualEmbeds { + mm = append(mm, input.Multimodal{Tensor: deepstackVisualEmbeds[i]}) + } + + return mm, nil +} + +var ( + tokenVision int32 = 151655 + tokenVisionStart int32 = 151652 + tokenVisionEnd int32 = 151653 +) + +type modelInput struct { + *input.Input + position int32 +} + +// PostTokenize arranges Qwen 3 VL's inputs for the forward pass +func (m *Model) PostTokenize(inputs []*input.Input) ([]*input.Input, error) { + m.positionCache = m.positionCache[:0] + return slices.Collect(func(yield func(*input.Input) bool) { + for i := range inputs { + s := []modelInput{{Input: inputs[i]}} + if mm := inputs[i].Multimodal; mm != nil { + t := mm[0].Tensor + s = slices.Repeat([]modelInput{ + { + position: int32(i + 1), + Input: &input.Input{Token: tokenVision}, + }, + }, t.Dim(1)+1+1) + + s[0] = modelInput{ + Input: &input.Input{Token: tokenVisionStart}, + position: int32(i), + } + + s[len(s)-1] = modelInput{ + Input: &input.Input{Token: tokenVisionEnd}, + position: int32(i + mm[0].Data.(*Grid).Width/m.spatialMergeSize + 1), + } + + s[1] = modelInput{ + Input: &input.Input{ + Token: tokenVision, + Multimodal: inputs[i].Multimodal, + MultimodalHash: inputs[i].MultimodalHash, + SameBatch: t.Dim(1), + }, + position: int32(i + 1), + } + } + + for _, e := range s { + position := e.position + if position == 0 && len(m.positionCache) > 0 { + position = m.positionCache[len(m.positionCache)-1] + 1 + } + + m.positionCache = append(m.positionCache, position) + if !yield(e.Input) { + return + } + } + } + }), nil +} + +func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { + positionSlice := slices.Collect(makeSlice2D[int32](3, len(batch.Positions))) + for i, id := range batch.Positions { + if id < int32(len(m.positionCache)) { + id = m.positionCache[id] + } else if len(m.positionCache) > 0 { + id = id - int32(len(m.positionCache)) + m.positionCache[len(m.positionCache)-1] + 1 + } + + positionSlice[0][i] = id + positionSlice[1][i] = id + positionSlice[2][i] = id + } + + hiddenStates := m.TextModel.TokenEmbedding.Forward(ctx, batch.Inputs).Duplicate(ctx) + + var deepstackVisualEmbeds []ml.Tensor + for _, mi := range batch.Multimodal { + visionOutputs := mi.Multimodal[0].Tensor + ctx.Forward(visionOutputs.Copy(ctx, hiddenStates.View(ctx, mi.Index*hiddenStates.Stride(1), visionOutputs.Dim(0)*visionOutputs.Dim(1)))) + + if grid, ok := mi.Multimodal[0].Data.(*Grid); ok { + for i := range visionOutputs.Dim(1) { + w := grid.Width / m.spatialMergeSize + positionSlice[1][mi.Index+i] += int32(i / w) + positionSlice[2][mi.Index+i] += int32(i % w) + } + } + + deepstackVisualEmbeds = make([]ml.Tensor, len(mi.Multimodal[1:])) + for i, mm := range mi.Multimodal[1:] { + deepstackVisualEmbeds[i] = ctx.Input().Zeros(mm.Tensor.DType(), hiddenStates.Shape()...) + ctx.Forward(mm.Tensor.Copy(ctx, deepstackVisualEmbeds[i].View(ctx, mi.Index*deepstackVisualEmbeds[i].Stride(1), mm.Tensor.Dim(0)*mm.Tensor.Dim(1)))) + } + } + + positions := ctx.Input().FromInts(slices.Concat(positionSlice...), len(positionSlice[0]), len(positionSlice)) + cos, sin := m.rotaryEmbedding(ctx, positions) + for i, layer := range m.TextModel.Layers { + if m.Cache != nil { + m.Cache.SetLayer(i) + } + + var outputs ml.Tensor + if i == len(m.TextModel.Layers)-1 { + outputs = batch.Outputs + } + + hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, outputs, m.Cache, m.Options) + if i < len(deepstackVisualEmbeds) { + hiddenStates = hiddenStates.Add(ctx, deepstackVisualEmbeds[i]) + } + } + + hiddenStates = m.OutputNorm.Forward(ctx, hiddenStates, 1e-06) + return m.Output.Forward(ctx, hiddenStates), nil +} + +func New(c fs.Config) (model.Model, error) { + m := Model{ + TextProcessor: model.NewBytePairEncoding( + &model.Vocabulary{ + Values: c.Strings("tokenizer.ggml.tokens"), + Types: c.Ints("tokenizer.ggml.token_type"), + Merges: c.Strings("tokenizer.ggml.merges"), + AddBOS: c.Bool("tokenizer.ggml.add_bos_token", false), + BOS: []int32{int32(c.Uint("tokenizer.ggml.bos_token_id"))}, + AddEOS: c.Bool("tokenizer.ggml.add_eos_token", false), + EOS: append( + []int32{int32(c.Uint("tokenizer.ggml.eos_token_id"))}, + c.Ints("tokenizer.ggml.eos_token_ids")..., + ), + }, + `(?i:'s|'t|'re|'ve|'m|'ll|'d)|[^\r\n\p{L}\p{N}]?\p{L}+|\p{N}| ?[^\s\p{L}\p{N}]+[\r\n]*|\s*[\r\n]+|\s+(?!\S)|\s+`, + ), + TextModel: newTextModel(c), + VisionModel: newVisionModel(c), + ImageProcessor: newImageProcessor(c), + } + + m.Cache = kvcache.NewCausalCache(func(ctx ml.Context, layer int, key, position ml.Tensor) (ml.Tensor, error) { + m.positionCache = nil + return nil, kvcache.ErrNotSupported + }) + return &m, nil +} + +func init() { + model.Register("qwen3vl", New) + model.Register("qwen3vlmoe", New) +} diff --git a/model/models/qwen3vl/model_text.go b/model/models/qwen3vl/model_text.go new file mode 100644 index 0000000000..14e7d7dc08 --- /dev/null +++ b/model/models/qwen3vl/model_text.go @@ -0,0 +1,229 @@ +package qwen3vl + +import ( + "cmp" + "math" + "slices" + "strings" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/kvcache" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" + "github.com/ollama/ollama/model" +) + +type TextOptions struct { + hiddenSize, + numHeads, + numKVHeads, + keyLength, + valueLength int + + eps, + ropeBase, + ropeScale float32 + mropeSections []int + + numExperts, numExpertsUsed int + normTopKProb bool + + inverseFrequenciesCache []float32 +} + +func (o TextOptions) headDim() int { + return cmp.Or(o.keyLength, o.valueLength, o.hiddenSize/o.numHeads) +} + +type TextAttention struct { + Query *nn.Linear `gguf:"attn_q"` + QueryNorm *nn.RMSNorm `gguf:"attn_q_norm"` + Key *nn.Linear `gguf:"attn_k"` + KeyNorm *nn.RMSNorm `gguf:"attn_k_norm"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_output"` +} + +func (sa *TextAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + batchSize := hiddenStates.Dim(1) + + query := sa.Query.Forward(ctx, hiddenStates) + key := sa.Key.Forward(ctx, hiddenStates) + value := sa.Value.Forward(ctx, hiddenStates) + + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, batchSize) + key = key.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) + value = value.Reshape(ctx, opts.headDim(), opts.numKVHeads, batchSize) + + query = sa.QueryNorm.Forward(ctx, query, opts.eps) + key = sa.KeyNorm.Forward(ctx, key, opts.eps) + + query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) + key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + + attention := nn.Attention(ctx, query, key, value, 1./math.Sqrt(float64(opts.headDim())), cache) + attention = attention.Reshape(ctx, attention.Dim(0)*attention.Dim(1), batchSize) + return sa.Output.Forward(ctx, attention) +} + +type TextMLP interface { + Forward(ml.Context, ml.Tensor, *TextOptions) ml.Tensor +} + +type sparse struct { + Router *nn.Linear `gguf:"ffn_gate_inp"` + Gate *nn.LinearBatch `gguf:"ffn_gate_exps"` + Up *nn.LinearBatch `gguf:"ffn_up_exps"` + Down *nn.LinearBatch `gguf:"ffn_down_exps"` +} + +func (mlp *sparse) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts *TextOptions) ml.Tensor { + hiddenDim, sequenceLength, batchSize := hiddenStates.Dim(0), hiddenStates.Dim(1), hiddenStates.Dim(2) + hiddenStates = hiddenStates.Reshape(ctx, hiddenDim, sequenceLength*batchSize) + routerLogits := mlp.Router.Forward(ctx, hiddenStates) + + routingWeights := routerLogits.Softmax(ctx) + selectedExperts := routingWeights.TopK(ctx, opts.numExpertsUsed) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExperts, hiddenStates.Dim(1)).Rows(ctx, selectedExperts) + if opts.normTopKProb { + routingWeights = routingWeights.Reshape(ctx, opts.numExpertsUsed, hiddenStates.Dim(1)) + routingWeights = routingWeights.Div(ctx, routingWeights.SumRows(ctx)) + routingWeights = routingWeights.Reshape(ctx, 1, opts.numExpertsUsed, hiddenStates.Dim(1)) + } + + hiddenStates = hiddenStates.Reshape(ctx, hiddenStates.Dim(0), 1, hiddenStates.Dim(1)) + + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates, selectedExperts).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates, selectedExperts)) + + experts := mlp.Down.Forward(ctx, hiddenStates, selectedExperts) + experts = experts.Mul(ctx, routingWeights) + + nextStates := experts.View(ctx, 0, experts.Dim(0), experts.Stride(2), experts.Dim(2)) + for i := 1; i < opts.numExpertsUsed; i++ { + nextStates = nextStates.Add(ctx, experts.View(ctx, i*experts.Stride(1), experts.Dim(0), experts.Stride(2), experts.Dim(2))) + } + + return nextStates +} + +type dense struct { + Gate *nn.Linear `gguf:"ffn_gate"` + Up *nn.Linear `gguf:"ffn_up"` + Down *nn.Linear `gguf:"ffn_down"` +} + +func (mlp *dense) Forward(ctx ml.Context, hiddenStates ml.Tensor, _ *TextOptions) ml.Tensor { + hiddenStates = mlp.Gate.Forward(ctx, hiddenStates).SILU(ctx, mlp.Up.Forward(ctx, hiddenStates)) + return mlp.Down.Forward(ctx, hiddenStates) +} + +type TextLayer struct { + AttentionNorm *nn.RMSNorm `gguf:"attn_norm"` + *TextAttention + + MLPNorm *nn.RMSNorm `gguf:"ffn_norm"` + TextMLP +} + +func (d *TextLayer) Forward(ctx ml.Context, hiddenStates, cos, sin, outputs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = d.AttentionNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.TextAttention.Forward(ctx, hiddenStates, cos, sin, cache, opts) + + if outputs != nil { + hiddenStates = hiddenStates.Rows(ctx, outputs) + residual = residual.Rows(ctx, outputs) + } + + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = d.MLPNorm.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = d.TextMLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type TextModel struct { + TokenEmbedding *nn.Embedding `gguf:"token_embd"` + OutputNorm *nn.RMSNorm `gguf:"output_norm"` + Output *nn.Linear `gguf:"output,alt:token_embd"` + + Layers []TextLayer `gguf:"blk"` + + Options *TextOptions +} + +func (m *TextModel) rotaryEmbedding(ctx ml.Context, positions ml.Tensor) (_, _ ml.Tensor) { + positions = positions.Reshape(ctx, 1, positions.Dim(0), positions.Dim(1)) + if len(m.Options.inverseFrequenciesCache) == 0 { + m.Options.inverseFrequenciesCache = make([]float32, m.Options.headDim()/2) + for i := range m.Options.inverseFrequenciesCache { + frequency := float32(math.Pow(float64(m.Options.ropeBase), float64(i*2)/float64(m.Options.headDim()))) + m.Options.inverseFrequenciesCache[i] = 1 / frequency + } + } + + inverseFrequencies := ctx.Input().FromFloats(m.Options.inverseFrequenciesCache, 1, len(m.Options.inverseFrequenciesCache)) + + positions = positions.Cast(ctx, ml.DTypeF32) + frequencies := inverseFrequencies.Mulmat(ctx, positions) + + interleaved := frequencies.View(ctx, + 0, frequencies.Dim(0), + frequencies.Stride(1), frequencies.Dim(1), + ) + + for _, i := range []int{1, 2} { + args := []int{ + i * frequencies.Stride(0), 1, + 3 * frequencies.Stride(0), m.Options.mropeSections[i], + frequencies.Stride(1), frequencies.Dim(1), + } + + ctx.Forward(frequencies.View(ctx, i*frequencies.Stride(2)+args[0], args[1:]...). + Copy(ctx, interleaved.View(ctx, args[0], args[1:]...))) + } + + interleaved = interleaved.Concat(ctx, interleaved, 0) + interleaved = interleaved.Reshape(ctx, interleaved.Dim(0), 1, interleaved.Dim(1), interleaved.Dim(2)) + return interleaved.Cos(ctx), interleaved.Sin(ctx) +} + +var _ model.Model = (*Model)(nil) + +func newTextModel(c fs.Config) *TextModel { + layers := make([]TextLayer, c.Uint("block_count")) + for i := range layers { + if strings.HasSuffix(c.String("general.architecture"), "moe") { + layers[i].TextMLP = &sparse{} + } else { + layers[i].TextMLP = &dense{} + } + } + + m := TextModel{ + Layers: layers, + Options: &TextOptions{ + hiddenSize: int(c.Uint("embedding_length")), + numHeads: int(c.Uint("attention.head_count")), + numKVHeads: int(c.Uint("attention.head_count_kv")), + keyLength: int(c.Uint("attention.key_length")), + valueLength: int(c.Uint("attention.value_length")), + eps: c.Float("attention.layer_norm_rms_epsilon"), + ropeBase: c.Float("rope.freq_base"), + ropeScale: c.Float("rope.scaling.factor", 1), + numExperts: int(c.Uint("expert_count")), + numExpertsUsed: int(c.Uint("expert_used_count")), + normTopKProb: c.Bool("norm_top_k_prob", true), + mropeSections: slices.Collect(func(yield func(int) bool) { + for _, section := range c.Ints("mrope_sections", []int32{24, 20, 20}) { + if !yield(int(section)) { + return + } + } + }), + }, + } + + return &m +} diff --git a/model/models/qwen3vl/model_vision.go b/model/models/qwen3vl/model_vision.go new file mode 100644 index 0000000000..69118666bf --- /dev/null +++ b/model/models/qwen3vl/model_vision.go @@ -0,0 +1,268 @@ +package qwen3vl + +import ( + "iter" + "math" + "slices" + + "github.com/ollama/ollama/fs" + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/ml/nn" +) + +type VisionAttention struct { + Query *nn.Linear `gguf:"attn_q"` + Key *nn.Linear `gguf:"attn_k"` + Value *nn.Linear `gguf:"attn_v"` + Output *nn.Linear `gguf:"attn_out"` +} + +func rotateHalf(ctx ml.Context, t ml.Tensor) ml.Tensor { + x1 := t.View(ctx, 0, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)) + x2 := t.View(ctx, t.Stride(0)*t.Dim(0)/2, t.Dim(0)/2, t.Stride(1), t.Dim(1), t.Stride(2), t.Dim(2), t.Stride(3), t.Dim(3)).Contiguous(ctx) + return x2.Scale(ctx, -1).Concat(ctx, x1, 0) +} + +func applyRotaryPositionalEmbedding(ctx ml.Context, t, cos, sin ml.Tensor) ml.Tensor { + return t.Mul(ctx, cos).Add(ctx, rotateHalf(ctx, t).Mul(ctx, sin)) +} + +func (sa *VisionAttention) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor { + query := sa.Query.Forward(ctx, hiddenStates) + query = query.Reshape(ctx, opts.headDim(), opts.numHeads, query.Dim(1)) + query = applyRotaryPositionalEmbedding(ctx, query, cos, sin) + + key := sa.Key.Forward(ctx, hiddenStates) + key = key.Reshape(ctx, opts.headDim(), opts.numHeads, key.Dim(1)) + key = applyRotaryPositionalEmbedding(ctx, key, cos, sin) + + value := sa.Value.Forward(ctx, hiddenStates) + value = value.Reshape(ctx, opts.headDim(), opts.numHeads, value.Dim(1)) + + attention := nn.Attention(ctx, query, key, value, math.Pow(float64(opts.headDim()), -0.5), nil) + attention = attention.Reshape(ctx, opts.hiddenSize, attention.Dim(2)) + return sa.Output.Forward(ctx, attention) +} + +type VisionMLP struct { + FC1 *nn.Linear `gguf:"linear_fc1"` + FC2 *nn.Linear `gguf:"linear_fc2"` +} + +func (mlp *VisionMLP) Forward(ctx ml.Context, hiddenStates ml.Tensor, opts VisionOptions) ml.Tensor { + return mlp.FC2.Forward(ctx, mlp.FC1.Forward(ctx, hiddenStates).GELU(ctx)) +} + +type VisionEncoderLayer struct { + Norm1 *nn.LayerNorm `gguf:"norm1"` + Attention *VisionAttention + Norm2 *nn.LayerNorm `gguf:"norm2"` + MLP *VisionMLP `gguf:"mlp"` +} + +func (e *VisionEncoderLayer) Forward(ctx ml.Context, hiddenStates, cos, sin ml.Tensor, opts VisionOptions) ml.Tensor { + residual := hiddenStates + hiddenStates = e.Norm1.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.Attention.Forward(ctx, hiddenStates, cos, sin, opts) + hiddenStates = hiddenStates.Add(ctx, residual) + + residual = hiddenStates + hiddenStates = e.Norm2.Forward(ctx, hiddenStates, opts.eps) + hiddenStates = e.MLP.Forward(ctx, hiddenStates, opts) + return hiddenStates.Add(ctx, residual) +} + +type VisionOptions struct { + hiddenSize, + numHeads, + patchSize, + numChannels, + spatialMergeSize, + temporalPatchSize, + gridPerSide int + + eps, + ropeTheta float32 + + deepstackVisualIndexes []int32 + mropeSections []int +} + +func (o VisionOptions) headDim() int { + return o.hiddenSize / o.numHeads +} + +type VisionPatchMerger struct { + Norm *nn.LayerNorm `gguf:"norm"` + FC1 *nn.Linear `gguf:"linear_fc1"` + FC2 *nn.Linear `gguf:"linear_fc2"` +} + +func (m *VisionPatchMerger) Forward(ctx ml.Context, visionOutputs ml.Tensor, postshuffleNorm bool, opts VisionOptions) ml.Tensor { + hiddenSize := opts.hiddenSize * opts.spatialMergeSize * opts.spatialMergeSize + if postshuffleNorm { + visionOutputs = visionOutputs.Reshape(ctx, hiddenSize, -1) + } + + visionOutputs = m.Norm.Forward(ctx, visionOutputs, opts.eps) + visionOutputs = visionOutputs.Reshape(ctx, hiddenSize, -1) + return m.FC2.Forward(ctx, m.FC1.Forward(ctx, visionOutputs).GELU(ctx)) +} + +type VisionPositionEmbedding struct { + PositionEmbedding *nn.Embedding `gguf:"pos_embed"` +} + +func makeSlice2D[T int32 | float32](n0, n1 int) iter.Seq[[]T] { + return func(yield func([]T) bool) { + for range n0 { + if !yield(make([]T, n1)) { + return + } + } + } +} + +func (m *VisionPositionEmbedding) Forward(ctx ml.Context, hiddenStates ml.Tensor, grid *Grid, opts VisionOptions) ml.Tensor { + indexSlice := slices.Collect(makeSlice2D[int32](4, grid.Height*grid.Width)) + weightSlice := slices.Collect(makeSlice2D[float32](4, grid.Height*grid.Width)) + + stepHeight := float32(opts.gridPerSide-1) / float32(grid.Height-1) + stepWidth := float32(opts.gridPerSide-1) / float32(grid.Width-1) + + var i int + for h := range grid.Height { + for w := range grid.Width { + y, x := float32(h)*stepHeight, float32(w)*stepWidth + + floorY, floorX := int32(y), int32(x) + ceilY, ceilX := min(floorY+1, int32(opts.gridPerSide-1)), min(floorX+1, int32(opts.gridPerSide-1)) + + indexSlice[0][i] = floorY*int32(opts.gridPerSide) + floorX + indexSlice[1][i] = floorY*int32(opts.gridPerSide) + ceilX + indexSlice[2][i] = ceilY*int32(opts.gridPerSide) + floorX + indexSlice[3][i] = ceilY*int32(opts.gridPerSide) + ceilX + + weightSlice[0][i] = (1 - (y - float32(floorY))) * (1 - (x - float32(floorX))) + weightSlice[1][i] = (1 - (y - float32(floorY))) * (x - float32(floorX)) + weightSlice[2][i] = (y - float32(floorY)) * (1 - (x - float32(floorX))) + weightSlice[3][i] = (y - float32(floorY)) * (x - float32(floorX)) + + i++ + } + } + + indices := ctx.Input().FromInts(slices.Concat(indexSlice...), grid.Height*grid.Width*4) + weights := ctx.Input().FromFloats(slices.Concat(weightSlice...), 1, grid.Height*grid.Width*4) + + n := hiddenStates.Dim(0) + positionEmbeds := m.PositionEmbedding.Forward(ctx, indices) + positionEmbeds = positionEmbeds.Mul(ctx, weights) + positionEmbeds = positionEmbeds.Reshape(ctx, n, -1, 4) + + positionEmbeds = positionEmbeds.View(ctx, 0, n, positionEmbeds.Stride(1), grid.Height*grid.Width). + Add(ctx, positionEmbeds.View(ctx, 1*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). + Add(ctx, positionEmbeds.View(ctx, 2*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)). + Add(ctx, positionEmbeds.View(ctx, 3*positionEmbeds.Stride(2), n, positionEmbeds.Stride(1), grid.Height*grid.Width)) + + positionEmbeds = positionEmbeds.Reshape(ctx, -1, grid.Width/opts.spatialMergeSize, opts.spatialMergeSize, grid.Height/opts.spatialMergeSize) + positionEmbeds = positionEmbeds.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx, n, -1) + return hiddenStates.Add(ctx, positionEmbeds) +} + +type VisionModel struct { + PatchEmbedding *nn.Conv3D `gguf:"patch_embed"` + PositionEmbedding *VisionPositionEmbedding + Layers []VisionEncoderLayer `gguf:"blk"` + PatchMerger *VisionPatchMerger `gguf:"merger"` + DeepstackMerger []*VisionPatchMerger `gguf:"deepstack_merger"` + + VisionOptions +} + +func (m *VisionModel) positions(ctx ml.Context, grid *Grid) (_, _ ml.Tensor) { + indices := ctx.Input().FromInts(slices.Collect(func(yield func(int32) bool) { + for y := range grid.Height { + for x := range grid.Width { + if !yield(int32(y)) { + return + } + if !yield(int32(x)) { + return + } + } + } + }), grid.Width*grid.Height*2) + + indices = indices.Reshape(ctx, -1, grid.Width/m.spatialMergeSize, m.spatialMergeSize, grid.Height/m.spatialMergeSize) + indices = indices.Permute(ctx, 0, 2, 1, 3).Contiguous(ctx) + indices = indices.Reshape(ctx, -1) + + halfDim := m.headDim() / 2 + maxGrid := max(grid.Height, grid.Width) + frequencies := ctx.Input().FromFloats(slices.Collect(func(yield func(float32) bool) { + ropeTheta := float64(m.ropeTheta) + for i := range maxGrid { + for j := range halfDim / 2 { + if !yield(float32(i) / float32(math.Pow(ropeTheta, float64(j*2)/float64(halfDim)))) { + return + } + } + } + }), halfDim/2, maxGrid) + + embeds := frequencies.Rows(ctx, indices) + embeds = embeds.Reshape(ctx, halfDim, 1, -1) + embeds = embeds.Concat(ctx, embeds, 0) + return embeds.Cos(ctx), embeds.Sin(ctx) +} + +// Forward computes the vision model for an input tensor +func (m *VisionModel) Forward(ctx ml.Context, pixelValues ml.Tensor, grid *Grid) (ml.Tensor, []ml.Tensor) { + pixelValues = pixelValues.Reshape(ctx, m.patchSize, m.patchSize, m.temporalPatchSize, -1) + hiddenStates := m.PatchEmbedding.Forward(ctx, pixelValues, m.numChannels, m.patchSize, m.patchSize, m.temporalPatchSize, 0, 0, 0, 1, 1, 1) + hiddenStates = m.PositionEmbedding.Forward(ctx, hiddenStates, grid, m.VisionOptions) + + cos, sin := m.positions(ctx, grid) + + deepstackStates := make([]ml.Tensor, len(m.deepstackVisualIndexes)) + for i, layer := range m.Layers { + hiddenStates = layer.Forward(ctx, hiddenStates, cos, sin, m.VisionOptions) + if i := slices.Index(m.deepstackVisualIndexes, int32(i)); i >= 0 { + deepstackStates[i] = m.DeepstackMerger[i].Forward(ctx, hiddenStates, true, m.VisionOptions) + } + } + + hiddenStates = m.PatchMerger.Forward(ctx, hiddenStates, false, m.VisionOptions) + return hiddenStates, deepstackStates +} + +// newVisionModel creates a new instance of the Qwen vision model +func newVisionModel(c fs.Config) *VisionModel { + deepstackVisualIndexes := c.Ints("vision.deepstack_visual_indexes") + model := &VisionModel{ + Layers: make([]VisionEncoderLayer, c.Uint("vision.block_count", 32)), + DeepstackMerger: make([]*VisionPatchMerger, len(deepstackVisualIndexes)), + VisionOptions: VisionOptions{ + hiddenSize: int(c.Uint("vision.embedding_length", 1280)), + numHeads: int(c.Uint("vision.attention.head_count", 16)), + patchSize: int(c.Uint("vision.patch_size", 14)), + numChannels: int(c.Uint("vision.num_channels", 3)), + eps: c.Float("vision.attention.layer_norm_epsilon", 1e-6), + ropeTheta: c.Float("vision.rope.freq_base", 10000.0), + spatialMergeSize: int(c.Uint("vision.spatial_merge_size", 2)), + temporalPatchSize: int(c.Uint("vision.temporal_patch_size", 2)), + gridPerSide: int(math.Sqrt(float64(c.Uint("vision.num_positional_embeddings", 2304)))), + mropeSections: slices.Collect(func(yield func(int) bool) { + for _, section := range c.Ints("mrope_sections", []int32{24, 20, 20}) { + if !yield(int(section)) { + return + } + } + }), + deepstackVisualIndexes: deepstackVisualIndexes, + }, + } + + return model +} diff --git a/runner/ollamarunner/cache.go b/runner/ollamarunner/cache.go index a3ffc3bd29..faab1b229c 100644 --- a/runner/ollamarunner/cache.go +++ b/runner/ollamarunner/cache.go @@ -235,15 +235,28 @@ func countCommonPrefix(a []*input.Input, b []*input.Input) int32 { return count } -// TODO(jessegross): If we need to reprocess the inputs we should ensure that -// we don't split up a SameBatch -func (c *InputCache) ShiftDiscard(inputLen int32, numKeep int32) int32 { - targetFree := (c.numCtx - numKeep) / 2 - targetFree = max(targetFree, 1) +// ShiftDiscard computes how many inputs can be discarded from the cache. Inputs in the same batch +// are discarded together. +func (c *InputCache) ShiftDiscard(inputs []*input.Input, numKeep int32) int32 { + targetFree := max((c.numCtx-numKeep)/2, 1) + currentFree := c.numCtx - int32(len(inputs)) - currentFree := c.numCtx - inputLen + var discard, sameBatch int32 + for _, input := range inputs[numKeep:] { + if sameBatch <= 0 && currentFree >= targetFree { + break + } - return max(targetFree-currentFree, 0) + sameBatch-- + currentFree++ + discard++ + + if input.SameBatch > 0 { + sameBatch = int32(input.SameBatch) + } + } + + return discard } type ErrReprocessInputs struct { @@ -264,7 +277,7 @@ func (c *InputCache) ShiftCacheSlot(slot *InputCacheSlot, numKeep int32) error { } inputLen := int32(len(slot.Inputs)) - discard := c.ShiftDiscard(inputLen, numKeep) + discard := c.ShiftDiscard(slot.Inputs, numKeep) if discard <= 0 { return nil diff --git a/runner/ollamarunner/cache_test.go b/runner/ollamarunner/cache_test.go index c0693e8343..d78727e766 100644 --- a/runner/ollamarunner/cache_test.go +++ b/runner/ollamarunner/cache_test.go @@ -3,6 +3,7 @@ package ollamarunner import ( "errors" "fmt" + "slices" "testing" "time" @@ -238,59 +239,137 @@ func TestShiftDiscard(t *testing.T) { name string numCtx int32 numKeep int32 - inputLen int32 + inputs []*input.Input expected int32 }{ { name: "Shift", numCtx: 2048, numKeep: 5, - inputLen: 2048, + inputs: slices.Repeat([]*input.Input{{}}, 2048), expected: 1021, }, { name: "Max Keep", numCtx: 2048, numKeep: 2047, - inputLen: 2048, + inputs: slices.Repeat([]*input.Input{{}}, 2048), expected: 1, }, { name: "No Keep", numCtx: 2048, numKeep: 0, - inputLen: 2048, + inputs: slices.Repeat([]*input.Input{{}}, 2048), expected: 1024, }, { name: "Truncate", numCtx: 2048, numKeep: 5, - inputLen: 5000, + inputs: slices.Repeat([]*input.Input{{}}, 5000), expected: 3973, }, { name: "Truncate Keep", numCtx: 2048, numKeep: 2047, - inputLen: 5000, + inputs: slices.Repeat([]*input.Input{{}}, 5000), expected: 2953, }, { name: "No Op", numCtx: 2048, numKeep: 5, - inputLen: 512, + inputs: slices.Repeat([]*input.Input{{}}, 512), expected: 0, }, + { + name: "Same Batch", + numCtx: 2048, + numKeep: 5, + inputs: slices.Collect(func(yield func(*input.Input) bool) { + for range 1024 { + if !yield(&input.Input{}) { + return + } + } + + if !yield(&input.Input{SameBatch: 512 - 1}) { + return + } + + for range 2048 - 1024 - 1 { + if !yield(&input.Input{}) { + return + } + } + }), + expected: 1531, + }, + { + name: "Same Batch Near Start", + numCtx: 2048, + numKeep: 5, + inputs: slices.Collect(func(yield func(*input.Input) bool) { + for range 10 { + if !yield(&input.Input{}) { + return + } + } + + if !yield(&input.Input{SameBatch: 512 - 1}) { + return + } + + for range 2048 - 10 - 1 { + if !yield(&input.Input{}) { + return + } + } + }), + expected: 1021, + }, + { + name: "Consecutive Same Batch", + numCtx: 32, + inputs: slices.Collect(func(yield func(*input.Input) bool) { + for i := range 32 { + input := input.Input{} + if i%10 == 0 { + input.SameBatch = 10 - 1 + } + if !yield(&input) { + return + } + } + }), + expected: 20, + }, + { + name: "Overlapping Same Batch", + numCtx: 32, + inputs: slices.Collect(func(yield func(*input.Input) bool) { + for i := range 32 { + input := input.Input{} + if slices.Contains([]int{4, 8, 14}, i) { + input.SameBatch = 10 - 1 + } + if !yield(&input) { + return + } + } + }), + expected: 24, + }, } for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { c := InputCache{numCtx: tt.numCtx} - result := c.ShiftDiscard(tt.inputLen, tt.numKeep) + result := c.ShiftDiscard(tt.inputs, tt.numKeep) if result != tt.expected { - t.Errorf("shiftDiscard(ctx: %v, keep: %v input: %v): have %v; want %v", tt.numCtx, tt.numKeep, tt.inputLen, result, tt.expected) + t.Errorf("shiftDiscard(ctx: %v, keep: %v inputs: %v): have %v; want %v", tt.numCtx, tt.numKeep, len(tt.inputs), result, tt.expected) } }) } diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index e977d18fd7..153a3e576e 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -214,7 +214,6 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, parts = []string{prompt} } - postTokenize := false for i, part := range parts { // text - tokenize tokens, err := s.model.(model.TextProcessor).Encode(part, i == 0) @@ -257,11 +256,10 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]*input.Input, mmStore.addMultimodal(imageEmbeddings) inputs = append(inputs, &input.Input{Multimodal: imageEmbeddings, MultimodalHash: imageHash}) - postTokenize = true } } - if visionModel && postTokenize { + if visionModel { var err error inputs, err = multimodalProcessor.PostTokenize(inputs) if err != nil { diff --git a/server/routes.go b/server/routes.go index 3d32a9aad9..5b4d5f5db4 100644 --- a/server/routes.go +++ b/server/routes.go @@ -142,7 +142,10 @@ func (s *Server) scheduleRunner(ctx context.Context, name string, caps []model.C // This model is much more capable with a larger context, so set that // unless it would penalize performance too much - if !s.lowVRAM && slices.Contains([]string{"gptoss", "gpt-oss"}, model.Config.ModelFamily) { + if !s.lowVRAM && slices.Contains([]string{ + "gptoss", "gpt-oss", + "qwen3vl", "qwen3vlmoe", + }, model.Config.ModelFamily) { opts.NumCtx = max(opts.NumCtx, 8192) } diff --git a/server/sched.go b/server/sched.go index e262d26fd0..1c04047ef1 100644 --- a/server/sched.go +++ b/server/sched.go @@ -390,11 +390,11 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, systemInfo ml.SystemInfo numParallel = 1 } - // `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1 + // `mllama`, `qwen3vl`, and `qwen3vlmoe` are snowflakes and uses an encoder cache which cannot be used with num_parallel > 1 // ref: https://github.com/ollama/ollama/issues/4165 - if slices.Contains(req.model.Config.ModelFamilies, "mllama") && numParallel != 1 { + if slices.Contains([]string{"mllama", "qwen3vl", "qwen3vlmoe"}, req.model.Config.ModelFamily) && numParallel != 1 { numParallel = 1 - slog.Warn("mllama does not currently support parallel requests") + slog.Warn("model architecture does not currently support parallel requests", "architecture", req.model.Config.ModelFamily) } sessionDuration := envconfig.KeepAlive()