diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 3f4374cd00..56ad420e53 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -57,10 +57,28 @@ func (kv KV) EmbeddingLength() uint64 { return uint64(kv.Uint("embedding_length")) } +func (kv KV) HeadCount() []uint64 { + headCountDefault := uint32(1) + headCount := kv.UintOrArrayValueAsArray("attention.head_count", headCountDefault) + if len(headCount) == 1 { + headCountDefault = headCount[0] + } + nLayers := int(kv.BlockCount()) + if len(headCount) > nLayers { + slog.Warn("got more elements of attention.head_count than layers", "len(headCount)", len(headCount), "layers", nLayers) + } + out := make([]uint64, nLayers) + for i := range nLayers { + if i >= len(headCount) { + out[i] = uint64(headCountDefault) + } else { + out[i] = uint64(headCount[i]) + } + } + return out +} + func (kv KV) HeadCountMax() uint64 { - // TODO(drifkin): using the max value can cause an overestimation. In the - // future if array values become more popular, we can adapt the more invasive - // return uint64(kv.UintOrMaxArrayValue("attention.head_count", 1)) } @@ -68,6 +86,27 @@ func (kv KV) HeadCountMin() uint64 { return uint64(kv.UintOrMinArrayValue("attention.head_count", 1)) } +func (kv KV) HeadCountKV() []uint64 { + headCountKVDefault := uint32(1) + headCountKV := kv.UintOrArrayValueAsArray("attention.head_count_kv", headCountKVDefault) + if len(headCountKV) == 1 { + headCountKVDefault = headCountKV[0] + } + nLayers := int(kv.BlockCount()) + if len(headCountKV) > nLayers { + slog.Warn("got more elements of attention.head_count than layers", "len(headCountKV)", len(headCountKV), "layers", nLayers) + } + out := make([]uint64, nLayers) + for i := range nLayers { + if i >= len(headCountKV) { + out[i] = uint64(headCountKVDefault) + } else { + out[i] = uint64(headCountKV[i]) + } + } + return out +} + func (kv KV) HeadCountKVMax() uint64 { return uint64(kv.UintOrMaxArrayValue("attention.head_count_kv", 1)) } @@ -100,6 +139,26 @@ func (kv KV) ChatTemplate() string { return kv.String("tokenizer.chat_template") } +// ssm architecture parameters + +func (kv KV) SSMConvKernel() uint64 { + return uint64(kv.Uint("ssm.conv_kernel")) +} + +func (kv KV) SSMInnerSize() uint64 { + return uint64(kv.Uint("ssm.inner_size")) +} + +func (kv KV) SSMStateSize() uint64 { + return uint64(kv.Uint("ssm.state_size")) +} + +func (kv KV) SSMGroupCount() uint64 { + return uint64(kv.Uint("ssm.group_count")) +} + +// general types + func (kv KV) String(key string, defaultValue ...string) string { val, _ := keyValue(kv, key, append(defaultValue, "")...) return val @@ -131,22 +190,27 @@ func (kv KV) UintOrMinArrayValue(key string, defaultValue uint32) uint32 { } func (kv KV) UintOrArrayValue(key string, defaultValue uint32) (uint32, uint32) { + arrVal := kv.UintOrArrayValueAsArray(key, defaultValue) + return slices.Min(arrVal), slices.Max(arrVal) +} + +func (kv KV) UintOrArrayValueAsArray(key string, defaultValue uint32) []uint32 { if u32, ok := keyValue(kv, key, uint32(0)); ok { - return u32, u32 + return []uint32{u32} } else if u32s, ok := keyValue(kv, key, &array[uint32]{}); ok { - min := slices.Min(u32s.values) - max := slices.Max(u32s.values) - return min, max + return u32s.values } else if i32s, ok := keyValue(kv, key, &array[int32]{}); ok { - min := slices.Min(i32s.values) - max := slices.Max(i32s.values) - if min < 0 || max < 0 { - slog.Warn("array values are unexpectedly negative", "key", key, "min", min, "max", max) + dst := make([]uint32, len(i32s.values)) + for i, v := range i32s.values { + if v < 0 { + slog.Warn("array values are unexpectedly negative", "key", key, "i", i, "v", v) + } + dst[i] = uint32(v) } - return uint32(min), uint32(max) + return dst } - return defaultValue, defaultValue + return []uint32{defaultValue} } func (kv KV) Strings(key string, defaultValue ...[]string) []string { @@ -486,7 +550,9 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri embedding := f.KV().EmbeddingLength() heads := f.KV().HeadCountMax() + headsArr := f.KV().HeadCount() headsKV := f.KV().HeadCountKVMax() + headsKVArr := f.KV().HeadCountKV() vocab := uint64(f.KV()["tokenizer.ggml.tokens"].(*array[string]).size) embeddingHeads := f.KV().EmbeddingHeadCountMax() @@ -496,12 +562,51 @@ func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType stri layers := f.Tensors().GroupLayers() bytesPerElement := kvCacheBytesPerElement(kvCacheType) + + // Default for models unless special-cased below. These defaults mirror the + // cache usage in llama.cpp under the assumption that models without special + // cases below will use the llamarunner and caching will be handled by the + // llama.cpp layer. + // + // This also assumes that a layer without heads or headsKV set is recurrent + // which is usually the case. Some models (eg nemotronh) use "blocks" in + // place of layers where some are MLP blocks that don't have any cache. + // Models like this will need a special case below to be accurately + // estimated. var kvTotal uint64 kv = make([]uint64, f.KV().BlockCount()) + kvSizeAttn := uint64(0) + kvSizeRecurrent := uint64(0) for i := range kv { - kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKV) * bytesPerElement) + headsL := headsArr[i] + headsKVL := headsKVArr[i] + if headsL > 0 && headsKVL > 0 { + // full attention layer + // NOTE: Assumes uniform values for all attn layers + kv[i] = uint64(float64(context*(embeddingHeadsK+embeddingHeadsV)*headsKVL) * bytesPerElement) + kvSizeAttn += kv[i] + } else { + // recurrent layer + ssmDConv := f.KV().SSMConvKernel() + ssmDState := f.KV().SSMStateSize() + ssmDInner := f.KV().SSMInnerSize() + ssmNGroups := f.KV().SSMGroupCount() + nEmbdR := uint64(0) + if ssmDConv > 0 { + nEmbdR = (ssmDConv - 1) * (ssmDInner + 2*ssmNGroups*ssmDState) + } + nEmbdS := ssmDState * ssmDInner + + // recurrent always uses F32 in llama.cpp backend + // https://github.com/ggml-org/llama.cpp/blob/master/src/llama-model.cpp#L18644 + bytesPerElementRecurrent := kvCacheBytesPerElement("f32") + + kv[i] = (nEmbdR + nEmbdS) * uint64(bytesPerElementRecurrent) + kvSizeRecurrent += kv[i] + } kvTotal += kv[i] } + slog.Debug("default cache size estimate", "attention MiB", float32(kvSizeAttn)/(1024.*1024.), "attention bytes", kvSizeAttn, "recurrent MiB", float32(kvSizeRecurrent)/(1024.*1024.), "recurrent bytes", kvSizeRecurrent) switch f.KV().Architecture() { case "llama", "llama4": @@ -794,6 +899,8 @@ func kvCacheBytesPerElement(cacheType string) float64 { return 1 // 1/2 of fp16 case "q4_0": return 0.5 // 1/4 of fp16 + case "f32": + return 4 // f32 (default for recurrent) default: return 2 // f16 (default) }