From d2ec22371edba325903588b9515dd94b15b80d76 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Wed, 12 Mar 2025 16:08:24 -0700 Subject: [PATCH] count all vision tensors --- fs/ggml/ggml.go | 21 +++++++++------------ 1 file changed, 9 insertions(+), 12 deletions(-) diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 00392b4af..da3ee0a79 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -579,12 +579,16 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO } func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { + for name, layer := range llm.Tensors().GroupLayers() { + if strings.HasPrefix(name, "v.") { + for _, tensor := range layer { + weights += tensor.Size() + } + } + } + switch llm.KV().Architecture() { case "mllama": - for _, layer := range llm.Tensors().GroupLayers()["v"] { - weights += layer.Size() - } - kv := func(n string) uint64 { if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok { return uint64(v) @@ -611,15 +615,8 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) { embeddingLength*numPatches*maxNumTiles + 9*embeddingLength*numPaddedPatches*maxNumTiles + numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) - case "gemma3": - for name, layer := range llm.Tensors().GroupLayers() { - if strings.HasPrefix(name, "v.") { - for _, tensor := range layer { - weights += tensor.Size() - } - } - } } + return weights, graphSize }