diff --git a/convert/convert_qwen25vl.go b/convert/convert_qwen25vl.go index 00b740421..193715bbb 100644 --- a/convert/convert_qwen25vl.go +++ b/convert/convert_qwen25vl.go @@ -2,15 +2,17 @@ package convert import ( "bytes" - "fmt" + "encoding/binary" "io" "log/slog" "strings" "github.com/ollama/ollama/fs/ggml" + "github.com/pdevine/tensor" + "github.com/pdevine/tensor/native" + "github.com/x448/float16" ) -// Matches the structure in config.json for Qwen2.5-VL type qwen25vlModel struct { ModelParameters HiddenSize uint32 `json:"hidden_size"` @@ -21,33 +23,16 @@ type qwen25vlModel struct { RopeTheta float32 `json:"rope_theta"` NumKeyValueHeads uint32 `json:"num_key_value_heads"` RMSNormEPS float32 `json:"rms_norm_eps"` - // TieWordEmbeddings is often present, even if not used directly here - TieWordEmbeddings bool `json:"tie_word_embeddings"` - // Vision specific parameters from its config (nested under vision_config) - VisionConfig struct { - HiddenSize uint32 `json:"hidden_size"` - ImageSize uint32 `json:"image_size"` - IntermediateSize uint32 `json:"intermediate_size"` - LayerNormEps float32 `json:"layer_norm_eps"` - NumAttentionHeads uint32 `json:"num_attention_heads"` - NumChannels uint32 `json:"num_channels"` - NumHiddenLayers uint32 `json:"num_hidden_layers"` - PatchSize uint32 `json:"patch_size"` - // May include others like projection_dim, use_cls_token etc. + VisionModel struct { } `json:"vision_config"` - // Might have top-level vision params too, check config.json - // Example: ProjectorHiddenAct string `json:"projector_hidden_act"` } -// Compile-time check to ensure qwen25vlModel implements ModelConverter var _ ModelConverter = (*qwen25vlModel)(nil) -// KV provides the metadata key-value pairs for the GGUF header. func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV { - kv := q.ModelParameters.KV(t) // Assuming ModelParameters provides defaults like general.name etc. + kv := q.ModelParameters.KV(t) kv["general.architecture"] = "qwen25vl" - // Text model parameters kv["qwen25vl.block_count"] = q.HiddenLayers kv["qwen25vl.context_length"] = q.MaxPositionEmbeddings kv["qwen25vl.embedding_length"] = q.HiddenSize @@ -57,50 +42,24 @@ func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV { kv["qwen25vl.rope.freq_base"] = q.RopeTheta kv["qwen25vl.attention.layer_norm_rms_epsilon"] = q.RMSNormEPS - // Vision model parameters (prefix with 'vision.') - kv["qwen25vl.vision.hidden_size"] = q.VisionConfig.HiddenSize - kv["qwen25vl.vision.image_size"] = q.VisionConfig.ImageSize - kv["qwen25vl.vision.intermediate_size"] = q.VisionConfig.IntermediateSize - kv["qwen25vl.vision.layer_norm_eps"] = q.VisionConfig.LayerNormEps - kv["qwen25vl.vision.attention.head_count"] = q.VisionConfig.NumAttentionHeads - kv["qwen25vl.vision.num_channels"] = q.VisionConfig.NumChannels // Usually 3 - kv["qwen25vl.vision.patch_size"] = q.VisionConfig.PatchSize - kv["qwen25vl.vision.block_count"] = q.VisionConfig.NumHiddenLayers - - // Add other relevant vision parameters if they exist in config.json - // e.g., kv["qwen25vl.vision.projection_dim"] = q.VisionConfig.ProjectionDim - - // Explicitly DO NOT set general.alignment here, rely on default handling - // if the tensor data sizes written by WriteTo are correct. - return kv } -// Tensors processes the list of loaded tensors, handling specific cases like splitting. func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor { var out []ggml.Tensor for _, t := range ts { - // Check if this tensor needs special handling if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") { - slog.Info("Splitting tensor", "name", t.Name()) var buf bytes.Buffer - // Write the original tensor data to a buffer - if _, err := t.WriteTo(&buf); err != nil { - panic(fmt.Sprintf("failed to read tensor %s for splitting: %v", t.Name(), err)) - - } - // Perform the split + t.WriteTo(&buf) newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape()) out = append(out, newTensors...) - slog.Info("Finished splitting tensor", "name", t.Name(), "output_tensors", len(newTensors)) } else { - // Pass through other tensors directly out = append(out, ggml.Tensor{ - Name: t.Name(), // Name will be transformed by Replacements later + Name: t.Name(), Kind: t.Kind(), Shape: t.Shape(), - WriterTo: t, // Pass the original tensor object + WriterTo: t, }) } } @@ -108,15 +67,12 @@ func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor { return out } -// Replacements provides the rules to rename tensors from the source format to the GGUF convention. -func (q *qwen25vlModel) Replacements() []string { - // Ensure these cover all transformations needed for both text and vision parts. - // Use the list from your original code, adding vision specific ones if missing. +func (p *qwen25vlModel) Replacements() []string { return []string{ - // Text model replacements "lm_head", "output", "model.embed_tokens", "token_embd", "model.layers", "blk", + "visual.blocks", "v.blk", "input_layernorm", "attn_norm", "self_attn.k_proj", "attn_k", "self_attn.v_proj", "attn_v", @@ -125,100 +81,96 @@ func (q *qwen25vlModel) Replacements() []string { "mlp.down_proj", "ffn_down", "mlp.gate_proj", "ffn_gate", "mlp.up_proj", "ffn_up", - "post_attention_layernorm", "ffn_norm", // Check if Qwen2.5 uses post_attention_layernorm or pre/post FFN norm + "post_attention_layernorm", "ffn_norm", "model.norm", "output_norm", - - // Vision model replacements (adjust based on actual HF names) - "visual.patch_embed.proj.weight", "v.patch_embed.proj.weight", // Base name for the split target - "visual.patch_embed.norm", "v.patch_embed.norm", // If norm exists - "visual.embed_tokens", "v.cls_token", // If CLS token exists - "visual.blocks", "v.blk", - "visual.norm", "v.post_norm", // Or v.norm depending on architecture - // Vision layer specific replacements (these should already be covered by text ones if names are consistent) - // e.g., within v.blk.*: - // "layer_norm1", "attn_norm", - // "attn.qkv", ... handle QKV split if needed ... - // "attn.proj", "attn_output", - // "layer_norm2", "ffn_norm", - // "mlp.fc1", "ffn_gate", // Or combine ffn_gate/ffn_up if HF uses different names - // "mlp.fc2", "ffn_down", - - // Multi-modal projector replacements (if applicable) - // "multi_modal_projector.linear_1", "mm_proj.0", // Example naming - // "multi_modal_projector.linear_2", "mm_proj.2", // Example naming } } func splitPatchEmbed(buf bytes.Buffer, kind uint32, shape []uint64) []ggml.Tensor { - // Ensure shape is as expected (5D with third dimension of 2) - if len(shape) != 5 || shape[2] != 2 { - panic(fmt.Sprintf("splitPatchEmbed: expected 5D tensor with shape[2]==2, got shape %v", shape)) + slog.Debug("patch stuff", "kind", kind, "shape", shape) + + if kind != tensorKindF16 { + panic("tensor is of wrong type") } - // Calculate target shape (remove the third dimension) - targetShape := append(shape[:2], shape[3:]...) - - // Calculate tensor sizes - elementSize := uint32(2) // F16 = 2 bytes per element - if kind == tensorKindF32 { - elementSize = 4 // F32 = 4 bytes per element + if len(shape) != 5 || (len(shape) == 5 && shape[2] != 2) { + panic("wrong sized tensor") } - // Calculate number of elements in each slice - elementsPerSlice := uint64(1) - for _, dim := range targetShape { - elementsPerSlice *= dim + // determine the size of the tensor based on its shape + shapeToSize := func(s []int) int { + r := 1 + for _, n := range s { + r *= int(n) + } + return r } - // Calculate total elements in original tensor - totalElements := elementsPerSlice * shape[2] // should be 2x the slice size - - // Read all data from buffer - data := make([]byte, totalElements*uint64(elementSize)) - if _, err := buf.Read(data); err != nil { - panic(fmt.Sprintf("splitPatchEmbed: failed to read data: %v", err)) + // tensor.WithShape() wants []int + intShape := make([]int, len(shape)) + for i, v := range shape { + intShape[i] = int(v) } - // Create the first tensor (slice 0) - slice0Data := make([]byte, elementsPerSlice*uint64(elementSize)) - for i := uint64(0); i < elementsPerSlice; i++ { - offset := i * uint64(elementSize) - copy(slice0Data[offset:offset+uint64(elementSize)], - data[offset:offset+uint64(elementSize)]) + u16s := make([]uint16, shapeToSize(intShape)) + if err := binary.Read(&buf, binary.LittleEndian, u16s); err != nil { + panic("bad read") } - // Create the second tensor (slice 1) - slice1Data := make([]byte, elementsPerSlice*uint64(elementSize)) - for i := uint64(0); i < elementsPerSlice; i++ { - srcOffset := (elementsPerSlice + i) * uint64(elementSize) - dstOffset := i * uint64(elementSize) - copy(slice1Data[dstOffset:dstOffset+uint64(elementSize)], - data[srcOffset:srcOffset+uint64(elementSize)]) + f32s := make([]float32, len(u16s)) + for i := range u16s { + f32s[i] = float16.Frombits(u16s[i]).Float32() } - // Return the two tensors with names matching the Python implementation - return []ggml.Tensor{ - { - Name: "v.patch_embd.weight", - Kind: kind, - Shape: targetShape, - WriterTo: &bytesWriterTo{data: slice0Data}, - }, - { - Name: "v.patch_embd.weight.1", - Kind: kind, - Shape: targetShape, - WriterTo: &bytesWriterTo{data: slice1Data}, - }, + newTensors := []ggml.Tensor{} + + getDataFromSlice := func(f32s []float32, shape []int, s []tensor.Slice) patchEmbed { + slog.Debug("getDataFromSlice", "num f32s", len(f32s), "shape", shape) + n := tensor.New(tensor.WithShape(shape...), tensor.WithBacking(f32s)) + t, err := n.Slice(s...) + if err != nil { + panic(err) + } + + ts, err := native.SelectF32(t.Materialize().(*tensor.Dense), 0) + if err != nil { + panic(err) + } + + slog.Debug("first vals", "val 1", ts[0][0], "val 2", ts[0][1], "val 3", ts[0][2]) + + f16s := make(patchEmbed, shapeToSize(shape)) + for r, row := range ts { + for c, col := range row { + f16s[r+c] = float16.Fromfloat32(col).Bits() + } + } + + return f16s } + + p := getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(0, 1, 1), nil, nil}) + newTensors = append(newTensors, ggml.Tensor{ + Name: "v.patch_embed.weight", + Kind: kind, + Shape: append(shape[:2], shape[3:]...), + WriterTo: p, + }) + + p = getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(1, 2, 1), nil, nil}) + newTensors = append(newTensors, ggml.Tensor{ + Name: "v.patch_embed.weight.1", + Kind: kind, + Shape: append(shape[:2], shape[3:]...), + WriterTo: p, + }) + + return newTensors } -// Helper type for writing bytes -type bytesWriterTo struct { - data []byte -} +type patchEmbed []uint16 -func (b *bytesWriterTo) WriteTo(w io.Writer) (int64, error) { - n, err := w.Write(b.data) - return int64(n), err +func (t patchEmbed) WriteTo(w io.Writer) (int64, error) { + err := binary.Write(w, binary.LittleEndian, t) + return 0, err } diff --git a/model/models/qwen25vl/model.go b/model/models/qwen25vl/model.go index 213d89d2e..66b2ff195 100644 --- a/model/models/qwen25vl/model.go +++ b/model/models/qwen25vl/model.go @@ -140,4 +140,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) { func init() { model.Register("qwen25vl", New) + model.Register("qwen2vl", New) } diff --git a/model/models/qwen25vl/model_text.go b/model/models/qwen25vl/model_text.go index 080866e6b..f1e5e05fc 100644 --- a/model/models/qwen25vl/model_text.go +++ b/model/models/qwen25vl/model_text.go @@ -1,6 +1,7 @@ package qwen25vl import ( + "fmt" "math" "github.com/ollama/ollama/kvcache" @@ -73,6 +74,9 @@ type SelfAttention struct { } func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor { + // fmt.Println(ml.Dump(ctx, sa.Query.Weight)) + // fmt.Println(ml.Dump(ctx, sa.Query.Bias)) + batchSize := hiddenState.Dim(1) headDim := opts.hiddenSize / opts.numHeads @@ -144,6 +148,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten } func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) { + fmt.Println(ml.Dump(ctx, m.OutputNorm.Weight)) + // Initial token embedding hiddenState := m.TokenEmbedding.Forward(ctx, inputs)