count all vision tensors

This commit is contained in:
Michael Yang 2025-03-12 16:08:24 -07:00
parent 033cec232a
commit d2ec22371e

View File

@ -579,12 +579,16 @@ func (f GGML) GraphSize(context, batch uint64, kvCacheType string) (kv, partialO
}
func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
for name, layer := range llm.Tensors().GroupLayers() {
if strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
switch llm.KV().Architecture() {
case "mllama":
for _, layer := range llm.Tensors().GroupLayers()["v"] {
weights += layer.Size()
}
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
@ -611,15 +615,8 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
for name, layer := range llm.Tensors().GroupLayers() {
if strings.HasPrefix(name, "v.") {
for _, tensor := range layer {
weights += tensor.Size()
}
}
}
}
return weights, graphSize
}