mirror of
https://github.com/ollama/ollama.git
synced 2025-04-12 21:59:22 +02:00
...
This commit is contained in:
parent
84b326ac3e
commit
a3e03aa240
@ -2,15 +2,17 @@ package convert
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"encoding/binary"
|
||||
"io"
|
||||
"log/slog"
|
||||
"strings"
|
||||
|
||||
"github.com/ollama/ollama/fs/ggml"
|
||||
"github.com/pdevine/tensor"
|
||||
"github.com/pdevine/tensor/native"
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
// Matches the structure in config.json for Qwen2.5-VL
|
||||
type qwen25vlModel struct {
|
||||
ModelParameters
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
@ -21,33 +23,16 @@ type qwen25vlModel struct {
|
||||
RopeTheta float32 `json:"rope_theta"`
|
||||
NumKeyValueHeads uint32 `json:"num_key_value_heads"`
|
||||
RMSNormEPS float32 `json:"rms_norm_eps"`
|
||||
// TieWordEmbeddings is often present, even if not used directly here
|
||||
TieWordEmbeddings bool `json:"tie_word_embeddings"`
|
||||
|
||||
// Vision specific parameters from its config (nested under vision_config)
|
||||
VisionConfig struct {
|
||||
HiddenSize uint32 `json:"hidden_size"`
|
||||
ImageSize uint32 `json:"image_size"`
|
||||
IntermediateSize uint32 `json:"intermediate_size"`
|
||||
LayerNormEps float32 `json:"layer_norm_eps"`
|
||||
NumAttentionHeads uint32 `json:"num_attention_heads"`
|
||||
NumChannels uint32 `json:"num_channels"`
|
||||
NumHiddenLayers uint32 `json:"num_hidden_layers"`
|
||||
PatchSize uint32 `json:"patch_size"`
|
||||
// May include others like projection_dim, use_cls_token etc.
|
||||
VisionModel struct {
|
||||
} `json:"vision_config"`
|
||||
// Might have top-level vision params too, check config.json
|
||||
// Example: ProjectorHiddenAct string `json:"projector_hidden_act"`
|
||||
}
|
||||
|
||||
// Compile-time check to ensure qwen25vlModel implements ModelConverter
|
||||
var _ ModelConverter = (*qwen25vlModel)(nil)
|
||||
|
||||
// KV provides the metadata key-value pairs for the GGUF header.
|
||||
func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv := q.ModelParameters.KV(t) // Assuming ModelParameters provides defaults like general.name etc.
|
||||
kv := q.ModelParameters.KV(t)
|
||||
kv["general.architecture"] = "qwen25vl"
|
||||
// Text model parameters
|
||||
kv["qwen25vl.block_count"] = q.HiddenLayers
|
||||
kv["qwen25vl.context_length"] = q.MaxPositionEmbeddings
|
||||
kv["qwen25vl.embedding_length"] = q.HiddenSize
|
||||
@ -57,50 +42,24 @@ func (q *qwen25vlModel) KV(t *Tokenizer) ggml.KV {
|
||||
kv["qwen25vl.rope.freq_base"] = q.RopeTheta
|
||||
kv["qwen25vl.attention.layer_norm_rms_epsilon"] = q.RMSNormEPS
|
||||
|
||||
// Vision model parameters (prefix with 'vision.')
|
||||
kv["qwen25vl.vision.hidden_size"] = q.VisionConfig.HiddenSize
|
||||
kv["qwen25vl.vision.image_size"] = q.VisionConfig.ImageSize
|
||||
kv["qwen25vl.vision.intermediate_size"] = q.VisionConfig.IntermediateSize
|
||||
kv["qwen25vl.vision.layer_norm_eps"] = q.VisionConfig.LayerNormEps
|
||||
kv["qwen25vl.vision.attention.head_count"] = q.VisionConfig.NumAttentionHeads
|
||||
kv["qwen25vl.vision.num_channels"] = q.VisionConfig.NumChannels // Usually 3
|
||||
kv["qwen25vl.vision.patch_size"] = q.VisionConfig.PatchSize
|
||||
kv["qwen25vl.vision.block_count"] = q.VisionConfig.NumHiddenLayers
|
||||
|
||||
// Add other relevant vision parameters if they exist in config.json
|
||||
// e.g., kv["qwen25vl.vision.projection_dim"] = q.VisionConfig.ProjectionDim
|
||||
|
||||
// Explicitly DO NOT set general.alignment here, rely on default handling
|
||||
// if the tensor data sizes written by WriteTo are correct.
|
||||
|
||||
return kv
|
||||
}
|
||||
|
||||
// Tensors processes the list of loaded tensors, handling specific cases like splitting.
|
||||
func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
var out []ggml.Tensor
|
||||
|
||||
for _, t := range ts {
|
||||
// Check if this tensor needs special handling
|
||||
if strings.HasSuffix(t.Name(), "patch_embed.proj.weight") {
|
||||
slog.Info("Splitting tensor", "name", t.Name())
|
||||
var buf bytes.Buffer
|
||||
// Write the original tensor data to a buffer
|
||||
if _, err := t.WriteTo(&buf); err != nil {
|
||||
panic(fmt.Sprintf("failed to read tensor %s for splitting: %v", t.Name(), err))
|
||||
|
||||
}
|
||||
// Perform the split
|
||||
t.WriteTo(&buf)
|
||||
newTensors := splitPatchEmbed(buf, t.Kind(), t.Shape())
|
||||
out = append(out, newTensors...)
|
||||
slog.Info("Finished splitting tensor", "name", t.Name(), "output_tensors", len(newTensors))
|
||||
} else {
|
||||
// Pass through other tensors directly
|
||||
out = append(out, ggml.Tensor{
|
||||
Name: t.Name(), // Name will be transformed by Replacements later
|
||||
Name: t.Name(),
|
||||
Kind: t.Kind(),
|
||||
Shape: t.Shape(),
|
||||
WriterTo: t, // Pass the original tensor object
|
||||
WriterTo: t,
|
||||
})
|
||||
}
|
||||
}
|
||||
@ -108,15 +67,12 @@ func (q *qwen25vlModel) Tensors(ts []Tensor) []ggml.Tensor {
|
||||
return out
|
||||
}
|
||||
|
||||
// Replacements provides the rules to rename tensors from the source format to the GGUF convention.
|
||||
func (q *qwen25vlModel) Replacements() []string {
|
||||
// Ensure these cover all transformations needed for both text and vision parts.
|
||||
// Use the list from your original code, adding vision specific ones if missing.
|
||||
func (p *qwen25vlModel) Replacements() []string {
|
||||
return []string{
|
||||
// Text model replacements
|
||||
"lm_head", "output",
|
||||
"model.embed_tokens", "token_embd",
|
||||
"model.layers", "blk",
|
||||
"visual.blocks", "v.blk",
|
||||
"input_layernorm", "attn_norm",
|
||||
"self_attn.k_proj", "attn_k",
|
||||
"self_attn.v_proj", "attn_v",
|
||||
@ -125,100 +81,96 @@ func (q *qwen25vlModel) Replacements() []string {
|
||||
"mlp.down_proj", "ffn_down",
|
||||
"mlp.gate_proj", "ffn_gate",
|
||||
"mlp.up_proj", "ffn_up",
|
||||
"post_attention_layernorm", "ffn_norm", // Check if Qwen2.5 uses post_attention_layernorm or pre/post FFN norm
|
||||
"post_attention_layernorm", "ffn_norm",
|
||||
"model.norm", "output_norm",
|
||||
|
||||
// Vision model replacements (adjust based on actual HF names)
|
||||
"visual.patch_embed.proj.weight", "v.patch_embed.proj.weight", // Base name for the split target
|
||||
"visual.patch_embed.norm", "v.patch_embed.norm", // If norm exists
|
||||
"visual.embed_tokens", "v.cls_token", // If CLS token exists
|
||||
"visual.blocks", "v.blk",
|
||||
"visual.norm", "v.post_norm", // Or v.norm depending on architecture
|
||||
// Vision layer specific replacements (these should already be covered by text ones if names are consistent)
|
||||
// e.g., within v.blk.*:
|
||||
// "layer_norm1", "attn_norm",
|
||||
// "attn.qkv", ... handle QKV split if needed ...
|
||||
// "attn.proj", "attn_output",
|
||||
// "layer_norm2", "ffn_norm",
|
||||
// "mlp.fc1", "ffn_gate", // Or combine ffn_gate/ffn_up if HF uses different names
|
||||
// "mlp.fc2", "ffn_down",
|
||||
|
||||
// Multi-modal projector replacements (if applicable)
|
||||
// "multi_modal_projector.linear_1", "mm_proj.0", // Example naming
|
||||
// "multi_modal_projector.linear_2", "mm_proj.2", // Example naming
|
||||
}
|
||||
}
|
||||
|
||||
func splitPatchEmbed(buf bytes.Buffer, kind uint32, shape []uint64) []ggml.Tensor {
|
||||
// Ensure shape is as expected (5D with third dimension of 2)
|
||||
if len(shape) != 5 || shape[2] != 2 {
|
||||
panic(fmt.Sprintf("splitPatchEmbed: expected 5D tensor with shape[2]==2, got shape %v", shape))
|
||||
slog.Debug("patch stuff", "kind", kind, "shape", shape)
|
||||
|
||||
if kind != tensorKindF16 {
|
||||
panic("tensor is of wrong type")
|
||||
}
|
||||
|
||||
// Calculate target shape (remove the third dimension)
|
||||
targetShape := append(shape[:2], shape[3:]...)
|
||||
|
||||
// Calculate tensor sizes
|
||||
elementSize := uint32(2) // F16 = 2 bytes per element
|
||||
if kind == tensorKindF32 {
|
||||
elementSize = 4 // F32 = 4 bytes per element
|
||||
if len(shape) != 5 || (len(shape) == 5 && shape[2] != 2) {
|
||||
panic("wrong sized tensor")
|
||||
}
|
||||
|
||||
// Calculate number of elements in each slice
|
||||
elementsPerSlice := uint64(1)
|
||||
for _, dim := range targetShape {
|
||||
elementsPerSlice *= dim
|
||||
// determine the size of the tensor based on its shape
|
||||
shapeToSize := func(s []int) int {
|
||||
r := 1
|
||||
for _, n := range s {
|
||||
r *= int(n)
|
||||
}
|
||||
return r
|
||||
}
|
||||
|
||||
// Calculate total elements in original tensor
|
||||
totalElements := elementsPerSlice * shape[2] // should be 2x the slice size
|
||||
|
||||
// Read all data from buffer
|
||||
data := make([]byte, totalElements*uint64(elementSize))
|
||||
if _, err := buf.Read(data); err != nil {
|
||||
panic(fmt.Sprintf("splitPatchEmbed: failed to read data: %v", err))
|
||||
// tensor.WithShape() wants []int
|
||||
intShape := make([]int, len(shape))
|
||||
for i, v := range shape {
|
||||
intShape[i] = int(v)
|
||||
}
|
||||
|
||||
// Create the first tensor (slice 0)
|
||||
slice0Data := make([]byte, elementsPerSlice*uint64(elementSize))
|
||||
for i := uint64(0); i < elementsPerSlice; i++ {
|
||||
offset := i * uint64(elementSize)
|
||||
copy(slice0Data[offset:offset+uint64(elementSize)],
|
||||
data[offset:offset+uint64(elementSize)])
|
||||
u16s := make([]uint16, shapeToSize(intShape))
|
||||
if err := binary.Read(&buf, binary.LittleEndian, u16s); err != nil {
|
||||
panic("bad read")
|
||||
}
|
||||
|
||||
// Create the second tensor (slice 1)
|
||||
slice1Data := make([]byte, elementsPerSlice*uint64(elementSize))
|
||||
for i := uint64(0); i < elementsPerSlice; i++ {
|
||||
srcOffset := (elementsPerSlice + i) * uint64(elementSize)
|
||||
dstOffset := i * uint64(elementSize)
|
||||
copy(slice1Data[dstOffset:dstOffset+uint64(elementSize)],
|
||||
data[srcOffset:srcOffset+uint64(elementSize)])
|
||||
f32s := make([]float32, len(u16s))
|
||||
for i := range u16s {
|
||||
f32s[i] = float16.Frombits(u16s[i]).Float32()
|
||||
}
|
||||
|
||||
// Return the two tensors with names matching the Python implementation
|
||||
return []ggml.Tensor{
|
||||
{
|
||||
Name: "v.patch_embd.weight",
|
||||
Kind: kind,
|
||||
Shape: targetShape,
|
||||
WriterTo: &bytesWriterTo{data: slice0Data},
|
||||
},
|
||||
{
|
||||
Name: "v.patch_embd.weight.1",
|
||||
Kind: kind,
|
||||
Shape: targetShape,
|
||||
WriterTo: &bytesWriterTo{data: slice1Data},
|
||||
},
|
||||
newTensors := []ggml.Tensor{}
|
||||
|
||||
getDataFromSlice := func(f32s []float32, shape []int, s []tensor.Slice) patchEmbed {
|
||||
slog.Debug("getDataFromSlice", "num f32s", len(f32s), "shape", shape)
|
||||
n := tensor.New(tensor.WithShape(shape...), tensor.WithBacking(f32s))
|
||||
t, err := n.Slice(s...)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ts, err := native.SelectF32(t.Materialize().(*tensor.Dense), 0)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
slog.Debug("first vals", "val 1", ts[0][0], "val 2", ts[0][1], "val 3", ts[0][2])
|
||||
|
||||
f16s := make(patchEmbed, shapeToSize(shape))
|
||||
for r, row := range ts {
|
||||
for c, col := range row {
|
||||
f16s[r+c] = float16.Fromfloat32(col).Bits()
|
||||
}
|
||||
}
|
||||
|
||||
return f16s
|
||||
}
|
||||
|
||||
p := getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(0, 1, 1), nil, nil})
|
||||
newTensors = append(newTensors, ggml.Tensor{
|
||||
Name: "v.patch_embed.weight",
|
||||
Kind: kind,
|
||||
Shape: append(shape[:2], shape[3:]...),
|
||||
WriterTo: p,
|
||||
})
|
||||
|
||||
p = getDataFromSlice(f32s, intShape, []tensor.Slice{nil, nil, tensor.S(1, 2, 1), nil, nil})
|
||||
newTensors = append(newTensors, ggml.Tensor{
|
||||
Name: "v.patch_embed.weight.1",
|
||||
Kind: kind,
|
||||
Shape: append(shape[:2], shape[3:]...),
|
||||
WriterTo: p,
|
||||
})
|
||||
|
||||
return newTensors
|
||||
}
|
||||
|
||||
// Helper type for writing bytes
|
||||
type bytesWriterTo struct {
|
||||
data []byte
|
||||
}
|
||||
type patchEmbed []uint16
|
||||
|
||||
func (b *bytesWriterTo) WriteTo(w io.Writer) (int64, error) {
|
||||
n, err := w.Write(b.data)
|
||||
return int64(n), err
|
||||
func (t patchEmbed) WriteTo(w io.Writer) (int64, error) {
|
||||
err := binary.Write(w, binary.LittleEndian, t)
|
||||
return 0, err
|
||||
}
|
||||
|
@ -140,4 +140,5 @@ func (m *Model) Forward(ctx ml.Context, batch input.Batch) (ml.Tensor, error) {
|
||||
|
||||
func init() {
|
||||
model.Register("qwen25vl", New)
|
||||
model.Register("qwen2vl", New)
|
||||
}
|
||||
|
@ -1,6 +1,7 @@
|
||||
package qwen25vl
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"math"
|
||||
|
||||
"github.com/ollama/ollama/kvcache"
|
||||
@ -73,6 +74,9 @@ type SelfAttention struct {
|
||||
}
|
||||
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *TextOptions) ml.Tensor {
|
||||
// fmt.Println(ml.Dump(ctx, sa.Query.Weight))
|
||||
// fmt.Println(ml.Dump(ctx, sa.Query.Bias))
|
||||
|
||||
batchSize := hiddenState.Dim(1)
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
|
||||
@ -144,6 +148,8 @@ func (l *Layer) Forward(ctx ml.Context, hiddenState, positionIDs, outputs ml.Ten
|
||||
}
|
||||
|
||||
func (m *TextModel) Forward(ctx ml.Context, inputs, positions, outputs ml.Tensor, batch input.Batch, cache kvcache.Cache) (ml.Tensor, error) {
|
||||
fmt.Println(ml.Dump(ctx, m.OutputNorm.Weight))
|
||||
|
||||
// Initial token embedding
|
||||
hiddenState := m.TokenEmbedding.Forward(ctx, inputs)
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user