roughly count gemma3 graph

the largest operation is by far (q @ k) so just count that for
simplicity
This commit is contained in:
Michael Yang 2025-03-13 14:41:57 -07:00
parent d2ec22371e
commit a422ba39c9

View File

@ -587,34 +587,32 @@ func (llm GGML) VisionGraphSize() (weights, graphSize uint64) {
} }
} }
switch llm.KV().Architecture() { imageSize := uint64(llm.KV().Uint("vision.image_size"))
case "mllama": patchSize := uint64(llm.KV().Uint("vision.patch_size"))
kv := func(n string) uint64 {
if v, ok := llm.KV()["mllama.vision."+n].(uint32); ok {
return uint64(v)
}
return 0 numPatches := (imageSize / patchSize) * (imageSize / patchSize)
}
imageSize := kv("image_size")
maxNumTiles := kv("max_num_tiles")
embeddingLength := kv("embedding_length")
headCount := kv("attention.head_count")
numPatches := (imageSize / kv("patch_size")) * (imageSize / kv("patch_size"))
if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok { if _, ok := llm.Tensors().GroupLayers()["v"]["class_embd"]; ok {
numPatches++ numPatches++
} }
headCount := uint64(llm.KV().Uint("vision.attention.head_count"))
switch llm.KV().Architecture() {
case "mllama":
numPaddedPatches := numPatches + 8 - (numPatches%8)%8 numPaddedPatches := numPatches + 8 - (numPatches%8)%8
maxNumTiles := uint64(llm.KV().Uint("vision.max_num_tiles"))
numChannels := uint64(llm.KV().Uint("vision.num_channels"))
embeddingLength := uint64(llm.KV().Uint("vision.embedding_length"))
graphSize = 4 * (8 + graphSize = 4 * (8 +
imageSize*imageSize*kv("num_channels")*maxNumTiles + imageSize*imageSize*numChannels*maxNumTiles +
embeddingLength*numPatches*maxNumTiles + embeddingLength*numPatches*maxNumTiles +
9*embeddingLength*numPaddedPatches*maxNumTiles + 9*embeddingLength*numPaddedPatches*maxNumTiles +
numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount) numPaddedPatches*maxNumTiles*numPaddedPatches*maxNumTiles*headCount)
case "gemma3":
graphSize = 4 * (numPatches * numPatches * headCount)
} }
return weights, graphSize return weights, graphSize