From 4987f13d345d77844b6737edadaa1f0432df004c Mon Sep 17 00:00:00 2001 From: Gabe Goodhart Date: Mon, 13 Oct 2025 16:26:18 -0600 Subject: [PATCH] Llama cpp bump (df1b612): granite docling / mamba2 optimizations / multimodal encoding fixes (#12552) * feat: Bump llama.cpp to df1b612 Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * fix(mtmd): Correctly encode text chunks during mtmd tokenization There can be text chunks that appear interspersed with the image embeddings that contain template delimiter tokens for some models. These need to be correctly translated to text tokens. Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * tests: Use MtmdChunk in image_test Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * style: Fix unnecessary conversion linting Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * fix(ggml): Revert changes to ggml_hip.cpp These changes were done largely by our code assistant and are likely wrong Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * fix: Revert changes in mem_nvml.cpp Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * feat: Update sync point to 1deee0 This brings in several more optimization commits and model support for EmbeddingGemma Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * feat: Update patches for 1deee0 Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * feat: sync for bump to 1deee0 Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * fix: Bad patch updates with errant `+` Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * feat: Bump llama.cpp/ggml to 7049736 Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart * fix: format-patches after latest bump Branch: LlamaCPPBump-GraniteDocling Signed-off-by: Gabe Goodhart --------- Signed-off-by: Gabe Goodhart --- Makefile.sync | 2 +- llama/build-info.cpp | 2 +- llama/llama.cpp/common/common.cpp | 1 + llama/llama.cpp/common/common.h | 8 +- llama/llama.cpp/include/llama.h | 8 + llama/llama.cpp/src/llama-arch.cpp | 62 + llama/llama.cpp/src/llama-arch.h | 15 + llama/llama.cpp/src/llama-chat.cpp | 2 +- llama/llama.cpp/src/llama-context.cpp | 6 + llama/llama.cpp/src/llama-graph.cpp | 17 + llama/llama.cpp/src/llama-graph.h | 8 + llama/llama.cpp/src/llama-hparams.cpp | 6 +- llama/llama.cpp/src/llama-hparams.h | 14 +- llama/llama.cpp/src/llama-kv-cache-iswa.cpp | 4 +- llama/llama.cpp/src/llama-kv-cache.cpp | 7 +- llama/llama.cpp/src/llama-memory-hybrid.cpp | 20 +- .../llama.cpp/src/llama-memory-recurrent.cpp | 14 +- llama/llama.cpp/src/llama-model-loader.cpp | 1 + llama/llama.cpp/src/llama-model.cpp | 375 ++++- llama/llama.cpp/src/llama-model.h | 13 + llama/llama.cpp/src/llama-sampling.cpp | 5 + llama/llama.cpp/src/llama-vocab.cpp | 6 + llama/llama.cpp/src/llama-vocab.h | 81 +- llama/llama.cpp/tools/mtmd/clip-impl.h | 1 + llama/llama.cpp/tools/mtmd/clip.cpp | 52 +- llama/llama.cpp/tools/mtmd/mtmd.cpp | 105 +- llama/llama.go | 63 +- ...loc-and-free-using-the-same-compiler.patch | 28 +- llama/patches/0002-pretokenizer.patch | 6 +- llama/patches/0003-clip-unicode.patch | 6 +- llama/patches/0004-solar-pro.patch | 53 +- .../0005-fix-deepseek-deseret-regex.patch | 2 +- ...target-ggml-cpu-for-all-cpu-variants.patch | 6 +- llama/patches/0009-remove-amx.patch | 4 +- .../0010-fix-string-arr-kv-loading.patch | 4 +- llama/patches/0011-ollama-debug-tensor.patch | 4 +- ...add-ollama-vocab-for-grammar-support.patch | 2 +- ...13-add-argsort-and-cuda-copy-for-i32.patch | 10 +- ...14-graph-memory-reporting-on-failure.patch | 18 +- .../patches/0015-ggml-Export-GPU-UUIDs.patch | 14 +- .../0016-add-C-API-for-mtmd_input_text.patch | 4 +- ...-no-power-throttling-win32-with-gnuc.patch | 4 +- .../0019-Enable-CUDA-Graphs-for-gemma3n.patch | 6 +- llama/patches/0022-ggml-No-alloc-mode.patch | 26 +- .../0023-decode-disable-output_all.patch | 2 +- ...gml-Enable-resetting-backend-devices.patch | 28 +- .../0026-GPU-discovery-enhancements.patch | 52 +- ml/backend/ggml/ggml/include/ggml-backend.h | 2 + ml/backend/ggml/ggml/include/ggml-rpc.h | 17 +- ml/backend/ggml/ggml/include/ggml.h | 22 + ml/backend/ggml/ggml/src/CMakeLists.txt | 3 + ml/backend/ggml/ggml/src/ggml-alloc.c | 30 +- ml/backend/ggml/ggml/src/ggml-backend-impl.h | 3 - ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp | 1 + .../ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h | 2 +- ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c | 8 +- ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp | 34 +- .../ggml/ggml/src/ggml-cpu/unary-ops.cpp | 103 ++ ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h | 1 + ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp | 66 + ml/backend/ggml/ggml/src/ggml-cpu/vec.h | 18 +- .../ggml/ggml/src/ggml-cuda/CMakeLists.txt | 2 + ml/backend/ggml/ggml/src/ggml-cuda/common.cuh | 36 +- .../ggml/ggml/src/ggml-cuda/fattn-common.cuh | 9 +- .../ggml/ggml/src/ggml-cuda/fattn-tile.cu | 774 +---------- .../ggml/ggml/src/ggml-cuda/fattn-tile.cuh | 1203 +++++++++++++++++ .../ggml/ggml/src/ggml-cuda/fattn-vec.cuh | 2 - .../ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu | 12 +- .../ggml/src/ggml-cuda/fattn-wmma-f16.cuh | 48 + ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu | 70 +- .../ggml/ggml/src/ggml-cuda/ggml-cuda.cu | 5 +- .../fattn-tile-instance-dkq112-dv112.cu | 5 + .../fattn-tile-instance-dkq128-dv128.cu | 5 + .../fattn-tile-instance-dkq256-dv256.cu | 5 + .../fattn-tile-instance-dkq40-dv40.cu | 5 + .../fattn-tile-instance-dkq576-dv512.cu | 5 + .../fattn-tile-instance-dkq64-dv64.cu | 5 + .../fattn-tile-instance-dkq80-dv80.cu | 5 + .../fattn-tile-instance-dkq96-dv96.cu | 5 + .../ggml/ggml/src/ggml-cuda/topk-moe.cu | 4 +- ml/backend/ggml/ggml/src/ggml-cuda/unary.cu | 54 + ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh | 3 + .../ggml/ggml/src/ggml-cuda/vendors/hip.h | 3 + .../ggml/ggml/src/ggml-hip/CMakeLists.txt | 12 +- ml/backend/ggml/ggml/src/ggml-impl.h | 3 + .../ggml/src/ggml-metal/ggml-metal-common.cpp | 4 +- .../ggml/src/ggml-metal/ggml-metal-device.cpp | 184 ++- .../ggml/src/ggml-metal/ggml-metal-device.h | 17 + .../ggml/src/ggml-metal/ggml-metal-device.m | 8 +- .../src/ggml-metal/ggml-metal-embed.metal | 1154 +++++++++------- .../ggml/src/ggml-metal/ggml-metal-impl.h | 80 +- .../ggml/src/ggml-metal/ggml-metal-ops.cpp | 527 ++++++-- .../ggml/ggml/src/ggml-metal/ggml-metal-ops.h | 5 + .../ggml/ggml/src/ggml-metal/ggml-metal.cpp | 7 +- .../ggml/ggml/src/ggml-metal/ggml-metal.metal | 1074 ++++++++------- ml/backend/ggml/ggml/src/ggml.c | 36 +- runner/llamarunner/image.go | 16 +- runner/llamarunner/image_test.go | 12 +- runner/llamarunner/runner.go | 12 +- 99 files changed, 4580 insertions(+), 2323 deletions(-) create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu create mode 100644 ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu diff --git a/Makefile.sync b/Makefile.sync index 800322e2be..2e99c7fb3d 100644 --- a/Makefile.sync +++ b/Makefile.sync @@ -1,6 +1,6 @@ UPSTREAM=https://github.com/ggml-org/llama.cpp.git WORKDIR=llama/vendor -FETCH_HEAD=364a7a6d4a786e98947c8a90430ea581213c0ba9 +FETCH_HEAD=7049736b2dd9011bf819e298b844ebbc4b5afdc9 .PHONY: help help: diff --git a/llama/build-info.cpp b/llama/build-info.cpp index 332bfecee8..ea711c8739 100644 --- a/llama/build-info.cpp +++ b/llama/build-info.cpp @@ -1,4 +1,4 @@ int LLAMA_BUILD_NUMBER = 0; -char const *LLAMA_COMMIT = "364a7a6d4a786e98947c8a90430ea581213c0ba9"; +char const *LLAMA_COMMIT = "7049736b2dd9011bf819e298b844ebbc4b5afdc9"; char const *LLAMA_COMPILER = ""; char const *LLAMA_BUILD_TARGET = ""; diff --git a/llama/llama.cpp/common/common.cpp b/llama/llama.cpp/common/common.cpp index c1e736c44c..b0591e84b0 100644 --- a/llama/llama.cpp/common/common.cpp +++ b/llama/llama.cpp/common/common.cpp @@ -1133,6 +1133,7 @@ struct llama_model_params common_model_params_to_llama(common_params & params) { mparams.use_mlock = params.use_mlock; mparams.check_tensors = params.check_tensors; mparams.use_extra_bufts = !params.no_extra_bufts; + mparams.no_host = params.no_host; if (params.kv_overrides.empty()) { mparams.kv_overrides = NULL; diff --git a/llama/llama.cpp/common/common.h b/llama/llama.cpp/common/common.h index 40c6847f32..a8cb630ea5 100644 --- a/llama/llama.cpp/common/common.h +++ b/llama/llama.cpp/common/common.h @@ -378,7 +378,7 @@ struct common_params { bool simple_io = false; // improves compatibility with subprocesses and limited consoles bool cont_batching = true; // insert new sequences for decoding on-the-fly bool no_perf = false; // disable performance metrics - bool ctx_shift = false; // context shift on infinite text generation + bool ctx_shift = false; // context shift on infinite text generation bool swa_full = false; // use full-size SWA cache (https://github.com/ggml-org/llama.cpp/pull/13194#issuecomment-2868343055) bool kv_unified = false; // enable unified KV cache @@ -392,6 +392,7 @@ struct common_params { bool check_tensors = false; // validate tensor data bool no_op_offload = false; // globally disable offload host tensor operations to device bool no_extra_bufts = false; // disable extra buffer types (used for weight repacking) + bool no_host = false; // bypass host buffer allowing extra buffers to be used bool single_turn = false; // single turn chat conversation @@ -424,7 +425,8 @@ struct common_params { int32_t timeout_write = timeout_read; // http write timeout in seconds int32_t n_threads_http = -1; // number of threads to process HTTP requests (TODO: support threadpool) int32_t n_cache_reuse = 0; // min chunk size to reuse from the cache via KV shifting - int32_t n_swa_checkpoints = 3; // max number of SWA checkpoints per slot + int32_t n_ctx_checkpoints = 8; // max number of context checkpoints per slot + int32_t cache_ram_mib = 8192; // -1 = no limit, 0 - disable, 1 = 1 MiB, etc. std::string hostname = "127.0.0.1"; std::string public_path = ""; // NOLINT @@ -432,7 +434,7 @@ struct common_params { std::string chat_template = ""; // NOLINT bool use_jinja = false; // NOLINT bool enable_chat_template = true; - common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_AUTO; + common_reasoning_format reasoning_format = COMMON_REASONING_FORMAT_DEEPSEEK; int reasoning_budget = -1; bool prefill_assistant = true; // if true, any trailing assistant message will be prefilled into the response diff --git a/llama/llama.cpp/include/llama.h b/llama/llama.cpp/include/llama.h index 452d9ec5bf..a0a660bff8 100644 --- a/llama/llama.cpp/include/llama.h +++ b/llama/llama.cpp/include/llama.h @@ -296,6 +296,7 @@ extern "C" { bool use_mlock; // force system to keep model in RAM bool check_tensors; // validate model tensor data bool use_extra_bufts; // use extra buffer types (used for weight repacking) + bool no_host; // bypass host buffer allowing extra buffers to be used }; // NOTE: changing the default values of parameters marked as [EXPERIMENTAL] may cause crashes or incorrect results in certain configurations @@ -543,6 +544,9 @@ extern "C" { // Returns true if the model is recurrent (like Mamba, RWKV, etc.) LLAMA_API bool llama_model_is_recurrent(const struct llama_model * model); + // Returns true if the model is hybrid (like Jamba, Granite, etc.) + LLAMA_API bool llama_model_is_hybrid(const struct llama_model * model); + // Returns true if the model is diffusion-based (like LLaDA, Dream, etc.) LLAMA_API bool llama_model_is_diffusion(const struct llama_model * model); @@ -791,8 +795,12 @@ extern "C" { size_t n_token_capacity, size_t * n_token_count_out); +// for backwards-compat #define LLAMA_STATE_SEQ_FLAGS_SWA_ONLY 1 +// work only with partial states, such as SWA KV cache or recurrent cache (e.g. Mamba) +#define LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY 1 + typedef uint32_t llama_state_seq_flags; LLAMA_API size_t llama_state_seq_get_size_ext( diff --git a/llama/llama.cpp/src/llama-arch.cpp b/llama/llama.cpp/src/llama-arch.cpp index f98a3574c2..9f6b6ad20e 100644 --- a/llama/llama.cpp/src/llama-arch.cpp +++ b/llama/llama.cpp/src/llama-arch.cpp @@ -94,12 +94,14 @@ static const std::map LLM_ARCH_NAMES = { { LLM_ARCH_SMOLLM3, "smollm3" }, { LLM_ARCH_OPENAI_MOE, "gpt-oss" }, { LLM_ARCH_LFM2, "lfm2" }, + { LLM_ARCH_LFM2MOE, "lfm2moe" }, { LLM_ARCH_DREAM, "dream" }, { LLM_ARCH_SMALLTHINKER, "smallthinker" }, { LLM_ARCH_LLADA, "llada" }, { LLM_ARCH_LLADA_MOE, "llada-moe" }, { LLM_ARCH_SEED_OSS, "seed_oss" }, { LLM_ARCH_GROVEMOE, "grovemoe" }, + { LLM_ARCH_APERTUS, "apertus" }, { LLM_ARCH_UNKNOWN, "(unknown)" }, }; @@ -219,6 +221,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_CLASSIFIER_OUTPUT_LABELS, "%s.classifier.output_labels" }, { LLM_KV_SHORTCONV_L_CACHE, "%s.shortconv.l_cache" }, + // sentence-transformers dense modules feature dims + { LLM_KV_DENSE_2_FEAT_IN, "%s.dense_2_feat_in" }, + { LLM_KV_DENSE_2_FEAT_OUT, "%s.dense_2_feat_out" }, + { LLM_KV_DENSE_3_FEAT_IN, "%s.dense_3_feat_in" }, + { LLM_KV_DENSE_3_FEAT_OUT, "%s.dense_3_feat_out" }, { LLM_KV_TOKENIZER_MODEL, "tokenizer.ggml.model" }, { LLM_KV_TOKENIZER_PRE, "tokenizer.ggml.pre" }, @@ -258,6 +265,11 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ADAPTER_LORA_PROMPT_PREFIX, "adapter.lora.prompt_prefix" }, { LLM_KV_ADAPTER_ALORA_INVOCATION_TOKENS, "adapter.alora.invocation_tokens" }, + { LLM_KV_XIELU_ALPHA_N, "xielu.alpha_n" }, + { LLM_KV_XIELU_ALPHA_P, "xielu.alpha_p" }, + { LLM_KV_XIELU_BETA, "xielu.beta" }, + { LLM_KV_XIELU_EPS, "xielu.eps" }, + // deprecated { LLM_KV_TOKENIZER_PREFIX_ID, "tokenizer.ggml.prefix_token_id" }, { LLM_KV_TOKENIZER_SUFFIX_ID, "tokenizer.ggml.suffix_token_id" }, @@ -1066,6 +1078,8 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_DENSE_2_OUT, "dense_2" }, + { LLM_TENSOR_DENSE_3_OUT, "dense_3" }, { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, @@ -2118,6 +2132,32 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_OUTPUT, "output" }, } }, + { + LLM_ARCH_LFM2MOE, + { + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_GATE, "blk.%d.ffn_gate" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + { LLM_TENSOR_SHORTCONV_CONV, "blk.%d.shortconv.conv" }, + { LLM_TENSOR_SHORTCONV_INPROJ, "blk.%d.shortconv.in_proj" }, + { LLM_TENSOR_SHORTCONV_OUTPROJ, "blk.%d.shortconv.out_proj" }, + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_TOKEN_EMBD_NORM, "token_embd_norm" }, + { LLM_TENSOR_FFN_GATE_INP, "blk.%d.ffn_gate_inp" }, + { LLM_TENSOR_FFN_GATE_EXPS, "blk.%d.ffn_gate_exps" }, + { LLM_TENSOR_FFN_DOWN_EXPS, "blk.%d.ffn_down_exps" }, + { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" }, + { LLM_TENSOR_FFN_EXP_PROBS_B, "blk.%d.exp_probs_b" }, + } + }, { LLM_ARCH_SMALLTHINKER, { @@ -2139,6 +2179,25 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_FFN_UP_EXPS, "blk.%d.ffn_up_exps" } }, }, + { + LLM_ARCH_APERTUS, + { + { LLM_TENSOR_TOKEN_EMBD, "token_embd" }, + { LLM_TENSOR_OUTPUT_NORM, "output_norm" }, + { LLM_TENSOR_OUTPUT, "output" }, + { LLM_TENSOR_ROPE_FREQS, "rope_freqs" }, + { LLM_TENSOR_ATTN_NORM, "blk.%d.attn_norm" }, + { LLM_TENSOR_ATTN_Q, "blk.%d.attn_q" }, + { LLM_TENSOR_ATTN_K, "blk.%d.attn_k" }, + { LLM_TENSOR_ATTN_V, "blk.%d.attn_v" }, + { LLM_TENSOR_ATTN_OUT, "blk.%d.attn_output" }, + { LLM_TENSOR_ATTN_Q_NORM, "blk.%d.attn_q_norm" }, + { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, + { LLM_TENSOR_FFN_NORM, "blk.%d.ffn_norm" }, + { LLM_TENSOR_FFN_DOWN, "blk.%d.ffn_down" }, + { LLM_TENSOR_FFN_UP, "blk.%d.ffn_up" }, + }, + }, { LLM_ARCH_DREAM, { @@ -2249,6 +2308,8 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_OUTPUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, {LLM_TENSOR_CLS_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, + {LLM_TENSOR_DENSE_2_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output + {LLM_TENSOR_DENSE_3_OUT, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL_MAT}}, // Dense layer output {LLM_TENSOR_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_DEC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, {LLM_TENSOR_ENC_OUTPUT_NORM, {LLM_TENSOR_LAYER_OUTPUT, GGML_OP_MUL}}, @@ -2489,6 +2550,7 @@ bool llm_arch_is_hybrid(const llm_arch & arch) { case LLM_ARCH_PLAMO2: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_NEMOTRON_H: return true; default: diff --git a/llama/llama.cpp/src/llama-arch.h b/llama/llama.cpp/src/llama-arch.h index aa8e0e7b7c..dc7a362af9 100644 --- a/llama/llama.cpp/src/llama-arch.h +++ b/llama/llama.cpp/src/llama-arch.h @@ -98,12 +98,14 @@ enum llm_arch { LLM_ARCH_SMOLLM3, LLM_ARCH_OPENAI_MOE, LLM_ARCH_LFM2, + LLM_ARCH_LFM2MOE, LLM_ARCH_DREAM, LLM_ARCH_SMALLTHINKER, LLM_ARCH_LLADA, LLM_ARCH_LLADA_MOE, LLM_ARCH_SEED_OSS, LLM_ARCH_GROVEMOE, + LLM_ARCH_APERTUS, LLM_ARCH_UNKNOWN, }; @@ -262,10 +264,21 @@ enum llm_kv { LLM_KV_SHORTCONV_L_CACHE, + LLM_KV_XIELU_ALPHA_N, + LLM_KV_XIELU_ALPHA_P, + LLM_KV_XIELU_BETA, + LLM_KV_XIELU_EPS, + // deprecated: LLM_KV_TOKENIZER_PREFIX_ID, LLM_KV_TOKENIZER_SUFFIX_ID, LLM_KV_TOKENIZER_MIDDLE_ID, + + // sentence-transformers dense layers in and out features + LLM_KV_DENSE_2_FEAT_IN, + LLM_KV_DENSE_2_FEAT_OUT, + LLM_KV_DENSE_3_FEAT_IN, + LLM_KV_DENSE_3_FEAT_OUT, }; enum llm_tensor { @@ -273,6 +286,8 @@ enum llm_tensor { LLM_TENSOR_TOKEN_EMBD_NORM, LLM_TENSOR_TOKEN_TYPES, LLM_TENSOR_POS_EMBD, + LLM_TENSOR_DENSE_2_OUT, + LLM_TENSOR_DENSE_3_OUT, LLM_TENSOR_OUTPUT, LLM_TENSOR_OUTPUT_NORM, LLM_TENSOR_ROPE_FREQS, diff --git a/llama/llama.cpp/src/llama-chat.cpp b/llama/llama.cpp/src/llama-chat.cpp index 66e6c6a38f..956c4e085e 100644 --- a/llama/llama.cpp/src/llama-chat.cpp +++ b/llama/llama.cpp/src/llama-chat.cpp @@ -590,7 +590,7 @@ int32_t llm_chat_apply_template( ss << message->content << "<|end_of_text|>\n"; } if (add_ass) { - ss << "<|start_of_role|>assistant<|end_of_role|>\n"; + ss << "<|start_of_role|>assistant<|end_of_role|>"; } } else if (tmpl == LLM_CHAT_TEMPLATE_GIGACHAT) { // GigaChat template diff --git a/llama/llama.cpp/src/llama-context.cpp b/llama/llama.cpp/src/llama-context.cpp index 09247cef99..53a5e3a9be 100644 --- a/llama/llama.cpp/src/llama-context.cpp +++ b/llama/llama.cpp/src/llama-context.cpp @@ -2345,6 +2345,12 @@ llama_context * llama_init_from_model( return nullptr; } + if (params.pooling_type != model->hparams.pooling_type) { + //user-specified pooling-type is different from the model default + LLAMA_LOG_WARN("%s: model default pooling_type is [%d], but [%d] was specified\n", __func__, + model->hparams.pooling_type, params.pooling_type); + } + try { auto * ctx = new llama_context(*model, params); return ctx; diff --git a/llama/llama.cpp/src/llama-graph.cpp b/llama/llama.cpp/src/llama-graph.cpp index 90cd885a60..a24853c63a 100644 --- a/llama/llama.cpp/src/llama-graph.cpp +++ b/llama/llama.cpp/src/llama-graph.cpp @@ -1853,6 +1853,23 @@ llm_graph_input_mem_hybrid * llm_graph_context::build_inp_mem_hybrid() const { return (llm_graph_input_mem_hybrid *) res->add_input(std::move(inp)); } +void llm_graph_context::build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const { + if (!cparams.embeddings || dense_2 == nullptr || dense_3 == nullptr) { + return; + } + ggml_tensor * cur = res->t_embd_pooled != nullptr ? res->t_embd_pooled : res->t_embd; + GGML_ASSERT(cur != nullptr && "missing t_embd_pooled/t_embd"); + + cur = ggml_mul_mat(ctx0, dense_2, cur); + cur = ggml_mul_mat(ctx0, dense_3, cur); + cb(cur, "result_embd_pooled", -1); + res->t_embd_pooled = cur; + ggml_build_forward_expand(gf, cur); +} + + void llm_graph_context::build_pooling( ggml_tensor * cls, ggml_tensor * cls_b, diff --git a/llama/llama.cpp/src/llama-graph.h b/llama/llama.cpp/src/llama-graph.h index 34b984afeb..dc84b79428 100644 --- a/llama/llama.cpp/src/llama-graph.h +++ b/llama/llama.cpp/src/llama-graph.h @@ -814,6 +814,14 @@ struct llm_graph_context { ggml_tensor * cls_b, ggml_tensor * cls_out, ggml_tensor * cls_out_b) const; + + // + // dense (out) + // + + void build_dense_out( + ggml_tensor * dense_2, + ggml_tensor * dense_3) const; }; // TODO: better name diff --git a/llama/llama.cpp/src/llama-hparams.cpp b/llama/llama.cpp/src/llama-hparams.cpp index 24a515a035..b6bf6bbf2c 100644 --- a/llama/llama.cpp/src/llama-hparams.cpp +++ b/llama/llama.cpp/src/llama-hparams.cpp @@ -140,7 +140,11 @@ uint32_t llama_hparams::n_embd_s() const { } bool llama_hparams::is_recurrent(uint32_t il) const { - return recurrent_layer_arr[il]; + if (il < n_layer) { + return recurrent_layer_arr[il]; + } + + GGML_ABORT("%s: il (%u) out of bounds (n_layer: %u)\n", __func__, il, n_layer); } uint32_t llama_hparams::n_pos_per_embd() const { diff --git a/llama/llama.cpp/src/llama-hparams.h b/llama/llama.cpp/src/llama-hparams.h index eb13709f88..80582728af 100644 --- a/llama/llama.cpp/src/llama-hparams.h +++ b/llama/llama.cpp/src/llama-hparams.h @@ -42,7 +42,7 @@ struct llama_hparams { uint32_t n_embd; uint32_t n_embd_features = 0; uint32_t n_layer; - int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache + int32_t n_layer_kv_from_start = -1; // if non-negative, the first n_layer_kv_from_start layers have KV cache uint32_t n_rot; uint32_t n_embd_head_k; // dimension of keys (d_k). d_q is assumed to be the same, but there are n_head q heads, and only n_head_kv k-v heads uint32_t n_embd_head_v; // dimension of values (d_v) aka n_embd_head @@ -171,6 +171,18 @@ struct llama_hparams { uint32_t laurel_rank = 64; uint32_t n_embd_altup = 256; + // needed for sentence-transformers dense layers + uint32_t dense_2_feat_in = 0; // in_features of the 2_Dense + uint32_t dense_2_feat_out = 0; // out_features of the 2_Dense + uint32_t dense_3_feat_in = 0; // in_features of the 3_Dense + uint32_t dense_3_feat_out = 0; // out_features of the 3_Dense + + // xIELU + std::array xielu_alpha_n; + std::array xielu_alpha_p; + std::array xielu_beta; + std::array xielu_eps; + // needed by encoder-decoder models (e.g. T5, FLAN-T5) // ref: https://github.com/ggerganov/llama.cpp/pull/8141 llama_token dec_start_token_id = LLAMA_TOKEN_NULL; diff --git a/llama/llama.cpp/src/llama-kv-cache-iswa.cpp b/llama/llama.cpp/src/llama-kv-cache-iswa.cpp index 827302e6d2..facba1d004 100644 --- a/llama/llama.cpp/src/llama-kv-cache-iswa.cpp +++ b/llama/llama.cpp/src/llama-kv-cache-iswa.cpp @@ -220,7 +220,7 @@ bool llama_kv_cache_iswa::get_can_shift() const { } void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_write(io, seq_id, flags); } @@ -228,7 +228,7 @@ void llama_kv_cache_iswa::state_write(llama_io_write_i & io, llama_seq_id seq_id } void llama_kv_cache_iswa::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - if ((flags & LLAMA_STATE_SEQ_FLAGS_SWA_ONLY) == 0) { + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { kv_base->state_read(io, seq_id, flags); } diff --git a/llama/llama.cpp/src/llama-kv-cache.cpp b/llama/llama.cpp/src/llama-kv-cache.cpp index 816f2d5de5..736693e174 100644 --- a/llama/llama.cpp/src/llama-kv-cache.cpp +++ b/llama/llama.cpp/src/llama-kv-cache.cpp @@ -123,11 +123,8 @@ llama_kv_cache::llama_kv_cache( throw std::runtime_error("failed to create ggml context for kv cache"); } - ggml_tensor * k; - ggml_tensor * v; - - k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); - v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); + ggml_tensor * k = ggml_new_tensor_3d(ctx, type_k, n_embd_k_gqa, kv_size, n_stream); + ggml_tensor * v = ggml_new_tensor_3d(ctx, type_v, n_embd_v_gqa, kv_size, n_stream); ggml_format_name(k, "cache_k_l%d", il); ggml_format_name(v, "cache_v_l%d", il); diff --git a/llama/llama.cpp/src/llama-memory-hybrid.cpp b/llama/llama.cpp/src/llama-memory-hybrid.cpp index abf652483c..dfb8439e01 100644 --- a/llama/llama.cpp/src/llama-memory-hybrid.cpp +++ b/llama/llama.cpp/src/llama-memory-hybrid.cpp @@ -73,7 +73,9 @@ llama_memory_context_ptr llama_memory_hybrid::init_batch(llama_batch_allocr & ba // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -175,17 +177,17 @@ std::map llama_memory_hybrid::memory_breakdo } void llama_memory_hybrid::state_write(llama_io_write_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) const { - GGML_UNUSED(flags); - - mem_attn->state_write(io, seq_id); - mem_recr->state_write(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_write(io, seq_id, flags); + } + mem_recr->state_write(io, seq_id, flags); } void llama_memory_hybrid::state_read(llama_io_read_i & io, llama_seq_id seq_id, llama_state_seq_flags flags) { - GGML_UNUSED(flags); - - mem_attn->state_read(io, seq_id); - mem_recr->state_read(io, seq_id); + if ((flags & LLAMA_STATE_SEQ_FLAGS_PARTIAL_ONLY) == 0) { + mem_attn->state_read(io, seq_id, flags); + } + mem_recr->state_read(io, seq_id, flags); } llama_kv_cache * llama_memory_hybrid::get_mem_attn() const { diff --git a/llama/llama.cpp/src/llama-memory-recurrent.cpp b/llama/llama.cpp/src/llama-memory-recurrent.cpp index 44645fcdd2..d67f5a5f47 100644 --- a/llama/llama.cpp/src/llama-memory-recurrent.cpp +++ b/llama/llama.cpp/src/llama-memory-recurrent.cpp @@ -136,6 +136,7 @@ void llama_memory_recurrent::clear(bool data) { } bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos p1) { + //printf("[DEBUG] calling llama_memory_recurrent::seq_rm` with `seq_id=%d, p0=%d, p1=%d`\n", seq_id, p0, p1); uint32_t new_head = size; if (p0 < 0) { @@ -156,7 +157,8 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos if (tail_id >= 0) { const auto & cell = cells[tail_id]; // partial intersection is invalid - if ((0 < p0 && p0 <= cell.pos) || (0 < p1 && p1 <= cell.pos)) { + if ((0 < p0 && p0 < cell.pos) || (0 < p1 && p1 <= cell.pos)) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: partial intersection is invalid, so returning false\n"); return false; } // invalidate tails which will be cleared @@ -167,6 +169,7 @@ bool llama_memory_recurrent::seq_rm(llama_seq_id seq_id, llama_pos p0, llama_pos } else { // seq_id is negative, then the range should include everything or nothing if (p0 != p1 && (p0 != 0 || p1 != std::numeric_limits::max())) { + //printf("[DEBUG] inside `llama_memory_recurrent::seq_rm`: `seq_id` is negative, so returning false\n"); return false; } } @@ -379,7 +382,9 @@ llama_memory_context_ptr llama_memory_recurrent::init_batch(llama_batch_allocr & // if all tokens are output, split by sequence ubatch = balloc.split_seq(n_ubatch); } else { - ubatch = balloc.split_equal(n_ubatch, false); + // TODO: non-sequential equal split can be done if using unified KV cache + // for simplicity, we always use sequential equal split for now + ubatch = balloc.split_equal(n_ubatch, true); } if (ubatch.n_tokens == 0) { @@ -856,9 +861,12 @@ void llama_memory_recurrent::state_write_data(llama_io_write_i & io, const std:: bool llama_memory_recurrent::state_read_meta(llama_io_read_i & io, uint32_t cell_count, llama_seq_id dest_seq_id) { if (dest_seq_id != -1) { // single sequence - seq_rm(dest_seq_id, -1, -1); + if (cell_count == 0) { + return true; + } + llama_batch_allocr balloc(hparams.n_pos_per_embd()); llama_ubatch ubatch = balloc.ubatch_reserve(cell_count, 1); diff --git a/llama/llama.cpp/src/llama-model-loader.cpp b/llama/llama.cpp/src/llama-model-loader.cpp index daef900c48..ee303bd58e 100644 --- a/llama/llama.cpp/src/llama-model-loader.cpp +++ b/llama/llama.cpp/src/llama-model-loader.cpp @@ -465,6 +465,7 @@ namespace GGUFMeta { // TODO: this is not very clever - figure out something better template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( diff --git a/llama/llama.cpp/src/llama-model.cpp b/llama/llama.cpp/src/llama-model.cpp index 0398b553f7..74e1d1623f 100644 --- a/llama/llama.cpp/src/llama-model.cpp +++ b/llama/llama.cpp/src/llama-model.cpp @@ -114,6 +114,7 @@ const char * llm_type_name(llm_type type) { case LLM_TYPE_17B_16E: return "17Bx16E (Scout)"; case LLM_TYPE_17B_128E: return "17Bx128E (Maverick)"; case LLM_TYPE_A13B: return "A13B"; + case LLM_TYPE_8B_A1B: return "8B.A1B"; case LLM_TYPE_21B_A3B: return "21B.A3B"; case LLM_TYPE_30B_A3B: return "30B.A3B"; case LLM_TYPE_106B_A12B: return "106B.A12B"; @@ -310,7 +311,7 @@ static ggml_backend_buffer_type_t select_weight_buft(const llama_hparams & hpara } // CPU: ACCEL -> GPU host -> CPU extra -> CPU -static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts) { +static buft_list_t make_cpu_buft_list(const std::vector & devices, bool use_extra_bufts, bool no_host) { buft_list_t buft_list; // add ACCEL buffer types @@ -331,11 +332,13 @@ static buft_list_t make_cpu_buft_list(const std::vector & de // generally, this will be done using the first device in the list // a better approach would be to handle this on a weight-by-weight basis using the offload_op // function of the device to determine if it would benefit from being stored in a host buffer - for (auto * dev : devices) { - ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); - if (buft) { - buft_list.emplace_back(dev, buft); - break; + if (!no_host) { + for (auto * dev : devices) { + ggml_backend_buffer_type_t buft = ggml_backend_dev_host_buffer_type(dev); + if (buft) { + buft_list.emplace_back(dev, buft); + break; + } } } @@ -512,9 +515,13 @@ void llama_model::load_hparams(llama_model_loader & ml) { llm_arch_is_recurrent(ml.get_arch())); std::fill(hparams.rope_sections.begin(), hparams.rope_sections.end(), 0); - std::fill(hparams.swa_layers.begin(), hparams.swa_layers.end(), 0); + std::fill(hparams.xielu_alpha_n.begin(), hparams.xielu_alpha_n.end(), 0.0f); + std::fill(hparams.xielu_alpha_p.begin(), hparams.xielu_alpha_p.end(), 0.0f); + std::fill(hparams.xielu_beta.begin(), hparams.xielu_beta.end(), 0.0f); + std::fill(hparams.xielu_eps.begin(), hparams.xielu_eps.end(), 0.0f); + ml.get_key_or_arr(LLM_KV_FEED_FORWARD_LENGTH, hparams.n_ff_arr, hparams.n_layer, false); ml.get_key_or_arr(LLM_KV_ATTENTION_HEAD_COUNT, hparams.n_head_arr, hparams.n_layer, false); @@ -1084,7 +1091,11 @@ void llama_model::load_hparams(llama_model_loader & ml) { } break; default: type = LLM_TYPE_UNKNOWN; - } + } + + // Load attention parameters + ml.get_key(LLM_KV_ATTENTION_KEY_LENGTH, hparams.n_embd_head_k, false); + ml.get_key(LLM_KV_ATTENTION_VALUE_LENGTH, hparams.n_embd_head_v, false); } break; case LLM_ARCH_GPT2: { @@ -1207,12 +1218,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { hparams.set_swa_pattern(6); hparams.causal_attn = false; // embeddings do not use causal attention - hparams.rope_freq_base_train_swa = 10000.0f; + hparams.rope_freq_base_train_swa = 10000.0f; hparams.rope_freq_scale_train_swa = 1.0f; - ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); + ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa); ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); - ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + ml.get_key(LLM_KV_POOLING_TYPE, hparams.pooling_type); + + //applied only if model converted with --sentence-transformers-dense-modules + ml.get_key(LLM_KV_DENSE_2_FEAT_IN, hparams.dense_2_feat_in, false); + ml.get_key(LLM_KV_DENSE_2_FEAT_OUT, hparams.dense_2_feat_out, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_IN, hparams.dense_3_feat_in, false); + ml.get_key(LLM_KV_DENSE_3_FEAT_OUT, hparams.dense_3_feat_out, false); + + GGML_ASSERT((hparams.dense_2_feat_in == 0 || hparams.dense_2_feat_in == hparams.n_embd) && "dense_2_feat_in must be equal to n_embd"); + GGML_ASSERT((hparams.dense_3_feat_out == 0 || hparams.dense_3_feat_out == hparams.n_embd) && "dense_3_feat_out must be equal to n_embd"); switch (hparams.n_layer) { case 24: type = LLM_TYPE_0_3B; break; @@ -2000,14 +2020,29 @@ void llama_model::load_hparams(llama_model_loader & ml) { for (uint32_t il = 0; il < hparams.n_layer; ++il) { hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; } + hparams.n_layer_dense_lead = hparams.n_layer; switch (hparams.n_ff()) { case 4608: type = LLM_TYPE_350M; break; case 6912: type = LLM_TYPE_700M; break; case 8192: type = LLM_TYPE_1_2B; break; case 10752: type = LLM_TYPE_2_6B; break; - default: type = LLM_TYPE_UNKNOWN; + default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_LFM2MOE: + { + ml.get_key(LLM_KV_SHORTCONV_L_CACHE, hparams.n_shortconv_l_cache); + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key(LLM_KV_LEADING_DENSE_BLOCK_COUNT, hparams.n_layer_dense_lead); + ml.get_key(LLM_KV_EXPERT_FEED_FORWARD_LENGTH, hparams.n_ff_exp); + ml.get_key(LLM_KV_EXPERT_GATING_FUNC, hparams.expert_gating_func); + + for (uint32_t il = 0; il < hparams.n_layer; ++il) { + hparams.recurrent_layer_arr[il] = hparams.n_head_kv(il) == 0; + } + + type = LLM_TYPE_8B_A1B; + } break; case LLM_ARCH_SMALLTHINKER: { const bool found_swa = ml.get_key(LLM_KV_ATTENTION_SLIDING_WINDOW, hparams.n_swa, false); @@ -2044,6 +2079,19 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; + case LLM_ARCH_APERTUS: + { + ml.get_key(LLM_KV_ATTENTION_LAYERNORM_RMS_EPS, hparams.f_norm_rms_eps); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_N, hparams.xielu_alpha_n, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_ALPHA_P, hparams.xielu_alpha_p, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_BETA, hparams.xielu_beta, hparams.n_layer); + ml.get_key_or_arr(LLM_KV_XIELU_EPS, hparams.xielu_eps, hparams.n_layer); + + switch (hparams.n_layer) { + case 32: type = LLM_TYPE_8B; break; + default: type = LLM_TYPE_UNKNOWN; + } + } break; default: throw std::runtime_error("unsupported model architecture"); } @@ -2077,7 +2125,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { LLAMA_LOG_INFO("%s: loading model tensors, this can take a while... (mmap = %s)\n", __func__, ml.use_mmap ? "true" : "false"); // build a list of buffer types for the CPU and GPU devices - pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts); + pimpl->cpu_buft_list = make_cpu_buft_list(devices, params.use_extra_bufts, params.no_host); for (auto * dev : devices) { buft_list_t buft_list = make_gpu_buft_list(dev, split_mode, tensor_split); // add CPU buffer types as a fallback @@ -3407,17 +3455,17 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } break; case LLM_ARCH_PLAMO2: { + // mamba parameters const uint32_t d_conv = hparams.ssm_d_conv; const uint32_t d_state = hparams.ssm_d_state; const uint32_t num_heads = hparams.ssm_dt_rank; const uint32_t intermediate_size = hparams.ssm_d_inner; - const uint32_t head_dim = intermediate_size / num_heads; - const uint32_t qk_dim = head_dim; - const uint32_t v_dim = head_dim; - const int64_t num_attention_heads = hparams.n_head(); - const int64_t q_num_heads = num_attention_heads; const int64_t dt_dim = std::max(64, int(hparams.n_embd / 16)); + // attention parameters + const uint32_t qk_dim = hparams.n_embd_head_k; + const uint32_t v_dim = hparams.n_embd_head_v; + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); // output @@ -3451,6 +3499,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ssm_b_norm = create_tensor(tn(LLM_TENSOR_SSM_B_NORM, i), {d_state}, 0); layer.ssm_c_norm = create_tensor(tn(LLM_TENSOR_SSM_C_NORM, i), {d_state}, 0); } else { + const int64_t num_attention_heads = hparams.n_head(i); + const int64_t q_num_heads = num_attention_heads; const int64_t num_key_value_heads = hparams.n_head_kv(i); const int64_t k_num_heads = num_key_value_heads; const int64_t v_num_heads = num_key_value_heads; @@ -3459,8 +3509,8 @@ bool llama_model::load_tensors(llama_model_loader & ml) { const int64_t v_proj_dim = v_num_heads * v_dim; layer.wqkv = create_tensor(tn(LLM_TENSOR_ATTN_QKV, "weight", i), {n_embd, q_proj_dim + k_proj_dim + v_proj_dim}, 0); - layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {head_dim, num_attention_heads}, 0); - layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {head_dim, k_num_heads}, 0); + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), {qk_dim, num_attention_heads}, 0); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), {qk_dim, k_num_heads}, 0); layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), {q_num_heads * v_dim, n_embd}, 0); } @@ -3660,6 +3710,11 @@ bool llama_model::load_tensors(llama_model_loader & ml) { output = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, TENSOR_DUPLICATED); } + // Dense linear weights + dense_2_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_2_OUT, "weight"), {n_embd, hparams.dense_2_feat_out}, TENSOR_NOT_REQUIRED); + dense_3_out_layers = create_tensor(tn(LLM_TENSOR_DENSE_3_OUT, "weight"), {hparams.dense_3_feat_in, n_embd}, TENSOR_NOT_REQUIRED); + + for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; @@ -4840,11 +4895,13 @@ bool llama_model::load_tensors(llama_model_loader & ml) { // NextN/MTP tensors (preserved but unused) - conditionally load for last nextn_predict_layers if (hparams.nextn_predict_layers > 0 && static_cast(i) >= n_layer - hparams.nextn_predict_layers) { layer.nextn.eh_proj = create_tensor(tn(LLM_TENSOR_NEXTN_EH_PROJ, "weight", i), { 2 * n_embd, n_embd }, flags); - layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags); layer.nextn.enorm = create_tensor(tn(LLM_TENSOR_NEXTN_ENORM, "weight", i), { n_embd }, flags); layer.nextn.hnorm = create_tensor(tn(LLM_TENSOR_NEXTN_HNORM, "weight", i), { n_embd }, flags); - layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags); - layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags); + + // Optional tensors + layer.nextn.embed_tokens = create_tensor(tn(LLM_TENSOR_NEXTN_EMBED_TOKENS, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_head = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_HEAD, "weight", i), { n_embd, n_vocab }, flags | TENSOR_NOT_REQUIRED); + layer.nextn.shared_head_norm = create_tensor(tn(LLM_TENSOR_NEXTN_SHARED_HEAD_NORM, "weight", i), { n_embd }, flags | TENSOR_NOT_REQUIRED); } } } @@ -5830,6 +5887,7 @@ bool llama_model::load_tensors(llama_model_loader & ml) { } } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), {n_embd, n_vocab}, 0); tok_norm = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD_NORM, "weight"), {n_embd}, 0); @@ -5841,11 +5899,23 @@ bool llama_model::load_tensors(llama_model_loader & ml) { for (int i = 0; i < n_layer; ++i) { auto & layer = layers[i]; - // ffn is same for transformer and conv layers + + const bool is_moe_layer = i >= static_cast(hparams.n_layer_dense_lead); + + // ffn/moe is same for transformer and conv layers layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); - layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); - layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); - layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + if (is_moe_layer) { + GGML_ASSERT(n_expert && n_expert_used); + layer.ffn_gate_inp = create_tensor(tn(LLM_TENSOR_FFN_GATE_INP, "weight", i), {n_embd, n_expert}, 0); + layer.ffn_gate_exps = create_tensor(tn(LLM_TENSOR_FFN_GATE_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_down_exps = create_tensor(tn(LLM_TENSOR_FFN_DOWN_EXPS, "weight", i), {hparams.n_ff_exp, n_embd, n_expert}, 0); + layer.ffn_up_exps = create_tensor(tn(LLM_TENSOR_FFN_UP_EXPS, "weight", i), {n_embd, hparams.n_ff_exp, n_expert}, 0); + layer.ffn_exp_probs_b = create_tensor(tn(LLM_TENSOR_FFN_EXP_PROBS_B, "bias", i), {n_expert}, 0); + } else { // dense + layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); + } // for operator_norm layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), {n_embd}, 0); @@ -5950,6 +6020,48 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_up_chexps = create_tensor(tn(LLM_TENSOR_FFN_UP_CHEXPS, "weight", i), { n_embd, n_ff_chexp, n_chunk_expert}, 0); } } break; + case LLM_ARCH_APERTUS: + { + tok_embd = create_tensor(tn(LLM_TENSOR_TOKEN_EMBD, "weight"), { n_embd, n_vocab }, 0); + + // output + output_norm = create_tensor(tn(LLM_TENSOR_OUTPUT_NORM, "weight"), { n_embd }, 0); + output = create_tensor(tn(LLM_TENSOR_OUTPUT, "weight"), { n_embd, n_vocab }, 0); + + for (int i = 0; i < n_layer; ++i) { + auto & layer = layers[i]; + + layer.attn_norm = create_tensor(tn(LLM_TENSOR_ATTN_NORM, "weight", i), { n_embd }, 0); + + if (hparams.rope_scaling_type_train == LLAMA_ROPE_SCALING_TYPE_LONGROPE) { + layer.rope_long = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_LONG, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + layer.rope_short = create_tensor(tn(LLM_TENSOR_ROPE_FACTORS_SHORT, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } else { + layer.rope_freqs = create_tensor(tn(LLM_TENSOR_ROPE_FREQS, "weight", i), { n_rot/2 }, TENSOR_NOT_REQUIRED | (i != 0 ? TENSOR_DUPLICATED : 0)); + } + + layer.wq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "weight", i), { n_embd, n_embd_head_k * n_head }, 0); + layer.wk = create_tensor(tn(LLM_TENSOR_ATTN_K, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wv = create_tensor(tn(LLM_TENSOR_ATTN_V, "weight", i), { n_embd, n_embd_gqa }, 0); + layer.wo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "weight", i), { n_embd_head_k * n_head, n_embd }, 0); + + // optional bias tensors + layer.bq = create_tensor(tn(LLM_TENSOR_ATTN_Q, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + layer.bk = create_tensor(tn(LLM_TENSOR_ATTN_K, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bv = create_tensor(tn(LLM_TENSOR_ATTN_V, "bias", i), { n_embd_gqa }, TENSOR_NOT_REQUIRED); + layer.bo = create_tensor(tn(LLM_TENSOR_ATTN_OUT, "bias", i), { n_embd }, TENSOR_NOT_REQUIRED); + + layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), { n_embd }, 0); + layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd }, 0); + layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), { n_embd, n_ff }, 0); + + // Q and K layernorms for Apertus + layer.attn_q_norm = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_q_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_Q_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + layer.attn_k_norm = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "weight", i), { n_embd_head_k }, 0); + layer.attn_k_norm_b = create_tensor(tn(LLM_TENSOR_ATTN_K_NORM, "bias", i), { n_embd_head_k }, TENSOR_NOT_REQUIRED); + } + } break; default: throw std::runtime_error("unknown architecture"); } @@ -6284,7 +6396,7 @@ void llama_model::print_info() const { LLAMA_LOG_INFO("%s: expert_weights_norm = %d\n", __func__, hparams.expert_weights_norm); } - if (arch == LLM_ARCH_SMALLTHINKER) { + if (arch == LLM_ARCH_SMALLTHINKER || arch == LLM_ARCH_LFM2MOE) { LLAMA_LOG_INFO("%s: n_ff_exp = %d\n", __func__, hparams.n_ff_exp); LLAMA_LOG_INFO("%s: expert_gating_func = %s\n", __func__, llama_expert_gating_func_name((llama_expert_gating_func_type) hparams.expert_gating_func)); } @@ -7819,6 +7931,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -7828,6 +7942,8 @@ struct llm_build_bert : public llm_graph_context { } if (model.layers[il].attn_k_norm) { + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, model.layers[il].attn_k_norm_b, @@ -8210,6 +8326,9 @@ struct llm_build_mpt : public llm_graph_context { // Q/K Layernorm if (model.layers[il].attn_q_norm) { + Qcur = ggml_reshape_2d(ctx0, Qcur, n_embd_head*n_head, n_tokens); + Kcur = ggml_reshape_2d(ctx0, Kcur, n_embd_head*n_head_kv, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, model.layers[il].attn_q_norm_b, @@ -16237,10 +16356,10 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } ggml_tensor * build_layer_ffn( - ggml_tensor * cur, - ggml_tensor * inpSA, - const llama_model & model, - const int il) { + ggml_tensor * cur, + ggml_tensor * inpSA, + const llama_model & model, + const int il) { // For Granite architectures - scale residual if (hparams.f_residual_scale) { @@ -17811,6 +17930,7 @@ private: const int64_t n_embd_head_q = hparams.n_embd_head_k; const int64_t n_embd_head_k = hparams.n_embd_head_k; const int64_t n_embd_head_v = hparams.n_embd_head_v; + int32_t n_head = hparams.n_head(il); int32_t n_head_kv = hparams.n_head_kv(il); const int64_t q_offset = 0; @@ -18727,6 +18847,8 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_tensor * inp_out_ids = build_inp_out_ids(); for (int il = 0; il < n_layer; ++il) { + const bool is_moe_layer = il >= static_cast(hparams.n_layer_dense_lead); + auto * prev_cur = cur; cur = build_norm(cur, model.layers[il].attn_norm, NULL, LLM_NORM_RMS, il); cb(cur, "model.layers.{}.operator_norm", il); @@ -18741,7 +18863,16 @@ struct llm_build_lfm2 : public llm_graph_context { } cur = ggml_add(ctx0, prev_cur, cur); - cur = ggml_add(ctx0, cur, build_feed_forward(cur, il)); + + auto * ffn_norm_out = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); + cb(ffn_norm_out, "model.layers.{}.ffn_norm", il); + + ggml_tensor * ffn_out = is_moe_layer ? + build_moe_feed_forward(ffn_norm_out, il) : + build_dense_feed_forward(ffn_norm_out, il); + cb(ffn_norm_out, "model.layers.{}.ffn_out", il); + + cur = ggml_add(ctx0, cur, ffn_out); } cur = build_norm(cur, model.tok_norm, NULL, LLM_NORM_RMS, -1); @@ -18756,23 +18887,32 @@ struct llm_build_lfm2 : public llm_graph_context { ggml_build_forward_expand(gf, cur); } - ggml_tensor * build_feed_forward(ggml_tensor * cur, - int il) const { - cur = build_norm(cur, model.layers[il].ffn_norm, NULL, LLM_NORM_RMS, il); - cb(cur, "model.layers.{}.ffn_norm", il); + ggml_tensor * build_moe_feed_forward(ggml_tensor * cur, + int il) const { + return build_moe_ffn(cur, + model.layers[il].ffn_gate_inp, + model.layers[il].ffn_up_exps, + model.layers[il].ffn_gate_exps, + model.layers[il].ffn_down_exps, + model.layers[il].ffn_exp_probs_b, + n_expert, n_expert_used, + LLM_FFN_SILU, true, + false, 0.0, + static_cast(hparams.expert_gating_func), + il); + } + ggml_tensor * build_dense_feed_forward(ggml_tensor * cur, + int il) const { GGML_ASSERT(!model.layers[il].ffn_up_b); GGML_ASSERT(!model.layers[il].ffn_gate_b); GGML_ASSERT(!model.layers[il].ffn_down_b); - cur = build_ffn(cur, + return build_ffn(cur, model.layers[il].ffn_up, NULL, NULL, model.layers[il].ffn_gate, NULL, NULL, model.layers[il].ffn_down, NULL, NULL, NULL, LLM_FFN_SILU, LLM_FFN_PAR, il); - cb(cur, "model.layers.{}.feed_forward.w2", il); - - return cur; } ggml_tensor * build_attn_block(ggml_tensor * cur, @@ -19292,6 +19432,141 @@ struct llm_build_grovemoe : public llm_graph_context { } }; +struct llm_build_apertus : public llm_graph_context { + llm_build_apertus(const llama_model & model, const llm_graph_params & params) : llm_graph_context(params) { + const int64_t n_embd_head = hparams.n_embd_head_v; + + GGML_ASSERT(n_embd_head == hparams.n_embd_head_k); + GGML_ASSERT(n_embd_head == hparams.n_rot); + + ggml_tensor * cur; + ggml_tensor * inpL; + + inpL = build_inp_embd(model.tok_embd); + + ggml_tensor * inp_pos = build_inp_pos(); + auto * inp_attn = build_attn_inp_kv(); + + const float kq_scale = hparams.f_attention_scale == 0.0f ? 1.0f / sqrtf(float(n_embd_head)) : hparams.f_attention_scale; + + ggml_tensor * inp_out_ids = build_inp_out_ids(); + + for (int il = 0; il < n_layer; ++il) { + ggml_tensor * inpSA = inpL; + + cur = build_norm(inpL, + model.layers[il].attn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "attn_norm", il); + + // self-attention + { + ggml_tensor * rope_factors = model.get_rope_factors(cparams, il); + + // compute Q and K and RoPE them + ggml_tensor * Qcur = build_lora_mm(model.layers[il].wq, cur); + cb(Qcur, "Qcur", il); + + ggml_tensor * Kcur = build_lora_mm(model.layers[il].wk, cur); + cb(Kcur, "Kcur", il); + + ggml_tensor * Vcur = build_lora_mm(model.layers[il].wv, cur); + cb(Vcur, "Vcur", il); + + Qcur = ggml_reshape_3d(ctx0, Qcur, n_embd_head, n_head, n_tokens); + Qcur = build_norm(Qcur, model.layers[il].attn_q_norm, NULL, LLM_NORM_RMS, il); + cb(Qcur, "Qcur_normed", il); + + Kcur = ggml_reshape_3d(ctx0, Kcur, n_embd_head, n_head_kv, n_tokens); + Kcur = build_norm(Kcur, model.layers[il].attn_k_norm, NULL, LLM_NORM_RMS, il); + cb(Kcur, "Kcur_normed", il); + + Vcur = ggml_reshape_3d(ctx0, Vcur, n_embd_head, n_head_kv, n_tokens); + + Qcur = ggml_rope_ext( + ctx0, Qcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + Kcur = ggml_rope_ext( + ctx0, Kcur, inp_pos, rope_factors, + n_rot, rope_type, n_ctx_orig, freq_base, freq_scale, + ext_factor, attn_factor, beta_fast, beta_slow + ); + + cb(Qcur, "Qcur_pos", il); + cb(Kcur, "Kcur_pos", il); + cb(Vcur, "Vcur_pos", il); + + cur = build_attn(inp_attn, + model.layers[il].wo, model.layers[il].bo, + Qcur, Kcur, Vcur, nullptr, nullptr, nullptr, kq_scale, il); + cb(cur, "attn_out", il); + } + + if (il == n_layer - 1 && inp_out_ids) { + cur = ggml_get_rows(ctx0, cur, inp_out_ids); + inpSA = ggml_get_rows(ctx0, inpSA, inp_out_ids); + } + + ggml_tensor * ffn_inp = ggml_add(ctx0, cur, inpSA); + cb(ffn_inp, "ffn_inp", il); + + // feed-forward network with xIELU activation + { + cur = build_norm(ffn_inp, + model.layers[il].ffn_norm, nullptr, + LLM_NORM_RMS, il); + cb(cur, "ffn_norm", il); + + // Up projection + ggml_tensor * up = build_lora_mm(model.layers[il].ffn_up, cur); + cb(up, "ffn_up", il); + + float alpha_n_val = hparams.xielu_alpha_n[il]; + float alpha_p_val = hparams.xielu_alpha_p[il]; + float beta_val = hparams.xielu_beta[il]; + float eps_val = hparams.xielu_eps[il]; + + // Apply xIELU activation + ggml_tensor * activated = ggml_xielu(ctx0, up, alpha_n_val, alpha_p_val, beta_val, eps_val); + cb(activated, "ffn_xielu", il); + + // Down projection + cur = build_lora_mm(model.layers[il].ffn_down, activated); + cb(cur, "ffn_down", il); + } + + cur = ggml_add(ctx0, cur, ffn_inp); + cb(cur, "ffn_out", il); + + cur = build_cvec(cur, il); + cb(cur, "l_out", il); + + // input for next layer + inpL = cur; + } + + cur = inpL; + + cur = build_norm(cur, + model.output_norm, nullptr, + LLM_NORM_RMS, -1); + + cb(cur, "result_norm", -1); + res->t_embd = cur; + + // lm_head + cur = build_lora_mm(model.output, cur); + + cb(cur, "result_output", -1); + res->t_logits = cur; + + ggml_build_forward_expand(gf, cur); + } +}; + llama_memory_i * llama_model::create_memory(const llama_memory_params & params, llama_cparams & cparams) const { llama_memory_i * res; @@ -19811,6 +20086,7 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { llm = std::make_unique(*this, params); } break; case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: { llm = std::make_unique(*this, params); } break; @@ -19826,6 +20102,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; + case LLM_ARCH_APERTUS: + { + llm = std::make_unique(*this, params); + } break; default: GGML_ABORT("fatal error"); } @@ -19833,6 +20113,12 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { // add on pooling layer llm->build_pooling(cls, cls_b, cls_out, cls_out_b); + // if the gguf model was converted with --sentence-transformers-dense-modules + // there will be two additional dense projection layers + // dense linear projections are applied after pooling + // TODO: move reranking logic here and generalize + llm->build_dense_out(dense_2_out_layers, dense_3_out_layers); + return llm->res->get_gf(); } @@ -19857,6 +20143,7 @@ llama_model_params llama_model_default_params() { /*.use_mlock =*/ false, /*.check_tensors =*/ false, /*.use_extra_bufts =*/ true, + /*.no_host =*/ false, }; return result; @@ -20029,10 +20316,12 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_OPENAI_MOE: case LLM_ARCH_HUNYUAN_DENSE: case LLM_ARCH_LFM2: + case LLM_ARCH_LFM2MOE: case LLM_ARCH_SMALLTHINKER: case LLM_ARCH_GLM4_MOE: case LLM_ARCH_SEED_OSS: case LLM_ARCH_GROVEMOE: + case LLM_ARCH_APERTUS: return LLAMA_ROPE_TYPE_NEOX; case LLM_ARCH_QWEN2VL: @@ -20143,6 +20432,10 @@ bool llama_model_is_recurrent(const llama_model * model) { return llm_arch_is_recurrent(model->arch); } +bool llama_model_is_hybrid(const llama_model * model) { + return llm_arch_is_hybrid(model->arch); +} + bool llama_model_is_diffusion(const llama_model * model) { return llm_arch_is_diffusion(model->arch); } diff --git a/llama/llama.cpp/src/llama-model.h b/llama/llama.cpp/src/llama-model.h index c086f94e82..ec3fbd334a 100644 --- a/llama/llama.cpp/src/llama-model.h +++ b/llama/llama.cpp/src/llama-model.h @@ -108,6 +108,7 @@ enum llm_type { LLM_TYPE_17B_16E, // llama4 Scout LLM_TYPE_17B_128E, // llama4 Maverick LLM_TYPE_A13B, + LLM_TYPE_8B_A1B, // lfm2moe LLM_TYPE_21B_A3B, // Ernie MoE small LLM_TYPE_30B_A3B, LLM_TYPE_106B_A12B, // GLM-4.5-Air @@ -381,6 +382,12 @@ struct llama_layer { // openai-moe struct ggml_tensor * attn_sinks = nullptr; + // xIELU activation parameters for Apertus + struct ggml_tensor * ffn_act_alpha_n = nullptr; + struct ggml_tensor * ffn_act_alpha_p = nullptr; + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct ggml_tensor * bskcn_tv = nullptr; struct llama_layer_posnet posnet; @@ -434,6 +441,12 @@ struct llama_model { std::vector layers; + //Dense linear projections for SentenceTransformers models like embeddinggemma + // For Sentence Transformers models structure see + // https://sbert.net/docs/sentence_transformer/usage/custom_models.html#structure-of-sentence-transformer-models + struct ggml_tensor * dense_2_out_layers = nullptr; + struct ggml_tensor * dense_3_out_layers = nullptr; + llama_model_params params; // gguf metadata diff --git a/llama/llama.cpp/src/llama-sampling.cpp b/llama/llama.cpp/src/llama-sampling.cpp index 8fb8600927..da34526b1f 100644 --- a/llama/llama.cpp/src/llama-sampling.cpp +++ b/llama/llama.cpp/src/llama-sampling.cpp @@ -2541,8 +2541,13 @@ static void llama_sampler_infill_apply(struct llama_sampler * smpl, llama_token_ if (n_non_eog == 0) { cur_p->size = 1; cur_p->data[0].id = ctx->vocab->token_eot(); + if (cur_p->data[0].id == LLAMA_TOKEN_NULL) { + cur_p->data[0].id = ctx->vocab->token_eos(); + } cur_p->data[0].logit = 1.0f; + GGML_ASSERT(cur_p->data[0].id != LLAMA_TOKEN_NULL); + return; } diff --git a/llama/llama.cpp/src/llama-vocab.cpp b/llama/llama.cpp/src/llama-vocab.cpp index 64c78a164e..217ede47a0 100644 --- a/llama/llama.cpp/src/llama-vocab.cpp +++ b/llama/llama.cpp/src/llama-vocab.cpp @@ -347,6 +347,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { case LLAMA_VOCAB_PRE_TYPE_OLMO: case LLAMA_VOCAB_PRE_TYPE_JAIS: case LLAMA_VOCAB_PRE_TYPE_TRILLION: + case LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING: regex_exprs = { "'s|'t|'re|'ve|'m|'ll|'d| ?\\p{L}+| ?\\p{N}+| ?[^\\s\\p{L}\\p{N}]+|\\s+(?!\\S)", }; @@ -1950,6 +1951,10 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { tokenizer_pre == "trillion") { pre_type = LLAMA_VOCAB_PRE_TYPE_TRILLION; clean_spaces = false; + } else if ( + tokenizer_pre == "granite-docling") { + pre_type = LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING; + clean_spaces = false; } else if ( tokenizer_pre == "bailingmoe" || tokenizer_pre == "llada-moe") { @@ -2156,6 +2161,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { || t.first == "<|end|>" || t.first == "" || t.first == "<|endoftext|>" + || t.first == "<|end_of_text|>" // granite || t.first == "" || t.first == "_" || t.first == "<|end▁of▁sentence|>" // DeepSeek diff --git a/llama/llama.cpp/src/llama-vocab.h b/llama/llama.cpp/src/llama-vocab.h index 0d2f28c36c..5e468675e4 100644 --- a/llama/llama.cpp/src/llama-vocab.h +++ b/llama/llama.cpp/src/llama-vocab.h @@ -8,46 +8,47 @@ // pre-tokenization types enum llama_vocab_pre_type { - LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, - LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, - LLAMA_VOCAB_PRE_TYPE_FALCON = 4, - LLAMA_VOCAB_PRE_TYPE_MPT = 5, - LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, - LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, - LLAMA_VOCAB_PRE_TYPE_REFACT = 8, - LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, - LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, - LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, - LLAMA_VOCAB_PRE_TYPE_OLMO = 12, - LLAMA_VOCAB_PRE_TYPE_DBRX = 13, - LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, - LLAMA_VOCAB_PRE_TYPE_PORO = 15, - LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, - LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, - LLAMA_VOCAB_PRE_TYPE_VIKING = 18, - LLAMA_VOCAB_PRE_TYPE_JAIS = 19, - LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, - LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, - LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, - LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, - LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, - LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, - LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, - LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, - LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, - LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, - LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, - LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, - LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, - LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, - LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, - LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, - LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, - LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, - LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_DEFAULT = 0, + LLAMA_VOCAB_PRE_TYPE_LLAMA3 = 1, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_LLM = 2, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK_CODER = 3, + LLAMA_VOCAB_PRE_TYPE_FALCON = 4, + LLAMA_VOCAB_PRE_TYPE_MPT = 5, + LLAMA_VOCAB_PRE_TYPE_STARCODER = 6, + LLAMA_VOCAB_PRE_TYPE_GPT2 = 7, + LLAMA_VOCAB_PRE_TYPE_REFACT = 8, + LLAMA_VOCAB_PRE_TYPE_COMMAND_R = 9, + LLAMA_VOCAB_PRE_TYPE_STABLELM2 = 10, + LLAMA_VOCAB_PRE_TYPE_QWEN2 = 11, + LLAMA_VOCAB_PRE_TYPE_OLMO = 12, + LLAMA_VOCAB_PRE_TYPE_DBRX = 13, + LLAMA_VOCAB_PRE_TYPE_SMAUG = 14, + LLAMA_VOCAB_PRE_TYPE_PORO = 15, + LLAMA_VOCAB_PRE_TYPE_CHATGLM3 = 16, + LLAMA_VOCAB_PRE_TYPE_CHATGLM4 = 17, + LLAMA_VOCAB_PRE_TYPE_VIKING = 18, + LLAMA_VOCAB_PRE_TYPE_JAIS = 19, + LLAMA_VOCAB_PRE_TYPE_TEKKEN = 20, + LLAMA_VOCAB_PRE_TYPE_SMOLLM = 21, + LLAMA_VOCAB_PRE_TYPE_CODESHELL = 22, + LLAMA_VOCAB_PRE_TYPE_BLOOM = 23, + LLAMA_VOCAB_PRE_TYPE_GPT3_FINNISH = 24, + LLAMA_VOCAB_PRE_TYPE_EXAONE = 25, + LLAMA_VOCAB_PRE_TYPE_CHAMELEON = 26, + LLAMA_VOCAB_PRE_TYPE_MINERVA = 27, + LLAMA_VOCAB_PRE_TYPE_DEEPSEEK3_LLM = 28, + LLAMA_VOCAB_PRE_TYPE_GPT4O = 29, + LLAMA_VOCAB_PRE_TYPE_SUPERBPE = 30, + LLAMA_VOCAB_PRE_TYPE_TRILLION = 31, + LLAMA_VOCAB_PRE_TYPE_BAILINGMOE = 32, + LLAMA_VOCAB_PRE_TYPE_LLAMA4 = 33, + LLAMA_VOCAB_PRE_TYPE_PIXTRAL = 34, + LLAMA_VOCAB_PRE_TYPE_SEED_CODER = 35, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN = 36, + LLAMA_VOCAB_PRE_TYPE_KIMI_K2 = 37, + LLAMA_VOCAB_PRE_TYPE_HUNYUAN_DENSE = 38, + LLAMA_VOCAB_PRE_TYPE_GROK_2 = 39, + LLAMA_VOCAB_PRE_TYPE_GRANITE_DOCLING = 40, }; struct LLM_KV; diff --git a/llama/llama.cpp/tools/mtmd/clip-impl.h b/llama/llama.cpp/tools/mtmd/clip-impl.h index 664b0c9ac6..7a7523851c 100644 --- a/llama/llama.cpp/tools/mtmd/clip-impl.h +++ b/llama/llama.cpp/tools/mtmd/clip-impl.h @@ -31,6 +31,7 @@ // vision-specific #define KEY_IMAGE_SIZE "clip.vision.image_size" +#define KEY_PREPROC_IMAGE_SIZE "clip.vision.preproc_image_size" #define KEY_PATCH_SIZE "clip.vision.patch_size" #define KEY_IMAGE_MEAN "clip.vision.image_mean" #define KEY_IMAGE_STD "clip.vision.image_std" diff --git a/llama/llama.cpp/tools/mtmd/clip.cpp b/llama/llama.cpp/tools/mtmd/clip.cpp index 355219a945..6699b75a4f 100644 --- a/llama/llama.cpp/tools/mtmd/clip.cpp +++ b/llama/llama.cpp/tools/mtmd/clip.cpp @@ -183,7 +183,9 @@ struct clip_hparams { int32_t projection_dim; int32_t n_head; int32_t n_layer; - int32_t proj_scale_factor = 0; // idefics3 + // idefics3 + int32_t preproc_image_size = 0; + int32_t proj_scale_factor = 0; float image_mean[3]; float image_std[3]; @@ -2263,6 +2265,7 @@ struct clip_model_loader { if (is_vision) { get_u32(KEY_IMAGE_SIZE, hparams.image_size); + get_u32(KEY_PREPROC_IMAGE_SIZE, hparams.preproc_image_size, false); get_u32(KEY_PATCH_SIZE, hparams.patch_size); get_u32(KEY_IMAGE_CROP_RESOLUTION, hparams.image_crop_resolution, false); get_i32(KEY_MINICPMV_VERSION, hparams.minicpmv_version, false); // legacy @@ -3590,10 +3593,51 @@ bool clip_image_preprocess(struct clip_ctx * ctx, const clip_image_u8 * img, str // res_imgs->data[0] = *res; res_imgs->entries.push_back(std::move(img_f32)); return true; - } - else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE + } else if (ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3) { + // The refined size has two steps: + // 1. Resize w/ aspect-ratio preserving such that the longer side is + // the preprocessor longest size + // 2. Resize w/out preserving aspect ratio such that both sides are + // multiples of image_size (always rounding up) + // + // CITE: https://github.com/huggingface/transformers/blob/main/src/transformers/models/idefics3/image_processing_idefics3.py#L737 + const clip_image_size refined_size = image_manipulation::calc_size_preserved_ratio( + original_size, params.image_size, params.preproc_image_size); + + llava_uhd::slice_instructions instructions; + instructions.overview_size = clip_image_size{params.image_size, params.image_size}; + instructions.refined_size = refined_size; + instructions.grid_size = clip_image_size{ + static_cast(std::ceil(static_cast(refined_size.width) / params.image_size)), + static_cast(std::ceil(static_cast(refined_size.height) / params.image_size)), + }; + for (int y = 0; y < refined_size.height; y += params.image_size) { + for (int x = 0; x < refined_size.width; x += params.image_size) { + instructions.slices.push_back(llava_uhd::slice_coordinates{ + /* x */x, + /* y */y, + /* size */clip_image_size{ + std::min(params.image_size, refined_size.width - x), + std::min(params.image_size, refined_size.height - y) + } + }); + } + } + auto imgs = llava_uhd::slice_image(img, instructions); + + // cast and normalize to f32 + for (size_t i = 0; i < imgs.size(); ++i) { + // clip_image_save_to_bmp(*imgs[i], "slice_" + std::to_string(i) + ".bmp"); + clip_image_f32_ptr res(clip_image_f32_init()); + normalize_image_u8_to_f32(*imgs[i], *res, params.image_mean, params.image_std); + res_imgs->entries.push_back(std::move(res)); + } + + res_imgs->grid_x = instructions.grid_size.width; + res_imgs->grid_y = instructions.grid_size.height; + return true; + } else if (ctx->proj_type() == PROJECTOR_TYPE_GLM_EDGE || ctx->proj_type() == PROJECTOR_TYPE_GEMMA3 - || ctx->proj_type() == PROJECTOR_TYPE_IDEFICS3 || ctx->proj_type() == PROJECTOR_TYPE_INTERNVL // TODO @ngxson : support dynamic resolution ) { clip_image_u8 resized_image; diff --git a/llama/llama.cpp/tools/mtmd/mtmd.cpp b/llama/llama.cpp/tools/mtmd/mtmd.cpp index 3d680945b2..35a0d25ed9 100644 --- a/llama/llama.cpp/tools/mtmd/mtmd.cpp +++ b/llama/llama.cpp/tools/mtmd/mtmd.cpp @@ -76,7 +76,7 @@ enum mtmd_slice_tmpl { MTMD_SLICE_TMPL_MINICPMV_2_5, MTMD_SLICE_TMPL_MINICPMV_2_6, MTMD_SLICE_TMPL_LLAMA4, - // TODO @ngxson : add support for idefics (SmolVLM) + MTMD_SLICE_TMPL_IDEFICS3, }; mtmd_input_text* mtmd_input_text_init(const char * text, bool add_special, bool parse_special) { @@ -124,19 +124,22 @@ struct mtmd_context { // for llava-uhd style models, we need special tokens in-between slices // minicpmv calls them "slices", llama 4 calls them "tiles" mtmd_slice_tmpl slice_tmpl = MTMD_SLICE_TMPL_NONE; - llama_token tok_ov_img_start = LLAMA_TOKEN_NULL; // overview image - llama_token tok_ov_img_end = LLAMA_TOKEN_NULL; // overview image - llama_token tok_slices_start = LLAMA_TOKEN_NULL; // start of all slices - llama_token tok_slices_end = LLAMA_TOKEN_NULL; // end of all slices - llama_token tok_sli_img_start = LLAMA_TOKEN_NULL; // single slice start - llama_token tok_sli_img_end = LLAMA_TOKEN_NULL; // single slice end - llama_token tok_sli_img_mid = LLAMA_TOKEN_NULL; // between 2 slices - llama_token tok_row_end = LLAMA_TOKEN_NULL; // end of row + std::vector tok_ov_img_start; // overview image + std::vector tok_ov_img_end; // overview image + std::vector tok_slices_start; // start of all slices + std::vector tok_slices_end; // end of all slices + std::vector tok_sli_img_start; // single slice start + std::vector tok_sli_img_end; // single slice end + std::vector tok_sli_img_mid; // between 2 slices + std::vector tok_row_end; // end of row bool tok_row_end_trail = false; bool ov_img_first = false; bool use_mrope = false; // for Qwen2VL, we need to use M-RoPE + // string template for slice image delimiters with row/col (idefics3) + std::string sli_img_start_tmpl; + // for whisper, we pre-calculate the mel filter bank whisper_preprocessor::whisper_filters w_filters; @@ -207,13 +210,13 @@ struct mtmd_context { // minicpmv 2.5 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_5; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_slices_start = lookup_token(""); - tok_slices_end = lookup_token(""); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_slices_start = {lookup_token("")}; + tok_slices_end = {lookup_token("")}; tok_sli_img_start = tok_ov_img_start; tok_sli_img_end = tok_ov_img_end; - tok_row_end = lookup_token("\n"); + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; @@ -221,11 +224,11 @@ struct mtmd_context { // minicpmv 2.6 format: // (overview) (slice) (slice) \n ... slice_tmpl = MTMD_SLICE_TMPL_MINICPMV_2_6; - tok_ov_img_start = lookup_token(""); - tok_ov_img_end = lookup_token(""); - tok_sli_img_start = lookup_token(""); - tok_sli_img_end = lookup_token(""); - tok_row_end = lookup_token("\n"); + tok_ov_img_start = {lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_sli_img_start = {lookup_token("")}; + tok_sli_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; tok_row_end_trail = false; // no trailing end-of-row token ov_img_first = true; @@ -240,9 +243,9 @@ struct mtmd_context { // <|image|> (overview) <-- overview image is last // <|image_end|> slice_tmpl = MTMD_SLICE_TMPL_LLAMA4; - tok_ov_img_start = lookup_token("<|image|>"); - tok_sli_img_mid = lookup_token("<|tile_x_separator|>"); - tok_row_end = lookup_token("<|tile_y_separator|>"); + tok_ov_img_start = {lookup_token("<|image|>")}; + tok_sli_img_mid = {lookup_token("<|tile_x_separator|>")}; + tok_row_end = {lookup_token("<|tile_y_separator|>")}; tok_row_end_trail = true; // add trailing end-of-row token ov_img_first = false; // overview image is last } @@ -255,8 +258,11 @@ struct mtmd_context { } else if (proj == PROJECTOR_TYPE_IDEFICS3) { // https://github.com/huggingface/transformers/blob/a42ba80fa520c784c8f11a973ca9034e5f859b79/src/transformers/models/idefics3/processing_idefics3.py#L192-L215 - img_beg = ""; - img_end = ""; + slice_tmpl = MTMD_SLICE_TMPL_IDEFICS3; + tok_ov_img_start = {lookup_token("\n\n"), lookup_token(""), lookup_token("")}; + tok_ov_img_end = {lookup_token("")}; + tok_row_end = {lookup_token("\n")}; + sli_img_start_tmpl = ""; } else if (proj == PROJECTOR_TYPE_PIXTRAL) { // https://github.com/huggingface/transformers/blob/1cd110c6cb6a6237614130c470e9a902dbc1a4bd/docs/source/en/model_doc/pixtral.md @@ -514,6 +520,7 @@ struct mtmd_tokenizer { ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_5 || ctx->slice_tmpl == MTMD_SLICE_TMPL_MINICPMV_2_6 || ctx->slice_tmpl == MTMD_SLICE_TMPL_LLAMA4 + || ctx->slice_tmpl == MTMD_SLICE_TMPL_IDEFICS3 ) { const int n_col = batch_f32.grid_x; const int n_row = batch_f32.grid_y; @@ -527,53 +534,45 @@ struct mtmd_tokenizer { // add overview image (first) if (ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } // add slices (or tiles) if (!chunks.empty()) { GGML_ASSERT((int)chunks.size() == n_row * n_col); - if (ctx->tok_slices_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_start}); - } + add_text(ctx->tok_slices_start); for (int y = 0; y < n_row; y++) { for (int x = 0; x < n_col; x++) { const bool is_last_in_row = (x == n_col - 1); - if (ctx->tok_sli_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_start}); + if (!ctx->tok_sli_img_start.empty()) { + add_text(ctx->tok_sli_img_start); + } else if (!ctx->sli_img_start_tmpl.empty()) { + // If using a template to preceed a slice image + const size_t sz = std::snprintf(nullptr, 0, ctx->sli_img_start_tmpl.c_str(), y+1, x+1) + 1; + std::unique_ptr buf(new char[sz]); + std::snprintf(buf.get(), sz, ctx->sli_img_start_tmpl.c_str(), y+1, x+1); + add_text(std::string(buf.get(), buf.get() + sz - 1), true); } cur.entries.emplace_back(std::move(chunks[y * n_col + x])); - if (ctx->tok_sli_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_end}); - } - if (!is_last_in_row && ctx->tok_sli_img_mid != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_sli_img_mid}); + add_text(ctx->tok_sli_img_end); + if (!is_last_in_row) { + add_text(ctx->tok_sli_img_mid); } } - if ((y != n_row - 1 || ctx->tok_row_end_trail) && ctx->tok_row_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_row_end}); + if ((y != n_row - 1 || ctx->tok_row_end_trail)) { + add_text(ctx->tok_row_end); } } - if (ctx->tok_slices_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_slices_end}); - } + add_text(ctx->tok_slices_end); } // add overview image (last) if (!ctx->ov_img_first) { - if (ctx->tok_ov_img_start != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_start}); - } + add_text(ctx->tok_ov_img_start); cur.entries.emplace_back(std::move(ov_chunk)); - if (ctx->tok_ov_img_end != LLAMA_TOKEN_NULL) { - add_text({ctx->tok_ov_img_end}); - } + add_text(ctx->tok_ov_img_end); } } else { @@ -790,7 +789,9 @@ int32_t mtmd_encode(mtmd_context * ctx, const mtmd_image_tokens * image_tokens) ctx->image_embd_v.resize(image_tokens->n_tokens() * n_mmproj_embd); bool ok = false; - if (clip_is_llava(ctx_clip) || clip_is_minicpmv(ctx_clip) || clip_is_glm(ctx_clip)) { + if (clip_is_llava(ctx_clip) + || clip_is_minicpmv(ctx_clip) + || clip_is_glm(ctx_clip)) { // TODO @ngxson : llava does not support batched encoding ; this should be fixed inside clip_image_batch_encode() const auto & entries = image_tokens->batch_f32.entries; for (size_t i = 0; i < entries.size(); i++) { diff --git a/llama/llama.go b/llama/llama.go index e0804ebddd..84618590f9 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -504,7 +504,12 @@ func (c *MtmdContext) Free() { C.mtmd_free(c.c) } -func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, error) { +type MtmdChunk struct { + Embed []float32 + Tokens []int +} + +func (c *MtmdContext) MultimodalTokenize(llamaContext *Context, data []byte) ([]MtmdChunk, error) { // Initialize the input chunks pointer ic := C.mtmd_input_chunks_init() defer C.mtmd_input_chunks_free(ic) @@ -523,35 +528,51 @@ func (c *MtmdContext) NewEmbed(llamaContext *Context, data []byte) ([][]float32, } nChunks := C.mtmd_input_chunks_size(ic) numEmbed := llamaContext.Model().NEmbd() - embed := make([][]float32, 0) + outChunks := make([]MtmdChunk, 0) for i := range int(nChunks) { chunk := C.mtmd_input_chunks_get(ic, C.size_t(i)) numTokens := int(C.mtmd_input_chunk_get_n_tokens(chunk)) slog.Debug("chunk tokens", "index", i, "numTokens", numTokens) - // Encode the chunk - if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { - return nil, errors.New("unable to encode mtmd image chunk") - } + if C.mtmd_input_chunk_get_type(chunk) == C.MTMD_INPUT_CHUNK_TYPE_TEXT { + // If this is a text chunk, add the tokens + cNumTokens := C.size_t(0) + cTokens := C.mtmd_input_chunk_get_tokens_text(chunk, &cNumTokens) + cTokensArr := unsafe.Slice(cTokens, int(cNumTokens)) + tokens := make([]int, int(cNumTokens)) + for j := range int(cNumTokens) { + tokens[j] = int(cTokensArr[j]) + } + outChunks = append(outChunks, MtmdChunk{Tokens: tokens}) + } else { + // Otherwise, encode the image chunk to embeddings - // Get the embeddings for this chunk - chunkEmbed := make([][]float32, numTokens) - chunkEmbd := C.mtmd_get_output_embd(c.c) - if nil == chunkEmbd { - continue - } + // Encode the chunk + if C.int32_t(0) != C.mtmd_encode_chunk(c.c, chunk) { + return nil, errors.New("unable to encode mtmd image chunk") + } - // Extend the embedding array for each token - s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) - rows := make([]float32, len(s)) - copy(rows, s) - for i := range numTokens { - chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + // Get the embeddings for this chunk + chunkEmbed := make([][]float32, numTokens) + chunkEmbd := C.mtmd_get_output_embd(c.c) + if nil == chunkEmbd { + return nil, errors.New("no mtmd image embedding") + } + + // Extend the embedding array for each token + s := unsafe.Slice((*float32)(chunkEmbd), numTokens*numEmbed) + rows := make([]float32, len(s)) + copy(rows, s) + for i := range numTokens { + chunkEmbed[i] = rows[i*numEmbed : (i+1)*numEmbed] + } + for _, e := range chunkEmbed { + outChunks = append(outChunks, MtmdChunk{Embed: e}) + } } - embed = append(embed, chunkEmbed...) } - slog.Debug("image embeddings", "totalEmbeddings", len(embed)) - return embed, nil + slog.Debug("image tokenization chunks", "totalChunks", len(outChunks)) + return outChunks, nil } func (c *Context) Synchronize() { diff --git a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch index 35b906d9e7..bcd60fb637 100644 --- a/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch +++ b/llama/patches/0001-ggml-backend-malloc-and-free-using-the-same-compiler.patch @@ -64,7 +64,7 @@ index ff9135fe..8ba86f82 100644 /* .init_tensor = */ NULL, // no initialization required /* .memset_tensor = */ ggml_backend_cpu_buffer_memset_tensor, diff --git a/ggml/src/ggml-cann/ggml-cann.cpp b/ggml/src/ggml-cann/ggml-cann.cpp -index b51b554e..3ba0f5a6 100755 +index ad1adba6..7d44f74f 100755 --- a/ggml/src/ggml-cann/ggml-cann.cpp +++ b/ggml/src/ggml-cann/ggml-cann.cpp @@ -843,6 +843,7 @@ static void ggml_backend_cann_buffer_free_buffer( @@ -84,7 +84,7 @@ index b51b554e..3ba0f5a6 100755 /** diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index b7e81b21..fdf8c63d 100644 +index 856e9de2..c0b1e4c1 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -567,6 +567,7 @@ struct ggml_backend_cuda_buffer_context { @@ -112,7 +112,7 @@ index b7e81b21..fdf8c63d 100644 static void * ggml_cuda_host_malloc(size_t size) { diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp -index e11555a7..909e17de 100644 +index 7afc881f..bf096227 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -25,6 +25,7 @@ static void ggml_backend_metal_buffer_shared_free_buffer(ggml_backend_buffer_t b @@ -132,10 +132,10 @@ index e11555a7..909e17de 100644 static void * ggml_backend_metal_buffer_private_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-opencl/ggml-opencl.cpp b/ggml/src/ggml-opencl/ggml-opencl.cpp -index 0cf3b924..09d706b5 100644 +index 79d21487..38c75018 100644 --- a/ggml/src/ggml-opencl/ggml-opencl.cpp +++ b/ggml/src/ggml-opencl/ggml-opencl.cpp -@@ -3215,6 +3215,7 @@ struct ggml_backend_opencl_buffer_context { +@@ -3212,6 +3212,7 @@ struct ggml_backend_opencl_buffer_context { static void ggml_backend_opencl_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_opencl_buffer_context * ctx = (ggml_backend_opencl_buffer_context *) buffer->context; delete ctx; @@ -144,10 +144,10 @@ index 0cf3b924..09d706b5 100644 static void * ggml_backend_opencl_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-rpc/ggml-rpc.cpp b/ggml/src/ggml-rpc/ggml-rpc.cpp -index f99681c8..59591770 100644 +index aad48d62..a46c0f52 100644 --- a/ggml/src/ggml-rpc/ggml-rpc.cpp +++ b/ggml/src/ggml-rpc/ggml-rpc.cpp -@@ -505,6 +505,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -528,6 +528,7 @@ static void ggml_backend_rpc_buffer_free_buffer(ggml_backend_buffer_t buffer) { bool status = send_rpc_cmd(ctx->sock, RPC_CMD_FREE_BUFFER, &request, sizeof(request), nullptr, 0); RPC_STATUS_ASSERT(status); delete ctx; @@ -156,10 +156,10 @@ index f99681c8..59591770 100644 static void * ggml_backend_rpc_buffer_get_base(ggml_backend_buffer_t buffer) { diff --git a/ggml/src/ggml-sycl/ggml-sycl.cpp b/ggml/src/ggml-sycl/ggml-sycl.cpp -index 4ac919ea..447ea3c4 100644 +index 45b8c216..4ec9a592 100644 --- a/ggml/src/ggml-sycl/ggml-sycl.cpp +++ b/ggml/src/ggml-sycl/ggml-sycl.cpp -@@ -331,6 +331,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { +@@ -334,6 +334,7 @@ ggml_backend_sycl_buffer_free_buffer(ggml_backend_buffer_t buffer) try { ggml_sycl_set_device(ctx->device); delete ctx; @@ -167,7 +167,7 @@ index 4ac919ea..447ea3c4 100644 } catch (sycl::exception const &exc) { std::cerr << exc.what() << "Exception caught at file:" << __FILE__ -@@ -792,6 +793,7 @@ struct ggml_backend_sycl_split_buffer_context { +@@ -795,6 +796,7 @@ struct ggml_backend_sycl_split_buffer_context { static void ggml_backend_sycl_split_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_sycl_split_buffer_context * ctx = (ggml_backend_sycl_split_buffer_context *)buffer->context; delete ctx; @@ -175,7 +175,7 @@ index 4ac919ea..447ea3c4 100644 } static void * ggml_backend_sycl_split_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -1134,6 +1136,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ +@@ -1137,6 +1139,7 @@ static const char * ggml_backend_sycl_host_buffer_type_name(ggml_backend_buffer_ static void ggml_backend_sycl_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_sycl_host_free(buffer->context); @@ -184,10 +184,10 @@ index 4ac919ea..447ea3c4 100644 static ggml_backend_buffer_t ggml_backend_sycl_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -index 2608cbd0..061cd078 100644 +index 3cd89c71..ed83236f 100644 --- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp +++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp -@@ -11603,6 +11603,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { +@@ -11600,6 +11600,7 @@ static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; ggml_vk_destroy_buffer(ctx->dev_buffer); delete ctx; @@ -195,7 +195,7 @@ index 2608cbd0..061cd078 100644 } static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { -@@ -11746,6 +11747,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe +@@ -11743,6 +11744,7 @@ static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffe static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); ggml_vk_host_free(vk_instance.devices[0], buffer->context); diff --git a/llama/patches/0002-pretokenizer.patch b/llama/patches/0002-pretokenizer.patch index 9d26bd548a..aacb1566dc 100644 --- a/llama/patches/0002-pretokenizer.patch +++ b/llama/patches/0002-pretokenizer.patch @@ -10,10 +10,10 @@ logs instead of throwing an error 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index da938af0..2a38abf4 100644 +index 7fffd171..0b6edaf4 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp -@@ -1811,16 +1811,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -1812,16 +1812,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { if (type == LLAMA_VOCAB_TYPE_BPE) { add_space_prefix = false; clean_spaces = true; @@ -31,7 +31,7 @@ index da938af0..2a38abf4 100644 pre_type = LLAMA_VOCAB_PRE_TYPE_DEFAULT; } else if ( tokenizer_pre == "llama3" || -@@ -1987,7 +1978,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -1992,7 +1983,8 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { pre_type = LLAMA_VOCAB_PRE_TYPE_GROK_2; clean_spaces = false; } else { diff --git a/llama/patches/0003-clip-unicode.patch b/llama/patches/0003-clip-unicode.patch index 4be92b89f0..3ba3742b02 100644 --- a/llama/patches/0003-clip-unicode.patch +++ b/llama/patches/0003-clip-unicode.patch @@ -10,7 +10,7 @@ filesystems for paths that include wide characters 1 file changed, 39 insertions(+) diff --git a/tools/mtmd/clip.cpp b/tools/mtmd/clip.cpp -index 210ecc88..355219a9 100644 +index 98e68af2..6699b75a 100644 --- a/tools/mtmd/clip.cpp +++ b/tools/mtmd/clip.cpp @@ -28,6 +28,19 @@ @@ -33,7 +33,7 @@ index 210ecc88..355219a9 100644 struct clip_logger_state g_logger_state = {GGML_LOG_LEVEL_CONT, clip_log_callback_default, NULL}; enum ffn_op_type { -@@ -2759,7 +2772,29 @@ struct clip_model_loader { +@@ -2762,7 +2775,29 @@ struct clip_model_loader { { std::vector read_buf; @@ -63,7 +63,7 @@ index 210ecc88..355219a9 100644 if (!fin) { throw std::runtime_error(string_format("%s: failed to open %s\n", __func__, fname.c_str())); } -@@ -2786,7 +2821,11 @@ struct clip_model_loader { +@@ -2789,7 +2824,11 @@ struct clip_model_loader { ggml_backend_tensor_set(cur, read_buf.data(), 0, num_bytes); } } diff --git a/llama/patches/0004-solar-pro.patch b/llama/patches/0004-solar-pro.patch index be37e5e8c4..631cba2a90 100644 --- a/llama/patches/0004-solar-pro.patch +++ b/llama/patches/0004-solar-pro.patch @@ -9,13 +9,13 @@ adds support for the Solar Pro architecture src/llama-arch.h | 3 + src/llama-hparams.cpp | 8 ++ src/llama-hparams.h | 5 + - src/llama-model-loader.cpp | 1 + + src/llama-model-loader.cpp | 2 +- src/llama-model.cpp | 207 +++++++++++++++++++++++++++++++++++++ src/llama-model.h | 3 + - 7 files changed, 248 insertions(+) + 7 files changed, 248 insertions(+), 1 deletion(-) diff --git a/src/llama-arch.cpp b/src/llama-arch.cpp -index 4e8d54c4..f98a3574 100644 +index 869e4dcc..9f6b6ad2 100644 --- a/src/llama-arch.cpp +++ b/src/llama-arch.cpp @@ -81,6 +81,7 @@ static const std::map LLM_ARCH_NAMES = { @@ -26,7 +26,7 @@ index 4e8d54c4..f98a3574 100644 { LLM_ARCH_WAVTOKENIZER_DEC, "wavtokenizer-dec" }, { LLM_ARCH_PLM, "plm" }, { LLM_ARCH_BAILINGMOE, "bailingmoe" }, -@@ -177,6 +178,7 @@ static const std::map LLM_KV_NAMES = { +@@ -179,6 +180,7 @@ static const std::map LLM_KV_NAMES = { { LLM_KV_ATTENTION_SCALE, "%s.attention.scale" }, { LLM_KV_ATTENTION_OUTPUT_SCALE, "%s.attention.output_scale" }, { LLM_KV_ATTENTION_TEMPERATURE_LENGTH, "%s.attention.temperature_length" }, @@ -34,7 +34,7 @@ index 4e8d54c4..f98a3574 100644 { LLM_KV_ATTENTION_KEY_LENGTH_MLA, "%s.attention.key_length_mla" }, { LLM_KV_ATTENTION_VALUE_LENGTH_MLA, "%s.attention.value_length_mla" }, -@@ -1879,6 +1881,24 @@ static const std::map> LLM_TENSOR_N +@@ -1893,6 +1895,24 @@ static const std::map> LLM_TENSOR_N { LLM_TENSOR_ATTN_K_NORM, "blk.%d.attn_k_norm" }, }, }, @@ -59,7 +59,7 @@ index 4e8d54c4..f98a3574 100644 { LLM_ARCH_WAVTOKENIZER_DEC, { -@@ -2368,6 +2388,7 @@ static const std::map LLM_TENSOR_INFOS = { +@@ -2429,6 +2449,7 @@ static const std::map LLM_TENSOR_INFOS = { {LLM_TENSOR_LAUREL_POST_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, // this tensor is loaded for T5, but never used {LLM_TENSOR_DEC_CROSS_ATTN_REL_B, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_NONE}}, @@ -68,7 +68,7 @@ index 4e8d54c4..f98a3574 100644 {LLM_TENSOR_POS_NET_NORM, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, {LLM_TENSOR_POS_NET_NORM1, {LLM_TENSOR_LAYER_REPEATING, GGML_OP_MUL}}, diff --git a/src/llama-arch.h b/src/llama-arch.h -index b5c6f3d7..aa8e0e7b 100644 +index c3ae7165..dc7a362a 100644 --- a/src/llama-arch.h +++ b/src/llama-arch.h @@ -85,6 +85,7 @@ enum llm_arch { @@ -79,7 +79,7 @@ index b5c6f3d7..aa8e0e7b 100644 LLM_ARCH_WAVTOKENIZER_DEC, LLM_ARCH_PLM, LLM_ARCH_BAILINGMOE, -@@ -181,6 +182,7 @@ enum llm_kv { +@@ -183,6 +184,7 @@ enum llm_kv { LLM_KV_ATTENTION_SCALE, LLM_KV_ATTENTION_OUTPUT_SCALE, LLM_KV_ATTENTION_TEMPERATURE_LENGTH, @@ -87,7 +87,7 @@ index b5c6f3d7..aa8e0e7b 100644 LLM_KV_ATTENTION_KEY_LENGTH_MLA, LLM_KV_ATTENTION_VALUE_LENGTH_MLA, -@@ -417,6 +419,7 @@ enum llm_tensor { +@@ -432,6 +434,7 @@ enum llm_tensor { LLM_TENSOR_ENC_OUTPUT_NORM, LLM_TENSOR_CLS, LLM_TENSOR_CLS_OUT, @@ -96,10 +96,10 @@ index b5c6f3d7..aa8e0e7b 100644 LLM_TENSOR_CONVNEXT_DW, LLM_TENSOR_CONVNEXT_NORM, diff --git a/src/llama-hparams.cpp b/src/llama-hparams.cpp -index c04ac58f..24a515a0 100644 +index db65d69e..b6bf6bbf 100644 --- a/src/llama-hparams.cpp +++ b/src/llama-hparams.cpp -@@ -147,6 +147,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { +@@ -151,6 +151,14 @@ uint32_t llama_hparams::n_pos_per_embd() const { return rope_type == LLAMA_ROPE_TYPE_MROPE ? 4 : 1; } @@ -115,7 +115,7 @@ index c04ac58f..24a515a0 100644 if (il < n_layer) { return swa_layers[il]; diff --git a/src/llama-hparams.h b/src/llama-hparams.h -index 0fe4b569..eb13709f 100644 +index 4e7f73ec..80582728 100644 --- a/src/llama-hparams.h +++ b/src/llama-hparams.h @@ -64,6 +64,8 @@ struct llama_hparams { @@ -127,7 +127,7 @@ index 0fe4b569..eb13709f 100644 uint32_t n_layer_dense_lead = 0; uint32_t n_lora_q = 0; uint32_t n_lora_kv = 0; -@@ -236,6 +238,9 @@ struct llama_hparams { +@@ -248,6 +250,9 @@ struct llama_hparams { uint32_t n_pos_per_embd() const; @@ -138,22 +138,23 @@ index 0fe4b569..eb13709f 100644 bool has_kv(uint32_t il) const; diff --git a/src/llama-model-loader.cpp b/src/llama-model-loader.cpp -index 8182a9ad..daef900c 100644 +index aa3a65f8..ee303bd5 100644 --- a/src/llama-model-loader.cpp +++ b/src/llama-model-loader.cpp -@@ -465,6 +465,7 @@ namespace GGUFMeta { - // TODO: this is not very clever - figure out something better +@@ -466,7 +466,7 @@ namespace GGUFMeta { template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); + template bool llama_model_loader::get_key_or_arr>(enum llm_kv kid, std::array & result, uint32_t n, bool required); +- + template bool llama_model_loader::get_key_or_arr(const std::string & key, std::array & result, uint32_t n, bool required); llama_model_loader::llama_model_loader( const std::string & fname, diff --git a/src/llama-model.cpp b/src/llama-model.cpp -index 2470f878..0398b553 100644 +index 36d495d6..74e1d162 100644 --- a/src/llama-model.cpp +++ b/src/llama-model.cpp -@@ -1845,6 +1845,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { +@@ -1865,6 +1865,21 @@ void llama_model::load_hparams(llama_model_loader & ml) { default: type = LLM_TYPE_UNKNOWN; } } break; @@ -175,7 +176,7 @@ index 2470f878..0398b553 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { ml.get_key(LLM_KV_ATTENTION_LAYERNORM_EPS, hparams.f_norm_eps); -@@ -5113,6 +5128,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { +@@ -5170,6 +5185,34 @@ bool llama_model::load_tensors(llama_model_loader & ml) { layer.ffn_norm = create_tensor(tn(LLM_TENSOR_FFN_NORM, "weight", i), {n_embd}, 0); @@ -210,7 +211,7 @@ index 2470f878..0398b553 100644 layer.ffn_gate = create_tensor(tn(LLM_TENSOR_FFN_GATE, "weight", i), {n_embd, n_ff}, 0); layer.ffn_down = create_tensor(tn(LLM_TENSOR_FFN_DOWN, "weight", i), { n_ff, n_embd}, 0); layer.ffn_up = create_tensor(tn(LLM_TENSOR_FFN_UP, "weight", i), {n_embd, n_ff}, 0); -@@ -16273,6 +16316,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { +@@ -16392,6 +16435,165 @@ struct llm_build_granite_hybrid : public llm_graph_context_mamba { } }; @@ -376,7 +377,7 @@ index 2470f878..0398b553 100644 // ref: https://github.com/facebookresearch/chameleon // based on the original build_llama() function, changes: // * qk-norm -@@ -19552,6 +19754,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { +@@ -19827,6 +20029,10 @@ ggml_cgraph * llama_model::build_graph(const llm_graph_params & params) const { { llm = std::make_unique(*this, params); } break; @@ -387,7 +388,7 @@ index 2470f878..0398b553 100644 case LLM_ARCH_WAVTOKENIZER_DEC: { llm = std::make_unique(*this, params); -@@ -19770,6 +19976,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { +@@ -20057,6 +20263,7 @@ llama_rope_type llama_model_rope_type(const llama_model * model) { case LLM_ARCH_GRANITE_MOE: case LLM_ARCH_GRANITE_HYBRID: case LLM_ARCH_CHAMELEON: @@ -396,7 +397,7 @@ index 2470f878..0398b553 100644 case LLM_ARCH_NEO_BERT: case LLM_ARCH_SMOLLM3: diff --git a/src/llama-model.h b/src/llama-model.h -index d73ce969..c086f94e 100644 +index 7f48662f..ec3fbd33 100644 --- a/src/llama-model.h +++ b/src/llama-model.h @@ -76,6 +76,7 @@ enum llm_type { @@ -407,9 +408,9 @@ index d73ce969..c086f94e 100644 LLM_TYPE_27B, LLM_TYPE_30B, LLM_TYPE_32B, -@@ -380,6 +381,8 @@ struct llama_layer { - // openai-moe - struct ggml_tensor * attn_sinks = nullptr; +@@ -387,6 +388,8 @@ struct llama_layer { + struct ggml_tensor * ffn_act_beta = nullptr; + struct ggml_tensor * ffn_act_eps = nullptr; + struct ggml_tensor * bskcn_tv = nullptr; + diff --git a/llama/patches/0005-fix-deepseek-deseret-regex.patch b/llama/patches/0005-fix-deepseek-deseret-regex.patch index 657de79de4..127fcc372d 100644 --- a/llama/patches/0005-fix-deepseek-deseret-regex.patch +++ b/llama/patches/0005-fix-deepseek-deseret-regex.patch @@ -12,7 +12,7 @@ regex 2 files changed, 22 insertions(+), 1 deletion(-) diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index 2a38abf4..26fa9fad 100644 +index 0b6edaf4..3de95c67 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp @@ -299,7 +299,7 @@ struct llm_tokenizer_bpe : llm_tokenizer { diff --git a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch index 5a52723bf7..43fc8a0b2c 100644 --- a/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch +++ b/llama/patches/0008-add-phony-target-ggml-cpu-for-all-cpu-variants.patch @@ -8,10 +8,10 @@ Subject: [PATCH] add phony target ggml-cpu for all cpu variants 1 file changed, 2 insertions(+) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index c8f3d859..ff6229a0 100644 +index 892c2331..09fdf5fc 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -307,6 +307,7 @@ function(ggml_add_cpu_backend_variant tag_name) +@@ -310,6 +310,7 @@ function(ggml_add_cpu_backend_variant tag_name) endif() ggml_add_cpu_backend_variant_impl(${tag_name}) @@ -19,7 +19,7 @@ index c8f3d859..ff6229a0 100644 endfunction() ggml_add_backend(CPU) -@@ -317,6 +318,7 @@ if (GGML_CPU_ALL_VARIANTS) +@@ -320,6 +321,7 @@ if (GGML_CPU_ALL_VARIANTS) elseif (GGML_CPU_ARM_ARCH) message(FATAL_ERROR "Cannot use both GGML_CPU_ARM_ARCH and GGML_CPU_ALL_VARIANTS") endif() diff --git a/llama/patches/0009-remove-amx.patch b/llama/patches/0009-remove-amx.patch index bc5980a603..6b0b90f339 100644 --- a/llama/patches/0009-remove-amx.patch +++ b/llama/patches/0009-remove-amx.patch @@ -9,10 +9,10 @@ disable amx as it reduces performance on some systems 1 file changed, 4 deletions(-) diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index ff6229a0..33b3a15f 100644 +index 09fdf5fc..0609c650 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -327,10 +327,6 @@ if (GGML_CPU_ALL_VARIANTS) +@@ -330,10 +330,6 @@ if (GGML_CPU_ALL_VARIANTS) ggml_add_cpu_backend_variant(skylakex SSE42 AVX F16C AVX2 BMI2 FMA AVX512) ggml_add_cpu_backend_variant(icelake SSE42 AVX F16C AVX2 BMI2 FMA AVX512 AVX512_VBMI AVX512_VNNI) ggml_add_cpu_backend_variant(alderlake SSE42 AVX F16C AVX2 BMI2 FMA AVX_VNNI) diff --git a/llama/patches/0010-fix-string-arr-kv-loading.patch b/llama/patches/0010-fix-string-arr-kv-loading.patch index 8a59a520df..29a313491f 100644 --- a/llama/patches/0010-fix-string-arr-kv-loading.patch +++ b/llama/patches/0010-fix-string-arr-kv-loading.patch @@ -53,10 +53,10 @@ index 8cc4ef1c..d950dbdf 100644 } diff --git a/src/llama-vocab.cpp b/src/llama-vocab.cpp -index 26fa9fad..64c78a16 100644 +index 3de95c67..217ede47 100644 --- a/src/llama-vocab.cpp +++ b/src/llama-vocab.cpp -@@ -1767,9 +1767,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { +@@ -1768,9 +1768,7 @@ void llama_vocab::impl::load(llama_model_loader & ml, const LLM_KV & kv) { const int precompiled_charsmap_keyidx = gguf_find_key(ctx, kv(LLM_KV_TOKENIZER_PRECOMPILED_CHARSMAP).c_str()); if (precompiled_charsmap_keyidx != -1) { const gguf_type pc_type = gguf_get_arr_type(ctx, precompiled_charsmap_keyidx); diff --git a/llama/patches/0011-ollama-debug-tensor.patch b/llama/patches/0011-ollama-debug-tensor.patch index 0bbe89b315..21edb8ba36 100644 --- a/llama/patches/0011-ollama-debug-tensor.patch +++ b/llama/patches/0011-ollama-debug-tensor.patch @@ -8,7 +8,7 @@ Subject: [PATCH] ollama debug tensor 1 file changed, 6 insertions(+) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index dbc07301..f8574d01 100644 +index ba2a36d9..99509b0c 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c @@ -15,6 +15,8 @@ @@ -20,7 +20,7 @@ index dbc07301..f8574d01 100644 #if defined(_MSC_VER) || defined(__MINGW32__) #include // using malloc.h with MSC/MINGW #elif !defined(__FreeBSD__) && !defined(__NetBSD__) && !defined(__OpenBSD__) -@@ -2881,6 +2883,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { +@@ -2887,6 +2889,10 @@ static thread_ret_t ggml_graph_compute_thread(void * data) { ggml_compute_forward(¶ms, node); diff --git a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch index 1677f59ae4..b4ad69cfd2 100644 --- a/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch +++ b/llama/patches/0012-add-ollama-vocab-for-grammar-support.patch @@ -184,7 +184,7 @@ index f8c291de..2a3a62db 100644 const char * grammar_root, bool lazy, diff --git a/src/llama-sampling.cpp b/src/llama-sampling.cpp -index 2186f827..8fb86009 100644 +index 55d2e355..da34526b 100644 --- a/src/llama-sampling.cpp +++ b/src/llama-sampling.cpp @@ -1563,7 +1563,7 @@ static void llama_sampler_grammar_reset(struct llama_sampler * smpl) { diff --git a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch index 14138c673e..f87c8c38cb 100644 --- a/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch +++ b/llama/patches/0013-add-argsort-and-cuda-copy-for-i32.patch @@ -12,10 +12,10 @@ Subject: [PATCH] add argsort and cuda copy for i32 5 files changed, 256 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp -index 14f7dcf4..f7f8da35 100644 +index 1c43865f..31478dd8 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp -@@ -7893,6 +7893,45 @@ static void ggml_compute_forward_argsort_f32( +@@ -7889,6 +7889,45 @@ static void ggml_compute_forward_argsort_f32( } } @@ -61,7 +61,7 @@ index 14f7dcf4..f7f8da35 100644 void ggml_compute_forward_argsort( const ggml_compute_params * params, ggml_tensor * dst) { -@@ -7904,6 +7943,10 @@ void ggml_compute_forward_argsort( +@@ -7900,6 +7939,10 @@ void ggml_compute_forward_argsort( { ggml_compute_forward_argsort_f32(params, dst); } break; @@ -272,10 +272,10 @@ index 746f4396..911220e9 100644 ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) { diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal -index 96df6f0c..44dc31c0 100644 +index 74a9aa99..375a0c7f 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal -@@ -4428,8 +4428,72 @@ kernel void kernel_argsort_f32_i32( +@@ -4346,8 +4346,72 @@ kernel void kernel_argsort_f32_i32( } } diff --git a/llama/patches/0014-graph-memory-reporting-on-failure.patch b/llama/patches/0014-graph-memory-reporting-on-failure.patch index 12d68dbad3..a3f0fc7085 100644 --- a/llama/patches/0014-graph-memory-reporting-on-failure.patch +++ b/llama/patches/0014-graph-memory-reporting-on-failure.patch @@ -23,10 +23,10 @@ index 2cb150fd..7ab3f019 100644 // Utils // Create a buffer and allocate all the tensors in a ggml_context diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 62b6d65e..fe20dca3 100644 +index f1b74078..c54ff98b 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -316,6 +316,7 @@ extern "C" { +@@ -318,6 +318,7 @@ extern "C" { GGML_API ggml_backend_buffer_type_t ggml_backend_sched_get_buffer_type(ggml_backend_sched_t sched, ggml_backend_t backend); GGML_API size_t ggml_backend_sched_get_buffer_size(ggml_backend_sched_t sched, ggml_backend_t backend); @@ -35,10 +35,10 @@ index 62b6d65e..fe20dca3 100644 GGML_API void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend); GGML_API ggml_backend_t ggml_backend_sched_get_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node); diff --git a/ggml/src/ggml-alloc.c b/ggml/src/ggml-alloc.c -index fa46f3b4..421ff7c7 100644 +index 929bc448..eee9d3b1 100644 --- a/ggml/src/ggml-alloc.c +++ b/ggml/src/ggml-alloc.c -@@ -492,6 +492,7 @@ struct node_alloc { +@@ -486,6 +486,7 @@ struct node_alloc { struct ggml_gallocr { ggml_backend_buffer_type_t * bufts; // [n_buffers] struct vbuffer ** buffers; // [n_buffers] @@ -46,7 +46,7 @@ index fa46f3b4..421ff7c7 100644 struct ggml_dyn_tallocr ** buf_tallocs; // [n_buffers] int n_buffers; -@@ -515,6 +516,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs +@@ -509,6 +510,9 @@ ggml_gallocr_t ggml_gallocr_new_n(ggml_backend_buffer_type_t * bufts, int n_bufs galloc->buffers = calloc(n_bufs, sizeof(struct vbuffer *)); GGML_ASSERT(galloc->buffers != NULL); @@ -56,7 +56,7 @@ index fa46f3b4..421ff7c7 100644 galloc->buf_tallocs = calloc(n_bufs, sizeof(struct ggml_dyn_tallocr *)); GGML_ASSERT(galloc->buf_tallocs != NULL); -@@ -582,6 +586,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { +@@ -576,6 +580,7 @@ void ggml_gallocr_free(ggml_gallocr_t galloc) { ggml_hash_set_free(&galloc->hash_set); free(galloc->hash_values); free(galloc->bufts); @@ -64,7 +64,7 @@ index fa46f3b4..421ff7c7 100644 free(galloc->buffers); free(galloc->buf_tallocs); free(galloc->node_allocs); -@@ -875,6 +880,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c +@@ -869,6 +874,8 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } @@ -73,7 +73,7 @@ index fa46f3b4..421ff7c7 100644 // reallocate buffers if needed for (int i = 0; i < galloc->n_buffers; i++) { // if the buffer type is used multiple times, we reuse the same buffer -@@ -896,14 +903,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c +@@ -898,14 +905,19 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c ggml_vbuffer_free(galloc->buffers[i]); galloc->buffers[i] = ggml_vbuffer_alloc(galloc->bufts[i], galloc->buf_tallocs[i], GGML_BACKEND_BUFFER_USAGE_COMPUTE); @@ -96,7 +96,7 @@ index fa46f3b4..421ff7c7 100644 } bool ggml_gallocr_reserve(ggml_gallocr_t galloc, struct ggml_cgraph *graph) { -@@ -1058,6 +1070,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { +@@ -1060,6 +1072,22 @@ size_t ggml_gallocr_get_buffer_size(ggml_gallocr_t galloc, int buffer_id) { return ggml_vbuffer_size(galloc->buffers[buffer_id]); } diff --git a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch index 813ae42560..b58d23d927 100644 --- a/llama/patches/0015-ggml-Export-GPU-UUIDs.patch +++ b/llama/patches/0015-ggml-Export-GPU-UUIDs.patch @@ -12,7 +12,7 @@ with tools (e.g. nvidia-smi) and system management libraries (e.g. nvml). 3 files changed, 63 insertions(+), 6 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index fe20dca3..48777212 100644 +index c54ff98b..229bf387 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -158,6 +158,7 @@ extern "C" { @@ -24,7 +24,7 @@ index fe20dca3..48777212 100644 size_t memory_total; // device type diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index fdf8c63d..ad389ece 100644 +index c0b1e4c1..5b852f69 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -183,6 +183,51 @@ static int ggml_cuda_parse_id(char devName[]) { @@ -110,7 +110,7 @@ index fdf8c63d..ad389ece 100644 std::string device_name(prop.name); if (device_name == "NVIDIA GeForce MX450") { turing_devices_without_mma.push_back({ id, device_name }); -@@ -3273,6 +3320,7 @@ struct ggml_backend_cuda_device_context { +@@ -3276,6 +3323,7 @@ struct ggml_backend_cuda_device_context { std::string name; std::string description; std::string pci_bus_id; @@ -118,7 +118,7 @@ index fdf8c63d..ad389ece 100644 }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -3285,6 +3333,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t +@@ -3288,6 +3336,11 @@ static const char * ggml_backend_cuda_device_get_description(ggml_backend_dev_t return ctx->description.c_str(); } @@ -130,7 +130,7 @@ index fdf8c63d..ad389ece 100644 static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); -@@ -3301,6 +3354,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -3304,6 +3357,7 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->name = ggml_backend_cuda_device_get_name(dev); props->description = ggml_backend_cuda_device_get_description(dev); @@ -138,7 +138,7 @@ index fdf8c63d..ad389ece 100644 props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ggml_backend_cuda_device_get_memory(dev, &props->memory_free, &props->memory_total); -@@ -3871,6 +3925,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -3873,6 +3927,7 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { cudaDeviceProp prop; CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); dev_ctx->description = prop.name; @@ -147,7 +147,7 @@ index fdf8c63d..ad389ece 100644 char pci_bus_id[16] = {}; snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp -index 909e17de..08ab4fc9 100644 +index bf096227..f2ff9f32 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -538,6 +538,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen diff --git a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch index d34c0b7896..422d633b45 100644 --- a/llama/patches/0016-add-C-API-for-mtmd_input_text.patch +++ b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch @@ -10,11 +10,11 @@ Signed-off-by: Gabe Goodhart 2 files changed, 13 insertions(+) diff --git a/tools/mtmd/mtmd.cpp b/tools/mtmd/mtmd.cpp -index cd022c5e..3d680945 100644 +index 4d487581..35a0d25e 100644 --- a/tools/mtmd/mtmd.cpp +++ b/tools/mtmd/mtmd.cpp @@ -79,6 +79,16 @@ enum mtmd_slice_tmpl { - // TODO @ngxson : add support for idefics (SmolVLM) + MTMD_SLICE_TMPL_IDEFICS3, }; +mtmd_input_text* mtmd_input_text_init(const char * text, bool add_special, bool parse_special) { diff --git a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch index 9912e57c96..279e42c353 100644 --- a/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch +++ b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch @@ -8,10 +8,10 @@ Subject: [PATCH] no power throttling win32 with gnuc 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/ggml/src/ggml-cpu/ggml-cpu.c b/ggml/src/ggml-cpu/ggml-cpu.c -index f8574d01..530efce0 100644 +index 99509b0c..b13a491d 100644 --- a/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ggml/src/ggml-cpu/ggml-cpu.c -@@ -2431,7 +2431,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { +@@ -2437,7 +2437,7 @@ static bool ggml_thread_apply_priority(int32_t prio) { // Newer Windows 11 versions aggresively park (offline) CPU cores and often place // all our threads onto the first 4 cores which results in terrible performance with // n_threads > 4 diff --git a/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch b/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch index d88c13591a..85cba5b3a7 100644 --- a/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch +++ b/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch @@ -13,10 +13,10 @@ checks. 1 file changed, 18 insertions(+) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index ad389ece..e51c5035 100644 +index 5b852f69..827e3205 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu -@@ -2686,14 +2686,26 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud +@@ -2689,14 +2689,26 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud // Loop over nodes in GGML graph to obtain info needed for CUDA graph cuda_ctx->cuda_graph->cpy_dest_ptrs.clear(); @@ -43,7 +43,7 @@ index ad389ece..e51c5035 100644 for (int i = 0; i < cgraph->n_nodes; i++) { ggml_tensor * node = cgraph->nodes[i]; -@@ -2717,6 +2729,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud +@@ -2720,6 +2732,12 @@ static bool check_node_graph_compatibility_and_refresh_copy_ops(ggml_backend_cud if (node->op == GGML_OP_ADD && node->src[1] && node->src[1]->ne[1] > 1 && diff --git a/llama/patches/0022-ggml-No-alloc-mode.patch b/llama/patches/0022-ggml-No-alloc-mode.patch index 838398bdc5..019cb8869b 100644 --- a/llama/patches/0022-ggml-No-alloc-mode.patch +++ b/llama/patches/0022-ggml-No-alloc-mode.patch @@ -16,10 +16,10 @@ must be recreated with no-alloc set to false before loading data. 5 files changed, 310 insertions(+), 44 deletions(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 48777212..d4352663 100644 +index 229bf387..1ff53ed0 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h -@@ -303,6 +303,7 @@ extern "C" { +@@ -305,6 +305,7 @@ extern "C" { // Initialize a backend scheduler, backends with low index are given priority over backends with high index GGML_API ggml_backend_sched_t ggml_backend_sched_new(ggml_backend_t * backends, ggml_backend_buffer_type_t * bufts, int n_backends, size_t graph_size, bool parallel, bool op_offload); @@ -28,7 +28,7 @@ index 48777212..d4352663 100644 // Initialize backend buffers from a measure graph diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index 07784d6f..869dc07d 100644 +index 6792ba98..3c3f22fc 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -26,12 +26,17 @@ extern "C" { @@ -218,7 +218,7 @@ index cb2b9956..6ef5eeaf 100644 void ggml_backend_sched_set_tensor_backend(ggml_backend_sched_t sched, struct ggml_tensor * node, ggml_backend_t backend) { diff --git a/ggml/src/ggml-cuda/common.cuh b/ggml/src/ggml-cuda/common.cuh -index c4246b65..448badf0 100644 +index e0abde54..28d6bcd7 100644 --- a/ggml/src/ggml-cuda/common.cuh +++ b/ggml/src/ggml-cuda/common.cuh @@ -35,6 +35,31 @@ @@ -253,7 +253,7 @@ index c4246b65..448badf0 100644 #define STRINGIZE_IMPL(...) #__VA_ARGS__ #define STRINGIZE(...) STRINGIZE_IMPL(__VA_ARGS__) -@@ -880,6 +905,9 @@ struct ggml_cuda_pool { +@@ -856,6 +881,9 @@ struct ggml_cuda_pool { virtual void * alloc(size_t size, size_t * actual_size) = 0; virtual void free(void * ptr, size_t size) = 0; @@ -263,7 +263,7 @@ index c4246b65..448badf0 100644 }; template -@@ -1023,11 +1051,11 @@ struct ggml_backend_cuda_context { +@@ -999,11 +1027,11 @@ struct ggml_backend_cuda_context { // pool std::unique_ptr pools[GGML_CUDA_MAX_DEVICES]; @@ -277,7 +277,7 @@ index c4246b65..448badf0 100644 } return *pools[device]; } -@@ -1035,4 +1063,20 @@ struct ggml_backend_cuda_context { +@@ -1011,4 +1039,20 @@ struct ggml_backend_cuda_context { ggml_cuda_pool & pool() { return pool(device); } @@ -299,7 +299,7 @@ index c4246b65..448badf0 100644 + } }; diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index e51c5035..d324bc68 100644 +index 827e3205..811462c7 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -350,6 +350,8 @@ const ggml_cuda_device_info & ggml_cuda_info() { @@ -540,7 +540,7 @@ index e51c5035..d324bc68 100644 }; ggml_backend_buffer_type_t ggml_backend_cuda_buffer_type(int device) { -@@ -3008,6 +3070,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, +@@ -3011,6 +3073,7 @@ static bool ggml_cuda_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx, ggml_cgraph * cgraph, bool & graph_evaluated_or_captured, bool & use_cuda_graph, bool & cuda_graph_update_required) { @@ -548,7 +548,7 @@ index e51c5035..d324bc68 100644 // flag used to determine whether it is an integrated_gpu const bool integrated = ggml_cuda_info().devices[cuda_ctx->device].integrated; -@@ -3023,6 +3086,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx +@@ -3026,6 +3089,11 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx continue; } @@ -560,7 +560,7 @@ index e51c5035..d324bc68 100644 static bool disable_fusion = (getenv("GGML_CUDA_DISABLE_FUSION") != nullptr); if (!disable_fusion) { -@@ -3149,6 +3217,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx +@@ -3152,6 +3220,7 @@ static void evaluate_and_capture_cuda_graph(ggml_backend_cuda_context * cuda_ctx static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; @@ -568,7 +568,7 @@ index e51c5035..d324bc68 100644 ggml_cuda_set_device(cuda_ctx->device); -@@ -3228,6 +3297,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, +@@ -3231,6 +3300,71 @@ static enum ggml_status ggml_backend_cuda_graph_compute(ggml_backend_t backend, return GGML_STATUS_SUCCESS; } @@ -640,7 +640,7 @@ index e51c5035..d324bc68 100644 static void ggml_backend_cuda_event_record(ggml_backend_t backend, ggml_backend_event_t event) { ggml_backend_cuda_context * cuda_ctx = (ggml_backend_cuda_context *)backend->context; -@@ -3268,6 +3402,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { +@@ -3271,6 +3405,9 @@ static const ggml_backend_i ggml_backend_cuda_interface = { /* .event_record = */ ggml_backend_cuda_event_record, /* .event_wait = */ ggml_backend_cuda_event_wait, /* .graph_optimize = */ NULL, diff --git a/llama/patches/0023-decode-disable-output_all.patch b/llama/patches/0023-decode-disable-output_all.patch index fc7a411778..ddf281bb91 100644 --- a/llama/patches/0023-decode-disable-output_all.patch +++ b/llama/patches/0023-decode-disable-output_all.patch @@ -8,7 +8,7 @@ Subject: [PATCH] decode: disable output_all 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/src/llama-context.cpp b/src/llama-context.cpp -index d8a8b5e6..09247cef 100644 +index e7526e7d..53a5e3a9 100644 --- a/src/llama-context.cpp +++ b/src/llama-context.cpp @@ -974,8 +974,7 @@ int llama_context::decode(const llama_batch & batch_inp) { diff --git a/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch b/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch index 5794516469..0df3be07fb 100644 --- a/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch +++ b/llama/patches/0024-ggml-Enable-resetting-backend-devices.patch @@ -10,12 +10,12 @@ unused then it can be reset to free these data structures. ggml/include/ggml-backend.h | 1 + ggml/src/ggml-backend-impl.h | 4 ++++ ggml/src/ggml-backend.cpp | 8 ++++++++ - ggml/src/ggml-cuda/ggml-cuda.cu | 17 +++++++++++++++-- + ggml/src/ggml-cuda/ggml-cuda.cu | 16 +++++++++++++++- ggml/src/ggml-cuda/vendors/hip.h | 1 + - 5 files changed, 29 insertions(+), 2 deletions(-) + 5 files changed, 29 insertions(+), 1 deletion(-) diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index d4352663..0a2dae26 100644 +index 1ff53ed0..ba181d09 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -178,6 +178,7 @@ extern "C" { @@ -27,7 +27,7 @@ index d4352663..0a2dae26 100644 GGML_API ggml_backend_buffer_type_t ggml_backend_dev_host_buffer_type(ggml_backend_dev_t device); GGML_API ggml_backend_buffer_t ggml_backend_dev_buffer_from_host_ptr(ggml_backend_dev_t device, void * ptr, size_t size, size_t max_tensor_size); diff --git a/ggml/src/ggml-backend-impl.h b/ggml/src/ggml-backend-impl.h -index 869dc07d..4889df79 100644 +index 3c3f22fc..43c91d9f 100644 --- a/ggml/src/ggml-backend-impl.h +++ b/ggml/src/ggml-backend-impl.h @@ -195,6 +195,10 @@ extern "C" { @@ -61,7 +61,7 @@ index 6ef5eeaf..0b757af5 100644 GGML_ASSERT(device); return device->iface.get_buffer_type(device); diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index d324bc68..531d6e27 100644 +index 811462c7..87c6c34a 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -107,6 +107,11 @@ int ggml_cuda_get_device() { @@ -76,7 +76,7 @@ index d324bc68..531d6e27 100644 static cudaError_t ggml_cuda_device_malloc(void ** ptr, size_t size, int device) { ggml_cuda_set_device(device); cudaError_t err; -@@ -3512,7 +3517,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -3515,7 +3520,10 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back props->id = ggml_backend_cuda_device_get_id(dev); props->type = ggml_backend_cuda_device_get_type(dev); props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); @@ -88,7 +88,7 @@ index d324bc68..531d6e27 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY -@@ -3945,6 +3953,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g +@@ -3948,6 +3956,11 @@ static void ggml_backend_cuda_device_event_synchronize(ggml_backend_dev_t dev, g CUDA_CHECK(cudaEventSynchronize((cudaEvent_t)event->context)); } @@ -100,7 +100,7 @@ index d324bc68..531d6e27 100644 static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .get_name = */ ggml_backend_cuda_device_get_name, /* .get_description = */ ggml_backend_cuda_device_get_description, -@@ -3961,6 +3974,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { +@@ -3964,6 +3977,7 @@ static const ggml_backend_device_i ggml_backend_cuda_device_interface = { /* .event_new = */ ggml_backend_cuda_device_event_new, /* .event_free = */ ggml_backend_cuda_device_event_free, /* .event_synchronize = */ ggml_backend_cuda_device_event_synchronize, @@ -108,19 +108,11 @@ index d324bc68..531d6e27 100644 }; // backend reg -@@ -4076,7 +4090,6 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { - dev_ctx->device = i; - dev_ctx->name = GGML_CUDA_NAME + std::to_string(i); - -- ggml_cuda_set_device(i); - cudaDeviceProp prop; - CUDA_CHECK(cudaGetDeviceProperties(&prop, i)); - dev_ctx->description = prop.name; diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index 37386afc..06f9e7c1 100644 +index 890c1036..1f06be80 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h -@@ -41,6 +41,7 @@ +@@ -45,6 +45,7 @@ #define cudaDeviceDisablePeerAccess hipDeviceDisablePeerAccess #define cudaDeviceEnablePeerAccess hipDeviceEnablePeerAccess #define cudaDeviceProp hipDeviceProp_t diff --git a/llama/patches/0026-GPU-discovery-enhancements.patch b/llama/patches/0026-GPU-discovery-enhancements.patch index 6addfd5125..e7073914be 100644 --- a/llama/patches/0026-GPU-discovery-enhancements.patch +++ b/llama/patches/0026-GPU-discovery-enhancements.patch @@ -9,17 +9,17 @@ management libraries for more accurate VRAM usage reporting if available. ggml/include/ggml-backend.h | 9 + ggml/src/CMakeLists.txt | 2 + ggml/src/ggml-cuda/ggml-cuda.cu | 72 +++++ - ggml/src/ggml-cuda/vendors/hip.h | 4 + + ggml/src/ggml-cuda/vendors/hip.h | 3 + ggml/src/ggml-impl.h | 8 + - ggml/src/ggml-metal/ggml-metal.cpp | 3 +- + ggml/src/ggml-metal/ggml-metal.cpp | 2 + ggml/src/mem_hip.cpp | 449 +++++++++++++++++++++++++++++ ggml/src/mem_nvml.cpp | 209 ++++++++++++++ - 8 files changed, 755 insertions(+), 1 deletion(-) + 8 files changed, 754 insertions(+) create mode 100644 ggml/src/mem_hip.cpp create mode 100644 ggml/src/mem_nvml.cpp diff --git a/ggml/include/ggml-backend.h b/ggml/include/ggml-backend.h -index 0a2dae26a..a6bf33785 100644 +index ba181d09..09ff75f9 100644 --- a/ggml/include/ggml-backend.h +++ b/ggml/include/ggml-backend.h @@ -169,6 +169,15 @@ extern "C" { @@ -39,10 +39,10 @@ index 0a2dae26a..a6bf33785 100644 GGML_API const char * ggml_backend_dev_name(ggml_backend_dev_t device); diff --git a/ggml/src/CMakeLists.txt b/ggml/src/CMakeLists.txt -index 33b3a15f0..86191ef2c 100644 +index 0609c650..aefe43bd 100644 --- a/ggml/src/CMakeLists.txt +++ b/ggml/src/CMakeLists.txt -@@ -206,6 +206,8 @@ add_library(ggml-base +@@ -209,6 +209,8 @@ add_library(ggml-base ggml-threading.h ggml-quants.c ggml-quants.h @@ -52,7 +52,7 @@ index 33b3a15f0..86191ef2c 100644 target_include_directories(ggml-base PRIVATE .) diff --git a/ggml/src/ggml-cuda/ggml-cuda.cu b/ggml/src/ggml-cuda/ggml-cuda.cu -index 531d6e272..3fa3a0575 100644 +index 87c6c34a..6a278b5e 100644 --- a/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ggml/src/ggml-cuda/ggml-cuda.cu @@ -261,6 +261,16 @@ static ggml_cuda_device_info ggml_cuda_init() { @@ -84,7 +84,7 @@ index 531d6e272..3fa3a0575 100644 GGML_LOG_INFO(" Device %d: %s, compute capability %d.%d, VMM: %s, ID: %s\n", id, prop.name, prop.major, prop.minor, device_vmm ? "yes" : "no", ggml_cuda_parse_uuid(prop, id).c_str()); -@@ -3481,6 +3496,14 @@ struct ggml_backend_cuda_device_context { +@@ -3484,6 +3499,14 @@ struct ggml_backend_cuda_device_context { std::string description; std::string pci_bus_id; std::string id; @@ -99,7 +99,7 @@ index 531d6e272..3fa3a0575 100644 }; static const char * ggml_backend_cuda_device_get_name(ggml_backend_dev_t dev) { -@@ -3501,6 +3524,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { +@@ -3504,6 +3527,28 @@ static const char * ggml_backend_cuda_device_get_id(ggml_backend_dev_t dev) { static void ggml_backend_cuda_device_get_memory(ggml_backend_dev_t dev, size_t * free, size_t * total) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; ggml_cuda_set_device(ctx->device); @@ -128,7 +128,7 @@ index 531d6e272..3fa3a0575 100644 CUDA_CHECK(cudaMemGetInfo(free, total)); } -@@ -3509,6 +3554,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend +@@ -3512,6 +3557,7 @@ static enum ggml_backend_dev_type ggml_backend_cuda_device_get_type(ggml_backend return GGML_BACKEND_DEVICE_TYPE_GPU; } @@ -136,7 +136,7 @@ index 531d6e272..3fa3a0575 100644 static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { ggml_backend_cuda_device_context * ctx = (ggml_backend_cuda_device_context *)dev->context; -@@ -3522,6 +3568,22 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back +@@ -3525,6 +3571,22 @@ static void ggml_backend_cuda_device_get_props(ggml_backend_dev_t dev, ggml_back // If you need the memory data, call ggml_backend_dev_memory() explicitly. props->memory_total = props->memory_free = 0; @@ -159,7 +159,7 @@ index 531d6e272..3fa3a0575 100644 bool host_buffer = getenv("GGML_CUDA_NO_PINNED") == nullptr; #ifdef GGML_CUDA_NO_PEER_COPY bool events = false; -@@ -4084,6 +4146,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -4087,6 +4149,8 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { std::lock_guard lock(mutex); if (!initialized) { ggml_backend_cuda_reg_context * ctx = new ggml_backend_cuda_reg_context; @@ -168,7 +168,7 @@ index 531d6e272..3fa3a0575 100644 for (int i = 0; i < ggml_cuda_info().device_count; i++) { ggml_backend_cuda_device_context * dev_ctx = new ggml_backend_cuda_device_context; -@@ -4099,6 +4163,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { +@@ -4102,6 +4166,14 @@ ggml_backend_reg_t ggml_backend_cuda_reg() { snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.0", prop.pciDomainID, prop.pciBusID, prop.pciDeviceID); dev_ctx->pci_bus_id = pci_bus_id; @@ -184,20 +184,19 @@ index 531d6e272..3fa3a0575 100644 /* .iface = */ ggml_backend_cuda_device_interface, /* .reg = */ ®, diff --git a/ggml/src/ggml-cuda/vendors/hip.h b/ggml/src/ggml-cuda/vendors/hip.h -index 06f9e7c1e..eb8f66cb0 100644 +index 1f06be80..2f9ef2dc 100644 --- a/ggml/src/ggml-cuda/vendors/hip.h +++ b/ggml/src/ggml-cuda/vendors/hip.h -@@ -5,6 +5,9 @@ +@@ -5,6 +5,8 @@ #include #include #include +// for rocblas_initialize() +#include "rocblas/rocblas.h" -+ - #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT - #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT -@@ -43,6 +46,7 @@ + #if defined(GGML_HIP_ROCWMMA_FATTN) + #include +@@ -47,6 +49,7 @@ #define cudaDeviceProp hipDeviceProp_t #define cudaDeviceReset hipDeviceReset #define cudaDeviceSynchronize hipDeviceSynchronize @@ -206,10 +205,10 @@ index 06f9e7c1e..eb8f66cb0 100644 #define cudaErrorPeerAccessAlreadyEnabled hipErrorPeerAccessAlreadyEnabled #define cudaErrorPeerAccessNotEnabled hipErrorPeerAccessNotEnabled diff --git a/ggml/src/ggml-impl.h b/ggml/src/ggml-impl.h -index 86a1ebf62..9fc9fbfcf 100644 +index d0fb3bcc..80597b6e 100644 --- a/ggml/src/ggml-impl.h +++ b/ggml/src/ggml-impl.h -@@ -635,6 +635,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx +@@ -638,6 +638,14 @@ static inline bool ggml_can_fuse(const struct ggml_cgraph * cgraph, int node_idx return ggml_can_fuse_ext(cgraph, idxs, ops, num_ops); } @@ -225,7 +224,7 @@ index 86a1ebf62..9fc9fbfcf 100644 } #endif diff --git a/ggml/src/ggml-metal/ggml-metal.cpp b/ggml/src/ggml-metal/ggml-metal.cpp -index 08ab4fc91..17999a616 100644 +index f2ff9f32..f356e4a0 100644 --- a/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ggml/src/ggml-metal/ggml-metal.cpp @@ -535,6 +535,7 @@ static enum ggml_backend_dev_type ggml_backend_metal_device_get_type(ggml_backen @@ -236,18 +235,17 @@ index 08ab4fc91..17999a616 100644 static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_backend_dev_props * props) { props->name = ggml_backend_metal_device_get_name(dev); props->description = ggml_backend_metal_device_get_description(dev); -@@ -542,7 +543,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac - props->type = ggml_backend_metal_device_get_type(dev); +@@ -543,6 +544,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); -- + + props->library = GGML_METAL_NAME; props->caps = { /* .async = */ true, /* .host_buffer = */ false, diff --git a/ggml/src/mem_hip.cpp b/ggml/src/mem_hip.cpp new file mode 100644 -index 000000000..8ef19b8cf +index 00000000..8ef19b8c --- /dev/null +++ b/ggml/src/mem_hip.cpp @@ -0,0 +1,449 @@ @@ -703,7 +701,7 @@ index 000000000..8ef19b8cf \ No newline at end of file diff --git a/ggml/src/mem_nvml.cpp b/ggml/src/mem_nvml.cpp new file mode 100644 -index 000000000..c9073cef0 +index 00000000..c9073cef --- /dev/null +++ b/ggml/src/mem_nvml.cpp @@ -0,0 +1,209 @@ diff --git a/ml/backend/ggml/ggml/include/ggml-backend.h b/ml/backend/ggml/ggml/include/ggml-backend.h index a6bf33785c..09ff75f932 100644 --- a/ml/backend/ggml/ggml/include/ggml-backend.h +++ b/ml/backend/ggml/ggml/include/ggml-backend.h @@ -226,6 +226,8 @@ extern "C" { // Backend registry // + GGML_API void ggml_backend_register(ggml_backend_reg_t reg); + GGML_API void ggml_backend_device_register(ggml_backend_dev_t device); // Backend (reg) enumeration diff --git a/ml/backend/ggml/ggml/include/ggml-rpc.h b/ml/backend/ggml/ggml/include/ggml-rpc.h index 1e67411276..72eff00273 100644 --- a/ml/backend/ggml/ggml/include/ggml-rpc.h +++ b/ml/backend/ggml/ggml/include/ggml-rpc.h @@ -7,26 +7,25 @@ extern "C" { #endif -#define RPC_PROTO_MAJOR_VERSION 2 +#define RPC_PROTO_MAJOR_VERSION 3 #define RPC_PROTO_MINOR_VERSION 0 #define RPC_PROTO_PATCH_VERSION 0 #define GGML_RPC_MAX_SERVERS 16 // backend API -GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint); +GGML_BACKEND_API ggml_backend_t ggml_backend_rpc_init(const char * endpoint, uint32_t device); GGML_BACKEND_API bool ggml_backend_is_rpc(ggml_backend_t backend); -GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint); +GGML_BACKEND_API ggml_backend_buffer_type_t ggml_backend_rpc_buffer_type(const char * endpoint, uint32_t device); -GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, size_t * free, size_t * total); +GGML_BACKEND_API void ggml_backend_rpc_get_device_memory(const char * endpoint, uint32_t device, size_t * free, size_t * total); -GGML_BACKEND_API void ggml_backend_rpc_start_server(ggml_backend_t backend, const char * endpoint, - const char * cache_dir, - size_t free_mem, size_t total_mem); +GGML_BACKEND_API void ggml_backend_rpc_start_server(const char * endpoint, const char * cache_dir, + size_t n_threads, size_t n_devices, + ggml_backend_dev_t * devices, size_t * free_mem, size_t * total_mem); GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_reg(void); - -GGML_BACKEND_API ggml_backend_dev_t ggml_backend_rpc_add_device(const char * endpoint); +GGML_BACKEND_API ggml_backend_reg_t ggml_backend_rpc_add_server(const char * endpoint); #ifdef __cplusplus } diff --git a/ml/backend/ggml/ggml/include/ggml.h b/ml/backend/ggml/ggml/include/ggml.h index 36b23dc6d0..60c6b63d05 100644 --- a/ml/backend/ggml/ggml/include/ggml.h +++ b/ml/backend/ggml/ggml/include/ggml.h @@ -237,6 +237,8 @@ #define GGML_EXIT_SUCCESS 0 #define GGML_EXIT_ABORTED 1 +// TODO: convert to enum https://github.com/ggml-org/llama.cpp/pull/16187#discussion_r2388538726 +#define GGML_ROPE_TYPE_NORMAL 0 #define GGML_ROPE_TYPE_NEOX 2 #define GGML_ROPE_TYPE_MROPE 8 #define GGML_ROPE_TYPE_VISION 24 @@ -574,6 +576,7 @@ extern "C" { GGML_UNARY_OP_HARDSIGMOID, GGML_UNARY_OP_EXP, GGML_UNARY_OP_GELU_ERF, + GGML_UNARY_OP_XIELU, GGML_UNARY_OP_COUNT, }; @@ -1148,6 +1151,18 @@ extern "C" { struct ggml_context * ctx, struct ggml_tensor * a); + // xIELU activation function + // x = x * (c_a(alpha_n) + c_b(alpha_p, beta) * sigmoid(beta * x)) + eps * (x > 0) + // where c_a = softplus and c_b(a, b) = softplus(a) + b are constraining functions + // that constrain the positive and negative source alpha values respectively + GGML_API struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps); + // gated linear unit ops // A: n columns, r rows, // result is n / 2 columns, r rows, @@ -1615,6 +1630,13 @@ extern "C" { float scale, float max_bias); + GGML_API struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias); + GGML_API void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks); diff --git a/ml/backend/ggml/ggml/src/CMakeLists.txt b/ml/backend/ggml/ggml/src/CMakeLists.txt index 86191ef2ca..aefe43bdd5 100644 --- a/ml/backend/ggml/ggml/src/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/CMakeLists.txt @@ -145,6 +145,9 @@ endif() # which was introduced in POSIX.1-2008, forcing us to go higher if (CMAKE_SYSTEM_NAME MATCHES "OpenBSD") add_compile_definitions(_XOPEN_SOURCE=700) +elseif (CMAKE_SYSTEM_NAME MATCHES "AIX") + # Don't define _XOPEN_SOURCE. We need _ALL_SOURCE, which is the default, + # in order to define _SC_PHYS_PAGES. else() add_compile_definitions(_XOPEN_SOURCE=600) endif() diff --git a/ml/backend/ggml/ggml/src/ggml-alloc.c b/ml/backend/ggml/ggml/src/ggml-alloc.c index 421ff7c7a3..eee9d3b113 100644 --- a/ml/backend/ggml/ggml/src/ggml-alloc.c +++ b/ml/backend/ggml/ggml/src/ggml-alloc.c @@ -392,12 +392,8 @@ static void ggml_dyn_tallocr_free(struct ggml_dyn_tallocr * alloc) { free(alloc); } -static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc) { - size_t max_size = 0; - for (int i = 0; i < alloc->n_chunks; i++) { - max_size += alloc->chunks[i]->max_size; - } - return max_size; +static size_t ggml_dyn_tallocr_max_size(struct ggml_dyn_tallocr * alloc, int chunk) { + return chunk < alloc->n_chunks ? alloc->chunks[chunk]->max_size : 0; } @@ -417,10 +413,8 @@ static void ggml_vbuffer_free(struct vbuffer * buf) { free(buf); } -static int ggml_vbuffer_n_chunks(struct vbuffer * buf) { - int n = 0; - while (n < GGML_VBUFFER_MAX_CHUNKS && buf->chunks[n]) n++; - return n; +static size_t ggml_vbuffer_chunk_size(struct vbuffer * buf, int chunk) { + return buf->chunks[chunk] ? ggml_backend_buffer_get_size(buf->chunks[chunk]) : 0; } static size_t ggml_vbuffer_size(struct vbuffer * buf) { @@ -892,12 +886,20 @@ bool ggml_gallocr_reserve_n(ggml_gallocr_t galloc, struct ggml_cgraph * graph, c } } - size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; - size_t new_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i]); - // even if there are no tensors allocated in this buffer, we still need to allocate it to initialize views - if (new_size > cur_size || galloc->buffers[i] == NULL) { + bool realloc = galloc->buffers[i] == NULL; + size_t new_size = 0; + for (int c = 0; c < galloc->buf_tallocs[i]->n_chunks; c++) { + size_t cur_chunk_size = galloc->buffers[i] ? ggml_vbuffer_chunk_size(galloc->buffers[i], c) : 0; + size_t new_chunk_size = ggml_dyn_tallocr_max_size(galloc->buf_tallocs[i], c); + new_size += new_chunk_size; + if (new_chunk_size > cur_chunk_size) { + realloc = true; + } + } + if (realloc) { #ifndef NDEBUG + size_t cur_size = galloc->buffers[i] ? ggml_vbuffer_size(galloc->buffers[i]) : 0; GGML_LOG_DEBUG("%s: reallocating %s buffer from size %.02f MiB to %.02f MiB\n", __func__, ggml_backend_buft_name(galloc->bufts[i]), cur_size / 1024.0 / 1024.0, new_size / 1024.0 / 1024.0); #endif diff --git a/ml/backend/ggml/ggml/src/ggml-backend-impl.h b/ml/backend/ggml/ggml/src/ggml-backend-impl.h index 4889df79f7..43c91d9f27 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-backend-impl.h @@ -229,9 +229,6 @@ extern "C" { void * context; }; - // Internal backend registry API - GGML_API void ggml_backend_register(ggml_backend_reg_t reg); - // Add backend dynamic loading support to the backend // Initialize the backend diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp index 867e158dca..895a571375 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/amx/amx.cpp @@ -149,6 +149,7 @@ class extra_buffer_type : ggml::cpu::extra_buffer_type { if (op->op == GGML_OP_MUL_MAT && is_contiguous_2d(op->src[0]) && // src0 must be contiguous is_contiguous_2d(op->src[1]) && // src1 must be contiguous op->src[0]->buffer && op->src[0]->buffer->buft == ggml_backend_amx_buffer_type() && + op->src[0]->ne[0] % (TILE_K * 2 * 32) == 0 && // TODO: not sure if correct (https://github.com/ggml-org/llama.cpp/pull/16315) op->ne[0] % (TILE_N * 2) == 0 && // out_features is 32x (qtype_has_amx_kernels(op->src[0]->type) || (op->src[0]->type == GGML_TYPE_F16))) { // src1 must be host buffer diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h index 799e2b1187..713bf85e5a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu-impl.h @@ -68,7 +68,7 @@ struct ggml_compute_params { #endif // __VXE2__ #endif // __s390x__ && __VEC__ -#if defined(__ARM_FEATURE_SVE) +#if defined(__ARM_FEATURE_SVE) && defined(__linux__) #include #endif diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c index 530efce008..b13a491df7 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ggml-cpu.c @@ -691,8 +691,13 @@ bool ggml_is_numa(void) { #endif static void ggml_init_arm_arch_features(void) { -#if defined(__linux__) && defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__aarch64__) && defined(__ARM_FEATURE_SVE) +#if defined(__linux__) ggml_arm_arch_features.sve_cnt = PR_SVE_VL_LEN_MASK & prctl(PR_SVE_GET_VL); +#else + // TODO: add support of SVE for non-linux systems +#error "TODO: SVE is not supported on this platform. To use SVE, sve_cnt needs to be initialized here." +#endif #endif } @@ -2189,6 +2194,7 @@ static int ggml_get_n_tasks(struct ggml_tensor * node, int n_threads) { case GGML_UNARY_OP_GELU_ERF: case GGML_UNARY_OP_GELU_QUICK: case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_XIELU: { n_tasks = n_threads; } break; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp index f7f8da352e..31478dd8e4 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/ops.cpp @@ -3467,31 +3467,27 @@ static void ggml_compute_forward_norm_f32( GGML_ASSERT(eps >= 0.0f); - // TODO: optimize for (int64_t i03 = 0; i03 < ne03; i03++) { for (int64_t i02 = 0; i02 < ne02; i02++) { for (int64_t i01 = ith; i01 < ne01; i01 += nth) { const float * x = (float *) ((char *) src0->data + i01*nb01 + i02*nb02 + i03*nb03); - ggml_float sum = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - sum += (ggml_float)x[i00]; - } - + float sum = 0.0; + ggml_vec_sum_f32(ne00, &sum, x); float mean = sum/ne00; float * y = (float *) ((char *) dst->data + i01*nb1 + i02*nb2 + i03*nb3); + float variance = 0; - ggml_float sum2 = 0.0; - for (int64_t i00 = 0; i00 < ne00; i00++) { - float v = x[i00] - mean; - y[i00] = v; - sum2 += (ggml_float)(v*v); - } +#ifdef GGML_USE_ACCELERATE + mean = -mean; + vDSP_vsadd(x, 1, &mean, y, 1, ne00); + vDSP_measqv(y, 1, &variance, ne00); +#else + variance = ggml_vec_cvar_f32(ne00, y, x, mean); +#endif //GGML_USE_ACCELERATE - float variance = sum2/ne00; const float scale = 1.0f/sqrtf(variance + eps); - ggml_vec_scale_f32(ne00, y, scale); } } @@ -8178,7 +8174,7 @@ static void ggml_compute_forward_flash_attn_ext_f16( } // V /= S - const float S_inv = 1.0f/S; + const float S_inv = S == 0.0f ? 0.0f : 1.0f/S; ggml_vec_scale_f32(DV, VKQ32, S_inv); // dst indices @@ -8680,7 +8676,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); const float dA = expf(dt_soft_plus * A[h]); const int g = h / (nh / ng); // repeat_interleave @@ -8777,7 +8773,7 @@ static void ggml_compute_forward_ssm_scan_f32( // n_head for (int h = ih0; h < ih1; ++h) { // ref: https://github.com/state-spaces/mamba/blob/62db608da60f6fc790b8ed9f4b3225e95ca15fde/mamba_ssm/ops/triton/softplus.py#L16 - const float dt_soft_plus = dt[h] <= 20.0f ? log1pf(expf(dt[h])) : dt[h]; + const float dt_soft_plus = ggml_softplus(dt[h]); const int g = h / (nh / ng); // repeat_interleave // dim @@ -9040,6 +9036,10 @@ void ggml_compute_forward_unary( { ggml_compute_forward_exp(params, dst); } break; + case GGML_UNARY_OP_XIELU: + { + ggml_compute_forward_xielu(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp index 4fce569b3b..cf1a4615d0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.cpp @@ -52,6 +52,15 @@ static inline float op_sqrt(float x) { return sqrtf(x); } +static inline float op_xielu(float x, float alpha_n, float alpha_p, float beta, float eps) { + if (x > 0.0f) { + return alpha_p * x * x + beta * x; + } else { + const float min_x_eps = fminf(x, eps); + return (expm1f(min_x_eps) - x) * alpha_n + beta * x; + } +} + static inline float op_sin(float x) { return sinf(x); } @@ -121,6 +130,86 @@ static void unary_op(const ggml_compute_params * params, ggml_tensor * dst) { } } +template +static void unary_op_params(const ggml_compute_params * params, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op(params, dst); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + +// Extend vec_unary_op to support functors +template +static inline void vec_unary_op_functor(int64_t n, dst_t * y, const src0_t * x, Op op) { + constexpr auto src0_to_f32 = type_conversion_table::to_f32; + constexpr auto f32_to_dst = type_conversion_table::from_f32; + + for (int i = 0; i < n; i++) { + y[i] = f32_to_dst(op(src0_to_f32(x[i]))); + } +} + +// Extend apply_unary_op to support functors +template +static void apply_unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + GGML_ASSERT(ggml_is_contiguous_1(src0) && ggml_is_contiguous_1(dst) && ggml_are_same_shape(src0, dst)); + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT( nb0 == sizeof(dst_t)); + GGML_ASSERT(nb00 == sizeof(src0_t)); + + const auto [ir0, ir1] = get_thread_range(params, src0); + + for (int64_t ir = ir0; ir < ir1; ++ir) { + const int64_t i03 = ir/(ne02*ne01); + const int64_t i02 = (ir - i03*ne02*ne01)/ne01; + const int64_t i01 = (ir - i03*ne02*ne01 - i02*ne01); + + dst_t * dst_ptr = (dst_t *) ((char *) dst->data + i03*nb3 + i02*nb2 + i01*nb1 ); + const src0_t * src0_ptr = (const src0_t *) ((const char *) src0->data + i03*nb03 + i02*nb02 + i01*nb01); + + vec_unary_op_functor(ne0, dst_ptr, src0_ptr, op); + } +} + +// Generic dispatcher for functors +template +static void unary_op_functor(const ggml_compute_params * params, ggml_tensor * dst, Op op) { + const ggml_tensor * src0 = dst->src[0]; + + /* */ if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { // all f32 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { // all f16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_BF16) { // all bf16 + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_BF16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + apply_unary_op_functor(params, dst, op); + } else { + fprintf(stderr, "%s: unsupported types: dst: %s, src0: %s\n", __func__, + ggml_type_name(dst->type), ggml_type_name(src0->type)); + GGML_ABORT("fatal error"); + } +} + void ggml_compute_forward_abs(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } @@ -184,3 +273,17 @@ void ggml_compute_forward_cos(const ggml_compute_params * params, ggml_tensor * void ggml_compute_forward_log(const ggml_compute_params * params, ggml_tensor * dst) { unary_op(params, dst); } + +void ggml_compute_forward_xielu(const ggml_compute_params * params, ggml_tensor * dst) { + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + const auto xielu_op_params = [alpha_n, alpha_p, beta, eps](float f) { + return op_xielu(f, alpha_n, alpha_p, beta, eps); + }; + + unary_op_functor(params, dst, xielu_op_params); +} + diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h index b1ade2c8e3..697c1e0da0 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/unary-ops.h @@ -22,6 +22,7 @@ void ggml_compute_forward_sqrt(const struct ggml_compute_params * params, struct void ggml_compute_forward_sin(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_cos(const struct ggml_compute_params * params, struct ggml_tensor * dst); void ggml_compute_forward_log(const struct ggml_compute_params * params, struct ggml_tensor * dst); +void ggml_compute_forward_xielu(const struct ggml_compute_params * params, struct ggml_tensor * dst); #ifdef __cplusplus } diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp index 437192d525..43dc7537c3 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.cpp @@ -404,6 +404,72 @@ void ggml_vec_swiglu_f32(const int n, float * y, const float * x, const float * } } +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean) { + int i = 0; + ggml_float sum = 0; +// TODO: optimize to process the remaining elements in groups using the smaller vector sizes from AVX2 and SSE +// ref: https://github.com/ggml-org/llama.cpp/pull/15953#pullrequestreview-3310928344 +#if defined(__AVX512F__) && defined(__AVX512DQ__) + for (; i + 15 < n; i += 16) { + __m512 val = _mm512_sub_ps(_mm512_loadu_ps(x + i), + _mm512_set1_ps(mean)); + _mm512_storeu_ps(y + i, val); + sum += (ggml_float)_mm512_reduce_add_ps(_mm512_mul_ps(val, val)); + } +#elif defined(__AVX2__) && defined(__FMA__) + for (; i + 7 < n; i += 8) { + __m256 val = _mm256_sub_ps(_mm256_loadu_ps(x + i), + _mm256_set1_ps(mean)); + _mm256_storeu_ps(y + i, val); + val = _mm256_mul_ps(val,val); + __m128 val2 = _mm_add_ps(_mm256_extractf128_ps(val, 1), + _mm256_castps256_ps128(val)); + val2 = _mm_add_ps(val2, _mm_movehl_ps(val2, val2)); + val2 = _mm_add_ss(val2, _mm_movehdup_ps(val2)); + sum += (ggml_float)_mm_cvtss_f32(val2); + } +#elif defined(__SSE2__) + for (; i + 3 < n; i += 4) { + __m128 val = _mm_sub_ps(_mm_loadu_ps(x + i), + _mm_set1_ps(mean)); + _mm_storeu_ps(y + i, val); + val = _mm_mul_ps(val, val); +#if defined(__AVX__) || defined(__AVX2__) || defined(__AVX512F__) + val = _mm_add_ps(val, _mm_movehl_ps(val, val)); + val = _mm_add_ss(val, _mm_movehdup_ps(val)); +#else + __m128 tmp = _mm_shuffle_ps(val, val, _MM_SHUFFLE(2, 3, 0, 1)); + val = _mm_add_ps(val, tmp); + tmp = _mm_movehl_ps(tmp, val); + val = _mm_add_ss(val, tmp); +#endif // __AVX__ || __AVX2__ || __AVX512F__ + sum += (ggml_float)_mm_cvtss_f32(val); + } +#elif defined(__ARM_NEON) && defined(__aarch64__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vsubq_f32(vld1q_f32(x + i), + vdupq_n_f32(mean)); + vst1q_f32(y + i, val); + val = vmulq_f32(val, val); + sum += (ggml_float)vaddvq_f32(val); + } +#elif defined(__VXE__) || defined(__VXE2__) + for (; i + 3 < n; i += 4) { + float32x4_t val = vec_sub(vec_xl(0, x + i), vec_splats(mean)); + vec_xst(val, 0, y + i); + val = vec_mul(val, val); + sum += (ggml_float)vec_hsum_f32x4(val); + } +#endif + for (; i < n; ++i) { + float val = x[i] - mean; + y[i] = val; + val *= val; + sum += (ggml_float)val; + } + return sum/n; +} + ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max) { int i = 0; ggml_float sum = 0; diff --git a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h index 341e64e64f..65c7dfb6b9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cpu/vec.h +++ b/ml/backend/ggml/ggml/src/ggml-cpu/vec.h @@ -44,6 +44,7 @@ void ggml_vec_dot_bf16(int n, float * GGML_RESTRICT s, size_t bs, ggml_bf16_t * void ggml_vec_dot_f16(int n, float * GGML_RESTRICT s, size_t bs, ggml_fp16_t * GGML_RESTRICT x, size_t bx, ggml_fp16_t * GGML_RESTRICT y, size_t by, int nrc); void ggml_vec_silu_f32(const int n, float * y, const float * x); +ggml_float ggml_vec_cvar_f32(const int n, float * y, const float * x, const float mean); //it will also center y ( y = y - mean ) ggml_float ggml_vec_soft_max_f32(const int n, float * y, const float * x, float max); ggml_float ggml_vec_log_soft_max_f32(const int n, float * y, const float * x, float max); @@ -143,14 +144,14 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG for (int i = 0; i < np; i += ggml_f16_step) { ay1 = GGML_F16x_VEC_LOAD(y + i + 0 * ggml_f16_epr, 0); // 8 elements - ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elemnst + ax1 = GGML_F16x_VEC_LOAD(x[0] + i + 0*ggml_f16_epr, 0); // 8 elements sum_00 = GGML_F16x_VEC_FMA(sum_00, ax1, ay1); // sum_00 = sum_00+ax1*ay1 ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 0*ggml_f16_epr, 0); // 8 elements sum_10 = GGML_F16x_VEC_FMA(sum_10, ax1, ay1); ay2 = GGML_F16x_VEC_LOAD(y + i + 1 * ggml_f16_epr, 1); // next 8 elements - ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 ekements + ax2 = GGML_F16x_VEC_LOAD(x[0] + i + 1*ggml_f16_epr, 1); // next 8 elements sum_01 = GGML_F16x_VEC_FMA(sum_01, ax2, ay2); ax2 = GGML_F16x_VEC_LOAD(x[1] + i + 1*ggml_f16_epr, 1); sum_11 = GGML_F16x_VEC_FMA(sum_11, ax2, ay2); @@ -159,7 +160,7 @@ inline static void ggml_vec_dot_f16_unroll(const int n, const int xs, float * GG ax3 = GGML_F16x_VEC_LOAD(x[0] + i + 2*ggml_f16_epr, 2); sum_02 = GGML_F16x_VEC_FMA(sum_02, ax3, ay3); - ax1 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); + ax3 = GGML_F16x_VEC_LOAD(x[1] + i + 2*ggml_f16_epr, 2); sum_12 = GGML_F16x_VEC_FMA(sum_12, ax3, ay3); ay4 = GGML_F16x_VEC_LOAD(y + i + 3 * ggml_f16_epr, 3); @@ -654,11 +655,11 @@ inline static void ggml_vec_scale_f32(const int n, float * y, const float v) { } // leftovers // maximum number of leftover elements will be less that ggml_f32_epr. Apply predicated svmad on available elements only - if (np < n) { - svbool_t pg = svwhilelt_b32(np, n); - ay1 = svld1_f32(pg, y + np); + for (int i = np; i < n; i += ggml_f32_epr) { + svbool_t pg = svwhilelt_b32(i, n); + ay1 = svld1_f32(pg, y + i); ay1 = svmul_f32_m(pg, ay1, vx); - svst1_f32(pg, y + np, ay1); + svst1_f32(pg, y + i, ay1); } #elif defined(__riscv_v_intrinsic) for (int i = 0, avl; i < n; i += avl) { @@ -819,7 +820,8 @@ inline static void ggml_vec_tanh_f16 (const int n, ggml_fp16_t * y, const ggml_f inline static void ggml_vec_elu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : expm1f(x[i]); } inline static void ggml_vec_elu_f16 (const int n, ggml_fp16_t * y, const ggml_fp16_t * x) { for (int i = 0; i < n; ++i) { - y[i] = GGML_CPU_FP32_TO_FP16(expm1f(GGML_CPU_FP16_TO_FP32(x[i]))); + const float v = GGML_CPU_FP16_TO_FP32(x[i]); + y[i] = GGML_CPU_FP32_TO_FP16((v > 0.f) ? v : expm1f(v)); } } inline static void ggml_vec_relu_f32 (const int n, float * y, const float * x) { for (int i = 0; i < n; ++i) y[i] = (x[i] > 0.f) ? x[i] : 0.f; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt index bdcefe7b7e..3024775135 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-cuda/CMakeLists.txt @@ -44,6 +44,8 @@ if (CUDAToolkit_FOUND) list(APPEND GGML_HEADERS_CUDA "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_CUDA "*.cu") + file(GLOB SRCS "template-instances/fattn-tile*.cu") + list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_CUDA ${SRCS}) file(GLOB SRCS "template-instances/mmq*.cu") diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh index 448badf0d3..28d6bcd71a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/common.cuh @@ -245,14 +245,6 @@ static const char * cu_get_error_str(CUresult err) { #define FAST_FP16_AVAILABLE #endif // defined(FP16_AVAILABLE) && __CUDA_ARCH__ != 610 -#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) -#define FP16_MMA_AVAILABLE -#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) - -#if defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) -#define FP16_MMA_AVAILABLE -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && (defined(CDNA) || defined(RDNA3) || (defined(GGML_HIP_ROCWMMA_FATTN_GFX12) && defined(RDNA4))) - #if defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) #define AMD_MFMA_AVAILABLE #endif // defined(GGML_USE_HIP) && defined(CDNA) && !defined(GGML_HIP_NO_MMQ_MFMA) @@ -278,7 +270,8 @@ static bool fp16_available(const int cc) { } static bool fast_fp16_available(const int cc) { - return (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && cc != 610) || GGML_CUDA_CC_IS_AMD(cc); + return GGML_CUDA_CC_IS_AMD(cc) || + (GGML_CUDA_CC_IS_NVIDIA(cc) && fp16_available(cc) && ggml_cuda_highest_compiled_arch(cc) != 610); } // To be used for feature selection of external libraries, e.g. cuBLAS. @@ -287,27 +280,6 @@ static bool fast_fp16_hardware_available(const int cc) { (GGML_CUDA_CC_IS_MTHREADS(cc) && cc >= GGML_CUDA_CC_QY2); } -// Any FP16 tensor core instructions are available for ggml code. -static bool fp16_mma_available(const int cc) { -#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) - return false; -#else - if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) >= GGML_CUDA_CC_VOLTA) || - GGML_CUDA_CC_IS_CDNA(cc) || GGML_CUDA_CC_IS_RDNA3(cc) || - GGML_CUDA_CC_IS_MTHREADS(cc)) { - return true; - } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { -#if defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - return true; -#else - return false; -#endif // defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_HIP_ROCWMMA_FATTN_GFX12) - } else { - return false; - } -#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) -} - // To be used for feature selection of external libraries, e.g. cuBLAS. static bool fp16_mma_hardware_available(const int cc) { return (GGML_CUDA_CC_IS_NVIDIA(cc) && cc >= GGML_CUDA_CC_VOLTA) || @@ -625,6 +597,10 @@ static __device__ __forceinline__ void ggml_cuda_mad(half2 & acc, const half2 v, } // Aligned memory transfers of 8/16 bytes can be faster than 2 transfers with 4 bytes, especially on AMD. +// Important: do not use this function if dst and src both point at registers. +// Due to the strict aliasing rule the compiler can do incorrect optimizations if src and dst have different types. +// The function is intended for copies between registers and SRAM/VRAM to make the compiler emit the right instructions. +// If dst and src point at different address spaces then they are guaranteed to not be aliased. template static __device__ __forceinline__ void ggml_cuda_memcpy_1(void * __restrict__ dst, const void * __restrict__ src) { if constexpr (alignment != 0) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh index 33d2f0f49e..bc0c2523cc 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-common.cuh @@ -793,8 +793,6 @@ void launch_fattn( GGML_ASSERT(!mask || mask->ne[1] >= GGML_PAD(Q->ne[1], 16) && "the Flash-Attention CUDA kernel requires the mask to be padded to 16 and at least n_queries big"); - GGML_ASSERT(K->ne[1] % FATTN_KQ_STRIDE == 0 && "Incorrect KV cache padding."); - ggml_cuda_pool & pool = ctx.pool(); cudaStream_t main_stream = ctx.stream(); const int id = ggml_cuda_get_device(); @@ -878,7 +876,7 @@ void launch_fattn( // Optional optimization where the mask is scanned to determine whether part of the calculation can be skipped. // Only worth the overhead if there is at lease one FATTN_KQ_STRIDE x FATTN_KQ_STRIDE square to be skipped or // multiple sequences of possibly different lengths. - if (mask && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { + if (mask && K->ne[1] % FATTN_KQ_STRIDE == 0 && (Q->ne[1] >= 1024 || Q->ne[3] > 1)) { const int s31 = mask->nb[1] / sizeof(half2); const int s33 = mask->nb[3] / sizeof(half2); @@ -916,8 +914,7 @@ void launch_fattn( dst_tmp_meta.alloc(blocks_num.x*ncols * (2*2 + DV) * sizeof(float)); } else { - GGML_ASSERT(K->ne[1] % KQ_row_granularity == 0); - const int ntiles_KQ = K->ne[1] / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. + const int ntiles_KQ = (K->ne[1] + KQ_row_granularity - 1) / KQ_row_granularity; // Max. number of parallel blocks limited by tensor size. // parallel_blocks must not be larger than what the tensor size allows: parallel_blocks = std::min(parallel_blocks, ntiles_KQ); @@ -946,7 +943,7 @@ void launch_fattn( blocks_num.x = ntiles_x; blocks_num.y = parallel_blocks; - blocks_num.z = Q->ne[2]*Q->ne[3]; + blocks_num.z = (Q->ne[2]/ncols2)*Q->ne[3]; if (parallel_blocks > 1) { dst_tmp.alloc(parallel_blocks*ggml_nelements(KQV)); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu index 131a5099a3..3a5806d909 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cu @@ -1,755 +1,45 @@ #include "common.cuh" -#include "fattn-common.cuh" #include "fattn-tile.cuh" - -// kq_stride == number of KQ rows to process per iteration -// kq_nbatch == number of K columns to load in parallel for KQ calculation - -static int fattn_tile_get_kq_stride_host(const int D, const int ncols, const int cc, const int warp_size) { - if (GGML_CUDA_CC_IS_AMD(cc)) { - if (GGML_CUDA_CC_IS_RDNA(cc)) { - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - if (fast_fp16_available(cc)) { - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - GGML_ABORT("fatal error"); - return -1; - } - } - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - GGML_ABORT("fatal error"); - return -1; - } - GGML_UNUSED(warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_stride_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP -#ifdef RDNA - switch (D) { - case 64: - return 128; - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols == 32 ? 128 : 64; - case 128: - return ncols == 32 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // RDNA -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - case 128: - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#else - switch (D) { - case 64: - return ncols <= 16 ? 128 : 64; - case 128: - return ncols <= 16 ? 64 : 32; - case 256: - return 32; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static constexpr __device__ int fattn_tile_get_kq_nbatch_device(int D, int ncols, int warp_size) { -#ifdef GGML_USE_HIP - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else -#ifdef FAST_FP16_AVAILABLE - switch (D) { - case 64: - return 64; - case 128: - case 256: - return 128; - default: - return -1; - } -#else - switch (D) { - case 64: - return 64; - case 128: - return 128; - case 256: - return ncols <= 16 ? 128 : 64; - default: - return -1; - } -#endif // FAST_FP16_AVAILABLE -#endif // GGML_USE_HIP - GGML_UNUSED_VARS(ncols, warp_size); -} - -static int fattn_tile_get_nthreads_host(const int cc, const int ncols) { - return 256; - GGML_UNUSED_VARS(cc, ncols); -} - -static constexpr __device__ int fattn_tile_get_nthreads_device(int ncols) { - return 256; - GGML_UNUSED(ncols); -} - -static constexpr __device__ int fattn_tile_get_occupancy_device(int ncols) { -#ifdef RDNA - return 3; -#else - return ncols <= 16 ? 3 : 2; -#endif // RDNA - GGML_UNUSED(ncols); -} - -template // D == head size -__launch_bounds__(fattn_tile_get_nthreads_device(ncols), fattn_tile_get_occupancy_device(ncols)) -static __global__ void flash_attn_tile( - const char * __restrict__ Q, - const char * __restrict__ K, - const char * __restrict__ V, - const char * __restrict__ mask, - const char * __restrict__ sinks, - const int * __restrict__ KV_max, - float * __restrict__ dst, - float2 * __restrict__ dst_meta, - const float scale, - const float max_bias, - const float m0, - const float m1, - const uint32_t n_head_log2, - const float logit_softcap, - const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, - const int32_t nb01, const int32_t nb02, const int32_t nb03, - const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, - const int32_t nb11, const int32_t nb12, const int64_t nb13, - const int32_t nb21, const int32_t nb22, const int64_t nb23, - const int32_t ne31, const int32_t ne32, const int32_t ne33, - const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#ifdef FLASH_ATTN_AVAILABLE - - // Skip unused kernel variants for faster compilation: -#ifdef FP16_MMA_AVAILABLE - NO_DEVICE_CODE; - return; -#endif // FP16_MMA_AVAILABLE - - if (use_logit_softcap && !(D == 128 || D == 256)) { - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; - return; - } - - constexpr int warp_size = 32; - constexpr int nwarps = fattn_tile_get_nthreads_device(ncols) / warp_size; - constexpr int kq_stride = fattn_tile_get_kq_stride_device(D, ncols, warp_size); - static_assert(kq_stride % warp_size == 0, "kq_stride not divisable by warp_size."); - constexpr int kq_nbatch = fattn_tile_get_kq_nbatch_device(D, ncols, warp_size); - static_assert(kq_nbatch % (2*warp_size) == 0, "bad kq_nbatch"); - - // In this kernel Q, K, V are matrices while i, j, k are matrix indices. - - const int ic0 = blockIdx.x * ncols; // Index of the Q/QKV column to work on. - - const int sequence = blockIdx.z / ne02; - const int head = blockIdx.z - sequence*ne02; - const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. - const float * Q_f = (const float *) (Q + nb03* sequence + nb02* head + nb01*ic0); - const half2 * K_h2 = (const half2 *) (K + nb13* sequence + nb12*(head / gqa_ratio)); - const half2 * V_h2 = (const half2 *) (V + nb13* sequence + nb12*(head / gqa_ratio)); // K and V have same shape - const half * maskh = (const half *) (mask + nb33*(sequence % ne33) + nb31*ic0); - const float * sinksf = (const float *) (sinks); - - const int stride_KV2 = nb11 / sizeof(half2); - - const float slope = get_alibi_slope(max_bias, head, n_head_log2, m0, m1); - - constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); - constexpr int cpy_ne = cpy_nb / 4; - - constexpr int cpw = ncols/nwarps; // cols per warp - - // softmax_iter_j == number of KQ columns for which to calculate softmax in parallel. - // KQ is originall 2D but uses a Z-shaped memory pattern for larger reads/writes. -#ifdef FAST_FP16_AVAILABLE - constexpr int softmax_iter_j = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; - - __shared__ half KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ half2 Q_tmp[ncols][D/2]; - __shared__ half2 KV_tmp[kq_stride * (kq_nbatch/2 + cpy_ne)]; // Padded to avoid memory bank conflicts. - half2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#else - constexpr int softmax_iter_j = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; - - __shared__ float KQ[ncols/softmax_iter_j][kq_stride][softmax_iter_j]; - __shared__ float Q_tmp[ncols][D]; - __shared__ float KV_tmp[kq_stride * (kq_nbatch + cpy_ne)]; // Padded to avoid memory bank conflicts. - float2 VKQ[cpw][D/(2*warp_size)] = {{{0.0f, 0.0f}}}; -#endif // FAST_FP16_AVAILABLE - static_assert(cpw % softmax_iter_j == 0, "bad softmax_iter_j"); - - float KQ_max[cpw]; -#pragma unroll - for (int j0 = 0; j0 < ncols; j0 += nwarps) { - KQ_max[j0/nwarps] = -FLT_MAX/2.0f; - } - float KQ_sum[cpw] = {0.0f}; - - // Load Q data, convert to FP16 if fast. -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - const int j = j0 + threadIdx.y*cpw; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; - -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - float tmp_f[cpy_ne_D] = {0.0f}; - if (ic0 + j < ne01) { - ggml_cuda_memcpy_1(tmp_f, &Q_f[j*(nb01/sizeof(float)) + i0 + threadIdx.x*cpy_ne_D]); - } - -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp_f[i1] *= scale; - } - -#ifdef FAST_FP16_AVAILABLE - half2 tmp_h2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { - tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); - } - ggml_cuda_memcpy_1(&Q_tmp[j][i0/2 + threadIdx.x*(cpy_ne_D/2)], tmp_h2); -#else - ggml_cuda_memcpy_1 (&Q_tmp[j][i0 + threadIdx.x* cpy_ne_D], tmp_f); -#endif // FAST_FP16_AVAILABLE - } - } - - __syncthreads(); - - // Main loop over KV cache: - const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; - for (int k_VKQ_0 = blockIdx.y*kq_stride; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*kq_stride) { - // Calculate KQ tile and keep track of new maximum KQ values: - - float KQ_max_new[cpw]; -#pragma unroll - for (int j = 0; j < cpw; ++j) { - KQ_max_new[j] = KQ_max[j]; - } - - float KQ_acc[kq_stride/warp_size][cpw] = {{0.0f}}; // Accumulators for KQ matrix multiplication. - - // KQ = K @ Q matrix multiplication: -#pragma unroll - for (int k_KQ_0 = 0; k_KQ_0 < D; k_KQ_0 += kq_nbatch) { -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += nwarps) { - const int i_KQ = i_KQ_0 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/(2*warp_size) ? cpy_ne : kq_nbatch/(2*warp_size); -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += warp_size*cpy_ne_kqnb) { - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], - &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1 + threadIdx.x*cpy_ne_kqnb]); - } -#else - constexpr int cpy_ne_kqnb = cpy_ne < kq_nbatch/warp_size ? cpy_ne : kq_nbatch/warp_size; -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += warp_size*cpy_ne_kqnb) { - half2 tmp_h2[cpy_ne_kqnb/2]; - ggml_cuda_memcpy_1( - tmp_h2, &K_h2[int64_t(k_VKQ_0 + i_KQ)*stride_KV2 + k_KQ_0/2 + k_KQ_1/2 + threadIdx.x*(cpy_ne_kqnb/2)]); - - float2 tmp_f2[cpy_ne_kqnb/2]; -#pragma unroll - for (int k_KQ_2 = 0; k_KQ_2 < cpy_ne_kqnb/2; ++k_KQ_2) { - tmp_f2[k_KQ_2] = __half22float2(tmp_h2[k_KQ_2]); - } - ggml_cuda_memcpy_1( - &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1 + threadIdx.x*cpy_ne_kqnb], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch/2; k_KQ_1 += cpy_ne) { - half2 K_k[kq_stride/warp_size][cpy_ne]; - half2 Q_k[cpw][cpy_ne]; -#else -#pragma unroll - for (int k_KQ_1 = 0; k_KQ_1 < kq_nbatch; k_KQ_1 += cpy_ne) { - float K_k[kq_stride/warp_size][cpy_ne]; - float Q_k[cpw][cpy_ne]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch/2 + cpy_ne) + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&K_k[i_KQ_0/warp_size], &KV_tmp[i_KQ*(kq_nbatch + cpy_ne) + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - -#ifdef FAST_FP16_AVAILABLE - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0/2 + k_KQ_1]); -#else - ggml_cuda_memcpy_1(&Q_k[j_KQ_0], &Q_tmp[j_KQ][k_KQ_0 + k_KQ_1]); -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { -#pragma unroll - for (int k = 0; k < cpy_ne; ++k) { - ggml_cuda_mad(KQ_acc[i_KQ_0/warp_size][j_KQ_0], K_k[i_KQ_0/warp_size][k], Q_k[j_KQ_0][k]); - } - } - } - } - - if (k_KQ_0 + kq_nbatch < D) { - __syncthreads(); // Sync not needed on last iteration. - } - } - - // Apply logit softcap, mask, update KQ_max: -#pragma unroll - for (int i_KQ_0 = 0; i_KQ_0 < kq_stride; i_KQ_0 += warp_size) { - const int i_KQ = i_KQ_0 + threadIdx.x; - -#pragma unroll - for (int j_KQ_0 = 0; j_KQ_0 < cpw; ++j_KQ_0) { - const int j_KQ = j_KQ_0 + threadIdx.y*cpw; - - if (use_logit_softcap) { - KQ_acc[i_KQ_0/warp_size][j_KQ_0] = logit_softcap * tanhf(KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - - KQ_acc[i_KQ_0/warp_size][j_KQ_0] += mask ? slope*__half2float(maskh[j_KQ*ne11 + k_VKQ_0 + i_KQ]) : 0.0f; - - KQ_max_new[j_KQ_0] = fmaxf(KQ_max_new[j_KQ_0], KQ_acc[i_KQ_0/warp_size][j_KQ_0]); - } - } - - __syncthreads(); - - // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { -#ifdef FAST_FP16_AVAILABLE - half tmp[kq_stride/warp_size][softmax_iter_j]; -#else - float tmp[kq_stride/warp_size][softmax_iter_j]; -#endif // FAST_FP16_AVAILABLE - -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_max_new[j0+j1] = warp_reduce_max(KQ_max_new[j0+j1]); - const float KQ_max_scale = expf(KQ_max[j0+j1] - KQ_max_new[j0+j1]); - KQ_max[j0+j1] = KQ_max_new[j0+j1]; - - float KQ_sum_add = 0.0f; -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const float val = expf(KQ_acc[i0/warp_size][j0+j1] - KQ_max[j0+j1]); - KQ_sum_add += val; - tmp[i0/warp_size][j1] = val; - } - KQ_sum[j0+j1] = KQ_sum[j0+j1]*KQ_max_scale + KQ_sum_add; - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0+j1][i0/warp_size].x *= KQ_max_scale; - VKQ[j0+j1][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - -#pragma unroll - for (int i0 = 0; i0 < kq_stride; i0 += warp_size) { - const int i = i0 + threadIdx.x; - - ggml_cuda_memcpy_1( - KQ[j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j)][i], tmp[i0/warp_size]); - } - } - - // VKQ = V @ KQ matrix multiplication: - constexpr int V_cols_per_iter = kq_stride*kq_nbatch / D; // Number of V columns that fit in SRAM for K. - static_assert(kq_stride % V_cols_per_iter == 0, "bad V_cols_per_iter"); -#pragma unroll - for (int k0 = 0; k0 < kq_stride; k0 += V_cols_per_iter) { -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; k1 += nwarps) { - const int k_tile = k1 + threadIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne < D/(2*warp_size) ? cpy_ne : D/(2*warp_size); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*(D/2) + i0 + threadIdx.x*cpy_ne_D], - &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0 + threadIdx.x*cpy_ne_D]); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - half2 tmp_h2[cpy_ne_D/2]; - ggml_cuda_memcpy_1( - tmp_h2, &V_h2[int64_t(k_VKQ_0 + k0 + k_tile)*stride_KV2 + i0/2 + threadIdx.x*(cpy_ne_D/2)]); - - float2 tmp_f2[cpy_ne_D/2]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { - tmp_f2[i1] = __half22float2(tmp_h2[i1]); - } - ggml_cuda_memcpy_1( - &KV_tmp[k_tile*D + i0 + threadIdx.x*cpy_ne_D], tmp_f2); - } -#endif // FAST_FP16_AVAILABLE - } - - __syncthreads(); - -#ifdef FAST_FP16_AVAILABLE -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - half2 V_k[(D/2)/warp_size]; - half2 KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[k1*(D/2) + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - half tmp[softmax_iter_j]; - ggml_cuda_memcpy_1( - &tmp, KQ[j][k0 + k1]); -#pragma unroll - for (int j1 = 0; j1 < softmax_iter_j; ++j1) { - KQ_k[j0+j1] = __half2half2(tmp[j1]); - } - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size] += V_k[i0/warp_size]*KQ_k[j0]; - } - } - } -#else -#pragma unroll - for (int k1 = 0; k1 < V_cols_per_iter; ++k1) { - float2 V_k[(D/2)/warp_size]; - float KQ_k[cpw]; - - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[k1*D + i0 + threadIdx.x*cpy_ne_D]); - } -#pragma unroll - for (int j0 = 0; j0 < cpw; j0 += softmax_iter_j) { - const int j = j0/softmax_iter_j + threadIdx.y*(cpw/softmax_iter_j); - - ggml_cuda_memcpy_1( - &KQ_k[j0], KQ[j][k0 + k1]); - } - -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - VKQ[j0][i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[j0]; - VKQ[j0][i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[j0]; - } - } - } -#endif // FAST_FP16_AVAILABLE - - __syncthreads(); - } - } - - - // Attention sink: adjust running max and sum once per head - if (sinksf && blockIdx.y == 0) { - const float sink = sinksf[head]; - -#pragma unroll - for (int j0 = 0; j0 < cpw; ++j0) { - float KQ_max_new_j = fmaxf(KQ_max[j0], sink); - KQ_max_new_j = warp_reduce_max(KQ_max_new_j); - - const float KQ_max_scale = expf(KQ_max[j0] - KQ_max_new_j); - KQ_max[j0] = KQ_max_new_j; - - const float val = expf(sink - KQ_max[j0]); - KQ_sum[j0] = KQ_sum[j0] * KQ_max_scale; - if (threadIdx.x == 0) { - KQ_sum[j0] += val; - } - -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size] *= KQ_max_scale_h2; - } -#else -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size) { - VKQ[j0][i0/warp_size].x *= KQ_max_scale; - VKQ[j0][i0/warp_size].y *= KQ_max_scale; - } -#endif // FAST_FP16_AVAILABLE - } - } - -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - KQ_sum[j_VKQ_0] = warp_reduce_sum(KQ_sum[j_VKQ_0]); - } - if (gridDim.y == 1) { -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { -#ifdef FAST_FP16_AVAILABLE - const half2 KQ_sum_j_inv = make_half2(1.0f/KQ_sum[j_VKQ_0], 1.0f/KQ_sum[j_VKQ_0]); -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i] *= KQ_sum_j_inv; - } -#else - const float KQ_sum_j_inv = 1.0f/KQ_sum[j_VKQ_0]; -#pragma unroll - for (int i = 0; i < (D/2)/warp_size; ++i) { - VKQ[j_VKQ_0][i].x *= KQ_sum_j_inv; - VKQ[j_VKQ_0][i].y *= KQ_sum_j_inv; - } -#endif // FAST_FP16_AVAILABLE - } - } - - // Write back results: -#pragma unroll - for (int j_VKQ_0 = 0; j_VKQ_0 < cpw; ++j_VKQ_0) { - const int j_VKQ = j_VKQ_0 + threadIdx.y*cpw; - - if (ic0 + j_VKQ >= ne01) { - return; - } - - const int j_dst_unrolled = ((sequence*ne01 + ic0 + j_VKQ)*ne02 + head)*gridDim.y + blockIdx.y; - -#ifdef FAST_FP16_AVAILABLE - constexpr int cpy_ne_D = cpy_ne/2 < (D/2)/warp_size ? cpy_ne/2 : (D/2)/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D/2; i0 += warp_size*cpy_ne_D) { - float2 tmp[cpy_ne_D]; -#pragma unroll - for (int i1 = 0; i1 < cpy_ne_D; ++i1) { - tmp[i1] = __half22float2(VKQ[j_VKQ_0][i0/warp_size + i1]); - } - ggml_cuda_memcpy_1(&dst[j_dst_unrolled*D + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); - } -#else - constexpr int cpy_ne_D = cpy_ne < D/warp_size ? cpy_ne : D/warp_size; -#pragma unroll - for (int i0 = 0; i0 < D; i0 += warp_size*cpy_ne_D) { - ggml_cuda_memcpy_1( - &dst[j_dst_unrolled*D + i0 + threadIdx.x*cpy_ne_D], &VKQ[j_VKQ_0][i0/(2*warp_size)]); - } -#endif // FAST_FP16_AVAILABLE - - if (gridDim.y != 1 && threadIdx.x == 0) { - dst_meta[j_dst_unrolled] = make_float2(KQ_max[j_VKQ_0], KQ_sum[j_VKQ_0]); - } - } -#else - GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, - max_bias, m0, m1, n_head_log2, logit_softcap, - ne00, ne01, ne02, ne03, - nb01, nb02, nb03, - ne10, ne11, ne12, ne13, - nb11, nb12, nb13, - nb21, nb22, nb23, - ne31, ne32, ne33, - nb31, nb32, nb33); - NO_DEVICE_CODE; -#endif // FLASH_ATTN_AVAILABLE -} - -template -static void launch_fattn_tile_switch_ncols(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - - const int id = ggml_cuda_get_device(); - const int cc = ggml_cuda_info().devices[id].cc; - const int warp_size = 32; - - constexpr size_t nbytes_shared = 0; - -#ifdef GGML_USE_HIP - if constexpr (D <= 128) { - if (Q->ne[1] > 32) { - constexpr int cols_per_block = 64; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - } -#endif // GGML_USE_HIP - - if (Q->ne[1] > 16) { - constexpr int cols_per_block = 32; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); - return; - } - - constexpr int cols_per_block = 16; - const int nwarps = fattn_tile_get_nthreads_host(cc, cols_per_block) / warp_size; - fattn_kernel_t fattn_kernel = flash_attn_tile; - const int kq_stride = fattn_tile_get_kq_stride_host(D, cols_per_block, cc, warp_size); - launch_fattn - (ctx, dst, fattn_kernel, nwarps, nbytes_shared, kq_stride, true, true, false, warp_size); -} - -template -static void launch_fattn_tile_switch_head_size(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * Q = dst->src[0]; - switch (Q->ne[0]) { +#include "fattn-wmma-f16.cuh" + +void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * V = dst->src[2]; + switch (K->ne[0]) { + case 40: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 40, 40>(ctx, dst); + } break; case 64: { - launch_fattn_tile_switch_ncols< 64, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 64, 64>(ctx, dst); + } break; + case 80: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 80, 80>(ctx, dst); + } break; + case 96: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case< 96, 96>(ctx, dst); + } break; + case 112: { + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<112, 112>(ctx, dst); } break; case 128: { - launch_fattn_tile_switch_ncols<128, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<128, 128>(ctx, dst); } break; case 256: { - launch_fattn_tile_switch_ncols<256, use_logit_softcap>(ctx, dst); + GGML_ASSERT(V->ne[0] == K->ne[0]); + ggml_cuda_flash_attn_ext_tile_case<256, 256>(ctx, dst); + } break; + case 576: { + GGML_ASSERT(V->ne[0] == 512); + ggml_cuda_flash_attn_ext_tile_case<576, 512>(ctx, dst); } break; default: { GGML_ABORT("Unsupported head size"); } break; } } - -void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { - const ggml_tensor * KQV = dst; - - float logit_softcap; - memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - - if (logit_softcap == 0.0f) { - constexpr bool use_logit_softcap = false; - launch_fattn_tile_switch_head_size(ctx, dst); - } else { - constexpr bool use_logit_softcap = true; - launch_fattn_tile_switch_head_size(ctx, dst); - } -} diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh index 10dc22d1bf..2b60b3bb13 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-tile.cuh @@ -1,3 +1,1206 @@ #include "common.cuh" +#include "fattn-common.cuh" +#include "fattn-wmma-f16.cuh" + +// nbatch_fa == number of KQ rows to process per iteration +// nbatch_K == number of K columns to load in parallel for KQ calculation + +// TODO optimize kernel parameters for FP16 NVIDIA (P100) +// TODO optimize kernel parameters for head sizes 40, 80, 96, 112 + +// The ROCm compiler cannot handle templating in __launch_bounds__. +// As a workaround, define a macro to package the kernel parameters as uint32_t: +#define GGML_CUDA_FATTN_TILE_CONFIG_CASE(DKQ_, DV_, ncols_, nthreads, occupancy, nbatch_fa, nbatch_K) \ + if (DKQ == (DKQ_) && DV == (DV_) && ncols == (ncols_)) { \ + static_assert((nthreads) <= 512, "bad nthreads"); \ + static_assert((occupancy) <= 8, "bad occupancy"); \ + static_assert((nbatch_fa) <= 256, "bad nbatch_fa"); \ + static_assert((nbatch_K) <= 256, "bad nbatch_K"); \ + return ((nthreads) << 0) | ((occupancy) << 10) | ((nbatch_fa) << 14) | ((nbatch_K) << 23); \ + } \ + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp16(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 64, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 64, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 64, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 64, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 64, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 64, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_nvidia_fp32(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 3, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 128, 3, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 32, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 3, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 128, 3, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 2, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 256, 2, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 2, 64, 32) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 256, 2, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 256, 2, 64, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 2, 32, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 2, 32, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 2, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 512, 1, 128, 64) + + return 0; +} + +static constexpr __host__ __device__ uint32_t ggml_cuda_fattn_tile_get_config_amd_rdna(const int DKQ, const int DV, const int ncols) { + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 40, 40, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 4, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 8, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 16, 128, 5, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 32, 128, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 64, 64, 64, 128, 5, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 2, 64, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 4, 128, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 8, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 16, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 32, 256, 2, 32, 40) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 80, 80, 64, 256, 2, 32, 40) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 2, 64, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 4, 128, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 8, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 16, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 32, 256, 2, 32, 48) + GGML_CUDA_FATTN_TILE_CONFIG_CASE( 96, 96, 64, 256, 2, 32, 48) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 2, 64, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 4, 128, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 8, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 16, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 32, 256, 2, 32, 56) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(112, 112, 64, 256, 2, 32, 56) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 4, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 8, 128, 8, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 16, 256, 3, 128, 128) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 32, 256, 3, 128, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(128, 128, 64, 256, 3, 64, 64) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 2, 64, 8, 32, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 4, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 8, 128, 6, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 16, 256, 5, 32, 256) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(256, 256, 32, 256, 3, 64, 128) + + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 16, 256, 4, 64, 64) + GGML_CUDA_FATTN_TILE_CONFIG_CASE(576, 512, 32, 256, 2, 128, 64) + + return 0; +} + +static __host__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols, const int cc) { + if (GGML_CUDA_CC_IS_AMD(cc)) { + if (GGML_CUDA_CC_IS_RDNA(cc)) { + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); + } + if (fast_fp16_available(cc)) { + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); + } + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +} + +static constexpr __device__ uint32_t ggml_cuda_fattn_tile_get_config(const int DKQ, const int DV, const int ncols) { +#ifdef GGML_USE_HIP +#ifdef RDNA + return ggml_cuda_fattn_tile_get_config_amd_rdna(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_amd(DKQ, DV, ncols); +#endif // RDNA +#else +#ifdef FAST_FP16_AVAILABLE + return ggml_cuda_fattn_tile_get_config_nvidia_fp16(DKQ, DV, ncols); +#else + return ggml_cuda_fattn_tile_get_config_nvidia_fp32(DKQ, DV, ncols); +#endif // FAST_FP16_AVAILABLE +#endif // GGML_USE_HIP +} + +static __host__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 0) & ((1 << 10) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nthreads(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 0) & ((1 << 10) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 10) & ((1 << 4) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_occupancy(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 10) & ((1 << 4) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 14) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_fa(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 14) & ((1 << 9) - 1); +} + +static __host__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols, const int cc) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols, cc) >> 23) & ((1 << 9) - 1); +} + +static constexpr __device__ int ggml_cuda_fattn_tile_get_nbatch_K(const int DKQ, const int DV, const int ncols) { + return (ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols) >> 23) & ((1 << 9) - 1); +} + +// TODO: deduplicate with mma-f16 +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, half2 * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (2*stride_j); + const int j0_stop = ((J/2)/cpy_ne) - ((J/2)/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*cpy_ne + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*cpy_ne; + + const half2 zero[cpy_ne] = {{0.0f, 0.0f}}; + ggml_cuda_memcpy_1( + tile_KV + i*(J/2 + J_padding) + j, + !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + } + } + } + }; + // 1: max 64*16=512 bytes, 512 half + // 2: max 32*16=512 bytes, 256 half + // 3: max 16*16=256 bytes, 128 half + // 4: max 8*16=128 bytes, 64 half + // 5: max 4*16= 64 bytes, 32 half + // 6: max 2*16= 32 bytes, 16 half + // 7: max 1*16= 16 bytes, 8 half + static_assert(J % 8 == 0, "bad J"); + static_assert((J/2) % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<7>{}(load); +} + +template +static __device__ __forceinline__ void flash_attn_tile_load_tile( + const half2 * const __restrict__ KV, float * const __restrict__ tile_KV, const int stride_KV, const int i_sup) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + auto load = [&] __device__ (const int n) { + const int stride_j = warp_size >> n; + + if (stride_j == 0) { + return; + } + + const int j0_start = stride_j == warp_size ? 0 : (J/cpy_ne) - (J/cpy_ne) % (2*stride_j); + const int j0_stop = (J/cpy_ne) - (J/cpy_ne) % (1*stride_j); + const int stride_i = warp_size / stride_j; + + if (j0_start == j0_stop) { + return; + } + +#pragma unroll + for (int i0 = 0; i0 < I; i0 += nwarps*stride_i) { + const int i = i0 + threadIdx.y*stride_i + (stride_j == warp_size ? 0 : threadIdx.x / stride_j); + + if (i0 + nwarps*stride_i <= I || i < I) { +#pragma unroll + for (int j0 = j0_start; j0 < j0_stop; j0 += stride_j) { + const int j = j0*(cpy_ne/2) + (stride_j == warp_size ? threadIdx.x : threadIdx.x % stride_j)*(cpy_ne/2); + + const half2 zero[cpy_ne/2] = {{0.0f, 0.0f}}; + half2 tmp_h2[cpy_ne/2]; + ggml_cuda_memcpy_1( + tmp_h2, !oob_check || i < i_sup ? KV + i*stride_KV + j : zero); + + float2 tmp_f2[cpy_ne/2]; +#pragma unroll + for (int l = 0; l < cpy_ne/2; ++l) { + tmp_f2[l] = __half22float2(tmp_h2[l]); + } + ggml_cuda_memcpy_1(tile_KV + i*(J + J_padding) + 2*j, tmp_f2); + } + } + } + }; + // 1: max 32*16=512 bytes, 128 float + // 2: max 16*16=256 bytes, 64 float + // 3: max 8*16=128 bytes, 32 float + // 4: max 4*16= 64 bytes, 16 float + // 5: max 2*16= 32 bytes, 8 float + static_assert(J % 8 == 0, "bad J"); + static_assert(J % cpy_ne == 0, "bad J"); + ggml_cuda_unroll<5>{}(load); +} + +// Function that performs a single iteration in for the KQ matrix multiplication: +template +static __device__ __forceinline__ void flash_attn_tile_iter_KQ( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int k_VKQ_0, + const int k_VKQ_sup, + const int k_KQ_0, + float * KQ_acc) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + flash_attn_tile_load_tile + (K_h2 + int64_t(k_VKQ_0)*stride_K2 + k_KQ_0/2, KV_tmp, stride_K2, k_VKQ_sup); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE + static_assert((nbatch_K/2) % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K/2; k_KQ_1 += cpy_ne) { + half2 K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + half2 Q_k[cpw][cpy_ne]; +#else + static_assert(nbatch_K % cpy_ne == 0, "bad nbatch_K"); +#pragma unroll + for (int k_KQ_1 = 0; k_KQ_1 < nbatch_K; k_KQ_1 += cpy_ne) { + float K_k[nbatch_fa/(np*warp_size)][cpy_ne]; + float Q_k[cpw][cpy_ne]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K/2 + cpy_ne) + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&K_k[i_KQ_0/(np*warp_size)], &KV_tmp[i_KQ*(nbatch_K + cpy_ne) + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + +#ifdef FAST_FP16_AVAILABLE + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc*(DKQ/2) + k_KQ_0/2 + k_KQ_1]); +#else + ggml_cuda_memcpy_1(&Q_k[jc0], &Q_tmp[jc* DKQ + k_KQ_0 + k_KQ_1]); +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { +#pragma unroll + for (int k = 0; k < cpy_ne; ++k) { + ggml_cuda_mad(KQ_acc[i_KQ_0/(np*warp_size)*cpw + jc0], K_k[i_KQ_0/(np*warp_size)][k], Q_k[jc0][k]); + } + } + } + } + + if (k_KQ_0 + nbatch_K < DKQ) { + __syncthreads(); // Sync not needed on last iteration. + } +} + +// Function that performs a single iteration of the main loop over up to nbatch_fa tokens. +template +static __device__ __forceinline__ void flash_attn_tile_iter( + T_vec_dot * const Q_tmp, + const half2 * const __restrict__ K_h2, + const half2 * const __restrict__ V_h2, + const half * const __restrict__ mask, + const float logit_softcap, + const float slope, + T_KQ * const KQ, + T_vec_dot * const KV_tmp, + const int stride_K2, + const int stride_V2, + const int stride_mask, + float * const KQ_max, + float * const KQ_sum, + T_acc * const VKQ, + const int k_VKQ_0, + const int k_VKQ_max) { + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int ncols = ncols1*ncols2; + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // number of parallel warps per Q column + + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // KQ_cs == KQ chunk size, number of KQ values in j direction to store as one contiguous chunk in memory. + // KQ is originally 2D but uses a Z-shaped 3D memory pattern like KQ[ncols/KQ_cs][DVp][KQ_cs]. +#ifdef FAST_FP16_AVAILABLE + constexpr int KQ_cs = cpw < 2*cpy_ne ? cpw : 2*cpy_ne; +#else + constexpr int KQ_cs = cpw < 1*cpy_ne ? cpw : 1*cpy_ne; +#endif // FAST_FP16_AVAILABLE + static_assert(cpw % KQ_cs == 0, "bad KQ_cs"); + const int k_VKQ_sup = k_VKQ_max - k_VKQ_0; // k supremum, only smaller k values have valid KV data + + float KQ_max_new[cpw]; +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_max_new[jc0] = KQ_max[jc0]; + } + + float KQ_acc[nbatch_fa/(np*warp_size) * cpw] = {0.0f}; // Accumulators for KQ matrix multiplication. + + // KQ = K @ Q matrix multiplication: + constexpr int nbatch_K_last = DKQ % nbatch_K; +#pragma unroll + for (int k_KQ_0 = 0; k_KQ_0 < DKQ - nbatch_K_last; k_KQ_0 += nbatch_K) { + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + if (nbatch_K_last > 0) { + constexpr int k_KQ_0 = DKQ - nbatch_K_last; + flash_attn_tile_iter_KQ( + Q_tmp, K_h2, KV_tmp, stride_K2, k_VKQ_0, k_VKQ_sup, k_KQ_0, KQ_acc); + } + + // Apply logit softcap + mask, update KQ_max: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int j = (jc0 + (threadIdx.y / np)*cpw)/ncols2; + +#pragma unroll + for (int i_KQ_0 = 0; i_KQ_0 < nbatch_fa; i_KQ_0 += np*warp_size) { + const int i_KQ = i_KQ_0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + if (use_logit_softcap) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] = logit_softcap * tanhf(KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + + if (!oob_check || i_KQ < k_VKQ_sup) { + KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0] += (ncols2 > 1 || mask) ? + slope*__half2float(mask[j*stride_mask + k_VKQ_0 + i_KQ]) : 0.0f; + + KQ_max_new[jc0] = fmaxf(KQ_max_new[jc0], KQ_acc[(i_KQ_0/(np*warp_size))*cpw + jc0]); + } + } + + KQ_max_new[jc0] = warp_reduce_max(KQ_max_new[jc0]); + } + + if constexpr (np == 1) { + __syncthreads(); + } else { + static_assert(cpw == 1, "bad cpw"); + __shared__ float KQ_max_new_shared[nwarps]; + if (threadIdx.x == 0) { + KQ_max_new_shared[threadIdx.y] = KQ_max_new[0]; + } + __syncthreads(); + KQ_max_new[0] = KQ_max_new_shared[(threadIdx.y & ~(np-1)) + threadIdx.x % np]; + KQ_max_new[0] = warp_reduce_max(KQ_max_new[0]); + } + + // Calculate KQ softmax, write to shared KQ buffer, re-scale VKQ accumulators: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; jc0 += KQ_cs) { +#ifdef FAST_FP16_AVAILABLE + half tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#else + float tmp[nbatch_fa/(np*warp_size)][KQ_cs]; +#endif // FAST_FP16_AVAILABLE + +#pragma unroll + for (int jc1 = 0; jc1 < KQ_cs; ++jc1) { + const int jc = jc0 + jc1; + + const float KQ_max_scale = expf(KQ_max[jc] - KQ_max_new[jc]); + KQ_max[jc] = KQ_max_new[jc]; + + float KQ_sum_add = 0.0f; +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const float val = !oob_check || i0 + (threadIdx.y % np)*warp_size + threadIdx.x < k_VKQ_sup ? + expf(KQ_acc[(i0/(np*warp_size))*cpw + jc] - KQ_max[jc]) : 0.0f; + KQ_sum_add += val; + tmp[i0/(np*warp_size)][jc1] = val; + } + KQ_sum[jc] = KQ_sum[jc]*KQ_max_scale + KQ_sum_add; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + +#pragma unroll + for (int i0 = 0; i0 < nbatch_fa; i0 += np*warp_size) { + const int i = i0 + (threadIdx.y % np)*warp_size + threadIdx.x; + + ggml_cuda_memcpy_1( + KQ + (jc0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs))*(nbatch_fa*KQ_cs) + i*KQ_cs, + tmp[i0/(np*warp_size)]); + } + } + + // VKQ = V @ KQ matrix multiplication: + static_assert(DV <= DKQ, "bad DV"); + static_assert(DV % nbatch_K == 0 || (nbatch_K % 3 == 0 && DV % (nbatch_K*2/3) == 0), "bad nbatch_K"); + constexpr int nbatch_V = (DV % nbatch_K == 0 ? nbatch_K : nbatch_K*2/3) * nbatch_fa / DV; // Number of V columns that fit in SRAM for K. + static_assert(nbatch_fa % nbatch_V == 0, "bad nbatch_V"); + static_assert(nbatch_V % np == 0, "bad nbatch_V"); +#pragma unroll + for (int k0 = 0; k0 < nbatch_fa; k0 += nbatch_V) { + flash_attn_tile_load_tile + (V_h2 + int64_t(k_VKQ_0 + k0)*stride_V2, KV_tmp, stride_V2, k_VKQ_sup - k0); + __syncthreads(); + +#ifdef FAST_FP16_AVAILABLE +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + half2 V_k[(DVp/2)/warp_size]; + half2 KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/warp_size], &KV_tmp[(k1 + threadIdx.y % np)*(DV/2) + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + half tmp[KQ_cs]; + ggml_cuda_memcpy_1( + &tmp, KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); +#pragma unroll + for (int jc_VKQ_1 = 0; jc_VKQ_1 < KQ_cs; ++jc_VKQ_1) { + KQ_k[jc_VKQ_0+jc_VKQ_1] = __half2half2(tmp[jc_VKQ_1]); + } + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size] += V_k[i0/warp_size]*KQ_k[jc_VKQ_0]; + } + } + } +#else +#pragma unroll + for (int k1 = 0; k1 < nbatch_V; k1 += np) { + float2 V_k[(DVp/2)/warp_size]; + float KQ_k[cpw]; + + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&V_k[i0/(2*warp_size)], &KV_tmp[(k1 + threadIdx.y % np)*DV + i0 + threadIdx.x*cpy_ne_D]); + } +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; jc_VKQ_0 += KQ_cs) { + const int jc_KQ = jc_VKQ_0/KQ_cs + (threadIdx.y / np)*(cpw/KQ_cs); + + ggml_cuda_memcpy_1( + &KQ_k[jc_VKQ_0], KQ + jc_KQ*(nbatch_fa*KQ_cs) + (k0 + k1 + threadIdx.y % np)*KQ_cs); + } + +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { +#pragma unroll + for (int jc_VKQ_0 = 0; jc_VKQ_0 < cpw; ++jc_VKQ_0) { + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].x += V_k[i0/warp_size].x*KQ_k[jc_VKQ_0]; + VKQ[jc_VKQ_0*((DVp/2)/warp_size) + i0/warp_size].y += V_k[i0/warp_size].y*KQ_k[jc_VKQ_0]; + } + } + } +#endif // FAST_FP16_AVAILABLE + + __syncthreads(); + } +} + +template // D == head size +__launch_bounds__(ggml_cuda_fattn_tile_get_nthreads(DKQ, DV, ncols1*ncols2), ggml_cuda_fattn_tile_get_occupancy(DKQ, DV, ncols1*ncols2)) +static __global__ void flash_attn_tile( + const char * __restrict__ Q, + const char * __restrict__ K, + const char * __restrict__ V, + const char * __restrict__ mask, + const char * __restrict__ sinks, + const int * __restrict__ KV_max, + float * __restrict__ dst, + float2 * __restrict__ dst_meta, + const float scale, + const float max_bias, + const float m0, + const float m1, + const uint32_t n_head_log2, + const float logit_softcap, + const int32_t ne00, const int32_t ne01, const int32_t ne02, const int32_t ne03, + const int32_t nb01, const int32_t nb02, const int32_t nb03, + const int32_t ne10, const int32_t ne11, const int32_t ne12, const int32_t ne13, + const int32_t nb11, const int32_t nb12, const int64_t nb13, + const int32_t nb21, const int32_t nb22, const int64_t nb23, + const int32_t ne31, const int32_t ne32, const int32_t ne33, + const int32_t nb31, const int32_t nb32, const int64_t nb33) { +#ifdef FLASH_ATTN_AVAILABLE + + // Skip unused kernel variants for faster compilation: + + if ( +#ifdef GGML_USE_WMMA_FATTN + (ncols2 != 1 && DV != 40 && DV != 512) || +#endif // GGML_USE_WMMA_FATTN + (use_logit_softcap && !(DV == 128 || DV == 256)) + ) { + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; + return; + } + + static_assert(ggml_cuda_fattn_tile_get_config(DKQ, DV, ncols1*ncols2) != 0, "kernel config not defined"); + + constexpr int ncols = ncols1*ncols2; + constexpr int warp_size = 32; + constexpr int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, ncols1*ncols2) / warp_size; + constexpr int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, ncols1*ncols2); + constexpr int nbatch_K = ggml_cuda_fattn_tile_get_nbatch_K (DKQ, DV, ncols1*ncols2); + + // In this kernel Q, K, V are matrices while i, j, k are matrix indices. + + const int col_Q_0 = blockIdx.x * ncols1; // Index of the first Q column for this CUDA block to work on. + + const int sequence = blockIdx.z / (ne02/ncols2); + const int head0 = blockIdx.z*ncols2 - sequence*ne02; // == blockIdx.z % (ne02/ncols2) + const int gqa_ratio = ne02 / ne12; // With grouped query attention there are > 1 Q matrices per K, V matrix. + const float * Q_f = (const float *) (Q + nb03*sequence + nb02* head0 + nb01*col_Q_0); + const half2 * K_h2 = (const half2 *) (K + nb13*sequence + nb12*(head0 / gqa_ratio)); + const half2 * V_h2 = (const half2 *) (V + nb23*sequence + nb22*(head0 / gqa_ratio)); // K and V have same shape + + const half * maskh = mask ? (const half *) (mask + nb33*(sequence % ne33) + nb31*col_Q_0) : nullptr; + + const int stride_K2 = nb11 / sizeof(half2); + const int stride_V2 = nb21 / sizeof(half2); + const int stride_mask = nb31 / sizeof(half); + + const float slope = ncols2 == 1 ? get_alibi_slope(max_bias, head0, n_head_log2, m0, m1) : 1.0f; + + constexpr int cpy_nb = ggml_cuda_get_max_cpy_bytes(); + constexpr int cpy_ne = cpy_nb / 4; + + constexpr int cpw = ncols > nwarps ? ncols/nwarps : 1; // Q columns per warp. + constexpr int np = nwarps > ncols ? nwarps/ncols : 1; // Number of parallel warps per Q column. + static_assert(cpw == 1 || np == 1, "bad cpw / np"); + static_assert(nbatch_fa % (np*warp_size) == 0, "nbatch_fa % (np*warp_size) != 0"); + + constexpr int DKQp = (DKQ + 2*warp_size - 1) & ~(2*warp_size - 1); // DKQ padded to multiple of 2*warp_size. + constexpr int DVp = (DV + 2*warp_size - 1) & ~(2*warp_size - 1); // DV padded to multiple of 2*warp_size. + + // Q_tmp == SRAM buffer to hold Q data for the entire lifetime of the kernel. + // KV_tmp == SRAM buffer to hold fragments of K/V data while iterating over ne11. + // KV_tmp is padded to avoid memory conflicts for K (cpy_ne) and OOB accesses for V (DVp-DV). + // KQ == SRAM buffer to hold KQ fragments between KQ and VKQ matrix multiplications. + // VKQ == Accumulators in registers for the final VKQ result. +#ifdef FAST_FP16_AVAILABLE + __shared__ half2 Q_tmp[ncols * DKQ/2]; + __shared__ half2 KV_tmp[nbatch_fa * (nbatch_K/2 + cpy_ne) + DVp-DV]; + __shared__ half KQ[ncols * nbatch_fa]; + half2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#else + __shared__ float Q_tmp[ncols * DKQ]; + __shared__ float KV_tmp[nbatch_fa * (nbatch_K + cpy_ne) + DVp-DV]; + __shared__ float KQ[ncols * nbatch_fa]; + float2 VKQ[cpw * ((DVp/2)/warp_size)] = {{0.0f, 0.0f}}; +#endif // FAST_FP16_AVAILABLE + + float KQ_max[cpw]; +#pragma unroll + for (int j0 = 0; j0 < ncols; j0 += nwarps) { + KQ_max[j0/nwarps] = -FLT_MAX/2.0f; + } + float KQ_sum[cpw] = {0.0f}; + + // Load Q data, convert to FP16 if fast: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y / np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + constexpr int cpy_ne_D = cpy_ne < DKQp/warp_size ? cpy_ne : DKQp/warp_size; + +#pragma unroll + for (int i0 = 0; i0 < DKQp; i0 += np*warp_size*cpy_ne_D) { + if (i0 + np*warp_size*cpy_ne_D <= DKQ || i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D < DKQ) { + float tmp_f[cpy_ne_D] = {0.0f}; + if (ncols1 == 1 || col_Q_0 + j < ne01) { + ggml_cuda_memcpy_1 + (tmp_f, &Q_f[c*(nb02/sizeof(float)) + j*(nb01/sizeof(float)) + + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x*cpy_ne_D]); + } + +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp_f[i1] *= scale; + } + +#ifdef FAST_FP16_AVAILABLE + half2 tmp_h2[cpy_ne_D/2]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; i1 += 2) { + tmp_h2[i1/2] = make_half2(tmp_f[i1 + 0], tmp_f[i1 + 1]); + } + ggml_cuda_memcpy_1( + &Q_tmp[jc*(DKQ/2) + i0/2 + (threadIdx.y % np)*(warp_size*cpy_ne_D/2) + threadIdx.x*(cpy_ne_D/2)], + tmp_h2); +#else + ggml_cuda_memcpy_1( + &Q_tmp[jc* DKQ + i0 + (threadIdx.y % np)*(warp_size*cpy_ne_D) + threadIdx.x* cpy_ne_D], + tmp_f); +#endif // FAST_FP16_AVAILABLE + } + } + } + + __syncthreads(); + + // Main loop over KV cache: + const int k_VKQ_max = KV_max ? KV_max[sequence*gridDim.x + blockIdx.x] : ne11; + if (ncols2 == 1) { + // Branch with out-of-bounds checks. + int k_VKQ_0 = blockIdx.y*nbatch_fa; + while (k_VKQ_0 < k_VKQ_max - nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + k_VKQ_0 += gridDim.y*nbatch_fa; + } + if (k_VKQ_0 < k_VKQ_max) { + constexpr bool oob_check = true; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } else { + // Branch without out-of-bounds checks. + for (int k_VKQ_0 = blockIdx.y*nbatch_fa; k_VKQ_0 < k_VKQ_max; k_VKQ_0 += gridDim.y*nbatch_fa) { + constexpr bool oob_check = false; + flash_attn_tile_iter + (Q_tmp, K_h2, V_h2, maskh, logit_softcap, slope, KQ, KV_tmp, + stride_K2, stride_V2, stride_mask, KQ_max, KQ_sum, VKQ, k_VKQ_0, k_VKQ_max); + } + } + +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + KQ_sum[jc0] = warp_reduce_sum(KQ_sum[jc0]); + } + + if constexpr (np > 1) { + static_assert(cpw == 1, "bad cpw"); + static_assert(nbatch_fa*nbatch_K >= nwarps*DVp, "KV_tmp too small"); + +#ifdef FAST_FP16_AVAILABLE + half2 * VKQ_combine = (half2 *) KV_tmp; +#else + float * VKQ_combine = (float *) KV_tmp; +#endif // FAST_FP16_AVAILABLE + float * KQ_sum_combine = (float *) Q_tmp; + + if (threadIdx.y % np != 0) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1(&VKQ_combine[threadIdx.y*(DVp/2) + i0 + threadIdx.x*cpy_ne_D], &VKQ[i0/warp_size]); + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + ggml_cuda_memcpy_1( + &VKQ_combine[threadIdx.y*DVp + i0 + threadIdx.x*cpy_ne_D], ((const float *) VKQ) + i0/warp_size); + } +#endif // FAST_FP16_AVAILABLE + + if (threadIdx.x == 0) { + KQ_sum_combine[threadIdx.y] = KQ_sum[0]; + } + + return; + } + + __syncthreads(); + +#pragma unroll + for (int ip = 1; ip < np; ++ip) { +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne < (DVp/2)/warp_size ? cpy_ne : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + half2 tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*(DVp/2) + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + VKQ[i0/warp_size + i1] += tmp[i1]; + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + float tmp[cpy_ne_D]; + ggml_cuda_memcpy_1(tmp, &VKQ_combine[(threadIdx.y + ip)*DVp + i0 + threadIdx.x*cpy_ne_D]); +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + ((float *)VKQ)[i0/warp_size + i1] += tmp[i1]; + } + } +#endif // FAST_FP16_AVAILABLE + + KQ_sum[0] += KQ_sum_combine[threadIdx.y + ip]; + } + } + + // Attention sink: adjust KQ max and sum only for the first of all parallel blocks: + if (sinks && blockIdx.y == 0) { +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + const float sink = ((const float *) sinks)[head0 + jc % ncols2]; + + float KQ_max_new_j = fmaxf(KQ_max[jc0], sink); + const float KQ_max_scale = expf(KQ_max[jc0] - KQ_max_new_j); + KQ_max[jc0] = KQ_max_new_j; + + const float val = expf(sink - KQ_max[jc0]); + KQ_sum[jc0] = KQ_sum[jc0]*KQ_max_scale + val; + +#ifdef FAST_FP16_AVAILABLE + const half2 KQ_max_scale_h2 = make_half2(KQ_max_scale, KQ_max_scale); +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size] *= KQ_max_scale_h2; + } +#else +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size) { + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].x *= KQ_max_scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size].y *= KQ_max_scale; + } +#endif // FAST_FP16_AVAILABLE + } + } + + // Write back results: +#pragma unroll + for (int jc0 = 0; jc0 < cpw; ++jc0) { + const int jc = jc0 + (threadIdx.y/np)*cpw; + + const int j = jc / ncols2; + const int c = jc % ncols2; + + if (ncols1 > 1 && col_Q_0 + j >= ne01) { + return; + } + + const float scale = gridDim.y == 1 ? 1.0f/KQ_sum[jc0] : 1.0f; + + const int j_dst_unrolled = ((sequence*ne01 + col_Q_0 + j)*ne02 + head0 + c)*gridDim.y + blockIdx.y; + +#ifdef FAST_FP16_AVAILABLE + constexpr int cpy_ne_D = cpy_ne/2 < (DVp/2)/warp_size ? cpy_ne/2 : (DVp/2)/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp/2; i0 += warp_size*cpy_ne_D) { + float2 tmp[cpy_ne_D]; +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D; ++i1) { + tmp[i1] = __half22float2(VKQ[jc0*((DVp/2)/warp_size) + i0/warp_size + i1]); + tmp[i1].x *= scale; + tmp[i1].y *= scale; + } + if (i0 + warp_size*cpy_ne_D <= DV/2 || i0 + threadIdx.x*cpy_ne_D < DV/2) { + ggml_cuda_memcpy_1(&dst[j_dst_unrolled*DV + 2*i0 + threadIdx.x*(2*cpy_ne_D)], tmp); + } + } +#else + constexpr int cpy_ne_D = cpy_ne < DVp/warp_size ? cpy_ne : DVp/warp_size; +#pragma unroll + for (int i0 = 0; i0 < DVp; i0 += warp_size*cpy_ne_D) { + if (i0 + warp_size*cpy_ne_D <= DV || i0 + threadIdx.x*cpy_ne_D < DV) { +#pragma unroll + for (int i1 = 0; i1 < cpy_ne_D/2; ++i1) { + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].x *= scale; + VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size) + i1].y *= scale; + } + ggml_cuda_memcpy_1( + &dst[j_dst_unrolled*DV + i0 + threadIdx.x*cpy_ne_D], + &VKQ[jc0*((DVp/2)/warp_size) + i0/(2*warp_size)]); + } + } +#endif // FAST_FP16_AVAILABLE + + if (gridDim.y != 1 && threadIdx.x == 0) { + dst_meta[j_dst_unrolled] = make_float2(KQ_max[jc0], KQ_sum[jc0]); + } + } +#else + GGML_UNUSED_VARS(Q, K, V, mask, sinks, KV_max, dst, dst_meta, scale, + max_bias, m0, m1, n_head_log2, logit_softcap, + ne00, ne01, ne02, ne03, + nb01, nb02, nb03, + ne10, ne11, ne12, ne13, + nb11, nb12, nb13, + nb21, nb22, nb23, + ne31, ne32, ne33, + nb31, nb32, nb33); + NO_DEVICE_CODE; +#endif // FLASH_ATTN_AVAILABLE +} + +template +static void launch_fattn_tile_switch_ncols1(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * Q = dst->src[0]; + + const int id = ggml_cuda_get_device(); + const int cc = ggml_cuda_info().devices[id].cc; + const int warp_size = 32; + + constexpr size_t nbytes_shared = 0; + +#ifdef GGML_USE_HIP + if constexpr (DV <= 128) { + if (Q->ne[1] > 32/ncols2) { + constexpr int cols_per_block = 64; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } +#endif // GGML_USE_HIP + +#ifndef GGML_USE_HIP + if constexpr (DV <= 256) +#endif // GGML_USE_HIP + { + if (Q->ne[1] > 16/ncols2) { + constexpr int cols_per_block = 32; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if (Q->ne[1] > 8/ncols2) { + constexpr int cols_per_block = 16; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + if constexpr (ncols2 <= 8) { + if (Q->ne[1] > 4/ncols2) { + constexpr int cols_per_block = 8; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 4) { + if (Q->ne[1] > 2/ncols2) { + constexpr int cols_per_block = 4; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + } + + if constexpr (ncols2 <= 2) { + constexpr int cols_per_block = 2; + const int nwarps = ggml_cuda_fattn_tile_get_nthreads (DKQ, DV, cols_per_block, cc) / warp_size; + const int nbatch_fa = ggml_cuda_fattn_tile_get_nbatch_fa(DKQ, DV, cols_per_block, cc); + fattn_kernel_t fattn_kernel = flash_attn_tile; + launch_fattn + (ctx, dst, fattn_kernel, nwarps, nbytes_shared, nbatch_fa, true, true, false, warp_size); + return; + } + + GGML_ABORT("fatal error"); +} + +template +static void launch_fattn_tile_switch_ncols2(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + const ggml_tensor * Q = dst->src[0]; + const ggml_tensor * K = dst->src[1]; + const ggml_tensor * mask = dst->src[3]; + + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + const int gqa_ratio = Q->ne[2] / K->ne[2]; + + const bool nvidia = GGML_CUDA_CC_IS_NVIDIA(ggml_cuda_info().devices[ggml_cuda_get_device()].cc); + const int gqa_limit = nvidia && gqa_ratio <= 4 ? 16 : INT_MAX; + const bool use_gqa_opt = mask && max_bias == 0.0f && Q->ne[1] <= gqa_limit && K->ne[1] % FATTN_KQ_STRIDE == 0; + + if constexpr (DV == 512) { + if (use_gqa_opt && gqa_ratio % 16 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + } + + if constexpr (DV <= 256) { + if (use_gqa_opt && gqa_ratio % 8 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 4 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + if (use_gqa_opt && gqa_ratio % 2 == 0) { + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + + launch_fattn_tile_switch_ncols1(ctx, dst); + return; + } + GGML_ABORT("fatal error"); +} + +template +void ggml_cuda_flash_attn_ext_tile_case(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * KQV = dst; + + float logit_softcap; + memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); + + if (logit_softcap == 0.0f) { + constexpr bool use_logit_softcap = false; + launch_fattn_tile_switch_ncols2(ctx, dst); + } else { + constexpr bool use_logit_softcap = true; + launch_fattn_tile_switch_ncols2(ctx, dst); + } +} void ggml_cuda_flash_attn_ext_tile(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +#define DECL_FATTN_TILE_CASE(DKQ, DV) \ + template void ggml_cuda_flash_attn_ext_tile_case \ + (ggml_backend_cuda_context & ctx, ggml_tensor * dst) \ + +extern DECL_FATTN_TILE_CASE( 40, 40); +extern DECL_FATTN_TILE_CASE( 64, 64); +extern DECL_FATTN_TILE_CASE( 80, 80); +extern DECL_FATTN_TILE_CASE( 96, 96); +extern DECL_FATTN_TILE_CASE(112, 112); +extern DECL_FATTN_TILE_CASE(128, 128); +extern DECL_FATTN_TILE_CASE(256, 256); +extern DECL_FATTN_TILE_CASE(576, 512); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh index 59c62553b0..89ab0f1638 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-vec.cuh @@ -535,8 +535,6 @@ void ggml_cuda_flash_attn_ext_vec_case(ggml_backend_cuda_context & ctx, ggml_ten float logit_softcap; memcpy(&logit_softcap, (const float *) KQV->op_params + 2, sizeof(float)); - const int cc = ggml_cuda_info().devices[ggml_cuda_get_device()].cc; - if (Q->ne[1] == 1) { constexpr int cols_per_block = 1; if (logit_softcap == 0.0f) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu index 2219191fd9..6c90d6d52b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cu @@ -6,19 +6,19 @@ #include "fattn-common.cuh" #include "fattn-wmma-f16.cuh" -#ifdef FP16_MMA_AVAILABLE +#ifdef GGML_USE_WMMA_FATTN #if !defined(GGML_USE_HIP) #include -#ifdef GGML_USE_MUSA +#if defined(GGML_USE_MUSA) namespace wmma = mtmusa::wmma; #else // GGML_USE_MUSA namespace wmma = nvcuda::wmma; #endif // GGML_USE_MUSA -#elif defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE) +#elif defined(GGML_USE_HIP) #include namespace wmma = rocwmma; #endif // !defined(GGML_USE_HIP) -#endif // FP16_MMA_AVAILABLE +#endif // GGML_USE_WMMA_FATTN // D == head size, VKQ_stride == num VKQ rows calculated in parallel: template @@ -45,7 +45,7 @@ static __global__ void flash_attn_ext_f16( const int32_t nb21, const int32_t nb22, const int64_t nb23, const int32_t ne31, const int32_t ne32, const int32_t ne33, const int32_t nb31, const int32_t nb32, const int64_t nb33) { -#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#if defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) // Skip unused kernel variants for faster compilation: if (use_logit_softcap && !(D == 128 || D == 256)) { NO_DEVICE_CODE; @@ -481,7 +481,7 @@ static __global__ void flash_attn_ext_f16( ne31, ne32, ne33, nb31, nb32, nb33); NO_DEVICE_CODE; -#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(FP16_MMA_AVAILABLE))) +#endif // defined(FLASH_ATTN_AVAILABLE) && (__CUDA_ARCH__ == GGML_CUDA_CC_VOLTA || (defined(GGML_HIP_ROCWMMA_FATTN) && defined(GGML_USE_WMMA_FATTN))) } constexpr int get_max_power_of_2(int x) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh index beeea95eb1..7235f1b77a 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn-wmma-f16.cuh @@ -1,3 +1,51 @@ +#pragma once + #include "common.cuh" +#if (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) +#define GGML_USE_WMMA_FATTN +#endif // (!defined(GGML_USE_HIP) && __CUDA_ARCH__ >= GGML_CUDA_CC_VOLTA) || defined(GGML_USE_MUSA) + +#if defined(GGML_HIP_ROCWMMA_FATTN) +#if defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#define GGML_USE_WMMA_FATTN +#elif defined(CDNA) +#warning "rocwmma fattn on CDNA is broken on rocwmma v2.0.0, expect degraded performance" +#endif // defined(CDNA) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) +#if defined(RDNA3) +#define GGML_USE_WMMA_FATTN +#endif // defined(RDNA3) +#if defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#define GGML_USE_WMMA_FATTN +#elif defined(RDNA4) +#warning "rocwmma fattn is not suported on RDNA4 on rocwmma < v2.0.0, expect degraded performance" +#endif // defined(RDNA4) && ROCWMMA_VERSION_MAJOR > 1 +#endif // defined(GGML_HIP_ROCWMMA_FATTN) + +// WMMA flash attention requires FP16 matrix instructions to be available for ggml code. +static bool ggml_cuda_should_use_wmma_fattn(const int cc) { +#if defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) + return false; +#else + if ((GGML_CUDA_CC_IS_NVIDIA(cc) && ggml_cuda_highest_compiled_arch(cc) == GGML_CUDA_CC_VOLTA) || + GGML_CUDA_CC_IS_RDNA3(cc) || GGML_CUDA_CC_IS_MTHREADS(cc)) { + return true; + } else if (GGML_CUDA_CC_IS_CDNA(cc)){ +#if defined(GGML_HIP_ROCWMMA_FATTN) && (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) (ROCWMMA_VERSION_MAJOR < 2 || ROCWMMA_VERSION_MINOR > 0 || ROCWMMA_VERSION_PATCH > 0) + } else if (GGML_CUDA_CC_IS_RDNA4(cc)) { +#if defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + return true; +#else + return false; +#endif // defined(GGML_HIP_ROCWMMA_FATTN) && ROCWMMA_VERSION_MAJOR > 1 + } else { + return false; + } +#endif // defined(GGML_USE_HIP) && !defined(GGML_HIP_ROCWMMA_FATTN) +} + void ggml_cuda_flash_attn_ext_wmma_f16(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu index 1cbd4f5bd6..fe970adaec 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/fattn.cu @@ -198,6 +198,7 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; #endif// FLASH_ATTN_AVAILABLE + const ggml_tensor * KQV = dst; const ggml_tensor * Q = dst->src[0]; const ggml_tensor * K = dst->src[1]; const ggml_tensor * V = dst->src[2]; @@ -206,31 +207,32 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const const int gqa_ratio = Q->ne[2] / K->ne[2]; GGML_ASSERT(Q->ne[2] % K->ne[2] == 0); + float max_bias = 0.0f; + memcpy(&max_bias, (const float *) KQV->op_params + 1, sizeof(float)); + + // The effective batch size for the kernel can be increased by gqa_ratio. + // The kernel versions without this optimization are also used for ALiBi, if there is no mask, or if the KV cache is not padded, + const bool gqa_opt_applies = gqa_ratio % 2 == 0 && mask && max_bias == 0.0f && K->ne[1] % FATTN_KQ_STRIDE == 0; + const int cc = ggml_cuda_info().devices[device].cc; switch (K->ne[0]) { + case 40: case 64: - case 128: - case 256: - if (V->ne[0] != K->ne[0]) { - return BEST_FATTN_KERNEL_NONE; - } - break; case 80: case 96: + case 128: case 112: + case 256: if (V->ne[0] != K->ne[0]) { return BEST_FATTN_KERNEL_NONE; } - if (!fp16_mma_available(cc) && !turing_mma_available(cc)) { - return BEST_FATTN_KERNEL_NONE; - } break; case 576: if (V->ne[0] != 512) { return BEST_FATTN_KERNEL_NONE; } - if (!turing_mma_available(cc) || gqa_ratio % 16 != 0) { + if (!gqa_opt_applies || gqa_ratio % 16 != 0) { return BEST_FATTN_KERNEL_NONE; } break; @@ -264,47 +266,57 @@ static best_fattn_kernel ggml_cuda_get_best_fattn_kernel(const int device, const return BEST_FATTN_KERNEL_NONE; } - const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0; - - // If Turing tensor cores available, use them except for some cases with batch size 1: - if (turing_mma_available(cc)) { - best_fattn_kernel best = BEST_FATTN_KERNEL_MMA_F16; + // For small batch sizes the vector kernel may be preferable over the kernels optimized for large batch sizes: + const bool can_use_vector_kernel = Q->ne[0] <= 256 && Q->ne[0] % 64 == 0 && K->ne[1] % FATTN_KQ_STRIDE == 0; + // If Turing tensor cores available, use them: + if (turing_mma_available(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40) { if (can_use_vector_kernel) { if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { if (cc >= GGML_CUDA_CC_ADA_LOVELACE && Q->ne[1] == 1 && Q->ne[3] == 1 && !(gqa_ratio > 4 && K->ne[1] >= 8192)) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (cc >= GGML_CUDA_CC_ADA_LOVELACE) { if (Q->ne[1] <= 2) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } else { if (Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; + return BEST_FATTN_KERNEL_VEC; } } } - if ((gqa_ratio % 2 != 0 || !mask) && Q->ne[1] == 1) { - best = BEST_FATTN_KERNEL_VEC; // GQA-specific optimizations in the mma kernel do not apply. + if (!gqa_opt_applies && Q->ne[1] == 1) { + return BEST_FATTN_KERNEL_VEC; } } - return best; + return BEST_FATTN_KERNEL_MMA_F16; } - // Use kernels specialized for small batch sizes if possible: - if (Q->ne[1] <= 8 && can_use_vector_kernel) { - return BEST_FATTN_KERNEL_VEC; - } - - // For large batch sizes, use the WMMA kernel if possible: - if (fp16_mma_available(cc)) { + // Use the WMMA kernel if possible: + if (ggml_cuda_should_use_wmma_fattn(cc) && K->ne[1] % FATTN_KQ_STRIDE == 0 && Q->ne[0] != 40 && Q->ne[0] != 576) { + if (can_use_vector_kernel && Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } return BEST_FATTN_KERNEL_WMMA_F16; } - // If there is no suitable kernel for tensor cores or small batch sizes, use the generic kernel for large batch sizes: + // If there are no tensor cores available, use the generic tile kernel: + if (can_use_vector_kernel) { + if (K->type == GGML_TYPE_F16 && V->type == GGML_TYPE_F16) { + if (Q->ne[1] == 1) { + if (!gqa_opt_applies) { + return BEST_FATTN_KERNEL_VEC; + } + } + } else { + if (Q->ne[1] <= 2) { + return BEST_FATTN_KERNEL_VEC; + } + } + } return BEST_FATTN_KERNEL_TILE; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu index 3fa3a0575a..6a278b5e9b 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/ggml-cuda.cu @@ -291,7 +291,7 @@ static ggml_cuda_device_info ggml_cuda_init() { info.default_tensor_split[id] = total_vram; total_vram += prop.totalGlobalMem; - info.devices[id].integrated = prop.integrated; + info.devices[id].integrated = false; // Temporarily disabled due to issues with corrupted output (e.g. #15034) info.devices[id].nsm = prop.multiProcessorCount; info.devices[id].smpb = prop.sharedMemPerBlock; info.devices[id].warp_size = prop.warpSize; @@ -2466,6 +2466,9 @@ static bool ggml_cuda_compute_forward(ggml_backend_cuda_context & ctx, struct gg case GGML_UNARY_OP_ELU: ggml_cuda_op_elu(ctx, dst); break; + case GGML_UNARY_OP_XIELU: + ggml_cuda_op_xielu(ctx, dst); + break; default: return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu new file mode 100644 index 0000000000..a8b15ad72a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq112-dv112.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(112, 112); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu new file mode 100644 index 0000000000..1da1810550 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq128-dv128.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(128, 128); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu new file mode 100644 index 0000000000..bc65c723ec --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq256-dv256.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(256, 256); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu new file mode 100644 index 0000000000..10b330fa6c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq40-dv40.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(40, 40); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu new file mode 100644 index 0000000000..254b7d2e1d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq576-dv512.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(576, 512); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu new file mode 100644 index 0000000000..5caffac046 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq64-dv64.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(64, 64); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu new file mode 100644 index 0000000000..90abb3b186 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq80-dv80.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(80, 80); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu new file mode 100644 index 0000000000..7292c0aab8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-cuda/template-instances/fattn-tile-instance-dkq96-dv96.cu @@ -0,0 +1,5 @@ +// This file has been autogenerated by generate_cu_files.py, do not edit manually. + +#include "../fattn-tile.cuh" + +DECL_FATTN_TILE_CASE(96, 96); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu index 039f284719..afe4aee240 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/topk-moe.cu @@ -13,7 +13,7 @@ It is intended as fusion of softmax->top-k->get_rows pipeline for MoE models */ -template +template __launch_bounds__(4 * WARP_SIZE, 1) __global__ void topk_moe_cuda(const float * logits, float * weights, int32_t * ids, @@ -204,8 +204,6 @@ void ggml_cuda_op_topk_moe(ggml_backend_cuda_context & ctx, GGML_ASSERT(ids->nb[1] / ggml_type_size(ids->type) == (size_t) n_experts); - cudaStream_t stream = ctx.stream(); - const int n_expert_used = weights->ne[1]; if (with_norm) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu index 5aff8a876a..3c564566a5 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu +++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cu @@ -1,4 +1,5 @@ #include "unary.cuh" +#include "convert.cuh" static __device__ __forceinline__ float op_abs(float x) { return fabsf(x); @@ -375,6 +376,59 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) swiglu_oai_cuda(src0_p, src1_p, (float *)dst_d, ggml_nelements(dst), nc, src0_o / sizeof(float), src1_o / sizeof(float), alpha, limit, stream); } +/* CUDA kernel + launcher for xIELU */ + +template +static __global__ void xielu_kernel(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps) { + const int i = blockDim.x*blockIdx.x + threadIdx.x; + + if (i >= k) { + return; + } + + const float xi = ggml_cuda_cast(x[i]); + + const float gate_pos = (xi > 0.0f); + const float y_pos = alpha_p * xi * xi + beta * xi; + const float min_v_eps = fminf(xi, eps); + const float y_neg = (expm1f(min_v_eps) - xi) * alpha_n + beta * xi; + const float out = gate_pos * y_pos + (1.0f - gate_pos) * y_neg; + + dst[i] = ggml_cuda_cast(out); +} + +template +static void xielu_cuda(const T * x, T * dst, const int k, float alpha_n, float alpha_p, float beta, float eps, cudaStream_t stream) { + const int num_blocks = (k + CUDA_XIELU_BLOCK_SIZE) / CUDA_XIELU_BLOCK_SIZE; + xielu_kernel<<>>(x, dst, k, alpha_n, alpha_p, beta, eps); +} + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { + const ggml_tensor * src0 = dst->src[0]; + const void * src0_d = src0->data; + void * dst_d = dst->data; + cudaStream_t stream = ctx.stream(); + + GGML_ASSERT(ggml_is_contiguous(src0)); + + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT( dst->type == GGML_TYPE_F32 || dst->type == GGML_TYPE_F16); + GGML_ASSERT(src0->type == dst->type); + + const float alpha_n = ggml_get_op_params_f32(dst, 1); + const float alpha_p = ggml_get_op_params_f32(dst, 2); + const float beta = ggml_get_op_params_f32(dst, 3); + const float eps = ggml_get_op_params_f32(dst, 4); + + if (src0->type == GGML_TYPE_F16) { + xielu_cuda((const half *)src0_d, (half *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } else { + xielu_cuda((const float *)src0_d, (float *)dst_d, ggml_nelements(src0), alpha_n, alpha_p, beta, eps, stream); + } +} + + + /* silu_back */ static __device__ __forceinline__ float op_silu_back(float grad, float x) { diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh index da3caf1d89..8e7644fcd9 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh +++ b/ml/backend/ggml/ggml/src/ggml-cuda/unary.cuh @@ -16,6 +16,7 @@ #define CUDA_SIN_BLOCK_SIZE 256 #define CUDA_COS_BLOCK_SIZE 256 #define CUDA_GLU_BLOCK_SIZE 256 +#define CUDA_XIELU_BLOCK_SIZE 256 void ggml_cuda_op_abs(ggml_backend_cuda_context & ctx, ggml_tensor * dst); @@ -72,3 +73,5 @@ void ggml_cuda_op_swiglu_oai(ggml_backend_cuda_context & ctx, ggml_tensor * dst) void ggml_cuda_op_geglu_erf(ggml_backend_cuda_context & ctx, ggml_tensor * dst); void ggml_cuda_op_geglu_quick(ggml_backend_cuda_context & ctx, ggml_tensor * dst); + +void ggml_cuda_op_xielu(ggml_backend_cuda_context & ctx, ggml_tensor * dst); diff --git a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h index eb8f66cb07..2f9ef2dc00 100644 --- a/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h +++ b/ml/backend/ggml/ggml/src/ggml-cuda/vendors/hip.h @@ -8,6 +8,9 @@ // for rocblas_initialize() #include "rocblas/rocblas.h" +#if defined(GGML_HIP_ROCWMMA_FATTN) +#include +#endif // defined(GGML_HIP_ROCWMMA_FATTN) #define CUBLAS_GEMM_DEFAULT HIPBLAS_GEMM_DEFAULT #define CUBLAS_GEMM_DEFAULT_TENSOR_OP HIPBLAS_GEMM_DEFAULT diff --git a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt index d327b90cce..934aefdcb4 100644 --- a/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt +++ b/ml/backend/ggml/ggml/src/ggml-hip/CMakeLists.txt @@ -39,12 +39,6 @@ endif() find_package(hip REQUIRED) find_package(hipblas REQUIRED) find_package(rocblas REQUIRED) -if (GGML_HIP_ROCWMMA_FATTN) - CHECK_INCLUDE_FILE_CXX("rocwmma/rocwmma.hpp" FOUND_ROCWMMA) - if (NOT ${FOUND_ROCWMMA}) - message(FATAL_ERROR "rocwmma has not been found") - endif() -endif() if (${hip_VERSION} VERSION_LESS 6.1) message(FATAL_ERROR "At least ROCM/HIP V6.1 is required") @@ -59,6 +53,8 @@ file(GLOB GGML_HEADERS_ROCM "../ggml-cuda/*.cuh") list(APPEND GGML_HEADERS_ROCM "../../include/ggml-cuda.h") file(GLOB GGML_SOURCES_ROCM "../ggml-cuda/*.cu") +file(GLOB SRCS "../ggml-cuda/template-instances/fattn-tile*.cu") +list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/fattn-mma*.cu") list(APPEND GGML_SOURCES_ROCM ${SRCS}) file(GLOB SRCS "../ggml-cuda/template-instances/mmq*.cu") @@ -117,10 +113,6 @@ if (NOT GGML_HIP_MMQ_MFMA) add_compile_definitions(GGML_HIP_NO_MMQ_MFMA) endif() -if (GGML_HIP_FORCE_ROCWMMA_FATTN_GFX12 OR ${hip_VERSION} VERSION_GREATER_EQUAL 7.0) - add_compile_definitions(GGML_HIP_ROCWMMA_FATTN_GFX12) -endif() - if (GGML_HIP_EXPORT_METRICS) set(CMAKE_HIP_FLAGS "${CMAKE_HIP_FLAGS} -Rpass-analysis=kernel-resource-usage --save-temps") endif() diff --git a/ml/backend/ggml/ggml/src/ggml-impl.h b/ml/backend/ggml/ggml/src/ggml-impl.h index 9fc9fbfcfd..80597b6ea8 100644 --- a/ml/backend/ggml/ggml/src/ggml-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-impl.h @@ -102,6 +102,9 @@ static bool ggml_op_is_empty(enum ggml_op op) { } } +static inline float ggml_softplus(float input) { + return (input > 20.0f) ? input : logf(1 + expf(input)); +} // // logging // diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp index dc7d241c3a..95627d3866 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-common.cpp @@ -112,7 +112,7 @@ static bool ggml_mem_ranges_add_dst(ggml_mem_ranges_t mrs, const ggml_tensor * t } bool ggml_mem_ranges_add(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { - for (int i = 0; i < GGML_MAX_DIMS; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i]) { ggml_mem_ranges_add_src(mrs, tensor->src[i]); } @@ -173,7 +173,7 @@ static bool ggml_mem_ranges_check_dst(ggml_mem_ranges_t mrs, const ggml_tensor * } bool ggml_mem_ranges_check(ggml_mem_ranges_t mrs, const ggml_tensor * tensor) { - for (int i = 0; i < GGML_MAX_DIMS; i++) { + for (int i = 0; i < GGML_MAX_SRC; i++) { if (tensor->src[i]) { if (!ggml_mem_ranges_check_src(mrs, tensor->src[i])) { return false; diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp index 819f31c8a3..866cd2da58 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.cpp @@ -268,6 +268,25 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu(ggml_metal_library_t l return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_SUM); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_op_sum_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows(ggml_metal_library_t lib, const ggml_tensor * op) { GGML_ASSERT(op->src[0]->nb[0] == ggml_type_size(op->src[0]->type)); @@ -338,7 +357,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar char base[256]; char name[256]; - snprintf(base, 256, "kernel_ssm_conv_%s_%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type)); + const char * suffix = ""; + + if (op->src[1]->ne[0] % 4 == 0) { + suffix = "_4"; + } + + snprintf(base, 256, "kernel_ssm_conv_%s_%s%s", ggml_type_name(op->src[0]->type), ggml_type_name(op->src[1]->type), suffix); snprintf(name, 256, "%s", base); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); @@ -352,15 +377,15 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv(ggml_metal_librar } ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_library_t lib, const ggml_tensor * op) { + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + char base[256]; char name[256]; - if (op->src[3]->ne[0] == 1) { - snprintf(base, 256, "kernel_ssm_scan_group_%s", ggml_type_name(op->src[0]->type)); - } else { - snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); - } - snprintf(name, 256, "%s", base); + const int nsg = (ne00 + 31)/32; + + snprintf(base, 256, "kernel_ssm_scan_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s_nsg=%d", base, nsg); ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); if (res) { @@ -369,7 +394,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_scan(ggml_metal_librar res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); - ggml_metal_pipeline_set_smem(res, 32*sizeof(float)); + ggml_metal_pipeline_set_smem(res, 32*sizeof(float)*nsg); return res; } @@ -918,6 +943,96 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_argsort(ggml_metal_library return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_pad"); + + snprintf(name, 256, "%s_mask=%d_ncpsg=%d", + base, + has_mask, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_PAD + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_PAD + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_PAD + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_PAD + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_PAD + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_PAD + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_PAD + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_PAD + 23); + //ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_PAD + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_PAD + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + GGML_UNUSED(op); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_%s", + "flash_attn_ext_blk"); + + snprintf(name, 256, "%s_nqptg=%d_ncpsg=%d", + base, + nqptg, + ncpsg); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + ggml_metal_cv_t cv = ggml_metal_cv_init(); + + //ggml_metal_cv_set_bool(cv, has_mask, FC_FLASH_ATTN_EXT_BLK + 0); + //ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_BLK + 1); + //ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_BLK + 2); + //ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_BLK + 3); + + //ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_BLK + 20); + //ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_BLK + 21); + //ggml_metal_cv_set_int32(cv, nsg, FC_FLASH_ATTN_EXT_BLK + 22); + //ggml_metal_cv_set_int32(cv, nwg, FC_FLASH_ATTN_EXT_BLK + 23); + ggml_metal_cv_set_int32(cv, nqptg, FC_FLASH_ATTN_EXT_BLK + 24); + ggml_metal_cv_set_int32(cv, ncpsg, FC_FLASH_ATTN_EXT_BLK + 25); + + res = ggml_metal_library_compile_pipeline(lib, base, name, cv); + + ggml_metal_cv_free(cv); + + return res; +} + ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, const ggml_tensor * op, @@ -925,6 +1040,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -937,18 +1053,23 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( const int32_t ns10 = op->src[1]->nb[1]/op->src[1]->nb[0]; const int32_t ns20 = op->src[2]->nb[1]/op->src[2]->nb[0]; + // do bounds checks for the mask? + const bool bc_mask = op->src[3] && (op->src[3]->ne[1] % 8 != 0); + snprintf(base, 256, "kernel_%s_%s_dk%d_dv%d", "flash_attn_ext", ggml_type_name(op->src[1]->type), dk, dv); - snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_ns10=%d_ns20=%d_nsg=%d", + snprintf(name, 256, "%s_mask=%d_sinks=%d_bias=%d_scap=%d_kvpad=%d_bcm=%d_ns10=%d_ns20=%d_nsg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, + bc_mask, ns10, ns20, nsg); @@ -964,6 +1085,9 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT + 4); + + ggml_metal_cv_set_bool(cv, bc_mask, FC_FLASH_ATTN_EXT + 10); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT + 21); @@ -983,6 +1107,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); @@ -1002,12 +1127,13 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( dk, dv); - snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_softcap=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", + snprintf(name, 256, "%s_mask=%d_sink=%d_bias=%d_scap=%d_kvpad=%d_ns10=%d_ns20=%d_nsg=%d_nwg=%d", base, has_mask, has_sinks, has_bias, has_scap, + has_kvpad, ns10, ns20, nsg, nwg); @@ -1023,6 +1149,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( ggml_metal_cv_set_bool(cv, has_sinks, FC_FLASH_ATTN_EXT_VEC + 1); ggml_metal_cv_set_bool(cv, has_bias, FC_FLASH_ATTN_EXT_VEC + 2); ggml_metal_cv_set_bool(cv, has_scap, FC_FLASH_ATTN_EXT_VEC + 3); + ggml_metal_cv_set_bool(cv, has_kvpad, FC_FLASH_ATTN_EXT_VEC + 4); ggml_metal_cv_set_int32(cv, ns10, FC_FLASH_ATTN_EXT_VEC + 20); ggml_metal_cv_set_int32(cv, ns20, FC_FLASH_ATTN_EXT_VEC + 21); @@ -1374,3 +1501,40 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_me return res; } +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_ADAMW); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_adamw_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd(ggml_metal_library_t lib, const ggml_tensor * op) { + assert(op->op == GGML_OP_OPT_STEP_SGD); + + char base[256]; + char name[256]; + + snprintf(base, 256, "kernel_opt_step_sgd_%s", ggml_type_name(op->src[0]->type)); + snprintf(name, 256, "%s", base); + + ggml_metal_pipeline_t res = ggml_metal_library_get_pipeline(lib, name); + if (res) { + return res; + } + + res = ggml_metal_library_compile_pipeline(lib, base, name, nullptr); + + return res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h index f6ebf90a00..28ae2e1765 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.h @@ -109,6 +109,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_set_rows (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_repeat (ggml_metal_library_t lib, enum ggml_type tsrc); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_unary (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_glu (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_sum_rows (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_soft_max (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_ssm_conv (ggml_metal_library_t lib, const struct ggml_tensor * op); @@ -134,6 +135,20 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad (ggml_me ggml_metal_pipeline_t ggml_metal_library_get_pipeline_pad_reflect_1d (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_arange (ggml_metal_library_t lib, const struct ggml_tensor * op); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_timestep_embedding(ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_adamw (ggml_metal_library_t lib, const struct ggml_tensor * op); +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_opt_step_sgd (ggml_metal_library_t lib, const struct ggml_tensor * op); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_pad( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + bool has_mask, + int32_t ncpsg); + +ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_blk( + ggml_metal_library_t lib, + const struct ggml_tensor * op, + int32_t nqptg, + int32_t ncpsg); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( ggml_metal_library_t lib, @@ -142,6 +157,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg); ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( @@ -151,6 +167,7 @@ ggml_metal_pipeline_t ggml_metal_library_get_pipeline_flash_attn_ext_vec( bool has_sinks, bool has_bias, bool has_scap, + bool has_kvpad, int32_t nsg, int32_t nwg); diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m index 523f9d71ba..fc5083043f 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-device.m @@ -656,6 +656,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te case GGML_OP_COS: case GGML_OP_LOG: return ggml_is_contiguous(op->src[0]) && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_SUM: case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: case GGML_OP_SOFT_MAX: @@ -776,9 +777,7 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te }; } case GGML_OP_GET_ROWS: - { - return op->ne[3] == 1; - } + return true; case GGML_OP_SET_ROWS: { if (op->src[0]->type != GGML_TYPE_F32) { @@ -800,6 +799,9 @@ bool ggml_metal_device_supports_op(ggml_metal_device_t dev, const struct ggml_te return false; }; } + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + return has_simdgroup_reduction; default: return false; } diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal index 6bfdd87e8a..9c0e0c56b2 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-embed.metal @@ -1954,11 +1954,20 @@ GGML_TABLE_END() #define N_SG_IQ4_XS 2 // function constants offsets -#define FC_FLASH_ATTN_EXT 100 -#define FC_FLASH_ATTN_EXT_VEC 200 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 -#define FC_MUL_MV 400 -#define FC_MUL_MM 500 +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs // @@ -2063,6 +2072,7 @@ typedef struct { } ggml_metal_kargs_clamp; typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -2128,6 +2138,35 @@ typedef struct { int32_t sect_3; } ggml_metal_kargs_rope; +typedef struct { + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + typedef struct { int32_t ne01; int32_t ne02; @@ -2146,6 +2185,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -2180,6 +2220,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -2388,6 +2429,10 @@ typedef struct{ float limit; } ggml_metal_kargs_glu; +typedef struct { + uint64_t np; +} ggml_metal_kargs_sum; + typedef struct { int64_t ne00; int64_t ne01; @@ -2457,32 +2502,45 @@ typedef struct { int64_t n_seq_tokens; int64_t n_seqs; uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { @@ -2604,6 +2662,14 @@ typedef struct { uint64_t nb01; } ggml_metal_kargs_argmax; +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_adamw; + +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_sgd; + #endif // GGML_METAL_IMPL #include @@ -4322,6 +4388,24 @@ kernel void kernel_geglu_quick_f32( } } +kernel void kernel_op_sum_f32( + constant ggml_metal_kargs_sum & args, + device const float * src0, + device float * dst, + ushort tiitg[[thread_index_in_threadgroup]]) { + + if (tiitg != 0) { + return; + } + + float acc = 0.0f; + for (ulong i = 0; i < args.np; ++i) { + acc += src0[i]; + } + + dst[0] = acc; +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -4631,7 +4715,38 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -4643,219 +4758,88 @@ kernel void kernel_ssm_scan_f32( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; - } - - // Assign the final state to the output buffer - s_buff[i] = s; -} - -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_group_f32( - constant ggml_metal_kargs_ssm_scan & args, - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { - - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq - - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); - - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; - - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); - - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -7112,10 +7096,142 @@ kernel void kernel_leaky_relu_f32_4( dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); } +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; @@ -7162,6 +7278,8 @@ void kernel_flash_attn_ext_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -7227,6 +7345,13 @@ void kernel_flash_attn_ext_impl( pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); } + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + { q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; @@ -7286,16 +7411,75 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = 0; ic < args.ne11; ic += C) { + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; + if (ic >= args.ne11) { + break; + } + + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } + } + + ic = 0; + } + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } + + continue; + } + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; - sm2[j*SH + tiisg] = pm2[jj][tiisg]; + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + pm2[jj] += NW; } +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + threadgroup_barrier(mem_flags::mem_threadgroup); // used to detect blocks full of -INF @@ -7314,13 +7498,14 @@ void kernel_flash_attn_ext_impl( continue; } +#endif } // Q*K^T // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); threadgroup const q_t * pq = sq; threadgroup s_t * ps = ss; @@ -7331,26 +7516,24 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // TODO: not good to unroll for large contexts - not sure why? + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - if (DK8 % 16 != 0) { + if (DK % 16 != 0) { k8x8_t mk; q8x8_t mq; FOR_UNROLL (short i = 0; i < DK8; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, pk, NS10, 0, true); - simdgroup_load(mq, pq, DK); + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - - pk += 8; - pq += 8; } } else { k8x8_t mk[2]; @@ -7359,26 +7542,22 @@ void kernel_flash_attn_ext_impl( FOR_UNROLL (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); - simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - simdgroup_load(mq[0], pq + 0*8, DK); - simdgroup_load(mq[1], pq + 1*8, DK); + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - - pk += 16; - pq += 16; } } simdgroup_store(mqk, ps, SH, 0, false); - pk += 8*(NSG*NS10 - DK8); - pq += 8*(NSG*0 - DK8); + pk += 8*(NSG*NS10); ps += 8*(NSG); } } else { @@ -7392,7 +7571,7 @@ void kernel_flash_attn_ext_impl( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -7512,27 +7691,50 @@ void kernel_flash_attn_ext_impl( } { - auto sst = ss; - - device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; - FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, sst, SH, 0, false); + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - v8x8_t mv; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; - simdgroup_load(mv, pv, NS20, 0, false); - simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - pv += 8*NSG; + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; - pv += 8*(NS20 - NO*NSG); - sst += 8; + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } } } @@ -7556,7 +7758,7 @@ void kernel_flash_attn_ext_impl( simdgroup_load(vs, ss + 8*cc, SH, 0, false); for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -7646,7 +7848,7 @@ void kernel_flash_attn_ext_impl( device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float scale = 1.0f/S[jj]; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; if (DV4 % NW == 0) { FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { @@ -7691,8 +7893,8 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = 8, // queries per threadgroup - short C = 64> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -7700,13 +7902,15 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -7826,6 +8030,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_ constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; @@ -7852,9 +8057,9 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32, // cache items per threadgroup + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup short NSG> // number of simd groups void kernel_flash_attn_ext_vec_impl( constant ggml_metal_kargs_flash_attn_ext_vec & args, @@ -7863,6 +8068,7 @@ void kernel_flash_attn_ext_vec_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -7968,12 +8174,38 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - const int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -7985,7 +8217,7 @@ void kernel_flash_attn_ext_vec_impl( // Q*K^T { - device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); threadgroup const q4_t * pq4 = sq4; pk4 += ty*NS10/4 + tx; @@ -8000,7 +8232,7 @@ void kernel_flash_attn_ext_vec_impl( mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); } } else { - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; @@ -8098,7 +8330,7 @@ void kernel_flash_attn_ext_vec_impl( } if (is_same::value) { - device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); pv4 += ty*NS20/4 + tx; @@ -8111,7 +8343,7 @@ void kernel_flash_attn_ext_vec_impl( } } else { FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { const short i = ii*NL + tx; @@ -8236,7 +8468,7 @@ void kernel_flash_attn_ext_vec_impl( device float4 * dst4 = (device float4 *) dst; device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { @@ -8274,8 +8506,8 @@ template< short DK, // K head size short DV, // V head size short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, @@ -8283,13 +8515,14 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_vec_nsg) { // note: disabled cases to reduce library load time case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; @@ -8413,7 +8646,8 @@ kernel void kernel_flash_attn_ext_vec_reduce( const float m = simd_max(M); const float ms = exp(M - m); - S = 1.0f/simd_sum(S*ms); + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; const short DV4 = DV/4; @@ -8433,21 +8667,17 @@ kernel void kernel_flash_attn_ext_vec_reduce( } template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -8458,190 +8688,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -8649,11 +8759,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -8665,10 +8776,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -10121,7 +10234,7 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -10134,13 +10247,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -10151,6 +10263,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -10158,24 +10273,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -10206,7 +10322,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -10228,7 +10344,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -10241,12 +10357,11 @@ void kernel_mul_mv_iq4_xs_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -10257,6 +10372,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -10266,7 +10384,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -10275,15 +10393,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -10313,7 +10432,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -10335,7 +10454,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -10348,13 +10467,12 @@ void kernel_mul_mv_mxfp4_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -10365,6 +10483,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -10372,20 +10493,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -10403,7 +10526,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -10428,66 +10551,60 @@ kernel void kernel_mul_mv_mxfp4_f32( template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + break; } } @@ -10973,12 +11090,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; @@ -11383,3 +11501,51 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_opt_step_adamw_f32( + constant ggml_metal_kargs_opt_step_adamw & args, + device float * x, + device const float * g, + device float * g_m, + device float * g_v, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[gid]; + const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1); + const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2); + + g_m[gid] = gmi; + g_v[gid] = gvi; + + const float mh = gmi * beta1h; + const float vh = sqrt(gvi * beta2h) + eps; + + x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; +} + +kernel void kernel_opt_step_sgd_f32( + constant ggml_metal_kargs_opt_step_sgd & args, + device float * x, + device const float * g, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h index 88c98423eb..a448c14f66 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-impl.h @@ -69,11 +69,20 @@ #define N_SG_IQ4_XS 2 // function constants offsets -#define FC_FLASH_ATTN_EXT 100 -#define FC_FLASH_ATTN_EXT_VEC 200 -#define FC_FLASH_ATTN_EXT_VEC_REDUCE 300 -#define FC_MUL_MV 400 -#define FC_MUL_MM 500 +#define FC_FLASH_ATTN_EXT_PAD 100 +#define FC_FLASH_ATTN_EXT_BLK 200 +#define FC_FLASH_ATTN_EXT 300 +#define FC_FLASH_ATTN_EXT_VEC 400 +#define FC_FLASH_ATTN_EXT_VEC_REDUCE 500 +#define FC_MUL_MV 600 +#define FC_MUL_MM 700 + +// op-specific constants +#define OP_FLASH_ATTN_EXT_NQPTG 8 +#define OP_FLASH_ATTN_EXT_NCPSG 64 + +#define OP_FLASH_ATTN_EXT_VEC_NQPTG 1 +#define OP_FLASH_ATTN_EXT_VEC_NCPSG 32 // kernel argument structs // @@ -178,6 +187,7 @@ typedef struct { } ggml_metal_kargs_clamp; typedef struct { + int64_t nk0; int64_t ne00; int64_t ne01; int64_t ne02; @@ -243,6 +253,35 @@ typedef struct { int32_t sect_3; } ggml_metal_kargs_rope; +typedef struct { + int32_t ne11; + int32_t ne_12_2; // assume K and V are same shape + int32_t ne_12_3; + uint64_t nb11; + uint64_t nb12; + uint64_t nb13; + uint64_t nb21; + uint64_t nb22; + uint64_t nb23; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_pad; + +typedef struct { + int32_t ne01; + int32_t ne30; + int32_t ne31; + int32_t ne32; + int32_t ne33; + uint64_t nb31; + uint64_t nb32; + uint64_t nb33; +} ggml_metal_kargs_flash_attn_ext_blk; + typedef struct { int32_t ne01; int32_t ne02; @@ -261,6 +300,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -295,6 +335,7 @@ typedef struct { uint64_t nb21; uint64_t nb22; uint64_t nb23; + int32_t ne31; int32_t ne32; int32_t ne33; uint64_t nb31; @@ -503,6 +544,10 @@ typedef struct{ float limit; } ggml_metal_kargs_glu; +typedef struct { + uint64_t np; +} ggml_metal_kargs_sum; + typedef struct { int64_t ne00; int64_t ne01; @@ -572,32 +617,45 @@ typedef struct { int64_t n_seq_tokens; int64_t n_seqs; uint64_t s_off; + uint64_t nb00; uint64_t nb01; uint64_t nb02; uint64_t nb03; + uint64_t nb10; uint64_t nb11; uint64_t nb12; + uint64_t ns12; uint64_t nb13; + uint64_t nb20; uint64_t nb21; + uint64_t ns21; uint64_t nb22; + int64_t ne30; uint64_t nb31; uint64_t nb41; uint64_t nb42; + uint64_t ns42; uint64_t nb43; uint64_t nb51; uint64_t nb52; + uint64_t ns52; uint64_t nb53; + uint64_t nb0; } ggml_metal_kargs_ssm_scan; typedef struct { - int64_t ne00; + int32_t ne00t; + int32_t ne00; uint64_t nb01; uint64_t nb02; - int64_t ne10; + uint64_t nb03; + int32_t ne10; uint64_t nb10; uint64_t nb11; + uint64_t nb12; uint64_t nb1; uint64_t nb2; + uint64_t nb3; } ggml_metal_kargs_get_rows; typedef struct { @@ -719,4 +777,12 @@ typedef struct { uint64_t nb01; } ggml_metal_kargs_argmax; +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_adamw; + +typedef struct { + int64_t np; +} ggml_metal_kargs_opt_step_sgd; + #endif // GGML_METAL_IMPL diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp index e85a223c01..a61ea8fb5a 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.cpp @@ -226,6 +226,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS(uint64_t, nb0, node->src[0], nb); GGML_TENSOR_LOCALS( int64_t, ne1, node->src[1], ne); GGML_TENSOR_LOCALS(uint64_t, nb1, node->src[1], nb); + GGML_TENSOR_LOCALS( int64_t, ne2, node->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, node->src[2], nb); + GGML_TENSOR_LOCALS( int64_t, ne3, node->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, node->src[3], nb); GGML_TENSOR_LOCALS( int64_t, ne, node, ne); GGML_TENSOR_LOCALS(uint64_t, nb, node, nb); @@ -237,6 +241,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { GGML_LOG_DEBUG("%s: src1 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[1]->type), ne10, ne11, ne12, ne13, nb10, nb11, nb12, nb13, ggml_is_contiguous(node->src[1]), node->src[1]->name); } + if (node->src[2]) { + GGML_LOG_DEBUG("%s: src2 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[2]->type), ne20, ne21, ne22, ne23, nb20, nb21, nb22, nb23, + ggml_is_contiguous(node->src[2]), node->src[2]->name); + } + if (node->src[3]) { + GGML_LOG_DEBUG("%s: src3 - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], %d, %s\n", __func__, ggml_type_name(node->src[3]->type), ne30, ne31, ne32, ne33, nb30, nb31, nb32, nb33, + ggml_is_contiguous(node->src[3]), node->src[3]->name); + } if (node) { GGML_LOG_DEBUG("%s: node - %4s [%5lld, %5lld, %5lld, %5lld] [%5lld, %5lld, %5lld, %5lld], 1, %s\n", __func__, ggml_type_name(node->type), ne0, ne1, ne2, ne3, nb0, nb1, nb2, nb3, node->name); @@ -289,6 +301,10 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_glu(ctx, idx); } break; + case GGML_OP_SUM: + { + n_fuse = ggml_metal_op_sum(ctx, idx); + } break; case GGML_OP_SUM_ROWS: case GGML_OP_MEAN: { @@ -398,6 +414,14 @@ static int ggml_metal_op_encode_impl(ggml_metal_op_t ctx, int idx) { { n_fuse = ggml_metal_op_argmax(ctx, idx); } break; + case GGML_OP_OPT_STEP_ADAMW: + { + n_fuse = ggml_metal_op_opt_step_adamw(ctx, idx); + } break; + case GGML_OP_OPT_STEP_SGD: + { + n_fuse = ggml_metal_op_opt_step_sgd(ctx, idx); + } break; default: { GGML_LOG_ERROR("%s: error: node %3d, op = %8s not implemented\n", __func__, idx, ggml_op_name(node->op)); @@ -577,6 +601,7 @@ int ggml_metal_op_acc(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_cpy(lib, op->src[0]->type, op->type); ggml_metal_kargs_cpy args = { + /*.nk0 =*/ ne00, /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, @@ -827,6 +852,30 @@ int ggml_metal_op_glu(ggml_metal_op_t ctx, int idx) { return 1; } +int ggml_metal_op_sum(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + const uint64_t n = (uint64_t) ggml_nelements(op->src[0]); + + ggml_metal_kargs_sum args = { + /*.np =*/ n, + }; + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_sum(lib, op); + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); + + ggml_metal_encoder_dispatch_threadgroups(enc, 1, 1, 1, 1, 1, 1); + + return 1; +} + int ggml_metal_op_sum_rows(ggml_metal_op_t ctx, int idx) { ggml_tensor * op = ctx->node(idx); @@ -906,23 +955,31 @@ int ggml_metal_op_get_rows(ggml_metal_op_t ctx, int idx) { ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_get_rows(lib, op->src[0]->type); ggml_metal_kargs_get_rows args = { - /*.ne00 =*/ ne00, - /*.nb01 =*/ nb01, - /*.nb02 =*/ nb02, - /*.ne10 =*/ ne10, - /*.nb10 =*/ nb10, - /*.nb11 =*/ nb11, - /*.nb1 =*/ nb1, - /*.nb2 =*/ nb2, + /*.ne00t =*/ ggml_is_quantized(op->src[0]->type) ? ne00/16 : ne00, + /*.ne00 =*/ ne00, + /*.nb01 =*/ nb01, + /*.nb02 =*/ nb02, + /*.nb03 =*/ nb03, + /*.ne10 =*/ ne10, + /*.nb10 =*/ nb10, + /*.nb11 =*/ nb11, + /*.nb12 =*/ nb12, + /*.nb1 =*/ nb1, + /*.nb2 =*/ nb2, + /*.nb3 =*/ nb3, }; + const int nth = std::min(args.ne00t, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + + const int nw0 = (args.ne00t + nth - 1)/nth; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 3); - ggml_metal_encoder_dispatch_threadgroups(enc, ne10, ne11, ne12, 32, 1, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*ne10, ne11, ne12, nth, 1, 1); return 1; } @@ -1117,7 +1174,7 @@ int ggml_metal_op_ssm_conv(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_bytes(enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); + ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 3); ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne1, ne02, 1, 1, 1); @@ -1172,25 +1229,36 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { /*.n_seq_tokens =*/ n_seq_tokens, /*.n_seqs =*/ n_seqs, /*.s_off =*/ ggml_nelements(op->src[1]) * sizeof(float), + /*.nb00 =*/ nb00, /*.nb01 =*/ nb01, /*.nb02 =*/ nb02, /*.nb03 =*/ nb03, + /*.nb10 =*/ nb10, /*.nb11 =*/ nb11, /*.nb12 =*/ nb12, + /*.ns12 =*/ nb12/nb10, /*.nb13 =*/ nb13, + /*.nb20 =*/ nb20, /*.nb21 =*/ nb21, + /*.ns21 =*/ nb21/nb20, /*.nb22 =*/ nb22, + /*.ne30 =*/ ne30, /*.nb31 =*/ nb31, /*.nb41 =*/ nb41, /*.nb42 =*/ nb42, + /*.ns42 =*/ nb42/nb40, /*.nb43 =*/ nb43, /*.nb51 =*/ nb51, /*.nb52 =*/ nb52, + /*.ns52 =*/ nb52/nb50, /*.nb53 =*/ nb53, + /*.nb0 =*/ nb0, }; ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_ssm_scan(lib, op); + GGML_ASSERT(d_state <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + const size_t sms = ggml_metal_pipeline_get_smem(pipeline); ggml_metal_encoder_set_pipeline(enc, pipeline); @@ -1206,13 +1274,7 @@ int ggml_metal_op_ssm_scan(ggml_metal_op_t ctx, int idx) { ggml_metal_encoder_set_threadgroup_memory_size(enc, sms, 0); - if (ne30 == 1) { - // Mamba-2 - ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); - } else { - GGML_ASSERT(d_inner == 1); - ggml_metal_encoder_dispatch_threadgroups(enc, n_head, n_seqs, 1, d_state, 1, 1); - } + ggml_metal_encoder_dispatch_threadgroups(enc, d_inner, n_head, n_seqs, d_state, 1, 1); return 1; } @@ -1273,26 +1335,23 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne00 % ggml_blck_size(op->src[0]->type) == 0); - // TODO: support - //const int32_t nk00 = ne00/ggml_blck_size(op->type); - const int32_t nk00 = ne00; - - int nth = 32; // SIMD width - - while (nth < nk00 && nth < ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { - nth *= 2; + int64_t nk0 = ne00; + if (ggml_is_quantized(op->src[0]->type)) { + nk0 = ne00/16; + } else if (ggml_is_quantized(op->type)) { + nk0 = ne00/ggml_blck_size(op->type); } - nth = std::min(nth, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); + int nth = std::min(nk0, ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); // when rows are small, we can batch them together in a single threadgroup int nrptg = 1; // TODO: relax this constraint in the future if (ggml_blck_size(op->src[0]->type) == 1 && ggml_blck_size(op->type) == 1) { - if (nth > nk00) { - nrptg = (nth + nk00 - 1)/nk00; - nth = nk00; + if (nth > nk0) { + nrptg = (nth + nk0 - 1)/nk0; + nth = nk0; if (nrptg*nth > ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)) { nrptg--; @@ -1300,10 +1359,11 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { } } - nth = std::min(nth, nk00); + nth = std::min(nth, nk0); ggml_metal_kargs_cpy args = { - /*.ne00 =*/ nk00, + /*.nk0 =*/ nk0, + /*.ne00 =*/ ne00, /*.ne01 =*/ ne01, /*.ne02 =*/ ne02, /*.ne03 =*/ ne03, @@ -1321,12 +1381,14 @@ int ggml_metal_op_cpy(ggml_metal_op_t ctx, int idx) { /*.nb3 =*/ nb3, }; + const int nw0 = nrptg == 1 ? (nk0 + nth - 1)/nth : 1; + ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 2); - ggml_metal_encoder_dispatch_threadgroups(enc, ne01, ne02, ne03, nth, nrptg, 1); + ggml_metal_encoder_dispatch_threadgroups(enc, nw0*(ne01 + nrptg - 1)/nrptg, ne02, ne03, nth, nrptg, 1); return 1; } @@ -1520,9 +1582,8 @@ int ggml_metal_op_mul_mat(ggml_metal_op_t ctx, int idx) { !ggml_is_transposed(op->src[1]) && // for now the matrix-matrix multiplication kernel only works on A14+/M1+ SoCs // AMD GPU and older A-chips will reuse matrix-vector multiplication kernel - props_dev->has_simdgroup_mm && ne00 >= 64 && - (ne11 > ne11_mm_min || (ggml_is_quantized(op->src[0]->type) && ne12 > 1))) { - //printf("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); + props_dev->has_simdgroup_mm && ne00 >= 64 && ne11 > ne11_mm_min) { + //GGML_LOG_INFO("matrix: ne00 = %6d, ne01 = %6d, ne02 = %6d, ne11 = %6d, ne12 = %6d\n", ne00, ne01, ne02, ne11, ne12); // some Metal matrix data types require aligned pointers // ref: https://developer.apple.com/metal/Metal-Shading-Language-Specification.pdf (Table 2.5) @@ -1875,20 +1936,107 @@ bool ggml_metal_op_flash_attn_ext_use_vec(const ggml_tensor * op) { return (ne01 < 20) && (ne00 % 32 == 0); } +size_t ggml_metal_op_flash_attn_ext_extra_pad(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_VEC_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_VEC_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } else { + const bool has_kvpad = ne11 % OP_FLASH_ATTN_EXT_NCPSG != 0; + + if (has_kvpad) { + res += OP_FLASH_ATTN_EXT_NCPSG*( + nb11*ne12*ne13 + + nb21*ne22*ne23 + + (has_mask ? ggml_type_size(GGML_TYPE_F16)*ne31*ne32*ne33 : 0)); + } + } + + return res; +} + +size_t ggml_metal_op_flash_attn_ext_extra_blk(const ggml_tensor * op) { + assert(op->op == GGML_OP_FLASH_ATTN_EXT); + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + //GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); + + size_t res = 0; + + const bool has_mask = op->src[3] != nullptr; + + if (!has_mask) { + return res; + } + + const bool is_vec = ggml_metal_op_flash_attn_ext_use_vec(op); + + // this optimization is not useful for the vector kernels + if (is_vec) { + return res; + } + + const int nqptg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NQPTG : OP_FLASH_ATTN_EXT_NQPTG; + const int ncpsg = is_vec ? OP_FLASH_ATTN_EXT_VEC_NCPSG : OP_FLASH_ATTN_EXT_NCPSG; + + const int64_t ne1 = (ne01 + nqptg - 1)/nqptg; + const int64_t ne0 = (ne30 + ncpsg - 1)/ncpsg; + + res += GGML_PAD(ggml_type_size(GGML_TYPE_I8)*ne0*ne1*ne32*ne33, 32); + + return res; +} + size_t ggml_metal_op_flash_attn_ext_extra_tmp(const ggml_tensor * op) { assert(op->op == GGML_OP_FLASH_ATTN_EXT); - const int64_t nwg = 32; + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + //GGML_TENSOR_LOCALS( int32_t, ne1, op->src[1], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb1, op->src[1], nb); + GGML_TENSOR_LOCALS( int32_t, ne2, op->src[2], ne); + GGML_TENSOR_LOCALS(uint64_t, nb2, op->src[2], nb); + //GGML_TENSOR_LOCALS( int32_t, ne3, op->src[3], ne); + //GGML_TENSOR_LOCALS(uint64_t, nb3, op->src[3], nb); - const int64_t ne01 = op->src[0]->ne[1]; - const int64_t ne02 = op->src[0]->ne[2]; - const int64_t ne03 = op->src[0]->ne[3]; - const int64_t ne20 = op->src[2]->ne[0]; + size_t res = 0; - // temp buffer for writing the results from each workgroup - // - ne20: the size of the Value head - // - + 2: the S and M values for each intermediate result - return ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + if (ggml_metal_op_flash_attn_ext_use_vec(op)) { + const int64_t nwg = 32; + + // temp buffer for writing the results from each workgroup + // - ne20: the size of the Value head + // - + 2: the S and M values for each intermediate result + res += ggml_type_size(GGML_TYPE_F32)*(ne01*ne02*ne03*nwg*(ne20 + 2)); + } + + return res; } int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { @@ -1910,8 +2058,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_TENSOR_LOCALS( int32_t, ne, op, ne); GGML_TENSOR_LOCALS( int32_t, nb, op, nb); - GGML_ASSERT(ne00 % 4 == 0); - GGML_ASSERT(ne11 % 32 == 0); + GGML_ASSERT(ne00 % 4 == 0); GGML_ASSERT(op->src[0]->type == GGML_TYPE_F32); GGML_ASSERT(op->src[1]->type == op->src[2]->type); @@ -1921,8 +2068,8 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne12 == ne22); GGML_ASSERT(!op->src[3] || op->src[3]->type == GGML_TYPE_F16); - GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= GGML_PAD(op->src[0]->ne[1], 8) && - "the Flash-Attention Metal kernel requires the mask to be padded to 8 and at least n_queries big"); + GGML_ASSERT(!op->src[3] || op->src[3]->ne[1] >= op->src[0]->ne[1] && + "the Flash-Attention Metal kernel requires the mask to be at least n_queries big"); float scale; float max_bias; @@ -1949,15 +2096,111 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(ne01 < 65536); + ggml_metal_buffer_id bid_src0 = ggml_metal_get_buffer_id(op->src[0]); + ggml_metal_buffer_id bid_src1 = ggml_metal_get_buffer_id(op->src[1]); + ggml_metal_buffer_id bid_src2 = ggml_metal_get_buffer_id(op->src[2]); + ggml_metal_buffer_id bid_src3 = has_mask ? ggml_metal_get_buffer_id(op->src[3]) : bid_src0; + ggml_metal_buffer_id bid_src4 = has_sinks ? ggml_metal_get_buffer_id(op->src[4]) : bid_src0; + + ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); + + ggml_metal_buffer_id bid_pad = bid_dst; + bid_pad.offs += ggml_nbytes(op); + + ggml_metal_buffer_id bid_blk = bid_pad; + bid_blk.offs += ggml_metal_op_flash_attn_ext_extra_pad(op); + + ggml_metal_buffer_id bid_tmp = bid_blk; + bid_tmp.offs += ggml_metal_op_flash_attn_ext_extra_blk(op); + if (!ggml_metal_op_flash_attn_ext_use_vec(op)) { // half8x8 kernel - const int64_t nqptg = 8; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 64; // cache values per simdgroup !! sync with kernel template arguments !! + const int nqptg = OP_FLASH_ATTN_EXT_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_NCPSG; // cache values per simdgroup GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 8 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (has_mask) { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) != 0); + + ggml_metal_kargs_flash_attn_ext_blk args0 = { + /*.ne01 =*/ ne01, + /*.ne30 =*/ ne30, + /*.ne31 =*/ ne31, + /*.ne32 =*/ ne32, + /*.ne33 =*/ ne33, + /*.nb31 =*/ nb31, + /*.nb32 =*/ nb32, + /*.nb33 =*/ nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_blk(lib, op, nqptg, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src3, 1); + ggml_metal_encoder_set_buffer (enc, bid_blk, 2); + + const int32_t nblk1 = ((ne01 + nqptg - 1)/nqptg); + const int32_t nblk0 = ((ne30 + ncpsg - 1)/ncpsg); + + ggml_metal_encoder_dispatch_threadgroups(enc, nblk0, nblk1, ne32*ne33, 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_blk(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + const int is_q = ggml_is_quantized(op->src[1]->type) ? 1 : 0; // 2*(2*ncpsg) @@ -2007,6 +2250,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2023,24 +2267,18 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); + ggml_metal_encoder_set_buffer (enc, bid_pad, 6); + ggml_metal_encoder_set_buffer (enc, bid_blk, 7); + ggml_metal_encoder_set_buffer (enc, bid_dst, 8); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); @@ -2048,14 +2286,62 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { #undef FATTN_SMEM } else { // half4x4 kernel - const int64_t nqptg = 1; // queries per threadgroup !! sync with kernel template arguments !! - const int64_t ncpsg = 32; // cache values per simdgroup !! sync with kernel template arguments !! - const int64_t nkpsg = 1*ncpsg; + const int nqptg = OP_FLASH_ATTN_EXT_VEC_NQPTG; // queries per threadgroup + const int ncpsg = OP_FLASH_ATTN_EXT_VEC_NCPSG; // cache values per simdgroup !! sync with kernel template arguments !! + const int nkpsg = 1*ncpsg; GGML_ASSERT(nqptg <= 32); GGML_ASSERT(nqptg % 1 == 0); GGML_ASSERT(ncpsg % 32 == 0); + bool need_sync = false; + + const bool has_kvpad = ne11 % ncpsg != 0; + + if (has_kvpad) { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) != 0); + + ggml_metal_kargs_flash_attn_ext_pad args0 = { + /*.ne11 =*/ne11, + /*.ne_12_2 =*/ne12, + /*.ne_12_3 =*/ne13, + /*.nb11 =*/nb11, + /*.nb12 =*/nb12, + /*.nb13 =*/nb13, + /*.nb21 =*/nb21, + /*.nb22 =*/nb22, + /*.nb23 =*/nb23, + /*.ne31 =*/ne31, + /*.ne32 =*/ne32, + /*.ne33 =*/ne33, + /*.nb31 =*/nb31, + /*.nb32 =*/nb32, + /*.nb33 =*/nb33, + }; + + ggml_metal_pipeline_t pipeline0 = ggml_metal_library_get_pipeline_flash_attn_ext_pad(lib, op, has_mask, ncpsg); + + ggml_metal_encoder_set_pipeline(enc, pipeline0); + ggml_metal_encoder_set_bytes (enc, &args0, sizeof(args0), 0); + ggml_metal_encoder_set_buffer (enc, bid_src1, 1); + ggml_metal_encoder_set_buffer (enc, bid_src2, 2); + ggml_metal_encoder_set_buffer (enc, bid_src3, 3); + ggml_metal_encoder_set_buffer (enc, bid_pad, 4); + + assert(ne12 == ne22); + assert(ne13 == ne23); + + ggml_metal_encoder_dispatch_threadgroups(enc, ncpsg, std::max(ne12, ne32), std::max(ne13, ne33), 32, 1, 1); + + need_sync = true; + } else { + assert(ggml_metal_op_flash_attn_ext_extra_pad(op) == 0); + } + + if (need_sync) { + ggml_metal_op_concurrency_reset(ctx); + } + // ne00 + 2*ncpsg*(nsg) // for each query, we load it as f16 in shared memory (ne00) // and store the soft_max values and the mask @@ -2120,6 +2406,7 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.nb21 =*/ nb21, /*.nb22 =*/ nb22, /*.nb23 =*/ nb23, + /*.ne31 =*/ ne31, /*.ne32 =*/ ne32, /*.ne33 =*/ ne33, /*.nb31 =*/ nb31, @@ -2136,25 +2423,17 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { /*.logit_softcap =*/ logit_softcap, }; - ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, nsg, nwg); + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_flash_attn_ext_vec(lib, op, has_mask, has_sinks, has_bias, has_scap, has_kvpad, nsg, nwg); GGML_ASSERT(nsg*32 <= ggml_metal_pipeline_max_theads_per_threadgroup(pipeline)); ggml_metal_encoder_set_pipeline(enc, pipeline); ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), 0); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), 1); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), 2); - ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), 3); - if (op->src[3]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[3]), 4); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 4); - } - if (op->src[4]) { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[4]), 5); - } else { - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op->src[0]), 5); - } + ggml_metal_encoder_set_buffer (enc, bid_src0, 1); + ggml_metal_encoder_set_buffer (enc, bid_src1, 2); + ggml_metal_encoder_set_buffer (enc, bid_src2, 3); + ggml_metal_encoder_set_buffer (enc, bid_src3, 4); + ggml_metal_encoder_set_buffer (enc, bid_src4, 5); const size_t smem = FATTN_SMEM(nsg); @@ -2162,23 +2441,25 @@ int ggml_metal_op_flash_attn_ext(ggml_metal_op_t ctx, int idx) { GGML_ASSERT(smem <= props_dev->max_theadgroup_memory_size); if (nwg == 1) { + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) == 0); + // using 1 workgroup -> write the result directly into dst - ggml_metal_encoder_set_buffer(enc, ggml_metal_get_buffer_id(op), 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_dst, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); } else { // sanity checks + assert(ggml_metal_op_flash_attn_ext_extra_tmp(op) != 0); + GGML_ASSERT(ne01*ne02*ne03 == ne1*ne2*ne3); GGML_ASSERT((uint64_t)ne1*ne2*ne3 <= (1u << 31)); - ggml_metal_buffer_id bid_dst = ggml_metal_get_buffer_id(op); - // write the results from each workgroup into a temp buffer - ggml_metal_buffer_id bid_tmp = bid_dst; - bid_tmp.offs += ggml_nbytes(op); - ggml_metal_encoder_set_buffer(enc, bid_tmp, 6); + ggml_metal_encoder_set_buffer(enc, bid_pad, 6); + ggml_metal_encoder_set_buffer(enc, bid_tmp, 7); ggml_metal_encoder_set_threadgroup_memory_size(enc, smem, 0); ggml_metal_encoder_dispatch_threadgroups(enc, (ne01 + nqptg - 1)/nqptg, ne02, ne03*nwg, 32, nsg, 1); @@ -3156,3 +3437,73 @@ int ggml_metal_op_leaky_relu(ggml_metal_op_t ctx, int idx) { return 1; } + +int ggml_metal_op_opt_step_adamw(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_adamw(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_adamw args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[3]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[4]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} + +int ggml_metal_op_opt_step_sgd(ggml_metal_op_t ctx, int idx) { + ggml_tensor * op = ctx->node(idx); + + ggml_metal_library_t lib = ctx->lib; + ggml_metal_encoder_t enc = ctx->enc; + + GGML_TENSOR_LOCALS( int32_t, ne0, op->src[0], ne); + GGML_TENSOR_LOCALS(uint64_t, nb0, op->src[0], nb); + GGML_TENSOR_LOCALS( int32_t, ne, op, ne); + GGML_TENSOR_LOCALS(uint32_t, nb, op, nb); + + ggml_metal_pipeline_t pipeline = ggml_metal_library_get_pipeline_opt_step_sgd(lib, op); + + const int64_t np = ggml_nelements(op->src[0]); + ggml_metal_kargs_opt_step_sgd args = { + /*.np =*/ np, + }; + + int ida = 0; + + ggml_metal_encoder_set_pipeline(enc, pipeline); + ggml_metal_encoder_set_bytes (enc, &args, sizeof(args), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[0]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[1]), ida++); + ggml_metal_encoder_set_buffer (enc, ggml_metal_get_buffer_id(op->src[2]), ida++); + + const int nth = std::min(ggml_metal_pipeline_max_theads_per_threadgroup(pipeline), ne0); + const int64_t n = (np + nth - 1) / nth; + + ggml_metal_encoder_dispatch_threadgroups(enc, n, 1, 1, nth, 1, 1); + + return 1; +} diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h index 8df4c72e7c..f352738698 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal-ops.h @@ -39,6 +39,8 @@ size_t ggml_metal_op_mul_mat_id_extra_ids(const struct ggml_tensor * op); // return true if we should use the FA vector kernel for this op bool ggml_metal_op_flash_attn_ext_use_vec(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_pad(const struct ggml_tensor * op); +size_t ggml_metal_op_flash_attn_ext_extra_blk(const struct ggml_tensor * op); size_t ggml_metal_op_flash_attn_ext_extra_tmp(const struct ggml_tensor * op); int ggml_metal_op_concat (ggml_metal_op_t ctx, int idx); @@ -48,6 +50,7 @@ int ggml_metal_op_scale (ggml_metal_op_t ctx, int idx); int ggml_metal_op_clamp (ggml_metal_op_t ctx, int idx); int ggml_metal_op_unary (ggml_metal_op_t ctx, int idx); int ggml_metal_op_glu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_sum (ggml_metal_op_t ctx, int idx); int ggml_metal_op_sum_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_get_rows (ggml_metal_op_t ctx, int idx); int ggml_metal_op_set_rows (ggml_metal_op_t ctx, int idx); @@ -76,6 +79,8 @@ int ggml_metal_op_timestep_embedding(ggml_metal_op_t ctx, int idx); int ggml_metal_op_argmax (ggml_metal_op_t ctx, int idx); int ggml_metal_op_argsort (ggml_metal_op_t ctx, int idx); int ggml_metal_op_leaky_relu (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_adamw (ggml_metal_op_t ctx, int idx); +int ggml_metal_op_opt_step_sgd (ggml_metal_op_t ctx, int idx); #ifdef __cplusplus } diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp index 17999a6162..f356e4a0aa 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.cpp @@ -195,9 +195,9 @@ static size_t ggml_backend_metal_buffer_type_get_alloc_size(ggml_backend_buffer_ } break; case GGML_OP_FLASH_ATTN_EXT: { - if (ggml_metal_op_flash_attn_ext_use_vec(tensor)) { - res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); - } + res += ggml_metal_op_flash_attn_ext_extra_pad(tensor); + res += ggml_metal_op_flash_attn_ext_extra_blk(tensor); + res += ggml_metal_op_flash_attn_ext_extra_tmp(tensor); } break; default: break; @@ -543,6 +543,7 @@ static void ggml_backend_metal_device_get_props(ggml_backend_dev_t dev, ggml_bac props->type = ggml_backend_metal_device_get_type(dev); ggml_backend_metal_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->library = GGML_METAL_NAME; props->caps = { /* .async = */ true, diff --git a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal index 44dc31c02a..375a0c7fd1 100644 --- a/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal +++ b/ml/backend/ggml/ggml/src/ggml-metal/ggml-metal.metal @@ -1723,6 +1723,24 @@ kernel void kernel_geglu_quick_f32( } } +kernel void kernel_op_sum_f32( + constant ggml_metal_kargs_sum & args, + device const float * src0, + device float * dst, + ushort tiitg[[thread_index_in_threadgroup]]) { + + if (tiitg != 0) { + return; + } + + float acc = 0.0f; + for (ulong i = 0; i < args.np; ++i) { + acc += src0[i]; + } + + dst[0] = acc; +} + template kernel void kernel_sum_rows( constant ggml_metal_kargs_sum_rows & args, @@ -2032,7 +2050,38 @@ kernel void kernel_ssm_conv_f32_f32( x[0] = sumf; } -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-1 part +kernel void kernel_ssm_conv_f32_f32_4( + constant ggml_metal_kargs_ssm_conv & args, + device const void * src0, + device const void * src1, + device float * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + uint3 tpitg[[thread_position_in_threadgroup]], + uint3 ntg[[threads_per_threadgroup]]) { + const int64_t ir = tgpig.x; + const int64_t i2 = tgpig.y; + const int64_t i3 = tgpig.z; + + const int64_t nc = args.ne10; + //const int64_t ncs = args.ne00; + //const int64_t nr = args.ne01; + //const int64_t n_t = args.ne1; + //const int64_t n_s = args.ne2; + + device const float4 * s = (device const float4 *) ((device const char *) src0 + ir*args.nb01 + i2*args.nb00 + i3*args.nb02); + device const float4 * c = (device const float4 *) ((device const char *) src1 + ir*args.nb11); + device float * x = (device float *) ((device char *) dst + ir*args.nb0 + i2*args.nb1 + i3*args.nb2); + + float sumf = 0.0f; + + for (int64_t i0 = 0; i0 < nc/4; ++i0) { + sumf += dot(s[i0], c[i0]); + } + + x[0] = sumf; +} + +// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part kernel void kernel_ssm_scan_f32( constant ggml_metal_kargs_ssm_scan & args, device const void * src0, @@ -2044,219 +2093,88 @@ kernel void kernel_ssm_scan_f32( device const void * src6, device float * dst, threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { + uint3 tgpig[[threadgroup_position_in_grid]], + ushort3 tpitg[[thread_position_in_threadgroup]], + ushort sgitg[[simdgroup_index_in_threadgroup]], + ushort tiisg[[thread_index_in_simdgroup]], + ushort sgptg[[simdgroups_per_threadgroup]], + uint3 tgpg[[threadgroups_per_grid]]) { + constexpr short NW = N_SIMDWIDTH; - const int64_t i0 = tpitg.x; - const int64_t i1 = 0; - const int64_t ir = tgpig.x; // current head - const int64_t i3 = tgpig.y; // current seq + shared[tpitg.x] = 0.0f; - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); + const int32_t i0 = tpitg.x; + const int32_t i1 = tgpig.x; + const int32_t ir = tgpig.y; // current head + const int32_t i3 = tgpig.z; // current seq - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; + const int32_t nc = args.d_state; + const int32_t nr = args.d_inner; + const int32_t nh = args.n_head; + const int32_t ng = args.n_group; + const int32_t n_t = args.n_seq_tokens; - const int64_t s_off = args.s_off; + const int32_t s_off = args.s_off; device const int32_t * ids = (device const int32_t *) src6; device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave + + const int32_t i = i0 + i1*nc; + const int32_t g = ir / (nh / ng); // repeat_interleave + float s0 = s0_buff[i]; - float s = s_buff[i]; + float s = 0.0f; - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); + device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {ne30, nh} - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} + const float A0 = A[i0%args.ne30]; - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; + device const float * x = (device const float *)((device const char *) src1 + i1*args.nb10 + ir*args.nb11 + i3*args.nb13); // {dim, nh, nt, ns} + device const float * dt = (device const float *)((device const char *) src2 + ir*args.nb20 + i3*args.nb22); // {nh, nt, ns} + device const float * B = (device const float *)((device const char *) src4 + g*args.nb41 + i3*args.nb43); // {d_state, ng, nt, ns} + device const float * C = (device const float *)((device const char *) src5 + g*args.nb51 + i3*args.nb53); // {d_state, ng, nt, ns} - const float state = (s0 * exp(dt_soft_plus * A[i0])) + (B[i0] * x_dt); - s = state; + device float * y = dst + (i1 + ir*(nr) + i3*(n_t*nh*nr)); // {dim, nh, nt, ns} - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - if (sgptg > 1) { - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. - threadgroup_barrier(mem_flags::mem_threadgroup); - - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); - if (tiisg == 0) { - y[0] = sumf; - } - } - } else if (tiisg == 0) { - y[0] = sumf; - } - - // recurse - s0 = s; - } - - // Assign the final state to the output buffer - s_buff[i] = s; -} - -// ref: ggml.c:ggml_compute_forward_ssm_scan_f32, Mamba-2 part -kernel void kernel_ssm_scan_group_f32( - constant ggml_metal_kargs_ssm_scan & args, - device const void * src0, - device const void * src1, - device const void * src2, - device const void * src3, - device const void * src4, - device const void * src5, - device const void * src6, - device float * dst, - threadgroup float * shared [[threadgroup(0)]], - uint3 tgpig[[threadgroup_position_in_grid]], - uint3 tpitg[[thread_position_in_threadgroup]], - ushort sgitg[[simdgroup_index_in_threadgroup]], - ushort tiisg[[thread_index_in_simdgroup]], - ushort sgptg[[simdgroups_per_threadgroup]], - uint3 tgpg[[threadgroups_per_grid]]) { - - const int64_t i0 = tpitg.x; - const int64_t i1 = tgpig.x; - const int64_t ir = tgpig.y; // current head - const int64_t i3 = tgpig.z; // current seq - - const uint64_t nb00 = sizeof(float); - const uint64_t nb10 = sizeof(float); - const uint64_t nb20 = sizeof(float); - - const int64_t nc = args.d_state; - const int64_t nr = args.d_inner; - const int64_t nh = args.n_head; - const int64_t ng = args.n_group; - const int64_t n_t = args.n_seq_tokens; - - const int64_t s_off = args.s_off; - - device const int32_t * ids = (device const int32_t *) src6; - - device const float * s0_buff = (device const float *) ((device const char *) src0 + ir*args.nb02 + ids[i3]*args.nb03); - device float * s_buff = (device float *) ((device char *) dst + ir*args.nb02 + i3*args.nb03 + s_off); - const int64_t i = i0 + i1*nc; - const int64_t g = ir / (nh / ng); // repeat_interleave - float s0 = s0_buff[i]; - float s = s_buff[i]; - - device const float * A = (device const float *) ((device const char *) src3 + ir*args.nb31); // {1, nh} - device const float * x_block = (device const float *) ((device const char *) src1 + i1*nb10 + ir*args.nb11 + i3*args.nb13); - device const float * dt_block = (device const float *) ((device const char *) src2 + ir*nb20 + i3*args.nb22); - device const float * B_block = (device const float *) ((device const char *) src4 + g*args.nb41 + i3*args.nb43); - device const float * C_block = (device const float *) ((device const char *) src5 + g*args.nb51 + i3*args.nb53); - device float * y_block = (device float *) ((device char *) dst + (i1 + ir*(nr) + i3*(n_t*nh*nr))*nb00); - - for (int64_t i2 = 0; i2 < n_t; ++i2) { - device const float * x = (device const float *) ((device const char *) x_block + i2*args.nb12); // {dim, nh, nt, ns} - device const float * dt = (device const float *) ((device const char *) dt_block + i2*args.nb21); // {nh, nt, ns} - device const float * B = (device const float *) ((device const char *) B_block + i2*args.nb42); // {d_state, ng, nt, ns} - device const float * C = (device const float *) ((device const char *) C_block + i2*args.nb52); // {d_state, ng, nt, ns} - device float * y = (device float *) ((device char *) y_block + i2*(nh*nr*nb00)); // {dim, nh, nt, ns} - - const float dt_soft_plus = dt[0] <= 20.0f ? log(1.0f + exp(dt[0])) : dt[0]; - const float x_dt = x[0] * dt_soft_plus; - const float dA = exp(dt_soft_plus * A[0]); - - const float state = (s0 * dA) + (B[i0] * x_dt); - s = state; - - // Parallel sum: This relies on the fact that this kernel will be - // dispatched with each threadgroup having (d_state, 1, 1) threads which - // are subdivided into SIMD groups of size `sgptg`. The goal is to - // compute y = sum({state * C[i] for i in range(d_state)}). - // To parallelize this effectively, we first use simd_sum over each SIMD - // group to compute the sum of each SIMD group, then place the result in - // the SIMD group's indexed bucket in the shared memory. We then sum - // over the individual group sums to compute the final sum. - - // Computed for each thread - float sumf = state * C[i0]; - - // Sum the threads in the simd group => simd sum - sumf = simd_sum(sumf); - - // Once per simd group, place the group sum into the shared buffer - if (tiisg == 0) { - shared[sgitg] = sumf; - } - - // Wait for all threads in the threadgroup to reach this point. This - // ensures that all elements of the shared buffer are populated with the - // sum of the individual simd groups. + for (int i2 = 0; i2 < n_t; i2 += sgptg) { threadgroup_barrier(mem_flags::mem_threadgroup); - // For simd group 0 at indices < num simd groups, extract the shared - // simd sum - sumf = 0.0f; - if (sgitg == 0) { - if (tiisg < sgptg) { - sumf = shared[tiisg]; - } - sumf = simd_sum(sumf); + for (int t = 0; t < sgptg && i2 + t < n_t; t++) { + const float dt0 = dt[0]; + const float dtsp = dt0 <= 20.0f ? log(1.0f + exp(dt0)) : dt0; + const float x_dt = x[0] * dtsp; + const float dA = exp(dtsp * A0); + + s = (s0 * dA) + (B[i0] * x_dt); + + const float sumf = simd_sum(s * C[i0]); + if (tiisg == 0) { - y[0] = sumf; + shared[t*NW + sgitg] = sumf; } + + // recurse + s0 = s; + + x += args.ns12; + dt += args.ns21; + B += args.ns42; + C += args.ns52; } - // recurse - s0 = s; + threadgroup_barrier(mem_flags::mem_threadgroup); + + const float sumf = simd_sum(shared[sgitg*NW + tiisg]); + + if (tiisg == 0 && i2 + sgitg < n_t) { + y[sgitg*nh*nr] = sumf; + } + + y += sgptg*nh*nr; } - // Assign the final state to the output buffer s_buff[i] = s; } @@ -4513,10 +4431,142 @@ kernel void kernel_leaky_relu_f32_4( dst[tpig] = float4(x > 0.0f)*x + float4(x <= 0.0f)*(x * args.slope); } +constant bool FC_flash_attn_ext_pad_has_mask [[function_constant(FC_FLASH_ATTN_EXT_PAD + 0)]]; + +constant int32_t FC_flash_attn_ext_pad_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_PAD + 25)]]; + +// pad the last chunk of C elements of k and v into a an extra pad buffer +kernel void kernel_flash_attn_ext_pad( + constant ggml_metal_kargs_flash_attn_ext_pad & args, + device const char * k, + device const char * v, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { + const int32_t C = FC_flash_attn_ext_pad_ncpsg; + + device char * k_pad = dst; + device char * v_pad = k_pad + args.nb11*C*args.ne_12_2*args.ne_12_3; + device char * mask_pad = v_pad + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const int32_t icp = args.ne11 % C; + const int32_t ic0 = args.ne11 - icp; + + const int32_t i1 = tgpig[0]; + const int32_t i2 = tgpig[1]; + const int32_t i3 = tgpig[2]; + + if (i2 < args.ne_12_2 && i3 < args.ne_12_3) { + device const char * k_src = k + args.nb11*(ic0 + i1) + args.nb12*i2 + args.nb13*i3; + device const char * v_src = v + args.nb21*(ic0 + i1) + args.nb22*i2 + args.nb23*i3; + + device char * k_dst = k_pad + args.nb11*i1 + args.nb11*C*i2 + args.nb11*C*args.ne_12_2*i3; + device char * v_dst = v_pad + args.nb21*i1 + args.nb21*C*i2 + args.nb21*C*args.ne_12_2*i3; + + if (i1 >= icp) { + // here it is not important the exact value that will be used as we rely on masking out the scores in the attention + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = 0; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = 0; + } + } else { + for (uint64_t i = tiitg; i < args.nb11; i += ntg.x) { + k_dst[i] = k_src[i]; + } + for (uint64_t i = tiitg; i < args.nb21; i += ntg.x) { + v_dst[i] = v_src[i]; + } + } + } + + if (FC_flash_attn_ext_pad_has_mask) { + if (i2 < args.ne32 && i3 < args.ne33) { + for (int ib = i1; ib < args.ne31; ib += C) { + device const half * mask_src = (device const half *)(mask + args.nb31*ib + args.nb32*i2 + args.nb33*i3) + ic0; + device half * mask_dst = (device half *)(mask_pad) + C*ib + C*args.ne31*i2 + C*args.ne31*args.ne32*i3; + + for (int i = tiitg; i < C; i += ntg.x) { + if (i >= icp) { + mask_dst[i] = -MAXHALF; + } else { + mask_dst[i] = mask_src[i]; + } + } + } + } + } +} + +constant int32_t FC_flash_attn_ext_blk_nqptg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 24)]]; +constant int32_t FC_flash_attn_ext_blk_ncpsg [[function_constant(FC_FLASH_ATTN_EXT_BLK + 25)]]; + +// scan the blocks of the mask that are not masked +// 0 - masked (i.e. full of -INF, skip) +// 1 - not masked (i.e. at least one element of the mask is not -INF) +kernel void kernel_flash_attn_ext_blk( + constant ggml_metal_kargs_flash_attn_ext_blk & args, + device const char * mask, + device char * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiisg[[thread_index_in_simdgroup]]) { + // block size C x Q + const int32_t Q = FC_flash_attn_ext_blk_nqptg; + const int32_t C = FC_flash_attn_ext_blk_ncpsg; + + constexpr short NW = N_SIMDWIDTH; + + const int32_t i3 = tgpig[2]/args.ne32; + const int32_t i2 = tgpig[2]%args.ne32; + const int32_t i1 = tgpig[1]; + const int32_t i0 = tgpig[0]; + + char res = i0*C + C > args.ne30 ? 1 : 0; + + device const half * mask_src = (device const half *) (mask + (i1*Q)*args.nb31 + i2*args.nb32 + i3*args.nb33) + i0*C + tiisg; + + // fast route + if (res == 0) { + if (simd_max(*mask_src) > -MAXHALF/2) { + res = 1; + } + } + + // detailed check of the elements of the block + if ((C > NW || Q > 1) && res == 0) { + half m = -MAXHALF; + + FOR_UNROLL (short j = 0; j < Q; ++j) { + FOR_UNROLL (short ii = 0; ii < C/NW; ++ii) { + m = max(m, mask_src[ii*NW]); + } + + mask_src += args.nb31/2; + } + + if (simd_max(m) > -MAXHALF/2) { + res = 1; + } + } + + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne30 + C - 1)/C); + + if (tiisg == 0) { + dst[((i3*args.ne32 + i2)*nblk1 + i1)*nblk0 + i0] = res; + } +} + constant bool FC_flash_attn_ext_has_mask [[function_constant(FC_FLASH_ATTN_EXT + 0)]]; constant bool FC_flash_attn_ext_has_sinks [[function_constant(FC_FLASH_ATTN_EXT + 1)]]; constant bool FC_flash_attn_ext_has_bias [[function_constant(FC_FLASH_ATTN_EXT + 2)]]; constant bool FC_flash_attn_ext_has_scap [[function_constant(FC_FLASH_ATTN_EXT + 3)]]; +constant bool FC_flash_attn_ext_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT + 4)]]; + +constant bool FC_flash_attn_ext_bc_mask [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_scale [[function_constant(FC_FLASH_ATTN_EXT + 10)]]; //constant float FC_flash_attn_ext_max_bias [[function_constant(FC_FLASH_ATTN_EXT + 11)]]; @@ -4563,6 +4613,8 @@ void kernel_flash_attn_ext_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16, uint3 tgpig, @@ -4628,6 +4680,13 @@ void kernel_flash_attn_ext_impl( pm2[jj] = (device const half2 *) ((device const char *) mask + (iq1 + j)*args.nb31 + (iq2%args.ne32)*args.nb32 + (iq3%args.ne33)*args.nb33); } + { + const int32_t nblk1 = ((args.ne01 + Q - 1)/Q); + const int32_t nblk0 = ((args.ne11 + C - 1)/C); + + blk += (((iq3%args.ne33)*args.ne32 + (iq2%args.ne32))*nblk1 + iq1/Q)*nblk0; + } + { q += iq1*args.nb01 + iq2*args.nb02 + iq3*args.nb03; @@ -4687,16 +4746,75 @@ void kernel_flash_attn_ext_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic = 0; ic < args.ne11; ic += C) { + for (int ic0 = 0; ; ++ic0) { + int ic = ic0*C; + if (ic >= args.ne11) { + break; + } + + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_has_mask) { + threadgroup half * sm = (threadgroup half *) (sm2); + + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + for (short i = tiisg; i < C; i += NW) { + if (ic + i >= args.ne11) { + sm[2*j*SH + i] = -MAXHALF; + } + } + } + } else { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + const short j = jj*NSG + sgitg; + + pm2[jj] = (device const half2 *) ((device const half *) mask + + (iq1 + j)*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32)); + } + } + + ic = 0; + } + // read the mask into shared mem if (FC_flash_attn_ext_has_mask) { + if (blk[ic0] == 0) { + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { + pm2[jj] += NW; + } + + continue; + } + FOR_UNROLL (short jj = 0; jj < NQ; ++jj) { const short j = jj*NSG + sgitg; - sm2[j*SH + tiisg] = pm2[jj][tiisg]; + if (FC_flash_attn_ext_bc_mask) { + sm2[j*SH + tiisg] = (iq1 + j) < args.ne31 ? pm2[jj][tiisg] : half2(-MAXHALF, -MAXHALF); + } else { + sm2[j*SH + tiisg] = pm2[jj][tiisg]; + } + pm2[jj] += NW; } +#if 0 + // note: old -INF block optimization - obsoleted by pre-computing non-masked blocks + threadgroup_barrier(mem_flags::mem_threadgroup); // used to detect blocks full of -INF @@ -4715,13 +4833,14 @@ void kernel_flash_attn_ext_impl( continue; } +#endif } // Q*K^T // this is compile-time check, so it does not have runtime overhead if (is_same::value) { // we can read directly from global memory - device const k_t * pk = (device const k_t *) ((device const char *) k + ic*args.nb11); + device const k_t * pk = (device const k_t *) (k + ic*args.nb11); threadgroup const q_t * pq = sq; threadgroup s_t * ps = ss; @@ -4732,26 +4851,24 @@ void kernel_flash_attn_ext_impl( constexpr short NC = (C/8)/NSG; - // TODO: not good to unroll for large contexts - not sure why? + // note: do not unroll for large heads + #pragma unroll (DK <= 64 ? NC : 1) for (short cc = 0; cc < NC; ++cc) { qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); - if (DK8 % 16 != 0) { + if (DK % 16 != 0) { k8x8_t mk; q8x8_t mq; FOR_UNROLL (short i = 0; i < DK8; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk, pk, NS10, 0, true); - simdgroup_load(mq, pq, DK); + simdgroup_load(mk, pk + 8*i, NS10, 0, true); + simdgroup_load(mq, pq + 8*i, DK); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq, mk, mqk); - - pk += 8; - pq += 8; } } else { k8x8_t mk[2]; @@ -4760,26 +4877,22 @@ void kernel_flash_attn_ext_impl( FOR_UNROLL (short i = 0; i < DK8/2; ++i) { simdgroup_barrier(mem_flags::mem_none); - simdgroup_load(mk[0], pk + 0*8, NS10, 0, true); - simdgroup_load(mk[1], pk + 1*8, NS10, 0, true); + simdgroup_load(mq[0], pq + 0*8 + 16*i, DK); + simdgroup_load(mq[1], pq + 1*8 + 16*i, DK); - simdgroup_load(mq[0], pq + 0*8, DK); - simdgroup_load(mq[1], pq + 1*8, DK); + simdgroup_load(mk[0], pk + 0*8 + 16*i, NS10, 0, true); + simdgroup_load(mk[1], pk + 1*8 + 16*i, NS10, 0, true); simdgroup_barrier(mem_flags::mem_none); simdgroup_multiply_accumulate(mqk, mq[0], mk[0], mqk); simdgroup_multiply_accumulate(mqk, mq[1], mk[1], mqk); - - pk += 16; - pq += 16; } } simdgroup_store(mqk, ps, SH, 0, false); - pk += 8*(NSG*NS10 - DK8); - pq += 8*(NSG*0 - DK8); + pk += 8*(NSG*NS10); ps += 8*(NSG); } } else { @@ -4793,7 +4906,7 @@ void kernel_flash_attn_ext_impl( qk8x8_t mqk = make_filled_simdgroup_matrix((qk_t) 0.0f); for (short ii = 0; ii < DK16; ii += 4) { - device const kd4x4_t * pk4x4 = (device const kd4x4_t *) ((device const char *) k + ((ic + 8*cc + ty)*args.nb11)); + device const kd4x4_t * pk4x4 = (device const kd4x4_t *) (k + ((ic + 8*cc + ty)*args.nb11)); if (DK16%4 == 0) { // the head is evenly divisible by 4*16 = 64, so no need for bound checks @@ -4913,27 +5026,50 @@ void kernel_flash_attn_ext_impl( } { - auto sst = ss; - - device const v_t * pv = (device const v_t *) ((device const char *) v + ic*args.nb21); + device const v_t * pv = (device const v_t *) (v + ic*args.nb21); pv += 8*sgitg; - FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { - s8x8_t vs; - simdgroup_load(vs, sst, SH, 0, false); + if (DV <= 64) { + FOR_UNROLL (short cc = 0; cc < C/8; ++cc) { + s8x8_t vs; + simdgroup_load(vs, ss + 8*cc, SH, 0, false); - FOR_UNROLL (short ii = 0; ii < NO; ++ii) { - v8x8_t mv; + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[2]; - simdgroup_load(mv, pv, NS20, 0, false); - simdgroup_multiply_accumulate(lo[ii], vs, mv, lo[ii]); + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG, NS20, 0, false); - pv += 8*NSG; + simdgroup_multiply_accumulate(lo[2*ii + 0], vs, mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs, mv[1], lo[2*ii + 1]); + } + + pv += 8*NS20; } + } else { + FOR_UNROLL (short cc = 0; cc < (C/8)/2; ++cc) { + s8x8_t vs[2]; - pv += 8*(NS20 - NO*NSG); - sst += 8; + simdgroup_load(vs[0], ss + 16*cc + 0, SH, 0, false); + simdgroup_load(vs[1], ss + 16*cc + 8, SH, 0, false); + + FOR_UNROLL (short ii = 0; ii < NO/2; ++ii) { + v8x8_t mv[4]; + + simdgroup_load(mv[0], pv + 0*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[1], pv + 8*NSG + 16*ii*NSG + 0*8*NS20, NS20, 0, false); + simdgroup_load(mv[2], pv + 0*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + simdgroup_load(mv[3], pv + 8*NSG + 16*ii*NSG + 1*8*NS20, NS20, 0, false); + + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[0], mv[0], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[0], mv[1], lo[2*ii + 1]); + simdgroup_multiply_accumulate(lo[2*ii + 0], vs[1], mv[2], lo[2*ii + 0]); + simdgroup_multiply_accumulate(lo[2*ii + 1], vs[1], mv[3], lo[2*ii + 1]); + } + + pv += 2*8*NS20; + } } } @@ -4957,7 +5093,7 @@ void kernel_flash_attn_ext_impl( simdgroup_load(vs, ss + 8*cc, SH, 0, false); for (short ii = 4*sgitg; ii < DV16; ii += 4*NSG) { - device const vd4x4_t * pv4x4 = (device const vd4x4_t *) ((device const char *) v + ((ic + 8*cc + ty)*args.nb21)); + device const vd4x4_t * pv4x4 = (device const vd4x4_t *) (v + ((ic + 8*cc + ty)*args.nb21)); if (DV16%4 == 0) { // no need for bound checks @@ -5047,7 +5183,7 @@ void kernel_flash_attn_ext_impl( device float4 * dst4 = (device float4 *) dst + ((uint64_t)iq3*args.ne2*args.ne1 + iq2 + (uint64_t)(iq1 + j)*args.ne1)*DV4; - const float scale = 1.0f/S[jj]; + const float scale = S[jj] == 0.0 ? 0.0f : 1.0f/S[jj]; if (DV4 % NW == 0) { FOR_UNROLL (short ii = 0; ii < DV4/NW; ++ii) { @@ -5092,8 +5228,8 @@ template< void (*deq_v)(device const vd4x4_t *, short, thread v4x4_t &), short DK, // K head size short DV, // V head size - short Q = 8, // queries per threadgroup - short C = 64> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext( constant ggml_metal_kargs_flash_attn_ext & args, device const char * q, @@ -5101,13 +5237,15 @@ kernel void kernel_flash_attn_ext( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, + device const char * blk, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q_t, q4_t, q8x8_t, k_t, k4x4_t, k8x8_t, v_t, v4x4_t, v8x8_t, qk_t, qk8x8_t, s_t, s2_t, s8x8_t, o_t, o4_t, o8x8_t, kd4x4_t, nl_k, deq_k, vd4x4_t, nl_v, deq_v, DK, DV, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, blk, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_nsg) { // note: disabled cases to reduce library load time //case 1: kernel_flash_attn_ext_impl(FWD_ARGS); break; @@ -5227,6 +5365,7 @@ constant bool FC_flash_attn_ext_vec_has_mask [[function_constant(FC_FLASH_ATTN_ constant bool FC_flash_attn_ext_vec_has_sinks [[function_constant(FC_FLASH_ATTN_EXT_VEC + 1)]]; constant bool FC_flash_attn_ext_vec_has_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 2)]]; constant bool FC_flash_attn_ext_vec_has_scap [[function_constant(FC_FLASH_ATTN_EXT_VEC + 3)]]; +constant bool FC_flash_attn_ext_vec_has_kvpad [[function_constant(FC_FLASH_ATTN_EXT_VEC + 4)]]; //constant float FC_flash_attn_ext_vec_scale [[function_constant(FC_FLASH_ATTN_EXT_VEC + 10)]]; //constant float FC_flash_attn_ext_vec_max_bias [[function_constant(FC_FLASH_ATTN_EXT_VEC + 11)]]; @@ -5253,9 +5392,9 @@ template< void (*deq_v_t4)(device const vd4_t *, short, thread v4_t &), short DK, // K head size short DV, // V head size - short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32, // cache items per threadgroup + short NE, // head elements per thread + short Q, // queries per threadgroup + short C, // cache items per threadgroup short NSG> // number of simd groups void kernel_flash_attn_ext_vec_impl( constant ggml_metal_kargs_flash_attn_ext_vec & args, @@ -5264,6 +5403,7 @@ void kernel_flash_attn_ext_vec_impl( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], @@ -5369,12 +5509,38 @@ void kernel_flash_attn_ext_vec_impl( // loop over the KV cache // each simdgroup handles blocks of Q rows and C columns - for (int ic0 = (int) iwg*C*NSG; ic0 < args.ne11; ic0 += (int) NWG*C*NSG) { - const int ic = ic0 + C*sgitg; + for (int ic0 = iwg*NSG + sgitg; ; ic0 += NWG*NSG) { + int ic = ic0*C; if (ic >= args.ne11) { break; } + // the last partial chunk uses the pad buffer as source + if (FC_flash_attn_ext_vec_has_kvpad && ic + C > args.ne11) { + k = pad; + v = k + args.nb11*C*args.ne_12_2*args.ne_12_3; + mask = v + args.nb21*C*args.ne_12_2*args.ne_12_3; + + const short ikv2 = iq2/(args.ne02/args.ne_12_2); + const short ikv3 = iq3/(args.ne03/args.ne_12_3); + + k += (ikv2 + ikv3*args.ne_12_2)*args.nb11*C; + v += (ikv2 + ikv3*args.ne_12_2)*args.nb21*C; + + if (!FC_flash_attn_ext_vec_has_mask) { + if (ic + tiisg >= args.ne11) { + sm[tiisg] = -MAXHALF; + } + } else { + pm = (device const half *) (mask) + + iq1*C + + (iq2%args.ne32)*(C*args.ne31) + + (iq3%args.ne33)*(C*args.ne31*args.ne32); + } + + ic = 0; + } + if (FC_flash_attn_ext_vec_has_mask) { sm[tiisg] = pm[ic + tiisg]; } @@ -5386,7 +5552,7 @@ void kernel_flash_attn_ext_vec_impl( // Q*K^T { - device const k4_t * pk4 = (device const k4_t *) ((device const char *) k + ic*args.nb11); + device const k4_t * pk4 = (device const k4_t *) (k + ic*args.nb11); threadgroup const q4_t * pq4 = sq4; pk4 += ty*NS10/4 + tx; @@ -5401,7 +5567,7 @@ void kernel_flash_attn_ext_vec_impl( mqk[cc] += dot((float4) pk4[cc*NE*NS10/4 + ii*NL], (float4) pq4[ii*NL]); } } else { - device const kd4_t * pk = (device const kd4_t *) ((device const char *) k + ((ic + NE*cc + ty)*args.nb11)); + device const kd4_t * pk = (device const kd4_t *) (k + ((ic + NE*cc + ty)*args.nb11)); k4_t mk; @@ -5499,7 +5665,7 @@ void kernel_flash_attn_ext_vec_impl( } if (is_same::value) { - device const v4_t * pv4 = (device const v4_t *) ((device const char *) v + ic*args.nb21); + device const v4_t * pv4 = (device const v4_t *) (v + ic*args.nb21); pv4 += ty*NS20/4 + tx; @@ -5512,7 +5678,7 @@ void kernel_flash_attn_ext_vec_impl( } } else { FOR_UNROLL (short cc = 0; cc < C/NE; ++cc) { - device const vd4_t * pv4 = (device const vd4_t *) ((device const char *) v + ((ic + NE*cc + ty)*args.nb21)); + device const vd4_t * pv4 = (device const vd4_t *) (v + ((ic + NE*cc + ty)*args.nb21)); FOR_UNROLL (short ii = 0; ii < DV4/NL; ++ii) { const short i = ii*NL + tx; @@ -5637,7 +5803,7 @@ void kernel_flash_attn_ext_vec_impl( device float4 * dst4 = (device float4 *) dst; device float * dst1 = (device float *) dst + nrows*DV*NWG; // the S and M are stored after the results - const float S = NWG == 1 ? 1.0f/ss[0] : 1.0f; + const float S = NWG == 1 ? (ss[0] == 0.0f ? 0.0f : 1.0f/ss[0]) : 1.0f; // interleave the workgroup data for (short i = tiisg; i < DV4; i += NW) { @@ -5675,8 +5841,8 @@ template< short DK, // K head size short DV, // V head size short NE = 4, // head elements per thread - short Q = 1, // queries per threadgroup - short C = 32> // cache items per threadgroup + short Q = OP_FLASH_ATTN_EXT_VEC_NQPTG, // queries per threadgroup + short C = OP_FLASH_ATTN_EXT_VEC_NCPSG> // cache items per threadgroup kernel void kernel_flash_attn_ext_vec( constant ggml_metal_kargs_flash_attn_ext_vec & args, device const char * q, @@ -5684,13 +5850,14 @@ kernel void kernel_flash_attn_ext_vec( device const char * v, device const char * mask, device const char * sinks, + device const char * pad, device char * dst, threadgroup half * shmem_f16 [[threadgroup(0)]], uint3 tgpig[[threadgroup_position_in_grid]], ushort tiisg[[thread_index_in_simdgroup]], ushort sgitg[[simdgroup_index_in_threadgroup]]) { #define FWD_TMPL q4_t, k4_t, v4_t, qk_t, s_t, s4_t, o4_t, kd4_t, nl_k, deq_k_t4, vd4_t, nl_v, deq_v_t4, DK, DV, NE, Q, C -#define FWD_ARGS args, q, k, v, mask, sinks, dst, shmem_f16, tgpig, tiisg, sgitg +#define FWD_ARGS args, q, k, v, mask, sinks, pad, dst, shmem_f16, tgpig, tiisg, sgitg switch (FC_flash_attn_ext_vec_nsg) { // note: disabled cases to reduce library load time case 1: kernel_flash_attn_ext_vec_impl(FWD_ARGS); break; @@ -5814,7 +5981,8 @@ kernel void kernel_flash_attn_ext_vec_reduce( const float m = simd_max(M); const float ms = exp(M - m); - S = 1.0f/simd_sum(S*ms); + S = simd_sum(S*ms); + S = S == 0.0f ? 0.0f : 1.0f/S; const short DV4 = DV/4; @@ -5834,21 +6002,17 @@ kernel void kernel_flash_attn_ext_vec_reduce( } template -kernel void kernel_cpy( +kernel void kernel_cpy_t_t( constant ggml_metal_kargs_cpy & args, device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 tptg[[threads_per_threadgroup]]) { + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]*tptg.y + tiitg/tptg.x; - - if (i01 >= args.ne01) { - return; - } + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -5859,190 +6023,70 @@ kernel void kernel_cpy( device T1 * dst_data = (device T1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tiitg%tptg.x; i00 < args.ne00; i00 += tptg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.ne00; ) { device const T0 * src = (device T0 *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); dst_data[i00] = (T1) src[0]; + break; } } -typedef decltype(kernel_cpy) kernel_cpy_t; +typedef decltype(kernel_cpy_t_t) kernel_cpy_t; -template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f32_i32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_i32_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f32_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_f16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_f16_f16")]] kernel kernel_cpy_t kernel_cpy_t_t; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy; -template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy; +template [[host_name("kernel_cpy_bf16_f32")]] kernel kernel_cpy_t kernel_cpy_t_t; +template [[host_name("kernel_cpy_bf16_bf16")]] kernel kernel_cpy_t kernel_cpy_t_t; #endif -// TODO: templetify these kernels -kernel void kernel_cpy_f32_q8_0( +template +kernel void kernel_cpy_f32_q( constant ggml_metal_kargs_cpy & args, device const char * src0, - device char * dst, + device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK8_0; + const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK; - device block_q8_0 * dst_data = (device block_q8_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + device block_q * dst_data = (device block_q *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x*QK8_0; i00 < args.ne00; i00 += ntg.x*QK8_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { + device const float * src = (device const float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + (i00*QK)*args.nb00); - quantize_q8_0(src, dst_data[i00/QK8_0]); + quantize_func(src, dst_data[i00]); + + break; } } -kernel void kernel_cpy_f32_q4_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; +typedef decltype(kernel_cpy_f32_q) cpy_f_q_t; - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_0; - - device block_q4_0 * dst_data = (device block_q4_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_0; i00 < args.ne00; i00 += ntg.x*QK4_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_0(src, dst_data[i00/QK4_0]); - } -} - -kernel void kernel_cpy_f32_q4_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_1; - - device block_q4_1 * dst_data = (device block_q4_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_1; i00 < args.ne00; i00 += ntg.x*QK4_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q4_1(src, dst_data[i00/QK4_1]); - } -} - -kernel void kernel_cpy_f32_q5_0( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_0; - - device block_q5_0 * dst_data = (device block_q5_0 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_0; i00 < args.ne00; i00 += ntg.x*QK5_0) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_0(src, dst_data[i00/QK5_0]); - } -} - -kernel void kernel_cpy_f32_q5_1( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK5_1; - - device block_q5_1 * dst_data = (device block_q5_1 *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK5_1; i00 < args.ne00; i00 += ntg.x*QK5_1) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_q5_1(src, dst_data[i00/QK5_1]); - } -} - -kernel void kernel_cpy_f32_iq4_nl( - constant ggml_metal_kargs_cpy & args, - device const char * src0, - device char * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], - ushort3 ntg[[threads_per_threadgroup]]) { - const int i03 = tgpig[2]; - const int i02 = tgpig[1]; - const int i01 = tgpig[0]; - - const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; - - const int64_t i3 = n / (args.ne2*args.ne1*args.ne0); - const int64_t i2 = (n - i3*args.ne2*args.ne1*args.ne0) / (args.ne1*args.ne0); - const int64_t i1 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0) / args.ne0; - const int64_t i0 = (n - i3*args.ne2*args.ne1*args.ne0 - i2*args.ne1*args.ne0 - i1*args.ne0)/QK4_NL; - - device block_iq4_nl * dst_data = (device block_iq4_nl *) (dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - - for (int64_t i00 = tpitg.x*QK4_NL; i00 < args.ne00; i00 += ntg.x*QK4_NL) { - device const float * src = (device float *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01 + i00*args.nb00); - - quantize_iq4_nl(src, dst_data[i00/QK4_NL]); - } -} +template [[host_name("kernel_cpy_f32_q8_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q4_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_0")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_q5_1")]] kernel cpy_f_q_t kernel_cpy_f32_q; +template [[host_name("kernel_cpy_f32_iq4_nl")]] kernel cpy_f_q_t kernel_cpy_f32_q; template kernel void kernel_cpy_q_f32( @@ -6050,11 +6094,12 @@ kernel void kernel_cpy_q_f32( device const char * src0, device char * dst, uint3 tgpig[[threadgroup_position_in_grid]], - ushort3 tpitg[[thread_position_in_threadgroup]], + ushort tiitg[[thread_index_in_threadgroup]], ushort3 ntg[[threads_per_threadgroup]]) { const int i03 = tgpig[2]; const int i02 = tgpig[1]; - const int i01 = tgpig[0]; + const int i01 = ntg[1] == 1 ? tgpig[0]%args.ne01 : tgpig[0]*ntg[1] + tiitg/ntg[0]; + const int iw0 = ntg[1] == 1 ? tgpig[0]/args.ne01 : 0; const int64_t n = i03*args.ne02*args.ne01*args.ne00 + i02*args.ne01*args.ne00 + i01*args.ne00; @@ -6066,10 +6111,12 @@ kernel void kernel_cpy_q_f32( device const block_q * src_data = (device const block_q *)(src0 + i03*args.nb03 + i02*args.nb02 + i01*args.nb01); device T4x4 * dst_data = (device T4x4 *)(dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); - for (int64_t i00 = tpitg.x; i00 < args.ne00/16; i00 += ntg.x) { + for (int64_t i00 = iw0*ntg[0] + tiitg%ntg[0]; i00 < args.nk0; ) { T4x4 temp; dequantize_func(src_data + i00/nl, i00%nl, temp); dst_data[i00] = temp; + + break; } } @@ -7522,7 +7569,7 @@ kernel void kernel_mul_mv_iq1_m_f32( kernel_mul_mv_iq1_m_f32_impl(args, src0, src1, dst, nullptr, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_nl_f32_impl( args_t args, device const char * src0, @@ -7535,13 +7582,12 @@ void kernel_mul_mv_iq4_nl_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK4_NL; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7552,6 +7598,9 @@ void kernel_mul_mv_iq4_nl_f32_impl( device const block_iq4_nl * x = (device const block_iq4_nl *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK4_NL; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7559,24 +7608,25 @@ void kernel_mul_mv_iq4_nl_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK4_NL + it * 8; + device const float * yb = y + ix*QK4_NL + it*8; uint32_t aux32[2]; thread const uint8_t * q8 = (thread const uint8_t *)aux32; float4 qf1, qf2; - for (int ib = ix; ib < nb; ib += 16) { + // [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; row++) { - device const block_iq4_nl & xb = x[row*nb + ib]; + for (short row = 0; row < NR0; row++) { + device const block_iq4_nl & xb = x[row*ns01 + ib]; device const uint16_t * q4 = (device const uint16_t *)(xb.qs + 8*it); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7607,7 +7657,7 @@ void kernel_mul_mv_iq4_nl_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7629,7 +7679,7 @@ kernel void kernel_mul_mv_iq4_nl_f32( kernel_mul_mv_iq4_nl_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_iq4_xs_f32_impl( args_t args, device const char * src0, @@ -7642,12 +7692,11 @@ void kernel_mul_mv_iq4_xs_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_K; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7658,6 +7707,9 @@ void kernel_mul_mv_iq4_xs_f32_impl( device const block_iq4_xs * x = (device const block_iq4_xs *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_K; + const int ns01 = args.nb01/args.nb00; + const short ix = tiisg/16; // 0 or 1 const short it = tiisg%16; // 0...15 const short ib = it/2; @@ -7667,7 +7719,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; device const float * yb = y + ix * QK_K + ib * 32 + il * 8; @@ -7676,15 +7728,16 @@ void kernel_mul_mv_iq4_xs_f32_impl( float4 qf1, qf2; - for (int ibl = ix; ibl < nb; ibl += 2) { + // [TAG_MUL_MV_WEIRD] + for (int ibl = ix; ibl < nb && ibl < ns01; ibl += 2) { device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; - for (short row = 0; row < nr0; ++row) { - device const block_iq4_xs & xb = x[row*nb + ibl]; + for (short row = 0; row < NR0; ++row) { + device const block_iq4_xs & xb = x[row*ns01 + ibl]; device const uint32_t * q4 = (device const uint32_t *)(xb.qs + 16*ib + 8*il); float4 acc1 = {0.f}, acc2 = {0.f}; @@ -7714,7 +7767,7 @@ void kernel_mul_mv_iq4_xs_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7736,7 +7789,7 @@ kernel void kernel_mul_mv_iq4_xs_f32( kernel_mul_mv_iq4_xs_f32_impl(args, src0, src1, dst, shmem, tgpig, tiisg, sgitg); } -template +template void kernel_mul_mv_mxfp4_f32_impl( args_t args, device const char * src0, @@ -7749,13 +7802,12 @@ void kernel_mul_mv_mxfp4_f32_impl( const short NSG = FC_mul_mv_nsg; threadgroup float * shmem_f32 = (threadgroup float *) shmem; - const int nb = args.ne00/QK_MXFP4; const int r0 = tgpig.x; const int r1 = tgpig.y; const int im = tgpig.z; - const int first_row = (r0 * NSG + sgitg) * nr0; + const int first_row = (r0 * NSG + sgitg) * NR0; const uint i12 = im%args.ne12; const uint i13 = im/args.ne12; @@ -7766,6 +7818,9 @@ void kernel_mul_mv_mxfp4_f32_impl( device const block_mxfp4 * x = (device const block_mxfp4 *) (src0 + offset0); device const float * y = (device const float *) (src1 + offset1); + const int nb = args.ne00/QK_MXFP4; + const int ns01 = args.nb01/args.nb00; // this can be larger than nb for permuted src0 tensors + const short ix = tiisg/2; // 0...15 const short it = tiisg%2; // 0 or 1 @@ -7773,20 +7828,22 @@ void kernel_mul_mv_mxfp4_f32_impl( threadgroup_barrier(mem_flags::mem_threadgroup); float4 yl[4]; - float sumf[nr0]={0.f}; + float sumf[NR0]={0.f}; - device const float * yb = y + ix * QK_MXFP4 + it * 8; + device const float * yb = y + ix*QK_MXFP4 + it*8; + + // note: just the check `ib < nb` is enough, but adding the redundant `&& ib < ns01` check makes the kernel a bit faster + // no idea why that is - needs some deeper investigation [TAG_MUL_MV_WEIRD] + for (int ib = ix; ib < nb && ib < ns01; ib += 16) { + device const float4 * y4 = (device const float4 *) yb; - for (int ib = ix; ib < nb; ib += 16) { - device const float4 * y4 = (device const float4 *)yb; yl[0] = y4[0]; yl[1] = y4[4]; yl[2] = y4[1]; yl[3] = y4[5]; -#pragma unroll(nr0) - for (short row = 0; row < nr0; row++) { - device const block_mxfp4 & xb = x[row*nb + ib]; + FOR_UNROLL (short row = 0; row < NR0; row++) { + device const block_mxfp4 & xb = x[row*ns01 + ib]; device const uint8_t * q2 = (device const uint8_t *)(xb.qs + 8*it); float4 acc1 = yl[0]*float4(shmem_f32[q2[0] & 0x0F], shmem_f32[q2[1] & 0x0F], shmem_f32[q2[2] & 0x0F], shmem_f32[q2[3] & 0x0F]); @@ -7804,7 +7861,7 @@ void kernel_mul_mv_mxfp4_f32_impl( device float * dst_f32 = (device float *) dst + (uint64_t)im*args.ne0*args.ne1 + (uint64_t)r1*args.ne0; - for (int row = 0; row < nr0 && first_row + row < args.ne0; ++row) { + for (int row = 0; row < NR0 && first_row + row < args.ne0; ++row) { float sum_all = simd_sum(sumf[row]); if (tiisg == 0) { dst_f32[first_row + row] = sum_all; @@ -7829,66 +7886,60 @@ kernel void kernel_mul_mv_mxfp4_f32( template kernel void kernel_get_rows_q( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int64_t ind = tiitg; ind < args.ne00/16; ind += tptg.x) { + auto psrc = (device const block_q *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = (device float4x4 *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); + + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { float4x4 temp; - dequantize_func(((device const block_q *) ((const device char *) src0 + r*args.nb01 + i02*args.nb02)) + ind/nl, ind%nl, temp); - *(((device float4x4 *) ((device char *) dst + i11*args.nb2 + i10*args.nb1)) + ind) = temp; + dequantize_func(psrc + ind/nl, ind%nl, temp); + pdst[ind] = temp; + + break; } } -template +template kernel void kernel_get_rows_f( constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device float * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + device const void * src0, + device const void * src1, + device void * dst, + uint3 tgpig[[threadgroup_position_in_grid]], + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 ntg [[threads_per_threadgroup]]) { + const int32_t iw0 = tgpig.x/args.ne10; + const int32_t i10 = tgpig.x%args.ne10; + const int32_t i11 = tgpig.y; + const int32_t i12 = tgpig.z; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; + const int32_t r = ((const device int32_t *) ((const device char *) src1 + i12*args.nb12 + i11*args.nb11 + i10*args.nb10))[0]; - const int64_t i02 = i11; + const int32_t i02 = i11; + const int32_t i03 = i12; - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device float *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device T *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; - } -} + auto psrc = (const device T0 *) ((const device char *) src0 + i03*args.nb03 + i02*args.nb02 + r*args.nb01); + auto pdst = ( device T *) (( device char *) dst + i12*args.nb3 + i11*args.nb2 + i10*args.nb1); -kernel void kernel_get_rows_i32( - constant ggml_metal_kargs_get_rows & args, - device const void * src0, - device const void * src1, - device int32_t * dst, - uint3 tgpig[[threadgroup_position_in_grid]], - uint tiitg[[thread_index_in_threadgroup]], - uint3 tptg [[threads_per_threadgroup]]) { - const int64_t i10 = tgpig.x; - const int64_t i11 = tgpig.y; + for (int ind = iw0*ntg.x + tiitg; ind < args.ne00t;) { + pdst[ind] = psrc[ind]; - const int64_t r = ((const device int32_t *) ((const device char *) src1 + i11*args.nb11 + i10*args.nb10))[0]; - - const int64_t i02 = i11; - - for (int ind = tiitg; ind < args.ne00; ind += tptg.x) { - (( device int32_t *) (( device char *) dst + i11*args.nb2 + i10*args.nb1))[ind] = - ((const device int32_t *) ((const device char *) src0 + i02*args.nb02 + r*args.nb01))[ind]; + break; } } @@ -8374,12 +8425,13 @@ kernel void kernel_mul_mm_id( // get rows // -typedef decltype(kernel_get_rows_f) get_rows_f_t; +typedef decltype(kernel_get_rows_f) get_rows_f_t; -template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; -template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f32")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_f16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_i32")]] kernel get_rows_f_t kernel_get_rows_f; #if defined(GGML_METAL_HAS_BF16) -template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; +template [[host_name("kernel_get_rows_bf16")]] kernel get_rows_f_t kernel_get_rows_f; #endif typedef decltype(kernel_get_rows_q) get_rows_q_t; @@ -8784,3 +8836,51 @@ kernel void kernel_pool_2d_avg_f32( o_ptr[cur_oh * args.OW + cur_ow] = res; } + +kernel void kernel_opt_step_adamw_f32( + constant ggml_metal_kargs_opt_step_adamw & args, + device float * x, + device const float * g, + device float * g_m, + device float * g_v, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + const float alpha = pars[0]; + const float beta1 = pars[1]; + const float beta2 = pars[2]; + const float eps = pars[3]; + const float wd = pars[4]; + const float beta1h = pars[5]; + const float beta2h = pars[6]; + + const float gi = g[gid]; + const float gmi = g_m[gid] * beta1 + gi * (1.0f - beta1); + const float gvi = g_v[gid] * beta2 + gi * gi * (1.0f - beta2); + + g_m[gid] = gmi; + g_v[gid] = gvi; + + const float mh = gmi * beta1h; + const float vh = sqrt(gvi * beta2h) + eps; + + x[gid] = x[gid] * (1.0f - alpha * wd) - alpha * mh / vh; +} + +kernel void kernel_opt_step_sgd_f32( + constant ggml_metal_kargs_opt_step_sgd & args, + device float * x, + device const float * g, + device const float * pars, + uint gid[[thread_position_in_grid]]) { + + if (gid >= args.np) { + return; + } + + x[gid] = x[gid] * (1.0f - pars[0] * pars[1]) - pars[0] * g[gid]; +} diff --git a/ml/backend/ggml/ggml/src/ggml.c b/ml/backend/ggml/ggml/src/ggml.c index aecbdad5a3..2bce1375ba 100644 --- a/ml/backend/ggml/ggml/src/ggml.c +++ b/ml/backend/ggml/ggml/src/ggml.c @@ -1143,10 +1143,10 @@ static const char * GGML_UNARY_OP_NAME[GGML_UNARY_OP_COUNT] = { "HARDSIGMOID", "EXP", "GELU_ERF", + "XIELU", }; -static_assert(GGML_UNARY_OP_COUNT == 15, "GGML_UNARY_OP_COUNT != 15"); - +static_assert(GGML_UNARY_OP_COUNT == 16, "GGML_UNARY_OP_COUNT != 16"); static const char * GGML_GLU_OP_NAME[GGML_GLU_OP_COUNT] = { "REGLU", @@ -2652,6 +2652,29 @@ struct ggml_tensor * ggml_silu_inplace( return ggml_unary_inplace(ctx, a, GGML_UNARY_OP_SILU); } +// ggml_xielu + +struct ggml_tensor * ggml_xielu( + struct ggml_context * ctx, + struct ggml_tensor * a, + float alpha_n, + float alpha_p, + float beta, + float eps) { + struct ggml_tensor * result = ggml_dup_tensor(ctx, a); + + ggml_set_op_params_i32(result, 0, (int32_t) GGML_UNARY_OP_XIELU); + ggml_set_op_params_f32(result, 1, beta + ggml_softplus(alpha_n)); + ggml_set_op_params_f32(result, 2, ggml_softplus(alpha_p)); + ggml_set_op_params_f32(result, 3, beta); + ggml_set_op_params_f32(result, 4, eps); + + result->op = GGML_OP_UNARY; + result->src[0] = a; + + return result; +} + // ggml_silu_back struct ggml_tensor * ggml_silu_back( @@ -3829,6 +3852,15 @@ struct ggml_tensor * ggml_soft_max_ext( return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, false); } +struct ggml_tensor * ggml_soft_max_ext_inplace( + struct ggml_context * ctx, + struct ggml_tensor * a, + struct ggml_tensor * mask, + float scale, + float max_bias) { + return ggml_soft_max_impl(ctx, a, mask, scale, max_bias, true); +} + void ggml_soft_max_add_sinks( struct ggml_tensor * a, struct ggml_tensor * sinks) { diff --git a/runner/llamarunner/image.go b/runner/llamarunner/image.go index cc0153aeaa..9fc9708113 100644 --- a/runner/llamarunner/image.go +++ b/runner/llamarunner/image.go @@ -56,7 +56,7 @@ func (c *ImageContext) Free(modelPath string) { } } -func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]float32, error) { +func (c *ImageContext) MultimodalTokenize(llamaContext *llama.Context, data []byte) ([]llama.MtmdChunk, error) { if c == nil { return nil, nil } @@ -70,10 +70,10 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]f c.mu.Lock() defer c.mu.Unlock() - embed, err := c.findImage(hash) + chunks, err := c.findImage(hash) if err != nil { if c.mtmd != nil { - embed, err = c.mtmd.NewEmbed(llamaContext, data) + chunks, err = c.mtmd.MultimodalTokenize(llamaContext, data) if err != nil { return nil, err } @@ -81,10 +81,10 @@ func (c *ImageContext) NewEmbed(llamaContext *llama.Context, data []byte) ([][]f return nil, errors.New("received image but vision model not loaded") } - c.addImage(hash, embed) + c.addImage(hash, chunks) } - return embed, nil + return chunks, nil } func (c *ImageContext) BatchSize(configuredBatchSize int) int { @@ -102,7 +102,7 @@ func (c *ImageContext) EmbedSize(llamaContext *llama.Context) int { type imageCache struct { key uint64 - val [][]float32 + val []llama.MtmdChunk lastUsed time.Time } @@ -114,7 +114,7 @@ func (c *ImageContext) hashImage(image []byte) uint64 { var errImageNotFound = errors.New("image not found in cache") -func (c *ImageContext) findImage(hash uint64) ([][]float32, error) { +func (c *ImageContext) findImage(hash uint64) ([]llama.MtmdChunk, error) { for i := range c.images { if c.images[i].key == hash { slog.Debug("loading image embeddings from cache", "entry", i) @@ -126,7 +126,7 @@ func (c *ImageContext) findImage(hash uint64) ([][]float32, error) { return nil, errImageNotFound } -func (c *ImageContext) addImage(hash uint64, embed [][]float32) { +func (c *ImageContext) addImage(hash uint64, embed []llama.MtmdChunk) { best := time.Now() var bestImage int diff --git a/runner/llamarunner/image_test.go b/runner/llamarunner/image_test.go index 2e1efaec85..f7d98a4722 100644 --- a/runner/llamarunner/image_test.go +++ b/runner/llamarunner/image_test.go @@ -3,16 +3,18 @@ package llamarunner import ( "reflect" "testing" + + "github.com/ollama/ollama/llama" ) func TestImageCache(t *testing.T) { cache := ImageContext{images: make([]imageCache, 4)} - valA := [][]float32{{0.1, 0.2}, {0.3}} - valB := [][]float32{{0.4}, {0.5}, {0.6}} - valC := [][]float32{{0.7}} - valD := [][]float32{{0.8}} - valE := [][]float32{{0.9}} + valA := []llama.MtmdChunk{{Embed: []float32{0.1, 0.2}}, {Embed: []float32{0.3}}} + valB := []llama.MtmdChunk{{Embed: []float32{0.4}}, {Embed: []float32{0.5}}, {Embed: []float32{0.6}}} + valC := []llama.MtmdChunk{{Embed: []float32{0.7}}} + valD := []llama.MtmdChunk{{Embed: []float32{0.8}}} + valE := []llama.MtmdChunk{{Embed: []float32{0.9}}} // Empty cache result, err := cache.findImage(0x5adb61d31933a946) diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 572abddec3..163aaa62d6 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -209,13 +209,19 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) return nil, fmt.Errorf("invalid image index: %d", n) } - embed, err := s.image.NewEmbed(s.lc, images[imageIndex].Data) + chunks, err := s.image.MultimodalTokenize(s.lc, images[imageIndex].Data) if err != nil { return nil, err } - for _, e := range embed { - inputs = append(inputs, input{embed: e}) + for _, c := range chunks { + if len(c.Embed) != 0 { + inputs = append(inputs, input{embed: c.Embed}) + } else { + for _, t := range c.Tokens { + inputs = append(inputs, input{token: t}) + } + } } } }