From 2aba569a2a593f56651ded7f5011480ece70c80f Mon Sep 17 00:00:00 2001 From: Thomas Stocker Date: Tue, 14 Oct 2025 19:59:58 +0200 Subject: [PATCH] Vulkan based on #9650 (#11835) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit * implement the vulkan C backend * add support in gpu.go * add support in gen_linux.sh * it builds * fix segfault * fix compilation * fix free memory monitor * fix total memory monitor * update gpu.go * fix build * fix check_perfmon len * remove cap_get_bound check * fix vulkan handle releasing * fix build on federa 40 * fix vulkan on windows * making amdgpu work on arm achitecutre with vulkan * add x86_64 lines in VulkanGlobs and capLinuxGlobs * add aarch64 lines in vulkanGlobs and capLinuxGlobs * Fix variable name * Add vulkan build patch from @jmorganca * Sync vendored ggml to add Vulkan support * Updated dockerfile https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco * Installing rocm library Signed-off-by: Vadim Grinco * This version works well built based on this: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco * Applied 00-fix-vulkan-building.patch Work done by McBane87 here: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco * Fixed the "detached head" issues Signed-off-by: Vadim Grinco * Merged in the right direction Signed-off-by: Vadim Grinco * Merging the latest stable (#2) * Applied 00-fix-vulkan-building.patch * Implemented vulkan backend based on the work done by whyvl, Dts0, McBane87 and others Tested on AMD Ryzen 7 8845HS w/ Radeon 780M Graphics with ROCm disabled ``` [GIN-debug] POST /v1/chat/completions --> github.com/ollama/ollama/server.(*Server).ChatHandler-fm (6 handlers) [GIN-debug] POST /v1/completions --> github.com/ollama/ollama/server.(*Server).GenerateHandler-fm (6 handlers) [GIN-debug] POST /v1/embeddings --> github.com/ollama/ollama/server.(*Server).EmbedHandler-fm (6 handlers) [GIN-debug] GET /v1/models --> github.com/ollama/ollama/server.(*Server).ListHandler-fm (6 handlers) [GIN-debug] GET /v1/models/:model --> github.com/ollama/ollama/server.(*Server).ShowHandler-fm (6 handlers) time=2025-03-11T13:00:40.793Z level=INFO source=gpu.go:199 msg="vulkan: load libvulkan and libcap ok" time=2025-03-11T13:00:40.877Z level=INFO source=gpu.go:421 msg="error looking up vulkan GPU memory" error="device is a CPU" time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:443 msg="amdgpu detected, but no compatible rocm library found. Either install rocm v6, or follow manual install instructions at https://github.com/ollama/ollama/blob/main/docs/linux.md#manual-install" time=2025-03-11T13:00:40.878Z level=WARN source=amd_linux.go:348 msg="unable to verify rocm library: no suitable rocm found, falling back to CPU" time=2025-03-11T13:00:40.879Z level=INFO source=types.go:137 msg="inference compute" id=0 library=vulkan variant="" compute=1.3 driver=1.3 name="AMD Radeon Graphics (RADV GFX1103_R1)" total="15.6 GiB" available="15.6 GiB" ``` ``` # ollama run phi4:14b >>> /set verbose Set 'verbose' mode. >>> how's it going? Hello! I'm here to help you with any questions or tasks you have. How can I assist you today? 😊 total duration: 3.341959745s load duration: 18.165612ms prompt eval count: 15 token(s) prompt eval duration: 475ms prompt eval rate: 31.58 tokens/s eval count: 26 token(s) eval duration: 2.846s eval rate: 9.14 tokens/s >>> ``` * This is no longer needed Signed-off-by: Vadim Grinco * Fixes SIGSEGV: segmentation violation running gemma3 models on ollama 0.6.0 #21 Patch provided by McBane87 on https://github.com/whyvl/ollama-vulkan/issues/21 Signed-off-by: Vadim Grinco * Applied 04-disable-mmap-vulkan.patch From: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2660836871 Signed-off-by: Vadim Grinco * Pulled new upstream code for ggml-bulkan backend Signed-off-by: Vadim Grinco * Merged latest ollama 0.6.2 and nasrally's Flash Attention patches (#5) * readme: add Ellama to list of community integrations (#9800) * readme: add screenpipe to community integrations (#9786) * Add support for ROCm gfx1151 (#9773) * conditionally enable parallel pipelines * sample: make mutations in transforms explicit (#9743) * updated minP to use early exit making use of sorted tokens * ml/backend/ggml: allocate memory with malloc when loading model (#9822) * runner: remove cache prompt flag from ollama runner (#9826) We do not need to bypass the prompt caching in the ollama runner yet, as only embedding models needed to bypass the prompt caching. When embedding models are implemented they can skip initializing this cache completely. * ollamarunner: Check for minBatch of context space when shifting Models can specify that a group of inputs need to be handled a single batch. However, context shifting didn't respect this and could trigger a break anyways. In this case, we should instead trigger a context shift earlier so that it occurs before the grouped batch. Note that there still some corner cases: - A long prompt that exceeds the context window can get truncated in the middle of an image. With the current models, this will result in the model not recognizing the image at all, which is pretty much the expected result with truncation. - The context window is set less than the minimum batch size. The only solution to this is to refuse to load the model with these settings. However, this can never occur with current models and default settings. Since users are unlikely to run into these scenarios, fixing them is left as a follow up. * Applied latest patches from McBane87 See this for details: https://github.com/whyvl/ollama-vulkan/issues/7#issuecomment-2708820861 Signed-off-by: Vadim Grinco * Add ability to enable flash attention on vulkan (#4) * discover: add flash attention handling for vulkan * envconfig: fix typo in config.go As part of the process some code was refactored and I added a new field FlashAttention to GpuInfo since the previous solution didn't allow for a granular check via vulkan extensions. As a side effect, this now allows for granular per-device FA support checking in other places --------- Signed-off-by: Vadim Grinco Co-authored-by: zeo <108888572+zeozeozeo@users.noreply.github.com> Co-authored-by: Louis Beaumont Co-authored-by: Daniel Hiltgen Co-authored-by: Michael Yang Co-authored-by: Parth Sareen Co-authored-by: Jeffrey Morgan Co-authored-by: Bruce MacDonald Co-authored-by: Jesse Gross Co-authored-by: Nikita <50599445+nasrally@users.noreply.github.com> * Revert Readme changes * Revert * Revert changes in amd_linux.go * Revert changes in amd_linux.go * Remove flashattention setting gpu.go * Revert whitespace changes in gpu.go * Revert changes in transforms_test.go * Revert changes in runner.go * Revert changes in Makefile.sync * Revert some unintented changes in Dockerfile * Revert vulkan copy changes in Dockerfile * Update Vulkan Code to de4c07f93783a1a96456a44dc16b9db538ee1618 * Fixed duplicate sync in ggml.go * Revert changes in ggml.go * Revert chnages in ggml.go * enable falsh attention on vulkan * revert remove parenthesis * fixed flash attention logic enabling * vk_check_flash_attention 0 means supported * Update gpu.go * Add vulkan to Windows Build script * Remove commented out code * Enable Vulkan Flash attention in FlashAttentionSupported * Fix logging * Update Vulkan backend to e54d41befcc1575f4c898c5ff4ef43970cead75f * Removed libcap related code libcap is not directly related to Vulkan and should be added by its own PR. It adds additional library dependencies for building and also requires users to run setcap or run ollama as root, which is not ideal for easy use * Fix Unit Test (Add Vulkan Library) * Add vulkan to TestHomogeneousGPUs Test * vulkan: get GPU ID (ollama v0.11.5) Signed-off-by: Xiaodong Ye * disable mmap for vulkan * Reduce Changes remove TestHomogeneousGPUs (doesn't exist on master) * Update vulkan version to the version used in llama.cpp * rename gpu patch to correct number * added Vulkan API to get correct Device UUID current UUID from pipelineCacheUUID does not match CUDA * Fix GPU ID Patch * Remove Code not in llama.cpp * modified UUID code inside ggml * Fix Patch * Copied minimal definition from vulkan header * Fix compile error in Mac Metal is preferred so we're disabling Vulkan for now * Removed unused code Fix linter error in CI * Fix patches apply * fixing lint error * Removed unneeded function call Somehow removing this call fixed the crashing when Vulkan header was removed * added missing NL * Fixed missing members in Vulkan header also added zero clear for some structs * Fixed wrong structure ID * Fixed Vulkan header More aligned with official header definition now * buildvulkanAsSeperateFunction * Vulkan on Windows Test * temporarly comment out gate to run windows task * use temporarly windows-latest for build * Commenting out other presets to build vulkan * reenable cpu * commenting out error action stop * temporarly commenting out rocm * set vulkan path * comment out cude for faster turnaround * correct vulkan install * correct vulkan silent install * fixed install command * revert debugging changes (vulkan builds on windows) * revert windows-latest * trying to build vulkan for linux * temporarly disable cuda and rocm * try again linux build * fix version * trying to fix * trying again * trying again * fix version * fixed vulkan-sdk name * try again * trying again * try without version number * try again * add some more extra * trying to use version 1.4.313 * revert debugging changes * Filter out already supported gpus * revert debug code * Use runners for GPU discovery This revamps how we discover GPUs in the system by leveraging the Ollama runner. This should eliminate inconsistency between our GPU discovery and the runners capabilities at runtime, particularly for cases where we try to filter out unsupported GPUs. Now the runner does that implicitly based on the actual device list. In some cases free VRAM reporting can be unreliable which can leaad to scheduling mistakes, so this also includes a patch to leverage more reliable VRAM reporting libraries if available. Automatic workarounds have been removed as only one GPU leveraged this, which is now documented. This GPU will soon fall off the support matrix with the next ROCm bump. Additional cleanup of the scheduler and discovery packages can be done in the future once we have switched on the new memory management code, and removed support for the llama runner. * timing info for runner * WIP - wire up Vulkan with the new engine based discovery Not a complete implementation - free VRAM is better, but not accurate on windows * fix - trust the library paths from discovery when starting runner * fix index bug * fix vulkan ids to be underlying * fix - give bootstrapping more time on slow systems * Test if Vulkan device is supported * vk_check_flash_attention is not needed (coompat2 coopmapt and scalar implementation exist) * Handle GGML_VK_VISIBLE_DEVICES * ask for supported first * win: fix CPU query buffer handling Try in a short loop until we get the size right. * test: harden integration tests for slow start If the server takes a while to start up, block tests from starting until it's online to avoid setting large timeouts in individual test cases. * gofumpt fix * fix build * merge fixes * merge fixes * fixed build * merge fixes * fixing build * fixed build * fixed formatting * fixed build * fix vulkan gpu id patch * sync llama.cpp vulkan code * update build windows script * merge fixes * fix format * fixed vulkan casing * handle igpu as gpu * improve case * print out unknown library * rturn Vulkan for vulkan library * Revert "rturn Vulkan for vulkan library" This reverts commit 690461a12fd5e93295d174c97edefb2bc33285b1. * fixed patch number * return Library Name * remvoe debug code * return integrated in vulkan backend * Return pci Properties * update patch * directly get pci proeprties without parsing * workaround for filtering devices. Correct way is to have a LibraryPosition Parameter in the deviceInfo * Revert "directly get pci proeprties without parsing" This reverts commit 8e0624851f5ed7d9f74518f574dfb422e4dd4dc2. * Set FilteredID for Environment Filtering * ROCm Library is named ROCm * revert changes in patch * Create 0028-vulkan-pci-and-memory.patch * vulkan memory patch * casing fix * Add more pci properties * Added better memory management * Added better memory managament * fixed patch * Fixed patch * FilterID creation group by library * filter out vulkan supported by other gpu * fixing deviceid compare * Vulkan Fix FA coopmat1 invalid array indexing * Use everywhere the same Vulkan Version 1.4.321.1 * Remove unneeded patch * vulkan update * sync vulkan glsl files * only use for vulkan the filteredid (numeric device number) * simplify code --------- Signed-off-by: Vadim Grinco Signed-off-by: Xiaodong Ye Co-authored-by: pufferffish Co-authored-by: KOISHI KOMEIJI FROM TOUHOU 11 Co-authored-by: DSLstandard Co-authored-by: pufferffish Co-authored-by: yeongbba Co-authored-by: tomaThomas Co-authored-by: Antoine Viallon Co-authored-by: Vadim Grinco Co-authored-by: zeo <108888572+zeozeozeo@users.noreply.github.com> Co-authored-by: Louis Beaumont Co-authored-by: Daniel Hiltgen Co-authored-by: Michael Yang Co-authored-by: Parth Sareen Co-authored-by: Jeffrey Morgan Co-authored-by: Bruce MacDonald Co-authored-by: Jesse Gross Co-authored-by: Nikita <50599445+nasrally@users.noreply.github.com> Co-authored-by: Masato Nakasaka Co-authored-by: Xiaodong Ye Co-authored-by: Daniel Hiltgen --- .github/workflows/test.yaml | 35 +- CMakeLists.txt | 12 + CMakePresets.json | 9 + Dockerfile | 24 +- discover/gpu.go | 31 +- discover/runner.go | 57 +- discover/types.go | 5 +- envconfig/config.go | 2 + llama/llama.go | 4 +- ...027-vulkan-get-GPU-ID-ollama-v0.11.5.patch | 95 + .../patches/0028-vulkan-pci-and-memory.patch | 253 + llm/server.go | 1 + ml/backend/ggml/ggml.go | 10 +- ml/backend/ggml/ggml/.rsync-filter | 4 + .../ggml/ggml/src/ggml-vulkan/CMakeLists.txt | 211 + .../ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp | 13903 ++++++++++++++++ .../ggml-vulkan/vulkan-shaders/CMakeLists.txt | 31 + .../src/ggml-vulkan/vulkan-shaders/acc.comp | 29 + .../src/ggml-vulkan/vulkan-shaders/add.comp | 69 + .../ggml-vulkan/vulkan-shaders/add_id.comp | 42 + .../ggml-vulkan/vulkan-shaders/argmax.comp | 60 + .../ggml-vulkan/vulkan-shaders/argsort.comp | 79 + .../src/ggml-vulkan/vulkan-shaders/clamp.comp | 17 + .../ggml-vulkan/vulkan-shaders/concat.comp | 41 + .../vulkan-shaders/contig_copy.comp | 49 + .../ggml-vulkan/vulkan-shaders/conv2d_dw.comp | 105 + .../ggml-vulkan/vulkan-shaders/conv2d_mm.comp | 349 + .../vulkan-shaders/conv_transpose_1d.comp | 98 + .../src/ggml-vulkan/vulkan-shaders/copy.comp | 23 + .../vulkan-shaders/copy_from_quant.comp | 51 + .../vulkan-shaders/copy_to_quant.comp | 296 + .../src/ggml-vulkan/vulkan-shaders/cos.comp | 17 + .../vulkan-shaders/count_equal.comp | 31 + .../vulkan-shaders/dequant_f32.comp | 20 + .../vulkan-shaders/dequant_funcs.glsl | 616 + .../vulkan-shaders/dequant_funcs_cm2.glsl | 720 + .../vulkan-shaders/dequant_head.glsl | 13 + .../vulkan-shaders/dequant_iq1_m.comp | 42 + .../vulkan-shaders/dequant_iq1_s.comp | 35 + .../vulkan-shaders/dequant_iq2_s.comp | 44 + .../vulkan-shaders/dequant_iq2_xs.comp | 43 + .../vulkan-shaders/dequant_iq2_xxs.comp | 49 + .../vulkan-shaders/dequant_iq3_s.comp | 40 + .../vulkan-shaders/dequant_iq3_xxs.comp | 51 + .../vulkan-shaders/dequant_iq4_nl.comp | 32 + .../vulkan-shaders/dequant_iq4_xs.comp | 34 + .../vulkan-shaders/dequant_mxfp4.comp | 32 + .../vulkan-shaders/dequant_q2_k.comp | 34 + .../vulkan-shaders/dequant_q3_k.comp | 42 + .../vulkan-shaders/dequant_q4_0.comp | 30 + .../vulkan-shaders/dequant_q4_1.comp | 32 + .../vulkan-shaders/dequant_q4_k.comp | 68 + .../vulkan-shaders/dequant_q5_0.comp | 34 + .../vulkan-shaders/dequant_q5_1.comp | 35 + .../vulkan-shaders/dequant_q5_k.comp | 70 + .../vulkan-shaders/dequant_q6_k.comp | 33 + .../vulkan-shaders/dequant_q8_0.comp | 31 + .../vulkan-shaders/diag_mask_inf.comp | 34 + .../src/ggml-vulkan/vulkan-shaders/div.comp | 27 + .../src/ggml-vulkan/vulkan-shaders/exp.comp | 21 + .../vulkan-shaders/flash_attn.comp | 383 + .../vulkan-shaders/flash_attn_base.glsl | 202 + .../vulkan-shaders/flash_attn_cm1.comp | 418 + .../vulkan-shaders/flash_attn_cm2.comp | 320 + .../flash_attn_split_k_reduce.comp | 120 + .../src/ggml-vulkan/vulkan-shaders/geglu.comp | 13 + .../ggml-vulkan/vulkan-shaders/geglu_erf.comp | 27 + .../vulkan-shaders/geglu_quick.comp | 11 + .../src/ggml-vulkan/vulkan-shaders/gelu.comp | 25 + .../ggml-vulkan/vulkan-shaders/gelu_erf.comp | 39 + .../vulkan-shaders/gelu_quick.comp | 23 + .../vulkan-shaders/generic_binary_head.glsl | 51 + .../vulkan-shaders/generic_head.glsl | 9 + .../vulkan-shaders/generic_unary_head.glsl | 76 + .../ggml-vulkan/vulkan-shaders/get_rows.comp | 42 + .../vulkan-shaders/get_rows_quant.comp | 51 + .../ggml-vulkan/vulkan-shaders/glu_head.glsl | 19 + .../ggml-vulkan/vulkan-shaders/glu_main.glsl | 29 + .../vulkan-shaders/group_norm.comp | 66 + .../vulkan-shaders/hardsigmoid.comp | 22 + .../ggml-vulkan/vulkan-shaders/hardswish.comp | 22 + .../ggml-vulkan/vulkan-shaders/im2col.comp | 103 + .../ggml-vulkan/vulkan-shaders/im2col_3d.comp | 125 + .../ggml-vulkan/vulkan-shaders/l2_norm.comp | 41 + .../vulkan-shaders/leaky_relu.comp | 22 + .../src/ggml-vulkan/vulkan-shaders/mul.comp | 27 + .../mul_mat_split_k_reduce.comp | 48 + .../vulkan-shaders/mul_mat_vec.comp | 169 + .../vulkan-shaders/mul_mat_vec_base.glsl | 182 + .../vulkan-shaders/mul_mat_vec_iq1_m.comp | 82 + .../vulkan-shaders/mul_mat_vec_iq1_s.comp | 79 + .../vulkan-shaders/mul_mat_vec_iq2_s.comp | 90 + .../vulkan-shaders/mul_mat_vec_iq2_xs.comp | 87 + .../vulkan-shaders/mul_mat_vec_iq2_xxs.comp | 87 + .../vulkan-shaders/mul_mat_vec_iq3_s.comp | 90 + .../vulkan-shaders/mul_mat_vec_iq3_xxs.comp | 88 + .../vulkan-shaders/mul_mat_vec_nc.comp | 122 + .../vulkan-shaders/mul_mat_vec_p021.comp | 154 + .../vulkan-shaders/mul_mat_vec_q2_k.comp | 130 + .../vulkan-shaders/mul_mat_vec_q3_k.comp | 132 + .../vulkan-shaders/mul_mat_vec_q4_k.comp | 136 + .../vulkan-shaders/mul_mat_vec_q5_k.comp | 167 + .../vulkan-shaders/mul_mat_vec_q6_k.comp | 130 + .../vulkan-shaders/mul_mat_vecq.comp | 140 + .../ggml-vulkan/vulkan-shaders/mul_mm.comp | 481 + .../vulkan-shaders/mul_mm_cm2.comp | 609 + .../vulkan-shaders/mul_mm_funcs.glsl | 556 + .../ggml-vulkan/vulkan-shaders/mul_mmq.comp | 449 + .../vulkan-shaders/mul_mmq_funcs.glsl | 105 + .../ggml-vulkan/vulkan-shaders/multi_add.comp | 111 + .../src/ggml-vulkan/vulkan-shaders/norm.comp | 44 + .../vulkan-shaders/opt_step_adamw.comp | 42 + .../vulkan-shaders/opt_step_sgd.comp | 22 + .../src/ggml-vulkan/vulkan-shaders/pad.comp | 49 + .../ggml-vulkan/vulkan-shaders/pool2d.comp | 74 + .../vulkan-shaders/quantize_q8_1.comp | 127 + .../src/ggml-vulkan/vulkan-shaders/reglu.comp | 9 + .../src/ggml-vulkan/vulkan-shaders/relu.comp | 21 + .../ggml-vulkan/vulkan-shaders/repeat.comp | 26 + .../vulkan-shaders/repeat_back.comp | 37 + .../ggml-vulkan/vulkan-shaders/rms_norm.comp | 105 + .../vulkan-shaders/rms_norm_back.comp | 55 + .../vulkan-shaders/rms_norm_partials.comp | 65 + .../src/ggml-vulkan/vulkan-shaders/roll.comp | 46 + .../ggml-vulkan/vulkan-shaders/rope_head.glsl | 55 + .../vulkan-shaders/rope_multi.comp | 58 + .../ggml-vulkan/vulkan-shaders/rope_neox.comp | 41 + .../ggml-vulkan/vulkan-shaders/rope_norm.comp | 41 + .../vulkan-shaders/rope_vision.comp | 47 + .../src/ggml-vulkan/vulkan-shaders/rte.glsl | 5 + .../src/ggml-vulkan/vulkan-shaders/scale.comp | 24 + .../ggml-vulkan/vulkan-shaders/sigmoid.comp | 20 + .../src/ggml-vulkan/vulkan-shaders/silu.comp | 22 + .../ggml-vulkan/vulkan-shaders/silu_back.comp | 26 + .../src/ggml-vulkan/vulkan-shaders/sin.comp | 17 + .../ggml-vulkan/vulkan-shaders/soft_max.comp | 195 + .../vulkan-shaders/soft_max_back.comp | 54 + .../src/ggml-vulkan/vulkan-shaders/sqrt.comp | 17 + .../ggml-vulkan/vulkan-shaders/square.comp | 17 + .../src/ggml-vulkan/vulkan-shaders/sub.comp | 29 + .../ggml-vulkan/vulkan-shaders/sum_rows.comp | 70 + .../ggml-vulkan/vulkan-shaders/swiglu.comp | 9 + .../vulkan-shaders/swiglu_oai.comp | 14 + .../src/ggml-vulkan/vulkan-shaders/tanh.comp | 20 + .../vulkan-shaders/timestep_embedding.comp | 42 + .../src/ggml-vulkan/vulkan-shaders/types.glsl | 1465 ++ .../ggml-vulkan/vulkan-shaders/upscale.comp | 100 + .../src/ggml-vulkan/vulkan-shaders/utils.glsl | 25 + .../vulkan-shaders/vulkan-shaders-gen.cpp | 1097 ++ .../src/ggml-vulkan/vulkan-shaders/wkv6.comp | 87 + .../src/ggml-vulkan/vulkan-shaders/wkv7.comp | 91 + scripts/build_windows.ps1 | 15 +- 152 files changed, 29425 insertions(+), 15 deletions(-) create mode 100644 llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch create mode 100644 llama/patches/0028-vulkan-pci-and-memory.patch create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp create mode 100644 ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index e470540a22..12ee713511 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -52,6 +52,12 @@ jobs: container: rocm/dev-ubuntu-22.04:6.1.2 extra-packages: rocm-libs flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_PREFIX_PATH=/opt/rocm' + - preset: Vulkan + container: ubuntu:22.04 + extra-packages: > + mesa-vulkan-drivers vulkan-tools + libvulkan1 libvulkan-dev + vulkan-sdk cmake ccache g++ make runs-on: linux container: ${{ matrix.container }} steps: @@ -59,7 +65,19 @@ jobs: - run: | [ -n "${{ matrix.container }}" ] || sudo=sudo $sudo apt-get update + # Add LunarG Vulkan SDK apt repo for Ubuntu 22.04 + if [ "${{ matrix.preset }}" = "Vulkan" ]; then + $sudo apt-get install -y --no-install-recommends wget gnupg ca-certificates software-properties-common + wget -qO - https://packages.lunarg.com/lunarg-signing-key-pub.asc | $sudo gpg --dearmor -o /usr/share/keyrings/lunarg-archive-keyring.gpg + # Use signed-by to bind the repo to the installed keyring to avoid NO_PUBKEY + echo "deb [signed-by=/usr/share/keyrings/lunarg-archive-keyring.gpg] https://packages.lunarg.com/vulkan/1.4.313 jammy main" | $sudo tee /etc/apt/sources.list.d/lunarg-vulkan-1.4.313-jammy.list > /dev/null + $sudo apt-get update + fi $sudo apt-get install -y cmake ccache ${{ matrix.extra-packages }} + # Export VULKAN_SDK if provided by LunarG package (defensive) + if [ -d "/usr/lib/x86_64-linux-gnu/vulkan" ] && [ "${{ matrix.preset }}" = "Vulkan" ]; then + echo "VULKAN_SDK=/usr" >> $GITHUB_ENV + fi env: DEBIAN_FRONTEND: noninteractive - uses: actions/cache@v4 @@ -92,18 +110,21 @@ jobs: - preset: ROCm install: https://download.amd.com/developer/eula/rocm-hub/AMD-Software-PRO-Edition-24.Q4-WinSvr2022-For-HIP.exe flags: '-DAMDGPU_TARGETS=gfx1010 -DCMAKE_C_COMPILER=clang -DCMAKE_CXX_COMPILER=clang++ -DCMAKE_C_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma" -DCMAKE_CXX_FLAGS="-parallel-jobs=4 -Wno-ignored-attributes -Wno-deprecated-pragma"' + - preset: Vulkan + install: https://sdk.lunarg.com/sdk/download/1.4.321.1/windows/vulkansdk-windows-X64-1.4.321.1.exe runs-on: windows steps: - run: | choco install -y --no-progress ccache ninja ccache -o cache_dir=${{ github.workspace }}\.ccache - - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' + - if: matrix.preset == 'CUDA' || matrix.preset == 'ROCm' || matrix.preset == 'Vulkan' id: cache-install uses: actions/cache/restore@v4 with: path: | C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA C:\Program Files\AMD\ROCm + C:\VulkanSDK key: ${{ matrix.install }} - if: matrix.preset == 'CUDA' name: Install CUDA ${{ matrix.cuda-version }} @@ -133,6 +154,18 @@ jobs: echo "HIPCXX=$hipPath\bin\clang++.exe" | Out-File -FilePath $env:GITHUB_ENV -Append echo "HIP_PLATFORM=amd" | Out-File -FilePath $env:GITHUB_ENV -Append echo "CMAKE_PREFIX_PATH=$hipPath" | Out-File -FilePath $env:GITHUB_ENV -Append + - if: matrix.preset == 'Vulkan' + name: Install Vulkan ${{ matrix.rocm-version }} + run: | + $ErrorActionPreference = "Stop" + if ("${{ steps.cache-install.outputs.cache-hit }}" -ne 'true') { + Invoke-WebRequest -Uri "${{ matrix.install }}" -OutFile "install.exe" + Start-Process -FilePath .\install.exe -ArgumentList "-c","--am","--al","in" -NoNewWindow -Wait + } + + $vulkanPath = (Resolve-Path "C:\VulkanSDK\*").path + echo "$vulkanPath\bin" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append + echo "VULKAN_SDK=$vulkanPath" >> $env:GITHUB_ENV - if: ${{ !cancelled() && steps.cache-install.outputs.cache-hit != 'true' }} uses: actions/cache/save@v4 with: diff --git a/CMakeLists.txt b/CMakeLists.txt index 1b6d0fb26f..50a80629a1 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -139,3 +139,15 @@ if(CMAKE_HIP_COMPILER) endforeach() endif() endif() + +find_package(Vulkan) +if(Vulkan_FOUND) + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/ggml/ggml/src/ggml-vulkan) + install(TARGETS ggml-vulkan + RUNTIME_DEPENDENCIES + PRE_INCLUDE_REGEXES vulkan + PRE_EXCLUDE_REGEXES ".*" + RUNTIME DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan + LIBRARY DESTINATION ${OLLAMA_INSTALL_DIR} COMPONENT Vulkan + ) +endif() diff --git a/CMakePresets.json b/CMakePresets.json index 15ec5943e9..b1bcb798e6 100644 --- a/CMakePresets.json +++ b/CMakePresets.json @@ -70,6 +70,10 @@ "CMAKE_HIP_FLAGS": "-parallel-jobs=4", "AMDGPU_TARGETS": "gfx940;gfx941;gfx942;gfx1010;gfx1012;gfx1030;gfx1100;gfx1101;gfx1102;gfx1151;gfx1200;gfx1201;gfx908:xnack-;gfx90a:xnack+;gfx90a:xnack-" } + }, + { + "name": "Vulkan", + "inherits": [ "Default" ] } ], "buildPresets": [ @@ -122,6 +126,11 @@ "name": "ROCm 6", "inherits": [ "ROCm" ], "configurePreset": "ROCm 6" + }, + { + "name": "Vulkan", + "targets": [ "ggml-vulkan" ], + "configurePreset": "Vulkan" } ] } diff --git a/Dockerfile b/Dockerfile index 6407f70979..63b671fea2 100644 --- a/Dockerfile +++ b/Dockerfile @@ -7,6 +7,7 @@ ARG ROCMVERSION=6.3.3 ARG JETPACK5VERSION=r35.4.1 ARG JETPACK6VERSION=r36.4.0 ARG CMAKEVERSION=3.31.2 +ARG VULKANVERSION=1.4.321.1 # We require gcc v10 minimum. v10.3 has regressions, so the rockylinux 8.5 AppStream has the latest compatible version FROM --platform=linux/amd64 rocm/dev-almalinux-8:${ROCMVERSION}-complete AS base-amd64 @@ -17,6 +18,16 @@ RUN yum install -y yum-utils \ && dnf install -y ccache \ && yum-config-manager --add-repo https://developer.download.nvidia.com/compute/cuda/repos/rhel8/x86_64/cuda-rhel8.repo ENV PATH=/opt/rh/gcc-toolset-10/root/usr/bin:$PATH +ARG VULKANVERSION +RUN wget https://sdk.lunarg.com/sdk/download/${VULKANVERSION}/linux/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz -O /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ + && tar xvf /tmp/vulkansdk-linux-x86_64-${VULKANVERSION}.tar.xz \ + && dnf -y install ninja-build \ + && ln -s /usr/bin/python3 /usr/bin/python \ + && /${VULKANVERSION}/vulkansdk -j 8 vulkan-headers \ + && /${VULKANVERSION}/vulkansdk -j 8 shaderc +RUN cp -r /${VULKANVERSION}/x86_64/include/* /usr/local/include/ \ + && cp -r /${VULKANVERSION}/x86_64/lib/* /usr/local/lib +ENV PATH=/${VULKANVERSION}/x86_64/bin:$PATH FROM --platform=linux/arm64 almalinux:8 AS base-arm64 # install epel-release for ccache @@ -106,6 +117,13 @@ RUN --mount=type=cache,target=/root/.ccache \ && cmake --build --parallel ${PARALLEL} --preset 'JetPack 6' \ && cmake --install build --component CUDA --strip --parallel ${PARALLEL} +FROM base AS vulkan +RUN --mount=type=cache,target=/root/.ccache \ + cmake --preset 'Vulkan' -DOLLAMA_RUNNER_DIR="vulkan" \ + && cmake --build --parallel --preset 'Vulkan' \ + && cmake --install build --component Vulkan --strip --parallel 8 + + FROM base AS build WORKDIR /go/src/github.com/ollama/ollama COPY go.mod go.sum . @@ -123,7 +141,8 @@ RUN --mount=type=cache,target=/root/.cache/go-build \ FROM --platform=linux/amd64 scratch AS amd64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ COPY --from=cuda-12 dist/lib/ollama /lib/ollama/ -COPY --from=cuda-13 dist/lib/ollama/ /lib/ollama/ +COPY --from=cuda-13 dist/lib/ollama /lib/ollama/ +COPY --from=vulkan dist/lib/ollama /lib/ollama/ FROM --platform=linux/arm64 scratch AS arm64 # COPY --from=cuda-11 dist/lib/ollama/ /lib/ollama/ @@ -136,12 +155,13 @@ FROM scratch AS rocm COPY --from=rocm-6 dist/lib/ollama /lib/ollama FROM ${FLAVOR} AS archive +ARG VULKANVERSION COPY --from=cpu dist/lib/ollama /lib/ollama COPY --from=build /bin/ollama /bin/ollama FROM ubuntu:24.04 RUN apt-get update \ - && apt-get install -y ca-certificates \ + && apt-get install -y ca-certificates libvulkan1 \ && apt-get clean \ && rm -rf /var/lib/apt/lists/* COPY --from=archive /bin /usr/bin diff --git a/discover/gpu.go b/discover/gpu.go index 15d1e79f24..2f394fdf86 100644 --- a/discover/gpu.go +++ b/discover/gpu.go @@ -70,6 +70,7 @@ func devInfoToInfoList(devs []ml.DeviceInfo) GpuInfoList { if dev.Library == "ROCm" && rocmDir != "" { info.DependencyPath = append(info.DependencyPath, rocmDir) } + // TODO any special processing of Vulkan devices? resp = append(resp, info) } if len(resp) == 0 { @@ -97,7 +98,16 @@ func (l GpuInfoList) GetVisibleDevicesEnv() []string { if len(l) == 0 { return nil } - return []string{rocmGetVisibleDevicesEnv(l)} + res := []string{} + envVar := rocmGetVisibleDevicesEnv(l) + if envVar != "" { + res = append(res, envVar) + } + envVar = vkGetVisibleDevicesEnv(l) + if envVar != "" { + res = append(res, envVar) + } + return res } func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { @@ -127,6 +137,25 @@ func rocmGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { return envVar + strings.Join(ids, ",") } +func vkGetVisibleDevicesEnv(gpuInfo []GpuInfo) string { + ids := []string{} + for _, info := range gpuInfo { + if info.Library != "Vulkan" { + continue + } + if info.filterID != "" { + ids = append(ids, info.filterID) + } else { + ids = append(ids, info.ID) + } + } + if len(ids) == 0 { + return "" + } + envVar := "GGML_VK_VISIBLE_DEVICES=" + return envVar + strings.Join(ids, ",") +} + // GetSystemInfo returns the last cached state of the GPUs on the system func GetSystemInfo() SystemInfo { deviceMu.Lock() diff --git a/discover/runner.go b/discover/runner.go index 6dfcae1138..318d405944 100644 --- a/discover/runner.go +++ b/discover/runner.go @@ -86,6 +86,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev // are enumerated, but not actually supported. // We run this in serial to avoid potentially initializing a GPU multiple // times concurrently leading to memory contention + // TODO refactor so we group the lib dirs and do serial per version, but parallel for different libs for dir := range libDirs { var dirs []string if dir != "" { @@ -131,19 +132,25 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev go func(i int) { defer wg.Done() var envVar string + id := devices[i].ID if devices[i].Library == "ROCm" { if runtime.GOOS != "linux" { envVar = "HIP_VISIBLE_DEVICES" } else { envVar = "ROCR_VISIBLE_DEVICES" } - } else { + } else if devices[i].Library == "CUDA" { envVar = "CUDA_VISIBLE_DEVICES" + } else if devices[i].Library == "Vulkan" { + id = devices[i].FilteredID + envVar = "GGML_VK_VISIBLE_DEVICES" + } else { + slog.Error("Unknown Library:" + devices[i].Library) } extraEnvs := []string{ - "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs - envVar + "=" + devices[i].ID, // Filter to just this one GPU + "GGML_CUDA_INIT=1", // force deep initialization to trigger crash on unsupported GPUs + envVar + "=" + id, // Filter to just this one GPU } if len(bootstrapDevices(ctx2ndPass, devices[i].LibraryPath, extraEnvs)) == 0 { needsDelete[i] = true @@ -163,6 +170,8 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev wg.Wait() logutil.Trace("supported GPU library combinations", "supported", supported) + filterOutVulkanThatAreSupportedByOtherGPU(needsDelete) + // Mark for deletion any overlaps - favoring the library version that can cover all GPUs if possible filterOverlapByLibrary(supported, needsDelete) @@ -184,7 +193,7 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev } } - // Now filter out any overlap with different libraries (favor CUDA/ROCm over others) + // Now filter out any overlap with different libraries (favor CUDA/HIP over others) for i := 0; i < len(devices); i++ { for j := i + 1; j < len(devices); j++ { // For this pass, we only drop exact duplicates @@ -346,6 +355,37 @@ func GPUDevices(ctx context.Context, runners []FilteredRunnerDiscovery) []ml.Dev return devices } +func filterOutVulkanThatAreSupportedByOtherGPU(needsDelete []bool) { + // Filter out Vulkan devices that share a PCI ID with a non-Vulkan device that is not marked for deletion + for i := range devices { + if devices[i].Library != "Vulkan" || needsDelete[i] { + continue + } + if devices[i].PCIID == "" { + continue + } + for j := range devices { + if i == j { + continue + } + if devices[j].PCIID == "" { + continue + } + if devices[j].PCIID == devices[i].PCIID && devices[j].Library != "Vulkan" && !needsDelete[j] { + needsDelete[i] = true + slog.Debug("dropping Vulkan duplicate by PCI ID", + "vulkan_id", devices[i].ID, + "vulkan_libdir", devices[i].LibraryPath[len(devices[i].LibraryPath)-1], + "pci_id", devices[i].PCIID, + "kept_library", devices[j].Library, + "kept_id", devices[j].ID, + ) + break + } + } + } +} + func filterOverlapByLibrary(supported map[string]map[string]map[string]int, needsDelete []bool) { // For multi-GPU systems, use the newest version that supports all the GPUs for _, byLibDirs := range supported { @@ -451,6 +491,7 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s cmd.Stdout = os.Stdout cmd.Stderr = os.Stderr } + // cmd.SysProcAttr = llm.LlamaServerSysProcAttr // circular dependency - bring back once refactored pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) pathNeeded := true @@ -508,6 +549,14 @@ func bootstrapDevices(ctx context.Context, ollamaLibDirs []string, extraEnvs []s } } logutil.Trace("runner enumerated devices", "OLLAMA_LIBRARY_PATH", ollamaLibDirs, "devices", devices) + + // Enumerate returned devices starting at 0 per library and assign the per-library index as FilteredID + libCounts := make(map[string]int) + for i := range devices { + lib := devices[i].Library + devices[i].FilteredID = strconv.Itoa(libCounts[lib]) + libCounts[lib]++ + } return devices } diff --git a/discover/types.go b/discover/types.go index f0c3989df4..adb2f43a74 100644 --- a/discover/types.go +++ b/discover/types.go @@ -37,7 +37,7 @@ type GpuInfo struct { // TODO better name maybe "InferenceProcessor"? UnreliableFreeMemory bool // GPU information - filterID string // AMD Workaround: The numeric ID of the device used to filter out other devices + filterID string // AMD/Vulkan Workaround: The numeric ID of the device used to filter out other devices Name string `json:"name"` // user friendly name if available ComputeMajor int `json:"compute_major"` // Compute Capability or gfx ComputeMinor int `json:"compute_minor"` @@ -175,7 +175,8 @@ func (l GpuInfoList) FlashAttentionSupported() bool { supportsFA := gpu.Library == "cpu" || gpu.Name == "Metal" || gpu.Library == "Metal" || (gpu.Library == "CUDA" && gpu.DriverMajor >= 7 && !(gpu.ComputeMajor == 7 && gpu.ComputeMinor == 2)) || // We don't have kernels for Jetson Xavier - gpu.Library == "ROCm" + gpu.Library == "ROCm" || + gpu.Library == "Vulkan" if !supportsFA { return false diff --git a/envconfig/config.go b/envconfig/config.go index dd25ed80d6..14c7c61700 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -217,6 +217,7 @@ var ( CudaVisibleDevices = String("CUDA_VISIBLE_DEVICES") HipVisibleDevices = String("HIP_VISIBLE_DEVICES") RocrVisibleDevices = String("ROCR_VISIBLE_DEVICES") + VkVisibleDevices = String("GGML_VK_VISIBLE_DEVICES") GpuDeviceOrdinal = String("GPU_DEVICE_ORDINAL") HsaOverrideGfxVersion = String("HSA_OVERRIDE_GFX_VERSION") ) @@ -307,6 +308,7 @@ func AsMap() map[string]EnvVar { ret["CUDA_VISIBLE_DEVICES"] = EnvVar{"CUDA_VISIBLE_DEVICES", CudaVisibleDevices(), "Set which NVIDIA devices are visible"} ret["HIP_VISIBLE_DEVICES"] = EnvVar{"HIP_VISIBLE_DEVICES", HipVisibleDevices(), "Set which AMD devices are visible by numeric ID"} ret["ROCR_VISIBLE_DEVICES"] = EnvVar{"ROCR_VISIBLE_DEVICES", RocrVisibleDevices(), "Set which AMD devices are visible by UUID or numeric ID"} + ret["GGML_VK_VISIBLE_DEVICES"] = EnvVar{"GGML_VK_VISIBLE_DEVICES", VkVisibleDevices(), "Set which Vulkan devices are visible by numeric ID"} ret["GPU_DEVICE_ORDINAL"] = EnvVar{"GPU_DEVICE_ORDINAL", GpuDeviceOrdinal(), "Set which AMD devices are visible by numeric ID"} ret["HSA_OVERRIDE_GFX_VERSION"] = EnvVar{"HSA_OVERRIDE_GFX_VERSION", HsaOverrideGfxVersion(), "Override the gfx used for all detected AMD GPUs"} ret["OLLAMA_INTEL_GPU"] = EnvVar{"OLLAMA_INTEL_GPU", IntelGPU(), "Enable experimental Intel GPU detection"} diff --git a/llama/llama.go b/llama/llama.go index 84618590f9..c995b3ead2 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -69,7 +69,9 @@ func EnumerateGPUs() []ml.DeviceID { for i := range C.ggml_backend_dev_count() { device := C.ggml_backend_dev_get(i) - if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + switch C.ggml_backend_dev_type(device) { + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: var props C.struct_ggml_backend_dev_props C.ggml_backend_dev_get_props(device, &props) ids = append(ids, ml.DeviceID{ diff --git a/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch new file mode 100644 index 0000000000..997dd3860a --- /dev/null +++ b/llama/patches/0027-vulkan-get-GPU-ID-ollama-v0.11.5.patch @@ -0,0 +1,95 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Xiaodong Ye +Date: Mon, 18 Aug 2025 12:48:07 +0800 +Subject: [PATCH] vulkan: get GPU ID (ollama v0.11.5) + +Signed-off-by: Xiaodong Ye +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 37 ++++++++++++++++++++++++++++ + 1 file changed, 37 insertions(+) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index 061cd078..adea7783 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -11588,6 +11588,29 @@ static void ggml_vk_get_device_description(int device, char * description, size_ + snprintf(description, description_size, "%s", props.deviceName.data()); + } + ++static std::string ggml_vk_get_device_id(int device) { ++ ggml_vk_instance_init(); ++ ++ std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); ++ ++ vk::PhysicalDeviceProperties2 props; ++ vk::PhysicalDeviceIDProperties deviceIDProps; ++ props.pNext = &deviceIDProps; ++ devices[device].getProperties2(&props); ++ ++ const auto& uuid = deviceIDProps.deviceUUID; ++ char id[64]; ++ snprintf(id, sizeof(id), ++ "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", ++ uuid[0], uuid[1], uuid[2], uuid[3], ++ uuid[4], uuid[5], ++ uuid[6], uuid[7], ++ uuid[8], uuid[9], ++ uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] ++ ); ++ return std::string(id); ++} ++ + // backend interface + + #define UNUSED GGML_UNUSED +@@ -12394,6 +12417,12 @@ void ggml_backend_vk_get_device_description(int device, char * description, size + ggml_vk_get_device_description(dev_idx, description, description_size); + } + ++std::string ggml_backend_vk_get_device_id(int device) { ++ GGML_ASSERT(device < (int) vk_instance.device_indices.size()); ++ int dev_idx = vk_instance.device_indices[device]; ++ return ggml_vk_get_device_id(dev_idx); ++} ++ + void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); +@@ -12481,6 +12510,7 @@ struct ggml_backend_vk_device_context { + std::string description; + bool is_integrated_gpu; + std::string pci_bus_id; ++ std::string id; + }; + + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { +@@ -12493,6 +12523,11 @@ static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t de + return ctx->description.c_str(); + } + ++static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { ++ ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; ++ return ctx->id.c_str(); ++} ++ + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx->device, free, total); +@@ -12519,6 +12554,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); ++ props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); + props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); +@@ -12965,6 +13001,7 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; + ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ++ ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, +-- +2.51.0 \ No newline at end of file diff --git a/llama/patches/0028-vulkan-pci-and-memory.patch b/llama/patches/0028-vulkan-pci-and-memory.patch new file mode 100644 index 0000000000..337eb847bc --- /dev/null +++ b/llama/patches/0028-vulkan-pci-and-memory.patch @@ -0,0 +1,253 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Daniel Hiltgen +Date: Fri Sep 5 08:25:03 2025 -0700 +Subject: [PATCH] Vulkan PCI and Memory + +--- + ggml/src/ggml-vulkan/ggml-vulkan.cpp | 176 ++++++++++++++++++++++----- + 1 file changed, 145 insertions(+), 31 deletions(-) + +diff --git a/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +index adea7783..fb7204ce 100644 +--- a/ggml/src/ggml-vulkan/ggml-vulkan.cpp ++++ b/ggml/src/ggml-vulkan/ggml-vulkan.cpp +@@ -12423,31 +12423,99 @@ std::string ggml_backend_vk_get_device_id(int device) { + return ggml_vk_get_device_id(dev_idx); + } + +-void ggml_backend_vk_get_device_memory(int device, size_t * free, size_t * total) { +- GGML_ASSERT(device < (int) vk_instance.device_indices.size()); +- GGML_ASSERT(device < (int) vk_instance.device_supports_membudget.size()); ++////////////////////////// ++ ++struct ggml_backend_vk_device_context { ++ size_t device; ++ std::string name; ++ std::string description; ++ bool is_integrated_gpu; ++ // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) ++ std::string pci_id; ++ std::string id; ++ std::string uuid; ++ int major; ++ int minor; ++ int driver_major; ++ int driver_minor; ++ int pci_bus_id; ++ int pci_device_id; ++ int pci_domain_id; ++}; ++ ++void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { ++ GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); ++ GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); ++ ++ vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + +- vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device]]; +- vk::PhysicalDeviceMemoryBudgetPropertiesEXT budgetprops; +- vk::PhysicalDeviceMemoryProperties2 memprops = {}; +- bool membudget_supported = vk_instance.device_supports_membudget[device]; ++ vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); ++ vk::PhysicalDeviceProperties2 props2; ++ vkdev.getProperties2(&props2); + +- if (membudget_supported) { +- memprops.pNext = &budgetprops; ++ if (!ctx->is_integrated_gpu) ++ { ++ // Use vendor specific management libraries for best VRAM reporting if available ++ switch (props2.properties.vendorID) { ++ case VK_VENDOR_ID_AMD: ++ if (ggml_hip_mgmt_init() == 0) { ++ int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_hip_mgmt_release(); ++ return; ++ } ++ ggml_hip_mgmt_release(); ++ } ++ break; ++ case VK_VENDOR_ID_NVIDIA: ++ if (ggml_nvml_init() == 0) { ++ int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); ++ if (status == 0) { ++ GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); ++ ggml_nvml_release(); ++ return; ++ } ++ ggml_nvml_release(); ++ } ++ break; ++ } + } +- vkdev.getMemoryProperties2(&memprops); ++ // else fallback to memory budget if supported + +- for (uint32_t i = 0; i < memprops.memoryProperties.memoryHeapCount; ++i) { +- const vk::MemoryHeap & heap = memprops.memoryProperties.memoryHeaps[i]; ++ *total = 0; ++ *free = 0; ++ vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; ++ vk::PhysicalDeviceMemoryProperties2 memprops2; ++ memprops2.pNext = &mem_budget_props; ++ vkdev.getMemoryProperties2(&memprops2); ++ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { ++ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { ++ *total += memprops2.memoryProperties.memoryHeaps[i].size; ++ } else if (ctx->is_integrated_gpu) { ++ // Include shared memory on iGPUs ++ *total += memprops2.memoryProperties.memoryHeaps[i].size; ++ } ++ } ++ for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { ++ if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { ++ *free += mem_budget_props.heapBudget[i]; ++ } else if (ctx->is_integrated_gpu) { ++ *free += mem_budget_props.heapBudget[i]; ++ } ++ } ++ if (*total > 0 && *free > 0) { ++ return; ++ } else if (*total > 0) { ++ *free = *total; ++ return; ++ } + ++ // else just report the physical memory ++ for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; +- +- if (membudget_supported && i < budgetprops.heapUsage.size()) { +- *free = budgetprops.heapBudget[i] - budgetprops.heapUsage[i]; +- } else { +- *free = heap.size; +- } ++ *free = heap.size; + break; + } + } +@@ -12502,16 +12570,17 @@ static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + return std::string(pci_bus_id); + } + +-////////////////////////// +- +-struct ggml_backend_vk_device_context { +- size_t device; +- std::string name; +- std::string description; +- bool is_integrated_gpu; +- std::string pci_bus_id; +- std::string id; +-}; ++static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { ++ if (id.empty()) return false; ++ unsigned int d = 0, b = 0, dev = 0, func = 0; ++ // Expected format: dddd:bb:dd.f (all hex) ++ int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); ++ if (n < 4) return false; ++ if (domain) *domain = (int) d; ++ if (bus) *bus = (int) b; ++ if (device) *device = (int) dev; ++ return true; ++} + + static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; +@@ -12530,7 +12599,7 @@ static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + + static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; +- ggml_backend_vk_get_device_memory(ctx->device, free, total); ++ ggml_backend_vk_get_device_memory(ctx, free, total); + } + + static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { +@@ -12556,7 +12625,7 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); +- props->device_id = ctx->pci_bus_id.empty() ? nullptr : ctx->pci_bus_id.c_str(); ++ props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str(); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, +@@ -12564,6 +12633,16 @@ static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; ++ ++ props->compute_major = ctx->major; ++ props->compute_minor = ctx->minor; ++ props->driver_major = ctx->driver_major; ++ props->driver_minor = ctx->driver_minor; ++ props->integrated = ctx->is_integrated_gpu; ++ props->pci_bus_id = ctx->pci_bus_id; ++ props->pci_device_id = ctx->pci_device_id; ++ props->pci_domain_id = ctx->pci_domain_id; ++ props->library = GGML_VK_NAME; + } + + static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { +@@ -12992,6 +13071,8 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { ++ std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); ++ + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; +@@ -13000,13 +13081,46 @@ static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; +- ctx->pci_bus_id = ggml_backend_vk_get_device_pci_id(i); ++ ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); ++ ++ // Gather additional information about the device ++ int dev_idx = vk_instance.device_indices[i]; ++ vk::PhysicalDeviceProperties props1; ++ vk_devices[dev_idx].getProperties(&props1); ++ vk::PhysicalDeviceProperties2 props2; ++ vk::PhysicalDeviceIDProperties device_id_props; ++ vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; ++ vk::PhysicalDeviceDriverProperties driver_props; ++ props2.pNext = &device_id_props; ++ device_id_props.pNext = &pci_bus_props; ++ pci_bus_props.pNext = &driver_props; ++ vk_devices[dev_idx].getProperties2(&props2); ++ std::ostringstream oss; ++ oss << std::hex << std::setfill('0'); ++ oss << "GPU-"; ++ int byteIdx = 0; ++ for (int i = 0; i < 16; ++i, ++byteIdx) { ++ oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); ++ if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { ++ oss << '-'; ++ } ++ } ++ ctx->uuid = oss.str(); ++ ctx->pci_bus_id = pci_bus_props.pciBus; ++ ctx->pci_device_id = pci_bus_props.pciDevice; ++ ctx->pci_domain_id = pci_bus_props.pciDomain; ++ ctx->id = std::to_string(i); ++ ctx->major = 0; ++ ctx->minor = 0; ++ // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string ++ ctx->driver_major = 0; ++ ctx->driver_minor = 0; + } + initialized = true; + } +-- +2.51.0 \ No newline at end of file diff --git a/llm/server.go b/llm/server.go index febcc1e65f..1a997e4d82 100644 --- a/llm/server.go +++ b/llm/server.go @@ -567,6 +567,7 @@ func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requi if (runtime.GOOS == "windows" && gpus[0].Library == "CUDA" && s.options.UseMMap == nil) || (runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) || (gpus[0].Library == "cpu" && s.options.UseMMap == nil) || + (gpus[0].Library == "Vulkan" && s.options.UseMMap == nil) || (s.options.UseMMap != nil && !*s.options.UseMMap) { s.loadRequest.UseMmap = false } diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index dc71c8de47..07e55dd3c3 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -57,7 +57,8 @@ var initDevices = sync.OnceFunc(func() { } case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: accels = append(accels, d) - case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpus = append(gpus, d) } @@ -470,7 +471,9 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { // Mimic llama runner logs summarizing layers and memory gpuLayers := 0 for layer := range maps.Values(b.layers) { - if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + switch C.ggml_backend_dev_type(layer.d) { + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: gpuLayers++ } } @@ -479,7 +482,8 @@ func (b *Backend) Load(ctx context.Context, progress func(float32)) error { switch C.ggml_backend_dev_type(b.output) { case C.GGML_BACKEND_DEVICE_TYPE_CPU: slog.Info("offloading output layer to CPU") - case C.GGML_BACKEND_DEVICE_TYPE_GPU: + case C.GGML_BACKEND_DEVICE_TYPE_GPU, + C.GGML_BACKEND_DEVICE_TYPE_IGPU: slog.Info("offloading output layer to GPU") gpuLayers++ case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: diff --git a/ml/backend/ggml/ggml/.rsync-filter b/ml/backend/ggml/ggml/.rsync-filter index a2b1b7d946..449ec9e5d0 100644 --- a/ml/backend/ggml/ggml/.rsync-filter +++ b/ml/backend/ggml/ggml/.rsync-filter @@ -20,10 +20,14 @@ include /src/ggml-cuda/vendors/ include /src/ggml-cuda/template-instances/ include /src/ggml-hip/ include /src/ggml-metal/ +include src/ggml-vulkan/ +include src/ggml-vulkan/vulkan-shaders include CMakeLists.txt include *.[chm] include *.cpp include *.cu include *.cuh include *.metal +include *.comp +include *.glsl hide * diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt new file mode 100644 index 0000000000..83a83887b5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/CMakeLists.txt @@ -0,0 +1,211 @@ +cmake_minimum_required(VERSION 3.19) +cmake_policy(SET CMP0114 NEW) +cmake_policy(SET CMP0116 NEW) + +find_package(Vulkan COMPONENTS glslc REQUIRED) + +function(detect_host_compiler) + if (CMAKE_HOST_SYSTEM_NAME STREQUAL "Windows") + find_program(HOST_C_COMPILER NAMES cl gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES cl g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + else() + find_program(HOST_C_COMPILER NAMES gcc clang NO_CMAKE_FIND_ROOT_PATH) + find_program(HOST_CXX_COMPILER NAMES g++ clang++ NO_CMAKE_FIND_ROOT_PATH) + endif() + set(HOST_C_COMPILER "${HOST_C_COMPILER}" PARENT_SCOPE) + set(HOST_CXX_COMPILER "${HOST_CXX_COMPILER}" PARENT_SCOPE) +endfunction() + +# Function to test shader extension support +# Parameters: +# EXTENSION_NAME - Name of the extension to test (e.g., "GL_EXT_integer_dot_product") +# TEST_SHADER_FILE - Path to the test shader file +# RESULT_VARIABLE - Name of the variable to set (ON/OFF) based on test result +function(test_shader_extension_support EXTENSION_NAME TEST_SHADER_FILE RESULT_VARIABLE) + execute_process( + COMMAND ${Vulkan_GLSLC_EXECUTABLE} -o - -fshader-stage=compute --target-env=vulkan1.3 "${TEST_SHADER_FILE}" + OUTPUT_VARIABLE glslc_output + ERROR_VARIABLE glslc_error + ) + + if (${glslc_error} MATCHES ".*extension not supported: ${EXTENSION_NAME}.*") + message(STATUS "${EXTENSION_NAME} not supported by glslc") + set(${RESULT_VARIABLE} OFF PARENT_SCOPE) + else() + message(STATUS "${EXTENSION_NAME} supported by glslc") + set(${RESULT_VARIABLE} ON PARENT_SCOPE) + add_compile_definitions(${RESULT_VARIABLE}) + + # Ensure the extension support is forwarded to vulkan-shaders-gen + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -D${RESULT_VARIABLE}=ON) + set(VULKAN_SHADER_GEN_CMAKE_ARGS "${VULKAN_SHADER_GEN_CMAKE_ARGS}" PARENT_SCOPE) + endif() +endfunction() + +if (Vulkan_FOUND) + message(STATUS "Vulkan found") + + ggml_add_backend_library(ggml-vulkan + ggml-vulkan.cpp + ../../include/ggml-vulkan.h + ) + + set(VULKAN_SHADER_GEN_CMAKE_ARGS "") + + # Test all shader extensions + test_shader_extension_support( + "GL_KHR_cooperative_matrix" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat.comp" + "GGML_VULKAN_COOPMAT_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_NV_cooperative_matrix2" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/coopmat2.comp" + "GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_EXT_integer_dot_product" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/integer_dot.comp" + "GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT" + ) + + test_shader_extension_support( + "GL_EXT_bfloat16" + "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders/feature-tests/bfloat16.comp" + "GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT" + ) + + target_link_libraries(ggml-vulkan PRIVATE Vulkan::Vulkan) + target_include_directories(ggml-vulkan PRIVATE ${CMAKE_CURRENT_BINARY_DIR}) + + # Workaround to the "can't dereference invalidated vector iterator" bug in clang-cl debug build + # Posssibly relevant: https://stackoverflow.com/questions/74748276/visual-studio-no-displays-the-correct-length-of-stdvector + if (MSVC AND CMAKE_CXX_COMPILER_ID STREQUAL "Clang") + add_compile_definitions(_ITERATOR_DEBUG_LEVEL=0) + endif() + + if (GGML_VULKAN_CHECK_RESULTS) + add_compile_definitions(GGML_VULKAN_CHECK_RESULTS) + endif() + + if (GGML_VULKAN_DEBUG) + add_compile_definitions(GGML_VULKAN_DEBUG) + endif() + + if (GGML_VULKAN_MEMORY_DEBUG) + add_compile_definitions(GGML_VULKAN_MEMORY_DEBUG) + endif() + + if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DGGML_VULKAN_SHADER_DEBUG_INFO=ON) + endif() + + if (GGML_VULKAN_VALIDATE) + add_compile_definitions(GGML_VULKAN_VALIDATE) + endif() + + if (GGML_VULKAN_RUN_TESTS) + add_compile_definitions(GGML_VULKAN_RUN_TESTS) + endif() + + # Set up toolchain for host compilation whether cross-compiling or not + if (CMAKE_CROSSCOMPILING) + if (GGML_VULKAN_SHADERS_GEN_TOOLCHAIN) + set(HOST_CMAKE_TOOLCHAIN_FILE ${GGML_VULKAN_SHADERS_GEN_TOOLCHAIN}) + else() + detect_host_compiler() + if (NOT HOST_C_COMPILER OR NOT HOST_CXX_COMPILER) + message(FATAL_ERROR "Host compiler not found") + else() + message(STATUS "Host compiler: ${HOST_C_COMPILER} ${HOST_CXX_COMPILER}") + endif() + configure_file(${CMAKE_CURRENT_SOURCE_DIR}/cmake/host-toolchain.cmake.in ${CMAKE_BINARY_DIR}/host-toolchain.cmake @ONLY) + set(HOST_CMAKE_TOOLCHAIN_FILE ${CMAKE_BINARY_DIR}/host-toolchain.cmake) + endif() + else() + # For non-cross-compiling, use empty toolchain (use host compiler) + set(HOST_CMAKE_TOOLCHAIN_FILE "") + endif() + + include(ExternalProject) + + if (CMAKE_CROSSCOMPILING) + list(APPEND VULKAN_SHADER_GEN_CMAKE_ARGS -DCMAKE_TOOLCHAIN_FILE=${HOST_CMAKE_TOOLCHAIN_FILE}) + message(STATUS "vulkan-shaders-gen toolchain file: ${HOST_CMAKE_TOOLCHAIN_FILE}") + endif() + + ExternalProject_Add( + vulkan-shaders-gen + SOURCE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders + CMAKE_ARGS -DCMAKE_INSTALL_PREFIX=${CMAKE_BINARY_DIR}/$ + -DCMAKE_INSTALL_BINDIR=. + -DCMAKE_BUILD_TYPE=$ + ${VULKAN_SHADER_GEN_CMAKE_ARGS} + + BUILD_COMMAND ${CMAKE_COMMAND} --build . --config $ + BUILD_ALWAYS TRUE + + # NOTE: When DESTDIR is set using Makefile generators and + # "make install" triggers the build step, vulkan-shaders-gen + # would be installed into the DESTDIR prefix, so it is unset + # to ensure that does not happen. + + INSTALL_COMMAND ${CMAKE_COMMAND} -E env --unset=DESTDIR + ${CMAKE_COMMAND} --install . --config $ + ) + + set (_ggml_vk_host_suffix $,.exe,>) + set (_ggml_vk_genshaders_dir "${CMAKE_BINARY_DIR}/$") + set (_ggml_vk_genshaders_cmd "${_ggml_vk_genshaders_dir}/vulkan-shaders-gen${_ggml_vk_host_suffix}") + set (_ggml_vk_header "${CMAKE_CURRENT_BINARY_DIR}/ggml-vulkan-shaders.hpp") + set (_ggml_vk_input_dir "${CMAKE_CURRENT_SOURCE_DIR}/vulkan-shaders") + set (_ggml_vk_output_dir "${CMAKE_CURRENT_BINARY_DIR}/vulkan-shaders.spv") + + file(GLOB _ggml_vk_shader_files CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.comp") + + # Because external projects do not provide source-level tracking, + # the vulkan-shaders-gen sources need to be explicitly added to + # ensure that changes will cascade into shader re-generation. + + file(GLOB _ggml_vk_shaders_gen_sources + CONFIGURE_DEPENDS "${_ggml_vk_input_dir}/*.cpp" + "${_ggml_vk_input_dir}/*.h") + + add_custom_command( + OUTPUT ${_ggml_vk_header} + COMMAND ${_ggml_vk_genshaders_cmd} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + DEPENDS ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders header" + ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_header}) + + foreach (file_full ${_ggml_vk_shader_files}) + get_filename_component(file ${file_full} NAME) + set (_ggml_vk_target_cpp "${CMAKE_CURRENT_BINARY_DIR}/${file}.cpp") + + add_custom_command( + OUTPUT ${_ggml_vk_target_cpp} + DEPFILE ${_ggml_vk_target_cpp}.d + COMMAND ${_ggml_vk_genshaders_cmd} + --glslc ${Vulkan_GLSLC_EXECUTABLE} + --source ${file_full} + --output-dir ${_ggml_vk_output_dir} + --target-hpp ${_ggml_vk_header} + --target-cpp ${_ggml_vk_target_cpp} + DEPENDS ${file_full} + ${_ggml_vk_shaders_gen_sources} + vulkan-shaders-gen + COMMENT "Generate vulkan shaders for ${file}" + ) + target_sources(ggml-vulkan PRIVATE ${_ggml_vk_target_cpp}) + endforeach() + +else() + message(WARNING "Vulkan not found") +endif() diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp new file mode 100644 index 0000000000..aa48a63c65 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/ggml-vulkan.cpp @@ -0,0 +1,13903 @@ +#include "ggml-vulkan.h" +#include +#if defined(GGML_VULKAN_RUN_TESTS) || defined(GGML_VULKAN_CHECK_RESULTS) +#include +#include "ggml-cpu.h" +#endif + +// See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- +#define VULKAN_HPP_DISPATCH_LOADER_DYNAMIC 1 +// We use VULKAN_HPP_DEFAULT_DISPATCHER, but not VULKAN_HPP_DEFAULT_DISPATCH_LOADER_DYNAMIC_STORAGE +// to avoid conflicts with applications or other libraries who might use it. +#if VK_HEADER_VERSION >= 301 +namespace vk::detail { class DispatchLoaderDynamic; } +using vk::detail::DispatchLoaderDynamic; +#else +namespace vk { class DispatchLoaderDynamic; } +using vk::DispatchLoaderDynamic; +#endif +DispatchLoaderDynamic & ggml_vk_default_dispatcher(); +#define VULKAN_HPP_DEFAULT_DISPATCHER ggml_vk_default_dispatcher() + +#include + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#if defined(_MSC_VER) +# define NOMINMAX 1 +# include +# define YIELD() YieldProcessor() +#elif defined(__clang__) || defined(__GNUC__) +# if defined(__x86_64__) ||defined(__i386__) +# include +# define YIELD() _mm_pause() +# elif defined(__arm__) || defined(__aarch64__) +# if defined(__clang__) +# include +# define YIELD() __yield() +# else +# define YIELD() asm volatile("yield") +# endif +# endif +#endif + +#if !defined(YIELD) +#define YIELD() +#endif + +#include "ggml-impl.h" +#include "ggml-backend-impl.h" + +#include "ggml-vulkan-shaders.hpp" + +// remove this once it's more widely available in the SDK +#if !defined(VK_KHR_shader_bfloat16) + +#define VK_KHR_shader_bfloat16 1 +#define VK_KHR_SHADER_BFLOAT16_SPEC_VERSION 1 +#define VK_KHR_SHADER_BFLOAT16_EXTENSION_NAME "VK_KHR_shader_bfloat16" +#define VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR ((VkStructureType)1000141000) +#define VK_COMPONENT_TYPE_BFLOAT16_KHR ((VkComponentTypeKHR)1000141000) + +typedef struct VkPhysicalDeviceShaderBfloat16FeaturesKHR { + VkStructureType sType; + void* pNext; + VkBool32 shaderBFloat16Type; + VkBool32 shaderBFloat16DotProduct; + VkBool32 shaderBFloat16CooperativeMatrix; +} VkPhysicalDeviceShaderBfloat16FeaturesKHR; +#endif + +#define ROUNDUP_POW2(M, N) (((M) + (N) - 1) & ~((N) - 1)) +#define CEIL_DIV(M, N) (((M) + (N)-1) / (N)) +static bool is_pow2(uint32_t x) { return x > 1 && (x & (x-1)) == 0; } + +#define VK_VENDOR_ID_AMD 0x1002 +#define VK_VENDOR_ID_APPLE 0x106b +#define VK_VENDOR_ID_INTEL 0x8086 +#define VK_VENDOR_ID_NVIDIA 0x10de + +#define VK_DEVICE_DESCRIPTOR_POOL_SIZE 256 + +#define GGML_VK_MAX_NODES 8192 + +#define MAX_VK_BUFFERS 256 + +#define VK_CHECK(err, msg) \ + do { \ + vk::Result err_ = (err); \ + if (err_ != vk::Result::eSuccess) { \ + fprintf(stderr, "ggml_vulkan: %s error %s at %s:%d\n", \ + #err, to_string(err_).c_str(), __FILE__, __LINE__); \ + exit(1); \ + } \ + } while (0) + +#ifdef GGML_VULKAN_DEBUG +#define VK_LOG_DEBUG(msg) std::cerr << msg << std::endl +#else +#define VK_LOG_DEBUG(msg) ((void) 0) +#endif // GGML_VULKAN_DEBUG + +struct ggml_backend_vk_context; + +#define MAX_PARAMETER_COUNT 12 +// Max number of adds that can be fused without exceeding MAX_PARAMETER_COUNT. +#define MAX_FUSED_ADDS (MAX_PARAMETER_COUNT - 3) + +struct vk_pipeline_struct { + std::string name; + vk::ShaderModule shader_module; + vk::PipelineLayout layout; + vk::Pipeline pipeline; + uint32_t push_constant_size; + uint32_t parameter_count; + std::array wg_denoms; + uint32_t align; + // true if fields have been set by ggml_vk_create_pipeline + bool initialized {}; + // set to true to request the pipeline is compiled after the dryrun + bool needed {}; + // set to true when the shader has been compiled + bool compiled {}; + // number of registers used, extracted from pipeline executable properties + uint32_t register_count {}; +}; + +typedef std::shared_ptr vk_pipeline; +typedef std::weak_ptr vk_pipeline_ref; + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline); + +struct vk_matmul_pipeline_struct { + vk_pipeline l, m, s; + vk_pipeline a_l, a_m, a_s; +}; + +typedef std::shared_ptr vk_matmul_pipeline; + +struct vk_matmul_pipeline2 { + vk_matmul_pipeline2() { + f16acc = std::make_shared(); + f32acc = std::make_shared(); + } + vk_matmul_pipeline f32acc; + vk_matmul_pipeline f16acc; +}; + +struct vk_device_struct; +typedef std::shared_ptr vk_device; +typedef std::weak_ptr vk_device_ref; + +struct vk_buffer_struct; +typedef std::shared_ptr vk_buffer; +typedef std::weak_ptr vk_buffer_ref; + +struct ggml_backend_vk_buffer_type_context { + std::string name; + vk_device device; +}; + +struct vk_queue; + +// Stores command pool/buffers. There's an instance of this +// for each (context,queue) pair and for each (device,queue) pair. +struct vk_command_pool { + void init(vk_device& device, vk_queue *q_); + void destroy(vk::Device& device); + + vk::CommandPool pool; + uint32_t cmd_buffer_idx; + std::vector cmd_buffers; + + vk_queue *q; +}; + +// Prevent simultaneous submissions to the same queue. +// This could be per vk_queue if we stopped having two vk_queue structures +// sharing the same vk::Queue. +static std::mutex queue_mutex; + +struct vk_queue { + uint32_t queue_family_index; + vk::Queue queue; + + vk_command_pool cmd_pool; + + vk::PipelineStageFlags stage_flags; + + bool transfer_only; + + // copy everything except the cmd_pool + void copyFrom(vk_queue &other) { + queue_family_index = other.queue_family_index; + queue = other.queue; + stage_flags = other.stage_flags; + transfer_only = other.transfer_only; + } +}; + +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft); +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size); +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft); +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor); +static ggml_backend_buffer_type_i ggml_backend_vk_buffer_type_interface = { + /* .get_name = */ ggml_backend_vk_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_vk_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_vk_buffer_type_get_alloc_size, + /* .is_host = */ NULL, +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +class vk_memory_logger; +#endif +class vk_perf_logger; +static void ggml_vk_destroy_buffer(vk_buffer& buf); + +static constexpr uint32_t mul_mat_vec_max_cols = 8; +static constexpr uint32_t p021_max_gqa_ratio = 8; + +enum vk_device_architecture { + OTHER, + AMD_GCN, + AMD_RDNA1, + AMD_RDNA2, + AMD_RDNA3, + INTEL_XE2, + NVIDIA_PRE_TURING, +}; + +static vk_device_architecture get_device_architecture(const vk::PhysicalDevice& device) { + vk::PhysicalDeviceProperties props = device.getProperties(); + + if (props.vendorID == VK_VENDOR_ID_AMD) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool amd_shader_core_properties = false; + bool integer_dot_product = false; + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_AMD_shader_core_properties", properties.extensionName) == 0) { + amd_shader_core_properties = true; + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0) { + integer_dot_product = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!amd_shader_core_properties || !integer_dot_product || !subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceShaderCorePropertiesAMD shader_core_props_amd; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR integer_dot_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &shader_core_props_amd; + shader_core_props_amd.pNext = &integer_dot_props; + integer_dot_props.pNext = &subgroup_size_control_props; + + device.getProperties2(&props2); + + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 64) { + return vk_device_architecture::AMD_GCN; + } + if (subgroup_size_control_props.maxSubgroupSize == 64 && subgroup_size_control_props.minSubgroupSize == 32) { + // RDNA + if (shader_core_props_amd.wavefrontsPerSimd == 20) { + return vk_device_architecture::AMD_RDNA1; + } + if (integer_dot_props.integerDotProduct4x8BitPackedMixedSignednessAccelerated) { + return vk_device_architecture::AMD_RDNA3; + } + return vk_device_architecture::AMD_RDNA2; + } + } else if (props.vendorID == VK_VENDOR_ID_INTEL) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool subgroup_size_control = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + subgroup_size_control = true; + } + } + + if (!subgroup_size_control) { + return vk_device_architecture::OTHER; + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + + props2.pNext = &subgroup_size_control_props; + device.getProperties2(&props2); + + if (subgroup_size_control_props.minSubgroupSize == 16) { + // Xe2 architecture uses SIMD16 while previous Xe and Gen architecture uses SIMD8. + // Minimum subgroup size matches the SIMD width so we distinguish architecture by checking this value. + // https://www.intel.com/content/www/us/en/content-details/824434/2024-intel-tech-tour-xe2-and-lunar-lake-s-gpu.html + // https://www.intel.com/content/www/us/en/docs/oneapi/optimization-guide-gpu/2025-0/intel-xe-gpu-architecture.html + return vk_device_architecture::INTEL_XE2; + } + } else if (props.vendorID == VK_VENDOR_ID_NVIDIA) { + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool cooperative_matrix = false; + + // Detect "pre-turing" based on lack of coopmat support. + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0) { + cooperative_matrix = true; + break; + } + } + + if (!cooperative_matrix) { + return vk_device_architecture::NVIDIA_PRE_TURING; + } + } + return vk_device_architecture::OTHER; +} + +enum vk_conv_shapes { + CONV_SHAPE_128x128, + CONV_SHAPE_64x32, + CONV_SHAPE_32x256, + CONV_SHAPE_COUNT, +}; + +enum dmmv_wg_sizes { + DMMV_WG_SIZE_SUBGROUP, + DMMV_WG_SIZE_LARGE, + DMMV_WG_SIZE_COUNT, +}; + +enum FaCodePath { + FA_SCALAR, + FA_COOPMAT1, + FA_COOPMAT2, +}; + +struct vk_fa_pipeline_state { + vk_fa_pipeline_state(uint32_t HSK, uint32_t HSV, bool small_rows, FaCodePath path, bool aligned, bool f32acc) + : HSK(HSK), HSV(HSV), small_rows(small_rows), path(path), aligned(aligned), f32acc(f32acc) {} + + uint32_t HSK, HSV; + bool small_rows; + FaCodePath path; + bool aligned; + bool f32acc; + + bool operator<(const vk_fa_pipeline_state &b) const { + return std::tie(HSK, HSV, small_rows, path, aligned, f32acc) < + std::tie(b.HSK, b.HSV, b.small_rows, b.path, b.aligned, b.f32acc); + } +}; + +enum shader_reduction_mode { + SHADER_REDUCTION_MODE_SHMEM, + SHADER_REDUCTION_MODE_HYBRID, + SHADER_REDUCTION_MODE_SUBGROUP, + SHADER_REDUCTION_MODE_COUNT, +}; + +static constexpr uint32_t num_argsort_pipelines = 11; +static constexpr uint32_t max_argsort_cols = 1 << (num_argsort_pipelines-1); + +struct vk_device_struct { + std::recursive_mutex mutex; + + vk::PhysicalDevice physical_device; + vk::PhysicalDeviceProperties properties; + std::string name; + uint64_t max_memory_allocation_size; + uint64_t max_buffer_size; + uint64_t suballocation_block_size; + bool fp16; + bool bf16; + bool pipeline_robustness; + vk::Device device; + uint32_t vendor_id; + vk::DriverId driver_id; + vk_device_architecture architecture; + vk_queue compute_queue; + vk_queue transfer_queue; + bool single_queue; + uint32_t subgroup_size; + uint32_t shader_core_count; + bool uma; + bool prefer_host_memory; + bool float_controls_rte_fp16; + bool subgroup_arithmetic; + bool subgroup_shuffle; + bool subgroup_ballot; + bool subgroup_clustered; + bool multi_add; + bool shader_int64; + bool buffer_device_address; + + bool add_rms_fusion; + uint32_t partials_binding_alignment; + + bool integer_dot_product; + // 0: default, 1: force mmvq, -1: disable mmvq + int32_t mmvq_mode; + + bool subgroup_size_control; + uint32_t subgroup_min_size; + uint32_t subgroup_max_size; + bool subgroup_require_full_support; + + bool coopmat_support; + bool coopmat_acc_f32_support {}; + bool coopmat_acc_f16_support {}; + bool coopmat_bf16_support {}; + bool coopmat_support_16x16x16_f16acc {}; + bool coopmat_support_16x16x16_f32acc {}; + bool coopmat1_fa_support {}; + uint32_t coopmat_m; + uint32_t coopmat_n; + uint32_t coopmat_k; + + bool coopmat_int_support; + uint32_t coopmat_int_m; + uint32_t coopmat_int_n; + uint32_t coopmat_int_k; + + bool coopmat2; + + bool pipeline_executable_properties_support {}; + + size_t idx; + + bool mul_mat_l[GGML_TYPE_COUNT]; + bool mul_mat_m[GGML_TYPE_COUNT]; + bool mul_mat_s[GGML_TYPE_COUNT]; + bool mul_mat_id_l[GGML_TYPE_COUNT]; + bool mul_mat_id_m[GGML_TYPE_COUNT]; + bool mul_mat_id_s[GGML_TYPE_COUNT]; + + // set to true to indicate that some shaders need to be compiled after the dryrun + bool need_compiles {}; + + vk::DescriptorSetLayout dsl; + + vk_matmul_pipeline pipeline_matmul_f32 {}; + vk_matmul_pipeline pipeline_matmul_f32_f16 {}; + vk_matmul_pipeline pipeline_matmul_bf16 {}; + vk_matmul_pipeline2 pipeline_matmul_f16; + vk_matmul_pipeline2 pipeline_matmul_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_COUNT]; + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_COUNT]; + + vk_matmul_pipeline pipeline_matmul_id_f32 {}; + vk_matmul_pipeline pipeline_matmul_id_bf16 {}; + vk_matmul_pipeline2 pipeline_matmul_id_f16; + vk_matmul_pipeline2 pipeline_matmul_id_f16_f32; + + vk_matmul_pipeline2 pipeline_dequant_mul_mat_mat_id[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_matmul_split_k_reduce; + vk_pipeline pipeline_quantize_q8_1; + vk_pipeline pipeline_quantize_q8_1_x4; + + vk_pipeline pipeline_dequant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_dequant_mul_mat_vec_f32_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_f16_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + vk_pipeline pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_dequant_mul_mat_vec_q8_1_f32[DMMV_WG_SIZE_COUNT][GGML_TYPE_COUNT][mul_mat_vec_max_cols]; + + vk_pipeline pipeline_mul_mat_vec_p021_f16_f32[p021_max_gqa_ratio]; + vk_pipeline pipeline_mul_mat_vec_nc_f16_f32; + vk_pipeline pipeline_get_rows[GGML_TYPE_COUNT]; + vk_pipeline pipeline_get_rows_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_acc_f32; + + // [src0 0=fp32,1=fp16][src1 0=fp32,1=fp16][dst 0=fp32,1=fp16] + vk_pipeline pipeline_add[2][2][2]; + vk_pipeline pipeline_add_norepeat[2][2][2]; + vk_pipeline pipeline_sub[2][2][2]; + vk_pipeline pipeline_sub_norepeat[2][2][2]; + vk_pipeline pipeline_mul[2][2][2]; + vk_pipeline pipeline_mul_norepeat[2][2][2]; + vk_pipeline pipeline_div[2][2][2]; + vk_pipeline pipeline_div_norepeat[2][2][2]; + vk_pipeline pipeline_add_rms[2][2][2]; + vk_pipeline pipeline_add_rms_norepeat[2][2][2]; + + // indexed by num_additional_fused_ops == num_adds - 1 + vk_pipeline pipeline_multi_add[MAX_FUSED_ADDS]; + vk_pipeline pipeline_multi_add_rms[MAX_FUSED_ADDS]; + + vk_pipeline pipeline_add_id_f32; + + vk_pipeline pipeline_concat_f32, pipeline_concat_f16, pipeline_concat_i32; + vk_pipeline pipeline_upscale_nearest_f32, pipeline_upscale_bilinear_f32, pipeline_upscale_bilinear_ac_f32; + vk_pipeline pipeline_scale_f32; + vk_pipeline pipeline_sqr_f32; + vk_pipeline pipeline_sqrt_f32; + vk_pipeline pipeline_sin_f32; + vk_pipeline pipeline_cos_f32; + vk_pipeline pipeline_clamp_f32; + vk_pipeline pipeline_pad_f32; + vk_pipeline pipeline_roll_f32; + vk_pipeline pipeline_repeat_f32, pipeline_repeat_back_f32; + vk_pipeline pipeline_cpy_f32_f32, pipeline_cpy_f32_f16, pipeline_cpy_f16_f16, pipeline_cpy_f16_f32, pipeline_cpy_f32_bf16, pipeline_cpy_f32_i32, pipeline_cpy_i32_f32; + vk_pipeline pipeline_contig_cpy_f32_f32, pipeline_contig_cpy_f32_f16, pipeline_contig_cpy_f16_f16, pipeline_contig_cpy_f16_f32, pipeline_contig_cpy_f32_bf16, pipeline_contig_cpy_f32_i32, pipeline_contig_cpy_i32_f32; + vk_pipeline pipeline_cpy_f32_quant[GGML_TYPE_COUNT]; + vk_pipeline pipeline_cpy_quant_f32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_set_rows_i32[GGML_TYPE_COUNT]; + vk_pipeline pipeline_set_rows_i64[GGML_TYPE_COUNT]; + vk_pipeline pipeline_norm_f32; + vk_pipeline pipeline_group_norm_f32; + vk_pipeline pipeline_rms_norm_f32; + vk_pipeline pipeline_rms_norm_mul_f32; + vk_pipeline pipeline_rms_norm_partials_f32; + vk_pipeline pipeline_rms_norm_mul_partials_f32; + vk_pipeline pipeline_rms_norm_back_f32; + vk_pipeline pipeline_l2_norm_f32; + + // [src/dst 0=fp32,1=fp16] + vk_pipeline pipeline_exp[2]; + vk_pipeline pipeline_gelu[2]; + vk_pipeline pipeline_gelu_erf[2]; + vk_pipeline pipeline_gelu_quick[2]; + vk_pipeline pipeline_silu[2]; + vk_pipeline pipeline_relu[2]; + vk_pipeline pipeline_tanh[2]; + vk_pipeline pipeline_sigmoid[2]; + vk_pipeline pipeline_hardsigmoid[2]; + vk_pipeline pipeline_hardswish[2]; + + vk_pipeline pipeline_geglu[2]; + vk_pipeline pipeline_reglu[2]; + vk_pipeline pipeline_swiglu[2]; + vk_pipeline pipeline_swiglu_oai[2]; + vk_pipeline pipeline_geglu_erf[2]; + vk_pipeline pipeline_geglu_quick[2]; + + vk_pipeline pipeline_leaky_relu_f32; + vk_pipeline pipeline_silu_back_f32; + vk_pipeline pipeline_diag_mask_inf_f32; + vk_pipeline pipeline_soft_max_f32, pipeline_soft_max_f32_f16; + vk_pipeline pipeline_soft_max_f32_wg512, pipeline_soft_max_f32_f16_wg512; + vk_pipeline pipeline_soft_max_back_f32; + vk_pipeline pipeline_rope_norm_f32, pipeline_rope_norm_f16; + vk_pipeline pipeline_rope_neox_f32, pipeline_rope_neox_f16; + vk_pipeline pipeline_rope_multi_f32, pipeline_rope_multi_f16; + vk_pipeline pipeline_rope_vision_f32, pipeline_rope_vision_f16; + vk_pipeline pipeline_argsort_f32[num_argsort_pipelines]; + vk_pipeline pipeline_sum_rows_f32; + vk_pipeline pipeline_argmax_f32; + vk_pipeline pipeline_count_equal_i32; + vk_pipeline pipeline_im2col_f32, pipeline_im2col_f32_f16; + vk_pipeline pipeline_im2col_3d_f32, pipeline_im2col_3d_f32_f16; + vk_pipeline pipeline_timestep_embedding_f32; + vk_pipeline pipeline_conv_transpose_1d_f32; + vk_pipeline pipeline_pool2d_f32; + vk_pipeline pipeline_rwkv_wkv6_f32; + vk_pipeline pipeline_rwkv_wkv7_f32; + vk_pipeline pipeline_opt_step_adamw_f32; + vk_pipeline pipeline_opt_step_sgd_f32; + vk_pipeline pipeline_conv2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv2d_f16_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv_transpose_2d_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv_transpose_2d_f16_f32[CONV_SHAPE_COUNT]; + vk_pipeline pipeline_conv2d_dw_whcn_f32, pipeline_conv2d_dw_whcn_f16_f32; + vk_pipeline pipeline_conv2d_dw_cwhn_f32, pipeline_conv2d_dw_cwhn_f16_f32; + + std::map pipeline_flash_attn_f32_f16[GGML_TYPE_COUNT]; + + vk_pipeline pipeline_flash_attn_split_k_reduce; + + std::vector all_pipelines; + + std::vector> pinned_memory; + + vk::Fence fence; + vk_buffer sync_staging; + + ggml_backend_buffer_type buffer_type; + + bool disable_fusion; + bool disable_host_visible_vidmem; + bool allow_sysmem_fallback; + bool disable_graph_optimize; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + std::unique_ptr memory_logger; +#endif + + // for GGML_VK_PERF_LOGGER + std::unique_ptr perf_logger; + vk::QueryPool query_pool; + int32_t num_queries; + + ~vk_device_struct() { + VK_LOG_DEBUG("destroy device " << name); + + device.destroyFence(fence); + + ggml_vk_destroy_buffer(sync_staging); + + compute_queue.cmd_pool.destroy(device); + transfer_queue.cmd_pool.destroy(device); + + for (auto& pipeline : all_pipelines) { + if (pipeline.expired()) { + continue; + } + + vk_pipeline pl = pipeline.lock(); + ggml_vk_destroy_pipeline(device, pl); + } + all_pipelines.clear(); + + device.destroyDescriptorSetLayout(dsl); + + device.destroy(); + } +}; + +void vk_command_pool::init(vk_device& device, vk_queue *q_) { + cmd_buffer_idx = 0; + q = q_; + + vk::CommandPoolCreateInfo command_pool_create_info(vk::CommandPoolCreateFlags(VK_COMMAND_POOL_CREATE_TRANSIENT_BIT), q->queue_family_index); + pool = device->device.createCommandPool(command_pool_create_info); +} + +void vk_command_pool::destroy(vk::Device& device) { + device.destroyCommandPool(pool); + pool = nullptr; + cmd_buffers.clear(); +} + +struct vk_buffer_struct { + vk::Buffer buffer = VK_NULL_HANDLE; + vk::DeviceMemory device_memory = VK_NULL_HANDLE; + vk::MemoryPropertyFlags memory_property_flags; + void * ptr; + size_t size = 0; + vk::DeviceAddress bda_addr {}; + + vk_device device; + + ~vk_buffer_struct() { + if (size == 0) { + return; + } + VK_LOG_DEBUG("~vk_buffer_struct(" << buffer << ", " << size << ")"); + + device->device.freeMemory(device_memory); + device->device.destroyBuffer(buffer); + } +}; + +struct vk_subbuffer { + vk_buffer buffer; + uint64_t offset; + uint64_t size; + + operator vk::DescriptorBufferInfo() const { + return { buffer->buffer, offset, size }; + } +}; + +struct vk_semaphore { + vk::Semaphore s; + uint64_t value; +}; + +struct vk_submission { + vk::CommandBuffer buffer; + std::vector wait_semaphores; + std::vector signal_semaphores; +}; + +typedef std::vector vk_sequence; + +struct vk_mat_mat_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t k_split; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; + uint32_t padded_N; +}; +struct vk_mat_vec_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t ne02; uint32_t ne12; uint32_t broadcast2; uint32_t broadcast3; +}; + +struct vk_mat_mat_id_push_constants { + uint32_t M; uint32_t N; uint32_t K; + uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t nei1; uint32_t nbi1; uint32_t ne11; + uint32_t padded_N; +}; +struct vk_mat_vec_id_push_constants { + uint32_t ncols; uint32_t stride_a; uint32_t stride_b; uint32_t stride_d; + uint32_t batch_stride_a; uint32_t batch_stride_b; uint32_t batch_stride_d; + uint32_t nei0; uint32_t ne11; +}; + +struct vk_flash_attn_push_constants { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + uint32_t nem2; + uint32_t nem3; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask_n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +}; +static_assert(sizeof(vk_flash_attn_push_constants) <= 128, "sizeof(vk_flash_attn_push_constants) must be <= 128"); + +struct vk_op_push_constants { + uint32_t KX; + uint32_t KY; + float param1; + float param2; +}; + +struct vk_op_glu_push_constants { + uint32_t N; + uint32_t ne00; + uint32_t ne20; + uint32_t mode; // 0: default, 1: swapped, 2: split + float alpha; // for swiglu_oai + float limit; +}; + +struct vk_op_unary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + float param1; float param2; + uint32_t ne0_012mp; uint32_t ne0_012L; + uint32_t ne0_01mp; uint32_t ne0_01L; + uint32_t ne0_0mp; uint32_t ne0_0L; + uint32_t ne1_012mp; uint32_t ne1_012L; + uint32_t ne1_01mp; uint32_t ne1_01L; + uint32_t ne1_0mp; uint32_t ne1_0L; +}; +static_assert(sizeof(vk_op_unary_push_constants) <= 128, "sizeof(vk_op_unary_push_constants) must be <= 128"); + +static vk_op_unary_push_constants vk_op_unary_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst, int64_t ne = 0) { + GGML_ASSERT(ne != 0 || (ggml_nelements(src0) == ggml_nelements(dst))); + ne = ne != 0 ? ne : ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_unary_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + return p; // offsets are initialized later in ggml_vk_op +} + +struct vk_op_pad_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t misalign_offsets; + + uint32_t lp0; uint32_t rp0; + uint32_t lp1; uint32_t rp1; + uint32_t lp2; uint32_t rp2; + uint32_t lp3; uint32_t rp3; +}; + +static vk_op_pad_push_constants vk_op_pad_push_constants_init(const ggml_tensor * src0, const ggml_tensor * dst) { + int64_t ne = ggml_nelements(dst); + GGML_ASSERT(ne <= (int64_t)std::numeric_limits::max()); + + vk_op_pad_push_constants p{}; + p.ne = (uint32_t)ne; + + size_t src0_tsize = ggml_type_size(src0->type); + p.ne00 = (uint32_t)src0->ne[0]; + p.ne01 = (uint32_t)src0->ne[1]; + p.ne02 = (uint32_t)src0->ne[2]; + p.ne03 = (uint32_t)src0->ne[3]; + p.nb00 = (uint32_t)(src0->nb[0] / src0_tsize); + p.nb01 = (uint32_t)(src0->nb[1] / src0_tsize); + p.nb02 = (uint32_t)(src0->nb[2] / src0_tsize); + p.nb03 = (uint32_t)(src0->nb[3] / src0_tsize); + + size_t dst_tsize = ggml_type_size(dst->type); + p.ne10 = (uint32_t)dst->ne[0]; + p.ne11 = (uint32_t)dst->ne[1]; + p.ne12 = (uint32_t)dst->ne[2]; + p.ne13 = (uint32_t)dst->ne[3]; + p.nb10 = (uint32_t)(dst->nb[0] / dst_tsize); + p.nb11 = (uint32_t)(dst->nb[1] / dst_tsize); + p.nb12 = (uint32_t)(dst->nb[2] / dst_tsize); + p.nb13 = (uint32_t)(dst->nb[3] / dst_tsize); + + p.lp0 = dst->op_params[0]; + p.rp0 = dst->op_params[1]; + p.lp1 = dst->op_params[2]; + p.rp1 = dst->op_params[3]; + p.lp2 = dst->op_params[4]; + p.rp2 = dst->op_params[5]; + p.lp3 = dst->op_params[6]; + p.rp3 = dst->op_params[7]; + + return p; // fastdiv values and offsets are initialized later in ggml_vk_op +} + +// See https://gmplib.org/~tege/divcnst-pldi94.pdf figure 4.1. +// Precompute mp (m' in the paper) and L such that division +// can be computed using a multiply (high 32b of 64b result) +// and a shift: +// +// n/d = (mulhi(n, mp) + n) >> L; +static void init_fastdiv_values(uint32_t d, uint32_t &mp, uint32_t &L) +{ + // compute L = ceil(log2(d)); + L = 0; + while (L < 32 && (uint32_t{1} << L) < d) { + L++; + } + + mp = (uint32_t)((uint64_t{1} << 32) * ((uint64_t{1} << L) - d) / d + 1); +} + +template void init_pushconst_fastdiv(T &p) { + GGML_UNUSED(p); + static_assert(!std::is_const::value, "unexpected type"); +} + +template <> void init_pushconst_fastdiv(vk_op_unary_push_constants &p) { + // Compute magic values to divide by these six numbers. + init_fastdiv_values(p.ne02*p.ne01*p.ne00, p.ne0_012mp, p.ne0_012L); + init_fastdiv_values(p.ne01*p.ne00, p.ne0_01mp, p.ne0_01L); + init_fastdiv_values(p.ne00, p.ne0_0mp, p.ne0_0L); + init_fastdiv_values(p.ne12*p.ne11*p.ne10, p.ne1_012mp, p.ne1_012L); + init_fastdiv_values(p.ne11*p.ne10, p.ne1_01mp, p.ne1_01L); + init_fastdiv_values(p.ne10, p.ne1_0mp, p.ne1_0L); +} + +struct vk_op_binary_push_constants { + uint32_t ne; + uint32_t ne00; uint32_t ne01; uint32_t ne02; uint32_t ne03; uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; uint32_t nb10; uint32_t nb11; uint32_t nb12; uint32_t nb13; + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; uint32_t nb20; uint32_t nb21; uint32_t nb22; uint32_t nb23; + uint32_t misalign_offsets; + float param1; float param2; int32_t param3; +}; + +struct vk_op_multi_add_push_constants { + // shape for dst + uint32_t ne20; uint32_t ne21; uint32_t ne22; uint32_t ne23; + + // strides for srcs+dst + uint32_t nb[MAX_PARAMETER_COUNT][4]; + + uint32_t rms_partials; +}; +// update multi_add.comp if this changes +static_assert(MAX_PARAMETER_COUNT == 12); +static_assert(sizeof(vk_op_multi_add_push_constants) <= 256); + +struct vk_op_add_id_push_constants { + uint32_t ne0; + uint32_t ne1; + uint32_t s01; + uint32_t s02; + uint32_t s11; + uint32_t s21; +}; + +struct vk_op_diag_mask_push_constants { + uint32_t ncols; + uint32_t rows_per_channel; + int32_t n_past; +}; + +struct vk_op_rope_push_constants { + uint32_t ncols; + uint32_t n_dims; + float freq_scale; + uint32_t p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint32_t has_ff; + uint32_t ne02; + uint32_t s1; + uint32_t s2; + int32_t sections[4]; + uint32_t is_back; +}; + +struct vk_op_soft_max_push_constants { + uint32_t KX; + uint32_t KY; + uint32_t ne00; + uint32_t ne01; + uint32_t ne02; + uint32_t ne12; + uint32_t ne13; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + float scale; + float max_bias; + float m0; + float m1; + uint32_t n_head_log2; + uint32_t nrows_x; + uint32_t has_sinks; +}; + +struct vk_op_argsort_push_constants { + uint32_t ncols; + int32_t order; +}; + +struct vk_op_im2col_push_constants { + uint64_t dst_addr; + uint32_t batch_offset; uint32_t offset_delta; + uint32_t IC; + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t KW; uint32_t KH; + uint32_t pelements; + uint32_t CHW; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; + int32_t d0; int32_t d1; +}; + +struct vk_op_im2col_3d_push_constants { + uint64_t dst_addr; + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +}; + +struct vk_op_timestep_embedding_push_constants { + uint32_t nb1; + uint32_t dim; + uint32_t max_period; +}; + +struct vk_op_conv_transpose_1d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +}; + +struct vk_op_pool2d_push_constants { + uint32_t IW; uint32_t IH; + uint32_t OW; uint32_t OH; + uint32_t OC; + uint32_t pelements; + uint32_t op; + int32_t k0; int32_t k1; + int32_t s0; int32_t s1; + int32_t p0; int32_t p1; +}; + +struct vk_op_rwkv_wkv6_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +struct vk_op_rwkv_wkv7_push_constants { + uint32_t B; + uint32_t T; + uint32_t C; + uint32_t H; +}; + +struct vk_op_conv2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); +} + +struct vk_op_conv_transpose_2d_push_constants { + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // init_fastdiv_values constants for dividing by KW, KW*KH, OW, OW*OH, s0, s1 + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +}; + +template <> void init_pushconst_fastdiv(vk_op_conv_transpose_2d_push_constants &p) { + // Compute magic values to divide by KW, KW*KH, OW, OW*OH, s0, s1 + init_fastdiv_values(p.KW, p.KWmp, p.KWL); + init_fastdiv_values(p.KW*p.KH, p.KWKHmp, p.KWKHL); + init_fastdiv_values(p.OW, p.OWmp, p.OWL); + init_fastdiv_values(p.OW*p.OH, p.OWOHmp, p.OWOHL); + init_fastdiv_values(p.s0, p.s0mp, p.s0L); + init_fastdiv_values(p.s1, p.s1mp, p.s1L); +} + +struct vk_op_conv2d_dw_push_constants { + uint32_t ne; + uint32_t batches; + uint32_t channels; + uint32_t dst_w; + uint32_t dst_h; + uint32_t src_w; + uint32_t src_h; + uint32_t knl_w; + uint32_t knl_h; + int32_t stride_x; + int32_t stride_y; + int32_t pad_x; + int32_t pad_y; + int32_t dilation_x; + int32_t dilation_y; +}; + +struct vk_op_upscale_push_constants { + uint32_t ne; uint32_t a_offset; uint32_t d_offset; + uint32_t ne00; uint32_t ne01; + uint32_t nb00; uint32_t nb01; uint32_t nb02; uint32_t nb03; + uint32_t ne10; uint32_t ne11; uint32_t ne12; uint32_t ne13; + float sf0; float sf1; float sf2; float sf3; +}; + +struct vk_op_sum_rows_push_constants +{ + uint32_t n_cols; + uint32_t ne01, ne02; + uint32_t nb01, nb02, nb03; + uint32_t nb11, nb12, nb13; + float weight; + uint32_t misalign_offsets; + uint32_t ne0_12mp, ne0_12L; + uint32_t ne0_1mp, ne0_1L; +}; + +static vk_op_sum_rows_push_constants vk_op_sum_rows_push_constants_init(const ggml_tensor * src, const ggml_tensor * dst, int64_t n_cols) { + uint32_t type_size = (uint32_t)ggml_type_size(src->type); + vk_op_sum_rows_push_constants p = {}; + p.n_cols = (uint32_t)n_cols; + p.ne01 = (uint32_t)src->ne[1]; + p.ne02 = (uint32_t)src->ne[2]; + p.nb01 = (uint32_t)src->nb[1] / type_size; + p.nb02 = (uint32_t)src->nb[2] / type_size; + p.nb03 = (uint32_t)src->nb[3] / type_size; + p.nb11 = (uint32_t)dst->nb[1] / type_size; + p.nb12 = (uint32_t)dst->nb[2] / type_size; + p.nb13 = (uint32_t)dst->nb[3] / type_size; + p.weight = 1.0f; + return p; +} + +template <> void init_pushconst_fastdiv(vk_op_sum_rows_push_constants &p) { + init_fastdiv_values(p.ne01*p.ne02, p.ne0_12mp, p.ne0_12L); + init_fastdiv_values(p.ne01, p.ne0_1mp, p.ne0_1L); +} + +// Allow pre-recording command buffers +struct vk_staging_memcpy { + vk_staging_memcpy(void * _dst, const void * _src, size_t _n) : dst(_dst), src(_src), n(_n) {} + + void * dst; + const void * src; + size_t n; +}; + +struct vk_staging_memset { + vk_staging_memset(void * _dst, uint32_t _val, size_t _n) : dst(_dst), val(_val), n(_n) {} + + void * dst; + uint32_t val; + size_t n; +}; + +struct vk_context_struct { + vk_submission * s; + std::vector seqs; + + int exit_tensor_idx; + + std::vector in_memcpys; + std::vector out_memcpys; + std::vector memsets; + + vk_command_pool * p {}; +}; +typedef std::shared_ptr vk_context; +typedef std::weak_ptr vk_context_ref; + +struct ggml_vk_garbage_collector { + std::vector tl_semaphores; + std::vector semaphores; + std::vector events; + std::vector temp_buffers; + std::vector contexts; +}; + +#if defined(GGML_VULKAN_MEMORY_DEBUG) || defined(GGML_VULKAN_DEBUG) +#define VK_LOG_MEMORY(msg) std::cerr << "ggml_vulkan memory: " << msg << std::endl + +static std::string format_size(size_t size) { + const size_t kib = 1024; + const size_t mib = kib * 1024; + const size_t gib = mib * 1024; + + std::ostringstream oss; + oss << std::fixed << std::setprecision(2); + + if (size >= gib) { + oss << static_cast(size) / gib << " GiB"; + } else if (size >= mib) { + oss << static_cast(size) / mib << " MiB"; + } else if (size >= kib) { + oss << static_cast(size) / kib << " KiB"; + } else { + oss << size << " B"; + } + + return oss.str(); +} + +class vk_memory_logger { +public: + vk_memory_logger(): total_device(0), total_host(0) {} + void log_allocation(vk_buffer_ref buf_ref, size_t size); + void log_deallocation(vk_buffer_ref buf_ref); + +private: + std::map allocations; // Track allocations + size_t total_device; + size_t total_host; +}; +#else +#define VK_LOG_MEMORY(msg) ((void) 0) +#endif // GGML_VULKAN_MEMORY_DEBUG + +class vk_perf_logger { + public: + void print_timings() { + if (timings.empty()) { + return; + } + uint64_t total_all_op_times = 0; + std::cerr << "----------------\nVulkan Timings:" << std::endl; + for (const auto & t : timings) { + uint64_t total_op_times = 0; + for (const auto & time : t.second) { + total_op_times += time; + } + std::cerr << t.first << ": " << t.second.size() << " x " << (total_op_times / t.second.size() / 1000.0) + << " us"; + + // If we have as many flops entries as timing entries for the op, then compute and log the flops/S. + auto it = flops.find(t.first); + if (it != flops.end() && (it->second).size() == t.second.size()) { + uint64_t total_op_flops = 0; + for (const auto & elem : it->second) { + total_op_flops += elem; + } + std::cerr << " (" + << (double(total_op_flops) / (1000.0 * 1000.0 * 1000.0)) / + (double(total_op_times) / (1000.0 * 1000.0 * 1000.0)) + << " GFLOPS/s)"; + } + + total_all_op_times += total_op_times; + + std::cerr << std::endl; + } + + if (timings.size() > 0) { + std::cerr << "Total time: " << total_all_op_times / 1000.0 << " us." << std::endl; + } + + timings.clear(); + flops.clear(); + } + + void log_timing(const ggml_tensor * node, uint64_t time) { + if (node->op == GGML_OP_UNARY) { + timings[ggml_unary_op_name(ggml_get_unary_op(node))].push_back(time); + return; + } + if (node->op == GGML_OP_MUL_MAT || node->op == GGML_OP_MUL_MAT_ID) { + const uint64_t m = node->src[0]->ne[1]; + const uint64_t n = node->ne[1]; + const uint64_t k = node->src[1]->ne[0]; + const uint64_t batch = node->src[1]->ne[2] * node->src[1]->ne[3]; + std::string name = ggml_op_name(node->op); + if ((node->op == GGML_OP_MUL_MAT && n <= mul_mat_vec_max_cols) || + (node->op == GGML_OP_MUL_MAT_ID && node->src[2]->ne[1] == 1)) { + name += "_VEC"; + } + name += " "; + name += ggml_type_name(node->src[0]->type); + name += " m=" + std::to_string(m) + " n=" + std::to_string(n) + " k=" + std::to_string(k); + if (batch > 1) { + name += " batch=" + std::to_string(batch); + } + timings[name].push_back(time); + flops[name].push_back(m * n * (k + (k - 1)) * batch); + return; + } + if (node->op == GGML_OP_CONV_2D || node->op == GGML_OP_CONV_TRANSPOSE_2D) { + std::string name = ggml_op_name(node->op); + ggml_tensor * knl = node->src[0]; + uint64_t OW = node->ne[0]; + uint64_t OH = node->ne[1]; + uint64_t N = node->ne[3]; + uint64_t Cout = node->ne[2]; + uint64_t KW = knl->ne[0]; + uint64_t KH = knl->ne[1]; + uint64_t Cin = node->src[1]->ne[2]; + // KxCRS @ CRSxNPQ = KxNPQ -> M=K, K=CRS, N=NPQ + uint64_t size_M = Cout; + uint64_t size_K = Cin * KW * KH; + uint64_t size_N = N * OW * OH; + uint64_t n_flops = size_M * size_N * (size_K + (size_K - 1)); + name += " M=Cout=" + std::to_string(size_M) + ", K=Cin*KW*KH=" + std::to_string(size_K) + + ", N=N*OW*OH=" + std::to_string(size_N); + flops[name].push_back(n_flops); + timings[name].push_back(time); + return; + } + if (node->op == GGML_OP_RMS_NORM) { + std::string name = ggml_op_name(node->op); + name += "(" + std::to_string(node->ne[0]) + "," + std::to_string(node->ne[1]) + "," + std::to_string(node->ne[2]) + "," + std::to_string(node->ne[3]) + ")"; + timings[name].push_back(time); + return; + } + timings[ggml_op_name(node->op)].push_back(time); + } + private: + std::map> timings; + std::map> flops; +}; + +struct ggml_backend_vk_context { + std::string name; + + vk_device device; + + size_t semaphore_idx, event_idx; + ggml_vk_garbage_collector gc; + size_t prealloc_size_x, prealloc_size_y, prealloc_size_split_k, prealloc_size_add_rms_partials, prealloc_size_add_rms_partials_offset; + vk_buffer prealloc_x, prealloc_y, prealloc_split_k, prealloc_add_rms_partials; + vk::Fence fence, almost_ready_fence; + bool almost_ready_fence_pending {}; + // Set before op_add and unset after op_rms_norm to indicate that the add should + // write partial sums to accumulate the square of the vector components + bool do_add_rms_partials; + + // Cache most recent tensor that was converted into prealloc_y, and what pipeline it used to convert. + vk_pipeline_struct * prealloc_y_last_pipeline_used {}; + const ggml_tensor * prealloc_y_last_tensor_used {}; + + // Track which nodes have been used since the last sync, and whether they were written to + std::vector unsynced_nodes_written; + std::vector unsynced_nodes_read; + // Track which prealloc buffers have pending reads that need to be synchronized. + // These are checked before writing to the buffer (and call ggml_vk_sync_buffers if set), + // and set to true after the buffer contents are consumed. + bool prealloc_x_need_sync, prealloc_y_need_sync, prealloc_split_k_need_sync; + + vk_buffer buffer_pool[MAX_VK_BUFFERS]; + + vk_context_ref compute_ctx; + vk_context_ref transfer_ctx; + + std::vector tensor_ctxs; + + std::vector descriptor_pools; + std::vector descriptor_sets; + uint32_t descriptor_set_idx {}; + uint32_t pipeline_descriptor_set_requirements {}; + + vk_command_pool compute_cmd_pool; + vk_command_pool transfer_cmd_pool; + + // number of additional consecutive nodes that are being fused with the + // node currently being processed + int num_additional_fused_ops {}; +}; + +static void * const vk_ptr_base = (void *)(uintptr_t) 0x1000; // NOLINT + +static uint64_t vk_tensor_offset(const ggml_tensor * tensor) { + if (tensor->view_src) { + return (uint8_t *) tensor->view_src->data - (uint8_t *) vk_ptr_base; + } + return (uint8_t *) tensor->data - (uint8_t *) vk_ptr_base; +} + +struct ggml_backend_vk_buffer_context { + vk_device_ref device; + vk_buffer dev_buffer; + std::string name; + + ggml_backend_vk_buffer_context(vk_device_ref device, vk_buffer&& dev_buffer, std::string& name) : + device(device), + dev_buffer(dev_buffer), + name(name) { + } + + ~ggml_backend_vk_buffer_context() { + ggml_vk_destroy_buffer(dev_buffer); + } +}; + +#ifdef GGML_VULKAN_MEMORY_DEBUG +static std::mutex log_mutex; + +void vk_memory_logger::log_allocation(vk_buffer_ref buf_ref, size_t size) { + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + const std::string type = device ? "device" : "host"; + allocations[buf->buffer] = size; + total_device += device ? size : 0; + total_host += device ? 0 : size; + VK_LOG_MEMORY(buf->device->name << ": +" << format_size(size) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); +} + +void vk_memory_logger::log_deallocation(vk_buffer_ref buf_ref) { + if (buf_ref.expired() || buf_ref.lock()->size == 0) { + return; + } + + std::lock_guard guard(log_mutex); + vk_buffer buf = buf_ref.lock(); + const bool device = bool(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eDeviceLocal); + std::string type = device ? "device" : "host"; + auto it = allocations.find(buf->buffer); + total_device -= device ? it->second : 0; + total_host -= device ? 0 : it->second; + if (it != allocations.end()) { + VK_LOG_MEMORY(buf->device->name << ": -" << format_size(it->second) << " " << type << " at " << buf->buffer << ". Total device: " << format_size(total_device) << ", total host: " << format_size(total_host)); + allocations.erase(it); + } else { + VK_LOG_MEMORY("ERROR " << buf->device->name << ": Attempted to deallocate unknown " << type << " memory at " << buf->buffer); + } +} +#endif // GGML_VULKAN_MEMORY_DEBUG + +struct vk_instance_t { + vk::Instance instance; + + bool debug_utils_support = false; // VK_EXT_debug_utils enabled + PFN_vkSetDebugUtilsObjectNameEXT pfn_vkSetDebugUtilsObjectNameEXT = {}; + PFN_vkQueueBeginDebugUtilsLabelEXT pfn_vkQueueBeginDebugUtilsLabelEXT = {}; + PFN_vkQueueEndDebugUtilsLabelEXT pfn_vkQueueEndDebugUtilsLabelEXT = {}; + PFN_vkCmdBeginDebugUtilsLabelEXT pfn_vkCmdBeginDebugUtilsLabelEXT = {}; + PFN_vkCmdEndDebugUtilsLabelEXT pfn_vkCmdEndDebugUtilsLabelEXT = {}; + PFN_vkCmdInsertDebugUtilsLabelEXT pfn_vkCmdInsertDebugUtilsLabelEXT = {}; + + std::vector device_indices; + std::vector device_supports_membudget; + vk_device devices[GGML_VK_MAX_DEVICES]; +}; + +static bool vk_instance_initialized = false; +static vk_instance_t vk_instance; + +static bool vk_perf_logger_enabled = false; + +#ifdef GGML_VULKAN_CHECK_RESULTS +static size_t vk_skip_checks; +static size_t vk_output_tensor; + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name); +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx); +#endif + +typedef void (*ggml_vk_func_t)(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst); + +static void ggml_backend_vk_free(ggml_backend_t backend); + +static VkDeviceSize ggml_vk_get_max_buffer_range(const ggml_backend_vk_context * ctx, const vk_buffer &buf, const VkDeviceSize offset) { + const VkDeviceSize range = std::min(VkDeviceSize{buf->size - offset}, + VkDeviceSize{ctx->device->properties.limits.maxStorageBufferRange}); + return range; +} + +// Wait for ctx->fence to be signaled. +static void ggml_vk_wait_for_fence(ggml_backend_vk_context * ctx) { + // Use waitForFences while most of the graph executes. Hopefully the CPU can sleep + // during this wait. + if (ctx->almost_ready_fence_pending) { + VK_CHECK(ctx->device->device.waitForFences({ ctx->almost_ready_fence }, true, UINT64_MAX), "almost_ready_fence"); + ctx->device->device.resetFences({ ctx->almost_ready_fence }); + ctx->almost_ready_fence_pending = false; + } + + // Spin (w/pause) waiting for the graph to finish executing. + vk::Result result; + while ((result = ctx->device->device.getFenceStatus(ctx->fence)) != vk::Result::eSuccess) { + if (result != vk::Result::eNotReady) { + fprintf(stderr, "ggml_vulkan: error %s at %s:%d\n", to_string(result).c_str(), __FILE__, __LINE__); + exit(1); + } + for (uint32_t i = 0; i < 100; ++i) { + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + YIELD(); + } + } + ctx->device->device.resetFences({ ctx->fence }); +} + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; + +static void ggml_vk_create_pipeline_func(vk_device& device, vk_pipeline& pipeline, size_t spv_size, const void* spv_data, const std::string entrypoint, + uint32_t parameter_count, std::array wg_denoms, std::vector specialization_constants, + bool disable_robustness, bool require_full_subgroups, uint32_t required_subgroup_size) { + VK_LOG_DEBUG("ggml_vk_create_pipeline(" << device->name << ", " << pipeline->name << ", " << entrypoint << ", " << parameter_count << + ", (" << wg_denoms[0] << "," << wg_denoms[1] << "," << wg_denoms[2] << "), specialization_constants, " << + disable_robustness << ", " << require_full_subgroups << ", " << required_subgroup_size << ")"); + GGML_ASSERT(parameter_count > 0); + GGML_ASSERT(parameter_count <= MAX_PARAMETER_COUNT); + GGML_ASSERT(wg_denoms[0] > 0 && wg_denoms[1] > 0 && wg_denoms[2] > 0); // NOLINT + + vk::ShaderModuleCreateInfo shader_module_create_info({}, spv_size, reinterpret_cast(spv_data)); + pipeline->shader_module = device->device.createShaderModule(shader_module_create_info); + + vk::PushConstantRange pcr( + vk::ShaderStageFlagBits::eCompute, + 0, + pipeline->push_constant_size + ); + + vk::PipelineLayoutCreateInfo pipeline_layout_create_info(vk::PipelineLayoutCreateFlags(), device->dsl, pcr); + pipeline->layout = device->device.createPipelineLayout(pipeline_layout_create_info); + + std::vector specialization_entries(specialization_constants.size()); + + for (size_t i = 0; i < specialization_constants.size(); i++) { + specialization_entries[i].constantID = i; + specialization_entries[i].offset = i * sizeof(uint32_t); + specialization_entries[i].size = sizeof(uint32_t); + } + + vk::SpecializationInfo specialization_info( + specialization_entries.size(), + specialization_entries.data(), + specialization_constants.size() * sizeof(uint32_t), + specialization_constants.data() + ); + + vk::PipelineShaderStageCreateFlags pipeline_shader_stage_create_flags{}; + + if (device->subgroup_require_full_support && require_full_subgroups) { + pipeline_shader_stage_create_flags |= vk::PipelineShaderStageCreateFlagBits::eRequireFullSubgroupsEXT; + } + + vk::PipelineShaderStageCreateInfo pipeline_shader_create_info( + pipeline_shader_stage_create_flags, + vk::ShaderStageFlagBits::eCompute, + pipeline->shader_module, + entrypoint.c_str(), + &specialization_info); + + vk::PipelineShaderStageRequiredSubgroupSizeCreateInfoEXT pipeline_shader_stage_required_subgroup_size_create_info; + pipeline_shader_stage_required_subgroup_size_create_info.requiredSubgroupSize = required_subgroup_size; + if (device->subgroup_size_control && required_subgroup_size > 0) { + GGML_ASSERT(device->subgroup_min_size <= required_subgroup_size && required_subgroup_size <= device->subgroup_max_size); + pipeline_shader_create_info.setPNext(&pipeline_shader_stage_required_subgroup_size_create_info); + } + + vk::ComputePipelineCreateInfo compute_pipeline_create_info( + device->pipeline_executable_properties_support ? + vk::PipelineCreateFlagBits::eCaptureStatisticsKHR : + vk::PipelineCreateFlags{}, + pipeline_shader_create_info, + pipeline->layout); + + vk::PipelineRobustnessCreateInfoEXT rci; + + if (device->pipeline_robustness && disable_robustness) { + rci.storageBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + rci.uniformBuffers = vk::PipelineRobustnessBufferBehaviorEXT::eDisabled; + compute_pipeline_create_info.setPNext(&rci); + } + + try { + pipeline->pipeline = device->device.createComputePipeline(VK_NULL_HANDLE, compute_pipeline_create_info).value; + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Compute pipeline creation failed for " << pipeline->name << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + pipeline->compiled = true; + + if (vk_instance.debug_utils_support) { + vk::DebugUtilsObjectNameInfoEXT duoni; + duoni.objectType = vk::ObjectType::ePipeline; + duoni.pObjectName = pipeline->name.c_str(); + duoni.objectHandle = /*reinterpret_cast*/(uint64_t)(static_cast(pipeline->pipeline)); + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT(device->device, &static_cast(duoni)); + } + + if (device->pipeline_executable_properties_support) { + vk::PipelineExecutableInfoKHR executableInfo; + executableInfo.pipeline = pipeline->pipeline; + + auto statistics = device->device.getPipelineExecutableStatisticsKHR(executableInfo); + for (auto & s : statistics) { + // "Register Count" is reported by NVIDIA drivers. + if (strcmp(s.name, "Register Count") == 0) { + VK_LOG_DEBUG(pipeline->name << " " << s.name << ": " << s.value.u64 << " registers"); + pipeline->register_count = (uint32_t)s.value.u64; + } + } + } + + { + std::lock_guard guard(device->mutex); + device->all_pipelines.push_back(pipeline); + } + + { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + } + compile_count_cond.notify_all(); +} + +static void ggml_vk_destroy_pipeline(vk::Device& device, vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_pipeline_destroy_pipeline(" << pipeline->name << ")"); + device.destroyPipelineLayout(pipeline->layout); + + device.destroyShaderModule(pipeline->shader_module); + + device.destroyPipeline(pipeline->pipeline); +} + +static void ggml_pipeline_request_descriptor_sets(ggml_backend_vk_context *ctx, vk_pipeline& pipeline, uint32_t n) { + VK_LOG_DEBUG("ggml_pipeline_request_descriptor_sets(" << pipeline->name << ", " << n << ")"); + ctx->pipeline_descriptor_set_requirements += n; + if (!pipeline->compiled) { + pipeline->needed = true; + ctx->device->need_compiles = true; + } +} + +static void ggml_pipeline_allocate_descriptor_sets(ggml_backend_vk_context * ctx) { + + if (ctx->descriptor_sets.size() >= ctx->pipeline_descriptor_set_requirements) { + // Enough descriptors are available + return; + } + + vk_device& device = ctx->device; + + uint32_t to_alloc = ctx->pipeline_descriptor_set_requirements - ctx->descriptor_sets.size(); + uint32_t pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE - ctx->descriptor_sets.size() % VK_DEVICE_DESCRIPTOR_POOL_SIZE; + uint32_t pool_idx = ctx->descriptor_sets.size() / VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + while (to_alloc > 0) { + const uint32_t alloc_count = std::min(pool_remaining, to_alloc); + to_alloc -= alloc_count; + pool_remaining = VK_DEVICE_DESCRIPTOR_POOL_SIZE; + + if (pool_idx >= ctx->descriptor_pools.size()) { + vk::DescriptorPoolSize descriptor_pool_size(vk::DescriptorType::eStorageBuffer, MAX_PARAMETER_COUNT * VK_DEVICE_DESCRIPTOR_POOL_SIZE); + vk::DescriptorPoolCreateInfo descriptor_pool_create_info({}, VK_DEVICE_DESCRIPTOR_POOL_SIZE, descriptor_pool_size); + ctx->descriptor_pools.push_back(device->device.createDescriptorPool(descriptor_pool_create_info)); + } + + std::vector layouts(alloc_count); + for (uint32_t i = 0; i < alloc_count; i++) { + layouts[i] = device->dsl; + } + vk::DescriptorSetAllocateInfo descriptor_set_alloc_info(ctx->descriptor_pools[pool_idx], alloc_count, layouts.data()); + std::vector sets = device->device.allocateDescriptorSets(descriptor_set_alloc_info); + ctx->descriptor_sets.insert(ctx->descriptor_sets.end(), sets.begin(), sets.end()); + + pool_idx++; + } +} + +static vk::CommandBuffer ggml_vk_create_cmd_buffer(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_create_cmd_buffer()"); + + if (p.cmd_buffers.size() > p.cmd_buffer_idx) { + // Reuse command buffer + return p.cmd_buffers[p.cmd_buffer_idx++]; + } + + vk::CommandBufferAllocateInfo command_buffer_alloc_info( + p.pool, + vk::CommandBufferLevel::ePrimary, + 1); + const std::vector cmd_buffers = device->device.allocateCommandBuffers(command_buffer_alloc_info); + auto buf = cmd_buffers.front(); + + p.cmd_buffers.push_back(buf); + p.cmd_buffer_idx++; + + return buf; +} + +static void ggml_vk_submit(vk_context& ctx, vk::Fence fence) { + if (ctx->seqs.empty()) { + if (fence) { + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit({}, fence); + } + return; + } + VK_LOG_DEBUG("ggml_vk_submit(" << ctx << ", " << fence << ")"); + + std::vector> tl_wait_vals; + std::vector> tl_signal_vals; + std::vector> tl_wait_semaphores; + std::vector> tl_signal_semaphores; + std::vector tl_submit_infos; + std::vector submit_infos; + int idx = -1; + std::vector> stage_flags; + + size_t reserve = 0; + + for (const auto& sequence : ctx->seqs) { + reserve += sequence.size(); + } + + // Pre-reserve vectors to prevent reallocation, which invalidates pointers + tl_wait_semaphores.reserve(reserve); + tl_wait_vals.reserve(reserve); + tl_signal_semaphores.reserve(reserve); + tl_signal_vals.reserve(reserve); + tl_submit_infos.reserve(reserve); + submit_infos.reserve(reserve); + stage_flags.reserve(reserve); + + for (const auto& sequence : ctx->seqs) { + for (const auto& submission : sequence) { + stage_flags.push_back({}); + idx++; + tl_wait_vals.push_back({}); + tl_wait_semaphores.push_back({}); + tl_signal_vals.push_back({}); + tl_signal_semaphores.push_back({}); + for (size_t i = 0; i < submission.wait_semaphores.size(); i++) { + stage_flags[idx].push_back(ctx->p->q->stage_flags); + tl_wait_vals[idx].push_back(submission.wait_semaphores[i].value); + tl_wait_semaphores[idx].push_back(submission.wait_semaphores[i].s); + } + for (size_t i = 0; i < submission.signal_semaphores.size(); i++) { + tl_signal_vals[idx].push_back(submission.signal_semaphores[i].value); + tl_signal_semaphores[idx].push_back(submission.signal_semaphores[i].s); + } + tl_submit_infos.push_back({ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_vals[idx].data(), + (uint32_t) submission.signal_semaphores.size(), + tl_signal_vals[idx].data(), + }); + tl_submit_infos[idx].sType = vk::StructureType::eTimelineSemaphoreSubmitInfo; + tl_submit_infos[idx].pNext = nullptr; + vk::SubmitInfo si{ + (uint32_t) submission.wait_semaphores.size(), + tl_wait_semaphores[idx].data(), + stage_flags[idx].data(), + 1, + &submission.buffer, + (uint32_t) submission.signal_semaphores.size(), + tl_signal_semaphores[idx].data(), + }; + si.setPNext(&tl_submit_infos[idx]); + submit_infos.push_back(si); + } + } + + std::lock_guard guard(queue_mutex); + ctx->p->q->queue.submit(submit_infos, fence); + + ctx->seqs.clear(); +} + +static uint32_t ggml_vk_find_queue_family_index(std::vector& queue_family_props, const vk::QueueFlags& required, const vk::QueueFlags& avoid, int32_t compute_index, uint32_t min_num_queues) { + VK_LOG_DEBUG("ggml_vk_find_queue_family_index()"); + const uint32_t qfsize = queue_family_props.size(); + + // Try with avoid preferences first + for (uint32_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required && !(queue_family_props[i].queueFlags & avoid)) { + return i; + } + } + + // Fall back to only required + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && (compute_index < 0 || i != (uint32_t) compute_index) && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to reusing compute queue + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueCount >= min_num_queues && queue_family_props[i].queueFlags & required) { + return i; + } + } + + // Fall back to ignoring min_num_queries + for (size_t i = 0; i < qfsize; i++) { + if (queue_family_props[i].queueFlags & required) { + return i; + } + } + + // All commands that are allowed on a queue that supports transfer operations are also allowed on a queue that supports either graphics or compute operations. + // Thus, if the capabilities of a queue family include VK_QUEUE_GRAPHICS_BIT or VK_QUEUE_COMPUTE_BIT, then reporting the VK_QUEUE_TRANSFER_BIT capability separately for that queue family is optional. + if (compute_index >= 0) { + return compute_index; + } + + std::cerr << "ggml_vulkan: No suitable queue family index found." << std::endl; + + for(auto &q_family : queue_family_props) { + std::cerr << "Queue number: " + std::to_string(q_family.queueCount) << " flags: " + to_string(q_family.queueFlags) << std::endl; + } + abort(); +} + +static void ggml_vk_create_queue(vk_device& device, vk_queue& q, uint32_t queue_family_index, uint32_t queue_index, vk::PipelineStageFlags&& stage_flags, bool transfer_only) { + VK_LOG_DEBUG("ggml_vk_create_queue()"); + std::lock_guard guard(device->mutex); + + q.queue_family_index = queue_family_index; + q.transfer_only = transfer_only; + + q.cmd_pool.init(device, &q); + + q.queue = device->device.getQueue(queue_family_index, queue_index); + + q.stage_flags = stage_flags; +} + +static vk_context ggml_vk_create_context(ggml_backend_vk_context * ctx, vk_command_pool& p) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_context(" << result << ")"); + ctx->gc.contexts.emplace_back(result); + result->p = &p; + return result; +} + +static vk_context ggml_vk_create_temporary_context(vk_command_pool& p) { + vk_context result = std::make_shared(); + VK_LOG_DEBUG("ggml_vk_create_temporary_context(" << result << ")"); + result->p = &p; + return result; +} + +static vk_semaphore * ggml_vk_create_binary_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eBinary, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.semaphores.push_back({ semaphore, 0 }); + return &ctx->gc.semaphores[ctx->gc.semaphores.size() - 1]; +} + +static vk_semaphore * ggml_vk_create_timeline_semaphore(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_create_timeline_semaphore()"); + if (ctx->semaphore_idx >= ctx->gc.tl_semaphores.size()) { + vk::SemaphoreTypeCreateInfo tci{ vk::SemaphoreType::eTimeline, 0 }; + vk::SemaphoreCreateInfo ci{}; + ci.setPNext(&tci); + vk::Semaphore semaphore = ctx->device->device.createSemaphore(ci); + ctx->gc.tl_semaphores.push_back({ semaphore, 0 }); + } + return &ctx->gc.tl_semaphores[ctx->semaphore_idx++]; +} + +static vk::Event ggml_vk_create_event(ggml_backend_vk_context * ctx) { + if (ctx->event_idx >= ctx->gc.events.size()) { + ctx->gc.events.push_back(ctx->device->device.createEvent({})); + } + return ctx->gc.events[ctx->event_idx++]; +} + +static void ggml_vk_command_pool_cleanup(vk_device& device, vk_command_pool& p) { + VK_LOG_DEBUG("ggml_vk_command_pool_cleanup()"); + + // Requires command buffers to be done + device->device.resetCommandPool(p.pool); + p.cmd_buffer_idx = 0; +} + +static void ggml_vk_queue_command_pools_cleanup(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_queue_command_pools_cleanup()"); + + // Arbitrary frequency to cleanup/reuse command buffers + static constexpr uint32_t cleanup_frequency = 10; + + if (device->compute_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->compute_queue.cmd_pool); + } + if (device->transfer_queue.cmd_pool.cmd_buffer_idx >= cleanup_frequency) { + ggml_vk_command_pool_cleanup(device, device->transfer_queue.cmd_pool); + } +} + + +static uint32_t find_properties(const vk::PhysicalDeviceMemoryProperties* mem_props, vk::MemoryRequirements* mem_req, vk::MemoryPropertyFlags flags) { + for (uint32_t i = 0; i < mem_props->memoryTypeCount; ++i) { + vk::MemoryType memory_type = mem_props->memoryTypes[i]; + if ((mem_req->memoryTypeBits & ((uint64_t)1 << i)) && + (flags & memory_type.propertyFlags) == flags && + mem_props->memoryHeaps[memory_type.heapIndex].size >= mem_req->size) { + return static_cast(i); + } + } + return UINT32_MAX; +} + +static vk_buffer ggml_vk_create_buffer(vk_device& device, size_t size, const std::initializer_list & req_flags_list) { + VK_LOG_DEBUG("ggml_vk_create_buffer(" << device->name << ", " << size << ", " << to_string(req_flags_list.begin()[0]) << ", " << to_string(req_flags_list.begin()[req_flags_list.size()-1]) << ")"); + if (size > device->max_buffer_size) { + throw vk::OutOfDeviceMemoryError("Requested buffer size exceeds device buffer size limit"); + } + + vk_buffer buf = std::make_shared(); + + if (size == 0) { + buf->size = 0; + return buf; + } + + vk::BufferUsageFlags usage_flags = vk::BufferUsageFlagBits::eStorageBuffer | vk::BufferUsageFlagBits::eTransferSrc | vk::BufferUsageFlagBits::eTransferDst; + vk::MemoryAllocateFlags mem_flags {}; + if (device->buffer_device_address) { + usage_flags |= vk::BufferUsageFlagBits::eShaderDeviceAddress; + mem_flags |= vk::MemoryAllocateFlagBits::eDeviceAddress; + } + + vk::BufferCreateInfo buffer_create_info{ + vk::BufferCreateFlags(), + size, + usage_flags, + vk::SharingMode::eExclusive, + 0, + nullptr, + }; + + buf->buffer = device->device.createBuffer(buffer_create_info); + + vk::MemoryRequirements mem_req = device->device.getBufferMemoryRequirements(buf->buffer); + + vk::PhysicalDeviceMemoryProperties mem_props = device->physical_device.getMemoryProperties(); + + const vk::MemoryAllocateFlagsInfo mem_flags_info { mem_flags }; + + for (auto it = req_flags_list.begin(); it != req_flags_list.end(); it++) { + const auto & req_flags = *it; + + uint32_t memory_type_index = find_properties(&mem_props, &mem_req, req_flags); + + if (memory_type_index == UINT32_MAX) { + continue; + } + buf->memory_property_flags = req_flags; + + try { + buf->device_memory = device->device.allocateMemory({ mem_req.size, memory_type_index, &mem_flags_info }); + break; + } catch (const vk::SystemError& e) { + // loop and retry + // during last attempt throw the exception + if (it + 1 == req_flags_list.end()) { + device->device.destroyBuffer(buf->buffer); + throw e; + } + } + } + + if (!buf->device_memory) { + device->device.destroyBuffer(buf->buffer); + throw vk::OutOfDeviceMemoryError("No suitable memory type found"); + } + + buf->ptr = nullptr; + + if (buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + buf->ptr = device->device.mapMemory(buf->device_memory, 0, VK_WHOLE_SIZE); + } + + device->device.bindBufferMemory(buf->buffer, buf->device_memory, 0); + + buf->device = device; + buf->size = size; + + if (device->buffer_device_address) { + const vk::BufferDeviceAddressInfo addressInfo(buf->buffer); + buf->bda_addr = device->device.getBufferAddress(addressInfo); + } + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger->log_allocation(buf, size); +#endif + + return buf; +} + +static vk_buffer ggml_vk_create_buffer_check(vk_device& device, size_t size, vk::MemoryPropertyFlags req_flags, vk::MemoryPropertyFlags fallback_flags = vk::MemoryPropertyFlags(0)) { + try { + return ggml_vk_create_buffer(device, size, {req_flags, fallback_flags}); + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } +} + +static vk_buffer ggml_vk_create_buffer_device(vk_device& device, size_t size) { + vk_buffer buf; + try { + if (device->prefer_host_memory) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); + } else if (device->uma) { + // Fall back to host memory type + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else if (device->disable_host_visible_vidmem) { + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } else { + // use rebar if available, otherwise fallback to device only visible memory + if (device->allow_sysmem_fallback) { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + } else { + buf = ggml_vk_create_buffer(device, size, {vk::MemoryPropertyFlagBits::eDeviceLocal | vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent, + vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } + } catch (const vk::SystemError& e) { + std::cerr << "ggml_vulkan: Device memory allocation of size " << size << " failed." << std::endl; + std::cerr << "ggml_vulkan: " << e.what() << std::endl; + throw e; + } + + return buf; +} + +static void ggml_vk_destroy_buffer(vk_buffer& buf) { + if (buf == nullptr) { + return; + } + +#ifdef GGML_VULKAN_MEMORY_DEBUG + if (buf->device != nullptr) { + buf->device->memory_logger->log_deallocation(buf); + } +#endif + + buf.reset(); +} + +static vk_subbuffer ggml_vk_subbuffer(const ggml_backend_vk_context* ctx, const vk_buffer& buf, size_t offset = 0) { + return { buf, offset, ggml_vk_get_max_buffer_range(ctx, buf, offset) }; +} + +static void ggml_vk_sync_buffers(ggml_backend_vk_context* ctx, vk_context& subctx) { + VK_LOG_DEBUG("ggml_vk_sync_buffers()"); + + const bool transfer_queue = subctx->p->q->transfer_only; + + if (ctx) { + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + } + + subctx->s->buffer.pipelineBarrier( + subctx->p->q->stage_flags, + subctx->p->q->stage_flags, + {}, + { { + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) }, + { !transfer_queue ? (vk::AccessFlagBits::eShaderRead | vk::AccessFlagBits::eShaderWrite | vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) : (vk::AccessFlagBits::eTransferRead | vk::AccessFlagBits::eTransferWrite) } + } }, + {}, + {} + ); +} + +static void ggml_vk_wait_events(vk_context& ctx, std::vector&& events) { + VK_LOG_DEBUG("ggml_vk_wait_events()"); + if (events.empty()) { + return; + } + + ctx->s->buffer.waitEvents( + events, + ctx->p->q->stage_flags, + ctx->p->q->stage_flags, + {}, + {}, + {} + ); +} + +// number of rows/cols for flash attention shader +static constexpr uint32_t flash_attention_num_small_rows = 32; +static constexpr uint32_t scalar_flash_attention_num_small_rows = 1; + +static uint32_t get_fa_scalar_num_large_rows(uint32_t hsv) { + if (hsv >= 192) { + return 2; + } else { + return 8; + } +} + +// The FA coopmat1 shader assumes 16x16x16 matrix multiply support. +// 128 threads split into four subgroups, each subgroup does 1/4 +// of the Bc dimension. +static constexpr uint32_t coopmat1_flash_attention_num_large_rows = 16; +static constexpr uint32_t scalar_flash_attention_Bc = 64; +static constexpr uint32_t scalar_flash_attention_workgroup_size = 128; + +static uint32_t get_fa_num_small_rows(FaCodePath path) { + if (path == FA_COOPMAT2) { + return flash_attention_num_small_rows; + } else { + return scalar_flash_attention_num_small_rows; + } +} + +static std::array fa_rows_cols(FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) { + GGML_UNUSED(clamp); + GGML_UNUSED(hsv); + + if (path == FA_SCALAR) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, 64}; + } else { + if ((hsv | hsk) & 8) { + // HSV/HSK not being a multiple of 16 makes D_split smaller, which makes cols_per_iter + // larger, and Bc needs to be >= cols_per_thread. 64 is large enough, 32 is not. + return {get_fa_scalar_num_large_rows(hsv), 64}; + } else { + return {get_fa_scalar_num_large_rows(hsv), 32}; + } + } + } + + if (path == FA_COOPMAT1) { + if (small_rows) { + return {scalar_flash_attention_num_small_rows, scalar_flash_attention_Bc}; + } else { + return {coopmat1_flash_attention_num_large_rows, scalar_flash_attention_Bc}; + } + } + + // small rows, large cols + if (small_rows) { + return {get_fa_num_small_rows(FA_COOPMAT2), 32}; + } + + // small cols to reduce register count + if (ggml_is_quantized(type) || hsk >= 256 || hsv >= 256) { + if (hsk >= 512 || hsv >= 512) { + return {32, 32}; + } else { + return {64, 32}; + } + } + return {64, 64}; +} + +static uint32_t fa_align(FaCodePath path, uint32_t hsk, uint32_t hsv, ggml_type type, bool small_rows) { + return fa_rows_cols(path, hsk, hsv, 0, type, small_rows)[1]; +} + +static bool ggml_vk_matmul_shmem_support(const vk_device& device, const std::vector& warptile, bool mul_mat_id, ggml_type src0_type) { + + uint32_t lut_size = 0; + switch (src0_type) { + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + lut_size = 2*2048; + break; + case GGML_TYPE_IQ2_XXS: + lut_size = 8*256; + break; + case GGML_TYPE_IQ2_XS: + lut_size = 8*512; + break; + case GGML_TYPE_IQ2_S: + lut_size = 8*1024; + break; + case GGML_TYPE_IQ3_XXS: + lut_size = 4*256; + break; + case GGML_TYPE_IQ3_S: + lut_size = 4*512; + break; + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_MXFP4: + lut_size = 4*16; + break; + default: + break; + } + + // Needs to be kept up to date on shader changes + const uint32_t bank_conflict_offset = device->coopmat_support ? 8 : 1; + const uint32_t type_size = device->fp16 ? sizeof(ggml_fp16_t) : sizeof(float); + const uint32_t warps = warptile[0] / warptile[10]; + + const uint32_t load_bufs = (warptile[1] + warptile[2]) * (warptile[3] + bank_conflict_offset) * type_size; + const uint32_t mmid_row_ids = mul_mat_id ? (warptile[2] * 2 * sizeof(uint16_t)) : 0; + const uint32_t coopmat_stage = device->coopmat_support ? warptile[7] * warptile[8] / warps * sizeof(float) : 0; + const uint32_t ballots_sh = mul_mat_id ? (warps * 4 * sizeof(uint32_t)) : 0; + + const uint32_t total_size = load_bufs + mmid_row_ids + coopmat_stage + lut_size + ballots_sh; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_matmul_shmem_support(warptile=(" << warptile[0] << "," << warptile[1] << "," << warptile[2] << "), " + "mul_mat_id=" << mul_mat_id << ", src0_type=" << ggml_type_name(src0_type) << ", supported=" << supported); + + return supported; +} + +struct GpuPipelineConfig { + // GPU architecture identifier. + // Example: vk_device_architecture::AMD_GCN + vk_device_architecture arch; + + // Mapping of pipeline names to their specific subgroup sizes. + // Example: {"soft_max_f32", 64} + std::unordered_map pipelines; + + // Default subgroup size for this GPU. + // Defaults to 0 if not explicitly provided. + uint32_t default_subgroup_size = 0; +}; + +// Pipeline configuration for RDNA1 GPUs. +static const std::unordered_map rdna1_pipelines = { + {"soft_max", 64}, {"im2col", 64}, + {"argmax", 64}, {"mul_mat_vec", 64}, + {"mul_mat_vec_f16", 32}, {"mul_mat_vec_f32_f16", 32} +}; + +// Pipeline configuration for RDNA2 GPUs. +static const std::unordered_map rdna2_pipelines = { + {"soft_max", 64}, {"im2col", 64}, +}; + +static constexpr uint32_t RDNA_DEFAULT_SUBGROUP_SIZE = 32; + +// Define configurations for different GPUs. +static std::vector gpu_pipeline_configs = { + { + vk_device_architecture::AMD_RDNA1, + { + rdna1_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, + { + vk_device_architecture::AMD_RDNA2, + { + rdna2_pipelines, + }, + RDNA_DEFAULT_SUBGROUP_SIZE + }, +}; + +static uint32_t get_subgroup_size(const std::string &pipeline_name, const vk_device_architecture &arch) { + for (const auto &config : gpu_pipeline_configs) { + if (config.arch == arch) { + auto pipIt = config.pipelines.find(pipeline_name); + if (pipIt != config.pipelines.end()) { + return pipIt->second; + } + std::vector> sorted_pipelines(config.pipelines.begin(), config.pipelines.end()); + std::sort(sorted_pipelines.begin(), sorted_pipelines.end(), + [](const auto &a, const auto &b) { return a.first.size() > b.first.size(); }); + for (const auto &entry : sorted_pipelines) { + if (pipeline_name.find(entry.first) != std::string::npos) { + return entry.second; + } + } + return config.default_subgroup_size; + } + } + return 0; // If no matching configuration is found +} + +static void ggml_vk_load_shaders(vk_device& device) { + VK_LOG_DEBUG("ggml_vk_load_shaders(" << device->name << ")"); + + // some shaders have a minimum subgroup size + const uint32_t subgroup_size_8 = std::max(device->subgroup_size, 8u); + const uint32_t subgroup_size_16 = std::max(device->subgroup_size, 16u); + const uint32_t subgroup_size_32 = std::max(device->subgroup_size, 32u); + + const uint32_t mul_mat_subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t mul_mat_subgroup_size_8 = std::max(mul_mat_subgroup_size, 8u); + const uint32_t mul_mat_subgroup_size_16 = std::max(mul_mat_subgroup_size, 16u); + const uint32_t mul_mat_subgroup_size_32 = std::max(mul_mat_subgroup_size, 32u); + + const bool subgroup_min_size_16 = (!device->subgroup_size_control && device->subgroup_size >= 16) || + (device->subgroup_size_control && device->subgroup_max_size >= 16); + + // mulmat + std::vector l_warptile, m_warptile, s_warptile, + l_warptile_id, m_warptile_id, s_warptile_id, + l_warptile_mmq, m_warptile_mmq, s_warptile_mmq, + l_warptile_mmq_int, m_warptile_mmq_int, s_warptile_mmq_int, + l_warptile_mmq_k, m_warptile_mmq_k, s_warptile_mmq_k, + l_warptile_mmqid, m_warptile_mmqid, s_warptile_mmqid; + std::array l_wg_denoms, m_wg_denoms, s_wg_denoms, + l_mmq_wg_denoms, m_mmq_wg_denoms, s_mmq_wg_denoms, + l_mmq_wg_denoms_k, m_mmq_wg_denoms_k, s_mmq_wg_denoms_k, + l_mmqid_wg_denoms, m_mmqid_wg_denoms, s_mmqid_wg_denoms; + + uint32_t l_align, m_align, s_align; + if (device->coopmat2) { + // spec constants and tile sizes for non-quant matmul/matmul_id + l_warptile = { 256, 128, 256, 64, 1 }; + m_warptile = { 256, 128, 128, 64, 0 }; + s_warptile = { 128, 64, 64, 64, 0 }; + l_wg_denoms = {128, 256, 1 }; + m_wg_denoms = {128, 128, 1 }; + s_wg_denoms = { 64, 64, 1 }; + + // spec constants and tile sizes for quant matmul (non-Qi_K) + l_warptile_mmq = { 256, 128, 256, 64, 1 }; + m_warptile_mmq = { 256, 128, 128, 64, 1 }; + s_warptile_mmq = { 256, 32, 64, 128, 0 }; + l_mmq_wg_denoms = { 128, 256, 1 }; + m_mmq_wg_denoms = { 128, 128, 1 }; + s_mmq_wg_denoms = { 32, 64, 1 }; + + // spec constants and tile sizes for quant matmul (Qi_K) + l_warptile_mmq_k = { 256, 128, 256, 64, 1 }; + m_warptile_mmq_k = { 256, 128, 128, 64, 1 }; + s_warptile_mmq_k = { 256, 32, 64, 128, 0 }; + l_mmq_wg_denoms_k = { 128, 256, 1 }; + m_mmq_wg_denoms_k = { 128, 128, 1 }; + s_mmq_wg_denoms_k = { 32, 64, 1 }; + + // spec constants and tile sizes for quant matmul_id + l_warptile_mmqid = { 256, 128, 128, 16, 1, device->subgroup_size }; + m_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + s_warptile_mmqid = { 256, 128, 64, 16, 0, device->subgroup_size }; + l_mmqid_wg_denoms = { 128, 128, 1 }; + m_mmqid_wg_denoms = { 128, 64, 1 }; + s_mmqid_wg_denoms = { 128, 64, 1 }; + + l_align = 128; + m_align = 64; + s_align = 32; + } else { + // Matrix cores require different warp group sizes + const uint32_t tm_l = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_m = device->coopmat_support ? device->coopmat_m : 4; + const uint32_t tm_s = device->coopmat_support ? device->coopmat_m : 2; + const uint32_t tn_l = device->coopmat_support ? device->coopmat_n : 4; + const uint32_t tn_m = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tn_s = device->coopmat_support ? device->coopmat_n : 2; + const uint32_t tk_l = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_m = device->coopmat_support ? device->coopmat_k : 1; + const uint32_t tk_s = device->coopmat_support ? device->coopmat_k : 1; + + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, subgroup_size_8 }; + m_warptile_mmq = { 128, 64, 64, 32, subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, subgroup_size_8 }; + s_warptile_mmq = { subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, subgroup_size_8 }; + + l_warptile_mmq_int = { 128, 128, 128, 32, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile_mmq_int = { 128, 64, 64, 32, subgroup_size_8, 32, 2, 2, 2, 1, subgroup_size_8 }; + s_warptile_mmq_int = { subgroup_size_32, 32, 32, 32, 32, 32, 2, 2, 1, 1, subgroup_size_8 }; + + l_warptile_id = { 128, 128, 128, 16, mul_mat_subgroup_size_16 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_16 }; + m_warptile_id = { 128, 64, 64, 16, mul_mat_subgroup_size_16, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_16 }; + s_warptile_id = { mul_mat_subgroup_size_16, 32, 32, 16, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_16 }; + + l_warptile_mmqid = { 128, 128, 128, 32, mul_mat_subgroup_size_8 * 2, 64, 2, tm_l, tn_l, tk_l, mul_mat_subgroup_size_8 }; + m_warptile_mmqid = { 128, 64, 64, 32, mul_mat_subgroup_size_8, 32, 2, tm_m, tn_m, tk_m, mul_mat_subgroup_size_8 }; + s_warptile_mmqid = { mul_mat_subgroup_size_32, 32, 32, 32, 32, 32, 2, tm_s, tn_s, tk_s, mul_mat_subgroup_size_8 }; + + // chip specific tuning + if ((device->architecture == AMD_GCN) && (device->driver_id != vk::DriverId::eAmdProprietary)) { + m_warptile_mmq = m_warptile_mmq_int = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + m_warptile_mmqid = { 256, 64, 64, 32, 16, 16, 2, 2, 2, 1, 16 }; + } + + l_mmq_wg_denoms = l_wg_denoms = {128, 128, 1 }; + m_mmq_wg_denoms = m_wg_denoms = { 64, 64, 1 }; + s_mmq_wg_denoms = s_wg_denoms = { 32, 32, 1 }; + l_align = 128; + m_align = 64; + s_align = 32; + + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + ggml_type t = (ggml_type)i; + // Disable medium and large matrix multiplication if not enough shared memory is available + // Check mmq warptiles as the largest configuration + // Throw an error if not enough for any matrix multiplication is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmq, false, t)) { + std::cerr << "ggml_vulkan: Error: Shared memory size too small for matrix multiplication." << std::endl; + throw std::runtime_error("Shared memory size too small for matrix multiplication."); + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmq, false, t)) { + device->mul_mat_m[i] = false; + device->mul_mat_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmq, false, t)) { + device->mul_mat_l[i] = false; + } + + // Disable mul_mat_id if not enough shared memory is available + if (!ggml_vk_matmul_shmem_support(device, s_warptile_mmqid, true, t)) { + device->mul_mat_id_s[i] = false; + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, m_warptile_mmqid, true, t)) { + device->mul_mat_id_m[i] = false; + device->mul_mat_id_l[i] = false; + } else if (!ggml_vk_matmul_shmem_support(device, l_warptile_mmqid, true, t)) { + device->mul_mat_id_l[i] = false; + } + } + } + + if (!device->pipeline_matmul_f32) { + device->pipeline_matmul_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_f32_f16) { + device->pipeline_matmul_f32_f16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_f32) { + device->pipeline_matmul_id_f32 = std::make_shared(); + } + if (!device->pipeline_matmul_bf16) { + device->pipeline_matmul_bf16 = std::make_shared(); + } + if (!device->pipeline_matmul_id_bf16) { + device->pipeline_matmul_id_bf16 = std::make_shared(); + } + + std::vector> compiles; + auto const &ggml_vk_create_pipeline = [&](vk_device& device, vk_pipeline& pipeline, const char *name, size_t spv_size, const void* spv_data, const char *entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + + if (!require_full_subgroups && required_subgroup_size == 0) { + required_subgroup_size = get_subgroup_size(name, device->architecture); + } + + if (!pipeline) { + pipeline = std::make_shared(); + } + if (!pipeline->initialized) { + pipeline->name = name; + pipeline->parameter_count = parameter_count; + pipeline->push_constant_size = push_constant_size; + pipeline->wg_denoms = wg_denoms; + pipeline->align = align; + pipeline->initialized = true; + } + + if (!pipeline->needed || pipeline->compiled) { + return; + } + { + // wait until fewer than N compiles are in progress + uint32_t N = std::max(1u, std::thread::hardware_concurrency()); + std::unique_lock guard(compile_count_mutex); + while (compile_count >= N) { + compile_count_cond.wait(guard); + } + compile_count++; + } + + compiles.push_back(std::async(ggml_vk_create_pipeline_func, std::ref(device), std::ref(pipeline), spv_size, spv_data, entrypoint, + parameter_count, wg_denoms, specialization_constants, disable_robustness, require_full_subgroups, required_subgroup_size)); + }; + + auto const &ggml_vk_create_pipeline2 = [&](vk_device& device, vk_pipeline& pipeline, const std::string &name, size_t spv_size, const void* spv_data, const char *entrypoint, + uint32_t parameter_count, uint32_t push_constant_size, std::array wg_denoms, const std::vector& specialization_constants, + uint32_t align, bool disable_robustness = false, bool require_full_subgroups = false, uint32_t required_subgroup_size = 0) { + return ggml_vk_create_pipeline(device, pipeline, name.c_str(), spv_size, spv_data, entrypoint, + parameter_count, push_constant_size, wg_denoms, specialization_constants, + align, disable_robustness, require_full_subgroups, required_subgroup_size); + }; + + auto const &fa_wg_denoms = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::array { + return {fa_rows_cols(path, hsk, hsv, clamp, type, small_rows)[0], 1, 1}; + }; + + auto const &fa_spec_constants = [&](FaCodePath path, uint32_t hsk, uint32_t hsv, uint32_t clamp, ggml_type type, bool small_rows) -> std::vector { + // For large number of rows, 128 invocations seems to work best. + // For small number of rows (e.g. N==1), 256 works better. But matrix granularity for 256 is 32, so we + // can't use 256 for D==80. + // For scalar, use 128 (arbitrary) + // The same D_split value is used for both HSK and HSV, so just base it on the union of the LSBs. + const uint32_t D = (hsk|hsv); + uint32_t wg_size = (path == FA_SCALAR || path == FA_COOPMAT1) + ? scalar_flash_attention_workgroup_size + : ((small_rows && (D % 32) == 0) ? 256 : 128); + auto rows_cols = fa_rows_cols(path, hsk, hsv, clamp, type, small_rows); + + // D_split can't be larger than a subgroup because we use subgroupShuffle to reduce it. + // D_split can't be larger than the LSB of D divided by 4 due to vectorization in the shader. + const uint32_t D_lsb = D ^ (D & (D-1)); + uint32_t D_split = std::min(std::min(device->subgroup_size, 8u), D_lsb / 4); + + return {wg_size, rows_cols[0], rows_cols[1], hsk, hsv, clamp, D_split}; + }; + +#define CREATE_FA(TYPE, NAMELC, FAPATH, SUFFIX) \ + for (auto &fa : device->pipeline_flash_attn_f32_f16[TYPE]) { \ + uint32_t HSK = fa.first.HSK; \ + uint32_t HSV = fa.first.HSV; \ + bool small_rows = fa.first.small_rows; \ + FaCodePath path = fa.first.path; \ + bool aligned = fa.first.aligned; \ + bool f32acc = fa.first.f32acc; \ + if (path == FAPATH) { \ + if (aligned) { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_aligned_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,0,TYPE,small_rows), fa_align(FAPATH,HSK,HSV,TYPE,small_rows), true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } else { \ + if (f32acc) { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f32acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } else { \ + ggml_vk_create_pipeline(device, fa.second, "flash_attn_f32_f16_f16acc" #NAMELC, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _len, flash_attn_f32_f16_ ## NAMELC ## _f16acc ## SUFFIX ## _data, "main", 6, sizeof(vk_flash_attn_push_constants), fa_wg_denoms(FAPATH, HSK,HSV,1,TYPE,small_rows), fa_spec_constants(FAPATH, HSK,HSV,1,TYPE,small_rows), 1, true, FAPATH==FA_COOPMAT1, (FAPATH==FA_COOPMAT1 ? 32 : 0)); \ + } \ + } \ + } \ + } + + CREATE_FA(GGML_TYPE_F16, f16, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_SCALAR, ) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_SCALAR, ) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat1_fa_support) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT1, _cm1) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT1, _cm1) + } +#endif +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_FA(GGML_TYPE_F16, f16, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_0, q4_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q4_1, q4_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_0, q5_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q5_1, q5_1, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_Q8_0, q8_0, FA_COOPMAT2, _cm2) + CREATE_FA(GGML_TYPE_IQ4_NL, iq4_nl, FA_COOPMAT2, _cm2) + } +#endif +#undef CREATE_FA + +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm2_len, NAMELC ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm2_len, NAMELC ## _aligned ## F16ACC ## _cm2_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + CREATE_MM(PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT) \ + + CREATE_MM2(pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3) + } +#endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_0], matmul_q4_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_1], matmul_q4_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_0], matmul_q5_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_1], matmul_q5_1_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q8_0], matmul_q8_0_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q2_K], matmul_q2_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q3_K], matmul_q3_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q4_K], matmul_q4_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q5_K], matmul_q5_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_Q6_K], matmul_q6_k_f16, mmq_wg_denoms_k, warptile_mmq_k, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_S], matmul_iq1_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ1_M], matmul_iq1_m_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ2_S], matmul_iq2_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ3_S], matmul_iq3_s_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + CREATE_MM2(pipeline_dequant_mul_mat_mat_f16[GGML_TYPE_MXFP4], matmul_mxfp4_f16, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3) + + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM2(pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4) + } +#endif + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) + CREATE_MM2(pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f16, mmqid_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4) +#undef CREATE_MM +#undef CREATE_MM2 + } else +#endif // defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) +#if defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->coopmat_support) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _cm1_len, NAMELC ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, true); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, true); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, true); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _cm1_len, NAMELC ## _aligned ## F16ACC ## _cm1_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, true); \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->coopmat_acc_f16_support) { \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + if (device->coopmat_acc_f32_support) { \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + } \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, ); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, ) + } +#endif + + if (device->coopmat_acc_f16_support) { + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } else { + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, ); + } + + GGML_ASSERT(device->subgroup_ballot); + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (device->coopmat_bf16_support) { + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id); + } +#endif + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_id_push_constants, 4, _id); +#undef CREATE_MM2 +#undef CREATE_MM + } else +#endif // defined(VK_KHR_cooperative_matrix) && defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (device->fp16) { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _len, NAMELC ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _len, NAMELC ## _aligned ## F16ACC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->l, #NAMELC "_f16acc_l", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->l, #NAMELC "_l", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _m[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->m, #NAMELC "_f16acc_m", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->m, #NAMELC "_m", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + } \ + if (device->mul_mat ## ID ## _s[TYPE]) { \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f16acc->s, #NAMELC "_f16acc_s", NAMELC ## _f16acc_len, NAMELC ## _f16acc_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME .f32acc->s, #NAMELC "_s", NAMELC ## _len, NAMELC ## _data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + } \ + + // Create 2 variants, {f16,f32} accumulator +#define CREATE_MM2(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f16acc, NAMELC, _f16acc, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + CREATE_MM(TYPE, PIPELINE_NAME . f32acc, NAMELC, , WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16, matmul_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_f16_f32, matmul_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0], matmul_q4_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1], matmul_q4_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0], matmul_q5_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1], matmul_q5_1_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0], matmul_q8_0_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K], matmul_q2_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K], matmul_q3_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K], matmul_q4_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K], matmul_q5_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K], matmul_q6_k_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S], matmul_iq1_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M], matmul_iq1_m_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS], matmul_iq2_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS], matmul_iq2_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S], matmul_iq2_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS], matmul_iq3_xxs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S], matmul_iq3_s_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS], matmul_iq4_xs_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL], matmul_iq4_nl_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4], matmul_mxfp4_f32, mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0], matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1], matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0], matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1], matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0], matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_subgroup_f16, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_subgroup_f16_f32, wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_subgroup_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_subgroup_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_subgroup_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_subgroup_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_subgroup_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_subgroup_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_subgroup_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_subgroup_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_subgroup_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_subgroup_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_subgroup_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_subgroup_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_subgroup_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_subgroup_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_subgroup_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_subgroup_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_subgroup_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_subgroup_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_subgroup_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_subgroup_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16, matmul_id_f16, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_F16, pipeline_matmul_id_f16_f32, matmul_id_f16_f32, wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MM2(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0], matmul_id_q4_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1], matmul_id_q4_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0], matmul_id_q5_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1], matmul_id_q5_1_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0], matmul_id_q8_0_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K], matmul_id_q2_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K], matmul_id_q3_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K], matmul_id_q4_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K], matmul_id_q5_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K], matmul_id_q6_k_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S], matmul_id_iq1_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M], matmul_id_iq1_m_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS], matmul_id_iq2_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS], matmul_id_iq2_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S], matmul_id_iq2_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS], matmul_id_iq3_xxs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S], matmul_id_iq3_s_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS], matmul_id_iq4_xs_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL], matmul_id_iq4_nl_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM2(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4], matmul_id_mxfp4_f32, mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } +#undef CREATE_MM2 +#undef CREATE_MMQ +#undef CREATE_MM + } else { + // Create 6 variants, {s,m,l}x{unaligned,aligned} +#define CREATE_MM(TYPE, PIPELINE_NAME, NAMELC, F16ACC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID, REQSUBGROUPSIZE) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC #F16ACC "_l", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC #F16ACC "_m", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC #F16ACC "_s", NAMELC ## F16ACC ## _fp32_len, NAMELC ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_l, #NAMELC #F16ACC "_aligned_l", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, l_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_m, #NAMELC #F16ACC "_aligned_m", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, m_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->a_s, #NAMELC #F16ACC "_aligned_s", NAMELC ## _aligned ## F16ACC ## _fp32_len, NAMELC ## _aligned ## F16ACC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, s_align, false, REQSUBGROUPSIZE > 0, REQSUBGROUPSIZE); \ + +#define CREATE_MMQ(TYPE, PIPELINE_NAME, NAMELC, WG_DENOMS, WARPTILE, PUSHCONST, PARAMCOUNT, ID) \ + if (device->mul_mat ## ID ## _l[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->l, #NAMELC "_l", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), l_ ## WG_DENOMS, l_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _m[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->m, #NAMELC "_m", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), m_ ## WG_DENOMS, m_ ## WARPTILE, 1); \ + if (device->mul_mat ## ID ## _s[TYPE]) \ + ggml_vk_create_pipeline(device, device-> PIPELINE_NAME ->s, #NAMELC "_s", NAMELC ## _fp32_len, NAMELC ## _fp32_data, "main", PARAMCOUNT, sizeof(PUSHCONST), s_ ## WG_DENOMS, s_ ## WARPTILE, 1); \ + + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32, matmul_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_f32_f16, matmul_f32_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16.f32acc, matmul_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_f16_f32.f32acc, matmul_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q2_K].f32acc, matmul_q2_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q3_K].f32acc, matmul_q3_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q4_K].f32acc, matmul_q4_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q5_K].f32acc, matmul_q5_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat[GGML_TYPE_Q6_K].f32acc, matmul_q6_k_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_S].f32acc, matmul_iq1_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ1_M].f32acc, matmul_iq1_m_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XXS].f32acc, matmul_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_XS].f32acc, matmul_iq2_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ2_S].f32acc, matmul_iq2_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_XXS].f32acc, matmul_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ3_S].f32acc, matmul_iq3_s_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_XS].f32acc, matmul_iq4_xs_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat[GGML_TYPE_IQ4_NL].f32acc, matmul_iq4_nl_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat[GGML_TYPE_MXFP4].f32acc, matmul_mxfp4_f32, , mmq_wg_denoms, warptile_mmq, vk_mat_mat_push_constants, 3, , 0); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + CREATE_MMQ(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_0].f32acc, matmul_q4_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q4_1].f32acc, matmul_q4_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_0].f32acc, matmul_q5_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q5_1].f32acc, matmul_q5_1_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + CREATE_MMQ(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_q8_1[GGML_TYPE_Q8_0].f32acc, matmul_q8_0_q8_1, mmq_wg_denoms, warptile_mmq_int, vk_mat_mat_push_constants, 3, ); + } +#endif + + if (device->subgroup_ballot && device->subgroup_require_full_support && subgroup_min_size_16) { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_subgroup_f32_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_subgroup_f16, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_subgroup_f16_f32, , wg_denoms, warptile_id, vk_mat_mat_push_constants, 4, _id, mul_mat_subgroup_size_16); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_subgroup_bf16, , wg_denoms, warptile_id, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size_16); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_subgroup_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_subgroup_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_subgroup_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_subgroup_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_subgroup_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_subgroup_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_subgroup_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_subgroup_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_subgroup_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_subgroup_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_subgroup_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_subgroup_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_subgroup_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_subgroup_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_subgroup_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_subgroup_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_subgroup_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_subgroup_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_subgroup_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_subgroup_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, mul_mat_subgroup_size); + } else { + CREATE_MM(GGML_TYPE_F32, pipeline_matmul_id_f32, matmul_id_f32_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16.f32acc, matmul_id_f16, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_F16, pipeline_matmul_id_f16_f32.f32acc, matmul_id_f16_f32, , wg_denoms, warptile, vk_mat_mat_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + + CREATE_MM(GGML_TYPE_Q4_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_0].f32acc, matmul_id_q4_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_1].f32acc, matmul_id_q4_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_0].f32acc, matmul_id_q5_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_1, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_1].f32acc, matmul_id_q5_1_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q8_0, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q8_0].f32acc, matmul_id_q8_0_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q2_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q2_K].f32acc, matmul_id_q2_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q3_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q3_K].f32acc, matmul_id_q3_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q4_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q4_K].f32acc, matmul_id_q4_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q5_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q5_K].f32acc, matmul_id_q5_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_Q6_K, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_Q6_K].f32acc, matmul_id_q6_k_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_S].f32acc, matmul_id_iq1_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ1_M, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ1_M].f32acc, matmul_id_iq1_m_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XXS].f32acc, matmul_id_iq2_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_XS].f32acc, matmul_id_iq2_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ2_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ2_S].f32acc, matmul_id_iq2_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_XXS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_XXS].f32acc, matmul_id_iq3_xxs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ3_S, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ3_S].f32acc, matmul_id_iq3_s_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_XS, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_XS].f32acc, matmul_id_iq4_xs_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_IQ4_NL, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_IQ4_NL].f32acc, matmul_id_iq4_nl_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + CREATE_MM(GGML_TYPE_MXFP4, pipeline_dequant_mul_mat_mat_id[GGML_TYPE_MXFP4].f32acc, matmul_id_mxfp4_f32, , mmq_wg_denoms, warptile_mmqid, vk_mat_mat_id_push_constants, 4, _id, 0); + } + } + // reusing CREATE_MM from the fp32 path + if ((device->coopmat2 || device->coopmat_support) +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + && !device->coopmat_bf16_support +#endif + ) { + // use scalar tile sizes + l_warptile = { 128, 128, 128, 16, subgroup_size_8 * 2, 64, 2, 4, 4, 1, subgroup_size_8 }; + m_warptile = { 128, 64, 64, 16, subgroup_size_8, 32, 2, 4, 2, 1, subgroup_size_8 }; + s_warptile = { subgroup_size_16, 32, 32, 16, 32, 32, 2, 2, 2, 1, subgroup_size_8 }; + + l_wg_denoms = {128, 128, 1 }; + m_wg_denoms = { 64, 64, 1 }; + s_wg_denoms = { 32, 32, 1 }; + + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_bf16, matmul_bf16, , wg_denoms, warptile, vk_mat_mat_push_constants, 3, , 0); + CREATE_MM(GGML_TYPE_BF16, pipeline_matmul_id_bf16, matmul_id_bf16, , wg_denoms, warptile, vk_mat_mat_id_push_constants, 4, _id, 0); + } +#undef CREATE_MM + + // mul mat vec + + // the number of rows computed per shader depends on GPU model and quant + uint32_t rm_stdq = 1; + uint32_t rm_kq = 2; + if (device->vendor_id == VK_VENDOR_ID_AMD) { + if (device->architecture == AMD_GCN) { + rm_stdq = 2; + rm_kq = 4; + } + } else if (device->vendor_id == VK_VENDOR_ID_INTEL) + rm_stdq = 2; + uint32_t rm_iq = 2 * rm_kq; + + const bool use_subgroups = device->subgroup_arithmetic && device->architecture != vk_device_architecture::AMD_GCN; + // Ensure a subgroup size >= 16 is available + const bool use_subgroups16 = use_subgroups && subgroup_min_size_16; + + const uint32_t subgroup_size = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control && device->subgroup_min_size <= 16 && device->subgroup_max_size >= 16) ? 16 : device->subgroup_size; + const uint32_t subgroup_size16 = std::max(subgroup_size, 16u); + + const uint32_t force_subgroup_size = use_subgroups ? subgroup_size : 0; + const uint32_t force_subgroup_size16 = use_subgroups16 ? subgroup_size16 : 0; + + for (uint32_t w = 0; w < DMMV_WG_SIZE_COUNT; ++w) { + const uint32_t wg_size_subgroup = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size : (subgroup_size * 4); + const uint32_t wg_size_subgroup16 = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size16 : (subgroup_size16 * 4); + + const shader_reduction_mode reduc = (use_subgroups && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; + + const shader_reduction_mode reduc16 = (use_subgroups16 && w == DMMV_WG_SIZE_SUBGROUP) ? SHADER_REDUCTION_MODE_SUBGROUP : + (use_subgroups16 && w == DMMV_WG_SIZE_LARGE) ? SHADER_REDUCTION_MODE_HYBRID : + SHADER_REDUCTION_MODE_SHMEM; + + for (uint32_t i = 0; i < mul_mat_vec_max_cols; ++i) { + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f32_f32", arr_dmmv_f32_f32_f32_len[reduc], arr_dmmv_f32_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f32_f32", arr_dmmv_f16_f32_f32_len[reduc], arr_dmmv_f16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f32_f32", arr_dmmv_bf16_f32_f32_len[reduc], arr_dmmv_bf16_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f32_f32", arr_dmmv_q4_0_f32_f32_len[reduc], arr_dmmv_q4_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f32_f32", arr_dmmv_q4_1_f32_f32_len[reduc], arr_dmmv_q4_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f32_f32", arr_dmmv_q5_0_f32_f32_len[reduc], arr_dmmv_q5_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f32_f32", arr_dmmv_q5_1_f32_f32_len[reduc], arr_dmmv_q5_1_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f32_f32", arr_dmmv_q8_0_f32_f32_len[reduc], arr_dmmv_q8_0_f32_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f32_f32", arr_dmmv_q2_k_f32_f32_len[reduc16], arr_dmmv_q2_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f32_f32", arr_dmmv_q3_k_f32_f32_len[reduc16], arr_dmmv_q3_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f32_f32", arr_dmmv_q4_k_f32_f32_len[reduc16], arr_dmmv_q4_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f32_f32", arr_dmmv_q5_k_f32_f32_len[reduc16], arr_dmmv_q5_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f32_f32", arr_dmmv_q6_k_f32_f32_len[reduc16], arr_dmmv_q6_k_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f32_f32", arr_dmmv_iq1_s_f32_f32_len[reduc16], arr_dmmv_iq1_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f32_f32", arr_dmmv_iq1_m_f32_f32_len[reduc16], arr_dmmv_iq1_m_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f32_f32", arr_dmmv_iq2_xxs_f32_f32_len[reduc16], arr_dmmv_iq2_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f32_f32", arr_dmmv_iq2_xs_f32_f32_len[reduc16], arr_dmmv_iq2_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f32_f32", arr_dmmv_iq2_s_f32_f32_len[reduc16], arr_dmmv_iq2_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f32_f32", arr_dmmv_iq3_xxs_f32_f32_len[reduc16], arr_dmmv_iq3_xxs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f32_f32", arr_dmmv_iq3_s_f32_f32_len[reduc16], arr_dmmv_iq3_s_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f32_f32", arr_dmmv_iq4_xs_f32_f32_len[reduc16], arr_dmmv_iq4_xs_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f32_f32", arr_dmmv_iq4_nl_f32_f32_len[reduc16], arr_dmmv_iq4_nl_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f32_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f32_f32", arr_dmmv_mxfp4_f32_f32_len[reduc16], arr_dmmv_mxfp4_f32_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F32 ][i], "mul_mat_vec_f32_f16_f32", arr_dmmv_f32_f16_f32_len[reduc], arr_dmmv_f32_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_F16 ][i], "mul_mat_vec_f16_f16_f32", arr_dmmv_f16_f16_f32_len[reduc], arr_dmmv_f16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_BF16][i], "mul_mat_vec_bf16_f16_f32", arr_dmmv_bf16_f16_f32_len[reduc], arr_dmmv_bf16_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2, 1, 1}, {wg_size_subgroup, 2, i+1}, 1, false, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_f16_f32", arr_dmmv_q4_0_f16_f32_len[reduc], arr_dmmv_q4_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_f16_f32", arr_dmmv_q4_1_f16_f32_len[reduc], arr_dmmv_q4_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_f16_f32", arr_dmmv_q5_0_f16_f32_len[reduc], arr_dmmv_q5_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_f16_f32", arr_dmmv_q5_1_f16_f32_len[reduc], arr_dmmv_q5_1_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup, 2*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_f16_f32", arr_dmmv_q8_0_f16_f32_len[reduc], arr_dmmv_q8_0_f16_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup, 1*rm_stdq, i+1}, 1, true, use_subgroups, force_subgroup_size); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q2_K][i], "mul_mat_vec_q2_k_f16_f32", arr_dmmv_q2_k_f16_f32_len[reduc16], arr_dmmv_q2_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q3_K][i], "mul_mat_vec_q3_k_f16_f32", arr_dmmv_q3_k_f16_f32_len[reduc16], arr_dmmv_q3_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q4_K][i], "mul_mat_vec_q4_k_f16_f32", arr_dmmv_q4_k_f16_f32_len[reduc16], arr_dmmv_q4_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q5_K][i], "mul_mat_vec_q5_k_f16_f32", arr_dmmv_q5_k_f16_f32_len[reduc16], arr_dmmv_q5_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_Q6_K][i], "mul_mat_vec_q6_k_f16_f32", arr_dmmv_q6_k_f16_f32_len[reduc16], arr_dmmv_q6_k_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_kq, 1, 1}, {wg_size_subgroup16, rm_kq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_S][i], "mul_mat_vec_iq1_s_f16_f32", arr_dmmv_iq1_s_f16_f32_len[reduc16], arr_dmmv_iq1_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ1_M][i], "mul_mat_vec_iq1_m_f16_f32", arr_dmmv_iq1_m_f16_f32_len[reduc16], arr_dmmv_iq1_m_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XXS][i], "mul_mat_vec_iq2_xxs_f16_f32", arr_dmmv_iq2_xxs_f16_f32_len[reduc16], arr_dmmv_iq2_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_XS][i], "mul_mat_vec_iq2_xs_f16_f32", arr_dmmv_iq2_xs_f16_f32_len[reduc16], arr_dmmv_iq2_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ2_S][i], "mul_mat_vec_iq2_s_f16_f32", arr_dmmv_iq2_s_f16_f32_len[reduc16], arr_dmmv_iq2_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_XXS][i], "mul_mat_vec_iq3_xxs_f16_f32", arr_dmmv_iq3_xxs_f16_f32_len[reduc16], arr_dmmv_iq3_xxs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ3_S][i], "mul_mat_vec_iq3_s_f16_f32", arr_dmmv_iq3_s_f16_f32_len[reduc16], arr_dmmv_iq3_s_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_XS][i], "mul_mat_vec_iq4_xs_f16_f32", arr_dmmv_iq4_xs_f16_f32_len[reduc16], arr_dmmv_iq4_xs_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_IQ4_NL][i], "mul_mat_vec_iq4_nl_f16_f32", arr_dmmv_iq4_nl_f16_f32_len[reduc16], arr_dmmv_iq4_nl_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_f16_f32[w][GGML_TYPE_MXFP4][i], "mul_mat_vec_mxfp4_f16_f32", arr_dmmv_mxfp4_f16_f32_len[reduc16], arr_dmmv_mxfp4_f16_f32_data[reduc16], "main", 3, sizeof(vk_mat_vec_push_constants), {rm_iq, 1, 1}, {wg_size_subgroup16, rm_iq, i+1}, 1, true, use_subgroups16, force_subgroup_size16); + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (device->integer_dot_product) { + const uint32_t subgroup_size_int = (device->vendor_id == VK_VENDOR_ID_INTEL && device->subgroup_size_control) ? device->subgroup_min_size : device->subgroup_size; + const uint32_t wg_size_subgroup_int = (w == DMMV_WG_SIZE_SUBGROUP) ? subgroup_size_int : (subgroup_size_int * 4); + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_0][i], "mul_mat_vec_q4_0_q8_1_f32", arr_dmmv_q4_0_q8_1_f32_len[reduc], arr_dmmv_q4_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q4_1][i], "mul_mat_vec_q4_1_q8_1_f32", arr_dmmv_q4_1_q8_1_f32_len[reduc], arr_dmmv_q4_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_0][i], "mul_mat_vec_q5_0_q8_1_f32", arr_dmmv_q5_0_q8_1_f32_len[reduc], arr_dmmv_q5_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q5_1][i], "mul_mat_vec_q5_1_q8_1_f32", arr_dmmv_q5_1_q8_1_f32_len[reduc], arr_dmmv_q5_1_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {2*rm_stdq, 1, 1}, {wg_size_subgroup_int, 2*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_q8_1_f32[w][GGML_TYPE_Q8_0][i], "mul_mat_vec_q8_0_q8_1_f32", arr_dmmv_q8_0_q8_1_f32_len[reduc], arr_dmmv_q8_0_q8_1_f32_data[reduc], "main", 3, sizeof(vk_mat_vec_push_constants), {1*rm_stdq, 1, 1}, {wg_size_subgroup_int, 1*rm_stdq, i+1}, 1, true, use_subgroups, subgroup_size_int); + } +#endif // GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT + } + } + + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F32 ], "mul_mat_vec_id_f32_f32", mul_mat_vec_id_f32_f32_len, mul_mat_vec_id_f32_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_F16 ], "mul_mat_vec_id_f16_f32", mul_mat_vec_id_f16_f32_len, mul_mat_vec_id_f16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_BF16], "mul_mat_vec_id_bf16_f32", mul_mat_vec_id_bf16_f32_len, mul_mat_vec_id_bf16_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2, 1, 1}, {device->subgroup_size, 2}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_0], "mul_mat_vec_id_q4_0_f32", mul_mat_vec_id_q4_0_f32_len, mul_mat_vec_id_q4_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_1], "mul_mat_vec_id_q4_1_f32", mul_mat_vec_id_q4_1_f32_len, mul_mat_vec_id_q4_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_0], "mul_mat_vec_id_q5_0_f32", mul_mat_vec_id_q5_0_f32_len, mul_mat_vec_id_q5_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_1], "mul_mat_vec_id_q5_1_f32", mul_mat_vec_id_q5_1_f32_len, mul_mat_vec_id_q5_1_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {2*rm_stdq, 1, 1}, {device->subgroup_size, 2*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q8_0], "mul_mat_vec_id_q8_0_f32", mul_mat_vec_id_q8_0_f32_len, mul_mat_vec_id_q8_0_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {1*rm_stdq, 1, 1}, {device->subgroup_size, 1*rm_stdq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q2_K], "mul_mat_vec_id_q2_k_f32", mul_mat_vec_id_q2_k_f32_len, mul_mat_vec_id_q2_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q3_K], "mul_mat_vec_id_q3_k_f32", mul_mat_vec_id_q3_k_f32_len, mul_mat_vec_id_q3_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q4_K], "mul_mat_vec_id_q4_k_f32", mul_mat_vec_id_q4_k_f32_len, mul_mat_vec_id_q4_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q5_K], "mul_mat_vec_id_q5_k_f32", mul_mat_vec_id_q5_k_f32_len, mul_mat_vec_id_q5_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_Q6_K], "mul_mat_vec_id_q6_k_f32", mul_mat_vec_id_q6_k_f32_len, mul_mat_vec_id_q6_k_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_kq, 1, 1}, {subgroup_size_16, rm_kq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_S], "mul_mat_vec_id_iq1_s_f32", mul_mat_vec_id_iq1_s_f32_len, mul_mat_vec_id_iq1_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ1_M], "mul_mat_vec_id_iq1_m_f32", mul_mat_vec_id_iq1_m_f32_len, mul_mat_vec_id_iq1_m_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XXS], "mul_mat_vec_id_iq2_xxs_f32", mul_mat_vec_id_iq2_xxs_f32_len, mul_mat_vec_id_iq2_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_XS], "mul_mat_vec_id_iq2_xs_f32", mul_mat_vec_id_iq2_xs_f32_len, mul_mat_vec_id_iq2_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ2_S], "mul_mat_vec_id_iq2_s_f32", mul_mat_vec_id_iq2_s_f32_len, mul_mat_vec_id_iq2_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_XXS], "mul_mat_vec_id_iq3_xxs_f32", mul_mat_vec_id_iq3_xxs_f32_len, mul_mat_vec_id_iq3_xxs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ3_S], "mul_mat_vec_id_iq3_s_f32", mul_mat_vec_id_iq3_s_f32_len, mul_mat_vec_id_iq3_s_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_XS], "mul_mat_vec_id_iq4_xs_f32", mul_mat_vec_id_iq4_xs_f32_len, mul_mat_vec_id_iq4_xs_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_IQ4_NL], "mul_mat_vec_id_iq4_nl_f32", mul_mat_vec_id_iq4_nl_f32_len, mul_mat_vec_id_iq4_nl_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_dequant_mul_mat_vec_id_f32[GGML_TYPE_MXFP4], "mul_mat_vec_id_mxfp4_f32", mul_mat_vec_id_mxfp4_f32_len, mul_mat_vec_id_mxfp4_f32_data, "main", 4, sizeof(vk_mat_vec_id_push_constants), {rm_iq, 1, 1}, {subgroup_size_16, rm_iq}, 1, true); + + // dequant shaders + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_F32 ], "f32_to_f16", dequant_f32_len, dequant_f32_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_0], "dequant_q4_0", dequant_q4_0_len, dequant_q4_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_1], "dequant_q4_1", dequant_q4_1_len, dequant_q4_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_0], "dequant_q5_0", dequant_q5_0_len, dequant_q5_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_1], "dequant_q5_1", dequant_q5_1_len, dequant_q5_1_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q8_0], "dequant_q8_0", dequant_q8_0_len, dequant_q8_0_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q2_K], "dequant_q2_k", dequant_q2_k_len, dequant_q2_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q3_K], "dequant_q3_k", dequant_q3_k_len, dequant_q3_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q4_K], "dequant_q4_k", dequant_q4_k_len, dequant_q4_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q5_K], "dequant_q5_k", dequant_q5_k_len, dequant_q5_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_Q6_K], "dequant_q6_k", dequant_q6_k_len, dequant_q6_k_data, "main", 2, 5 * sizeof(uint32_t), {256 * 64, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_S], "dequant_iq1_s", dequant_iq1_s_len, dequant_iq1_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ1_M], "dequant_iq1_m", dequant_iq1_m_len, dequant_iq1_m_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XXS], "dequant_iq2_xxs", dequant_iq2_xxs_len, dequant_iq2_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_XS], "dequant_iq2_xs", dequant_iq2_xs_len, dequant_iq2_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ2_S], "dequant_iq2_s", dequant_iq2_s_len, dequant_iq2_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_XXS], "dequant_iq3_xxs", dequant_iq3_xxs_len, dequant_iq3_xxs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ3_S], "dequant_iq3_s", dequant_iq3_s_len, dequant_iq3_s_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_XS], "dequant_iq4_xs", dequant_iq4_xs_len, dequant_iq4_xs_data, "main", 2, 5 * sizeof(uint32_t), {256 * 32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_IQ4_NL], "dequant_iq4_nl", dequant_iq4_nl_len, dequant_iq4_nl_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_dequant[GGML_TYPE_MXFP4], "dequant_mxfp4", dequant_mxfp4_len, dequant_mxfp4_data, "main", 2, 5 * sizeof(uint32_t), {256 * 16, 1, 1}, {}, 1); + + // get_rows + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F32 ], "get_rows_f32", get_rows_f32_len, get_rows_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_F16 ], "get_rows_f16", get_rows_f16_len, get_rows_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_BF16], "get_rows_bf16", get_rows_bf16_len, get_rows_bf16_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_0], "get_rows_q4_0", get_rows_q4_0_len, get_rows_q4_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_1], "get_rows_q4_1", get_rows_q4_1_len, get_rows_q4_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_0], "get_rows_q5_0", get_rows_q5_0_len, get_rows_q5_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_1], "get_rows_q5_1", get_rows_q5_1_len, get_rows_q5_1_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q8_0], "get_rows_q8_0", get_rows_q8_0_len, get_rows_q8_0_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q2_K], "get_rows_q2_k", get_rows_q2_k_len, get_rows_q2_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q3_K], "get_rows_q3_k", get_rows_q3_k_len, get_rows_q3_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q4_K], "get_rows_q4_k", get_rows_q4_k_len, get_rows_q4_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q5_K], "get_rows_q5_k", get_rows_q5_k_len, get_rows_q5_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_Q6_K], "get_rows_q6_k", get_rows_q6_k_len, get_rows_q6_k_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_S], "get_rows_iq1_s", get_rows_iq1_s_len, get_rows_iq1_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ1_M], "get_rows_iq1_m", get_rows_iq1_m_len, get_rows_iq1_m_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs", get_rows_iq2_xxs_len, get_rows_iq2_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs", get_rows_iq2_xs_len, get_rows_iq2_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ2_S], "get_rows_iq2_s", get_rows_iq2_s_len, get_rows_iq2_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs", get_rows_iq3_xxs_len, get_rows_iq3_xxs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ3_S], "get_rows_iq3_s", get_rows_iq3_s_len, get_rows_iq3_s_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs", get_rows_iq4_xs_len, get_rows_iq4_xs_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl", get_rows_iq4_nl_len, get_rows_iq4_nl_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows[GGML_TYPE_MXFP4], "get_rows_mxfp4", get_rows_mxfp4_len, get_rows_mxfp4_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F32 ], "get_rows_f32_f32", get_rows_f32_f32_len, get_rows_f32_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_F16 ], "get_rows_f16_f32", get_rows_f16_f32_len, get_rows_f16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_BF16], "get_rows_bf16_f32", get_rows_bf16_f32_len, get_rows_bf16_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), { 512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_0], "get_rows_q4_0_f32", get_rows_q4_0_f32_len, get_rows_q4_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_1], "get_rows_q4_1_f32", get_rows_q4_1_f32_len, get_rows_q4_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_0], "get_rows_q5_0_f32", get_rows_q5_0_f32_len, get_rows_q5_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_1], "get_rows_q5_1_f32", get_rows_q5_1_f32_len, get_rows_q5_1_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q8_0], "get_rows_q8_0_f32", get_rows_q8_0_f32_len, get_rows_q8_0_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q2_K], "get_rows_q2_k_f32", get_rows_q2_k_f32_len, get_rows_q2_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q3_K], "get_rows_q3_k_f32", get_rows_q3_k_f32_len, get_rows_q3_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q4_K], "get_rows_q4_k_f32", get_rows_q4_k_f32_len, get_rows_q4_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q5_K], "get_rows_q5_k_f32", get_rows_q5_k_f32_len, get_rows_q5_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_Q6_K], "get_rows_q6_k_f32", get_rows_q6_k_f32_len, get_rows_q6_k_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_S], "get_rows_iq1_s_f32", get_rows_iq1_s_f32_len, get_rows_iq1_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ1_M], "get_rows_iq1_m_f32", get_rows_iq1_m_f32_len, get_rows_iq1_m_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XXS], "get_rows_iq2_xxs_f32", get_rows_iq2_xxs_f32_len, get_rows_iq2_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_XS], "get_rows_iq2_xs_f32", get_rows_iq2_xs_f32_len, get_rows_iq2_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ2_S], "get_rows_iq2_s_f32", get_rows_iq2_s_f32_len, get_rows_iq2_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_XXS], "get_rows_iq3_xxs_f32", get_rows_iq3_xxs_f32_len, get_rows_iq3_xxs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ3_S], "get_rows_iq3_s_f32", get_rows_iq3_s_f32_len, get_rows_iq3_s_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_XS], "get_rows_iq4_xs_f32", get_rows_iq4_xs_f32_len, get_rows_iq4_xs_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_IQ4_NL], "get_rows_iq4_nl_f32", get_rows_iq4_nl_f32_len, get_rows_iq4_nl_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_get_rows_f32[GGML_TYPE_MXFP4], "get_rows_mxfp4_f32", get_rows_mxfp4_f32_len, get_rows_mxfp4_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {1024, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_matmul_split_k_reduce, "split_k_reduce", split_k_reduce_len, split_k_reduce_data, "main", 2, 2 * sizeof(uint32_t), {256 * 4, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_flash_attn_split_k_reduce, "fa_split_k_reduce", fa_split_k_reduce_len, fa_split_k_reduce_data, "main", 3, 5 * sizeof(uint32_t), {1, device->subgroup_size, 1}, {device->subgroup_size}, 1, true); + + if (device->subgroup_clustered && device->subgroup_require_full_support) { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_subgroup_len, quantize_q8_1_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_subgroup_len, quantize_q8_1_x4_subgroup_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1, true, true); + } else { + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1, "quantize_q8_1", quantize_q8_1_len, quantize_q8_1_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_quantize_q8_1_x4, "quantize_q8_1_x4", quantize_q8_1_x4_len, quantize_q8_1_x4_data, "main", 2, 1 * sizeof(uint32_t), {32 * device->subgroup_size / 8, 1, 1}, { device->subgroup_size }, 1); + } + + for (uint32_t i = 0; i < p021_max_gqa_ratio; ++i) { + if (device->subgroup_arithmetic && device->subgroup_require_full_support) { + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_subgroup_add_len, mul_mat_vec_p021_f16_f32_subgroup_add_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true, true); + } else { + ggml_vk_create_pipeline2(device, device->pipeline_mul_mat_vec_p021_f16_f32[i], "mul_mat_vec_p021_f16_f32"+std::to_string(i+1), mul_mat_vec_p021_f16_f32_len, mul_mat_vec_p021_f16_f32_data, "main", 3, 6 * sizeof(uint32_t), {1, 1, 1}, {device->subgroup_size, i + 1}, 1, true); + } + } + ggml_vk_create_pipeline(device, device->pipeline_mul_mat_vec_nc_f16_f32, "mul_mat_vec_nc_f16_f32", mul_mat_vec_nc_f16_f32_len, mul_mat_vec_nc_f16_f32_data, "main", 3, 12 * sizeof(uint32_t), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_norm_f32, "norm_f32", norm_f32_len, norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_group_norm_f32, "group_norm_f32", group_norm_f32_len, group_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_f32, "rms_norm_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_f32, "rms_norm_mul_f32", rms_norm_f32_len, rms_norm_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_partials_f32, "rms_norm_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 0}, 1, true); + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_mul_partials_f32, "rms_norm_mul_partials_f32", rms_norm_partials_f32_len, rms_norm_partials_f32_data, "main", 4, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {0, 1}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_rms_norm_back_f32, "rms_norm_back_f32", rms_norm_back_f32_len, rms_norm_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_l2_norm_f32, "l2_norm_f32", l2_norm_f32_len, l2_norm_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f32, "cpy_f32_f32", cpy_f32_f32_len, cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_f16, "cpy_f32_f16", cpy_f32_f16_len, cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f16, "cpy_f16_f16", cpy_f16_f16_len, cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f16_f32, "cpy_f16_f32", cpy_f16_f32_len, cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_bf16,"cpy_f32_bf16",cpy_f32_bf16_len,cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_i32_f32, "cpy_i32_f32", cpy_i32_f32_len, cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_i32, "cpy_f32_i32", cpy_f32_i32_len, cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f32, "contig_cpy_f32_f32", contig_cpy_f32_f32_len, contig_cpy_f32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_f16, "contig_cpy_f32_f16", contig_cpy_f32_f16_len, contig_cpy_f32_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f16, "contig_cpy_f16_f16", contig_cpy_f16_f16_len, contig_cpy_f16_f16_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f16_f32, "contig_cpy_f16_f32", contig_cpy_f16_f32_len, contig_cpy_f16_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_bf16,"contig_cpy_f32_bf16",contig_cpy_f32_bf16_len,contig_cpy_f32_bf16_data,"main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_i32_f32, "contig_cpy_i32_f32", contig_cpy_i32_f32_len, contig_cpy_i32_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_contig_cpy_f32_i32, "contig_cpy_f32_i32", contig_cpy_f32_i32_len, contig_cpy_f32_i32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_rte_len, cpy_f32_q4_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_rte_len, cpy_f32_q4_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_rte_len, cpy_f32_q5_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_rte_len, cpy_f32_q5_1_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_rte_len, cpy_f32_q8_0_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_rte_len, cpy_f32_iq4_nl_rte_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_0], "cpy_f32_q4_0", cpy_f32_q4_0_len, cpy_f32_q4_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q4_1], "cpy_f32_q4_1", cpy_f32_q4_1_len, cpy_f32_q4_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_0], "cpy_f32_q5_0", cpy_f32_q5_0_len, cpy_f32_q5_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q5_1], "cpy_f32_q5_1", cpy_f32_q5_1_len, cpy_f32_q5_1_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_Q8_0], "cpy_f32_q8_0", cpy_f32_q8_0_len, cpy_f32_q8_0_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_f32_quant[GGML_TYPE_IQ4_NL], "cpy_f32_iq4_nl", cpy_f32_iq4_nl_len, cpy_f32_iq4_nl_data, "main", 2, sizeof(vk_op_unary_push_constants), {32, 1, 1}, {}, 1); + } + +#define SET_ROWS(itype, rte) \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F32], "set_rows_f32" #itype, set_rows_f32 ## itype ## rte ## _len, set_rows_f32 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_F16], "set_rows_f16" #itype, set_rows_f16 ## itype ## rte ## _len, set_rows_f16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_BF16], "set_rows_bf16" #itype, set_rows_bf16 ## itype ## rte ## _len, set_rows_bf16 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_0], "set_rows_q4_0" #itype, set_rows_q4_0 ## itype ## rte ## _len, set_rows_q4_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q4_1], "set_rows_q4_1" #itype, set_rows_q4_1 ## itype ## rte ## _len, set_rows_q4_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_0], "set_rows_q5_0" #itype, set_rows_q5_0 ## itype ## rte ## _len, set_rows_q5_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q5_1], "set_rows_q5_1" #itype, set_rows_q5_1 ## itype ## rte ## _len, set_rows_q5_1 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_Q8_0], "set_rows_q8_0" #itype, set_rows_q8_0 ## itype ## rte ## _len, set_rows_q8_0 ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_set_rows ## itype [GGML_TYPE_IQ4_NL], "set_rows_iq4_nl" #itype, set_rows_iq4_nl ## itype ## rte ## _len, set_rows_iq4_nl ## itype ## rte ## _data, "main", 3, sizeof(vk_op_binary_push_constants), {1, 1, 1}, {1}, 1, true); + + if (device->float_controls_rte_fp16) { + SET_ROWS(_i32, _rte) + SET_ROWS(_i64, _rte) + } else { + SET_ROWS(_i32, ) + SET_ROWS(_i64, ) + } +#undef SET_ROWS + + + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_0], "cpy_q4_0_f32", cpy_q4_0_f32_len, cpy_q4_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q4_1], "cpy_q4_1_f32", cpy_q4_1_f32_len, cpy_q4_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q4_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_0], "cpy_q5_0_f32", cpy_q5_0_f32_len, cpy_q5_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q5_1], "cpy_q5_1_f32", cpy_q5_1_f32_len, cpy_q5_1_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q5_1), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_Q8_0], "cpy_q8_0_f32", cpy_q8_0_f32_len, cpy_q8_0_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_Q8_0), 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cpy_quant_f32[GGML_TYPE_IQ4_NL], "cpy_iq4_nl_f32", cpy_iq4_nl_f32_len, cpy_iq4_nl_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {(uint32_t)ggml_blck_size(GGML_TYPE_IQ4_NL), 1, 1}, {}, 1); + + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + + bool rte = device->float_controls_rte_fp16; +#define CREATE_BINARY(name, namemod, spec, bindings) \ + for (int s0 : {0,1}) for (int s1 : {0,1}) for (int d : {0,1}) \ + ggml_vk_create_pipeline2(device, device->pipeline_ ## name ## namemod[s0][s1][d], \ + #name + get_suffix(s0, s1, d) + #namemod, name ## _len[s0][s1][d][rte], name ## _data[s0][s1][d][rte], \ + "main", (bindings), sizeof(vk_op_binary_push_constants), {512, 1, 1}, spec, 1); + + CREATE_BINARY(add, , {0}, 4) + CREATE_BINARY(add, _norepeat, {1}, 4) + CREATE_BINARY(sub, , {0}, 3) + CREATE_BINARY(sub, _norepeat, {1}, 3) + CREATE_BINARY(mul, , {0}, 3) + CREATE_BINARY(mul, _norepeat, {1}, 3) + CREATE_BINARY(div, , {0}, 3) + CREATE_BINARY(div, _norepeat, {1}, 3) + CREATE_BINARY(add_rms, , {0}, 4) + CREATE_BINARY(add_rms, _norepeat, {1}, 4) +#undef CREATE_BINARY + + if (device->multi_add) { + for (uint32_t i = 0; i < MAX_FUSED_ADDS; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_multi_add[i], "multi_add_f32_" + std::to_string(i+1), multi_add_f32_len, multi_add_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + ggml_vk_create_pipeline2(device, device->pipeline_multi_add_rms[i], "multi_add_rms_f32_" + std::to_string(i+1), multi_add_rms_f32_len, multi_add_rms_f32_data, "main", MAX_PARAMETER_COUNT, sizeof(vk_op_multi_add_push_constants), {512, 1, 1}, {i+2}, 1); + } + } + + ggml_vk_create_pipeline(device, device->pipeline_add_id_f32, "add_id_f32", add_id_f32_len, add_id_f32_data, "main", 4, sizeof(vk_op_add_id_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_acc_f32, "acc_f32", acc_f32_len, acc_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_concat_f32, "concat_f32", concat_f32_len, concat_f32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_f16, "concat_f16", concat_f16_len, concat_f16_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_concat_i32, "concat_i32", concat_i32_len, concat_i32_data, "main", 3, sizeof(vk_op_binary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_upscale_nearest_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_NEAREST}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR}, 1); + ggml_vk_create_pipeline(device, device->pipeline_upscale_bilinear_ac_f32, "upscale_f32", upscale_f32_len, upscale_f32_data, "main", 2, sizeof(vk_op_upscale_push_constants), {512, 1, 1}, {GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_scale_f32, "scale_f32", scale_f32_len, scale_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sqr_f32, "sqr_f32", sqr_f32_len, sqr_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sqrt_f32, "sqrt_f32", sqrt_f32_len, sqrt_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_sin_f32, "sin_f32", sin_f32_len, sin_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_cos_f32, "cos_f32", cos_f32_len, cos_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_clamp_f32, "clamp_f32", clamp_f32_len, clamp_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pad_f32, "pad_f32", pad_f32_len, pad_f32_data, "main", 2, sizeof(vk_op_pad_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_roll_f32, "roll_f32", roll_f32_len, roll_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_repeat_f32, "repeat_f32", repeat_f32_len, repeat_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_repeat_back_f32, "repeat_back_f32", repeat_back_f32_len, repeat_back_f32_data, "main", 2, sizeof(vk_op_unary_push_constants), {512, 1, 1}, {}, 1); + +#define CREATE_UNARY(name) \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + CREATE_UNARY(gelu) + CREATE_UNARY(gelu_erf) + CREATE_UNARY(gelu_quick) + CREATE_UNARY(silu) + CREATE_UNARY(relu) + CREATE_UNARY(tanh) + CREATE_UNARY(sigmoid) + CREATE_UNARY(hardsigmoid) + CREATE_UNARY(hardswish) +#undef CREATE_UNARY + +#define CREATE_UNARY_RTE(name) \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); \ + } + CREATE_UNARY_RTE(exp) +#undef CREATE_UNARY_RTE + +#define CREATE_GLU(name) \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32_rte", name ## _f32_rte_len, name ## _f32_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16_rte", name ## _f16_rte_len, name ## _f16_rte_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [0], #name "_f32", name ## _f32_len, name ## _f32_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_ ## name [1], #name "_f16", name ## _f16_len, name ## _f16_data, "main", 3, sizeof(vk_op_glu_push_constants), {512, 1, 1}, {}, 1, true); \ + } + + CREATE_GLU(geglu) + CREATE_GLU(reglu) + CREATE_GLU(swiglu) + CREATE_GLU(swiglu_oai) + CREATE_GLU(geglu_erf) + CREATE_GLU(geglu_quick) +#undef CREATE_GLU + + ggml_vk_create_pipeline(device, device->pipeline_leaky_relu_f32, "leaky_relu_f32", leaky_relu_f32_len, leaky_relu_f32_data, "main", 2, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_silu_back_f32, "silu_back_f32", silu_back_f32_len, silu_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_diag_mask_inf_f32, "diag_mask_inf_f32", diag_mask_inf_f32_len, diag_mask_inf_f32_data, "main", 2, sizeof(vk_op_diag_mask_push_constants), {1, 512, 1}, {}, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32, "soft_max_f32", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_wg512, "soft_max_f32_wg512", soft_max_f32_len, soft_max_f32_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16, "soft_max_f32_f16", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_f32_f16_wg512, "soft_max_f32_f16_wg512", soft_max_f32_f16_len, soft_max_f32_f16_data, "main", 4, sizeof(vk_op_soft_max_push_constants), {1, 1, 1}, { 512 }, 1); + ggml_vk_create_pipeline(device, device->pipeline_soft_max_back_f32, "soft_max_back_f32", soft_max_back_f32_len, soft_max_back_f32_data, "main", 3, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1, true); + + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f32, "rope_norm_f32", rope_norm_f32_len, rope_norm_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f32, "rope_neox_f32", rope_neox_f32_len, rope_neox_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f32, "rope_multi_f32", rope_multi_f32_len, rope_multi_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f32, "rope_vision_f32", rope_vision_f32_len, rope_vision_f32_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + + if (device->float_controls_rte_fp16) { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_rte_len, rope_norm_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_rte_len, rope_neox_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_rte_len, rope_multi_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_rte_len, rope_vision_f16_rte_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } else { + ggml_vk_create_pipeline(device, device->pipeline_rope_norm_f16, "rope_norm_f16", rope_norm_f16_len, rope_norm_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_neox_f16, "rope_neox_f16", rope_neox_f16_len, rope_neox_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_multi_f16, "rope_multi_f16", rope_multi_f16_len, rope_multi_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_rope_vision_f16, "rope_vision_f16", rope_vision_f16_len, rope_vision_f16_data, "main", 4, sizeof(vk_op_rope_push_constants), {1, 512, 1}, {}, 1); + } + + for (uint32_t i = 0; i < num_argsort_pipelines; ++i) { + ggml_vk_create_pipeline2(device, device->pipeline_argsort_f32[i], "argsort_f32_"+std::to_string(i), argsort_f32_len, argsort_f32_data, "main", 2, sizeof(vk_op_argsort_push_constants), {1u<pipeline_argmax_f32, "argmax_f32", argmax_f32_len, argmax_f32_data, "main", 2, sizeof(vk_op_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_sum_rows_f32, "sum_rows_f32", sum_rows_f32_len, sum_rows_f32_data, "main", 2, sizeof(vk_op_sum_rows_push_constants), {1, 1, 1}, { device->subgroup_size }, 1); + + ggml_vk_create_pipeline(device, device->pipeline_count_equal_i32, "count_equal_i32", count_equal_i32_len, count_equal_i32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, { device->subgroup_size }, 1); + +#define IM2COL(bda) \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32, "im2col_f32", im2col_f32 ## bda ## _len, im2col_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32, "im2col_3d_f32", im2col_3d_f32 ## bda ## _len, im2col_3d_f32 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + if (device->float_controls_rte_fp16) { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16_rte ## bda ## _len, im2col_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16_rte ## bda ## _len, im2col_3d_f32_f16_rte ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } else { \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_f32_f16, "im2col_f32_f16", im2col_f32_f16 ## bda ## _len, im2col_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_push_constants), {512, 1, 1}, { device->subgroup_size }, 1, true); \ + ggml_vk_create_pipeline(device, device->pipeline_im2col_3d_f32_f16, "im2col_3d_f32_f16", im2col_3d_f32_f16 ## bda ## _len, im2col_3d_f32_f16 ## bda ## _data, "main", 2, sizeof(vk_op_im2col_3d_push_constants), {512, 1, 1}, { 512 }, 1, true); \ + } + if (device->shader_int64 && device->buffer_device_address) { + IM2COL(_bda) + } else { + IM2COL() + } + + ggml_vk_create_pipeline(device, device->pipeline_timestep_embedding_f32, "timestep_embedding_f32", timestep_embedding_f32_len, timestep_embedding_f32_data, "main", 2, sizeof(vk_op_timestep_embedding_push_constants), {256, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_conv_transpose_1d_f32, "conv_transpose_1d_f32", conv_transpose_1d_f32_len, conv_transpose_1d_f32_data, "main", 3, sizeof(vk_op_conv_transpose_1d_push_constants), {1, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_pool2d_f32, "pool2d_f32", pool2d_f32_len, pool2d_f32_data, "main", 2, sizeof(vk_op_pool2d_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv6_f32, "rwkv_wkv6_f32", rwkv_wkv6_f32_len, rwkv_wkv6_f32_data, "main", 7, sizeof(vk_op_rwkv_wkv6_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_rwkv_wkv7_f32, "rwkv_wkv7_f32", rwkv_wkv7_f32_len, rwkv_wkv7_f32_data, "main", 8, sizeof(vk_op_rwkv_wkv7_push_constants), {1, 1, 1}, {device->subgroup_size}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_opt_step_adamw_f32, "opt_step_adamw_f32", opt_step_adamw_f32_len, opt_step_adamw_f32_data, "main", 5, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + ggml_vk_create_pipeline(device, device->pipeline_opt_step_sgd_f32, "opt_step_sgd_f32", opt_step_sgd_f32_len, opt_step_sgd_f32_data, "main", 3, sizeof(vk_op_push_constants), {512, 1, 1}, {}, 1); + + // conv2d, conv_transpose_2d + for (uint32_t s = 0; s < CONV_SHAPE_COUNT; ++s) { + uint32_t conv2d_WG_SIZE = 256; + uint32_t conv2d_BS_K = 128; + uint32_t conv2d_BS_CRS = 16; + uint32_t use_collectives = 0; // Enables subgroup ops for preventing the re-calculation of indices. + uint32_t conv2d_BS_NPQ = 128; + uint32_t conv2d_TS_K = 8; + uint32_t conv2d_SHMEM_PAD = 4; + bool conv2d_UNROLL = true; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + conv2d_SHMEM_PAD = 8; // 8 float16_t + } +#endif + + if (device->vendor_id == VK_VENDOR_ID_INTEL) { + conv2d_SHMEM_PAD = 0; + conv2d_UNROLL = false; + } else if (device->vendor_id == VK_VENDOR_ID_AMD) { + conv2d_SHMEM_PAD = device->architecture == vk_device_architecture::AMD_GCN ? 1 : 4; + } + + switch (s) { + default: + case CONV_SHAPE_128x128: + conv2d_BS_K = 128; + conv2d_BS_NPQ = 128; + conv2d_BS_CRS = 16; + if (device->vendor_id == VK_VENDOR_ID_AMD && device->architecture != vk_device_architecture::AMD_GCN) { + conv2d_UNROLL = false; + } + break; + case CONV_SHAPE_64x32: + conv2d_BS_K = 64; + conv2d_BS_NPQ = 32; + conv2d_BS_CRS = 32; + conv2d_TS_K = 4; + break; + case CONV_SHAPE_32x256: + conv2d_BS_K = 32; + conv2d_BS_NPQ = 256; + conv2d_BS_CRS = 16; + break; + } + + // Use collectives on pre-Turing NVIDIA GPUs and GCN AMD cards, which had slower integer math. + bool allow_collectives_nv = device->vendor_id != VK_VENDOR_ID_NVIDIA || + device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + bool allow_collectives_amd = device->vendor_id != VK_VENDOR_ID_AMD || + device->architecture == vk_device_architecture::AMD_GCN; + + if (device->subgroup_shuffle && + device->vendor_id != VK_VENDOR_ID_INTEL && // Do not enable collectives on Intel, see PR 14316. + allow_collectives_nv && + allow_collectives_amd) { + use_collectives = 1; + conv2d_BS_CRS = std::min( + device->subgroup_size, + conv2d_BS_CRS); // CRS block size should be capped at subgroup size for correctness when shuffle is used. + } + + uint32_t conv2d_shmem_req = + (conv2d_BS_K * (conv2d_BS_CRS + conv2d_SHMEM_PAD) + conv2d_BS_CRS * (conv2d_BS_NPQ + conv2d_SHMEM_PAD)) * sizeof(float); + if (device->properties.limits.maxComputeSharedMemorySize < conv2d_shmem_req) { + conv2d_BS_CRS = 8; + if (use_collectives) { + conv2d_BS_CRS = std::min(device->subgroup_size, conv2d_BS_CRS); + } + } + + std::array wg_denoms = { conv2d_BS_K, conv2d_BS_NPQ, 1 }; + std::vector spec_constants = { conv2d_WG_SIZE, conv2d_BS_K, conv2d_BS_CRS, conv2d_BS_NPQ, conv2d_TS_K, use_collectives, conv2d_SHMEM_PAD }; + +#define CREATE_CONV(name, type_suffix, spv_suffix) \ + ggml_vk_create_pipeline( \ + device, device->pipeline_##name##type_suffix[s], #name #type_suffix, \ + name##type_suffix##spv_suffix##_len, name##type_suffix##spv_suffix##_data, "main", 3, \ + sizeof(vk_op_##name##_push_constants), wg_denoms, spec_constants, 1, true, use_collectives); +#define CREATE_CONVS(spv_suffix) \ + CREATE_CONV(conv2d, _f32, spv_suffix) \ + CREATE_CONV(conv2d, _f16_f32, spv_suffix) \ + if (device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_conv_transpose_2d_push_constants)) { \ + CREATE_CONV(conv_transpose_2d, _f32, spv_suffix) \ + CREATE_CONV(conv_transpose_2d, _f16_f32, spv_suffix) \ + } +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (device->coopmat2) { + CREATE_CONVS(_cm2) + } else +#endif + if (conv2d_UNROLL) { + CREATE_CONVS(_unroll) + } else { + CREATE_CONVS( ) + } +#undef CREATE_CONV +#undef CREATE_CONVS + } + + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f32, "conv2d_dw_whcn_f32", conv2d_dw_whcn_f32_len, conv2d_dw_whcn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f32, "conv2d_dw_cwhn_f32", conv2d_dw_cwhn_f32_len, conv2d_dw_cwhn_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_whcn_f16_f32, "conv2d_dw_whcn_f16_f32", conv2d_dw_whcn_f16_f32_len, conv2d_dw_whcn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + ggml_vk_create_pipeline(device, device->pipeline_conv2d_dw_cwhn_f16_f32, "conv2d_dw_cwhn_f16_f32", conv2d_dw_cwhn_f16_f32_len, conv2d_dw_cwhn_f16_f32_data, "main", 3, sizeof(vk_op_conv2d_dw_push_constants), {512, 1, 1}, {}, 1); + + for (auto &c : compiles) { + c.wait(); + } + device->need_compiles = false; +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch); + +static vk_device ggml_vk_get_device(size_t idx) { + VK_LOG_DEBUG("ggml_vk_get_device(" << idx << ")"); + + if (vk_instance.devices[idx] == nullptr) { + VK_LOG_DEBUG("Initializing new vk_device"); + vk_device device = std::make_shared(); + vk_instance.devices[idx] = device; + +#ifdef GGML_VULKAN_MEMORY_DEBUG + device->memory_logger = std::unique_ptr(new vk_memory_logger()); +#endif + if (vk_perf_logger_enabled) { + device->perf_logger = std::unique_ptr(new vk_perf_logger()); + } + + size_t dev_num = vk_instance.device_indices[idx]; + + std::vector physical_devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= physical_devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + device->physical_device = physical_devices[dev_num]; + const std::vector ext_props = device->physical_device.enumerateDeviceExtensionProperties(); + + device->architecture = get_device_architecture(device->physical_device); + + const char* GGML_VK_PREFER_HOST_MEMORY = getenv("GGML_VK_PREFER_HOST_MEMORY"); + device->prefer_host_memory = GGML_VK_PREFER_HOST_MEMORY != nullptr; + + const char* GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM = getenv("GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM"); + device->disable_host_visible_vidmem = GGML_VK_DISABLE_HOST_VISIBLE_VIDMEM != nullptr; + + const char* GGML_VK_ALLOW_SYSMEM_FALLBACK = getenv("GGML_VK_ALLOW_SYSMEM_FALLBACK"); + device->allow_sysmem_fallback = GGML_VK_ALLOW_SYSMEM_FALLBACK != nullptr; + + const char* GGML_VK_DISABLE_GRAPH_OPTIMIZE = getenv("GGML_VK_DISABLE_GRAPH_OPTIMIZE"); + device->disable_graph_optimize = GGML_VK_DISABLE_GRAPH_OPTIMIZE != nullptr; + + bool fp16_storage = false; + bool fp16_compute = false; + bool maintenance4_support = false; + bool sm_builtins = false; + bool amd_shader_core_properties2 = false; + bool pipeline_robustness = false; + bool coopmat2_support = false; + bool pipeline_executable_properties_support = false; + device->coopmat_support = false; + device->integer_dot_product = false; + bool bfloat16_support = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_KHR_maintenance4", properties.extensionName) == 0) { + maintenance4_support = true; + } else if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; + } else if (strcmp("VK_NV_shader_sm_builtins", properties.extensionName) == 0) { + sm_builtins = true; + } else if (strcmp("VK_AMD_shader_core_properties2", properties.extensionName) == 0) { + amd_shader_core_properties2 = true; + } else if (strcmp("VK_EXT_pipeline_robustness", properties.extensionName) == 0) { + pipeline_robustness = true; + } else if (strcmp("VK_EXT_subgroup_size_control", properties.extensionName) == 0) { + device->subgroup_size_control = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + device->coopmat_support = true; + device->coopmat_m = 0; + device->coopmat_n = 0; + device->coopmat_k = 0; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + device->integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; +#endif + } else if (strcmp("VK_KHR_pipeline_executable_properties", properties.extensionName) == 0) { + pipeline_executable_properties_support = true; + } + } + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceMaintenance4Properties props4; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderSMBuiltinsPropertiesNV sm_props; + vk::PhysicalDeviceShaderCoreProperties2AMD amd_shader_core_properties2_props; + vk::PhysicalDeviceVulkan11Properties vk11_props; + vk::PhysicalDeviceVulkan12Properties vk12_props; + vk::PhysicalDeviceSubgroupSizeControlPropertiesEXT subgroup_size_control_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + driver_props.pNext = &vk11_props; + vk11_props.pNext = &vk12_props; + + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&vk12_props; + + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&props4; + last_struct = (VkBaseOutStructure *)&props4; + } + if (sm_builtins) { + last_struct->pNext = (VkBaseOutStructure *)&sm_props; + last_struct = (VkBaseOutStructure *)&sm_props; + } + if (amd_shader_core_properties2) { + last_struct->pNext = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + last_struct = (VkBaseOutStructure *)&amd_shader_core_properties2_props; + } + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_props; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_props; + } + +#if defined(VK_NV_cooperative_matrix2) + vk::PhysicalDeviceCooperativeMatrix2PropertiesNV coopmat2_props; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_props; + last_struct = (VkBaseOutStructure *)&coopmat2_props; + } +#endif + + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + device->physical_device.getProperties2(&props2); + device->properties = props2.properties; + device->vendor_id = device->properties.vendorID; + device->driver_id = driver_props.driverID; + + const char* GGML_VK_FORCE_MAX_ALLOCATION_SIZE = getenv("GGML_VK_FORCE_MAX_ALLOCATION_SIZE"); + + if (GGML_VK_FORCE_MAX_ALLOCATION_SIZE != nullptr) { + device->max_memory_allocation_size = std::stoull(GGML_VK_FORCE_MAX_ALLOCATION_SIZE); + } else if (maintenance4_support) { + device->max_memory_allocation_size = std::min(props3.maxMemoryAllocationSize, props4.maxBufferSize); + } else { + device->max_memory_allocation_size = props3.maxMemoryAllocationSize; + } + + const char* GGML_VK_FORCE_MAX_BUFFER_SIZE = getenv("GGML_VK_FORCE_MAX_BUFFER_SIZE"); + + if (GGML_VK_FORCE_MAX_BUFFER_SIZE != nullptr) { + device->max_buffer_size = std::stoull(GGML_VK_FORCE_MAX_BUFFER_SIZE); + } else if (maintenance4_support) { + device->max_buffer_size = props4.maxBufferSize; + } else { + device->max_buffer_size = device->max_memory_allocation_size; + } + + const char* GGML_VK_SUBALLOCATION_BLOCK_SIZE = getenv("GGML_VK_SUBALLOCATION_BLOCK_SIZE"); + + if (GGML_VK_SUBALLOCATION_BLOCK_SIZE != nullptr) { + device->suballocation_block_size = std::stoull(GGML_VK_SUBALLOCATION_BLOCK_SIZE); + } else { + // Limit batching of allocations to 1GB by default to avoid fragmentation issues + device->suballocation_block_size = 1024*1024*1024; + } + device->suballocation_block_size = std::min(device->suballocation_block_size, device->max_memory_allocation_size); + + device->subgroup_size = subgroup_props.subgroupSize; + device->uma = device->properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + if (sm_builtins) { + device->shader_core_count = sm_props.shaderSMCount; + } else if (amd_shader_core_properties2) { + device->shader_core_count = amd_shader_core_properties2_props.activeComputeUnitCount; + } else { + device->shader_core_count = 0; + } + device->float_controls_rte_fp16 = vk12_props.shaderRoundingModeRTEFloat16; + + device->subgroup_arithmetic = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eArithmetic); +#ifdef __APPLE__ + // Workaround for subgroup arithmetic failing on MoltenVK with AMD GPUs (issue 15846) + if (device->vendor_id == VK_VENDOR_ID_AMD) { + device->subgroup_arithmetic = false; + } +#endif + device->subgroup_shuffle = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eShuffle); + device->subgroup_clustered = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eClustered); + + device->subgroup_ballot = (vk11_props.subgroupSupportedStages & vk::ShaderStageFlagBits::eCompute) && + (vk11_props.subgroupSupportedOperations & vk::SubgroupFeatureFlagBits::eBallot); + + const bool force_disable_f16 = getenv("GGML_VK_DISABLE_F16") != nullptr; + + device->fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + if (!ggml_vk_khr_cooperative_matrix_support(device->properties, driver_props, device->architecture)) { + device->coopmat_support = false; + } + + device->integer_dot_product = device->integer_dot_product && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated; + + std::vector queue_family_props = device->physical_device.getQueueFamilyProperties(); + + // Try to find a non-graphics compute queue and transfer-focused queues + const uint32_t compute_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eCompute, vk::QueueFlagBits::eGraphics, -1, 1); + const uint32_t transfer_queue_family_index = ggml_vk_find_queue_family_index(queue_family_props, vk::QueueFlagBits::eTransfer, vk::QueueFlagBits::eCompute | vk::QueueFlagBits::eGraphics, compute_queue_family_index, 1); + + const float priorities[] = { 1.0f, 1.0f }; + device->single_queue = compute_queue_family_index == transfer_queue_family_index && queue_family_props[compute_queue_family_index].queueCount == 1; + + std::vector device_queue_create_infos; + if (compute_queue_family_index != transfer_queue_family_index) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), transfer_queue_family_index, 1, priorities + 1}); + } else if(!device->single_queue) { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 2, priorities}); + } else { + device_queue_create_infos.push_back({vk::DeviceQueueCreateFlags(), compute_queue_family_index, 1, priorities}); + } + vk::DeviceCreateInfo device_create_info; + std::vector device_extensions; + vk::PhysicalDeviceFeatures device_features = device->physical_device.getFeatures(); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + device_features2.features = (VkPhysicalDeviceFeatures)device_features; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + last_struct = (VkBaseOutStructure *)&vk12_features; + + VkPhysicalDevicePipelineRobustnessFeaturesEXT pl_robustness_features; + pl_robustness_features.pNext = nullptr; + pl_robustness_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_ROBUSTNESS_FEATURES_EXT; + pl_robustness_features.pipelineRobustness = VK_FALSE; + + if (pipeline_robustness) { + last_struct->pNext = (VkBaseOutStructure *)&pl_robustness_features; + last_struct = (VkBaseOutStructure *)&pl_robustness_features; + device_extensions.push_back("VK_EXT_pipeline_robustness"); + } + + VkPhysicalDeviceSubgroupSizeControlFeaturesEXT subgroup_size_control_features; + subgroup_size_control_features.pNext = nullptr; + subgroup_size_control_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SUBGROUP_SIZE_CONTROL_FEATURES_EXT; + subgroup_size_control_features.computeFullSubgroups = false; + subgroup_size_control_features.subgroupSizeControl = false; + + if (device->subgroup_size_control) { + last_struct->pNext = (VkBaseOutStructure *)&subgroup_size_control_features; + last_struct = (VkBaseOutStructure *)&subgroup_size_control_features; + } + +#if defined(VK_KHR_cooperative_matrix) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (device->coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } +#endif + +#if defined(VK_NV_cooperative_matrix2) + VkPhysicalDeviceCooperativeMatrix2FeaturesNV coopmat2_features {}; + coopmat2_features.pNext = nullptr; + coopmat2_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_2_FEATURES_NV; + if (coopmat2_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat2_features; + last_struct = (VkBaseOutStructure *)&coopmat2_features; + device_extensions.push_back("VK_NV_cooperative_matrix2"); + } +#endif + +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.pNext = nullptr; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif + + VkPhysicalDeviceMaintenance4Features maint4_features {}; + maint4_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_MAINTENANCE_4_FEATURES; + if (maintenance4_support) { + last_struct->pNext = (VkBaseOutStructure *)&maint4_features; + last_struct = (VkBaseOutStructure *)&maint4_features; + device_extensions.push_back("VK_KHR_maintenance4"); + } + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (device->integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + device_extensions.push_back("VK_KHR_shader_integer_dot_product"); + } + + VkPhysicalDevicePipelineExecutablePropertiesFeaturesKHR pep_features {}; + pep_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_PIPELINE_EXECUTABLE_PROPERTIES_FEATURES_KHR; + if (pipeline_executable_properties_support) { + last_struct->pNext = (VkBaseOutStructure *)&pep_features; + last_struct = (VkBaseOutStructure *)&pep_features; + device_extensions.push_back("VK_KHR_pipeline_executable_properties"); + } + + vkGetPhysicalDeviceFeatures2(device->physical_device, &device_features2); + + device->pipeline_executable_properties_support = pipeline_executable_properties_support; + + device->fp16 = device->fp16 && vk12_features.shaderFloat16; + +#if defined(VK_KHR_shader_bfloat16) + device->bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + device->bf16 = false; +#endif + + device->pipeline_robustness = pl_robustness_features.pipelineRobustness; + + device->multi_add = vk12_props.shaderRoundingModeRTEFloat16 && + device->properties.limits.maxPushConstantsSize >= sizeof(vk_op_multi_add_push_constants) && + vk12_features.runtimeDescriptorArray && + device->vendor_id != VK_VENDOR_ID_INTEL && + getenv("GGML_VK_DISABLE_MULTI_ADD") == nullptr; + + device->shader_int64 = device_features2.features.shaderInt64; + device->buffer_device_address = vk12_features.bufferDeviceAddress; + + if (device->subgroup_size_control) { + device->subgroup_min_size = subgroup_size_control_props.minSubgroupSize; + device->subgroup_max_size = subgroup_size_control_props.maxSubgroupSize; + device_extensions.push_back("VK_EXT_subgroup_size_control"); + } + + device->subgroup_size_control = device->subgroup_size_control && + (subgroup_size_control_props.requiredSubgroupSizeStages & vk::ShaderStageFlagBits::eCompute) && + subgroup_size_control_features.subgroupSizeControl; + + device->subgroup_require_full_support = subgroup_size_control_features.computeFullSubgroups; + +#if defined(VK_KHR_cooperative_matrix) + device->coopmat_support = device->coopmat_support && coopmat_features.cooperativeMatrix; + + // coopmat1 fa shader currently assumes 32 invocations per subgroup + device->coopmat1_fa_support = device->coopmat_support && device->subgroup_require_full_support && + device->subgroup_size_control && device->subgroup_min_size <= 32 && + device->subgroup_max_size >= 32; +#endif + + if (coopmat2_support) { +#if defined(VK_NV_cooperative_matrix2) && defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (coopmat2_features.cooperativeMatrixWorkgroupScope && + coopmat2_features.cooperativeMatrixFlexibleDimensions && + coopmat2_features.cooperativeMatrixReductions && + coopmat2_features.cooperativeMatrixConversions && + coopmat2_features.cooperativeMatrixPerElementOperations && + coopmat2_features.cooperativeMatrixTensorAddressing && + coopmat2_features.cooperativeMatrixBlockLoads && + vk12_features.bufferDeviceAddress) { + + std::vector flexible_dimensions; + uint32_t count = 0; + + PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV = + (PFN_vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV) + vk_instance.instance.getProcAddr("vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV"); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, nullptr); + + VkCooperativeMatrixFlexibleDimensionsPropertiesNV empty_prop {}; + empty_prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_FLEXIBLE_DIMENSIONS_PROPERTIES_NV; + flexible_dimensions.resize(count, empty_prop); + + _vkGetPhysicalDeviceCooperativeMatrixFlexibleDimensionsPropertiesNV(device->physical_device, &count, flexible_dimensions.data()); + + bool found_fp16_128 = false, + found_fp16_256 = false, + found_fp32_128 = false, + found_fp32_256 = false; + // need to support fp16*fp16 with fp16/fp32 accumulator, for workgroupsize 128 + // with 32x16x16 and 256 with 32x32x16. + for (auto &prop : flexible_dimensions) { + if (prop.saturatingAccumulation == VK_FALSE && + prop.scope == VK_SCOPE_WORKGROUP_KHR && + prop.AType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + + if (prop.workgroupInvocations == 128 && + prop.MGranularity <= 32 && + prop.NGranularity <= 16 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_128 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_128 = true; + } + } + if (prop.workgroupInvocations == 256 && + prop.MGranularity <= 32 && + prop.NGranularity <= 32 && + prop.KGranularity <= 16) { + if (prop.CType == VK_COMPONENT_TYPE_FLOAT16_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT16_KHR) { + found_fp16_256 = true; + } + if (prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR) { + found_fp32_256 = true; + } + } + } + } + if (found_fp16_128 && found_fp16_256 && + found_fp32_128 && found_fp32_256 && + coopmat2_props.cooperativeMatrixFlexibleDimensionsMaxDimension >= 512) { + device->coopmat2 = true; + } + } +#endif + } + + if (!vk11_features.storageBuffer16BitAccess) { + std::cerr << "ggml_vulkan: device " << GGML_VK_NAME << idx << " does not support 16-bit storage." << std::endl; + throw std::runtime_error("Unsupported device"); + } + + device_extensions.push_back("VK_KHR_16bit_storage"); + +#ifdef GGML_VULKAN_VALIDATE + device_extensions.push_back("VK_KHR_shader_non_semantic_info"); +#endif + + if (device->fp16) { + device_extensions.push_back("VK_KHR_shader_float16_int8"); + } + +#if defined(VK_KHR_cooperative_matrix) + if (device->coopmat_support) { + // Query supported shapes + std::vector cm_props; + + PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR = + (PFN_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR)vkGetInstanceProcAddr(vk_instance.instance, "vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR"); + + uint32_t cm_props_num; + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, nullptr); + + cm_props.resize(cm_props_num); + + for (auto& prop : cm_props) { + prop.sType = VK_STRUCTURE_TYPE_COOPERATIVE_MATRIX_PROPERTIES_KHR; + } + + pfn_vkGetPhysicalDeviceCooperativeMatrixPropertiesKHR(device->physical_device, &cm_props_num, cm_props.data()); + + VK_LOG_DEBUG("ggml_vulkan: Cooperative Matrix Shapes: " << cm_props.size()); + + for (auto& prop : cm_props) { + VK_LOG_DEBUG("ggml_vulkan: M: " << prop.MSize << " N: " << prop.NSize << " K: " << prop.KSize << " A: " << vk::to_string((vk::ComponentTypeKHR)prop.AType) << " B: " << vk::to_string((vk::ComponentTypeKHR)prop.BType) << " C: " << vk::to_string((vk::ComponentTypeKHR)prop.CType) << " Result: " << vk::to_string((vk::ComponentTypeKHR)prop.ResultType) << " saturatingAccumulation: " << prop.saturatingAccumulation << " scope: " << vk::to_string((vk::ScopeKHR)prop.scope)); + + if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eFloat16 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat32) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f32_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f32_support = true; + } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f32acc = true; + } + } else if ((vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eFloat16 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eFloat16) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_acc_f16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_acc_f16_support = true; + } + if (prop.MSize == 16 && prop.NSize == 16 && prop.KSize == 16) { + device->coopmat_support_16x16x16_f16acc = true; + } + } + } else if ((vk::ComponentTypeKHR)prop.AType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.BType == vk::ComponentTypeKHR::eSint8 && + (vk::ComponentTypeKHR)prop.CType == vk::ComponentTypeKHR::eSint32 && + (vk::ComponentTypeKHR)prop.ResultType == vk::ComponentTypeKHR::eSint32 && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup && + device->coopmat_int_m == 0 + ) { + device->coopmat_int_support = true; + device->coopmat_int_m = prop.MSize; + device->coopmat_int_n = prop.NSize; + device->coopmat_int_k = prop.KSize; + } +#if defined(VK_KHR_shader_bfloat16) && defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (prop.AType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.BType == VK_COMPONENT_TYPE_BFLOAT16_KHR && + prop.CType == VK_COMPONENT_TYPE_FLOAT32_KHR && + prop.ResultType == VK_COMPONENT_TYPE_FLOAT32_KHR && + (vk::ScopeKHR)prop.scope == vk::ScopeKHR::eSubgroup + ) { + // coopmat sizes not set yet + if (device->coopmat_m == 0) { + device->coopmat_bf16_support = true; + device->coopmat_m = prop.MSize; + device->coopmat_n = prop.NSize; + device->coopmat_k = prop.KSize; + } else if (device->coopmat_m == prop.MSize && device->coopmat_n == prop.NSize && device->coopmat_k == prop.KSize) { + // Only enable if shape is identical + device->coopmat_bf16_support = true; + } + } +#endif + } + + if (device->coopmat_m == 0 || !device->coopmat_acc_f32_support) { + // No suitable matmul mode found + GGML_LOG_DEBUG("ggml_vulkan: WARNING: No suitable matrix core mode found. Disabling matrix cores.\n"); + device->coopmat_support = false; + } + if (getenv("GGML_VK_DISABLE_BFLOAT16")) { + device->coopmat_bf16_support = false; + } + } + + if (device->coopmat_support) { + device_extensions.push_back("VK_KHR_cooperative_matrix"); + } +#if defined(VK_KHR_shader_bfloat16) + if (device->coopmat_bf16_support) { + device_extensions.push_back("VK_KHR_shader_bfloat16"); + } +#endif +#endif + device->name = GGML_VK_NAME + std::to_string(idx); + + device_create_info = { + vk::DeviceCreateFlags(), + device_queue_create_infos, + {}, + device_extensions + }; + device_create_info.setPNext(&device_features2); + device->device = device->physical_device.createDevice(device_create_info); + + // Queues + ggml_vk_create_queue(device, device->compute_queue, compute_queue_family_index, 0, { vk::PipelineStageFlagBits::eComputeShader | vk::PipelineStageFlagBits::eTransfer }, false); + + // Shaders + // Disable matmul tile sizes early if performance low or not supported + for (uint32_t i = 0; i < GGML_TYPE_COUNT; ++i) { + switch (device->vendor_id) { +#ifndef GGML_VULKAN_RUN_TESTS + case VK_VENDOR_ID_AMD: + case VK_VENDOR_ID_INTEL: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + case VK_VENDOR_ID_APPLE: + device->mul_mat_l[i] = false; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = false; + device->mul_mat_id_l[i] = false; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = false; + break; +#endif + default: + device->mul_mat_l[i] = true; + device->mul_mat_m[i] = true; + device->mul_mat_s[i] = true; + device->mul_mat_id_l[i] = true; + device->mul_mat_id_m[i] = true; + device->mul_mat_id_s[i] = true; + break; + } + } + + + std::vector dsl_binding; + std::vector dsl_binding_flags; + for (uint32_t i = 0; i < MAX_PARAMETER_COUNT; i++) { + dsl_binding.push_back({i, vk::DescriptorType::eStorageBuffer, 1, vk::ShaderStageFlagBits::eCompute}); + dsl_binding_flags.push_back({}); + } + + vk::DescriptorSetLayoutBindingFlagsCreateInfo dslbfci = { dsl_binding_flags }; + + vk::DescriptorSetLayoutCreateInfo descriptor_set_layout_create_info( + {}, + dsl_binding); + descriptor_set_layout_create_info.setPNext(&dslbfci); + device->dsl = device->device.createDescriptorSetLayout(descriptor_set_layout_create_info); + + ggml_vk_load_shaders(device); + + if (!device->single_queue) { + const uint32_t transfer_queue_index = compute_queue_family_index == transfer_queue_family_index ? 1 : 0; + ggml_vk_create_queue(device, device->transfer_queue, transfer_queue_family_index, transfer_queue_index, { vk::PipelineStageFlagBits::eTransfer }, true); + } else { + // TODO: Use pointer or reference to avoid copy + device->transfer_queue.copyFrom(device->compute_queue); + device->transfer_queue.cmd_pool.init(device, &device->transfer_queue); + } + + device->buffer_type = { + /* .iface = */ ggml_backend_vk_buffer_type_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), idx), + /* .context = */ new ggml_backend_vk_buffer_type_context{ device->name, device }, + }; + + device->fence = device->device.createFence({}); + + device->idx = idx; + + device->disable_fusion = getenv("GGML_VK_DISABLE_FUSION") != nullptr; + + device->add_rms_fusion = !device->disable_fusion && + device->subgroup_arithmetic && + device->vendor_id != VK_VENDOR_ID_INTEL; + device->partials_binding_alignment = + std::max(4u, (uint32_t)device->properties.limits.minStorageBufferOffsetAlignment); + + device->mmvq_mode = 0; + if (getenv("GGML_VK_DISABLE_MMVQ")) { + device->mmvq_mode = -1; + } else if (getenv("GGML_VK_FORCE_MMVQ")) { + device->mmvq_mode = 1; + } + + return device; + } + + return vk_instance.devices[idx]; +} + +static void ggml_vk_print_gpu_info(size_t idx) { + GGML_ASSERT(idx < vk_instance.device_indices.size()); + size_t dev_num = vk_instance.device_indices[idx]; + VK_LOG_DEBUG("ggml_vk_print_gpu_info(" << dev_num << ")"); + GGML_ASSERT(vk_instance_initialized); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + if (dev_num >= devices.size()) { + std::cerr << "ggml_vulkan: Device with index " << dev_num << " does not exist." << std::endl; + throw std::runtime_error("Device not found"); + } + + vk::PhysicalDevice physical_device = devices[dev_num]; + std::vector ext_props = physical_device.enumerateDeviceExtensionProperties(); + + bool fp16_storage = false; + bool fp16_compute = false; + bool coopmat_support = false; + bool coopmat2_support = false; + bool integer_dot_product = false; + bool bfloat16_support = false; + + for (auto properties : ext_props) { + if (strcmp("VK_KHR_16bit_storage", properties.extensionName) == 0) { + fp16_storage = true; + } else if (strcmp("VK_KHR_shader_float16_int8", properties.extensionName) == 0) { + fp16_compute = true; +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_cooperative_matrix", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT")) { + coopmat_support = true; +#endif +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + } else if (strcmp("VK_NV_cooperative_matrix2", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_COOPMAT2")) { + coopmat2_support = true; +#endif +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_integer_dot_product", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_INTEGER_DOT_PRODUCT")) { + integer_dot_product = true; +#endif +#if defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + } else if (strcmp("VK_KHR_shader_bfloat16", properties.extensionName) == 0 && + !getenv("GGML_VK_DISABLE_BFLOAT16")) { + bfloat16_support = true; +#endif + } + } + + const vk_device_architecture device_architecture = get_device_architecture(physical_device); + + const char* GGML_VK_DISABLE_F16 = getenv("GGML_VK_DISABLE_F16"); + bool force_disable_f16 = GGML_VK_DISABLE_F16 != nullptr; + + bool fp16 = !force_disable_f16 && fp16_storage && fp16_compute; + + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceMaintenance3Properties props3; + vk::PhysicalDeviceSubgroupProperties subgroup_props; + vk::PhysicalDeviceDriverProperties driver_props; + vk::PhysicalDeviceShaderIntegerDotProductPropertiesKHR shader_integer_dot_product_props; + props2.pNext = &props3; + props3.pNext = &subgroup_props; + subgroup_props.pNext = &driver_props; + + // Pointer to the last chain element + VkBaseOutStructure * last_struct = (VkBaseOutStructure *)&driver_props; + + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_props; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_props; + } + + physical_device.getProperties2(&props2); + + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + device_features2.pNext = nullptr; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + VkPhysicalDeviceVulkan12Features vk12_features; + vk12_features.pNext = nullptr; + vk12_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_2_FEATURES; + vk11_features.pNext = &vk12_features; + + // Pointer to the last chain element + last_struct = (VkBaseOutStructure *)&vk12_features; + +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + VkPhysicalDeviceCooperativeMatrixFeaturesKHR coopmat_features; + coopmat_features.pNext = nullptr; + coopmat_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_COOPERATIVE_MATRIX_FEATURES_KHR; + coopmat_features.cooperativeMatrix = VK_FALSE; + + if (coopmat_support) { + last_struct->pNext = (VkBaseOutStructure *)&coopmat_features; + last_struct = (VkBaseOutStructure *)&coopmat_features; + } +#endif + + VkPhysicalDeviceShaderIntegerDotProductFeaturesKHR shader_integer_dot_product_features {}; + shader_integer_dot_product_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_INTEGER_DOT_PRODUCT_FEATURES_KHR; + if (integer_dot_product) { + last_struct->pNext = (VkBaseOutStructure *)&shader_integer_dot_product_features; + last_struct = (VkBaseOutStructure *)&shader_integer_dot_product_features; + } + +#if defined(VK_KHR_shader_bfloat16) + VkPhysicalDeviceShaderBfloat16FeaturesKHR bfloat16_features {}; + bfloat16_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_SHADER_BFLOAT16_FEATURES_KHR; + if (bfloat16_support) { + last_struct->pNext = (VkBaseOutStructure *)&bfloat16_features; + last_struct = (VkBaseOutStructure *)&bfloat16_features; + } +#endif + + vkGetPhysicalDeviceFeatures2(physical_device, &device_features2); + + fp16 = fp16 && vk12_features.shaderFloat16; + +#if defined(VK_KHR_shader_bfloat16) + bool bf16 = bfloat16_support && bfloat16_features.shaderBFloat16Type; +#else + bool bf16 = false; +#endif + + uint32_t default_subgroup_size = get_subgroup_size("", device_architecture); + const size_t subgroup_size = (default_subgroup_size != 0) ? default_subgroup_size : subgroup_props.subgroupSize; + const bool uma = props2.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu; + + integer_dot_product = integer_dot_product + && shader_integer_dot_product_props.integerDotProduct4x8BitPackedSignedAccelerated + && shader_integer_dot_product_features.shaderIntegerDotProduct; + + coopmat_support = coopmat_support +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + && coopmat_features.cooperativeMatrix +#endif + && ggml_vk_khr_cooperative_matrix_support(props2.properties, driver_props, device_architecture); + + std::string matrix_cores = coopmat2_support ? "NV_coopmat2" : coopmat_support ? "KHR_coopmat" : "none"; + + std::string device_name = props2.properties.deviceName.data(); + GGML_LOG_DEBUG("ggml_vulkan: %zu = %s (%s) | uma: %d | fp16: %d | bf16: %d | warp size: %zu | shared memory: %d | int dot: %d | matrix cores: %s\n", + idx, device_name.c_str(), driver_props.driverName.data(), uma, fp16, bf16, subgroup_size, + props2.properties.limits.maxComputeSharedMemorySize, integer_dot_product, matrix_cores.c_str()); + + if (props2.properties.deviceType == vk::PhysicalDeviceType::eCpu) { + GGML_LOG_DEBUG("ggml_vulkan: Warning: Device type is CPU. This is probably not the device you want.\n"); + } +} + +static bool ggml_vk_instance_validation_ext_available(); +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions); +static bool ggml_vk_instance_debug_utils_ext_available(const std::vector & instance_extensions); +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev); + +static DispatchLoaderDynamic ggml_vk_default_dispatcher_instance; +DispatchLoaderDynamic & ggml_vk_default_dispatcher() { + return ggml_vk_default_dispatcher_instance; +} + +static void ggml_vk_instance_init() { + if (vk_instance_initialized) { + return; + } + VK_LOG_DEBUG("ggml_vk_instance_init()"); + + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + ggml_vk_default_dispatcher_instance.init(vkGetInstanceProcAddr); + + uint32_t api_version = vk::enumerateInstanceVersion(); + + if (api_version < VK_API_VERSION_1_2) { + std::cerr << "ggml_vulkan: Error: Vulkan 1.2 required." << std::endl; + throw vk::SystemError(vk::Result::eErrorFeatureNotPresent, "Vulkan 1.2 required"); + } + + vk::ApplicationInfo app_info{ "ggml-vulkan", 1, nullptr, 0, api_version }; + + const std::vector instance_extensions = vk::enumerateInstanceExtensionProperties(); + const bool validation_ext = ggml_vk_instance_validation_ext_available(); +#ifdef __APPLE__ + const bool portability_enumeration_ext = ggml_vk_instance_portability_enumeration_ext_available(instance_extensions); +#endif + const bool debug_utils_ext = ggml_vk_instance_debug_utils_ext_available(instance_extensions) && getenv("GGML_VK_DEBUG_MARKERS") != nullptr; + std::vector layers; + + if (validation_ext) { + layers.push_back("VK_LAYER_KHRONOS_validation"); + } + std::vector extensions; + if (validation_ext) { + extensions.push_back("VK_EXT_validation_features"); + } +#ifdef __APPLE__ + if (portability_enumeration_ext) { + extensions.push_back("VK_KHR_portability_enumeration"); + } +#endif + if (debug_utils_ext) { + extensions.push_back("VK_EXT_debug_utils"); + } + vk::InstanceCreateInfo instance_create_info(vk::InstanceCreateFlags{}, &app_info, layers, extensions); +#ifdef __APPLE__ + if (portability_enumeration_ext) { + instance_create_info.flags |= vk::InstanceCreateFlagBits::eEnumeratePortabilityKHR; + } +#endif + + std::vector features_enable; + vk::ValidationFeaturesEXT validation_features; + + if (validation_ext) { + features_enable = { vk::ValidationFeatureEnableEXT::eBestPractices }; + validation_features = { + features_enable, + {}, + }; + validation_features.setPNext(nullptr); + instance_create_info.setPNext(&validation_features); + GGML_LOG_DEBUG("ggml_vulkan: Validation layers enabled\n"); + } + vk_instance.instance = vk::createInstance(instance_create_info); + vk_instance_initialized = true; + + if (debug_utils_ext) { + vk_instance.debug_utils_support = true; + vk_instance.pfn_vkSetDebugUtilsObjectNameEXT = (PFN_vkSetDebugUtilsObjectNameEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkSetDebugUtilsObjectNameEXT"); + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT = (PFN_vkQueueBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkQueueEndDebugUtilsLabelEXT = (PFN_vkQueueEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkQueueEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdBeginDebugUtilsLabelEXT = (PFN_vkCmdBeginDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdBeginDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdEndDebugUtilsLabelEXT = (PFN_vkCmdEndDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdEndDebugUtilsLabelEXT"); + vk_instance.pfn_vkCmdInsertDebugUtilsLabelEXT = (PFN_vkCmdInsertDebugUtilsLabelEXT) vkGetInstanceProcAddr(vk_instance.instance, "vkCmdInsertDebugUtilsLabelEXT"); + } + + vk_perf_logger_enabled = getenv("GGML_VK_PERF_LOGGER") != nullptr; + + // See https://github.com/KhronosGroup/Vulkan-Hpp?tab=readme-ov-file#extensions--per-device-function-pointers- + VULKAN_HPP_DEFAULT_DISPATCHER.init(vk_instance.instance); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + // Emulate behavior of CUDA_VISIBLE_DEVICES for Vulkan + char * devices_env = getenv("GGML_VK_VISIBLE_DEVICES"); + if (devices_env != nullptr) { + size_t num_available_devices = devices.size(); + + std::string devices(devices_env); + std::replace(devices.begin(), devices.end(), ',', ' '); + + std::stringstream ss(devices); + size_t tmp; + while (ss >> tmp) { + if(tmp >= num_available_devices) { + std::cerr << "ggml_vulkan: Invalid device index " << tmp << " in GGML_VK_VISIBLE_DEVICES." << std::endl; + throw std::runtime_error("Invalid Vulkan device index"); + } + vk_instance.device_indices.push_back(tmp); + } + } else { + // If no vulkan devices are found, return early + if (devices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; + } + + // Default to using all dedicated GPUs + for (size_t i = 0; i < devices.size(); i++) { + vk::PhysicalDeviceProperties2 new_props; + vk::PhysicalDeviceDriverProperties new_driver; + vk::PhysicalDeviceIDProperties new_id; + new_props.pNext = &new_driver; + new_driver.pNext = &new_id; + devices[i].getProperties2(&new_props); + + if ((new_props.properties.deviceType == vk::PhysicalDeviceType::eDiscreteGpu || new_props.properties.deviceType == vk::PhysicalDeviceType::eIntegratedGpu) && ggml_vk_device_is_supported(devices[i])) { + // Check if there are two physical devices corresponding to the same GPU + auto old_device = std::find_if( + vk_instance.device_indices.begin(), + vk_instance.device_indices.end(), + [&devices, &new_id](const size_t k){ + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceIDProperties old_id; + old_props.pNext = &old_id; + devices[k].getProperties2(&old_props); + return std::equal(std::begin(old_id.deviceUUID), std::end(old_id.deviceUUID), std::begin(new_id.deviceUUID)); + } + ); + if (old_device == vk_instance.device_indices.end()) { + vk_instance.device_indices.push_back(i); + } else { + // There can be two physical devices corresponding to the same GPU if there are 2 different drivers + // This can cause error when splitting layers aross the devices, need to keep only 1 + VK_LOG_DEBUG("Device " << i << " and device " << *old_device << " have the same deviceUUID"); + + vk::PhysicalDeviceProperties2 old_props; + vk::PhysicalDeviceDriverProperties old_driver; + old_props.pNext = &old_driver; + devices[*old_device].getProperties2(&old_props); + + std::map driver_priorities {}; + int old_priority = std::numeric_limits::max(); + int new_priority = std::numeric_limits::max(); + + // Check https://registry.khronos.org/vulkan/specs/1.3-extensions/man/html/VkDriverId.html for the list of driver id + // Smaller number -> higher priority + switch (old_props.properties.vendorID) { + case VK_VENDOR_ID_AMD: + driver_priorities[vk::DriverId::eMesaRadv] = 1; + driver_priorities[vk::DriverId::eAmdOpenSource] = 2; + driver_priorities[vk::DriverId::eAmdProprietary] = 3; + break; + case VK_VENDOR_ID_INTEL: + driver_priorities[vk::DriverId::eIntelOpenSourceMESA] = 1; + driver_priorities[vk::DriverId::eIntelProprietaryWindows] = 2; + break; + case VK_VENDOR_ID_NVIDIA: + driver_priorities[vk::DriverId::eNvidiaProprietary] = 1; +#if defined(VK_API_VERSION_1_3) && VK_HEADER_VERSION >= 235 + driver_priorities[vk::DriverId::eMesaNvk] = 2; +#endif + break; + } + + if (driver_priorities.count(old_driver.driverID)) { + old_priority = driver_priorities[old_driver.driverID]; + } + if (driver_priorities.count(new_driver.driverID)) { + new_priority = driver_priorities[new_driver.driverID]; + } + + if (new_priority < old_priority) { + auto r = std::remove(vk_instance.device_indices.begin(), vk_instance.device_indices.end(), *old_device); + vk_instance.device_indices.erase(r, vk_instance.device_indices.end()); + vk_instance.device_indices.push_back(i); + + VK_LOG_DEBUG("Prioritize device " << i << " driver " << new_driver.driverName << " over device " << *old_device << " driver " << old_driver.driverName); + } + else { + VK_LOG_DEBUG("Prioritize device " << *old_device << " driver " << old_driver.driverName << " over device " << i << " driver " << new_driver.driverName << std::endl); + } + } + } + } + + // If no GPUs found, fall back to the first non-CPU device. + // If only CPU devices are available, return without devices. + if (vk_instance.device_indices.empty()) { + for (size_t i = 0; i < devices.size(); i++) { + if (devices[i].getProperties().deviceType != vk::PhysicalDeviceType::eCpu) { + vk_instance.device_indices.push_back(i); + break; + } + } + } + + if (vk_instance.device_indices.empty()) { + GGML_LOG_INFO("ggml_vulkan: No devices found.\n"); + return; + } + } + GGML_LOG_DEBUG("ggml_vulkan: Found %zu Vulkan devices:\n", vk_instance.device_indices.size()); + + for (size_t i = 0; i < vk_instance.device_indices.size(); i++) { + vk::PhysicalDevice vkdev = devices[vk_instance.device_indices[i]]; + std::vector extensionprops = vkdev.enumerateDeviceExtensionProperties(); + + bool membudget_supported = false; + for (const auto & ext : extensionprops) { + if (strcmp(VK_EXT_MEMORY_BUDGET_EXTENSION_NAME, ext.extensionName) == 0) { + membudget_supported = true; + break; + } + } + + vk_instance.device_supports_membudget.push_back(membudget_supported); + + ggml_vk_print_gpu_info(i); + } +} + +static void ggml_vk_init(ggml_backend_vk_context * ctx, size_t idx) { + VK_LOG_DEBUG("ggml_vk_init(" << ctx->name << ", " << idx << ")"); + ggml_vk_instance_init(); + GGML_ASSERT(idx < vk_instance.device_indices.size()); + + ctx->name = GGML_VK_NAME + std::to_string(idx); + + ctx->device = ggml_vk_get_device(idx); + + ctx->semaphore_idx = 0; + ctx->event_idx = 0; + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + ctx->fence = ctx->device->device.createFence({}); + ctx->almost_ready_fence = ctx->device->device.createFence({}); + + ctx->compute_cmd_pool.init(ctx->device, &ctx->device->compute_queue); + ctx->transfer_cmd_pool.init(ctx->device, &ctx->device->transfer_queue); + +#ifdef GGML_VULKAN_CHECK_RESULTS + const char* skip_checks = getenv("GGML_VULKAN_SKIP_CHECKS"); + vk_skip_checks = (skip_checks == NULL ? 0 : atoi(skip_checks)); + const char* output_tensor = getenv("GGML_VULKAN_OUTPUT_TENSOR"); + vk_output_tensor = (output_tensor == NULL ? 0 : atoi(output_tensor)); +#endif +} + +static vk_pipeline ggml_vk_get_to_fp16(ggml_backend_vk_context * ctx, ggml_type type) { + VK_LOG_DEBUG("ggml_vk_get_to_fp16()"); + switch (type) { + case GGML_TYPE_F32: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant[type]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_pipeline(" << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ", " << prec << ")"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f32; + } + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f32_f16; + } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_bf16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_f16.f32acc; + } + } + + // MMQ + if (src1_type == GGML_TYPE_Q8_1) { + vk_matmul_pipeline pipelines = (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_q8_1[src0_type].f32acc; + + if (pipelines->s == nullptr && pipelines->m == nullptr && pipelines->l == nullptr) { + return nullptr; + } + + return pipelines; + } + + if (src1_type != GGML_TYPE_F32 && !ctx->device->coopmat2) { + return nullptr; + } + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + if (ctx->device->coopmat2) { + assert(src1_type == GGML_TYPE_F16); + return prec == GGML_PREC_DEFAULT ? ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat_f16[src0_type].f32acc; + } + if (ctx->device->coopmat_support) { + return (ctx->device->fp16 && ctx->device->coopmat_acc_f16_support && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; + } + return (ctx->device->fp16 && prec == GGML_PREC_DEFAULT) ? ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f16acc : ctx->device->pipeline_dequant_mul_mat_mat[src0_type].f32acc; +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type, uint32_t num_cols, uint32_t m, uint32_t k) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec()"); + GGML_ASSERT(b_type == GGML_TYPE_F32 || b_type == GGML_TYPE_F16 || b_type == GGML_TYPE_Q8_1); + GGML_ASSERT(num_cols >= 1 && num_cols <= mul_mat_vec_max_cols); + + if (b_type == GGML_TYPE_Q8_1) { + switch (a_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + break; + default: + return nullptr; + } + } + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + // heuristic to choose workgroup size + uint32_t dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + if ((ctx->device->vendor_id == VK_VENDOR_ID_NVIDIA && ctx->device->architecture != vk_device_architecture::NVIDIA_PRE_TURING) || ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + // Prefer larger workgroups when M is small, to spread the work out more + // and keep more SMs busy. + // q6_k seems to prefer small workgroup size even for "medium" values of M. + if (a_type == GGML_TYPE_Q6_K) { + if (m < 4096 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } else { + if (m <= 8192 && k >= 1024) { + dmmv_wg = DMMV_WG_SIZE_LARGE; + } + } + } + + if (b_type == GGML_TYPE_Q8_1) { + if (ctx->device->vendor_id == VK_VENDOR_ID_INTEL) { + dmmv_wg = DMMV_WG_SIZE_SUBGROUP; + } + return ctx->device->pipeline_dequant_mul_mat_vec_q8_1_f32[dmmv_wg][a_type][num_cols-1]; + } + + return b_type == GGML_TYPE_F32 ? ctx->device->pipeline_dequant_mul_mat_vec_f32_f32[dmmv_wg][a_type][num_cols-1] : ctx->device->pipeline_dequant_mul_mat_vec_f16_f32[dmmv_wg][a_type][num_cols-1]; +} + +static vk_matmul_pipeline ggml_vk_get_mul_mat_mat_id_pipeline(ggml_backend_vk_context * ctx, ggml_type src0_type, ggml_type src1_type, ggml_prec prec) { + VK_LOG_DEBUG("ggml_vk_get_mul_mat_mat_id_pipeline()"); + if (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f32; + } + if (src0_type == GGML_TYPE_BF16 && src1_type == GGML_TYPE_BF16) { + return ctx->device->pipeline_matmul_id_bf16; + } + if (prec == GGML_PREC_DEFAULT && ctx->device->fp16 && !(ctx->device->coopmat_support && !ctx->device->coopmat_acc_f16_support)) { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f16acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f16acc; + } + } else { + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F32) { + return ctx->device->pipeline_matmul_id_f16_f32.f32acc; + } + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return ctx->device->pipeline_matmul_id_f16.f32acc; + } + } + + GGML_ASSERT(src1_type == GGML_TYPE_F32 || (ctx->device->coopmat2 && src1_type == GGML_TYPE_F16)); + + switch (src0_type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + // XXX TODO 'prec' is not actually allowed in mul_mat_id. + bool prefer_fp16acc = ctx->device->fp16 /*&& prec == GGML_PREC_DEFAULT*/; + bool support_fp16acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc != nullptr; + bool support_fp32acc = ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc != nullptr; + + if (support_fp16acc && (prefer_fp16acc || !support_fp32acc)) { + return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f16acc; + } else { + GGML_ASSERT(support_fp32acc); + return ctx->device->pipeline_dequant_mul_mat_mat_id[src0_type].f32acc; + } +} + +static vk_pipeline ggml_vk_get_dequantize_mul_mat_vec_id(ggml_backend_vk_context * ctx, ggml_type a_type, ggml_type b_type) { + VK_LOG_DEBUG("ggml_vk_get_dequantize_mul_mat_vec_id()"); + GGML_ASSERT(b_type == GGML_TYPE_F32); + + switch (a_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return nullptr; + } + + return ctx->device->pipeline_dequant_mul_mat_vec_id_f32[a_type]; +} + +static vk_buffer ggml_vk_pool_malloc(ggml_backend_vk_context * ctx, size_t size) { + VK_LOG_DEBUG("ggml_vk_pool_malloc(" << size << ")"); + VK_LOG_MEMORY("ggml_vk_pool_malloc"); + + int best_i = -1; + size_t best_size = std::numeric_limits::max(); //smallest unused buffer that fits our needs + int worst_i = -1; + size_t worst_size = 0; //largest unused buffer seen so far + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer &b = ctx->buffer_pool[i]; + if (b != nullptr && b->size >= size && b->size < best_size) { + best_i = i; + best_size = b->size; + } + if (b != nullptr && b->size > worst_size) { + worst_i = i; + worst_size = b->size; + } + } + if(best_i != -1) { + //found the smallest buffer that fits our needs + vk_buffer b = ctx->buffer_pool[best_i]; + ctx->buffer_pool[best_i].reset(); + return b; + } + if(worst_i != -1) { + //no buffer that fits our needs, resize largest one to save memory + vk_buffer& b = ctx->buffer_pool[worst_i]; + ggml_vk_destroy_buffer(b); + } + + return ggml_vk_create_buffer_device(ctx->device, size); +} + +static void ggml_vk_pool_free(ggml_backend_vk_context * ctx, vk_buffer& buffer) { + VK_LOG_DEBUG("ggml_vk_pool_free(" << buffer->size << ")"); + for (int i = 0; i < MAX_VK_BUFFERS; ++i) { + vk_buffer& b = ctx->buffer_pool[i]; + if (b == nullptr) { + b = buffer; + return; + } + } + std::cerr << "ggml_vulkan: WARNING: vk buffer pool full, increase MAX_VK_BUFFERS" << std::endl; + ggml_vk_destroy_buffer(buffer); +} + +// Returns an available temporary buffer that may only be used temporarily, it will be reused +static vk_buffer ggml_vk_create_buffer_temp(ggml_backend_vk_context * ctx, size_t size) { + // Try to find existing temp buffer with enough capacity + for (auto& buffer : ctx->gc.temp_buffers) { + if (buffer->size >= size) { + return buffer; + } + } + + VK_LOG_MEMORY("ggml_vk_create_buffer_temp(" << size << ")"); + + // Otherwise create new buffer + vk_buffer buf = ggml_vk_pool_malloc(ctx, size); + ctx->gc.temp_buffers.push_back(buf); + + return buf; +} + +static void * ggml_vk_host_malloc(vk_device& device, size_t size) { + VK_LOG_MEMORY("ggml_vk_host_malloc(" << size << ")"); + vk_buffer buf = ggml_vk_create_buffer(device, size, + {vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent}); + + if(!(buf->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible)) { + fprintf(stderr, "WARNING: failed to allocate %.2f MB of pinned memory\n", + size/1024.0/1024.0); + device->device.freeMemory(buf->device_memory); + device->device.destroyBuffer(buf->buffer); + return nullptr; + } + + std::lock_guard guard(device->mutex); + device->pinned_memory.push_back(std::make_tuple(buf->ptr, size, buf)); + + return buf->ptr; +} + +static void ggml_vk_host_free(vk_device& device, void* ptr) { + if (ptr == nullptr) { + return; + } + VK_LOG_MEMORY("ggml_vk_host_free(" << ptr << ")"); + std::lock_guard guard(device->mutex); + + vk_buffer buf; + size_t index; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + index = i; + break; + } + } + if (buf == nullptr) { + fprintf(stderr, "WARNING: failed to free pinned memory: memory not in map\n"); + return; + } + + ggml_vk_destroy_buffer(buf); + + device->pinned_memory.erase(device->pinned_memory.begin() + index); +} + +static void ggml_vk_host_get(vk_device& device, const void * ptr, vk_buffer& buf, size_t& buf_offset) { + std::lock_guard guard(device->mutex); + buf = nullptr; + buf_offset = 0; + for (size_t i = 0; i < device->pinned_memory.size(); i++) { + const uint8_t* addr = (const uint8_t*) std::get<0>(device->pinned_memory[i]); + const uint8_t* endr = addr + std::get<1>(device->pinned_memory[i]); + if (ptr >= addr && ptr < endr) { + buf = std::get<2>(device->pinned_memory[i]); + buf_offset = ((const uint8_t *)ptr) - addr; + break; + } + } +} + +static vk_submission ggml_vk_begin_submission(vk_device& device, vk_command_pool& p, bool one_time = true) { + vk_submission s; + s.buffer = ggml_vk_create_cmd_buffer(device, p); + if (one_time) { + s.buffer.begin({ vk::CommandBufferUsageFlagBits::eOneTimeSubmit }); + } else { + s.buffer.begin({ vk::CommandBufferUsageFlags{} }); + } + + return s; +} + +template size_t push_constant_size(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + GGML_UNUSED(t); + return sizeof(T); +} +template size_t push_constant_size(const std::vector &t) { + GGML_UNUSED(t); + return sizeof(T) * t.size(); +} +template size_t push_constant_size(const std::array &t) { + GGML_UNUSED(t); + return sizeof(T) * N; +} + +template const T *push_constant_data(const T &t) { + static_assert(std::is_class::value, "T must be a struct/class"); + return &t; +} +template const T *push_constant_data(const std::vector &t) { + return t.data(); +} +template const T *push_constant_data(const std::array &t) { + return t.data(); +} + +template +static void ggml_vk_dispatch_pipeline(ggml_backend_vk_context* ctx, vk_context& subctx, vk_pipeline& pipeline, std::initializer_list const& descriptor_buffer_infos, const T &push_constants, std::array elements) { + const uint32_t wg0 = CEIL_DIV(elements[0], pipeline->wg_denoms[0]); + const uint32_t wg1 = CEIL_DIV(elements[1], pipeline->wg_denoms[1]); + const uint32_t wg2 = CEIL_DIV(elements[2], pipeline->wg_denoms[2]); + VK_LOG_DEBUG("ggml_vk_dispatch_pipeline(" << pipeline->name << ", {"; + for (auto& buffer : descriptor_buffer_infos) { + std::cerr << "(" << buffer.buffer << ", " << buffer.offset << ", " << buffer.range << "), "; + } + std::cerr << "}, (" << wg0 << "," << wg1 << "," << wg2 << "))"); + GGML_ASSERT(ctx->descriptor_set_idx < ctx->descriptor_sets.size()); + GGML_ASSERT(descriptor_buffer_infos.size() <= MAX_PARAMETER_COUNT); + GGML_ASSERT(pipeline->parameter_count == descriptor_buffer_infos.size()); + + vk::DescriptorSet& descriptor_set = ctx->descriptor_sets[ctx->descriptor_set_idx++]; + vk::WriteDescriptorSet write_descriptor_set{ descriptor_set, 0, 0, pipeline->parameter_count, vk::DescriptorType::eStorageBuffer, nullptr, descriptor_buffer_infos.begin() }; + ctx->device->device.updateDescriptorSets({ write_descriptor_set }, {}); + + subctx->s->buffer.pushConstants(pipeline->layout, vk::ShaderStageFlagBits::eCompute, 0, push_constant_size(push_constants), push_constant_data(push_constants)); + subctx->s->buffer.bindPipeline(vk::PipelineBindPoint::eCompute, pipeline->pipeline); + subctx->s->buffer.bindDescriptorSets(vk::PipelineBindPoint::eCompute, + pipeline->layout, + 0, + { descriptor_set }, + {}); + subctx->s->buffer.dispatch(wg0, wg1, wg2); +} + +static void ggml_vk_end_submission(vk_submission& s, std::vector wait_semaphores, std::vector signal_semaphores) { + s.buffer.end(); + + s.wait_semaphores = std::move(wait_semaphores); + s.signal_semaphores = std::move(signal_semaphores); +} + +static void ggml_vk_ctx_end(vk_context& ctx) { + VK_LOG_DEBUG("ggml_vk_ctx_end(" << ctx << ", " << ctx->seqs.size() << ")"); + if (ctx->s == nullptr) { + return; + } + + ctx->s->buffer.end(); + ctx->s = nullptr; +} + +static void ggml_vk_ctx_begin(vk_device& device, vk_context& subctx) { + VK_LOG_DEBUG("ggml_vk_ctx_begin(" << device->name << ")"); + if (subctx->s != nullptr) { + ggml_vk_ctx_end(subctx); + } + + subctx->seqs.push_back({ ggml_vk_begin_submission(device, *subctx->p) }); + subctx->s = subctx->seqs[subctx->seqs.size() - 1].data(); +} + +static size_t ggml_vk_align_size(size_t width, size_t align) { + VK_LOG_DEBUG("ggml_vk_align_size(" << width << ", " << align << ")"); + return CEIL_DIV(width, align) * align; +} + +static void deferred_memcpy(void * dst, const void * src, size_t size, std::vector* memcpys = nullptr) { + if (memcpys == nullptr) { + memcpy(dst, src, size); + } else { + memcpys->emplace_back(dst, src, size); + } +} + +static void deferred_memset(void * dst, uint32_t val, size_t size, std::vector* memsets = nullptr) { + if (memsets == nullptr) { + memset(dst, val, size); + } else { + memsets->emplace_back(dst, val, size); + } +} + +static void ggml_vk_ensure_sync_staging_buffer(vk_device& device, size_t size) { + if (device->sync_staging == nullptr || device->sync_staging->size < size) { + VK_LOG_MEMORY("ggml_vk_ensure_sync_staging_buffer(" << size << ")"); + ggml_vk_destroy_buffer(device->sync_staging); + device->sync_staging = ggml_vk_create_buffer_check(device, size, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent | vk::MemoryPropertyFlagBits::eHostCached, + vk::MemoryPropertyFlagBits::eHostVisible | vk::MemoryPropertyFlagBits::eHostCoherent); + } +} + +static void ggml_vk_buffer_write_nc_async(ggml_backend_vk_context * ctx, vk_context& subctx, vk_buffer& dst, size_t offset, const ggml_tensor * tensor, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_nc_async(" << tensor << ")"); + GGML_ASSERT(!ggml_is_contiguous(tensor)); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_nc_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(ctx->device, tensor->data, buf, buf_offset); + + const uint64_t ne0 = tensor->ne[0]; + const uint64_t ne1 = tensor->ne[1]; + const uint64_t ne2 = tensor->ne[2]; + const uint64_t ne3 = tensor->ne[3]; + const uint64_t nb0 = tensor->nb[0]; + const uint64_t nb1 = tensor->nb[1]; + const uint64_t nb2 = tensor->nb[2]; + const uint64_t nb3 = tensor->nb[3]; + const ggml_type type = tensor->type; + const uint64_t ts = ggml_type_size(type); + const uint64_t bs = ggml_blck_size(type); + + const uint64_t dstnb0 = ts; + const uint64_t dstnb1 = dstnb0*(ne0/bs); + const uint64_t dstnb2 = dstnb1*ne1; + const uint64_t dstnb3 = dstnb2*ne2; + + const uint64_t ne = ggml_nelements(tensor); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices; + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2, offset + i3*dstnb3 + i2*dstnb2, dstnb2 }); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + slices.push_back({ buf_offset + i3*nb3 + i2*nb2 + i1*nb1, offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, dstnb1 }); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = offset + i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + slices.push_back({ s_off + i1*nb0, d_off + i0*dstnb0, dstnb0 }); + } + } + } + } + } + } + + ggml_vk_sync_buffers(ctx, subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + vk_buffer& staging = ctx->device->sync_staging; + const uint64_t copy_size = ts*ne/bs; + ggml_vk_ensure_sync_staging_buffer(ctx->device, copy_size); + VkBufferCopy buf_copy{ 0, offset, copy_size }; + + ggml_vk_sync_buffers(ctx, subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + for (uint64_t i3 = 0; i3 < ne3; i3++) { + for (uint64_t i2 = 0; i2 < ne2; i2++) { + // Find longest contiguous slice + if (ne1*nb1 == dstnb2) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2, dstnb2, &subctx->in_memcpys); + } else { + for (uint64_t i1 = 0; i1 < ne1; i1++) { + if (ne0*nb0/bs == dstnb1) { + deferred_memcpy((uint8_t *)staging->ptr + i3*dstnb3 + i2*dstnb2 + i1*dstnb1, (const uint8_t *) tensor->data + buf_offset + i3*nb3 + i2*nb2 + i1*nb1, dstnb1, &subctx->in_memcpys); + } else { + const uint64_t s_off = buf_offset + i3*nb3 + i2*nb2 + i1*nb1; + const uint64_t d_off = i3*dstnb3 + i2*dstnb2 + i1*dstnb1; + for (uint64_t i0 = 0; i0 < ne0; i0++) { + deferred_memcpy((uint8_t *)staging->ptr + d_off + i0*dstnb0, (const uint8_t *) tensor->data + s_off + i0*nb0, dstnb0, &subctx->in_memcpys); + } + } + } + } + } + } +} + +static void ggml_vk_buffer_write_2d_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d_async(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + std::cerr << "ggml_vulkan: buffer_write_async dst buffer is host_visible. Use synchronous write." << std::endl; + GGML_ABORT("fatal error"); + } + // Check if src is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(dst->device, src, buf, buf_offset); + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + std::vector slices(1); + if (width == spitch) { + // Only do single write if stride is equal + slices[0].srcOffset = buf_offset; + slices[0].dstOffset = offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = buf_offset + i * spitch; + slices[i].dstOffset = offset + i * width; + slices[i].size = width; + } + } + + ggml_vk_sync_buffers(nullptr, subctx); + subctx->s->buffer.copyBuffer(buf->buffer, dst->buffer, slices); + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous write to non-pinned memory not supported"); + } + + // Staging buffer required + const size_t copy_size = width*height; + ggml_vk_ensure_sync_staging_buffer(dst->device, copy_size); + + vk_buffer& staging_buffer = dst->device->sync_staging; + + VkBufferCopy buf_copy = { + 0, + offset, + copy_size}; + + ggml_vk_sync_buffers(nullptr, subctx); + vkCmdCopyBuffer(subctx->s->buffer, (VkBuffer)staging_buffer->buffer, (VkBuffer)dst->buffer, 1, &buf_copy); + + if (width == spitch) { + deferred_memcpy((uint8_t *)staging_buffer->ptr, src, width * height, &subctx->in_memcpys); + } else { + for (size_t i = 0; i < height; i++) { + deferred_memcpy((uint8_t *)staging_buffer->ptr + i * width, (const uint8_t *) src + i * spitch, width, &subctx->in_memcpys); + } + } +} + +static void ggml_vk_buffer_write_async(vk_context subctx, vk_buffer& dst, size_t offset, const void * src, size_t size, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_write_async(" << size << ")"); + return ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_write_2d(vk_buffer& dst, size_t offset, const void * src, size_t spitch, size_t width, size_t height) { + VK_LOG_DEBUG("ggml_vk_buffer_write_2d(" << width << ", " << height << ")"); + // Buffer is already mapped + if(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible) { + GGML_ASSERT(dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + for (size_t i = 0; i < height; i++) { + memcpy((uint8_t *)dst->ptr + offset + i * width, (const uint8_t *) src + i * spitch, width); + } + } else { + std::lock_guard guard(dst->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(dst->device, subctx); + ggml_vk_buffer_write_2d_async(subctx, dst, offset, src, spitch, width, height, true); + ggml_vk_ctx_end(subctx); + + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_buffer_write_2d waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); + } +} + +static void ggml_vk_buffer_write(vk_buffer& dst, size_t offset, const void * src, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_write(" << size << ")"); + ggml_vk_buffer_write_2d(dst, offset, src, 0, size, 1); +} + +static void ggml_vk_buffer_read_2d_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t spitch, size_t dpitch, size_t width, size_t height, bool sync_staging = false) { + VK_LOG_DEBUG("ggml_vk_buffer_read_2d_async(offset=" << offset << ", width=" << width << ", height=" << height << ")"); + GGML_ASSERT(width > 0); + GGML_ASSERT(height > 0); + GGML_ASSERT(src != nullptr); + + // TODO: staging_offset is not used + + // Check if dst is pinned memory + vk_buffer buf = nullptr; + size_t buf_offset = 0; + ggml_vk_host_get(src->device, dst, buf, buf_offset); + + std::vector slices(1); + if (width == spitch && width == dpitch) { + // Only do single write if stride is equal + slices[0].srcOffset = offset; + slices[0].dstOffset = buf_offset; + slices[0].size = width * height; + } else { + slices.resize(height); + for (size_t i = 0; i < height; i++) { + slices[i].srcOffset = offset + i * spitch; + slices[i].dstOffset = buf_offset + i * dpitch; + slices[i].size = width; + } + } + + if (buf != nullptr) { + // Memory is pinned, use as staging buffer + ggml_vk_sync_buffers(nullptr, subctx); + subctx->s->buffer.copyBuffer(src->buffer, buf->buffer, slices); + + return; + } + VK_LOG_DEBUG("STAGING"); + + if (!sync_staging) { + GGML_ABORT("Asynchronous read from non-pinned memory not supported"); + } + + // Fall back to staging buffer + const size_t copy_size = dpitch * height; + ggml_vk_ensure_sync_staging_buffer(src->device, copy_size); + + vk_buffer& staging_buffer = src->device->sync_staging; + + ggml_vk_sync_buffers(nullptr, subctx); + subctx->s->buffer.copyBuffer(src->buffer, staging_buffer->buffer, slices); + + deferred_memcpy(dst, staging_buffer->ptr, copy_size, &subctx->out_memcpys); +} + +static void ggml_vk_buffer_read_async(vk_context subctx, vk_buffer& src, size_t offset, void * dst, size_t size, bool sync_staging = false) { + return ggml_vk_buffer_read_2d_async(subctx, src, offset, dst, size, size, size, 1, sync_staging); +} + +static void ggml_vk_buffer_read(vk_buffer& src, size_t offset, void * dst, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_read(" << src->buffer << ", " << offset << ", " << size << ")"); + + // If the device is not an UMA device the memory is host-accessible through rebar. While writing + // through PCIe is sufficient fast reading back data from PCIe is slower than going through + // the HW device to host copy path. + if(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && src->device->uma) { + GGML_ASSERT(src->memory_property_flags & vk::MemoryPropertyFlagBits::eHostCoherent); + + memcpy(dst, (uint8_t *) src->ptr + offset, size); + } else { + std::lock_guard guard(src->device->mutex); + + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_read_async(subctx, src, offset, dst, size, true); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_read waitForFences"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); + + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + } +} + +static void ggml_vk_buffer_copy_async(vk_context& ctx, vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_copy_async(" << size << ")"); + // Make sure both buffers are on same device + GGML_ASSERT(src->device == dst->device); + + VkBufferCopy bc{ src_offset, dst_offset, size }; + + vkCmdCopyBuffer(ctx->s->buffer, (VkBuffer)src->buffer, (VkBuffer)dst->buffer, 1, &bc); +} + +static void ggml_vk_buffer_copy(vk_buffer& dst, size_t dst_offset, vk_buffer& src, size_t src_offset, size_t size) { + if (src->device == dst->device) { + std::lock_guard guard(src->device->mutex); + VK_LOG_DEBUG("ggml_vk_buffer_copy(SINGLE_DEVICE, " << size << ")"); + // Copy within the device + vk_context subctx = ggml_vk_create_temporary_context(src->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(src->device, subctx); + ggml_vk_buffer_copy_async(subctx, dst, dst_offset, src, src_offset, size); + ggml_vk_ctx_end(subctx); + ggml_vk_submit(subctx, src->device->fence); + VK_CHECK(src->device->device.waitForFences({ src->device->fence }, true, UINT64_MAX), "vk_buffer_copy waitForFences"); + src->device->device.resetFences({ src->device->fence }); + ggml_vk_queue_command_pools_cleanup(src->device); + } else { + VK_LOG_DEBUG("ggml_vk_buffer_copy(MULTI_DEVICE, " << size << ")"); + // Copy device to device + ggml_vk_ensure_sync_staging_buffer(src->device, size); + ggml_vk_ensure_sync_staging_buffer(dst->device, size); + + // Copy to src staging buffer + ggml_vk_buffer_copy(src->device->sync_staging, 0, src, src_offset, size); + // memcpy to dst staging buffer + memcpy(dst->device->sync_staging->ptr, src->device->sync_staging->ptr, size); + // Copy to dst buffer + ggml_vk_buffer_copy(dst, dst_offset, dst->device->sync_staging, 0, size); + } +} + +static void ggml_vk_buffer_memset_async(vk_context& ctx, vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset_async(" << offset << ", " << c << ", " << size << ")"); + + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + deferred_memset((uint8_t*)dst->ptr + offset, c, size, &ctx->memsets); + return; + } + + // Fall back to GPU fillBuffer for non-UMA or non-host-visible buffers + ctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); +} + +static void ggml_vk_buffer_memset(vk_buffer& dst, size_t offset, uint32_t c, size_t size) { + VK_LOG_DEBUG("ggml_vk_buffer_memset(" << offset << ", " << c << ", " << size << ")"); + + if (dst->memory_property_flags & vk::MemoryPropertyFlagBits::eHostVisible && + dst->device->uma) { + memset((uint8_t*)dst->ptr + offset, c, size); + return; + } + + std::lock_guard guard(dst->device->mutex); + vk_context subctx = ggml_vk_create_temporary_context(dst->device->transfer_queue.cmd_pool); + ggml_vk_ctx_begin(dst->device, subctx); + subctx->s->buffer.fillBuffer(dst->buffer, offset, size, c); + ggml_vk_ctx_end(subctx); + + ggml_vk_submit(subctx, dst->device->fence); + VK_CHECK(dst->device->device.waitForFences({ dst->device->fence }, true, UINT64_MAX), "vk_memset waitForFences"); + dst->device->device.resetFences({ dst->device->fence }); + ggml_vk_queue_command_pools_cleanup(dst->device); +} + +static uint32_t ggml_vk_guess_split_k(ggml_backend_vk_context * ctx, uint32_t m, uint32_t n, uint32_t k, bool disable_split_k, const vk_pipeline& pipeline) { + VK_LOG_DEBUG("ggml_vk_guess_split_k(" << m << ", " << n << ", " << k << ", " << disable_split_k << ")"); + + if (disable_split_k) { + return 1; + } + + uint32_t split_k = 1; + if (ctx->device->shader_core_count != 0 && m >= pipeline->wg_denoms[0] && n >= pipeline->wg_denoms[1]) { + // If k is 'large' and the SMs will fill less than halfway, use split_k. + uint32_t m_tiles = CEIL_DIV(m, pipeline->wg_denoms[0]); + uint32_t n_tiles = CEIL_DIV(n, pipeline->wg_denoms[1]); + + if (k >= 2048) { + if (m_tiles * n_tiles <= ctx->device->shader_core_count / 2) { + split_k = ctx->device->shader_core_count / (m_tiles * n_tiles); + } else if (m_tiles * n_tiles <= ctx->device->shader_core_count * 2 / 3) { + split_k = 3; + } + // Cap the split at 8x. Unless k is huge this is a lot of overhead. + split_k = std::min(split_k, 8u); + + // ggml_vk_matmul will align the splits to be a multiple of 256. + // If this rounded up size would cause the last split to be empty, + // then reduce the split count. + while (true) { + if (split_k == 1) { + break; + } + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + if (k_split * (split_k - 1) < k) { + break; + } + split_k--; + } + } + } + + return split_k; +} + +static vk_pipeline ggml_vk_guess_matmul_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + + if (ctx->device->coopmat2) { + const uint32_t shader_core_count = ctx->device->shader_core_count; + const uint32_t tiles_l = CEIL_DIV(m, mmp->a_l->wg_denoms[0]) * CEIL_DIV(n, mmp->a_l->wg_denoms[1]); + const uint32_t tiles_m = CEIL_DIV(m, mmp->a_m->wg_denoms[0]) * CEIL_DIV(n, mmp->a_m->wg_denoms[1]); + + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + + // Prefer large over medium if either: + // - medium or large tiles would overfill the GPU + // - large tiles with a split_k==3 fits in the GPU and medium tiles with split_k==2 does not + // (medium with split_k==2 is probably better if it fits - more workgroups running and less split_k overhead) + bool prefer_large = tiles_m > shader_core_count || tiles_l > shader_core_count || + // split_k==3 with large tiles likely better than medium tiles with no split_k. + (tiles_l <= shader_core_count / 3 && tiles_m > shader_core_count / 2); + + if ((ctx->device->mul_mat_l[src0_type] && (n > crossover_large && prefer_large)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_s[src0_type])) { + return aligned ? mmp->a_l : mmp->l; + } + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_s[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_m[src0_type] && !ctx->device->mul_mat_l[src0_type])) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_l[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; + + GGML_UNUSED(src1_type); +} + +static uint32_t ggml_vk_guess_matmul_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type, ggml_type src1_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ", " << ggml_type_name(src1_type) << ")"); + return ggml_vk_guess_matmul_pipeline(ctx, mmp, m, n, true, src0_type, src1_type)->align; +} + +static void ggml_vk_matmul( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& split_k_buffer, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t split_k, uint32_t batch, uint32_t ne02, uint32_t ne12, uint32_t broadcast2, uint32_t broadcast3, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), split_k: (" << (split_k_buffer.buffer != nullptr ? split_k_buffer.buffer->buffer : VK_NULL_HANDLE) << ", " << split_k_buffer.offset << ", " << split_k_buffer.size << "), m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", split_k: " << split_k << ", batch: " << batch << ", ne02: " << ne02 << ", ne12: " << ne12 << ", broadcast2: " << broadcast2 << ", broadcast3: " << broadcast3 << ", padded_n: " << padded_n << ")"); + if (split_k == 1) { + const vk_mat_mat_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k, ne02, ne12, broadcast2, broadcast3, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d }, pc, { m, n, batch }); + return; + } + + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + GGML_ASSERT(batch_stride_d == m * n); + + // Round the split size up to a multiple of 256 (k-quant alignment) + uint32_t k_split = CEIL_DIV(k, split_k); + k_split = ROUNDUP_POW2(k_split, 256); + + const vk_mat_mat_push_constants pc1 = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, k_split, ne02, ne12, broadcast2, broadcast3, padded_n }; + // Make sure enough workgroups get assigned for split k to work + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, split_k_buffer }, pc1, { (CEIL_DIV(m, pipeline->wg_denoms[0]) * pipeline->wg_denoms[0]) * split_k, n, batch }); + ggml_vk_sync_buffers(ctx, subctx); + const std::array pc2 = { (uint32_t)(m * n * batch), split_k }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_matmul_split_k_reduce, { split_k_buffer, d }, pc2, { m * n * batch, 1, 1 }); + ctx->prealloc_split_k_need_sync = true; +} + +static vk_pipeline ggml_vk_guess_matmul_id_pipeline(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, uint32_t m, uint32_t n, bool aligned, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_id_pipeline(" << m << ", " << n << ", " << aligned << ", " << ggml_type_name(src0_type) << ")"); + + if (ctx->device->coopmat2) { + // Use large shader when the N dimension is greater than the medium shader's tile size + uint32_t crossover_large = mmp->m->wg_denoms[1]; + if ((ctx->device->mul_mat_id_l[src0_type] && (n > crossover_large)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_s[src0_type])) { + return aligned ? mmp->a_l : mmp->l; + } + // Use medium shader when the N dimension is greater than the small shader's tile size + uint32_t crossover_medium = mmp->s->wg_denoms[1]; + if ((ctx->device->mul_mat_id_m[src0_type] && (n > crossover_medium)) || !ctx->device->mul_mat_id_s[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_s : mmp->s; + } + + if ((ctx->device->mul_mat_id_s[src0_type] && (m <= 32 || n <= 32)) || (!ctx->device->mul_mat_id_m[src0_type] && !ctx->device->mul_mat_id_l[src0_type])) { + return aligned ? mmp->a_s : mmp->s; + } + if ((ctx->device->mul_mat_id_m[src0_type] && (m <= 64 || n <= 64)) || !ctx->device->mul_mat_id_l[src0_type]) { + return aligned ? mmp->a_m : mmp->m; + } + return aligned ? mmp->a_l : mmp->l; +} + +static uint32_t ggml_vk_guess_matmul_id_pipeline_align(ggml_backend_vk_context * ctx, vk_matmul_pipeline& mmp, int m, int n, ggml_type src0_type) { + VK_LOG_DEBUG("ggml_vk_guess_matmul_pipeline_align(" << m << ", " << n << ", " << ggml_type_name(src0_type) << ")"); + return ggml_vk_guess_matmul_id_pipeline(ctx, mmp, m, n, true, src0_type)->align; +} + +static void ggml_vk_matmul_id( + ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline& pipeline, + vk_subbuffer&& a, vk_subbuffer&& b, vk_subbuffer&& d, vk_subbuffer&& ids, + uint32_t m, uint32_t n, uint32_t k, uint32_t stride_a, uint32_t stride_b, uint32_t stride_d, + uint32_t batch_stride_a, uint32_t batch_stride_b, uint32_t batch_stride_d, + uint32_t n_as, uint32_t nei0, uint32_t nei1, uint32_t nbi1, uint32_t ne11, + uint32_t padded_n) { + VK_LOG_DEBUG("ggml_vk_matmul_id(a: (" << a.buffer->buffer << ", " << a.offset << ", " << a.size << "), b: (" << b.buffer->buffer << ", " << b.offset << ", " << b.size << "), d: (" << d.buffer->buffer << ", " << d.offset << ", " << d.size << "), ids: (" << ids.buffer->buffer << ", " << ids.offset << ", " << ids.size << "), " << + "m: " << m << ", n: " << n << ", k: " << k << ", stride_a: " << stride_a << ", stride_b: " << stride_b << ", stride_d: " << stride_d << ", " << + "batch_stride_a: " << batch_stride_a << ", batch_stride_b: " << batch_stride_b << ", batch_stride_d: " << batch_stride_d << ", " << + "n_as: " << n_as << ", nei0: " << nei0 << ", nei1: " << nei1 << ", nbi1: " << nbi1 << ", ne11: " << ne11 << ")"); + const vk_mat_mat_id_push_constants pc = { m, n, k, stride_a, stride_b, stride_d, batch_stride_a, batch_stride_b, batch_stride_d, + nei0, nei1, nbi1, ne11, padded_n }; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { a, b, d, ids }, pc, { m, nei1, n_as }); +} + +static bool ggml_vk_dim01_contiguous(const ggml_tensor * tensor) { + return + tensor->nb[0] == ggml_type_size(tensor->type) && + tensor->nb[1] == (tensor->nb[0]*tensor->ne[0])/ggml_blck_size(tensor->type) && + (tensor->ne[3] == 1 || tensor->nb[3] == tensor->nb[2]*tensor->ne[2]); +} + +static vk_pipeline ggml_vk_get_cpy_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src, const ggml_tensor * dst, ggml_type to) { + + // Choose "contiguous copy" shader if src/dst are contiguous + bool contig = ggml_is_contiguous(src) && (!dst || ggml_is_contiguous(dst)); + + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f16; + } else { + return ctx->device->pipeline_cpy_f32_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + if (src->type == GGML_TYPE_F16 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f32; + } else { + return ctx->device->pipeline_cpy_f16_f32; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_BF16) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_bf16; + } else { + return ctx->device->pipeline_cpy_f32_bf16; + } + } + if (src->type == GGML_TYPE_F32 && to == GGML_TYPE_I32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_i32; + } else { + return ctx->device->pipeline_cpy_f32_i32; + } + } + if (src->type == GGML_TYPE_I32 && to == GGML_TYPE_F32) { + if (contig) { + return ctx->device->pipeline_contig_cpy_i32_f32; + } else { + return ctx->device->pipeline_cpy_i32_f32; + } + } + if (src->type == GGML_TYPE_F32) { + switch (to) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_f32_quant[to]; + default: + break; + } + } + + if (to == GGML_TYPE_F32) { + switch (src->type) { + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return ctx->device->pipeline_cpy_quant_f32[src->type]; + default: + break; + } + } + + if (src->type == to) { + // Copy two or four bytes at a time, depending on block size. + // For quantized types, we scale by block size/type size. But + // this path is also used for bf16->bf16 for example, where the + // type size must be exactly 2 or 4. + GGML_ASSERT(ggml_is_quantized(to) || ggml_type_size(src->type) == 2 || ggml_type_size(src->type) == 4); + if ((ggml_type_size(src->type) % 4) == 0) { + if (contig) { + return ctx->device->pipeline_contig_cpy_f32_f32; + } else { + return ctx->device->pipeline_cpy_f32_f32; + } + } else { + if (contig) { + return ctx->device->pipeline_contig_cpy_f16_f16; + } else { + return ctx->device->pipeline_cpy_f16_f16; + } + } + } + + std::cerr << "Missing CPY op for types: " << ggml_type_name(src->type) << " " << ggml_type_name(to) << std::endl; + GGML_ABORT("fatal error"); +} + +static void ggml_vk_cpy_to_contiguous(ggml_backend_vk_context * ctx, vk_context& subctx, vk_pipeline pipeline, const ggml_tensor * tensor, vk_subbuffer&& in, vk_subbuffer&& out) { + VK_LOG_DEBUG("ggml_vk_cpy_to_contiguous((" << tensor << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << "), "; + std::cerr << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ")"); + const int tensor_type_size = ggml_type_size(tensor->type); + + const uint32_t ne = ggml_nelements(tensor); + std::array elements; + + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + vk_op_unary_push_constants pc = { + (uint32_t)ne, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], (uint32_t)tensor->nb[0] / tensor_type_size, (uint32_t)tensor->nb[1] / tensor_type_size, (uint32_t)tensor->nb[2] / tensor_type_size, (uint32_t)tensor->nb[3] / tensor_type_size, + (uint32_t)tensor->ne[0], (uint32_t)tensor->ne[1], (uint32_t)tensor->ne[2], (uint32_t)tensor->ne[3], 1 , (uint32_t)tensor->ne[0] , (uint32_t)(tensor->ne[0] * tensor->ne[1]) , (uint32_t)(tensor->ne[0] * tensor->ne[1] * tensor->ne[2]), + 0, + 0.0f, 0.0f, + 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, + }; + init_pushconst_fastdiv(pc); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, pc, elements); + ggml_vk_sync_buffers(ctx, subctx); +} + +static vk_pipeline ggml_vk_get_quantize_pipeline(ggml_backend_vk_context * ctx, ggml_type type, bool use_x4_blocks) { + switch(type) { + case GGML_TYPE_Q8_1: + return use_x4_blocks ? ctx->device->pipeline_quantize_q8_1_x4 : ctx->device->pipeline_quantize_q8_1; + default: + std::cerr << "Missing quantize pipeline for type: " << ggml_type_name(type) << std::endl; + GGML_ABORT("fatal error"); + } +} + +static void ggml_vk_quantize_q8_1(ggml_backend_vk_context * ctx, vk_context& subctx, vk_subbuffer&& in, vk_subbuffer&& out, uint32_t ne, bool use_x4_blocks = false) { + VK_LOG_DEBUG("ggml_vk_quantize_q8_1(" << "buffer in size=" << in.buffer->size << ", buffer out size=" << out.buffer->size << ", " << ne << ")"); + + vk_pipeline pipeline = use_x4_blocks ? ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true) : ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, false); + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { in, out }, std::array{ne}, { ne, 1, 1 }); + ggml_vk_sync_buffers(ctx, subctx); +} + +static void ggml_vk_mul_mat_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool disable_split_k, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_q_f16((" << src0 << ", name=" << src0->name << ", type=" << ggml_type_name(src0->type) << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << ggml_type_name(src1->type) << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << ggml_type_name(dst->type) << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne21 = dst->ne[1]; + const uint32_t stride_d = dst->nb[1] / ggml_type_size(dst->type); + const uint32_t stride_batch_d = stride_d*ne21; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || + !ggml_vk_dim01_contiguous(src1); + + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0; + + // Check for mmq first + vk_matmul_pipeline mmp = quantize_y ? ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, GGML_TYPE_Q8_1, (ggml_prec)dst->op_params[0]) : nullptr; + + if (mmp == nullptr) { + // Fall back to f16 dequant mul mat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + quantize_y = false; + } + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != f16_type && !y_f32_kernel) || y_non_contig); + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint32_t kpad = quantize_y ? 0 : ggml_vk_align_size(ne10, ggml_vk_guess_matmul_pipeline_align(ctx, mmp, ne01, ne11, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type))); + const bool aligned = !quantize_y && ne10 == kpad && ne01 > 8 && ne11 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_pipeline(ctx, mmp, ne01, ne11, aligned, qx_needs_dequant ? f16_type : src0->type, quantize_y ? GGML_TYPE_Q8_1 : (y_f32_kernel ? GGML_TYPE_F32 : src1->type)); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) : ne11; + const int x_ne = ne01 * ne00; + const int y_ne = padded_n * ne10; + const int d_ne = ne11 * ne01; + + const uint32_t split_k = ggml_vk_guess_split_k(ctx, ne01, ne11, ne10, disable_split_k, pipeline); + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } + const uint64_t split_k_size = split_k > 1 ? d_sz * ne12 * ne13 * split_k : 0; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (split_k > 1 && split_k_size > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + if (split_k > 1 && ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + GGML_ASSERT(d_D->size >= d_buf_offset + d_sz * ne02 * ne03); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); + } + if (y_non_contig) { + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant && !quantize_y) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + + // compute + ggml_vk_matmul( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz_total }, + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), { ctx->prealloc_split_k, 0, d_sz * ne12 * ne13 * split_k }, + ne01, ne11, ne10, + ne10, ne10, stride_d, stride_batch_x, stride_batch_y, stride_batch_d, + split_k, ne12*ne13, ne02, ne12, r2, r3, padded_n + ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } +} + +// Device tuning +static bool ggml_vk_should_use_mmvq(const vk_device& device, uint32_t m, uint32_t n, uint32_t k, ggml_type src0_type) { + if (device->mmvq_mode == 1) { + return true; + } else if (device->mmvq_mode == -1) { + return false; + } + + // MMVQ is generally good for batches + if (n > 1) { + return true; + } + + switch (device->vendor_id) { + case VK_VENDOR_ID_NVIDIA: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::NVIDIA_PRE_TURING; + default: + return true; + } + case VK_VENDOR_ID_AMD: + switch (src0_type) { + case GGML_TYPE_Q8_0: + return device->architecture == vk_device_architecture::AMD_GCN; + default: + return true; + } + case VK_VENDOR_ID_INTEL: + switch (src0_type) { + // From tests on A770 Linux, may need more tuning + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q5_1: + return false; + default: + return true; + } + default: + return true; + } + + GGML_UNUSED(m); + GGML_UNUSED(k); +} + +static void ggml_vk_mul_mat_vec_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t r2 = ne12 / ne02; + const uint64_t r3 = ne13 / ne03; + + // batch_n indicates that we need to compute a few vector results, and this assumes + // ne12 and ne13 are 1. It overloads the batch_strides to hold the row strides. + GGML_ASSERT(ne11 == 1 || ne12 * ne13 == 1); + bool batch_n = ne11 > 1; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + bool quantize_y = ctx->device->integer_dot_product && src1->type == GGML_TYPE_F32 && ggml_is_contiguous(src1) && (ne11 * ne10) % 4 == 0 && ggml_vk_should_use_mmvq(ctx->device, ne01, ne11, ne10, src0->type); + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + + // Check for mmq first + vk_pipeline dmmv = quantize_y ? ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, GGML_TYPE_Q8_1, ne11, ne20, ne00) : nullptr; + vk_pipeline to_q8_1 = nullptr; + + if (dmmv == nullptr) { + // Fall back to f16 dequant mul mat + dmmv = ggml_vk_get_dequantize_mul_mat_vec(ctx, src0->type, src1->type, ne11, ne20, ne00); + quantize_y = false; + } + + if (quantize_y) { + to_q8_1 = ggml_vk_get_quantize_pipeline(ctx, GGML_TYPE_Q8_1, true); + } + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = !quantize_y && ((src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig); + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne11 * ne01; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = quantize_y ? (y_ne * ggml_type_size(GGML_TYPE_Q8_1) / ggml_blck_size(GGML_TYPE_Q8_1)) : (f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + uint64_t y_sz_upd = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_upd = CEIL_DIV(y_sz_upd, 144) * 144; + } + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if ((qy_needs_dequant || quantize_y) && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + if (quantize_y) { + ggml_pipeline_request_descriptor_sets(ctx, to_q8_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else if (quantize_y) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= CEIL_DIV(y_sz * ne12 * ne13, 144) * 144); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + if (quantize_y) { + if (ctx->prealloc_y_last_pipeline_used != to_q8_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0), y_ne * ne12 * ne13, true); + ctx->prealloc_y_last_pipeline_used = to_q8_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + // For batch_n, the A matrix is the same for each batch, and B/D use the row stride as the batch stride + uint32_t stride_batch_x = batch_n ? 0 : ne00*ne01; + uint32_t stride_batch_y = batch_n ? ne10 : (ne10*ne11); + uint32_t stride_batch_d = batch_n ? ne20 : (ne20*ne21); + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // TODO: Clean up this whole sz * ne_2 * ne_3 thing, it hasn't been necessary for a long time + uint32_t y_sz_total = y_sz * ne12 * ne13; + if (quantize_y) { + y_sz_total = CEIL_DIV(y_sz_total, 144) * 144; + } + + // compute + const vk_mat_vec_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + stride_batch_x, stride_batch_y, stride_batch_d, + (uint32_t)ne02, (uint32_t)ne12, (uint32_t)r2, (uint32_t)r3, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, vk_subbuffer{ d_Y, y_buf_offset, y_sz_total }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23} }, + pc, { groups_x, (uint32_t)(ne12 * ne13), groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig || quantize_y) { + ctx->prealloc_y_need_sync = true; + } +} + +static void ggml_vk_mul_mat_vec_p021_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_p021_f16_f32(" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_is_permuted(src0) && ggml_is_permuted(src1)); + GGML_ASSERT(src0->nb[0] <= src0->nb[1] && src0->nb[2] <= src0->nb[3]); // NOLINT + GGML_ASSERT(src1->nb[0] <= src1->nb[1] && src1->nb[2] <= src1->nb[3]); // NOLINT + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + // const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + GGML_ASSERT(ne11 == 1); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t x_ne = ne00 * ne01 * ne02; + const uint64_t y_ne = ne10 * ne11 * ne12; + const uint64_t d_ne = ne01 * ne11 * ne12; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t d_sz = sizeof(float) * d_ne; + + // With grouped query attention there are > 1 Q matrices per K, V matrix. + uint32_t gqa_ratio = (uint32_t)ne12 / (uint32_t)ne02; + if (gqa_ratio > 8 || gqa_ratio == 0 || ne12 != ne02 * gqa_ratio) { + gqa_ratio = 1; + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)) }; + + uint32_t workgroups_z = (uint32_t)ne12; + // When gqa_ratio > 1, each invocation does multiple rows and we can launch fewer workgroups + if (gqa_ratio > 1) { + workgroups_z /= gqa_ratio; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_p021_f16_f32[gqa_ratio - 1], { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { 1, (uint32_t)ne01, workgroups_z }); +} + +static void ggml_vk_mul_mat_vec_nc_f16_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_nc_f16_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(!ggml_is_transposed(src0)); + GGML_ASSERT(!ggml_is_transposed(src1)); + GGML_ASSERT(!ggml_is_permuted(src0)); + GGML_ASSERT(src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t nb01 = src0->nb[1]; + const uint64_t nb02 = src0->nb[2]; + + const uint64_t nb12 = src1->nb[2]; + + // const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + // const uint64_t ne13 = src1->ne[3]; + + const uint32_t nb03 = (uint32_t)(src0->nb[3] / sizeof(ggml_fp16_t)); + const uint32_t nb13 = (uint32_t)(src1->nb[3] / sizeof(float)); + const uint32_t nb23 = (uint32_t)(dst->nb[3] / sizeof(float)); + + GGML_ASSERT(ne11 == 1); + GGML_ASSERT(src0->ne[3] == src1->ne[3]); // checked in supports_op + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + + bool src1_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + src1_uma = d_Qy != nullptr; + } + + const uint64_t d_ne = ne01 * ne11 * ne12 * ne03; + + const uint32_t row_stride_x = nb01 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_x = nb02 / sizeof(ggml_fp16_t); + const uint32_t channel_stride_y = nb12 / sizeof(float); + + const uint64_t qx_sz = ggml_nbytes(src0); + const uint64_t qy_sz = ggml_nbytes(src1); + const uint64_t d_sz = sizeof(float) * d_ne; + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_Qx = src0_buf_ctx->dev_buffer; + const uint64_t qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + + const uint64_t qy_buffer_offset = (qy_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t qy_shader_offset = qy_buf_offset - qy_buffer_offset; + + const uint64_t d_buffer_offset = (d_buf_offset / ctx->device->properties.limits.minStorageBufferOffsetAlignment) * ctx->device->properties.limits.minStorageBufferOffsetAlignment; + const uint64_t d_shader_offset = d_buf_offset - d_buffer_offset; + + // compute + const std::array pc = { (uint32_t)ne00, (uint32_t)ne01, row_stride_x, channel_stride_x, channel_stride_y, (uint32_t)(ne12 / ne02), (uint32_t)ne12, (uint32_t)(qy_shader_offset / ggml_type_size(src1->type)), (uint32_t)(d_shader_offset / ggml_type_size(dst->type)), nb03, nb13, nb23 }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_mul_mat_vec_nc_f16_f32, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz }, vk_subbuffer{ d_Qy, qy_buffer_offset, qy_sz + qy_shader_offset }, vk_subbuffer{ d_D, d_buffer_offset, d_sz + d_shader_offset } }, pc, { (uint32_t)ne03, (uint32_t)ne01, (uint32_t)ne12 }); +} + +static void ggml_vk_mul_mat(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * src0, ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat(" << src0 << ", " << src1 << ", " << dst << ")"); + + // Handle huge A matrix by splitting the M dimensions. This works well for convolution use cases + // where the M dimension is very large. + // Split_k doesn't work with M splitting. + const size_t nbytes = ggml_nbytes(src0); + const bool needs_split = nbytes > ctx->device->properties.limits.maxStorageBufferRange; + if (needs_split) { + // Choose the number of rows that can fit (and divide by two, to allow for any additional offsets) + const uint32_t M_split = ctx->device->properties.limits.maxStorageBufferRange / (2 * src0->nb[1]); + uint32_t m_offset = 0; + while (m_offset < dst->ne[0]) { + const uint32_t cur_M_size = std::min(M_split, (uint32_t)(dst->ne[0] - m_offset)); + ggml_tensor dst2 = *dst; + ggml_tensor src02 = *src0; + + dst2.view_src = dst->view_src ? dst->view_src : dst; + src02.view_src = src0->view_src ? src0->view_src : src0; + + dst2.view_offs += m_offset * dst->nb[0]; + src02.view_offs += m_offset * src0->nb[1]; + dst2.ne[0] = cur_M_size; + src02.ne[1] = cur_M_size; + + ggml_vk_mul_mat_q_f16(ctx, subctx, &src02, src1, &dst2, true, dryrun); + + m_offset += cur_M_size; + } + } else if (src0->type == GGML_TYPE_F16 && ggml_is_permuted(src0) && ggml_is_permuted(src1) && dst->ne[1] == 1 && + // detect 0213 permutation, and batch size of 1 + src0->nb[0] <= src0->nb[2] && + src0->nb[2] <= src0->nb[1] && + src0->nb[1] <= src0->nb[3] && + src1->nb[0] <= src1->nb[2] && + src1->nb[2] <= src1->nb[1] && + src1->nb[1] <= src1->nb[3] && + src0->ne[3] == 1 && + src1->ne[3] == 1) { + ggml_vk_mul_mat_vec_p021_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + } else if (src0->type == GGML_TYPE_F16 && !ggml_is_contiguous(src0) && !ggml_is_transposed(src1) && dst->ne[1] == 1 && + !ggml_is_permuted(src0) && !ggml_is_permuted(src1)) { + ggml_vk_mul_mat_vec_nc_f16_f32(ctx, subctx, src0, src1, dst, dryrun); + // mul_mat_vec supports batching ne12*ne13 when ne11==1, or treating ne11 as the batch size (up to four) + // when ne12 and ne13 are one. + } else if ((dst->ne[1] == 1 || (dst->ne[1] <= mul_mat_vec_max_cols && src1->ne[2] * src1->ne[3] == 1)) && + (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_q_f16(ctx, subctx, src0, src1, dst, dryrun); + } else { + ggml_vk_mul_mat_q_f16(ctx, subctx, src0, src1, dst, false, dryrun); + } +} + +static void ggml_vk_mul_mat_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3] << "),)"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + + const uint32_t nbi1 = ids->nb[1]; + const uint32_t nbi2 = ids->nb[2]; + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + const uint64_t n_as = ne02; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + // Reformat and convert to fp16 if non-contiguous, or for coopmat2 for better perf + const bool x_non_contig = (ctx->device->coopmat2 && src0->type == GGML_TYPE_F32) || + !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = (ctx->device->coopmat2 && src1->type == GGML_TYPE_F32) || + (src0->type == GGML_TYPE_BF16 && src1->type != GGML_TYPE_BF16) || + !ggml_vk_dim01_contiguous(src1); + + // If src0 is BF16, try to use a BF16 x BF16 multiply + ggml_type f16_type = src0->type == GGML_TYPE_BF16 ? GGML_TYPE_BF16 : GGML_TYPE_F16; + + const bool y_f32_kernel = src1->type == GGML_TYPE_F32 && !y_non_contig; + + vk_matmul_pipeline mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, src0->type, y_non_contig ? f16_type : src1->type, (ggml_prec)dst->op_params[0]); + + const bool qx_needs_dequant = mmp == nullptr || x_non_contig; + const bool qy_needs_dequant = (src1->type != f16_type && !y_f32_kernel) || y_non_contig; + + if (qx_needs_dequant) { + // Fall back to dequant + f16 mulmat + mmp = ggml_vk_get_mul_mat_mat_id_pipeline(ctx, f16_type, y_f32_kernel ? GGML_TYPE_F32 : f16_type, (ggml_prec)dst->op_params[0]); + } + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint32_t kpad = ggml_vk_align_size(ne10, ggml_vk_guess_matmul_id_pipeline_align(ctx, mmp, ne01, nei1, qx_needs_dequant ? f16_type : src0->type)); + const bool aligned = ne10 == kpad && ne01 > 8 && nei1 > 8; + + vk_pipeline pipeline = ggml_vk_guess_matmul_id_pipeline(ctx, mmp, ne01, nei1, aligned, qx_needs_dequant ? f16_type : src0->type); + + // Reserve extra storage in the N dimension for the Y matrix, so we can avoid bounds-checking + uint32_t padded_n = qy_needs_dequant ? ROUNDUP_POW2(ne11, pipeline->wg_denoms[1]) :ne11; + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = padded_n * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = !qx_needs_dequant ? qx_sz : sizeof(ggml_fp16_t) * x_ne; + const uint64_t y_sz = y_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, f16_type); + } else { + to_fp16_vk_0 = ggml_vk_get_to_fp16(ctx, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, f16_type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if (!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if (!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if (!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + GGML_ASSERT(d_X->size >= x_sz * ne02 * ne03); + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + GGML_ASSERT(d_Y->size >= y_sz * ne12 * ne13); + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig || qx_needs_dequant) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + + if (x_non_contig) { + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } else if (qx_needs_dequant) { + const std::vector pc = { (uint32_t)ne01, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)(ggml_nelements(src0)) }; + ggml_vk_dispatch_pipeline(ctx, subctx, to_fp16_vk_0, + { vk_subbuffer{ d_Qx, qx_buf_offset, qx_sz * ne02 * ne03 }, vk_subbuffer{ d_X, 0, x_sz * ne02 * ne03 } }, pc, { (uint32_t)(x_ne * ne02 * ne03), 1, 1}); + ggml_vk_sync_buffers(ctx, subctx); + } + if (y_non_contig) { + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + uint32_t stride_batch_x = ne00*ne01; + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src0) && !qx_needs_dequant) { + stride_batch_x = src0->nb[0] / ggml_type_size(src0->type); + } + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + // compute + ggml_vk_matmul_id( + ctx, subctx, pipeline, + { d_X, x_buf_offset, x_sz * ne02 * ne03 }, { d_Y, y_buf_offset, y_sz * ne12 * ne13 }, + { d_D, d_buf_offset, d_sz * ne22 * ne23 }, { d_ids, ids_buf_offset, ids_sz }, + ne01, ne21, ne10, ne10, ne10, ne01, + stride_batch_x, stride_batch_y, ne20*ne21, + n_as, nei0, nei1, nbi1 / ggml_type_size(ids->type), ne11, padded_n + ); // NOLINT + + if (x_non_contig || qx_needs_dequant) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } +} + +static void ggml_vk_mul_mat_vec_id_q_f16(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * ids, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_vec_id_q_f16((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + std::cerr << "), (" << ids << ", name=" << ids->name << ", type=" << ids->type << ", ne0=" << ids->ne[0] << ", ne1=" << ids->ne[1] << ", ne2=" << ids->ne[2] << ", ne3=" << ids->ne[3] << ", nb0=" << ids->nb[0] << ", nb1=" << ids->nb[1] << ", nb2=" << ids->nb[2] << ", nb3=" << ids->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(ggml_vk_dim01_contiguous(src0) || src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || src0->type == GGML_TYPE_BF16); // NOLINT + GGML_ASSERT(ggml_vk_dim01_contiguous(src1) || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); // NOLINT + GGML_ASSERT(ids->type == GGML_TYPE_I32); + + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + + const uint64_t ne10 = src1->ne[0]; + const uint64_t ne11 = src1->ne[1]; + const uint64_t ne12 = src1->ne[2]; + const uint64_t ne13 = src1->ne[3]; + + const uint64_t nei0 = ids->ne[0]; + const uint64_t nei1 = ids->ne[1]; + + const uint64_t nbi2 = ids->nb[2]; + + GGML_ASSERT(nei1 == 1); + + const uint64_t ne20 = dst->ne[0]; + const uint64_t ne21 = dst->ne[1]; + const uint64_t ne22 = dst->ne[2]; + const uint64_t ne23 = dst->ne[3]; + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = (ggml_backend_vk_buffer_context *)src1->buffer->context; + ggml_backend_vk_buffer_context * ids_buf_ctx = (ggml_backend_vk_buffer_context *)ids->buffer->context; + + vk_buffer d_Qx = nullptr; + size_t qx_buf_offset = 0; + vk_buffer d_Qy = nullptr; + size_t qy_buf_offset = 0; + vk_buffer d_ids = nullptr; + size_t ids_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool ids_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_Qx, qx_buf_offset); + ggml_vk_host_get(ctx->device, src1->data, d_Qy, qy_buf_offset); + ggml_vk_host_get(ctx->device, ids->data, d_ids, ids_buf_offset); + src0_uma = d_Qx != nullptr; + src1_uma = d_Qy != nullptr; + ids_uma = d_ids != nullptr; + } + + const bool x_non_contig = !ggml_vk_dim01_contiguous(src0); + const bool y_non_contig = !ggml_vk_dim01_contiguous(src1); + + const bool f16_f32_kernel = src1->type == GGML_TYPE_F32; + + const bool qx_needs_dequant = x_non_contig; + const bool qy_needs_dequant = (src1->type != GGML_TYPE_F16 && !f16_f32_kernel) || y_non_contig; + + // Not implemented + GGML_ASSERT(y_non_contig || !qy_needs_dequant); // NOLINT + + const uint64_t x_ne = ne01 * ne00; + const uint64_t y_ne = ne11 * ne10; + const uint64_t d_ne = ne21 * ne20; + + const uint64_t qx_sz = ggml_vk_align_size(ggml_type_size(src0->type) * x_ne / ggml_blck_size(src0->type), ctx->device->properties.limits.minStorageBufferOffsetAlignment); + const uint64_t qy_sz = ggml_type_size(src1->type) * y_ne / ggml_blck_size(src1->type); + const uint64_t x_sz = x_non_contig ? ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment) : qx_sz; + const uint64_t y_sz = f16_f32_kernel ? sizeof(float) * y_ne : sizeof(ggml_fp16_t) * y_ne; + const uint64_t ids_sz = nbi2; + const uint64_t d_sz = sizeof(float) * d_ne; + + vk_pipeline to_fp16_vk_0 = nullptr; + vk_pipeline to_fp16_vk_1 = nullptr; + if (x_non_contig) { + to_fp16_vk_0 = ggml_vk_get_cpy_pipeline(ctx, src0, nullptr, src0->type); + } + if (y_non_contig) { + to_fp16_vk_1 = ggml_vk_get_cpy_pipeline(ctx, src1, nullptr, src1->type); + } else { + to_fp16_vk_1 = ggml_vk_get_to_fp16(ctx, src1->type); + } + vk_pipeline dmmv = ggml_vk_get_dequantize_mul_mat_vec_id(ctx, src0->type, src1->type); + GGML_ASSERT(!qx_needs_dequant || to_fp16_vk_0 != nullptr); // NOLINT + GGML_ASSERT(!qy_needs_dequant || to_fp16_vk_1 != nullptr); // NOLINT + GGML_ASSERT(dmmv != nullptr); + + if (dryrun) { + const uint64_t x_sz_upd = x_sz * ne02 * ne03; + const uint64_t y_sz_upd = y_sz * ne12 * ne13; + if ( + (qx_needs_dequant && x_sz_upd > ctx->device->properties.limits.maxStorageBufferRange) || + (qy_needs_dequant && y_sz_upd > ctx->device->properties.limits.maxStorageBufferRange)) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (qx_needs_dequant && ctx->prealloc_size_x < x_sz_upd) { + ctx->prealloc_size_x = x_sz_upd; + } + if (qy_needs_dequant && ctx->prealloc_size_y < y_sz_upd) { + ctx->prealloc_size_y = y_sz_upd; + } + + // Request descriptor sets + if (qx_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_0, 1); + } + if (qy_needs_dequant) { + ggml_pipeline_request_descriptor_sets(ctx, to_fp16_vk_1, 1); + } + ggml_pipeline_request_descriptor_sets(ctx, dmmv, 1); + return; + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + const uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + GGML_ASSERT(d_D != nullptr); + vk_buffer d_X; + uint64_t x_buf_offset = 0; + vk_buffer d_Y; + uint64_t y_buf_offset = 0; + if(!src0_uma) { + d_Qx = src0_buf_ctx->dev_buffer; + qx_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_Qx != nullptr); + } + if(!src1_uma) { + d_Qy = src1_buf_ctx->dev_buffer; + qy_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Qy != nullptr); + } + if(!ids_uma) { + d_ids = ids_buf_ctx->dev_buffer; + ids_buf_offset = vk_tensor_offset(ids) + ids->view_offs; + GGML_ASSERT(d_ids != nullptr); + } + if (qx_needs_dequant) { + d_X = ctx->prealloc_x; + } else { + d_X = d_Qx; + x_buf_offset = qx_buf_offset; + GGML_ASSERT(qx_sz == x_sz); + } + if (qy_needs_dequant) { + d_Y = ctx->prealloc_y; + } else { + d_Y = d_Qy; + y_buf_offset = qy_buf_offset; + GGML_ASSERT(qy_sz == y_sz); + } + + if (x_non_contig) { + if (ctx->prealloc_x_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + } + + if (x_non_contig) { + GGML_ASSERT(x_sz == ggml_vk_align_size(ggml_type_size(src0->type) * x_ne, ctx->device->properties.limits.minStorageBufferOffsetAlignment)); + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_0, src0, ggml_vk_subbuffer(ctx, d_Qx, qx_buf_offset), ggml_vk_subbuffer(ctx, d_X, 0)); + } + if (y_non_contig) { + GGML_ASSERT(y_sz == ggml_type_size(src1->type) * y_ne); + if (ctx->prealloc_y_last_pipeline_used != to_fp16_vk_1.get() || + ctx->prealloc_y_last_tensor_used != src1) { + if (ctx->prealloc_y_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + ggml_vk_cpy_to_contiguous(ctx, subctx, to_fp16_vk_1, src1, ggml_vk_subbuffer(ctx, d_Qy, qy_buf_offset), ggml_vk_subbuffer(ctx, d_Y, 0)); + ctx->prealloc_y_last_pipeline_used = to_fp16_vk_1.get(); + ctx->prealloc_y_last_tensor_used = src1; + } + } + + uint32_t stride_batch_y = ne10*ne11; + + if (!ggml_vk_dim01_contiguous(src1) && !qy_needs_dequant) { + stride_batch_y = src1->nb[0] / ggml_type_size(src1->type); + } + + const uint32_t max_groups_x = ctx->device->properties.limits.maxComputeWorkGroupCount[0]; + + uint32_t groups_x = ne01; + uint32_t groups_z = 1; + + if (ne01 > max_groups_x) { + groups_z = 64; + groups_x = CEIL_DIV(groups_x, groups_z); + } + + // compute + const vk_mat_vec_id_push_constants pc = { + (uint32_t)ne00, (uint32_t)ne10, (uint32_t)ne10, (uint32_t)ne01, + (uint32_t)x_ne, stride_batch_y, (uint32_t)(ne20*ne21), + (uint32_t)nei0, (uint32_t)ne11, + }; + ggml_vk_dispatch_pipeline(ctx, subctx, dmmv, + { vk_subbuffer{ d_X, x_buf_offset, x_sz * ne02 * ne03 }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz * ne12 * ne13 }, vk_subbuffer{ d_D, d_buf_offset, d_sz * ne22 * ne23}, vk_subbuffer{ d_ids, ids_buf_offset, ids_sz } }, + pc, { groups_x, (uint32_t)nei0, groups_z }); + + if (x_non_contig) { + ctx->prealloc_x_need_sync = true; + } + if (y_non_contig) { + ctx->prealloc_y_need_sync = true; + } +} + +static void ggml_vk_mul_mat_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_mul_mat_id(" << src0 << ", " << src1 << ", " << src2 << ", " << dst << ")"); + if (src2->ne[1] == 1 && (src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16 || ggml_is_quantized(src0->type))) { + ggml_vk_mul_mat_vec_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } else { + ggml_vk_mul_mat_id_q_f16(ctx, subctx, src0, src1, src2, dst, dryrun); + } +} + +static bool ggml_vk_flash_attn_scalar_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = get_fa_scalar_num_large_rows(hsv); + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * sizeof(float); + + const uint32_t masksh = Bc * Br * sizeof(float); + + const uint32_t Qf = Br * (hsk / 4 + 2) * 4 * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + masksh + Qf; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static bool ggml_vk_flash_attn_coopmat_shmem_support(const vk_device& device, const uint32_t hsk, uint32_t hsv, bool f32acc) { + // Needs to be kept up to date on shader changes + GGML_UNUSED(hsv); + const uint32_t wg_size = scalar_flash_attention_workgroup_size; + const uint32_t Br = coopmat1_flash_attention_num_large_rows; + const uint32_t Bc = scalar_flash_attention_Bc; + + const uint32_t hsk_pad = ROUNDUP_POW2(hsk, 16); + + const uint32_t acctype = f32acc ? 4 : 2; + const uint32_t f16vec4 = 8; + + const uint32_t tmpsh = wg_size * sizeof(float); + const uint32_t tmpshv4 = wg_size * 4 * acctype; + + const uint32_t qstride = hsk_pad / 4 + 2; + const uint32_t Qf = Br * qstride * f16vec4; + + const uint32_t sfshstride = (hsk <= 128) ? (Br + 8) : Br; + const uint32_t sfsh = Bc * sfshstride * acctype; + + const uint32_t kshstride = hsk_pad / 4 + 2; + const uint32_t ksh = Bc * kshstride * f16vec4; + + const uint32_t slope = Br * sizeof(float); + + const uint32_t total_size = tmpsh + tmpshv4 + Qf + sfsh + ksh + slope; + const bool supported = total_size <= device->properties.limits.maxComputeSharedMemorySize; + + VK_LOG_DEBUG("ggml_vk_flash_attn_coopmat_shmem_support(HSK=" << hsk << ", HSV=" << hsv << ", f32acc=" << f32acc << ", total_size=" << total_size << ", supported=" << supported); + + return supported; +} + +static void ggml_vk_flash_attn(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * q, const ggml_tensor * k, const ggml_tensor * v, const ggml_tensor * mask, const ggml_tensor * sinks, ggml_tensor * dst, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_flash_attn((" << q << ", name=" << q->name << ", type=" << q->type << ", ne0=" << q->ne[0] << ", ne1=" << q->ne[1] << ", ne2=" << q->ne[2] << ", ne3=" << q->ne[3] << ", nb0=" << q->nb[0] << ", nb1=" << q->nb[1] << ", nb2=" << q->nb[2] << ", nb3=" << q->nb[3]; + std::cerr << "), (" << k << ", name=" << k->name << ", type=" << k->type << ", ne0=" << k->ne[0] << ", ne1=" << k->ne[1] << ", ne2=" << k->ne[2] << ", ne3=" << k->ne[3] << ", nb0=" << k->nb[0] << ", nb1=" << k->nb[1] << ", nb2=" << k->nb[2] << ", nb3=" << k->nb[3]; + std::cerr << "), (" << v << ", name=" << v->name << ", type=" << v->type << ", ne0=" << v->ne[0] << ", ne1=" << v->ne[1] << ", ne2=" << v->ne[2] << ", ne3=" << v->ne[3] << ", nb0=" << v->nb[0] << ", nb1=" << v->nb[1] << ", nb2=" << v->nb[2] << ", nb3=" << v->nb[3]; + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + if (sinks) { + std::cerr << "), (" << sinks << ", name=" << sinks->name << ", type=" << sinks->type << ", ne0=" << sinks->ne[0] << ", ne1=" << sinks->ne[1] << ", ne2=" << sinks->ne[2] << ", ne3=" << sinks->ne[3] << ", nb0=" << sinks->nb[0] << ", nb1=" << sinks->nb[1] << ", nb2=" << sinks->nb[2] << ", nb3=" << sinks->nb[3]; + } + std::cerr << "), " << (dryrun ? "dryrun" : "") << ")"); + + GGML_TENSOR_LOCALS(int64_t, neq, q, ne) + GGML_TENSOR_LOCALS(size_t, nbq, q, nb) + GGML_TENSOR_LOCALS(int64_t, nek, k, ne) + GGML_TENSOR_LOCALS(size_t, nbk, k, nb) + GGML_TENSOR_LOCALS(int64_t, nev, v, ne) + GGML_TENSOR_LOCALS(size_t, nbv, v, nb) + GGML_TENSOR_LOCALS(int64_t, ne, dst, ne) + GGML_TENSOR_LOCALS(size_t, nb, dst, nb) + + const uint32_t nem1 = mask ? mask->ne[1] : 0; + const uint32_t nem2 = mask ? mask->ne[2] : 0; + const uint32_t nem3 = mask ? mask->ne[3] : 0; + + const uint32_t HSK = nek0; + const uint32_t HSV = nev0; + uint32_t N = neq1; + const uint32_t KV = nek1; + + GGML_ASSERT(ne0 == HSV); + GGML_ASSERT(ne2 == N); + + // input tensor rows must be contiguous + GGML_ASSERT(nbq0 == ggml_type_size(q->type)); + GGML_ASSERT(nbk0 == ggml_type_size(k->type)); + GGML_ASSERT(nbv0 == ggml_type_size(v->type)); + + GGML_ASSERT(neq0 == HSK); + + GGML_ASSERT(neq1 == N); + + GGML_ASSERT(nev1 == nek1); + + // dst cannot be transposed or permuted + GGML_ASSERT(nb0 == sizeof(float)); + GGML_ASSERT(nb0 <= nb1); + GGML_ASSERT(nb1 <= nb2); + GGML_ASSERT(nb2 <= nb3); + + assert(dst->type == GGML_TYPE_F32); + assert(q->type == GGML_TYPE_F32); + assert(k->type == v->type); + + FaCodePath path = ctx->device->coopmat2 ? FA_COOPMAT2 : + ctx->device->coopmat1_fa_support ? FA_COOPMAT1 : FA_SCALAR; + + if (path == FA_COOPMAT1) { + const bool coopmat_shape_supported = (dst->op_params[3] == GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f32acc) || + (dst->op_params[3] != GGML_PREC_F32 && ctx->device->coopmat_support_16x16x16_f16acc); + + const bool coopmat_shmem_supported = ggml_vk_flash_attn_coopmat_shmem_support(ctx->device, HSK, HSV, dst->op_params[3] == GGML_PREC_F32); + + if (!coopmat_shape_supported || !coopmat_shmem_supported) { + path = FA_SCALAR; + } + } + + uint32_t gqa_ratio = 1; + uint32_t qk_ratio = neq2 / nek2; + uint32_t workgroups_x = (uint32_t)neq1; + uint32_t workgroups_y = (uint32_t)neq2; + uint32_t workgroups_z = (uint32_t)neq3; + + // For scalar/coopmat1 FA, we can use the "large" size to accommodate qga. + // For coopmat2 FA, we always use the small size (which is still pretty large for gqa). + uint32_t max_gqa; + switch (path) { + case FA_SCALAR: + case FA_COOPMAT1: + // We may switch from coopmat1 to scalar, so use the scalar limit for both + max_gqa = get_fa_scalar_num_large_rows(HSV); + break; + case FA_COOPMAT2: + max_gqa = get_fa_num_small_rows(FA_COOPMAT2); + break; + default: + GGML_ASSERT(0); + } + + if (N == 1 && qk_ratio > 1 && qk_ratio <= max_gqa && + qk_ratio * nek2 == neq2 && nek2 == nev2 && nem2 <= 1) { + // grouped query attention - make the N dimension equal to gqa_ratio, reduce + // workgroups proportionally in y dimension. The shader will detect gqa_ratio > 1 + // and change addressing calculations to index Q's dimension 2. + gqa_ratio = qk_ratio; + N = gqa_ratio; + workgroups_y /= N; + } + + bool small_rows = N <= get_fa_num_small_rows(path); + + // coopmat1 does not actually support "small rows" (it needs 16 rows). + // So use scalar instead. + if (small_rows && path == FA_COOPMAT1) { + path = FA_SCALAR; + } + + // scalar is faster than coopmat2 when N==1 + if (N == 1 && path == FA_COOPMAT2) { + path = FA_SCALAR; + } + + // with large hsk/hsv, scalar path may need to use small_rows to fit in shared memory + if (path == FA_SCALAR && + !ggml_vk_flash_attn_scalar_shmem_support(ctx->device, HSK, HSV)) { + small_rows = true; + } + + const uint32_t q_stride = (uint32_t)(nbq1 / ggml_type_size(q->type)); + const uint32_t k_stride = (uint32_t)(nbk1 / ggml_type_size(k->type)); + const uint32_t v_stride = (uint32_t)(nbv1 / ggml_type_size(v->type)); + + uint32_t alignment = fa_align(path, HSK, HSV, k->type, small_rows); + bool aligned = (KV % alignment) == 0 && + // the "aligned" shader variant will forcibly align strides, for performance + (q_stride & 7) == 0 && (k_stride & 7) == 0 && (v_stride & 7) == 0; + + // Need to use the coopmat2 variant that clamps loads when HSK/HSV aren't sufficiently aligned. + if (((HSK | HSV) % 16) != 0 && path == FA_COOPMAT2) { + aligned = false; + } + + bool f32acc = path == FA_SCALAR || dst->op_params[3] == GGML_PREC_F32; + + vk_fa_pipeline_state fa_pipeline_state(HSK, HSV, small_rows, path, aligned, f32acc); + + vk_pipeline pipeline = nullptr; + + auto &pipelines = ctx->device->pipeline_flash_attn_f32_f16[k->type]; + auto it = pipelines.find(fa_pipeline_state); + if (it != pipelines.end()) { + pipeline = it->second; + } else { + pipelines[fa_pipeline_state] = pipeline = std::make_shared(); + } + + assert(pipeline); + + uint32_t split_kv = KV; + uint32_t split_k = 1; + + // Use a placeholder core count if one isn't available. split_k is a big help for perf. + const uint32_t shader_core_count = ctx->device->shader_core_count ? ctx->device->shader_core_count : 16; + + // Try to use split_k when KV is large enough to be worth the overhead + if (workgroups_x == 1 && shader_core_count > 0) { + // Try to run two workgroups per SM. + split_k = shader_core_count * 2 / (workgroups_y * workgroups_z); + if (split_k > 1) { + // Try to evenly split KV into split_k chunks, but it needs to be a multiple + // of "align", so recompute split_k based on that. + split_kv = ROUNDUP_POW2(std::max(1u, KV / split_k), alignment); + split_k = CEIL_DIV(KV, split_kv); + workgroups_x = split_k; + } + } + + // Reserve space for split_k temporaries. For each split x batch, we need to store the O matrix (D x ne1) + // and the per-row m and L values (ne1 rows). We store all the matrices first, followed by the rows. + const uint64_t split_k_size = split_k > 1 ? (HSV * ne1 * sizeof(float) + ne1 * sizeof(float) * 2) * split_k * ne3 : 0; + if (split_k_size > ctx->device->properties.limits.maxStorageBufferRange) { + GGML_ABORT("Requested preallocation size is too large"); + } + if (ctx->prealloc_size_split_k < split_k_size) { + ctx->prealloc_size_split_k = split_k_size; + } + + if (dryrun) { + // Request descriptor sets + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_flash_attn_split_k_reduce, 1); + } + return; + } + + float scale = 1.0f; + float max_bias = 0.0f; + float logit_softcap = 0.0f; + + memcpy(&scale, (const float *) dst->op_params + 0, sizeof(float)); + memcpy(&max_bias, (const float *) dst->op_params + 1, sizeof(float)); + memcpy(&logit_softcap, (const float *) dst->op_params + 2, sizeof(float)); + + if (logit_softcap != 0) { + scale /= logit_softcap; + } + + const uint32_t n_head_kv = neq2; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + vk_buffer d_Q = nullptr, d_K = nullptr, d_V = nullptr, d_D = nullptr, d_M = nullptr, d_S = nullptr; + size_t q_buf_offset = 0, k_buf_offset = 0, v_buf_offset = 0, d_buf_offset = 0, m_buf_offset = 0, s_buf_offset = 0; + + bool Q_uma = false, K_uma = false, V_uma = false, D_uma = false, M_uma = false, S_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, q->data, d_Q, q_buf_offset); + ggml_vk_host_get(ctx->device, k->data, d_K, k_buf_offset); + ggml_vk_host_get(ctx->device, v->data, d_V, v_buf_offset); + ggml_vk_host_get(ctx->device, dst->data, d_D, d_buf_offset); + Q_uma = d_Q != nullptr; + K_uma = d_K != nullptr; + V_uma = d_V != nullptr; + D_uma = d_D != nullptr; + if (mask) { + ggml_vk_host_get(ctx->device, mask->data, d_M, m_buf_offset); + M_uma = d_M != nullptr; + } + if (sinks) { + ggml_vk_host_get(ctx->device, sinks->data, d_S, s_buf_offset); + S_uma = d_S != nullptr; + } + } + + + ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * q_buf_ctx = (ggml_backend_vk_buffer_context *)q->buffer->context; + ggml_backend_vk_buffer_context * k_buf_ctx = (ggml_backend_vk_buffer_context *)k->buffer->context; + ggml_backend_vk_buffer_context * v_buf_ctx = (ggml_backend_vk_buffer_context *)v->buffer->context; + + if (!Q_uma) { + d_Q = q_buf_ctx->dev_buffer; + q_buf_offset = vk_tensor_offset(q) + q->view_offs; + } + if (!K_uma) { + d_K = k_buf_ctx->dev_buffer; + k_buf_offset = vk_tensor_offset(k) + k->view_offs; + } + if (!V_uma) { + d_V = v_buf_ctx->dev_buffer; + v_buf_offset = vk_tensor_offset(v) + v->view_offs; + } + if (!D_uma) { + d_D = d_buf_ctx->dev_buffer; + d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + if (!M_uma) { + d_M = d_Q; + m_buf_offset = q_buf_offset; + if (mask) { + ggml_backend_vk_buffer_context * m_buf_ctx = (ggml_backend_vk_buffer_context*)mask->buffer->context; + d_M = m_buf_ctx->dev_buffer; + m_buf_offset = vk_tensor_offset(mask) + mask->view_offs; + } + } + + if (!S_uma) { + d_S = d_Q; + s_buf_offset = q_buf_offset; + if (sinks) { + ggml_backend_vk_buffer_context * s_buf_ctx = (ggml_backend_vk_buffer_context*)sinks->buffer->context; + d_S = s_buf_ctx->dev_buffer; + s_buf_offset = vk_tensor_offset(sinks) + sinks->view_offs; + } + } + + uint32_t mask_n_head_log2 = ((sinks != nullptr) << 24) | ((mask != nullptr) << 16) | n_head_log2; + + const vk_flash_attn_push_constants pc = { N, KV, + (uint32_t)ne1, (uint32_t)ne2, (uint32_t)ne3, + (uint32_t)neq2, (uint32_t)neq3, + (uint32_t)nek2, (uint32_t)nek3, + (uint32_t)nev2, (uint32_t)nev3, + nem1, nem2, nem3, + q_stride, (uint32_t)nbq2, (uint32_t)nbq3, + k_stride, (uint32_t)nbk2, (uint32_t)nbk3, + v_stride, (uint32_t)nbv2, (uint32_t)nbv3, + scale, max_bias, logit_softcap, + mask_n_head_log2, m0, m1, + gqa_ratio, split_kv, split_k }; + + if (split_k > 1) { + if (ctx->prealloc_split_k_need_sync) { + ggml_vk_sync_buffers(ctx, subctx); + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), + }, + // We only use split_k when group query attention is enabled, which means + // there's no more than one tile of rows (i.e. workgroups_x would have been + // one). We reuse workgroups_x to mean the number of splits, so we need to + // cancel out the divide by wg_denoms[0]. + pc, { workgroups_x * pipeline->wg_denoms[0], workgroups_y, workgroups_z }); + + ggml_vk_sync_buffers(ctx, subctx); + const std::array pc2 = { HSV, (uint32_t)ne1, (uint32_t)ne3, split_k, (sinks != nullptr) }; + ggml_vk_dispatch_pipeline(ctx, subctx, ctx->device->pipeline_flash_attn_split_k_reduce, + { + ggml_vk_subbuffer(ctx, ctx->prealloc_split_k, 0), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), + }, + pc2, { (uint32_t)ne1, HSV, (uint32_t)ne3 }); + ctx->prealloc_split_k_need_sync = true; + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, d_Q, q_buf_offset), + ggml_vk_subbuffer(ctx, d_K, k_buf_offset), + ggml_vk_subbuffer(ctx, d_V, v_buf_offset), + ggml_vk_subbuffer(ctx, d_M, m_buf_offset), + ggml_vk_subbuffer(ctx, d_S, s_buf_offset), + ggml_vk_subbuffer(ctx, d_D, d_buf_offset), + }, + pc, { workgroups_x, workgroups_y, workgroups_z }); + } +} + +static std::array ggml_vk_get_conv_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cin, Cout] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + // Copied from ggml.c: int64_t ggml_calc_conv_output_size(int64_t ins, int64_t ks, int s, int p, int d) + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins + 2 * p - d * (ks - 1) - 1) / s + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[3]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[1], dst->op_params[3], dst->op_params[5]); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], dst->op_params[2], dst->op_params[4]); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + +static std::array ggml_vk_get_conv_transpose_2d_elements(const ggml_tensor *dst) { + const ggml_tensor *src0 = dst->src[0]; + const ggml_tensor *src1 = dst->src[1]; + + // src0 - kernel: [KW, KH, Cout, Cin] + // src1 - input: [W, H, Cin, N] + // dst - result: [OW, OH, Cout, N] + + auto calc_conv_output_size = [](int64_t ins, int64_t ks, int s, int p, int d) -> int64_t { + return (ins - 1) * s - 2 * p + (ks - 1) * d + 1; + }; + // parallelize in {OW/BS_K, OH/BS_NPQ, 1} + int64_t W = src1->ne[0]; + int64_t H = src1->ne[1]; + int64_t KW = src0->ne[0]; + int64_t KH = src0->ne[1]; + int64_t Cout = src0->ne[2]; + int64_t N = src1->ne[3]; + int64_t OH = calc_conv_output_size(H, KH, dst->op_params[0], 0, 1); + int64_t OW = calc_conv_output_size(W, KW, dst->op_params[0], 0, 1); + int64_t NPQ = N * OW * OH; + + // Tile output matrix to (K/NB_K, NPQ/NB_NPQ, 1) workgroups + std::array elements = { static_cast(Cout), static_cast(NPQ), 1 }; + return elements; +} + +static vk_pipeline ggml_vk_op_get_pipeline(ggml_backend_vk_context * ctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, const ggml_tensor * dst, ggml_op op) { + switch (op) { + case GGML_OP_GET_ROWS: + GGML_ASSERT(src1->type == GGML_TYPE_I32); + if (dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_get_rows[src0->type]; + } + if (dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_get_rows_f32[src0->type]; + } + return nullptr; + case GGML_OP_ACC: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_acc_f32; + } + return nullptr; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (src1->type != GGML_TYPE_F32 && src1->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16)) { + return nullptr; + } + switch (op) { + case GGML_OP_ADD: + { + if (ctx->num_additional_fused_ops > 0) { + if (ctx->do_add_rms_partials) { + return ctx->device->pipeline_multi_add_rms[ctx->num_additional_fused_ops]; + } else { + return ctx->device->pipeline_multi_add[ctx->num_additional_fused_ops]; + } + } + if (ctx->do_add_rms_partials) { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_rms_norepeat : ctx->device->pipeline_add_rms; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } else { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_add_norepeat : ctx->device->pipeline_add; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + } + case GGML_OP_SUB: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_sub_norepeat : ctx->device->pipeline_sub; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_MUL: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_mul_norepeat : ctx->device->pipeline_mul; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + case GGML_OP_DIV: + { + auto pipelines = ggml_are_same_shape(src0, src1) ? ctx->device->pipeline_div_norepeat : ctx->device->pipeline_div; + return pipelines[src0->type == GGML_TYPE_F16][src1->type == GGML_TYPE_F16][dst->type == GGML_TYPE_F16]; + } + default: + break; + } + return nullptr; + case GGML_OP_ADD_ID: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && src2->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_add_id_f32; + } + return nullptr; + case GGML_OP_CONCAT: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_concat_f32; + } + if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_concat_f16; + } + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_concat_i32; + } + return nullptr; + case GGML_OP_UPSCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + int mode = ggml_get_op_params_i32(dst, 0); + switch (mode) { + case GGML_SCALE_MODE_NEAREST: + return ctx->device->pipeline_upscale_nearest_f32; + case GGML_SCALE_MODE_BILINEAR: + return ctx->device->pipeline_upscale_bilinear_f32; + case GGML_SCALE_MODE_BILINEAR | GGML_SCALE_FLAG_ALIGN_CORNERS: + return ctx->device->pipeline_upscale_bilinear_ac_f32; + } + } + return nullptr; + case GGML_OP_SCALE: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_scale_f32; + } + return nullptr; + case GGML_OP_SQR: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqr_f32; + } + return nullptr; + case GGML_OP_SQRT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sqrt_f32; + } + return nullptr; + case GGML_OP_SIN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sin_f32; + } + return nullptr; + case GGML_OP_COS: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_cos_f32; + } + return nullptr; + case GGML_OP_CLAMP: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_clamp_f32; + } + return nullptr; + case GGML_OP_PAD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pad_f32; + } + return nullptr; + case GGML_OP_ROLL: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_roll_f32; + } + return nullptr; + case GGML_OP_REPEAT: + if (ggml_type_size(src0->type) == sizeof(float) && ggml_type_size(dst->type) == sizeof(float)) { + return ctx->device->pipeline_repeat_f32; + } + return nullptr; + case GGML_OP_REPEAT_BACK: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_repeat_back_f32; + } + return nullptr; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + return ggml_vk_get_cpy_pipeline(ctx, src0, dst, dst->type); + case GGML_OP_SET_ROWS: + if (src1->type == GGML_TYPE_I64) { + return ctx->device->pipeline_set_rows_i64[dst->type]; + } else { + return ctx->device->pipeline_set_rows_i32[dst->type]; + } + case GGML_OP_SILU_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_silu_back_f32; + } + return nullptr; + case GGML_OP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_norm_f32; + } + return nullptr; + case GGML_OP_GROUP_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_group_norm_f32; + } + return nullptr; + case GGML_OP_RMS_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ctx->do_add_rms_partials) { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_partials_f32 : ctx->device->pipeline_rms_norm_partials_f32; + } else { + return ctx->num_additional_fused_ops > 0 ? ctx->device->pipeline_rms_norm_mul_f32 : ctx->device->pipeline_rms_norm_f32; + } + } + return nullptr; + case GGML_OP_RMS_NORM_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rms_norm_back_f32; + } + return nullptr; + case GGML_OP_L2_NORM: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_l2_norm_f32; + } + return nullptr; + case GGML_OP_UNARY: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_unary_op(dst)) { + case GGML_UNARY_OP_EXP: + return ctx->device->pipeline_exp[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SILU: + return ctx->device->pipeline_silu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU: + return ctx->device->pipeline_gelu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU_ERF: + return ctx->device->pipeline_gelu_erf[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_GELU_QUICK: + return ctx->device->pipeline_gelu_quick[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_RELU: + return ctx->device->pipeline_relu[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_TANH: + return ctx->device->pipeline_tanh[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_SIGMOID: + return ctx->device->pipeline_sigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSIGMOID: + return ctx->device->pipeline_hardsigmoid[dst->type == GGML_TYPE_F16]; + case GGML_UNARY_OP_HARDSWISH: + return ctx->device->pipeline_hardswish[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; + case GGML_OP_GLU: + if ((src0->type != GGML_TYPE_F32 && src0->type != GGML_TYPE_F16) || + (dst->type != GGML_TYPE_F32 && dst->type != GGML_TYPE_F16) || + (src0->type != dst->type)) { + return nullptr; + } + + switch (ggml_get_glu_op(dst)) { + case GGML_GLU_OP_GEGLU: + return ctx->device->pipeline_geglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_REGLU: + return ctx->device->pipeline_reglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU: + return ctx->device->pipeline_swiglu[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_SWIGLU_OAI: + return ctx->device->pipeline_swiglu_oai[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_ERF: + return ctx->device->pipeline_geglu_erf[dst->type == GGML_TYPE_F16]; + case GGML_GLU_OP_GEGLU_QUICK: + return ctx->device->pipeline_geglu_quick[dst->type == GGML_TYPE_F16]; + default: + break; + } + return nullptr; + case GGML_OP_DIAG_MASK_INF: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_diag_mask_inf_f32; + } + return nullptr; + case GGML_OP_SOFT_MAX: + GGML_ASSERT(!src1 || src1->type == GGML_TYPE_F32 || src1->type == GGML_TYPE_F16); + GGML_ASSERT(!src2 || src2->type == GGML_TYPE_F32); + + if (src0->type == GGML_TYPE_F32 && (src1 == nullptr || src1->type == GGML_TYPE_F32) && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_wg512 : ctx->device->pipeline_soft_max_f32; + } + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + return src0->ne[0] > 1024 ? ctx->device->pipeline_soft_max_f32_f16_wg512 : ctx->device->pipeline_soft_max_f32_f16; + } + return nullptr; + case GGML_OP_SOFT_MAX_BACK: + if (src0->type == GGML_TYPE_F32 && src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_soft_max_back_f32; + } + return nullptr; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + { + const int mode = ((const int32_t *) dst->op_params)[2]; + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_neox) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_neox_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_neox_f16; + } + } else if (is_mrope && !is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_multi_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_multi_f16; + } + } else if (is_vision) { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_vision_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_vision_f16; + } + } else { + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rope_norm_f32; + } + if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_rope_norm_f16; + } + } + return nullptr; + } + case GGML_OP_ARGSORT: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + uint32_t idx = (uint32_t)ceilf(log2f(float(dst->ne[0]))); + return ctx->device->pipeline_argsort_f32[idx]; + } + return nullptr; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_sum_rows_f32; + } + return nullptr; + case GGML_OP_ARGMAX: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_I32) { + return ctx->device->pipeline_argmax_f32; + } + return nullptr; + case GGML_OP_COUNT_EQUAL: + if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32 && dst->type == GGML_TYPE_I64) { + return ctx->device->pipeline_count_equal_i32; + } + return nullptr; + case GGML_OP_IM2COL: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_f32_f16; + } + return nullptr; + case GGML_OP_IM2COL_3D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_im2col_3d_f32; + } + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F16) { + return ctx->device->pipeline_im2col_3d_f32_f16; + } + return nullptr; + case GGML_OP_TIMESTEP_EMBEDDING: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_timestep_embedding_f32; + } + return nullptr; + case GGML_OP_CONV_TRANSPOSE_1D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_1d_f32; + } + return nullptr; + case GGML_OP_POOL_2D: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_pool2d_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV6: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv6_f32; + } + return nullptr; + case GGML_OP_RWKV_WKV7: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_rwkv_wkv7_f32; + } + return nullptr; + case GGML_OP_OPT_STEP_ADAMW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_adamw_f32; + } + return nullptr; + case GGML_OP_OPT_STEP_SGD: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_opt_step_sgd_f32; + } + return nullptr; + case GGML_OP_LEAKY_RELU: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + return ctx->device->pipeline_leaky_relu_f32; + } + return nullptr; + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + if (src1->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32 && + ggml_is_contiguous(src0) && ggml_is_contiguous(src1) && ggml_is_contiguous(dst)) { + std::array elements; + if (op == GGML_OP_CONV_2D) elements = ggml_vk_get_conv_elements(dst); + else if (op == GGML_OP_CONV_TRANSPOSE_2D) elements = ggml_vk_get_conv_transpose_2d_elements(dst); + vk_conv_shapes shape; + + uint32_t tiles[CONV_SHAPE_COUNT]; + for (uint32_t i = 0; i < CONV_SHAPE_COUNT; ++i) { + tiles[i] = CEIL_DIV(elements[0], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[0]) * CEIL_DIV(elements[1], ctx->device->pipeline_conv2d_f32[i]->wg_denoms[1]); + } + + // We can't query number of shader cores on Intel, use 32 as a placeholder + // so small convolutions will still choose a smaller tile. + const uint32_t shader_core_count = ctx->device->shader_core_count > 0 ? ctx->device->shader_core_count : 32; + + if (elements[0] > 64 && tiles[CONV_SHAPE_128x128] >= shader_core_count * 2) { + shape = CONV_SHAPE_128x128; + } else if (elements[0] <= 32 && tiles[CONV_SHAPE_32x256] >= shader_core_count * 2) { + shape = CONV_SHAPE_32x256; + } else { + shape = CONV_SHAPE_64x32; + } + + if (op == GGML_OP_CONV_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv2d_f16_f32[shape]; + } + } else if (op == GGML_OP_CONV_TRANSPOSE_2D) { + if (src0->type == GGML_TYPE_F32) { + return ctx->device->pipeline_conv_transpose_2d_f32[shape]; + } else if (src0->type == GGML_TYPE_F16) { + return ctx->device->pipeline_conv_transpose_2d_f16_f32[shape]; + } + } + } + return nullptr; + case GGML_OP_CONV_2D_DW: + if (src0->type == GGML_TYPE_F32 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f32; + } + } else if (src0->type == GGML_TYPE_F16 && dst->type == GGML_TYPE_F32) { + if (ggml_is_contiguous(src1)) { + return ctx->device->pipeline_conv2d_dw_whcn_f16_f32; + } else if (ggml_is_contiguous_channels(src1)) { + return ctx->device->pipeline_conv2d_dw_cwhn_f16_f32; + } + } + return nullptr; + default: + return nullptr; + } + + GGML_UNUSED(src2); +} + +static bool ggml_vk_op_supports_incontiguous(ggml_op op) { + switch (op) { + case GGML_OP_CPY: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_ADD_ID: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ROPE: + case GGML_OP_RMS_NORM: + case GGML_OP_CONV_2D_DW: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_SET_ROWS: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + return true; + default: + return false; + } +} + +static uint32_t get_misalign_bytes(ggml_backend_vk_context * ctx, const ggml_tensor * t) +{ + return ((vk_tensor_offset(t) + t->view_offs) & (ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1));; +} + +template void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, T &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + GGML_UNUSED(p); + GGML_UNUSED(src0); + GGML_UNUSED(src1); + GGML_UNUSED(src2); + GGML_UNUSED(dst); + static_assert(!std::is_const::value, "unexpected type"); + GGML_ASSERT(!src0 || get_misalign_bytes(ctx, src0) == 0); + GGML_ASSERT(!src1 || get_misalign_bytes(ctx, src1) == 0); + GGML_ASSERT(!src2 || get_misalign_bytes(ctx, src2) == 0); + GGML_ASSERT(!dst || get_misalign_bytes(ctx, dst) == 0); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_unary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_sum_rows_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_pad_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_im2col_3d_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.misalign_offsets = (a_offset << 16) | d_offset; + + GGML_UNUSED(src0); + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_binary_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t b_offset = get_misalign_bytes(ctx, src1) / ggml_type_size(src1->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + GGML_ASSERT(dst->op != GGML_OP_GET_ROWS || (a_offset == 0 && b_offset == 0 && d_offset == 0)); + + p.misalign_offsets = (a_offset << 16) | (b_offset << 8) | d_offset; + + GGML_UNUSED(src2); +} + +template <> void init_pushconst_tensor_offsets(ggml_backend_vk_context * ctx, vk_op_upscale_push_constants &p, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst) { + const uint32_t a_offset = get_misalign_bytes(ctx, src0) / ggml_type_size(src0->type); + const uint32_t d_offset = get_misalign_bytes(ctx, dst) / ggml_type_size(dst->type); + + p.a_offset = a_offset; + p.d_offset = d_offset; + + GGML_UNUSED(src1); + GGML_UNUSED(src2); +} + +template +static void ggml_vk_op_f32(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, ggml_op op, PC&& pc, bool dryrun = false) { + VK_LOG_DEBUG("ggml_vk_op_f32((" << src0 << ", name=" << src0->name << ", type=" << src0->type << ", ne0=" << src0->ne[0] << ", ne1=" << src0->ne[1] << ", ne2=" << src0->ne[2] << ", ne3=" << src0->ne[3] << ", nb0=" << src0->nb[0] << ", nb1=" << src0->nb[1] << ", nb2=" << src0->nb[2] << ", nb3=" << src0->nb[3]; + if (src1 != nullptr) { + std::cerr << "), (" << src1 << ", name=" << src1->name << ", type=" << src1->type << ", ne0=" << src1->ne[0] << ", ne1=" << src1->ne[1] << ", ne2=" << src1->ne[2] << ", ne3=" << src1->ne[3] << ", nb0=" << src1->nb[0] << ", nb1=" << src1->nb[1] << ", nb2=" << src1->nb[2] << ", nb3=" << src1->nb[3]; + } + if (src2 != nullptr) { + std::cerr << "), (" << src2 << ", name=" << src2->name << ", type=" << src2->type << ", ne0=" << src2->ne[0] << ", ne1=" << src2->ne[1] << ", ne2=" << src2->ne[2] << ", ne3=" << src2->ne[3] << ", nb0=" << src2->nb[0] << ", nb1=" << src2->nb[1] << ", nb2=" << src2->nb[2] << ", nb3=" << src2->nb[3]; + } + std::cerr << "), (" << dst << ", name=" << dst->name << ", type=" << dst->type << ", ne0=" << dst->ne[0] << ", ne1=" << dst->ne[1] << ", ne2=" << dst->ne[2] << ", ne3=" << dst->ne[3] << ", nb0=" << dst->nb[0] << ", nb1=" << dst->nb[1] << ", nb2=" << dst->nb[2] << ", nb3=" << dst->nb[3]; + std::cerr << "), " << ggml_op_name(op) << ", " << (dryrun ? "dryrun" : "") << ")"); + GGML_ASSERT(op == GGML_OP_GET_ROWS || op == GGML_OP_CPY || (!ggml_is_quantized(src0->type) && (src1 == nullptr || !ggml_is_quantized(src1->type)))); // NOLINT + GGML_ASSERT(ggml_vk_op_supports_incontiguous(op) || ggml_vk_dim01_contiguous(src0)); // NOLINT + GGML_ASSERT(dst->buffer != nullptr); + const uint64_t ne00 = src0->ne[0]; + const uint64_t ne01 = src0->ne[1]; + const uint64_t ne02 = src0->ne[2]; + const uint64_t ne03 = src0->ne[3]; + const uint64_t ne0 = ne00 * ne01; + + const bool use_src1 = src1 != nullptr; + const uint64_t ne10 = use_src1 ? src1->ne[0] : 0; + const uint64_t ne11 = use_src1 ? src1->ne[1] : 0; + const uint64_t ne12 = use_src1 ? src1->ne[2] : 0; + const uint64_t ne13 = use_src1 ? src1->ne[3] : 0; + const uint64_t ne1 = ne10 * ne11; + // const uint64_t nb10 = use_src1 ? src1->nb[0] : 0; + + const bool use_src2 = src2 != nullptr; + const uint64_t ne20 = use_src2 ? src2->ne[0] : 0; + const uint64_t ne21 = use_src2 ? src2->ne[1] : 0; + const uint64_t ne22 = use_src2 ? src2->ne[2] : 0; + const uint64_t ne23 = use_src2 ? src2->ne[3] : 0; + const uint64_t ne2 = ne20 * ne21; + + const uint64_t ned0 = dst->ne[0]; + const uint64_t ned1 = dst->ne[1]; + const uint64_t ned2 = dst->ne[2]; + const uint64_t ned3 = dst->ne[3]; + const uint64_t ned = ned0 * ned1; + + init_pushconst_fastdiv(pc); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, dst, op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(op) << " for " << ggml_type_name(src0->type); + if (src1 != nullptr) { + std::cerr << " and " << ggml_type_name(src1->type); + } + std::cerr << " to " << ggml_type_name(dst->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + const bool op_supports_incontiguous = ggml_vk_op_supports_incontiguous(op); + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src0_buf_ctx = (ggml_backend_vk_buffer_context *)src0->buffer->context; + ggml_backend_vk_buffer_context * src1_buf_ctx = use_src1 ? (ggml_backend_vk_buffer_context *)src1->buffer->context : nullptr; + ggml_backend_vk_buffer_context * src2_buf_ctx = use_src2 ? (ggml_backend_vk_buffer_context *)src2->buffer->context : nullptr; + + vk_buffer d_X = nullptr; + size_t x_buf_offset = 0; + vk_buffer d_Y = nullptr; + size_t y_buf_offset = 0; + vk_buffer d_Z = nullptr; + size_t z_buf_offset = 0; + + bool src0_uma = false; + bool src1_uma = false; + bool src2_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, src0->data, d_X, x_buf_offset); + src0_uma = d_X != nullptr; + if (use_src1) { + ggml_vk_host_get(ctx->device, src1->data, d_Y, y_buf_offset); + src1_uma = d_Y != nullptr; + } + if (use_src2) { + ggml_vk_host_get(ctx->device, src2->data, d_Z, z_buf_offset); + src2_uma = d_Z != nullptr; + } + } + + vk_buffer d_D = dst_buf_ctx->dev_buffer; + + GGML_ASSERT(d_D != nullptr); + uint64_t d_buf_offset = vk_tensor_offset(dst) + dst->view_offs; + if(!src0_uma) { + d_X = src0_buf_ctx->dev_buffer; + x_buf_offset = vk_tensor_offset(src0) + src0->view_offs; + GGML_ASSERT(d_X != nullptr); + } + if (use_src1 && !src1_uma) { + d_Y = src1_buf_ctx->dev_buffer; + y_buf_offset = vk_tensor_offset(src1) + src1->view_offs; + GGML_ASSERT(d_Y != nullptr); + } + if (use_src2 && !src2_uma) { + d_Z = src2_buf_ctx->dev_buffer; + z_buf_offset = vk_tensor_offset(src2) + src2->view_offs; + GGML_ASSERT(d_Z != nullptr); + } + // Compute misalignment offset for descriptors and store it in in push constants, then align the descriptor offsets. + init_pushconst_tensor_offsets(ctx, pc, src0, src1, src2, dst); + x_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + y_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + z_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + d_buf_offset &= ~(ctx->device->properties.limits.minStorageBufferOffsetAlignment - 1); + + std::array elements; + + // Single call if dimension 2 is contiguous + GGML_ASSERT(op_supports_incontiguous || (ggml_is_contiguous(src0) && (src1 == nullptr || ggml_is_contiguous(src1)))); + + switch (op) { + case GGML_OP_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + { + const uint32_t nr = ggml_nrows(src0); + if (nr > 262144) { + elements = { 512, 512, CEIL_DIV(nr, 262144) }; + } else if (nr > 512) { + elements = { 512, CEIL_DIV(nr, 512), 1 }; + } else { + elements = { nr, 1, 1 }; + } + } break; + case GGML_OP_RMS_NORM: + if (ctx->do_add_rms_partials) { + // Run one element per thread, 128 threads per workgroup + elements = { (uint32_t)CEIL_DIV(ne00, 128), 1, 1 }; + } else { + elements = { (uint32_t)ne01, (uint32_t)ne02, (uint32_t)ne03 }; + } + break; + + case GGML_OP_SUM: + // We use GGML_OP_SUM_ROWS with 1 row. + elements = { 1, 1, 1 }; + break; + case GGML_OP_GROUP_NORM: + { + const uint32_t num_groups = dst->op_params[0]; + elements = { num_groups * (uint32_t)src0->ne[3], 1, 1 }; + } break; + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + elements = { (uint32_t)ggml_nrows(src0), (uint32_t)ne00, 1 }; + break; + case GGML_OP_GET_ROWS: + elements = { (uint32_t)ne00, (uint32_t)ne10, (uint32_t)(ne11 * ne12) }; + elements[1] = std::min(elements[1], ctx->device->properties.limits.maxComputeWorkGroupCount[1]); + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + break; + case GGML_OP_ARGSORT: + elements = { (uint32_t)ne00, (uint32_t)ggml_nrows(src0), 1 }; + break; + case GGML_OP_IM2COL: + { + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t batch = src1->ne[is_2D ? 3 : 2]; + + elements = { OW * KW * KH, OH, batch * IC }; + } break; + case GGML_OP_IM2COL_3D: + { + const uint32_t IC = ((const uint32_t *)(dst->op_params))[9]; + + const uint32_t N = ne13 / IC; + + const uint32_t KD = ne02; + const uint32_t KH = ne01; + const uint32_t KW = ne00; + + const uint32_t OD = ned3 / N; + const uint32_t OH = ned2; + const uint32_t OW = ned1; + + const uint32_t IC_KD_KH_KW = IC*KD*KH*KW; + const uint32_t N_OD_OH = N*OD*OH; + + elements = { IC_KD_KH_KW, OW, N_OD_OH }; + elements[2] = std::min(elements[2], ctx->device->properties.limits.maxComputeWorkGroupCount[2]); + } break; + case GGML_OP_TIMESTEP_EMBEDDING: + { + const uint32_t dim = dst->op_params[0]; + uint32_t half_ceil = (dim + 1) / 2; + elements = { half_ceil, (uint32_t)src0->ne[0], 1 }; + } break; + case GGML_OP_CONV_TRANSPOSE_1D: + { + elements = {uint32_t(src0->ne[1]), 1, 1}; // parallelize in {Cout, 1, 1} + } break; + case GGML_OP_POOL_2D: + { + const uint32_t N = dst->ne[3]; + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + elements = { N * OC * OH * OW, 1, 1}; + } break; + case GGML_OP_CONV_2D: + { + elements = ggml_vk_get_conv_elements(dst); + } break; + case GGML_OP_CONV_TRANSPOSE_2D: + { + elements = ggml_vk_get_conv_transpose_2d_elements(dst); + } break; + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_DIV: + case GGML_OP_MUL: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_CPY: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_UNARY: + case GGML_OP_GLU: + case GGML_OP_CONV_2D_DW: + { + uint32_t ne = ggml_nelements(dst); + if (op == GGML_OP_CPY && ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + // copy_to_quant has block size of 32, and each thread does QUANT_K elements. + // Splitting into 512x512xZ wouldn't work well since each workgroup does 1024 elements. + // So divide by block size here before splitting into 512x512 groups. + if (op == GGML_OP_CPY && !ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + ne = CEIL_DIV(ne, ggml_blck_size(dst->type)); + } + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } break; + case GGML_OP_ADD_ID: + { + elements = { (uint32_t)ne01, (uint32_t)ne02, 1 }; + } break; + case GGML_OP_SET_ROWS: + { + uint32_t ne = ggml_nelements(src0); + if (ggml_is_quantized(dst->type)) { + // quants run 32 threads each doing QUANT_K elements + ne = CEIL_DIV(ne, 32 * ggml_blck_size(dst->type)); + } else { + // scalar types do one element per thread, running 512 threads + ne = CEIL_DIV(ne, 512); + } + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + } + break; + default: + elements = { (uint32_t)ggml_nelements(src0), 1, 1 }; + break; + } + + uint64_t x_sz, y_sz, z_sz, d_sz; + + if (op_supports_incontiguous) { + x_sz = ggml_nbytes(src0) + get_misalign_bytes(ctx, src0); + y_sz = use_src1 ? ggml_nbytes(src1) + get_misalign_bytes(ctx, src1) : 0; + z_sz = use_src2 ? ggml_nbytes(src2) + get_misalign_bytes(ctx, src2) : 0; + d_sz = ggml_nbytes(dst) + get_misalign_bytes(ctx, dst); + + if (x_buf_offset + x_sz >= d_X->size) { + x_sz = ggml_vk_get_max_buffer_range(ctx, d_X, x_buf_offset); + } + if (use_src1 && y_buf_offset + y_sz >= d_Y->size) { + y_sz = ggml_vk_get_max_buffer_range(ctx, d_Y, y_buf_offset); + } + if (use_src2 && z_buf_offset + z_sz >= d_Z->size) { + z_sz = ggml_vk_get_max_buffer_range(ctx, d_Z, z_buf_offset); + } + if (d_buf_offset + d_sz >= d_D->size) { + d_sz = ggml_vk_get_max_buffer_range(ctx, d_D, d_buf_offset); + } + } else { + x_sz = ggml_type_size(src0->type)/ggml_blck_size(src0->type) * ne0 * ne02 * ne03; + y_sz = use_src1 ? ggml_type_size(src1->type) * ne1 * ne12 * ne13 : 0; + z_sz = use_src2 ? ggml_type_size(src2->type) * ne2 * ne22 * ne23 : 0; + d_sz = ggml_type_size(dst->type) * ned * ned2 * ned3; + } + + if (op == GGML_OP_ADD || op == GGML_OP_RMS_NORM) { + vk_buffer d_A = ctx->do_add_rms_partials ? ctx->prealloc_add_rms_partials : d_X; + size_t a_buf_offset = ctx->do_add_rms_partials ? ctx->prealloc_size_add_rms_partials_offset : 0; + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { vk_subbuffer{ d_X, x_buf_offset, x_sz }, + vk_subbuffer{ d_Y, y_buf_offset, y_sz }, + vk_subbuffer{ d_D, d_buf_offset, d_sz }, + ggml_vk_subbuffer(ctx, d_A, a_buf_offset), + }, pc, elements); + } else if (op == GGML_OP_GLU) { + // Empty src1 is possible in glu, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_SOFT_MAX) { + // Empty src1 and src2 is possible in soft_max, but the shader needs a buffer + vk_subbuffer subbuf_y; + if (use_src1) { + subbuf_y = { d_Y, y_buf_offset, y_sz }; + } else { + subbuf_y = { d_X, 0, x_sz }; + } + + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, subbuf_y, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_ROPE || op == GGML_OP_ROPE_BACK) { + // Empty src2 is possible in rope, but the shader needs a buffer + vk_subbuffer subbuf_z; + if (use_src2) { + subbuf_z = { d_Z, z_buf_offset, z_sz }; + } else { + subbuf_z = { d_X, 0, x_sz }; + } + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, subbuf_z, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_IM2COL || op == GGML_OP_IM2COL_3D) { + if (ctx->device->shader_int64 && ctx->device->buffer_device_address) { + // buffer device address path doesn't use dst buffer + d_sz = 1; + } + // im2col uses only src1 and dst buffers + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_COUNT_EQUAL) { + // count_equal assumes that destination buffer is initialized with zeroes + ggml_vk_buffer_memset_async(subctx, d_D, d_buf_offset, 0, d_sz); + ggml_vk_sync_buffers(ctx, subctx); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (op == GGML_OP_OPT_STEP_SGD) { + // OPT_STEP_SGD works on src0, it does not need dst + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz } }, pc, elements); + } else if (use_src2) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_Z, z_buf_offset, z_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else if (use_src1) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_Y, y_buf_offset, y_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } else { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { vk_subbuffer{ d_X, x_buf_offset, x_sz }, vk_subbuffer{ d_D, d_buf_offset, d_sz } }, pc, elements); + } +} + +static void ggml_vk_get_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_acc(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + int nb1 = dst->op_params[0] / 4; // 4 bytes of float32 + int nb2 = dst->op_params[1] / 4; // 4 bytes of float32 + // int nb3 = dst->op_params[2] / 4; // 4 bytes of float32 - unused + int offset = dst->op_params[3] / 4; // offset in bytes + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ACC, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t)nb1, (uint32_t)nb2, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, offset, + }, dryrun); +} + +static void ggml_vk_multi_add(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_cgraph * cgraph, int node_idx, bool dryrun = false) { + const ggml_tensor *first_node = cgraph->nodes[node_idx]; + const ggml_tensor *dst = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + + // Make a list of all the tensors used by the op. + // Last element of the list is the dest tensor. + const ggml_tensor *tensors[MAX_PARAMETER_COUNT]; + uint32_t num_srcs = ctx->num_additional_fused_ops + 2; + uint32_t num_tensors = num_srcs + 1; + GGML_ASSERT(num_tensors + ctx->do_add_rms_partials <= MAX_PARAMETER_COUNT); + + tensors[0] = first_node->src[0]; + tensors[1] = first_node->src[1]; + for (int32_t i = 0; i < ctx->num_additional_fused_ops; ++i) { + // check whether the previous result is src[0] or src[1] + if (cgraph->nodes[node_idx + i] == cgraph->nodes[node_idx + i + 1]->src[0]) { + tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[1]; + } else { + tensors[i+2] = cgraph->nodes[node_idx + i + 1]->src[0]; + } + } + tensors[num_srcs] = dst; + + vk_op_multi_add_push_constants pc; + pc.ne20 = (uint32_t)dst->ne[0]; + pc.ne21 = (uint32_t)dst->ne[1]; + pc.ne22 = (uint32_t)dst->ne[2]; + pc.ne23 = (uint32_t)dst->ne[3]; + + for (uint32_t i = 0; i < num_tensors; ++i) { + const ggml_tensor *t = tensors[i]; + pc.nb[i][0] = (uint32_t)t->nb[0] / sizeof(float); + pc.nb[i][1] = (uint32_t)t->nb[1] / sizeof(float); + pc.nb[i][2] = (uint32_t)t->nb[2] / sizeof(float); + pc.nb[i][3] = (uint32_t)t->nb[3] / sizeof(float); + } + pc.rms_partials = ctx->do_add_rms_partials; + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, tensors[0], tensors[1], nullptr, dst, dst->op); + + if (pipeline == nullptr) { + std::cerr << "ggml_vulkan: Error: Missing multi_add"; + GGML_ABORT("fatal error"); + } + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * buf_ctx[MAX_PARAMETER_COUNT]; + vk_buffer buf[MAX_PARAMETER_COUNT]; + size_t offset[MAX_PARAMETER_COUNT]; + bool uma[MAX_PARAMETER_COUNT]; + + for (uint32_t i = 0; i < num_tensors; ++i) { + buf_ctx[i] = (ggml_backend_vk_buffer_context *)tensors[i]->buffer->context; + buf[i] = nullptr; + offset[i] = 0; + uma[i] = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, tensors[i]->data, buf[i], offset[i]); + uma[i] = buf[i] != nullptr; + } + if (!uma[i]) { + buf[i] = buf_ctx[i]->dev_buffer; + offset[i] = vk_tensor_offset(tensors[i]) + tensors[i]->view_offs; + } + GGML_ASSERT(buf[i] != nullptr); + } + // If any remaining descriptors are unused, just point them at src[0] + for (uint32_t i = num_tensors; i < MAX_PARAMETER_COUNT; ++i) { + buf[i] = buf[0]; + offset[i] = 0; + } + if (ctx->do_add_rms_partials) { + buf[num_tensors] = ctx->prealloc_add_rms_partials; + offset[num_tensors] = ctx->prealloc_size_add_rms_partials_offset; + } + + std::array elements; + + uint32_t ne = ggml_nelements(dst); + if (ne > 262144) { + elements = { 512, 512, CEIL_DIV(ne, 262144) }; + } else if (ne > 512) { + elements = { 512, CEIL_DIV(ne, 512), 1 }; + } else { + elements = { ne, 1, 1 }; + } + + static_assert(MAX_PARAMETER_COUNT == 12); + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, + { + ggml_vk_subbuffer(ctx, buf[0], offset[0]), + ggml_vk_subbuffer(ctx, buf[1], offset[1]), + ggml_vk_subbuffer(ctx, buf[2], offset[2]), + ggml_vk_subbuffer(ctx, buf[3], offset[3]), + ggml_vk_subbuffer(ctx, buf[4], offset[4]), + ggml_vk_subbuffer(ctx, buf[5], offset[5]), + ggml_vk_subbuffer(ctx, buf[6], offset[6]), + ggml_vk_subbuffer(ctx, buf[7], offset[7]), + ggml_vk_subbuffer(ctx, buf[8], offset[8]), + ggml_vk_subbuffer(ctx, buf[9], offset[9]), + ggml_vk_subbuffer(ctx, buf[10], offset[10]), + ggml_vk_subbuffer(ctx, buf[11], offset[11]), + }, pc, elements); +} + +static void ggml_vk_add(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_ADD, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, ctx->do_add_rms_partials, + }, dryrun); +} + +static void ggml_vk_sub(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SUB, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_mul(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_MUL, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_div(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_DIV, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_add_id(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t src2_type_size = ggml_type_size(src2->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ADD_ID, { + (uint32_t)dst->ne[0], + (uint32_t)dst->ne[1], + (uint32_t)src0->nb[1] / src0_type_size, + (uint32_t)src0->nb[2] / src0_type_size, + (uint32_t)src1->nb[1] / src1_type_size, + (uint32_t)src2->nb[1] / src2_type_size, + }, dryrun); +} + +static void ggml_vk_op_f32_wkv(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_rwkv_wkv6_push_constants&& pc, int version, bool dryrun = false) { + GGML_ASSERT(version == 6 || version == 7); + int num_srcs = version == 6 ? 6 : 7; + + for (int i = 0; i < num_srcs; i++) { + GGML_ASSERT(!ggml_is_quantized(dst->src[i]->type)); + } + + GGML_ASSERT(dst->buffer != nullptr); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, dst->src[0], dst->src[1], dst->src[2], dst, dst->op); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + ggml_backend_vk_buffer_context * src_buf_ctxs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + for (int i = 0; i < num_srcs; i++) { + src_buf_ctxs[i] = (ggml_backend_vk_buffer_context *)dst->src[i]->buffer->context; + } + + vk_buffer d_D = nullptr, d_srcs[7] = { nullptr, nullptr, nullptr, nullptr, nullptr, nullptr, nullptr }; + size_t dst_offset = 0, src_offsets[7] = { 0, 0, 0, 0, 0, 0, 0 }; + bool dst_uma = false, srcs_uma[7] = { false, false, false, false, false, false, false }; + + if (ctx->device->uma) { + for (int i = 0; i < num_srcs; i++) { + ggml_vk_host_get(ctx->device, dst->src[i]->data, d_srcs[i], src_offsets[i]); + srcs_uma[i] = d_srcs[i] != nullptr; + } + + ggml_vk_host_get(ctx->device, dst->data, d_D, dst_offset); + dst_uma = d_D != nullptr; + } + + uint64_t src_sizes[7] = { 0, 0, 0, 0, 0, 0, 0 }; + for (int i = 0; i < num_srcs; i++) { + src_sizes[i] = ggml_nbytes(dst->src[i]); + if (!srcs_uma[i]) { + d_srcs[i] = src_buf_ctxs[i]->dev_buffer; + src_offsets[i] = vk_tensor_offset(dst->src[i]) + dst->src[i]->view_offs; + } + } + + const uint64_t dst_size = ggml_nbytes(dst); + if (!dst_uma) { + d_D = dst_buf_ctx->dev_buffer; + dst_offset = vk_tensor_offset(dst) + dst->view_offs; + } + + std::array elements = { + (uint32_t)(pc.B * pc.H), + 1, + 1 + }; + + if (version == 6) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); + } else if (version == 7) { + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_srcs[0], src_offsets[0], src_sizes[0] }, + vk_subbuffer{ d_srcs[1], src_offsets[1], src_sizes[1] }, + vk_subbuffer{ d_srcs[2], src_offsets[2], src_sizes[2] }, + vk_subbuffer{ d_srcs[3], src_offsets[3], src_sizes[3] }, + vk_subbuffer{ d_srcs[4], src_offsets[4], src_sizes[4] }, + vk_subbuffer{ d_srcs[5], src_offsets[5], src_sizes[5] }, + vk_subbuffer{ d_srcs[6], src_offsets[6], src_sizes[6] }, + vk_subbuffer{ d_D, dst_offset, dst_size } + }, pc, elements); + } else { + // shouldn't happen + GGML_ASSERT(false); + } +} + +static void ggml_vk_rwkv_wkv6(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[5]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 6, + dryrun + ); +} + +static void ggml_vk_rwkv_wkv7(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t seq_length = dst->src[0]->ne[2]; + const size_t n_embed = dst->ne[0]; + const size_t n_heads = dst->src[0]->ne[1]; + const size_t n_seqs = dst->src[6]->ne[1]; + + ggml_vk_op_f32_wkv( + ctx, subctx, dst, + { + (uint32_t)n_seqs, + (uint32_t)seq_length, + (uint32_t)n_embed, + (uint32_t)n_heads, + }, + 7, + dryrun + ); +} + +static void ggml_vk_op_f32_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, const vk_op_push_constants&& pc, bool dryrun = false) { + const ggml_tensor * x = dst->src[0]; + const ggml_tensor * g = dst->src[1]; + const ggml_tensor * gm = dst->src[2]; + const ggml_tensor * gv = dst->src[3]; + const ggml_tensor * p = dst->src[4]; + + GGML_ASSERT(x->type == GGML_TYPE_F32); + GGML_ASSERT(g->type == GGML_TYPE_F32); + GGML_ASSERT(gm->type == GGML_TYPE_F32); + GGML_ASSERT(gv->type == GGML_TYPE_F32); + GGML_ASSERT(p->type == GGML_TYPE_F32); + GGML_ASSERT(dst->buffer != nullptr); + GGML_ASSERT(ggml_is_contiguous(x)); + GGML_ASSERT(ggml_is_contiguous(g)); + GGML_ASSERT(ggml_is_contiguous(gm)); + GGML_ASSERT(ggml_is_contiguous(gv)); + GGML_ASSERT(ggml_is_contiguous(p)); + GGML_ASSERT(ggml_are_same_shape(x, g)); + GGML_ASSERT(ggml_are_same_shape(x, gm)); + GGML_ASSERT(ggml_are_same_shape(x, gv)); + GGML_ASSERT(ggml_nelements(p) == 7); + + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, g, gm, gv, dst, GGML_OP_OPT_STEP_ADAMW); + GGML_ASSERT(pipeline != nullptr); + + if (dryrun) { + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + return; + } + + ggml_backend_vk_buffer_context * x_buf_ctx = (ggml_backend_vk_buffer_context *)x->buffer->context; + ggml_backend_vk_buffer_context * g_buf_ctx = (ggml_backend_vk_buffer_context *)g->buffer->context; + ggml_backend_vk_buffer_context * gm_buf_ctx = (ggml_backend_vk_buffer_context *)gm->buffer->context; + ggml_backend_vk_buffer_context * gv_buf_ctx = (ggml_backend_vk_buffer_context *)gv->buffer->context; + ggml_backend_vk_buffer_context * p_buf_ctx = (ggml_backend_vk_buffer_context *)p->buffer->context; + + vk_buffer d_X = nullptr, d_G = nullptr, d_GM = nullptr, d_GV = nullptr, d_P = nullptr; + size_t x_offset = 0, g_offset = 0, gm_offset = 0, gv_offset = 0, p_offset = 0; + bool X_uma = false, G_uma = false, GM_uma = false, GV_uma = false, P_uma = false; + + if (ctx->device->uma) { + ggml_vk_host_get(ctx->device, x->data, d_X, x_offset); + ggml_vk_host_get(ctx->device, g->data, d_G, g_offset); + ggml_vk_host_get(ctx->device, gm->data, d_GM, gm_offset); + ggml_vk_host_get(ctx->device, gv->data, d_GV, gv_offset); + ggml_vk_host_get(ctx->device, p->data, d_P, p_offset); + + X_uma = d_X != nullptr; + G_uma = d_G != nullptr; + GM_uma = d_GM != nullptr; + GV_uma = d_GV != nullptr; + P_uma = d_P != nullptr; + } + + if (!X_uma) { + d_X = x_buf_ctx->dev_buffer; + x_offset = vk_tensor_offset(x) + x->view_offs; + } + if (!G_uma) { + d_G = g_buf_ctx->dev_buffer; + g_offset = vk_tensor_offset(g) + g->view_offs; + } + if (!GM_uma) { + d_GM = gm_buf_ctx->dev_buffer; + gm_offset = vk_tensor_offset(gm) + gm->view_offs; + } + if (!GV_uma) { + d_GV = gv_buf_ctx->dev_buffer; + gv_offset = vk_tensor_offset(gv) + gv->view_offs; + } + if (!P_uma) { + d_P = p_buf_ctx->dev_buffer; + p_offset = vk_tensor_offset(p) + p->view_offs; + } + + const uint64_t x_size = ggml_nbytes(x); + const uint64_t g_size = ggml_nbytes(g); + const uint64_t gm_size = ggml_nbytes(gm); + const uint64_t gv_size = ggml_nbytes(gv); + const uint64_t p_size = ggml_nbytes(p); + + std::array elements = { (uint32_t)ggml_nelements(x), 1, 1 }; + + ggml_vk_dispatch_pipeline(ctx, subctx, pipeline, { + vk_subbuffer{ d_X, x_offset, x_size }, + vk_subbuffer{ d_G, g_offset, g_size }, + vk_subbuffer{ d_GM, gm_offset, gm_size }, + vk_subbuffer{ d_GV, gv_offset, gv_size }, + vk_subbuffer{ d_P, p_offset, p_size }, + }, pc, elements); +} + +static void ggml_vk_opt_step_adamw(ggml_backend_vk_context * ctx, vk_context& subctx, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32_opt_step_adamw( + ctx, subctx, dst, + { (uint32_t)n, 0, 0.0f, 0.0f }, + dryrun + ); +} + +static void ggml_vk_opt_step_sgd(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + const size_t n = ggml_nelements(dst->src[0]); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_OPT_STEP_SGD, { (uint32_t)n, 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_concat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + int * op_params = (int *)dst->op_params; + + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONCAT, { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, op_params[0], + }, dryrun); +} + +static void ggml_vk_upscale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t mode = (uint32_t)ggml_get_op_params_i32(dst, 0); + + float sf0 = (float)dst->ne[0] / src0->ne[0]; + float sf1 = (float)dst->ne[1] / src0->ne[1]; + float sf2 = (float)dst->ne[2] / src0->ne[2]; + float sf3 = (float)dst->ne[3] / src0->ne[3]; + + if (mode & GGML_SCALE_FLAG_ALIGN_CORNERS) { + sf0 = (float)(dst->ne[0] - 1) / (src0->ne[0] - 1); + sf1 = (float)(dst->ne[1] - 1) / (src0->ne[1] - 1); + } + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UPSCALE, { + (uint32_t)ggml_nelements(dst), 0, 0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], + (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)dst->ne[0], (uint32_t)dst->ne[1], (uint32_t)dst->ne[2],(uint32_t)dst->ne[3], + sf0, sf1, sf2, sf3, + }, dryrun); +} + +static void ggml_vk_scale(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SCALE, std::move(p), dryrun); +} + +static void ggml_vk_sqr(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQR, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_sqrt(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SQRT, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_sin(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SIN, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_cos(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_COS, vk_op_unary_push_constants_init(src0, dst), dryrun); +} + +static void ggml_vk_clamp(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + p.param1 = ggml_get_op_params_f32(dst, 0); + p.param2 = ggml_get_op_params_f32(dst, 1); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CLAMP, std::move(p), dryrun); +} + +static void ggml_vk_pad(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_pad_push_constants p = vk_op_pad_push_constants_init(src0, dst); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_PAD, std::move(p), dryrun); +} + +static void ggml_vk_roll(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = ggml_get_op_params_i32(dst, 0); + const int32_t s1 = ggml_get_op_params_i32(dst, 1); + const int32_t s2 = ggml_get_op_params_i32(dst, 2); + const int32_t s3 = ggml_get_op_params_i32(dst, 3); + const uint32_t s01_packed = ((s0 + 0x8000) << 16) | (s1 + 0x8000); + const uint32_t s23_packed = ((s2 + 0x8000) << 16) | (s3 + 0x8000); + + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst); + memcpy(&p.param1, &s01_packed, sizeof(float)); + memcpy(&p.param2, &s23_packed, sizeof(float)); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ROLL, std::move(p), dryrun); +} + +static void ggml_vk_repeat(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT, std::move(p), dryrun); +} + +static void ggml_vk_repeat_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ggml_nelements(dst)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_REPEAT_BACK, std::move(p), dryrun); +} + +static void ggml_vk_cpy(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t ne = (uint32_t)ggml_nelements(src0); + if (ggml_is_quantized(src0->type) && ggml_is_quantized(dst->type)) { + // Convert from number of logical elements to 2- or 4-byte units. + ne /= ggml_blck_size(src0->type); + if ((ggml_type_size(src0->type) % 4) == 0) { + ne *= ggml_type_size(src0->type) / 4; + } else { + ne *= ggml_type_size(src0->type) / 2; + } + } + + vk_op_unary_push_constants p = vk_op_unary_push_constants_init(src0, dst, ne); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_CPY, std::move(p), dryrun); +} + +static void ggml_vk_set_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + // Skip empty skip_rows operations. For most ops the empty check at the start + // of ggml_vk_build_graph is sufficient, but set_rows can have a nonempty dst + // with empty srcs. + if (ggml_is_empty(src0) || ggml_is_empty(src1)) { + return; + } + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SET_ROWS, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + 0.0f, 0.0f, 0, + }, dryrun); +} + +static void ggml_vk_silu_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SILU_BACK, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_group_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const int * int_op_params = (const int *)dst->op_params; + const float * float_op_params = (const float *)dst->op_params; + + const uint32_t num_groups = int_op_params[0]; + const float eps = float_op_params[1]; + const uint32_t group_size = src0->ne[0] * src0->ne[1] * ((src0->ne[2] + num_groups - 1) / num_groups); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_GROUP_NORM, { group_size, 0, eps, 0.0f }, dryrun); +} + +static uint32_t ggml_vk_rms_num_partials(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t ne = (uint32_t)node->ne[0]; + const uint32_t denom = ctx->device->pipeline_add_rms[0][0][0]->wg_denoms[0]; + const uint32_t num_partials = CEIL_DIV(ne, denom); + return num_partials; +} + +static uint32_t ggml_vk_rms_partials_size(ggml_backend_vk_context * ctx, const ggml_tensor *node) { + const uint32_t num_partials = ggml_vk_rms_num_partials(ctx, node); + const uint32_t num_bytes = ROUNDUP_POW2(num_partials * sizeof(uint32_t), ctx->device->partials_binding_alignment); + return num_bytes; +} + +static void ggml_vk_rms_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, float * op_params, bool dryrun = false) { + const uint32_t src0_type_size = ggml_type_size(src0->type); + const uint32_t src1_type_size = ggml_type_size(src1->type); + const uint32_t dst_type_size = ggml_type_size(dst->type); + + uint32_t param3 = ctx->do_add_rms_partials ? ggml_vk_rms_num_partials(ctx, dst) : 0; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM, { + (uint32_t)ggml_nelements(src0), + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2],(uint32_t)src0->ne[3], (uint32_t)src0->nb[0] / src0_type_size, (uint32_t)src0->nb[1] / src0_type_size, (uint32_t)src0->nb[2] / src0_type_size, (uint32_t)src0->nb[3] / src0_type_size, + (uint32_t)src1->ne[0], (uint32_t)src1->ne[1], (uint32_t)src1->ne[2],(uint32_t)src1->ne[3], (uint32_t)src1->nb[0] / src1_type_size, (uint32_t)src1->nb[1] / src1_type_size, (uint32_t)src1->nb[2] / src1_type_size, (uint32_t)src1->nb[3] / src1_type_size, + (uint32_t) dst->ne[0], (uint32_t) dst->ne[1], (uint32_t) dst->ne[2],(uint32_t) dst->ne[3], (uint32_t) dst->nb[0] / dst_type_size, (uint32_t) dst->nb[1] / dst_type_size, (uint32_t) dst->nb[2] / dst_type_size, (uint32_t) dst->nb[3] / dst_type_size, + 0, + op_params[0], 0.0f, (int32_t)param3, + }, dryrun); + + if (ctx->do_add_rms_partials) { + ctx->prealloc_size_add_rms_partials_offset += ggml_vk_rms_partials_size(ctx, src0); + ctx->do_add_rms_partials = false; + } +} + +static void ggml_vk_rms_norm_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_RMS_NORM_BACK, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_l2_norm(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_L2_NORM, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0], 0.0f }, dryrun); +} + +static void ggml_vk_unary(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_UNARY, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_glu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const float * op_params_f = (const float *)dst->op_params; + + const bool swapped = (bool)dst->op_params[1]; + const bool split = src1 != nullptr; + const float alpha = op_params_f[2]; + const float limit = op_params_f[3]; + + GGML_ASSERT(ggml_is_contiguous(src0)); + + if (!split) { + GGML_ASSERT(src0->ne[0] / 2 == dst->ne[0]); + } else { + GGML_ASSERT(src0->ne[0] == src1->ne[0]); + GGML_ASSERT(src0->ne[0] == dst->ne[0]); + GGML_ASSERT(src0->type == src1->type); + } + + const uint32_t mode = split ? 2 : (swapped ? 1 : 0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_GLU, + { + (uint32_t)ggml_nelements(dst), + (uint32_t)src0->ne[0], + (uint32_t)dst->ne[0], + mode, + alpha, + limit + }, dryrun); +} + +static void ggml_vk_diag_mask_inf(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_DIAG_MASK_INF, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], op_params[0] }, dryrun); +} + +static void ggml_vk_soft_max(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + + float scale = op_params[0]; + float max_bias = op_params[1]; + + const uint32_t ncols = (uint32_t)src0->ne[0]; + const uint32_t nrows_x = (uint32_t)ggml_nrows(src0); + const uint32_t nrows_y = (uint32_t)src0->ne[1]; + + const uint32_t ne12 = src1 ? (uint32_t)(src1->ne[2]) : 0u; + const uint32_t ne13 = src1 ? (uint32_t)(src1->ne[3]) : 0u; + const uint32_t nb11 = src1 ? (uint32_t)(src1->nb[1] / src1->nb[0]) : 0u; + const uint32_t nb12 = src1 ? (uint32_t)(src1->nb[2] / src1->nb[0]) : 0u; + const uint32_t nb13 = src1 ? (uint32_t)(src1->nb[3] / src1->nb[0]) : 0u; + + const uint32_t n_head_kv = src0->ne[2]; + const uint32_t n_head_log2 = 1u << (uint32_t) floorf(log2f((float) n_head_kv)); + + const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); + const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_SOFT_MAX, { + ncols, + src1 != nullptr ? nrows_y : (uint32_t)0, + (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], (uint32_t)src0->ne[2], + ne12, ne13, + nb11, nb12, nb13, + scale, max_bias, + m0, m1, + n_head_log2, + nrows_x, + src2 != nullptr + }, dryrun); +} + +static void ggml_vk_soft_max_back(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + float * op_params = (float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_SOFT_MAX_BACK, { (uint32_t)src0->ne[0], (uint32_t)ggml_nrows(src0), op_params[0], op_params[1] }, dryrun); +} + +static void ggml_vk_rope(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, const ggml_tensor * src2, ggml_tensor * dst, bool backprop, bool dryrun = false) { + const int n_dims = ((int32_t *) dst->op_params)[1]; + const int mode = ((int32_t *) dst->op_params)[2]; + // const int n_ctx = ((int32_t *) dst->op_params)[3]; + const int n_ctx_orig = ((int32_t *) dst->op_params)[4]; + const float freq_base = ((float *) dst->op_params)[5]; + const float freq_scale = ((float *) dst->op_params)[6]; + const float ext_factor = ((float *) dst->op_params)[7]; + const float attn_factor = ((float *) dst->op_params)[8]; + const float beta_fast = ((float *) dst->op_params)[9]; + const float beta_slow = ((float *) dst->op_params)[10]; + int sections[4] {}; + if (mode & GGML_ROPE_TYPE_MROPE) { + memcpy(sections, (int32_t *) dst->op_params + 11, sizeof(int)*4); + } + + float corr_dims[2]; + ggml_rope_yarn_corr_dims(n_dims, n_ctx_orig, freq_base, beta_fast, beta_slow, corr_dims); + + const float theta_scale = powf(freq_base, -2.0f/n_dims); + + uint32_t s1 = src0->nb[1] / ggml_type_size(src0->type); + uint32_t s2 = src0->nb[2] / ggml_type_size(src0->type); + + ggml_vk_op_f32(ctx, subctx, src0, src1, src2, dst, GGML_OP_ROPE, { + (uint32_t)src0->ne[0], (uint32_t)n_dims, freq_scale, (uint32_t)src0->ne[1], + freq_base, ext_factor, attn_factor, {corr_dims[0], corr_dims[1]}, theta_scale, + src2 != nullptr, (uint32_t)src0->ne[2], s1, s2, + { sections[0], sections[1], sections[2], sections[3] }, backprop + }, dryrun); +} + +static void ggml_vk_argsort(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + int32_t * op_params = (int32_t *)dst->op_params; + + uint32_t ncols = src0->ne[0]; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGSORT, { + ncols, + op_params[0], + }, dryrun); +} + +static void ggml_vk_sum(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, ggml_nelements(src0)); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM, p, dryrun); +} + +static void ggml_vk_sum_rows(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_SUM_ROWS, p, dryrun); +} + +static void ggml_vk_mean(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + vk_op_sum_rows_push_constants p = vk_op_sum_rows_push_constants_init(src0, dst, src0->ne[0]); + p.weight = 1.0f / (float)src0->ne[0]; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_MEAN, p, dryrun); +} + +static void ggml_vk_argmax(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_ARGMAX, { (uint32_t)src0->ne[0], (uint32_t)src0->ne[1], 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_count_equal(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_COUNT_EQUAL, { (uint32_t)ggml_nelements(src0), 0, 0.0f, 0.0f }, dryrun); +} + +static void ggml_vk_im2col(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + const int32_t s0 = dst->op_params[0]; + const int32_t s1 = dst->op_params[1]; + const int32_t p0 = dst->op_params[2]; + const int32_t p1 = dst->op_params[3]; + const int32_t d0 = dst->op_params[4]; + const int32_t d1 = dst->op_params[5]; + + const bool is_2D = dst->op_params[6] == 1; + + const uint32_t IC = src1->ne[is_2D ? 2 : 1]; + const uint32_t IH = is_2D ? src1->ne[1] : 1; + const uint32_t IW = src1->ne[0]; + + const uint32_t KH = is_2D ? src0->ne[1] : 1; + const uint32_t KW = src0->ne[0]; + + const uint32_t OH = is_2D ? dst->ne[2] : 1; + const uint32_t OW = dst->ne[1]; + + const uint32_t offset_delta = src1->nb[is_2D ? 2 : 1] / 4; // nb is byte offset, src is type float32 + const uint32_t batch_offset = src1->nb[is_2D ? 3 : 2] / 4; // nb is byte offset, src is type float32 + + const uint32_t pelements = OW * KW * KH; + + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL, { + dst_addr, + batch_offset, offset_delta, + IC, IW, IH, OW, OH, KW, KH, + pelements, + IC * KH * KW, + s0, s1, p0, p1, d0, d1, + }, dryrun); +} + +static void ggml_vk_im2col_3d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_TENSOR_BINARY_OP_LOCALS + + const int32_t s0 = ((const int32_t *)(dst->op_params))[0]; + const int32_t s1 = ((const int32_t *)(dst->op_params))[1]; + const int32_t s2 = ((const int32_t *)(dst->op_params))[2]; + const int32_t p0 = ((const int32_t *)(dst->op_params))[3]; + const int32_t p1 = ((const int32_t *)(dst->op_params))[4]; + const int32_t p2 = ((const int32_t *)(dst->op_params))[5]; + const int32_t d0 = ((const int32_t *)(dst->op_params))[6]; + const int32_t d1 = ((const int32_t *)(dst->op_params))[7]; + const int32_t d2 = ((const int32_t *)(dst->op_params))[8]; + const int32_t IC = ((const int32_t *)(dst->op_params))[9]; + + const int64_t N = ne13 / IC; + const int64_t ID = ne12; + const int64_t IH = ne11; + const int64_t IW = ne10; + + const int64_t KD = ne02; + const int64_t KH = ne01; + const int64_t KW = ne00; + + const int64_t OD = ne3 / N; + const int64_t OH = ne2; + const int64_t OW = ne1; + + const ggml_backend_vk_buffer_context * d_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + const vk_buffer d_buf = d_buf_ctx->dev_buffer; + + const vk::DeviceAddress dst_addr = d_buf->bda_addr + vk_tensor_offset(dst) + dst->view_offs; + + vk_op_im2col_3d_push_constants pc {}; + + pc.dst_addr = dst_addr; + pc.nb10 = nb10 / ggml_type_size(src1->type); + pc.nb11 = nb11 / ggml_type_size(src1->type); + pc.nb12 = nb12 / ggml_type_size(src1->type); + pc.nb13 = nb13 / ggml_type_size(src1->type); + pc.s0 = s0; + pc.s1 = s1; + pc.s2 = s2; + pc.p0 = p0; + pc.p1 = p1; + pc.p2 = p2; + pc.d0 = d0; + pc.d1 = d1; + pc.d2 = d2; + pc.IW = IW; + pc.IH = IH; + pc.ID = ID; + pc.IC = IC; + pc.KW = KW; + pc.OH = OH; + pc.KD_KH_KW = KD*KH*KW; + pc.KH_KW = KH*KW; + pc.IC_KD_KH_KW = IC*KD*KH*KW; + pc.N_OD_OH = N*OD*OH; + pc.OD_OH = OD*OH; + pc.OD_OH_OW_IC_KD_KH_KW = OD*OH*OW*IC*KD*KH*KW; + pc.OH_OW_IC_KD_KH_KW = OH*OW*IC*KD*KH*KW; + pc.OW_IC_KD_KH_KW = OW*IC*KD*KH*KW; + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_IM2COL_3D, std::move(pc), dryrun); +} + +static void ggml_vk_timestep_embedding(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const uint32_t dim = dst->op_params[0]; + const uint32_t max_period = dst->op_params[1]; + const uint32_t nb1 = dst->nb[1] / ggml_type_size(dst->type); + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_TIMESTEP_EMBEDDING, { + nb1, dim, max_period, + }, dryrun); +} + +static void ggml_vk_conv_transpose_1d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + // src0: (K, Cout, Cin, 1) -- kernel + // src1: (L, Cin, 1, 1) -- input + // dst: (*, Cout, 1, 1) + + GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT( dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float)); + GGML_ASSERT(nb10 == sizeof(float)); + + const int32_t s0 = dst->op_params[0]; + + vk_op_conv_transpose_1d_push_constants p{}; + p.Cout = static_cast(ne01); + p.Cin = static_cast(ne02); + p.K = static_cast(ne00); + p.L = static_cast(ne10); + p.KL = static_cast(ne0); + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb11 = static_cast(nb11 / nb10); + p.nb1 = static_cast(nb1 / nb0); + p.s0 = static_cast(s0); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_1D, std::move(p), dryrun); +} + +static void ggml_vk_pool_2d(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + uint32_t op = static_cast(dst->op_params[0]); + const int32_t k1 = dst->op_params[1]; + const int32_t k0 = dst->op_params[2]; + const int32_t s1 = dst->op_params[3]; + const int32_t s0 = dst->op_params[4]; + const int32_t p1 = dst->op_params[5]; + const int32_t p0 = dst->op_params[6]; + + const uint32_t IH = src0->ne[1]; + const uint32_t IW = src0->ne[0]; + + const uint32_t N = dst->ne[3]; + + const uint32_t OC = dst->ne[2]; + const uint32_t OH = dst->ne[1]; + const uint32_t OW = dst->ne[0]; + + const uint32_t parallel_elements = N * OC * OH * OW; + + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_POOL_2D, { + IW, IH, OW, OH, OC, + parallel_elements, + op, + k0, k1, s0, s1, p0, p1, + }, dryrun); +} + +static void ggml_vk_conv_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv2d_push_constants p{}; + p.Cout = static_cast(ne03); + p.Cin = static_cast(ne02); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[1]); + p.p0 = static_cast(dst->op_params[2]); + p.p1 = static_cast(dst->op_params[3]); + p.d0 = static_cast(dst->op_params[4]); + p.d1 = static_cast(dst->op_params[5]); + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne03 == ne2); + GGML_ASSERT(ne02 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D, std::move(p), dryrun); +} + +static void ggml_vk_conv_transpose_2d(ggml_backend_vk_context * ctx, vk_context & subctx, const ggml_tensor * src0, + const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_F16); + GGML_ASSERT(src1->type == GGML_TYPE_F32); + GGML_ASSERT(dst->type == GGML_TYPE_F32); + + GGML_TENSOR_BINARY_OP_LOCALS + + GGML_ASSERT(nb00 == sizeof(float) || nb00 == sizeof(ggml_fp16_t)); + GGML_ASSERT(nb10 == sizeof(float)); + GGML_ASSERT(nb0 == sizeof(float)); + + vk_op_conv_transpose_2d_push_constants p{}; + p.Cout = static_cast(ne02); + p.Cin = static_cast(ne03); + p.N = static_cast(ne13); + + p.KW = static_cast(ne00); + p.KH = static_cast(ne01); + p.W = static_cast(ne10); + p.H = static_cast(ne11); + p.OW = static_cast(ne0); + p.OH = static_cast(ne1); + + p.s0 = static_cast(dst->op_params[0]); + p.s1 = static_cast(dst->op_params[0]); + p.p0 = 0; + p.p1 = 0; + p.d0 = 1; + p.d1 = 1; + + p.nb01 = static_cast(nb01 / nb00); + p.nb02 = static_cast(nb02 / nb00); + p.nb03 = static_cast(nb03 / nb00); + + p.nb11 = static_cast(nb11 / nb10); + p.nb12 = static_cast(nb12 / nb10); + p.nb13 = static_cast(nb13 / nb10); + + p.nb1 = static_cast(nb1 / nb0); + p.nb2 = static_cast(nb2 / nb0); + p.nb3 = static_cast(nb3 / nb0); + + GGML_ASSERT(ne02 == ne2); + GGML_ASSERT(ne03 == ne12); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_TRANSPOSE_2D, std::move(p), dryrun); +} + +static void ggml_vk_conv_2d_dw(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, const ggml_tensor * src1, ggml_tensor * dst, bool dryrun = false) { + vk_op_conv2d_dw_push_constants p{}; + p.ne = ggml_nelements(dst); + p.channels = dst->ne[2]; + p.batches = dst->ne[3]; + p.dst_w = dst->ne[0]; + p.dst_h = dst->ne[1]; + p.src_w = src1->ne[0]; + p.src_h = src1->ne[1]; + p.knl_w = src0->ne[0]; + p.knl_h = src0->ne[1]; + p.stride_x = dst->op_params[0]; + p.stride_y = dst->op_params[1]; + p.pad_x = dst->op_params[2]; + p.pad_y = dst->op_params[3]; + p.dilation_x = dst->op_params[4]; + p.dilation_y = dst->op_params[5]; + + GGML_ASSERT(src0->ne[3] == p.channels); + GGML_ASSERT(src1->ne[3] == p.batches); + + ggml_vk_op_f32(ctx, subctx, src0, src1, nullptr, dst, GGML_OP_CONV_2D_DW, std::move(p), dryrun); +} + +static void ggml_vk_leaky_relu(ggml_backend_vk_context * ctx, vk_context& subctx, const ggml_tensor * src0, ggml_tensor * dst, bool dryrun = false) { + const float * op_params = (const float *)dst->op_params; + ggml_vk_op_f32(ctx, subctx, src0, nullptr, nullptr, dst, GGML_OP_LEAKY_RELU, { (uint32_t)ggml_nelements(src0), 0, op_params[0], 0.0f }, dryrun); +} + +#ifdef GGML_VULKAN_RUN_TESTS +static void ggml_vk_print_matrix_area(const void * data, ggml_type type, int ne0, int ne1, int i0, int i1, int i2) { + if (type != GGML_TYPE_F32 && type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < ne0 && idx1 >= 0 && idx1 < ne1) { + float val; + if (type == GGML_TYPE_F32) { + val = *((const float *) data + i2*ne1*ne0 + idx1*ne0 + idx0); + } else if (type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*((const ggml_fp16_t *) data + i2*ne1*ne0 + idx1*ne0 + idx0)); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +template +static void ggml_vk_test_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, int split_k, int shader_size) { + VK_LOG_DEBUG("ggml_vk_test_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << shader_size << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_s; + shname = "F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_s; + shname = "F32_F16_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_s; + shname = "F16_F32_ALIGNED_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_s; + shname = "F16_ALIGNED_S"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_m; + shname = "F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_m; + shname = "F32_F16_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_m; + shname = "F16_F32_ALIGNED_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_m; + shname = "F16_ALIGNED_M"; + } else { + GGML_ABORT("fatal error"); + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->a_l; + shname = "F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->a_l; + shname = "F32_F16_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->a_l; + shname = "F16_F32_ALIGNED_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->a_l; + shname = "F16_ALIGNED_L"; + } else { + GGML_ABORT("fatal error"); + } + } else { + GGML_ASSERT(0); + } + + const size_t kpad = ggml_vk_align_size(k, p->align); + + if (k != kpad) { + if (shader_size == 0) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->s; + shname = "F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->s; + shname = "F32_F16_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->s; + shname = "F16_F32_S"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->s; + shname = "F16_S"; + } + } else if (shader_size == 1) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->m; + shname = "F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->m; + shname = "F32_F16_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->m; + shname = "F16_F32_M"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->m; + shname = "F16_M"; + } + } else if (shader_size == 2) { + if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32->l; + shname = "F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f32_f16->l; + shname = "F32_F16_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16_f32.f32acc->l; + shname = "F16_F32_L"; + } else if (std::is_same() && std::is_same()) { + p = ctx->device->pipeline_matmul_f16.f32acc->l; + shname = "F16_L"; + } + } + } + + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + vk_buffer d_X = ggml_vk_create_buffer_check(ctx->device, sizeof(X_TYPE) * x_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_Y = ggml_vk_create_buffer_check(ctx->device, sizeof(Y_TYPE) * y_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_D = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + + X_TYPE* x = (X_TYPE *) malloc(sizeof(X_TYPE) * x_ne); + Y_TYPE* y = (Y_TYPE *) malloc(sizeof(Y_TYPE) * y_ne); + float* d = (float *) malloc(sizeof(float) * d_ne); + + for (size_t i = 0; i < x_ne; i++) { + if (std::is_same()) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = 1.0f; + // x[i] = i + 1; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + } else if (std::is_same()) { + x[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // x[i] = ggml_fp32_to_fp16(1.0f); + // x[i] = ggml_fp32_to_fp16(i + 1); + // x[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + } else { + GGML_ABORT("fatal error"); + } + } + for (size_t i = 0; i < y_ne; i++) { + if (std::is_same()) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i + 1; + } else if (std::is_same()) { + y[i] = ggml_fp32_to_fp16((rand() / (float)RAND_MAX) * 2.0f - 1.0f); + // y[i] = ggml_fp32_to_fp16((i % k == i / k) ? 1.0f : 0.0f); + // y[i] = ggml_fp32_to_fp16(i + 1); + } else { + GGML_ABORT("fatal error"); + } + } + + ggml_vk_buffer_write(d_X, 0, x, sizeof(X_TYPE) * k * m * batch); + ggml_vk_buffer_write(d_Y, 0, y, sizeof(Y_TYPE) * k * n * batch); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, ggml_vk_subbuffer(ctx, d_X), ggml_vk_subbuffer(ctx, d_Y), ggml_vk_subbuffer(ctx, d_D), ggml_vk_subbuffer(ctx, ctx->prealloc_split_k), + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_matmul waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + double time = std::chrono::duration_cast(end-begin).count() / 1000.0; + + // copy dst to host + ggml_vk_buffer_read(d_D, 0, d, sizeof(float) * d_ne); + + float * d_chk = (float *) malloc(sizeof(float) * d_ne); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_type src0_type; + ggml_type src1_type; + + if (std::is_same()) { + src0_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src0_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + if (std::is_same()) { + src1_type = GGML_TYPE_F32; + } else if (std::is_same()) { + src1_type = GGML_TYPE_F16; + } else { + GGML_ABORT("fatal error"); + } + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, src0_type, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, src1_type, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = x; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST " << shname << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + free(d_chk); + + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + + ggml_vk_destroy_buffer(d_X); + ggml_vk_destroy_buffer(d_Y); + ggml_vk_destroy_buffer(d_D); + + free(x); + free(y); + free(d); +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(float *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor->data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_quantize_data(const float * from, void * to, size_t ne, ggml_type quant) { + ggml_quantize_chunk(quant, from, to, 0, 1, ne, nullptr); +} + +static void ggml_vk_dequantize_data(const void * from, float * to, size_t ne, ggml_type quant) { + if (quant == GGML_TYPE_F32) { + memcpy(to, from, sizeof(float) * ne); + return; + } + + const auto * tt = ggml_get_type_traits(quant); + + ggml_to_float_t dequant_fn = tt->to_float; + + dequant_fn(from, to, ne); +} + +static void ggml_vk_test_dequant(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { + VK_LOG_DEBUG("ggml_vk_test_dequant(" << ne << ")"); + const size_t x_sz = sizeof(float) * ne; + const size_t x_sz_f16 = sizeof(ggml_fp16_t) * ne; + const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); + float * x = (float *) malloc(x_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz_f16, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + float * x_ref = (float *) malloc(x_sz); + ggml_fp16_t * x_chk = (ggml_fp16_t *) malloc(x_sz_f16); + + for (size_t i = 0; i < ne; i++) { + x[i] = rand() / (float)RAND_MAX; + } + + vk_pipeline p = ggml_vk_get_to_fp16(ctx, quant); + + ggml_vk_quantize_data(x, qx, ne, quant); + ggml_vk_dequantize_data(qx, x_ref, ne, quant); + + ggml_pipeline_request_descriptor_sets(ctx, p, 1); + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + const std::vector pc = { 1, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne, (uint32_t)ne }; + ggml_vk_dispatch_pipeline(ctx, subctx, p, { vk_subbuffer{ qx_buf, 0, qx_sz }, vk_subbuffer{ x_buf, 0, x_sz_f16 } }, pc, { (uint32_t)ne, 1, 1}); + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + + double ms_dequant = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(x_buf, 0, x_chk, x_sz_f16); + + int first_err = -1; + + double avg_err = 0.0; + for (size_t i = 0; i < ne; i++) { + double error = std::fabs(x_ref[i] - ggml_fp16_to_fp32(x_chk[i])); + avg_err += error; + + if (first_err < 0 && error > 0.05) { + first_err = i; + } + } + + avg_err /= ne; + + std::cerr << "TEST DEQUANT " << ggml_type_name(quant) << " time=" << ms_dequant << "ms avg_err=" << avg_err << std::endl; + + if (avg_err > 0.1) { + std::cerr << "first_error = " << first_err << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << ggml_fp16_to_fp32(x_chk[i]) << ", "; + } + std::cerr << std::endl << "Expected result: " << std::endl << std::endl; + for (int i = std::max(0, first_err - 5); i < std::min((int)ne, first_err + 5); i++) { + std::cerr << x_ref[i] << ", "; + } + std::cerr << std::endl; + } + + ggml_vk_destroy_buffer(x_buf); + ggml_vk_destroy_buffer(qx_buf); + + free(x); + free(qx); + free(x_ref); + free(x_chk); +} + +// This does not work without ggml q8_1 quantization support +// +// typedef uint16_t ggml_half; +// typedef uint32_t ggml_half2; +// +// #define QK8_1 32 +// typedef struct { +// union { +// struct { +// ggml_half d; // delta +// ggml_half s; // d * sum(qs[i]) +// } GGML_COMMON_AGGR_S; +// ggml_half2 ds; +// } GGML_COMMON_AGGR_U; +// int8_t qs[QK8_1]; // quants +// } block_q8_1; +// +// static void ggml_vk_test_quantize(ggml_backend_vk_context * ctx, size_t ne, ggml_type quant) { +// VK_LOG_DEBUG("ggml_vk_test_quantize(" << ne << ")"); +// GGML_ASSERT(quant == GGML_TYPE_Q8_1); +// +// const size_t x_sz = sizeof(float) * ne; +// const size_t qx_sz = ne * ggml_type_size(quant)/ggml_blck_size(quant); +// float * x = (float *) malloc(x_sz); +// block_q8_1 * qx = (block_q8_1 *)malloc(qx_sz); +// block_q8_1 * qx_res = (block_q8_1 *)malloc(qx_sz); +// vk_buffer x_buf = ggml_vk_create_buffer_check(ctx->device, x_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); +// vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); +// +// for (size_t i = 0; i < ne; i++) { +// x[i] = rand() / (float)RAND_MAX; +// } +// +// vk_pipeline p = ggml_vk_get_quantize_pipeline(ctx, quant); +// +// ggml_pipeline_request_descriptor_sets(ctx, p, 1); +// +// if (ctx->device->need_compiles) { +// ggml_vk_load_shaders(ctx->device); +// } +// +// ggml_pipeline_allocate_descriptor_sets(ctx); +// +// ggml_vk_buffer_write(x_buf, 0, x, x_sz); +// +// vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); +// ggml_vk_ctx_begin(ctx->device, subctx); +// ggml_vk_quantize_q8_1(ctx, subctx, ggml_vk_subbuffer(ctx, x_buf), ggml_vk_subbuffer(ctx, qx_buf), ne); +// ggml_vk_ctx_end(subctx); +// +// auto begin = std::chrono::high_resolution_clock::now(); +// +// ggml_vk_submit(subctx, ctx->fence); +// VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_quantize waitForFences"); +// ctx->device->device.resetFences({ ctx->fence }); +// ggml_vk_queue_command_pools_cleanup(ctx->device); +// +// auto end = std::chrono::high_resolution_clock::now(); +// +// double ms_quant = std::chrono::duration_cast(end-begin).count() / 1000.0; +// ggml_vk_buffer_read(qx_buf, 0, qx, qx_sz); +// +// ggml_vk_quantize_data(x, qx_res, ne, quant); +// +// int first_err = -1; +// +// for (size_t i = 0; i < ne / 32; i++) { +// double error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// error = std::fabs(ggml_fp16_to_fp32(qx_res[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) - ggml_fp16_to_fp32(qx[i].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s)); +// +// if (first_err < 0 && error > 0.1) { +// first_err = i; +// } +// +// for (size_t j = 0; j < 32; j++) { +// uint64_t error = std::abs(qx_res[i].qs[j] - qx[i].qs[j]); +// +// if (first_err < 0 && error > 1) { +// first_err = i; +// } +// } +// } +// +// std::cerr << "TEST QUANTIZE " << ggml_type_name(quant) << " time=" << ms_quant << "ms " << (first_err == -1 ? "CORRECT" : "INCORRECT") << std::endl; +// +// if (first_err != -1) { +// std::cerr << "first_error = " << first_err << std::endl; +// std::cerr << "Actual result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx[first_err].qs[j] << " "; +// } +// std::cerr << std::endl << std::endl << "Expected result: " << std::endl << std::endl; +// std::cout << "d=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.d) << " s=" << ggml_fp16_to_fp32(qx_res[first_err].GGML_COMMON_AGGR_U.GGML_COMMON_AGGR_S.s) << " "; +// for (size_t j = 0; j < 32; j++) { +// std::cout << " qs" << j << "=" << (uint32_t)qx_res[first_err].qs[j] << " "; +// } +// std::cerr << std::endl; +// } +// +// ggml_vk_destroy_buffer(x_buf); +// ggml_vk_destroy_buffer(qx_buf); +// +// free(x); +// free(qx); +// free(qx_res); +// } + +static void ggml_vk_test_dequant_matmul(ggml_backend_vk_context * ctx, size_t m, size_t n, size_t k, size_t batch, size_t num_it, size_t split_k, size_t shader_size, ggml_type quant, bool mmq = false) { + VK_LOG_DEBUG("ggml_vk_test_dequant_matmul(" << m << ", " << n << ", " << k << ", " << batch << ", " << num_it << ", " << split_k << ", " << ggml_type_name(quant) << ")"); + const size_t x_ne = m * k * batch; + const size_t y_ne = k * n * batch; + const size_t d_ne = m * n * batch; + + vk_matmul_pipeline2 * pipelines; + + if (mmq) { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat_q8_1; + } else { + pipelines = ctx->device->pipeline_dequant_mul_mat_mat; + } + + const bool fp16acc = ctx->device->fp16; + + vk_pipeline p; + std::string shname; + if (shader_size == 0) { + p = fp16acc ? pipelines[quant].f16acc->a_s : pipelines[quant].f32acc->a_s; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_S"; + } else if (shader_size == 1) { + p = fp16acc ? pipelines[quant].f16acc->a_m : pipelines[quant].f32acc->a_m; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_M"; + } else if (shader_size == 2) { + p = fp16acc ? pipelines[quant].f16acc->a_l : pipelines[quant].f32acc->a_l; + shname = std::string(ggml_type_name(quant)) + "_ALIGNED_L"; + } else { + GGML_ASSERT(0); + } + + const size_t kpad = mmq ? 0 : ggml_vk_align_size(k, p->align); + + if (mmq || k != kpad) { + if (shader_size == 0) { + p = fp16acc ? pipelines[quant].f16acc->s : pipelines[quant].f32acc->s; + shname = std::string(ggml_type_name(quant)) + "_S"; + } else if (shader_size == 1) { + p = fp16acc ? pipelines[quant].f16acc->m : pipelines[quant].f32acc->m; + shname = std::string(ggml_type_name(quant)) + "_M"; + } else if (shader_size == 2) { + p = fp16acc ? pipelines[quant].f16acc->l : pipelines[quant].f32acc->l; + shname = std::string(ggml_type_name(quant)) + "_L"; + } else { + GGML_ASSERT(0); + } + } + + if (p == nullptr) { + std::cerr << "error: no pipeline for ggml_vk_test_dequant_matmul " << ggml_type_name(quant) << std::endl; + return; + } + + const size_t x_sz = sizeof(float) * x_ne; + const size_t y_sz = sizeof(float) * y_ne; + const size_t qx_sz = x_ne * ggml_type_size(quant)/ggml_blck_size(quant); + const size_t qy_sz = mmq ? y_ne * ggml_type_size(GGML_TYPE_Q8_1)/ggml_blck_size(GGML_TYPE_Q8_1) : y_sz; + const size_t d_sz = sizeof(float) * d_ne; + float * x = (float *) malloc(x_sz); + float * y = (float *) malloc(y_sz); + void * qx = malloc(qx_sz); + vk_buffer qx_buf = ggml_vk_create_buffer_check(ctx->device, qx_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer y_buf = ggml_vk_create_buffer_check(ctx->device, y_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer qy_buf = ggml_vk_create_buffer_check(ctx->device, qy_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + vk_buffer d_buf = ggml_vk_create_buffer_check(ctx->device, d_sz, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + float * d = (float *) malloc(d_sz); + float * d_chk = (float *) malloc(d_sz); + + for (size_t i = 0; i < x_ne; i++) { + x[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // x[i] = (i % k == i / k) ? 1.0f : 0.0f; + // x[i] = i % k; + } + + ggml_vk_quantize_data(x, qx, x_ne, quant); + + for (size_t i = 0; i < y_ne; i++) { + y[i] = (rand() / (float)RAND_MAX) * 2.0f - 1.0f; + // y[i] = (i % k == i / k) ? 1.0f : 0.0f; + // y[i] = i % k; + } + + ggml_pipeline_request_descriptor_sets(ctx, p, num_it); + if (split_k > 1) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_matmul_split_k_reduce, num_it); + + if (ctx->prealloc_split_k == nullptr || ctx->prealloc_split_k->size < sizeof(float) * d_ne * split_k) { + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_check(ctx->device, sizeof(float) * d_ne * split_k, {vk::MemoryPropertyFlagBits::eDeviceLocal}); + } + } + if (mmq) { + ggml_pipeline_request_descriptor_sets(ctx, ctx->device->pipeline_quantize_q8_1, num_it); + } + + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + + ggml_pipeline_allocate_descriptor_sets(ctx); + + ggml_vk_buffer_write(qx_buf, 0, qx, qx_sz); + ggml_vk_buffer_write(y_buf, 0, y, y_sz); + + vk_context subctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ggml_vk_ctx_begin(ctx->device, subctx); + if (mmq) { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_quantize_q8_1(ctx, subctx, { y_buf, 0, y_sz }, { qy_buf, 0, qy_sz }, y_ne); + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { qy_buf, 0, qy_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } else { + for (size_t i = 0; i < num_it; i++) { + ggml_vk_matmul( + ctx, subctx, p, { qx_buf, 0, qx_sz }, { y_buf, 0, y_sz }, { d_buf, 0, d_sz }, { ctx->prealloc_split_k, 0, ctx->prealloc_size_split_k }, + m, n, k, + k, k, m, k*m, k*n, m*n, + split_k, batch, batch, batch, 1, 1, n + ); + } + } + ggml_vk_ctx_end(subctx); + + auto begin = std::chrono::high_resolution_clock::now(); + + ggml_vk_submit(subctx, ctx->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->fence }, true, UINT64_MAX), "ggml_vk_test_dequant waitForFences"); + ctx->device->device.resetFences({ ctx->fence }); + ggml_vk_queue_command_pools_cleanup(ctx->device); + + auto end = std::chrono::high_resolution_clock::now(); + + double time_ms = std::chrono::duration_cast(end-begin).count() / 1000.0; + ggml_vk_buffer_read(d_buf, 0, d, d_sz); + + ggml_init_params iparams = { + /*.mem_size =*/ 1024*1024*1024, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ true, + }; + + ggml_context * ggml_ctx = ggml_init(iparams); + + ggml_tensor * src0_ggml = ggml_new_tensor_3d(ggml_ctx, quant, k, m, batch); + ggml_tensor * src1_ggml = ggml_new_tensor_3d(ggml_ctx, GGML_TYPE_F32, k, n, batch); + ggml_tensor * tensor_ggml = ggml_mul_mat(ggml_ctx, src0_ggml, src1_ggml); + + src0_ggml->data = qx; + src1_ggml->data = y; + tensor_ggml->data = d_chk; + + ggml_cgraph * cgraph = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph, tensor_ggml); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph, 1); + + ggml_free(ggml_ctx); + + double avg_err = 0.0; + int first_err_n = -1; + int first_err_m = -1; + int first_err_b = -1; + + for (size_t i = 0; i < m*n*batch; i++) { + double err = std::fabs(d[i] - d_chk[i]); + avg_err += err; + + if ((err > 0.05f || std::isnan(err)) && first_err_n == -1) { + first_err_b = i / (m * n); + first_err_n = (i % (m * n)) / m; + first_err_m = (i % (m * n)) % m; + } + } + + avg_err /= m * n; + + double tflops = 2.0*m*n*k*batch*num_it / (time_ms / 1000.0) / (1000.0*1000.0*1000.0*1000.0); + + std::cerr << "TEST dequant matmul " << shname; + if (mmq) { + std::cerr << " mmq"; + } + std::cerr << " m=" << m << " n=" << n << " k=" << k << " batch=" << batch << " split_k=" << split_k << " matmul " << time_ms / num_it << "ms " << tflops << " TFLOPS avg_err=" << avg_err << std::endl; + + if (avg_err > 0.01 || std::isnan(avg_err)) { + std::cerr << "m = " << first_err_m << " n = " << first_err_n << " b = " << first_err_b << std::endl; + std::cerr << "Actual result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "Expected result: " << std::endl << std::endl; + ggml_vk_print_matrix_area(d_chk, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "src0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(x, GGML_TYPE_F32, k, m, first_err_m, first_err_n, first_err_b); + std::cerr << std::endl; + std::cerr << "src1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(y, GGML_TYPE_F32, k, n, first_err_m, first_err_n, first_err_b); + + if (split_k > 1) { + float * split_k_buf = (float *) malloc(sizeof(float) * d_ne * split_k); + ggml_vk_buffer_read(ctx->prealloc_split_k, 0, split_k_buf, sizeof(float) * d_ne * split_k); + + std::cerr << "d_buf0: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf1: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf2: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 2 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + std::cerr << "d_buf3: " << std::endl << std::endl; + ggml_vk_print_matrix_area(split_k_buf + 3 * d_ne, GGML_TYPE_F32, m, n, first_err_m, first_err_n, first_err_b); + + free(split_k_buf); + } + } + + ggml_vk_destroy_buffer(qx_buf); + ggml_vk_destroy_buffer(y_buf); + ggml_vk_destroy_buffer(qy_buf); + ggml_vk_destroy_buffer(d_buf); + + free(x); + free(qx); + free(y); + free(d); + free(d_chk); +} +#endif + +static void ggml_vk_preallocate_buffers(ggml_backend_vk_context * ctx) { +#if defined(GGML_VULKAN_RUN_TESTS) + const std::vector vals { + 512, 512, 128, + 128, 512, 512, + 4096, 512, 4096, + 11008, 512, 4096, + 4096, 512, 11008, + 32000, 512, 4096, + 8, 8, 8, + 100, 46, 576, + 623, 111, 128, + 100, 46, 558, + 512, 1, 256, + 128, 110, 622, + 511, 511, 127, + 511, 511, 7, + 511, 511, 17, + 49, 49, 128, + 128, 49, 49, + 4096, 49, 4096, + }; + const size_t num_it = 100; + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q4_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q4_0, true); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0); + + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 0, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 1, GGML_TYPE_Q8_0, true); + ggml_vk_test_dequant_matmul(ctx, 4096, 512, 4096, 2, num_it, 1, 2, GGML_TYPE_Q8_0, true); + + abort(); + + for (size_t i = 0; i < vals.size(); i += 3) { + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2); + std::cerr << '\n'; + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1); + ggml_vk_test_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2); + std::cerr << '\n' << std::endl; + + if (vals[i + 2] % 32 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_0); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_0); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_0); + std::cerr << '\n' << std::endl; + } + + if (vals[i + 2] % 256 == 0) { + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 1, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 2, 2, GGML_TYPE_Q4_K); + std::cerr << '\n'; + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 0, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 1, GGML_TYPE_Q4_K); + ggml_vk_test_dequant_matmul(ctx, vals[i], vals[i + 1], vals[i + 2], 2, num_it, 4, 2, GGML_TYPE_Q4_K); + std::cerr << '\n' << std::endl; + } + } + + GGML_ABORT("fatal error"); +#endif + + if (ctx->prealloc_x == nullptr || (ctx->prealloc_size_x > 0 && ctx->prealloc_x->size < ctx->prealloc_size_x)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(x_size: " << ctx->prealloc_size_x << ")"); + // Resize buffer + if (ctx->prealloc_x != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_x); + } + ctx->prealloc_x = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_x); + } + if (ctx->prealloc_y == nullptr || (ctx->prealloc_size_y > 0 && ctx->prealloc_y->size < ctx->prealloc_size_y)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(y_size: " << ctx->prealloc_size_y << ")"); + // Resize buffer + if (ctx->prealloc_y != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_y); + } + ctx->prealloc_y = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_y); + } + if (ctx->prealloc_split_k == nullptr || (ctx->prealloc_size_split_k > 0 && ctx->prealloc_split_k->size < ctx->prealloc_size_split_k)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(split_k_size: " << ctx->prealloc_size_split_k << ")"); + // Resize buffer + if (ctx->prealloc_split_k != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + } + ctx->prealloc_split_k = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_split_k); + } + if (ctx->prealloc_add_rms_partials == nullptr || (ctx->prealloc_size_add_rms_partials > 0 && ctx->prealloc_add_rms_partials->size < ctx->prealloc_size_add_rms_partials)) { + VK_LOG_MEMORY("ggml_vk_preallocate_buffers(add_partials_size: " << ctx->prealloc_add_rms_partials << ")"); + // Resize buffer + if (ctx->prealloc_add_rms_partials != nullptr) { + ggml_vk_destroy_buffer(ctx->prealloc_add_rms_partials); + } + ctx->prealloc_add_rms_partials = ggml_vk_create_buffer_device(ctx->device, ctx->prealloc_size_add_rms_partials); + } +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context* ctx, ggml_cgraph * cgraph, ggml_tensor* tensor, int tensor_idx, bool use_fence, bool almost_ready); + +// Returns true if node has enqueued work into the queue, false otherwise +// If submit is true the current all operations queued so far are being submitted to Vulkan to overlap cmdlist creation and GPU execution. +static bool ggml_vk_build_graph(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int node_idx, ggml_tensor *node_begin, int node_idx_begin, bool dryrun, bool last_node, bool almost_ready, bool submit){ + ggml_tensor * node = cgraph->nodes[node_idx]; + if (ggml_is_empty(node) || !node->buffer) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_build_graph(" << node << ", " << ggml_op_name(node->op) << ")"); + ctx->semaphore_idx = 0; + + ggml_tensor * src0 = node->src[0]; + ggml_tensor * src1 = node->src[1]; + ggml_tensor * src2 = node->src[2]; + ggml_tensor * src3 = node->src[3]; + + switch (node->op) { + // Return on empty ops to avoid generating a compute_ctx and setting exit_tensor + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + return false; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + break; + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + break; + default: + return false; + } + break; + case GGML_OP_ADD: + { + int next_node_idx = node_idx + 1 + ctx->num_additional_fused_ops; + if (next_node_idx < cgraph->n_nodes && + cgraph->nodes[next_node_idx]->op == GGML_OP_RMS_NORM && + cgraph->nodes[next_node_idx]->src[0] == cgraph->nodes[next_node_idx - 1] && + ggml_nrows(cgraph->nodes[next_node_idx]) == 1 && + ctx->device->add_rms_fusion) { + if (dryrun) { + ctx->prealloc_size_add_rms_partials += ggml_vk_rms_partials_size(ctx, cgraph->nodes[node_idx]); + } + ctx->do_add_rms_partials = true; + } + } break; + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD_ID: + case GGML_OP_ACC: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_LEAKY_RELU: + case GGML_OP_FLASH_ATTN_EXT: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + break; + default: + std::cerr << "ggml_vulkan: Error: Missing op: " << ggml_op_name(node->op) << std::endl; + GGML_ABORT("fatal error"); + } + + vk_context compute_ctx; + + if (!dryrun) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + } else { + switch (node->op) { + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_UNARY: + case GGML_OP_GLU: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_SGD: + { + // These operations all go through ggml_vk_op_f32, so short-circuit and + // do the only thing needed for the dryrun. + vk_pipeline pipeline = ggml_vk_op_get_pipeline(ctx, src0, src1, src2, node, node->op); + ggml_pipeline_request_descriptor_sets(ctx, pipeline, 1); + if (node->op == GGML_OP_RMS_NORM) { + ctx->do_add_rms_partials = false; + } + return false; + } + default: + break; + } + } + + if (!dryrun) { + // This logic detects dependencies between modes in the graph and calls ggml_vk_sync_buffers + // to synchronize them. This handles most "normal" synchronization when computing the graph, and when + // there is no auxiliary memory use, it shouldn't be necessary to call ggml_vk_sync_buffers + // outside of this logic. When a node uses one of the prealloc buffers for something like + // dequantization or split_k, additional synchronization is needed between those passes. + bool need_sync = false; + + // Check whether "node" requires synchronization. The node requires synchronization if it + // overlaps in memory with another unsynchronized node and at least one of them is a write. + // Destination nodes are checked against both the written/read lists. Source nodes are only + // checked against the written list. Two nodes overlap in memory if they come from the same + // buffer and the tensor or view ranges overlap. + auto const &overlaps_unsynced = [&](const ggml_tensor *node, const std::vector &unsynced_nodes) -> bool { + if (unsynced_nodes.size() == 0) { + return false; + } + auto n_base = vk_tensor_offset(node) + node->view_offs; + auto n_size = ggml_nbytes(node); + ggml_backend_vk_buffer_context * a_buf_ctx = (ggml_backend_vk_buffer_context *)node->buffer->context; + vk_buffer a_buf = a_buf_ctx->dev_buffer; + for (auto &other : unsynced_nodes) { + ggml_backend_vk_buffer_context * o_buf_ctx = (ggml_backend_vk_buffer_context *)other->buffer->context; + vk_buffer o_buf = o_buf_ctx->dev_buffer; + if (a_buf == o_buf) { + auto o_base = vk_tensor_offset(other) + other->view_offs; + auto o_size = ggml_nbytes(other); + + if ((o_base <= n_base && n_base < o_base + o_size) || + (n_base <= o_base && o_base < n_base + n_size)) { + return true; + } + } + } + return false; + }; + + // For all fused ops, check if the destination node or any of the source + // nodes require synchronization. + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1 && !need_sync; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + if (overlaps_unsynced(cur_node, ctx->unsynced_nodes_read) || overlaps_unsynced(cur_node, ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + if (overlaps_unsynced(cur_node->src[j], ctx->unsynced_nodes_written)) { + need_sync = true; + break; + } + } + } + if (need_sync) { + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + // Add the last fused node and all fused source nodes to the unsynchronized list. + const ggml_tensor * last_node = cgraph->nodes[node_idx + ctx->num_additional_fused_ops]; + ctx->unsynced_nodes_written.push_back(last_node); + for (int32_t i = 0; i < ctx->num_additional_fused_ops + 1; ++i) { + const ggml_tensor *cur_node = cgraph->nodes[node_idx + i]; + for (uint32_t j = 0; j < GGML_MAX_SRC; ++j) { + if (!cur_node->src[j]) { + continue; + } + ctx->unsynced_nodes_read.push_back(cur_node->src[j]); + } + } + } + + switch (node->op) { + case GGML_OP_REPEAT: + ggml_vk_repeat(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_REPEAT_BACK: + ggml_vk_repeat_back(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ACC: + ggml_vk_acc(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_GET_ROWS: + ggml_vk_get_rows(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ADD: + if (ctx->num_additional_fused_ops) { + ggml_vk_multi_add(ctx, compute_ctx, cgraph, node_idx, dryrun); + } else { + ggml_vk_add(ctx, compute_ctx, src0, src1, node, dryrun); + } + break; + case GGML_OP_SUB: + ggml_vk_sub(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL: + ggml_vk_mul(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_DIV: + ggml_vk_div(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ADD_ID: + ggml_vk_add_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + case GGML_OP_CONCAT: + ggml_vk_concat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_UPSCALE: + ggml_vk_upscale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SCALE: + ggml_vk_scale(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SQR: + ggml_vk_sqr(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SQRT: + ggml_vk_sqrt(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SIN: + ggml_vk_sin(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COS: + ggml_vk_cos(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CLAMP: + ggml_vk_clamp(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_PAD: + ggml_vk_pad(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ROLL: + ggml_vk_roll(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CPY: + case GGML_OP_CONT: + case GGML_OP_DUP: + ggml_vk_cpy(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SET_ROWS: + ggml_vk_set_rows(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_SILU_BACK: + ggml_vk_silu_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_NORM: + ggml_vk_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_GROUP_NORM: + ggml_vk_group_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_RMS_NORM: + if (ctx->num_additional_fused_ops > 0) { + // fused rms_norm + mul + ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + ggml_tensor *other_src = mul->src[0] == node ? mul->src[1] : mul->src[0]; + ggml_vk_rms_norm(ctx, compute_ctx, src0, other_src, mul, (float *)node->op_params, dryrun); + } else { + ggml_vk_rms_norm(ctx, compute_ctx, src0, src0, node, (float *)node->op_params, dryrun); + } + break; + case GGML_OP_RMS_NORM_BACK: + ggml_vk_rms_norm_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_L2_NORM: + ggml_vk_l2_norm(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(node)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + ggml_vk_unary(ctx, compute_ctx, src0, node, dryrun); + break; + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(node)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + ggml_vk_glu(ctx, compute_ctx, src0, src1, node, dryrun); + break; + default: + return false; + } + break; + case GGML_OP_DIAG_MASK_INF: + ggml_vk_diag_mask_inf(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SOFT_MAX: + ggml_vk_soft_max(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + case GGML_OP_SOFT_MAX_BACK: + ggml_vk_soft_max_back(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_ROPE: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, false, dryrun); + + break; + case GGML_OP_ROPE_BACK: + ggml_vk_rope(ctx, compute_ctx, src0, src1, src2, node, true, dryrun); + + break; + case GGML_OP_ARGSORT: + ggml_vk_argsort(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM: + ggml_vk_sum(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_SUM_ROWS: + ggml_vk_sum_rows(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_MEAN: + ggml_vk_mean(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_ARGMAX: + ggml_vk_argmax(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_COUNT_EQUAL: + ggml_vk_count_equal(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_IM2COL: + ggml_vk_im2col(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_IM2COL_3D: + ggml_vk_im2col_3d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_TIMESTEP_EMBEDDING: + ggml_vk_timestep_embedding(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CONV_TRANSPOSE_1D: + ggml_vk_conv_transpose_1d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_POOL_2D: + ggml_vk_pool_2d(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_CONV_2D: + ggml_vk_conv_2d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_CONV_TRANSPOSE_2D: + ggml_vk_conv_transpose_2d(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_CONV_2D_DW: + ggml_vk_conv_2d_dw(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_LEAKY_RELU: + ggml_vk_leaky_relu(ctx, compute_ctx, src0, node, dryrun); + + break; + case GGML_OP_MUL_MAT: + ggml_vk_mul_mat(ctx, compute_ctx, src0, src1, node, dryrun); + + break; + case GGML_OP_MUL_MAT_ID: + ggml_vk_mul_mat_id(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + + case GGML_OP_FLASH_ATTN_EXT: + ggml_vk_flash_attn(ctx, compute_ctx, src0, src1, src2, src3, node->src[4], node, dryrun); + + break; + + case GGML_OP_RWKV_WKV6: + ggml_vk_rwkv_wkv6(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_RWKV_WKV7: + ggml_vk_rwkv_wkv7(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_OPT_STEP_ADAMW: + ggml_vk_opt_step_adamw(ctx, compute_ctx, node, dryrun); + + break; + + case GGML_OP_OPT_STEP_SGD: + ggml_vk_opt_step_sgd(ctx, compute_ctx, src0, src1, src2, node, dryrun); + + break; + default: + return false; + } + + if (dryrun) { + return false; + } + + ctx->tensor_ctxs[node_idx] = compute_ctx; + +#if defined(GGML_VULKAN_CHECK_RESULTS) + // Force context reset on each node so that each tensor ends up in its own context + // and can be run and compared to its CPU equivalent separately + last_node = true; +#endif + + if (submit || last_node) { + ggml_vk_ctx_end(compute_ctx); + + // TODO probably it'd be better to pass a exit_node flag to ggml_vk_compute_forward + if (last_node) { + compute_ctx->exit_tensor_idx = node_idx_begin; + } + else { + compute_ctx->exit_tensor_idx = -1; + } + + ctx->compute_ctx.reset(); + + bool ok = ggml_vk_compute_forward(ctx, cgraph, node_begin, node_idx_begin, false, almost_ready); + if (!ok) { + if (node->op == GGML_OP_UNARY) { + std::cerr << __func__ << ": error: op not supported UNARY " << node->name << " (" << ggml_unary_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else if (node->op == GGML_OP_GLU) { + std::cerr << __func__ << ": error: op not supported GLU " << node->name << " (" << ggml_glu_op_name(static_cast(node->op_params[0])) << ")" << std::endl; + } else { + std::cerr << __func__ << ": error: op not supported " << node->name << " (" << ggml_op_name(node->op) << ")" << std::endl; + } + } + + } + return true; +} + +static bool ggml_vk_compute_forward(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, ggml_tensor * tensor, int tensor_idx, bool use_fence = true, bool almost_ready = false) { + GGML_UNUSED(cgraph); + ggml_backend_buffer * buf = nullptr; + + switch (tensor->op) { + case GGML_OP_ADD: + case GGML_OP_ACC: + case GGML_OP_GET_ROWS: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + case GGML_OP_ADD_ID: + case GGML_OP_CONCAT: + case GGML_OP_UPSCALE: + case GGML_OP_SCALE: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_CPY: + case GGML_OP_SET_ROWS: + case GGML_OP_CONT: + case GGML_OP_DUP: + case GGML_OP_SILU_BACK: + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_RMS_NORM: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_L2_NORM: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_NONE: + case GGML_OP_ARGSORT: + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_TRANSPOSE_1D: + case GGML_OP_POOL_2D: + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + case GGML_OP_CONV_2D_DW: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + case GGML_OP_LEAKY_RELU: + case GGML_OP_REPEAT: + case GGML_OP_REPEAT_BACK: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + buf = tensor->buffer; + break; + case GGML_OP_UNARY: + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + buf = tensor->buffer; + break; + default: + return false; + } + break; + case GGML_OP_GLU: + switch (ggml_get_glu_op(tensor)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + buf = tensor->buffer; + break; + default: + return false; + } + break; + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + case GGML_OP_FLASH_ATTN_EXT: + buf = tensor->buffer; + + break; + default: + return false; + } + + if (buf == nullptr) { + return false; + } + + VK_LOG_DEBUG("ggml_vk_compute_forward(" << tensor << ", name=" << tensor->name << ", op=" << ggml_op_name(tensor->op) << ", type=" << tensor->type << ", ne0=" << tensor->ne[0] << ", ne1=" << tensor->ne[1] << ", ne2=" << tensor->ne[2] << ", ne3=" << tensor->ne[3] << ", nb0=" << tensor->nb[0] << ", nb1=" << tensor->nb[1] << ", nb2=" << tensor->nb[2] << ", nb3=" << tensor->nb[3] << ", view_src=" << tensor->view_src << ", view_offs=" << tensor->view_offs << ")"); + + vk_context subctx = ctx->tensor_ctxs[tensor_idx].lock(); + + // always wait for the GPU work to be done for the last submit + if (tensor_idx == subctx->exit_tensor_idx) { + use_fence = true; + } + + // Only run if ctx hasn't been submitted yet + if (!subctx->seqs.empty()) { +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_0(ctx, cgraph, tensor_idx); + use_fence = true; +#endif + + // Do staging buffer copies + for (auto& cpy : subctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + for (auto& mset : subctx->memsets) { + memset(mset.dst, mset.val, mset.n); + } + + if (almost_ready && !ctx->almost_ready_fence_pending && !use_fence) { + ggml_vk_submit(subctx, ctx->almost_ready_fence); + ctx->almost_ready_fence_pending = true; + } else { + ggml_vk_submit(subctx, use_fence ? ctx->fence : vk::Fence{}); + } + + if (use_fence) { + ggml_vk_wait_for_fence(ctx); + } +#ifdef GGML_VULKAN_CHECK_RESULTS + ggml_vk_check_results_1(ctx, cgraph, tensor_idx); +#endif + } + + if (tensor_idx == subctx->exit_tensor_idx) { + // Do staging buffer copies + for (auto& cpy : subctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + subctx->in_memcpys.clear(); + subctx->out_memcpys.clear(); + subctx->memsets.clear(); + } + + return true; +} + +// Clean up after graph processing is done +static void ggml_vk_graph_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_graph_cleanup()"); + for (auto& buffer : ctx->gc.temp_buffers) { + ggml_vk_pool_free(ctx, buffer); + } + ctx->gc.temp_buffers.clear(); + ctx->prealloc_y_last_pipeline_used = {}; + + ctx->unsynced_nodes_written.clear(); + ctx->unsynced_nodes_read.clear(); + ctx->prealloc_x_need_sync = ctx->prealloc_y_need_sync = ctx->prealloc_split_k_need_sync = false; + + ggml_vk_command_pool_cleanup(ctx->device, ctx->compute_cmd_pool); + ggml_vk_command_pool_cleanup(ctx->device, ctx->transfer_cmd_pool); + + for (size_t i = 0; i < ctx->gc.semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.semaphores[i].s }); + } + ctx->gc.semaphores.clear(); + + for (size_t i = 0; i < ctx->gc.tl_semaphores.size(); i++) { + ctx->device->device.destroySemaphore({ ctx->gc.tl_semaphores[i].s }); + } + ctx->gc.tl_semaphores.clear(); + ctx->semaphore_idx = 0; + + ctx->event_idx = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.resetEvent(event); + } + + ctx->tensor_ctxs.clear(); + ctx->gc.contexts.clear(); + ctx->pipeline_descriptor_set_requirements = 0; + ctx->descriptor_set_idx = 0; +} + +// Clean up on backend free +static void ggml_vk_cleanup(ggml_backend_vk_context * ctx) { + VK_LOG_DEBUG("ggml_vk_cleanup(" << ctx->name << ")"); + ggml_vk_graph_cleanup(ctx); + + ggml_vk_destroy_buffer(ctx->prealloc_x); + ggml_vk_destroy_buffer(ctx->prealloc_y); + ggml_vk_destroy_buffer(ctx->prealloc_split_k); + ctx->prealloc_y_last_pipeline_used = nullptr; + + for (auto& buffer : ctx->buffer_pool) { + ggml_vk_destroy_buffer(buffer); + } + + ctx->prealloc_size_x = 0; + ctx->prealloc_size_y = 0; + ctx->prealloc_size_split_k = 0; + + for (auto& event : ctx->gc.events) { + ctx->device->device.destroyEvent(event); + } + ctx->gc.events.clear(); + + ctx->device->device.destroyFence(ctx->fence); + ctx->device->device.destroyFence(ctx->almost_ready_fence); + + for (auto& pool : ctx->descriptor_pools) { + ctx->device->device.destroyDescriptorPool(pool); + } + ctx->descriptor_pools.clear(); + ctx->descriptor_sets.clear(); + + ctx->compute_cmd_pool.destroy(ctx->device->device); + ctx->transfer_cmd_pool.destroy(ctx->device->device); +} + +static int ggml_vk_get_device_count() { + ggml_vk_instance_init(); + + return vk_instance.device_indices.size(); +} + +static void ggml_vk_get_device_description(int device, char * description, size_t description_size) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties props; + devices[device].getProperties(&props); + + snprintf(description, description_size, "%s", props.deviceName.data()); +} + +static std::string ggml_vk_get_device_id(int device) { + ggml_vk_instance_init(); + + std::vector devices = vk_instance.instance.enumeratePhysicalDevices(); + + vk::PhysicalDeviceProperties2 props; + vk::PhysicalDeviceIDProperties deviceIDProps; + props.pNext = &deviceIDProps; + devices[device].getProperties2(&props); + + const auto& uuid = deviceIDProps.deviceUUID; + char id[64]; + snprintf(id, sizeof(id), + "GPU-%02x%02x%02x%02x-%02x%02x-%02x%02x-%02x%02x-%02x%02x%02x%02x%02x%02x", + uuid[0], uuid[1], uuid[2], uuid[3], + uuid[4], uuid[5], + uuid[6], uuid[7], + uuid[8], uuid[9], + uuid[10], uuid[11], uuid[12], uuid[13], uuid[14], uuid[15] + ); + return std::string(id); +} + +// backend interface + +#define UNUSED GGML_UNUSED + +// device backend + +static bool ggml_backend_buffer_is_vk(ggml_backend_buffer_t buffer) { + return buffer->buft->iface.get_name == ggml_backend_vk_buffer_type_name; +} + +static void ggml_backend_vk_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_free_buffer()"); + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + ggml_vk_destroy_buffer(ctx->dev_buffer); + delete ctx; + delete buffer; +} + +static void * ggml_backend_vk_buffer_get_base(ggml_backend_buffer_t buffer) { + return vk_ptr_base; + + UNUSED(buffer); +} + +static enum ggml_status ggml_backend_vk_buffer_init_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_init_tensor(" << buffer << " (" << buffer->context << "), " << tensor << ")"); + if (tensor->view_src != nullptr) { + GGML_ASSERT(tensor->view_src->buffer->buft == buffer->buft); + } + return GGML_STATUS_SUCCESS; +} + +static void ggml_backend_vk_buffer_memset_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, uint8_t value, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_memset_tensor(" << buffer << ", " << tensor << ", " << value << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + uint32_t val32 = (uint32_t)value * 0x01010101; + ggml_vk_buffer_memset(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, val32, size); +} + +static void ggml_backend_vk_buffer_set_tensor(ggml_backend_buffer_t buffer, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_set_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_buffer_get_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_buffer_get_tensor(" << buffer << ", " << tensor << ", " << data << ", " << offset << ", " << size << ")"); + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read(buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_buffer_cpy_tensor(ggml_backend_buffer_t buffer, const ggml_tensor * src, ggml_tensor * dst) { + if (ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy(dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + + return true; + } + return false; + + UNUSED(buffer); +} + +static void ggml_backend_vk_buffer_clear(ggml_backend_buffer_t buffer, uint8_t value) { + ggml_backend_vk_buffer_context * ctx = (ggml_backend_vk_buffer_context *)buffer->context; + + ggml_vk_buffer_memset(ctx->dev_buffer, 0, value, buffer->size); +} + +static ggml_backend_buffer_i ggml_backend_vk_buffer_interface = { + /* .free_buffer = */ ggml_backend_vk_buffer_free_buffer, + /* .get_base = */ ggml_backend_vk_buffer_get_base, + /* .init_tensor = */ ggml_backend_vk_buffer_init_tensor, + /* .memset_tensor = */ ggml_backend_vk_buffer_memset_tensor, + /* .set_tensor = */ ggml_backend_vk_buffer_set_tensor, + /* .get_tensor = */ ggml_backend_vk_buffer_get_tensor, + /* .cpy_tensor = */ ggml_backend_vk_buffer_cpy_tensor, + /* .clear = */ ggml_backend_vk_buffer_clear, + /* .reset = */ NULL, +}; + +// vk buffer type +static const char * ggml_backend_vk_buffer_type_name(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return ctx->name.c_str(); +} + +static ggml_backend_buffer_t ggml_backend_vk_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_buffer_type_alloc_buffer(" << size << ")"); + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + + vk_buffer dev_buffer = nullptr; + try { + dev_buffer = ggml_vk_create_buffer_device(ctx->device, size); + } catch (const vk::SystemError& e) { + return nullptr; + } + + ggml_backend_vk_buffer_context * bufctx = new ggml_backend_vk_buffer_context(ctx->device, std::move(dev_buffer), ctx->name); + + return ggml_backend_buffer_init(buft, ggml_backend_vk_buffer_interface, bufctx, size); +} + +static size_t ggml_backend_vk_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->properties.limits.minStorageBufferOffsetAlignment; +} + +static size_t ggml_backend_vk_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + ggml_backend_vk_buffer_type_context * ctx = (ggml_backend_vk_buffer_type_context *) buft->context; + return ctx->device->suballocation_block_size; +} + +static size_t ggml_backend_vk_buffer_type_get_alloc_size(ggml_backend_buffer_type_t buft, const ggml_tensor * tensor) { + return ggml_nbytes(tensor); + + UNUSED(buft); +} + +ggml_backend_buffer_type_t ggml_backend_vk_buffer_type(size_t dev_num) { + ggml_vk_instance_init(); + + VK_LOG_DEBUG("ggml_backend_vk_buffer_type(" << dev_num << ")"); + + vk_device dev = ggml_vk_get_device(dev_num); + + return &dev->buffer_type; +} + +// host buffer type + +static const char * ggml_backend_vk_host_buffer_type_name(ggml_backend_buffer_type_t buft) { + return GGML_VK_NAME "_Host"; + + UNUSED(buft); +} + +static const char * ggml_backend_vk_host_buffer_name(ggml_backend_buffer_t buffer) { + return GGML_VK_NAME "_Host"; + + UNUSED(buffer); +} + +static void ggml_backend_vk_host_buffer_free_buffer(ggml_backend_buffer_t buffer) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_free_buffer()"); + ggml_vk_host_free(vk_instance.devices[0], buffer->context); + delete buffer; +} + +static ggml_backend_buffer_t ggml_backend_vk_host_buffer_type_alloc_buffer(ggml_backend_buffer_type_t buft, size_t size) { + VK_LOG_MEMORY("ggml_backend_vk_host_buffer_type_alloc_buffer(" << size << ")"); + + size += 32; // Behave like the CPU buffer type + void * ptr = nullptr; + try { + ptr = ggml_vk_host_malloc(vk_instance.devices[0], size); + } catch (vk::SystemError& e) { + GGML_LOG_WARN("ggml_vulkan: Failed to allocate pinned memory (%s)\n", e.what()); + // fallback to cpu buffer + return ggml_backend_buft_alloc_buffer(ggml_backend_cpu_buffer_type(), size); + } + + ggml_backend_buffer_t buffer = ggml_backend_cpu_buffer_from_ptr(ptr, size); + buffer->buft = buft; + buffer->iface.free_buffer = ggml_backend_vk_host_buffer_free_buffer; + + return buffer; + + UNUSED(buft); +} + +static size_t ggml_backend_vk_host_buffer_type_get_alignment(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->properties.limits.minMemoryMapAlignment; + + UNUSED(buft); +} + +static size_t ggml_backend_vk_host_buffer_type_get_max_size(ggml_backend_buffer_type_t buft) { + return vk_instance.devices[0]->suballocation_block_size; + + UNUSED(buft); +} + +// Should be changed to return device-specific host buffer type +// but that probably requires changes in llama.cpp +ggml_backend_buffer_type_t ggml_backend_vk_host_buffer_type() { + static struct ggml_backend_buffer_type ggml_backend_vk_buffer_type_host = { + /* .iface = */ { + /* .get_name = */ ggml_backend_vk_host_buffer_type_name, + /* .alloc_buffer = */ ggml_backend_vk_host_buffer_type_alloc_buffer, + /* .get_alignment = */ ggml_backend_vk_host_buffer_type_get_alignment, + /* .get_max_size = */ ggml_backend_vk_host_buffer_type_get_max_size, + /* .get_alloc_size = */ ggml_backend_cpu_buffer_type()->iface.get_alloc_size, + /* .is_host = */ ggml_backend_cpu_buffer_type()->iface.is_host, + }, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), 0), + /* .context = */ nullptr, + }; + + // Make sure device 0 is initialized + ggml_vk_instance_init(); + ggml_vk_get_device(0); + + return &ggml_backend_vk_buffer_type_host; +} + + +// backend + +static const char * ggml_backend_vk_name(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return ctx->name.c_str(); +} + +static void ggml_backend_vk_free(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + VK_LOG_DEBUG("ggml_backend_vk_free(" << ctx->name << ")"); + + ggml_vk_cleanup(ctx); + + delete ctx; + delete backend; +} + +static ggml_backend_buffer_type_t ggml_backend_vk_get_default_buffer_type(ggml_backend_t backend) { + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + return &ctx->device->buffer_type; +} + +static void ggml_backend_vk_set_tensor_async(ggml_backend_t backend, ggml_tensor * tensor, const void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_set_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_write_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static void ggml_backend_vk_get_tensor_async(ggml_backend_t backend, const ggml_tensor * tensor, void * data, size_t offset, size_t size) { + VK_LOG_DEBUG("ggml_backend_vk_get_tensor_async(" << size << ")"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + GGML_ASSERT((tensor->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || tensor->buffer->buft == ggml_backend_vk_host_buffer_type()) && "unsupported buffer type"); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer buf = buf_ctx->dev_buffer; + + ggml_vk_buffer_read_async(transfer_ctx, buf, vk_tensor_offset(tensor) + tensor->view_offs + offset, data, size); +} + +static bool ggml_backend_vk_cpy_tensor_async(ggml_backend_t backend, const ggml_tensor * src, ggml_tensor * dst) { + VK_LOG_DEBUG("ggml_backend_vk_cpy_tensor_async()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if ((dst->buffer->buft == ggml_backend_vk_get_default_buffer_type(backend) || dst->buffer->buft == ggml_backend_vk_host_buffer_type()) && ggml_backend_buffer_is_vk(src->buffer)) { + ggml_backend_vk_buffer_context * src_buf_ctx = (ggml_backend_vk_buffer_context *)src->buffer->context; + ggml_backend_vk_buffer_context * dst_buf_ctx = (ggml_backend_vk_buffer_context *)dst->buffer->context; + + vk_context transfer_ctx; + + if (ctx->transfer_ctx.expired()) { + // Initialize new transfer context + transfer_ctx = ggml_vk_create_context(ctx, ctx->transfer_cmd_pool); + ctx->transfer_ctx = transfer_ctx; + ggml_vk_ctx_begin(ctx->device, transfer_ctx); + } else { + transfer_ctx = ctx->transfer_ctx.lock(); + } + + vk_buffer src_buf = src_buf_ctx->dev_buffer; + vk_buffer dst_buf = dst_buf_ctx->dev_buffer; + + ggml_vk_buffer_copy_async(transfer_ctx, dst_buf, vk_tensor_offset(dst) + dst->view_offs, src_buf, vk_tensor_offset(src) + src->view_offs, ggml_nbytes(src)); + return true; + } + + return false; +} + +static void ggml_backend_vk_synchronize(ggml_backend_t backend) { + VK_LOG_DEBUG("ggml_backend_vk_synchronize()"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + if(ctx->transfer_ctx.expired()) { + return; + } + + vk_context transfer_ctx = ctx->transfer_ctx.lock(); + + ggml_vk_ctx_end(transfer_ctx); + + for (auto& cpy : transfer_ctx->in_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ggml_vk_submit(transfer_ctx, ctx->fence); + ggml_vk_wait_for_fence(ctx); + + for (auto& cpy : transfer_ctx->out_memcpys) { + memcpy(cpy.dst, cpy.src, cpy.n); + } + + ctx->transfer_ctx.reset(); +} + +static bool ggml_vk_is_empty(ggml_tensor * node) { + return ggml_is_empty(node) || node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; +} + +static bool ggml_vk_can_fuse(const struct ggml_cgraph * cgraph, int node_idx, std::initializer_list ops) { + if (!ggml_can_fuse(cgraph, node_idx, ops)) { + return false; + } + + if (ops.size() == 2 && ops.begin()[0] == GGML_OP_RMS_NORM && ops.begin()[1] == GGML_OP_MUL) { + // additional constraints specific to this fusion + const ggml_tensor *rms_norm = cgraph->nodes[node_idx]; + const ggml_tensor *mul = cgraph->nodes[node_idx + 1]; + + GGML_ASSERT(rms_norm->src[0]->type == GGML_TYPE_F32); + GGML_ASSERT(rms_norm->type == GGML_TYPE_F32); + // rms_norm only supports f32 + if (mul->src[0]->type != GGML_TYPE_F32 || + mul->src[1]->type != GGML_TYPE_F32 || + mul->type != GGML_TYPE_F32) { + return false; + } + // if rms_norm is the B operand, then we don't handle broadcast + if (rms_norm == mul->src[1] && + !ggml_are_same_shape(mul->src[0], rms_norm)) { + return false; + } + // rms_norm shader assumes contiguous rows + if (!ggml_is_contiguous_rows(mul->src[0]) || !ggml_is_contiguous_rows(mul->src[1])) { + return false; + } + } + return true; +} + +static uint32_t ggml_vk_fuse_multi_add(ggml_backend_vk_context * ctx, const struct ggml_cgraph * cgraph, int node_idx) { + + const ggml_tensor *first_node = cgraph->nodes[node_idx]; + if (first_node->op != GGML_OP_ADD) { + return 0; + } + + if (!ctx->device->multi_add) { + return 0; + } + + int32_t num_adds = 1; + while (node_idx + num_adds < cgraph->n_nodes && + cgraph->nodes[node_idx + num_adds]->op == GGML_OP_ADD && + num_adds < MAX_FUSED_ADDS) { + num_adds++; + } + + // The shader currently requires same shapes (but different strides are allowed), + // everything f32, and no misalignment + for (int32_t i = 0; i < num_adds; ++i) { + const ggml_tensor *next_node = cgraph->nodes[node_idx + i]; + if (!ggml_are_same_shape(first_node, next_node->src[0]) || + !ggml_are_same_shape(first_node, next_node->src[1]) || + next_node->type != GGML_TYPE_F32 || + next_node->src[0]->type != GGML_TYPE_F32 || + next_node->src[1]->type != GGML_TYPE_F32 || + get_misalign_bytes(ctx, next_node) || + get_misalign_bytes(ctx, next_node->src[0]) || + get_misalign_bytes(ctx, next_node->src[1])) { + num_adds = i; + } + } + + // Verify we can fuse these + ggml_op adds[MAX_FUSED_ADDS]; + for (int32_t i = 0; i < num_adds; ++i) { + adds[i] = GGML_OP_ADD; + } + + // decrease num_adds if they can't all be fused + while (num_adds > 1 && !ggml_can_fuse(cgraph, node_idx, adds, num_adds)) { + num_adds--; + } + + // a single add is not "fused", so just return zero + if (num_adds == 1) { + return 0; + } + return num_adds; +} + +static ggml_status ggml_backend_vk_graph_compute(ggml_backend_t backend, ggml_cgraph * cgraph) { + VK_LOG_DEBUG("ggml_backend_vk_graph_compute(" << cgraph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (vk_instance.debug_utils_support) { + vk::DebugUtilsLabelEXT dul = {}; + dul.pLabelName = "ggml_backend_vk_graph_compute"; + dul.color = std::array{1.0f, 1.0f, 1.0f, 1.0f}; + vk_instance.pfn_vkQueueBeginDebugUtilsLabelEXT(ctx->device->compute_queue.queue, reinterpret_cast(&dul)); + } + + ctx->prealloc_size_add_rms_partials = 0; + ctx->prealloc_size_add_rms_partials_offset = 0; + ctx->do_add_rms_partials = false; + + uint64_t total_mat_mul_bytes = 0; + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ctx->device->disable_fusion) { + uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); + if (num_adds) { + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + } + ggml_vk_build_graph(ctx, cgraph, i, nullptr, 0, true, false, false, false); + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + total_mat_mul_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } else if (cgraph->nodes[i]->op == GGML_OP_CONV_2D || cgraph->nodes[i]->op == GGML_OP_CONV_TRANSPOSE_2D) { + // Return CRSxNPQxsizeof(*) to account as many bytes as mul_mat has in im2col->mul_mat mode. + auto CRS_size = + cgraph->nodes[i]->src[0]->ne[0] * cgraph->nodes[i]->src[0]->ne[1] * cgraph->nodes[i]->src[1]->ne[2]; + auto NPQ_size = cgraph->nodes[i]->ne[0] * cgraph->nodes[i]->ne[1] * cgraph->nodes[i]->ne[3]; + total_mat_mul_bytes += NPQ_size * CRS_size * ggml_type_size(cgraph->nodes[i]->type); + } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; + } + if (ctx->device->need_compiles) { + ggml_vk_load_shaders(ctx->device); + } + ggml_vk_preallocate_buffers(ctx); + ggml_pipeline_allocate_descriptor_sets(ctx); + + int last_node = cgraph->n_nodes - 1; + + // If the last op in the cgraph isn't backend GPU, the command buffer doesn't get closed properly + while (last_node > 0 && ggml_vk_is_empty(cgraph->nodes[last_node])) { + last_node -= 1; + } + + // Reserve tensor context space for all nodes + ctx->tensor_ctxs.resize(cgraph->n_nodes); + + bool first_node_in_batch = true; // true if next node will be first node in a batch + int submit_node_idx = 0; // index to first node in a batch + + vk_context compute_ctx; + if (vk_perf_logger_enabled) { + // allocate/resize the query pool + if (ctx->device->num_queries < cgraph->n_nodes + 1) { + if (ctx->device->query_pool) { + ctx->device->device.destroyQueryPool(ctx->device->query_pool); + } + vk::QueryPoolCreateInfo query_create_info; + query_create_info.queryType = vk::QueryType::eTimestamp; + query_create_info.queryCount = cgraph->n_nodes + 100; + ctx->device->query_pool = ctx->device->device.createQueryPool(query_create_info); + ctx->device->num_queries = query_create_info.queryCount; + } + + ctx->device->device.resetQueryPool(ctx->device->query_pool, 0, cgraph->n_nodes+1); + + GGML_ASSERT(ctx->compute_ctx.expired()); + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, 0); + } + + ctx->prealloc_y_last_pipeline_used = nullptr; + ctx->prealloc_y_last_tensor_used = nullptr; + + if (ctx->prealloc_size_add_rms_partials) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // initialize partial sums to zero. + ggml_vk_buffer_memset_async(compute_ctx, ctx->prealloc_add_rms_partials, 0, 0, ctx->prealloc_size_add_rms_partials); + ggml_vk_sync_buffers(ctx, compute_ctx); + } + + // Submit after enough work has accumulated, to overlap CPU cmdbuffer generation with GPU execution. + // Estimate the amount of matmul work by looking at the weight matrix size, and submit every 100MB + // (and scaled down based on model size, so smaller models submit earlier). + // Also submit at least every 100 nodes, in case there are workloads without as much matmul. + int nodes_per_submit = 100; + int submitted_nodes = 0; + int submit_count = 0; + uint64_t mul_mat_bytes = 0; + uint64_t mul_mat_bytes_per_submit = std::min(uint64_t(100*1000*1000), total_mat_mul_bytes / 40u); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (first_node_in_batch) { + submit_node_idx = i; + } + + if (cgraph->nodes[i]->op == GGML_OP_MUL_MAT || cgraph->nodes[i]->op == GGML_OP_MUL_MAT_ID) { + mul_mat_bytes += ggml_nbytes(cgraph->nodes[i]->src[0]); + } + + if (!ctx->device->disable_fusion) { + uint32_t num_adds = ggml_vk_fuse_multi_add(ctx, cgraph, i); + if (num_adds) { + ctx->num_additional_fused_ops = num_adds - 1; + } else if (ggml_vk_can_fuse(cgraph, i, { GGML_OP_RMS_NORM, GGML_OP_MUL })) { + ctx->num_additional_fused_ops = 1; + } + } + + // Signal the almost_ready fence when the graph is mostly complete (< 20% remaining) + bool almost_ready = (cgraph->n_nodes - i) < cgraph->n_nodes / 5; + bool submit = (submitted_nodes >= nodes_per_submit) || + (mul_mat_bytes >= mul_mat_bytes_per_submit) || + (i + ctx->num_additional_fused_ops == last_node) || + (almost_ready && !ctx->almost_ready_fence_pending); + + bool enqueued = ggml_vk_build_graph(ctx, cgraph, i, cgraph->nodes[submit_node_idx], submit_node_idx, false, i + ctx->num_additional_fused_ops == last_node, almost_ready, submit); + + if (vk_perf_logger_enabled) { + if (ctx->compute_ctx.expired()) { + compute_ctx = ggml_vk_create_context(ctx, ctx->compute_cmd_pool); + ctx->compute_ctx = compute_ctx; + ggml_vk_ctx_begin(ctx->device, compute_ctx); + } else { + compute_ctx = ctx->compute_ctx.lock(); + } + // If there are fused ops, just write out timestamps for all nodes to keep the accounting simple + for (int j = 0; j < ctx->num_additional_fused_ops + 1; ++j) { + compute_ctx->s->buffer.writeTimestamp(vk::PipelineStageFlagBits::eAllCommands, ctx->device->query_pool, i+j+1); + } + } + + if (enqueued) { + ++submitted_nodes; + +#ifndef GGML_VULKAN_CHECK_RESULTS + if (first_node_in_batch) { + first_node_in_batch = false; + } +#endif + } + + if (submit && enqueued) { + first_node_in_batch = true; + submitted_nodes = 0; + mul_mat_bytes = 0; + if (submit_count < 3) { + mul_mat_bytes_per_submit *= 2; + } + submit_count++; + } + i += ctx->num_additional_fused_ops; + ctx->num_additional_fused_ops = 0; + } + + if (vk_perf_logger_enabled) { + // End the command buffer and submit/wait + GGML_ASSERT(!ctx->compute_ctx.expired()); + compute_ctx = ctx->compute_ctx.lock(); + ggml_vk_ctx_end(compute_ctx); + + ggml_vk_submit(compute_ctx, ctx->device->fence); + VK_CHECK(ctx->device->device.waitForFences({ ctx->device->fence }, true, UINT64_MAX), "GGML_VULKAN_PERF waitForFences"); + ctx->device->device.resetFences({ ctx->device->fence }); + + // Get the results and pass them to the logger + std::vector timestamps(cgraph->n_nodes + 1); + VK_CHECK(ctx->device->device.getQueryPoolResults(ctx->device->query_pool, 0, cgraph->n_nodes + 1, (cgraph->n_nodes + 1)*sizeof(uint64_t), timestamps.data(), sizeof(uint64_t), vk::QueryResultFlagBits::e64 | vk::QueryResultFlagBits::eWait), "get timestamp results"); + for (int i = 0; i < cgraph->n_nodes; i++) { + if (!ggml_vk_is_empty(cgraph->nodes[i])) { + ctx->device->perf_logger->log_timing(cgraph->nodes[i], uint64_t((timestamps[i+1] - timestamps[i]) * ctx->device->properties.limits.timestampPeriod)); + } + } + + ctx->device->perf_logger->print_timings(); + } + + ggml_vk_graph_cleanup(ctx); + + return GGML_STATUS_SUCCESS; + + UNUSED(backend); +} + +// Sort the graph for improved parallelism. +static void ggml_vk_graph_optimize(ggml_backend_t backend, struct ggml_cgraph * graph) +{ + VK_LOG_DEBUG("ggml_vk_graph_optimize(" << graph->n_nodes << " nodes)"); + ggml_backend_vk_context * ctx = (ggml_backend_vk_context *)backend->context; + + if (ctx->device->disable_graph_optimize) { + return; + } + + auto const &is_empty = [](ggml_tensor * node) -> bool { + return node->op == GGML_OP_NONE || node->op == GGML_OP_RESHAPE || node->op == GGML_OP_TRANSPOSE || node->op == GGML_OP_VIEW || node->op == GGML_OP_PERMUTE; + }; + + auto const &is_src_of = [](const ggml_tensor *dst, const ggml_tensor *src) -> bool { + for (uint32_t s = 0; s < GGML_MAX_SRC; ++s) { + if (dst->src[s] == src) { + return true; + } + } + // implicit dependency if they view the same tensor + const ggml_tensor *dst2 = dst->view_src ? dst->view_src : dst; + const ggml_tensor *src2 = src->view_src ? src->view_src : src; + if (dst2 == src2) { + return true; + } + return false; + }; + + // This function tries to reorder the graph to allow nodes to run in parallel. + // This helps with small batches, but for large batches its a slowdown, probably + // due to cache contention. So only reorder if the majority of nodes have few rows. + int num_small_nodes = 0; + int num_counted_nodes = 0; + for (int i = 0; i < graph->n_nodes; ++i) { + if (!is_empty(graph->nodes[i]) && + graph->nodes[i]->op != GGML_OP_SET_ROWS) { + if (ggml_nrows(graph->nodes[i]) <= 8) { + num_small_nodes++; + } + num_counted_nodes++; + } + } + if (num_small_nodes < num_counted_nodes / 2) { + return; + } + + std::vector new_order; + std::vector used(graph->n_nodes, false); + int first_unused = 0; + while (first_unused < graph->n_nodes) { + std::vector current_set; + + // First, grab the next unused node. + current_set.push_back(first_unused); + + // Loop through the next N nodes. Grab any that don't depend on other nodes that + // haven't already been run. Nodes that have already been run have used[i] set + // to true. Allow nodes that depend on the previous node if it's a fusion pattern + // that we support (e.g. RMS_NORM + MUL). + // This first pass only grabs "real" (non-view nodes). Second pass grabs view nodes. + // The goal is to not interleave real and view nodes in a way that breaks fusion. + const int NUM_TO_CHECK = 20; + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !(j == c+1 && c == current_set.back() && graph->nodes[c]->op == GGML_OP_RMS_NORM && graph->nodes[j]->op == GGML_OP_MUL)) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + // Second pass grabs view nodes. + // Skip this if it would break a fusion optimization (don't split up add->rms_norm or add->add). + if (graph->nodes[current_set.back()]->op != GGML_OP_ADD) { + for (int j = first_unused+1; j < std::min(first_unused + NUM_TO_CHECK, graph->n_nodes); ++j) { + if (used[j]) { + continue; + } + if (!is_empty(graph->nodes[j])) { + continue; + } + bool ok = true; + for (int c = first_unused; c < j; ++c) { + bool c_in_current_set = std::find(current_set.begin(), current_set.end(), c) != current_set.end(); + // skip views whose srcs haven't been processed. + if (!used[c] && + is_src_of(graph->nodes[j], graph->nodes[c]) && + !c_in_current_set) { + ok = false; + break; + } + } + if (ok) { + current_set.push_back(j); + } + } + } + + // Push the current set into new_order + for (auto c : current_set) { + new_order.push_back(graph->nodes[c]); + used[c] = true; + } + while (first_unused < graph->n_nodes && used[first_unused]) { + first_unused++; + } + } + // Replace the graph with the new order. + for (int i = 0; i < graph->n_nodes; ++i) { + graph->nodes[i] = new_order[i]; + } +} + +// TODO: enable async and synchronize +static ggml_backend_i ggml_backend_vk_interface = { + /* .get_name = */ ggml_backend_vk_name, + /* .free = */ ggml_backend_vk_free, + /* .set_tensor_async = */ NULL, // ggml_backend_vk_set_tensor_async, + /* .get_tensor_async = */ NULL, // ggml_backend_vk_get_tensor_async, + /* .cpy_tensor_async = */ NULL, // ggml_backend_vk_cpy_tensor_async, + /* .synchronize = */ NULL, // ggml_backend_vk_synchronize, + /* .graph_plan_create = */ NULL, + /* .graph_plan_free = */ NULL, + /* .graph_plan_update = */ NULL, + /* .graph_plan_compute = */ NULL, + /* .graph_compute = */ ggml_backend_vk_graph_compute, + /* .event_record = */ NULL, + /* .event_wait = */ NULL, + /* .graph_optimize = */ ggml_vk_graph_optimize, +}; + +static ggml_guid_t ggml_backend_vk_guid() { + static ggml_guid guid = { 0xb8, 0xf7, 0x4f, 0x86, 0x40, 0x3c, 0xe1, 0x02, 0x91, 0xc8, 0xdd, 0xe9, 0x02, 0x3f, 0xc0, 0x2b }; + return &guid; +} + +ggml_backend_t ggml_backend_vk_init(size_t dev_num) { + VK_LOG_DEBUG("ggml_backend_vk_init(" << dev_num << ")"); + + ggml_backend_vk_context * ctx = new ggml_backend_vk_context; + ggml_vk_init(ctx, dev_num); + + ggml_backend_t vk_backend = new ggml_backend { + /* .guid = */ ggml_backend_vk_guid(), + /* .iface = */ ggml_backend_vk_interface, + /* .device = */ ggml_backend_reg_dev_get(ggml_backend_vk_reg(), dev_num), + /* .context = */ ctx, + }; + + return vk_backend; +} + +bool ggml_backend_is_vk(ggml_backend_t backend) { + return backend != NULL && ggml_guid_matches(backend->guid, ggml_backend_vk_guid()); +} + +int ggml_backend_vk_get_device_count() { + return ggml_vk_get_device_count(); +} + +void ggml_backend_vk_get_device_description(int device, char * description, size_t description_size) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + ggml_vk_get_device_description(dev_idx, description, description_size); +} + +std::string ggml_backend_vk_get_device_id(int device) { + GGML_ASSERT(device < (int) vk_instance.device_indices.size()); + int dev_idx = vk_instance.device_indices[device]; + return ggml_vk_get_device_id(dev_idx); +} + +////////////////////////// + +struct ggml_backend_vk_device_context { + size_t device; + std::string name; + std::string description; + bool is_integrated_gpu; + // Combined string id in the form "dddd:bb:dd.f" (domain:bus:device.function) + std::string pci_id; + std::string id; + std::string uuid; + int major; + int minor; + int driver_major; + int driver_minor; + int pci_bus_id; + int pci_device_id; + int pci_domain_id; +}; + +void ggml_backend_vk_get_device_memory(ggml_backend_vk_device_context *ctx, size_t * free, size_t * total) { + GGML_ASSERT(ctx->device < (int) vk_instance.device_indices.size()); + GGML_ASSERT(ctx->device < (int) vk_instance.device_supports_membudget.size()); + + vk::PhysicalDevice vkdev = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[ctx->device]]; + + vk::PhysicalDeviceMemoryProperties memprops = vkdev.getMemoryProperties(); + vk::PhysicalDeviceProperties2 props2; + vkdev.getProperties2(&props2); + + if (!ctx->is_integrated_gpu) + { + // Use vendor specific management libraries for best VRAM reporting if available + switch (props2.properties.vendorID) { + case VK_VENDOR_ID_AMD: + if (ggml_hip_mgmt_init() == 0) { + int status = ggml_hip_get_device_memory(ctx->pci_bus_id, ctx->pci_device_id, free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing ADLX memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_hip_mgmt_release(); + return; + } + ggml_hip_mgmt_release(); + } + break; + case VK_VENDOR_ID_NVIDIA: + if (ggml_nvml_init() == 0) { + int status = ggml_nvml_get_device_memory(ctx->uuid.c_str(), free, total); + if (status == 0) { + GGML_LOG_DEBUG("%s utilizing NVML memory reporting free: %zu total: %zu\n", __func__, *free, *total); + ggml_nvml_release(); + return; + } + ggml_nvml_release(); + } + break; + } + } + // else fallback to memory budget if supported + + *total = 0; + *free = 0; + vk::PhysicalDeviceMemoryBudgetPropertiesEXT mem_budget_props; + vk::PhysicalDeviceMemoryProperties2 memprops2; + memprops2.pNext = &mem_budget_props; + vkdev.getMemoryProperties2(&memprops2); + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } else if (ctx->is_integrated_gpu) { + // Include shared memory on iGPUs + *total += memprops2.memoryProperties.memoryHeaps[i].size; + } + } + for (int i = 0; i < memprops2.memoryProperties.memoryHeapCount; i++) { + if (memprops2.memoryProperties.memoryHeaps[i].flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *free += mem_budget_props.heapBudget[i]; + } else if (ctx->is_integrated_gpu) { + *free += mem_budget_props.heapBudget[i]; + } + } + if (*total > 0 && *free > 0) { + return; + } else if (*total > 0) { + *free = *total; + return; + } + + // else just report the physical memory + for (const vk::MemoryHeap& heap : memprops2.memoryProperties.memoryHeaps) { + if (heap.flags & vk::MemoryHeapFlagBits::eDeviceLocal) { + *total = heap.size; + *free = heap.size; + break; + } + } +} + +static vk::PhysicalDeviceType ggml_backend_vk_get_device_type(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + vk::PhysicalDeviceProperties2 props = {}; + device.getProperties2(&props); + + return props.properties.deviceType; +} + +static std::string ggml_backend_vk_get_device_pci_id(int device_idx) { + GGML_ASSERT(device_idx >= 0 && device_idx < (int) vk_instance.device_indices.size()); + + vk::PhysicalDevice device = vk_instance.instance.enumeratePhysicalDevices()[vk_instance.device_indices[device_idx]]; + + const std::vector ext_props = device.enumerateDeviceExtensionProperties(); + + bool ext_support = false; + + for (const auto& properties : ext_props) { + if (strcmp("VK_EXT_pci_bus_info", properties.extensionName) == 0) { + ext_support = true; + break; + } + } + + if (!ext_support) { + return ""; + } + + vk::PhysicalDeviceProperties2 props = {}; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_info = {}; + + props.pNext = &pci_bus_info; + + device.getProperties2(&props); + + const uint32_t pci_domain = pci_bus_info.pciDomain; + const uint32_t pci_bus = pci_bus_info.pciBus; + const uint32_t pci_device = pci_bus_info.pciDevice; + const uint8_t pci_function = (uint8_t) pci_bus_info.pciFunction; // pci function is between 0 and 7, prevent printf overflow warning + + char pci_bus_id[16] = {}; + snprintf(pci_bus_id, sizeof(pci_bus_id), "%04x:%02x:%02x.%x", pci_domain, pci_bus, pci_device, pci_function); + + return std::string(pci_bus_id); +} + +static bool ggml_backend_vk_parse_pci_bus_id(const std::string & id, int *domain, int *bus, int *device) { + if (id.empty()) return false; + unsigned int d = 0, b = 0, dev = 0, func = 0; + // Expected format: dddd:bb:dd.f (all hex) + int n = sscanf(id.c_str(), "%4x:%2x:%2x.%1x", &d, &b, &dev, &func); + if (n < 4) return false; + if (domain) *domain = (int) d; + if (bus) *bus = (int) b; + if (device) *device = (int) dev; + return true; +} + +static const char * ggml_backend_vk_device_get_name(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->name.c_str(); +} + +static const char * ggml_backend_vk_device_get_description(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->description.c_str(); +} + +static const char * ggml_backend_vk_device_get_id(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ctx->id.c_str(); +} + +static void ggml_backend_vk_device_get_memory(ggml_backend_dev_t device, size_t * free, size_t * total) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)device->context; + ggml_backend_vk_get_device_memory(ctx, free, total); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_buffer_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_buffer_type(ctx->device); +} + +static ggml_backend_buffer_type_t ggml_backend_vk_device_get_host_buffer_type(ggml_backend_dev_t dev) { + UNUSED(dev); + return ggml_backend_vk_host_buffer_type(); +} + +static enum ggml_backend_dev_type ggml_backend_vk_device_get_type(ggml_backend_dev_t dev) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + + return ctx->is_integrated_gpu ? GGML_BACKEND_DEVICE_TYPE_IGPU : GGML_BACKEND_DEVICE_TYPE_GPU; +} + +static void ggml_backend_vk_device_get_props(ggml_backend_dev_t dev, struct ggml_backend_dev_props * props) { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + + props->name = ggml_backend_vk_device_get_name(dev); + props->description = ggml_backend_vk_device_get_description(dev); + props->id = ggml_backend_vk_device_get_id(dev); + props->type = ggml_backend_vk_device_get_type(dev); + props->device_id = ctx->pci_id.empty() ? nullptr : ctx->pci_id.c_str(); + ggml_backend_vk_device_get_memory(dev, &props->memory_free, &props->memory_total); + props->caps = { + /* .async = */ false, + /* .host_buffer = */ true, + /* .buffer_from_host_ptr = */ false, + /* .events = */ false, + }; + + props->compute_major = ctx->major; + props->compute_minor = ctx->minor; + props->driver_major = ctx->driver_major; + props->driver_minor = ctx->driver_minor; + props->integrated = ctx->is_integrated_gpu; + props->pci_bus_id = ctx->pci_bus_id; + props->pci_device_id = ctx->pci_device_id; + props->pci_domain_id = ctx->pci_domain_id; + props->library = GGML_VK_NAME; +} + +static ggml_backend_t ggml_backend_vk_device_init(ggml_backend_dev_t dev, const char * params) { + UNUSED(params); + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + return ggml_backend_vk_init(ctx->device); +} + +static bool ggml_backend_vk_device_supports_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + switch (op->op) { + case GGML_OP_UNARY: + switch (ggml_get_unary_op(op)) { + case GGML_UNARY_OP_EXP: + case GGML_UNARY_OP_GELU: + case GGML_UNARY_OP_GELU_ERF: + case GGML_UNARY_OP_GELU_QUICK: + case GGML_UNARY_OP_SILU: + case GGML_UNARY_OP_RELU: + case GGML_UNARY_OP_TANH: + case GGML_UNARY_OP_SIGMOID: + case GGML_UNARY_OP_HARDSIGMOID: + case GGML_UNARY_OP_HARDSWISH: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + case GGML_OP_GLU: + switch (ggml_get_glu_op(op)) { + case GGML_GLU_OP_GEGLU: + case GGML_GLU_OP_REGLU: + case GGML_GLU_OP_SWIGLU: + case GGML_GLU_OP_SWIGLU_OAI: + case GGML_GLU_OP_GEGLU_ERF: + case GGML_GLU_OP_GEGLU_QUICK: + return ggml_is_contiguous(op->src[0]) && + (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16) && + (op->src[0]->type == op->type); + default: + return false; + } + case GGML_OP_MUL_MAT: + case GGML_OP_MUL_MAT_ID: + { + ggml_type src0_type = op->src[0]->type; + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_MUL_MAT_ID) { + if (!device->mul_mat_id_s[src0_type] && !device->mul_mat_id_m[src0_type] && !device->mul_mat_id_l[src0_type]) { + // If there's not enough shared memory for row_ids and the result tile, fallback to CPU + return false; + } + } + switch (src0_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + break; + default: + return false; + } + struct ggml_tensor * a; + struct ggml_tensor * b; + if (op->op == GGML_OP_MUL_MAT) { + a = op->src[0]; + b = op->src[1]; + } else { + a = op->src[2]; + b = op->src[1]; + } + if (a->ne[3] != b->ne[3]) { + return false; + } + if (!(ggml_vk_dim01_contiguous(op->src[0]) || op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16 || op->src[0]->type == GGML_TYPE_BF16) || + !(ggml_vk_dim01_contiguous(op->src[1]) || op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16)) { + return false; + } + if (op->src[0]->type == GGML_TYPE_BF16 && op->src[1]->type == GGML_TYPE_F16) { + // We currently don't have a bf16 x f16 shader, or an fp16->bf16 copy shader. + // So don't support this combination for now. + return false; + } + + return true; + } + case GGML_OP_FLASH_ATTN_EXT: + { + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + auto device = ggml_vk_get_device(ctx->device); + bool coopmat2 = device->coopmat2; + uint32_t HSK = op->src[1]->ne[0]; + uint32_t HSV = op->src[2]->ne[0]; + if ((HSK % 8) != 0 || (HSV % 8) != 0) { + return false; + } + if (op->src[4] && op->src[4]->type != GGML_TYPE_F32) { + return false; + } + if (op->src[0]->type != GGML_TYPE_F32) { + return false; + } + if (op->type != GGML_TYPE_F32) { + return false; + } + if (op->src[3] && op->src[3]->type != GGML_TYPE_F16) { + return false; + } + // It's straightforward to support different K/V dequant, but would + // significantly increase the number of pipelines + if (op->src[1]->type != op->src[2]->type) { + return false; + } + switch (op->src[1]->type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q8_0: + // supported in scalar and coopmat2 paths + break; + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + // K dequants currently disabled because D dimension is rounded up to 256 and runs inefficiently + //case GGML_TYPE_Q2_K: + //case GGML_TYPE_Q3_K: + //case GGML_TYPE_Q4_K: + //case GGML_TYPE_Q5_K: + //case GGML_TYPE_Q6_K: + //case GGML_TYPE_IQ1_S: + //case GGML_TYPE_IQ1_M: + //case GGML_TYPE_IQ2_XXS: + //case GGML_TYPE_IQ2_XS: + //case GGML_TYPE_IQ2_S: + //case GGML_TYPE_IQ3_XXS: + //case GGML_TYPE_IQ3_S: + //case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + // currently supported only in coopmat2 path + if (!coopmat2) { + return false; + } + break; + default: + return false; + } + if (!coopmat2 && !device->subgroup_shuffle) { + // scalar FA uses subgroupShuffle + return false; + } + return true; + } + case GGML_OP_GET_ROWS: + { + switch (op->src[0]->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_Q2_K: + case GGML_TYPE_Q3_K: + case GGML_TYPE_Q4_K: + case GGML_TYPE_Q5_K: + case GGML_TYPE_Q6_K: + case GGML_TYPE_IQ1_S: + case GGML_TYPE_IQ1_M: + case GGML_TYPE_IQ2_XXS: + case GGML_TYPE_IQ2_XS: + case GGML_TYPE_IQ2_S: + case GGML_TYPE_IQ3_XXS: + case GGML_TYPE_IQ3_S: + case GGML_TYPE_IQ4_XS: + case GGML_TYPE_IQ4_NL: + case GGML_TYPE_MXFP4: + return true; + default: + return false; + } + } + case GGML_OP_SET_ROWS: + { + switch (op->type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + return false; + } + } + case GGML_OP_CONT: + case GGML_OP_CPY: + case GGML_OP_DUP: + { + ggml_type src0_type = op->src[0]->type; + ggml_type src1_type = op->src[1] != nullptr ? op->src[1]->type : src0_type; + + if (src0_type == GGML_TYPE_F32) { + switch (src1_type) { + case GGML_TYPE_F32: + case GGML_TYPE_F16: + case GGML_TYPE_BF16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } + } + if (src1_type == GGML_TYPE_F32) { + switch (src0_type) { + case GGML_TYPE_F16: + case GGML_TYPE_Q4_0: + case GGML_TYPE_Q4_1: + case GGML_TYPE_Q5_0: + case GGML_TYPE_Q5_1: + case GGML_TYPE_Q8_0: + case GGML_TYPE_IQ4_NL: + return true; + default: + break; + } + } + + if (src0_type == GGML_TYPE_F16 && src1_type == GGML_TYPE_F16) { + return true; + } + + if ( + (src0_type == GGML_TYPE_F32 && src1_type == GGML_TYPE_I32) || + (src0_type == GGML_TYPE_I32 && src1_type == GGML_TYPE_F32) + ) { + return true; + } + + // We can handle copying from a type to the same type if it's + // contiguous (memcpy). We use f16 or f32 shaders to do the copy, + // so the type/block size must be a multiple of 4. + if (src0_type == src1_type && + ggml_is_contiguous(op->src[0]) && ggml_is_contiguous(op) && + (ggml_type_size(src0_type) % 2) == 0) { + return true; + } + return false; + } + case GGML_OP_REPEAT: + return ggml_type_size(op->type) == sizeof(float) && ggml_type_size(op->src[0]->type) == sizeof(float); + case GGML_OP_REPEAT_BACK: + return op->type == GGML_TYPE_F32 && op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ROPE: + case GGML_OP_ROPE_BACK: + case GGML_OP_NONE: + case GGML_OP_RESHAPE: + case GGML_OP_VIEW: + case GGML_OP_PERMUTE: + case GGML_OP_TRANSPOSE: + case GGML_OP_RMS_NORM: + return true; + case GGML_OP_NORM: + case GGML_OP_GROUP_NORM: + case GGML_OP_L2_NORM: + return ggml_is_contiguous(op->src[0]); + case GGML_OP_ADD: + case GGML_OP_SUB: + case GGML_OP_MUL: + case GGML_OP_DIV: + return (op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + (op->src[1]->type == GGML_TYPE_F32 || op->src[1]->type == GGML_TYPE_F16) && + (op->type == GGML_TYPE_F32 || op->type == GGML_TYPE_F16); + case GGML_OP_ADD_ID: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32 && op->src[2]->type == GGML_TYPE_I32 && + op->type == GGML_TYPE_F32; + case GGML_OP_SILU_BACK: + case GGML_OP_RMS_NORM_BACK: + case GGML_OP_SQR: + case GGML_OP_SQRT: + case GGML_OP_SIN: + case GGML_OP_COS: + case GGML_OP_CLAMP: + case GGML_OP_LEAKY_RELU: + case GGML_OP_OPT_STEP_ADAMW: + case GGML_OP_OPT_STEP_SGD: + return op->src[0]->type == GGML_TYPE_F32; + case GGML_OP_ARGSORT: + return op->ne[0] <= max_argsort_cols; + case GGML_OP_UPSCALE: + case GGML_OP_ACC: + case GGML_OP_CONCAT: + case GGML_OP_SCALE: + case GGML_OP_PAD: + case GGML_OP_ROLL: + case GGML_OP_DIAG_MASK_INF: + case GGML_OP_SOFT_MAX: + case GGML_OP_SOFT_MAX_BACK: + return true; + case GGML_OP_SUM: + case GGML_OP_SUM_ROWS: + case GGML_OP_MEAN: + return op->src[0]->type == GGML_TYPE_F32 && ggml_is_contiguous_rows(op->src[0]); + case GGML_OP_ARGMAX: + case GGML_OP_COUNT_EQUAL: + case GGML_OP_IM2COL: + case GGML_OP_IM2COL_3D: + case GGML_OP_TIMESTEP_EMBEDDING: + case GGML_OP_CONV_2D_DW: + case GGML_OP_POOL_2D: + case GGML_OP_RWKV_WKV6: + case GGML_OP_RWKV_WKV7: + return true; + case GGML_OP_CONV_TRANSPOSE_1D: + return op->src[0]->type == GGML_TYPE_F32 && op->src[1]->type == GGML_TYPE_F32; + case GGML_OP_CONV_2D: + case GGML_OP_CONV_TRANSPOSE_2D: + { + // Op is disabled for Apple because it segfaults at pipeline create time on MoltenVK + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + const vk_device& device = ggml_vk_get_device(ctx->device); + if (op->op == GGML_OP_CONV_TRANSPOSE_2D && + device->properties.limits.maxPushConstantsSize < sizeof(vk_op_conv_transpose_2d_push_constants)) { + return false; + } + // Channel-contiguous format is not supported yet. + return ((op->src[0]->type == GGML_TYPE_F32 || op->src[0]->type == GGML_TYPE_F16) && + op->src[1]->type == GGML_TYPE_F32 && + op->type == GGML_TYPE_F32 && + ggml_is_contiguous(op->src[0]) && + ggml_is_contiguous(op->src[1]) && + ggml_is_contiguous(op)); + } + default: + return false; + } + + UNUSED(dev); +} + +static bool ggml_backend_vk_device_supports_buft(ggml_backend_dev_t dev, ggml_backend_buffer_type_t buft) { + if (buft->iface.get_name != ggml_backend_vk_buffer_type_name) { + return false; + } + + ggml_backend_vk_device_context * ctx = (ggml_backend_vk_device_context *)dev->context; + ggml_backend_vk_buffer_type_context * buft_ctx = (ggml_backend_vk_buffer_type_context *)buft->context; + + return buft_ctx->device->idx == ctx->device; +} + +static bool ggml_backend_vk_device_offload_op(ggml_backend_dev_t dev, const ggml_tensor * op) { + const int min_batch_size = 32; + + return (op->ne[1] >= min_batch_size && op->op != GGML_OP_GET_ROWS) || + (op->ne[2] >= min_batch_size && op->op == GGML_OP_MUL_MAT_ID); + + UNUSED(dev); +} + +static const struct ggml_backend_device_i ggml_backend_vk_device_i = { + /* .get_name = */ ggml_backend_vk_device_get_name, + /* .get_description = */ ggml_backend_vk_device_get_description, + /* .get_memory = */ ggml_backend_vk_device_get_memory, + /* .get_type = */ ggml_backend_vk_device_get_type, + /* .get_props = */ ggml_backend_vk_device_get_props, + /* .init_backend = */ ggml_backend_vk_device_init, + /* .get_buffer_type = */ ggml_backend_vk_device_get_buffer_type, + /* .get_host_buffer_type = */ ggml_backend_vk_device_get_host_buffer_type, + /* .buffer_from_host_ptr = */ NULL, + /* .supports_op = */ ggml_backend_vk_device_supports_op, + /* .supports_buft = */ ggml_backend_vk_device_supports_buft, + /* .offload_op = */ ggml_backend_vk_device_offload_op, + /* .event_new = */ NULL, + /* .event_free = */ NULL, + /* .event_synchronize = */ NULL, +}; + +static const char * ggml_backend_vk_reg_get_name(ggml_backend_reg_t reg) { + UNUSED(reg); + return GGML_VK_NAME; +} + +static size_t ggml_backend_vk_reg_get_device_count(ggml_backend_reg_t reg) { + UNUSED(reg); + return ggml_backend_vk_get_device_count(); +} + +static ggml_backend_dev_t ggml_backend_vk_reg_get_device(ggml_backend_reg_t reg, size_t device) { + static std::vector devices; + + static bool initialized = false; + + { + static std::mutex mutex; + std::lock_guard lock(mutex); + if (!initialized) { + std::vector vk_devices = vk_instance.instance.enumeratePhysicalDevices(); + + for (int i = 0; i < ggml_backend_vk_get_device_count(); i++) { + ggml_backend_vk_device_context * ctx = new ggml_backend_vk_device_context; + char desc[256]; + ggml_backend_vk_get_device_description(i, desc, sizeof(desc)); + ctx->device = i; + ctx->name = GGML_VK_NAME + std::to_string(i); + ctx->description = desc; + ctx->is_integrated_gpu = ggml_backend_vk_get_device_type(i) == vk::PhysicalDeviceType::eIntegratedGpu; + ctx->pci_id = ggml_backend_vk_get_device_pci_id(i); + ctx->id = ggml_backend_vk_get_device_id(i); + devices.push_back(new ggml_backend_device { + /* .iface = */ ggml_backend_vk_device_i, + /* .reg = */ reg, + /* .context = */ ctx, + }); + + // Gather additional information about the device + int dev_idx = vk_instance.device_indices[i]; + vk::PhysicalDeviceProperties props1; + vk_devices[dev_idx].getProperties(&props1); + vk::PhysicalDeviceProperties2 props2; + vk::PhysicalDeviceIDProperties device_id_props; + vk::PhysicalDevicePCIBusInfoPropertiesEXT pci_bus_props; + vk::PhysicalDeviceDriverProperties driver_props; + props2.pNext = &device_id_props; + device_id_props.pNext = &pci_bus_props; + pci_bus_props.pNext = &driver_props; + vk_devices[dev_idx].getProperties2(&props2); + std::ostringstream oss; + oss << std::hex << std::setfill('0'); + oss << "GPU-"; + int byteIdx = 0; + for (int i = 0; i < 16; ++i, ++byteIdx) { + oss << std::setw(2) << static_cast(device_id_props.deviceUUID[i]); + if (byteIdx == 3 || byteIdx == 5 || byteIdx == 7 || byteIdx == 9) { + oss << '-'; + } + } + ctx->uuid = oss.str(); + ctx->pci_bus_id = pci_bus_props.pciBus; + ctx->pci_device_id = pci_bus_props.pciDevice; + ctx->pci_domain_id = pci_bus_props.pciDomain; + ctx->id = std::to_string(i); + ctx->major = 0; + ctx->minor = 0; + // TODO regex parse driver_props.driverInfo for a X.Y or X.Y.Z version string + ctx->driver_major = 0; + ctx->driver_minor = 0; + } + initialized = true; + } + } + + GGML_ASSERT(device < devices.size()); + return devices[device]; +} + +static const struct ggml_backend_reg_i ggml_backend_vk_reg_i = { + /* .get_name = */ ggml_backend_vk_reg_get_name, + /* .get_device_count = */ ggml_backend_vk_reg_get_device_count, + /* .get_device = */ ggml_backend_vk_reg_get_device, + /* .get_proc_address = */ NULL, +}; + +ggml_backend_reg_t ggml_backend_vk_reg() { + static ggml_backend_reg reg = { + /* .api_version = */ GGML_BACKEND_API_VERSION, + /* .iface = */ ggml_backend_vk_reg_i, + /* .context = */ nullptr, + }; + try { + ggml_vk_instance_init(); + return ® + } catch (const vk::SystemError& e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: System error: " << e.what()); + return nullptr; + } catch (const std::exception &e) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: " << e.what()); + return nullptr; + } catch (...) { + VK_LOG_DEBUG("ggml_backend_vk_reg() -> Error: unknown exception during Vulkan init"); + return nullptr; + } +} + +// Extension availability +static bool ggml_vk_instance_validation_ext_available() { +#ifdef GGML_VULKAN_VALIDATE + // Check if validation layer provides the extension + const std::string layer_name = "VK_LAYER_KHRONOS_validation"; + for (const auto& layer : vk::enumerateInstanceLayerProperties()) { + if (layer_name == layer.layerName.data()) { + for (const auto& ext : vk::enumerateInstanceExtensionProperties(layer_name)) { + if (strcmp("VK_EXT_validation_features", ext.extensionName.data()) == 0) { + return true; + } + } + } + } + + std::cerr << "ggml_vulkan: WARNING: Validation layer or layer extension VK_EXT_validation_features not found." << std::endl; +#endif + return false; +} +static bool ggml_vk_instance_portability_enumeration_ext_available(const std::vector& instance_extensions) { +#ifdef __APPLE__ + // Check for portability enumeration extension for MoltenVK support + for (const auto& properties : instance_extensions) { + if (strcmp("VK_KHR_portability_enumeration", properties.extensionName) == 0) { + return true; + } + } + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_KHR_portability_enumeration not found." << std::endl; +#endif + return false; + + UNUSED(instance_extensions); +} + +// Extension availability +static bool ggml_vk_instance_debug_utils_ext_available( + const std::vector & instance_extensions) { + // Check for portability enumeration extension for MoltenVK support + for (const auto & properties : instance_extensions) { + if (strcmp("VK_EXT_debug_utils", properties.extensionName) == 0) { + return true; + } + } + + std::cerr << "ggml_vulkan: WARNING: Instance extension VK_EXT_debug_utils not found." << std::endl; + return false; + + UNUSED(instance_extensions); +} + +static bool ggml_vk_device_is_supported(const vk::PhysicalDevice & vkdev) { + VkPhysicalDeviceFeatures2 device_features2; + device_features2.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_FEATURES_2; + + VkPhysicalDeviceVulkan11Features vk11_features; + vk11_features.pNext = nullptr; + vk11_features.sType = VK_STRUCTURE_TYPE_PHYSICAL_DEVICE_VULKAN_1_1_FEATURES; + device_features2.pNext = &vk11_features; + + vkGetPhysicalDeviceFeatures2(vkdev, &device_features2); + + return vk11_features.storageBuffer16BitAccess; +} + +static bool ggml_vk_khr_cooperative_matrix_support(const vk::PhysicalDeviceProperties& props, const vk::PhysicalDeviceDriverProperties& driver_props, vk_device_architecture arch) { + switch (props.vendorID) { + case VK_VENDOR_ID_INTEL: + // Only allowing Xe2 GPU at the moment since Xe2 GPU can gain significant performance boost, + // while some older hardware (ex. Arc A770) has performance regressions + return arch == vk_device_architecture::INTEL_XE2; + case VK_VENDOR_ID_AMD: + if (driver_props.driverID == vk::DriverId::eAmdProprietary || driver_props.driverID == vk::DriverId::eAmdOpenSource) { + // Workaround for AMD proprietary driver reporting support on all GPUs + return arch == vk_device_architecture::AMD_RDNA3; + } + return true; + default: + return true; + } +} + +// checks + +#ifdef GGML_VULKAN_CHECK_RESULTS +static void ggml_vk_print_graph_origin(const ggml_tensor * tensor, std::vector& done, int level = 0) { + if (std::find(done.begin(), done.end(), tensor) != done.end() || level > 10) { + return; + } + for (int j = 0; j < level; j++) { + std::cerr << " "; + } + std::cerr << ggml_op_name(tensor->op) << " gpu=" << (tensor->extra != nullptr) << std::endl; + + done.push_back(tensor); + + for (int i = 0; i < GGML_MAX_SRC; i++) { + if (tensor->src[i] != nullptr) { + ggml_vk_print_graph_origin(tensor->src[i], done, level + 1); + } + } +} + +static void ggml_vk_print_tensor_area(const ggml_tensor * tensor, const void * data, int i0, int i1, int i2, int i3) { + if (tensor->type != GGML_TYPE_F32 && tensor->type != GGML_TYPE_F16 && tensor->type != GGML_TYPE_I32) { + return; + } + i0 = std::max(i0, 5); + i1 = std::max(i1, 5); + i2 = std::max(i2, 0); + i3 = std::max(i3, 0); + fprintf(stderr, " "); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + fprintf(stderr, "%7d ", idx1); + } + fprintf(stderr, "\n"); + for (int idx0 = i0 - 5; idx0 < i0 + 5; idx0++) { + fprintf(stderr, "%7d: ", idx0); + for (int idx1 = i1 - 5; idx1 < i1 + 5; idx1++) { + if (idx0 >= 0 && idx0 < tensor->ne[0] && idx1 >= 0 && idx1 < tensor->ne[1] && i2 >= 0 && i2 < tensor->ne[2] && i3 >= 0 && i3 < tensor->ne[3]) { + float val; + if (tensor->type == GGML_TYPE_F32) { + val = *(const float *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + val = ggml_fp16_to_fp32(*(const ggml_fp16_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + val = *(const int32_t *) ((const char *) data + i3*tensor->nb[3] + i2*tensor->nb[2] + idx1*tensor->nb[1] + idx0*tensor->nb[0]); + } else { + GGML_ABORT("fatal error"); + } + fprintf(stderr, "% 7.2f ", val); + } else { + fprintf(stderr, " "); + } + } + fprintf(stderr, "\n"); + } +} + +static void ggml_vk_print_tensor(const ggml_tensor * tensor, const char * name) { + void * tensor_data = tensor->data; + + const bool is_gpu = tensor->buffer != nullptr && ggml_backend_buffer_is_vk(tensor->buffer); + + if (is_gpu) { + const size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer buffer_gpu = buf_ctx->dev_buffer; + ggml_vk_buffer_read(buffer_gpu, vk_tensor_offset(tensor) + tensor->view_offs, tensor_data, tensor_size); + } + + std::cerr << "TENSOR CHECK " << name << " (" << tensor->name << "): " << ggml_op_name(tensor->op) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << std::endl; + if (tensor->src[0] != nullptr) { + std::cerr << "tensor->src[0]=" << tensor->src[0] << " name=" << tensor->src[0]->name << " op=" << ggml_op_name(tensor->src[0]->op) << " type=" << ggml_type_name(tensor->src[0]->type) << " ne0=" << tensor->src[0]->ne[0] << " nb0=" << tensor->src[0]->nb[0] << " ne1=" << tensor->src[0]->ne[1] << " nb1=" << tensor->src[0]->nb[1] << " ne2=" << tensor->src[0]->ne[2] << " nb2=" << tensor->src[0]->nb[2] << " ne3=" << tensor->src[0]->ne[3] << " nb3=" << tensor->src[0]->nb[3] << std::endl; + } + if (tensor->src[1] != nullptr) { + std::cerr << "tensor->src[1]=" << tensor->src[1] << " name=" << tensor->src[1]->name << " op=" << ggml_op_name(tensor->src[1]->op) << " type=" << ggml_type_name(tensor->src[1]->type) << " ne0=" << tensor->src[1]->ne[0] << " nb0=" << tensor->src[1]->nb[0] << " ne1=" << tensor->src[1]->ne[1] << " nb1=" << tensor->src[1]->nb[1] << " ne2=" << tensor->src[1]->ne[2] << " nb2=" << tensor->src[1]->nb[2] << " ne3=" << tensor->src[1]->ne[3] << " nb3=" << tensor->src[1]->nb[3] << std::endl; + } + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + + if (is_gpu) { + free(tensor_data); + } +} + +void * comp_result; +size_t comp_size; +size_t comp_nb[GGML_MAX_DIMS]; +size_t check_counter = 0; +static void ggml_vk_check_results_0(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { + return; + } + + bool fused_rms_norm_mul = false; + int rms_norm_idx = -1; + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + fused_rms_norm_mul = true; + tensor = cgraph->nodes[tensor_idx + 1]; + } + + check_counter++; + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_0(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + + struct ggml_init_params iparams = { + /*.mem_size =*/ 2ul*1024ul*1024ul*1024ul, + /*.mem_buffer =*/ NULL, + /*.no_alloc =*/ false, + }; + + struct ggml_context * ggml_ctx = ggml_init(iparams); + + std::array src_clone = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + std::array src_size = {0, 0, 0, 0, 0, 0}; + std::array src_buffer = {nullptr, nullptr, nullptr, nullptr, nullptr, nullptr}; + const char * srci_name[6] = {"src0", "src1", "src2", "src3", "src4", "src5"}; + + struct ggml_tensor * tensor_clone = nullptr; + + for (int i = 0; i < 6; i++) { + ggml_tensor * srci = tensor->src[i]; + if (fused_rms_norm_mul) { + rms_norm_idx = tensor->src[0]->op == GGML_OP_RMS_NORM ? 0 : 1; + ggml_tensor *rms_norm = tensor->src[rms_norm_idx]; + switch (i) { + case 0: srci = rms_norm->src[0]; break; + case 1: srci = tensor->src[1 - rms_norm_idx]; break; + default: continue; + } + } + if (srci == nullptr) { + continue; + } + ggml_tensor * srci_clone = ggml_dup_tensor(ggml_ctx, srci); + size_t srci_size = ggml_nbytes(srci); + + src_clone[i] = srci_clone; + src_size[i] = ggml_nbytes(srci); + src_buffer[i] = malloc(srci_size); + + srci_clone->data = src_buffer[i]; + if (ggml_backend_buffer_is_host(srci->buffer)) { + memcpy(srci_clone->data, srci->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } else if (ggml_backend_buffer_is_vk(srci->buffer)) { + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)srci->buffer->context; + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(srci) + srci->view_offs; + if (!ggml_is_contiguous(srci) && ggml_vk_dim01_contiguous(srci)) { + for (int i3 = 0; i3 < srci->ne[3]; i3++) { + for (int i2 = 0; i2 < srci->ne[2]; i2++) { + const int idx = i3*srci->ne[2] + i2; + ggml_vk_buffer_read(buffer_gpu, offset + idx * srci->nb[2], ((char *)srci_clone->data + idx * srci_clone->nb[2]), srci->ne[1] * srci->nb[1]); + } + } + + srci_clone->nb[0] = srci->nb[0]; + srci_clone->nb[1] = srci->nb[1]; + for (int i = 2; i < GGML_MAX_DIMS; i++) { + srci_clone->nb[i] = srci_clone->nb[i - 1]*srci_clone->ne[i - 1]; + } + } else { + if (offset + srci_size >= buffer_gpu->size) { + srci_size = buffer_gpu->size - offset; + } + ggml_vk_buffer_read(buffer_gpu, offset, srci_clone->data, srci_size); + memcpy(srci_clone->nb, srci->nb, sizeof(size_t) * GGML_MAX_DIMS); + } + } else { + GGML_ABORT("fatal error"); + } + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(srci, srci_name[i]); + } + } + + if (tensor->op == GGML_OP_FLASH_ATTN_EXT) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_flash_attn_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], params[0], params[1], params[2]); + if (src_clone[4]) { + ggml_flash_attn_ext_add_sinks(tensor_clone, src_clone[4]); + } + } else if (tensor->op == GGML_OP_MUL_MAT) { + tensor_clone = ggml_mul_mat(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL_MAT_ID) { + tensor_clone = ggml_mul_mat_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } else if (tensor->op == GGML_OP_SUB) { + tensor_clone = ggml_sub(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_MUL) { + if (fused_rms_norm_mul) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->src[rms_norm_idx]->op_params); + tensor_clone = ggml_mul(ggml_ctx, tensor_clone, src_clone[1 - rms_norm_idx]); + } else { + tensor_clone = ggml_mul(ggml_ctx, src_clone[0], src_clone[1]); + } + } else if (tensor->op == GGML_OP_DIV) { + tensor_clone = ggml_div(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_CONCAT) { + tensor_clone = ggml_concat(ggml_ctx, src_clone[0], src_clone[1], *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_UPSCALE) { + tensor_clone = ggml_interpolate(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], (ggml_scale_mode) tensor->op_params[0]); + } else if (tensor->op == GGML_OP_SCALE) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_scale_bias(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_SQR) { + tensor_clone = ggml_sqr(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SQRT) { + tensor_clone = ggml_sqrt(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SIN) { + tensor_clone = ggml_sin(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COS) { + tensor_clone = ggml_cos(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_CLAMP) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_clamp(ggml_ctx, src_clone[0], params[0], params[1]); + } else if (tensor->op == GGML_OP_PAD) { + tensor_clone = ggml_pad_ext(ggml_ctx, src_clone[0], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3], + tensor->op_params[4], tensor->op_params[5], tensor->op_params[6], tensor->op_params[7]); + } else if (tensor->op == GGML_OP_REPEAT) { + tensor_clone = ggml_repeat(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_REPEAT_BACK) { + tensor_clone = ggml_repeat_back(ggml_ctx, src_clone[0], tensor); + } else if (tensor->op == GGML_OP_ADD) { + tensor_clone = ggml_add(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ACC) { + tensor_clone = ggml_acc(ggml_ctx, src_clone[0], src_clone[1], tensor->op_params[0], tensor->op_params[1], tensor->op_params[2], tensor->op_params[3]); + } else if (tensor->op == GGML_OP_NORM) { + tensor_clone = ggml_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_GROUP_NORM) { + const float * float_params = (const float *)tensor->op_params; + tensor_clone = ggml_group_norm(ggml_ctx, src_clone[0], tensor->op_params[0], float_params[1]); + } else if (tensor->op == GGML_OP_RMS_NORM) { + tensor_clone = ggml_rms_norm(ggml_ctx, src_clone[0], *(float *)tensor->op_params); + } else if (tensor->op == GGML_OP_RMS_NORM_BACK) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_rms_norm_back(ggml_ctx, src_clone[0], src_clone[1], eps); + } else if (tensor->op == GGML_OP_SILU_BACK) { + tensor_clone = ggml_silu_back(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_L2_NORM) { + const float eps = ((float *) tensor->op_params)[0]; + tensor_clone = ggml_l2_norm(ggml_ctx, src_clone[0], eps); + } else if (tensor->op == GGML_OP_SOFT_MAX) { + if (src1 != nullptr) { + const float * params = (const float *)tensor->op_params; + tensor_clone = ggml_soft_max_ext(ggml_ctx, src_clone[0], src_clone[1], params[0], params[1]); + } else { + tensor_clone = ggml_soft_max(ggml_ctx, src_clone[0]); + } + } else if (tensor->op == GGML_OP_SOFT_MAX_BACK) { + tensor_clone = ggml_soft_max_ext_back(ggml_ctx, src_clone[0], src_clone[1], ((float *)tensor->op_params)[0], ((float *)tensor->op_params)[1]); + } else if (tensor->op == GGML_OP_DIAG_MASK_INF) { + tensor_clone = ggml_diag_mask_inf(ggml_ctx, src_clone[0], tensor->op_params[0]); + } else if (tensor->op == GGML_OP_ROPE || tensor->op == GGML_OP_ROPE_BACK) { + const int n_dims = ((int32_t *) tensor->op_params)[1]; + const int mode = ((int32_t *) tensor->op_params)[2]; + //const int n_ctx_ggml = ((int32_t *) tensor->op_params)[3]; + const int n_ctx_orig_ggml = ((int32_t *) tensor->op_params)[4]; + const float freq_base = ((float *) tensor->op_params)[5]; + const float freq_scale = ((float *) tensor->op_params)[6]; + const float ext_factor = ((float *) tensor->op_params)[7]; + const float attn_factor = ((float *) tensor->op_params)[8]; + const float beta_fast = ((float *) tensor->op_params)[9]; + const float beta_slow = ((float *) tensor->op_params)[10]; + if (mode & GGML_ROPE_TYPE_MROPE) { + int32_t *sections = ((int32_t *) tensor->op_params) + 11; + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_multi(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_multi_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, sections, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } else { + if (tensor->op == GGML_OP_ROPE) { + tensor_clone = ggml_rope_ext(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } else { + tensor_clone = ggml_rope_ext_back(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], n_dims, mode, n_ctx_orig_ggml, freq_base, freq_scale, ext_factor, attn_factor, beta_fast, beta_slow); + } + } + } else if (tensor->op == GGML_OP_UNARY) { + switch (ggml_get_unary_op(tensor)) { + case GGML_UNARY_OP_EXP: + tensor_clone = ggml_exp(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SILU: + tensor_clone = ggml_silu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU: + tensor_clone = ggml_gelu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU_ERF: + tensor_clone = ggml_gelu_erf(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_GELU_QUICK: + tensor_clone = ggml_gelu_quick(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_RELU: + tensor_clone = ggml_relu(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_TANH: + tensor_clone = ggml_tanh(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_SIGMOID: + tensor_clone = ggml_sigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSIGMOID: + tensor_clone = ggml_hardsigmoid(ggml_ctx, src_clone[0]); + break; + case GGML_UNARY_OP_HARDSWISH: + tensor_clone = ggml_hardswish(ggml_ctx, src_clone[0]); + break; + default: + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + } else if (tensor->op == GGML_OP_GLU) { + if (src_clone[1] == nullptr) { + tensor_clone = ggml_glu(ggml_ctx, src_clone[0], (ggml_glu_op) tensor->op_params[0], tensor->op_params[1]); + } else { + tensor_clone = ggml_glu_split(ggml_ctx, src_clone[0], src_clone[1], (ggml_glu_op) tensor->op_params[0]); + } + ggml_set_op_params_i32(tensor_clone, 2, ggml_get_op_params_i32(tensor, 2)); + ggml_set_op_params_i32(tensor_clone, 3, ggml_get_op_params_i32(tensor, 3)); + } else if (tensor->op == GGML_OP_CPY || tensor->op == GGML_OP_DUP) { + if (src1 == nullptr) { + tensor_clone = ggml_dup(ggml_ctx, src_clone[0]); + tensor_clone->type = tensor->type; + } else { + tensor_clone = ggml_cpy(ggml_ctx, src_clone[0], src_clone[1]); + } + } else if (tensor->op == GGML_OP_CONT) { + tensor_clone = ggml_cont_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_RESHAPE) { + tensor_clone = ggml_reshape_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3]); + } else if (tensor->op == GGML_OP_VIEW) { + tensor_clone = ggml_view_4d(ggml_ctx, src_clone[0], tensor->ne[0], tensor->ne[1], tensor->ne[2], tensor->ne[3], tensor->nb[1], tensor->nb[2], tensor->nb[3], ((int32_t *) tensor->op_params)[0]); + } else if (tensor->op == GGML_OP_PERMUTE) { + int32_t * params = (int32_t *)tensor->op_params; + tensor_clone = ggml_permute(ggml_ctx, src_clone[0], params[0], params[1], params[2], params[3]); + } else if (tensor->op == GGML_OP_TRANSPOSE) { + tensor_clone = ggml_transpose(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_GET_ROWS) { + tensor_clone = ggml_get_rows(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_ARGSORT) { + tensor_clone = ggml_argsort(ggml_ctx, src_clone[0], (ggml_sort_order) *(int *)tensor->op_params); + } else if (tensor->op == GGML_OP_SUM) { + tensor_clone = ggml_sum(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_SUM_ROWS) { + tensor_clone = ggml_sum_rows(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_MEAN) { + tensor_clone = ggml_mean(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_ARGMAX) { + tensor_clone = ggml_argmax(ggml_ctx, src_clone[0]); + } else if (tensor->op == GGML_OP_COUNT_EQUAL) { + tensor_clone = ggml_count_equal(ggml_ctx, src_clone[0], src_clone[1]); + } else if (tensor->op == GGML_OP_IM2COL) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + + const bool is_2D = tensor->op_params[6] == 1; + tensor_clone = ggml_im2col(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1, is_2D, tensor->type); + } else if (tensor->op == GGML_OP_IM2COL_3D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t s2 = tensor->op_params[2]; + const int32_t p0 = tensor->op_params[3]; + const int32_t p1 = tensor->op_params[4]; + const int32_t p2 = tensor->op_params[5]; + const int32_t d0 = tensor->op_params[6]; + const int32_t d1 = tensor->op_params[7]; + const int32_t d2 = tensor->op_params[8]; + const int32_t IC = tensor->op_params[9]; + + tensor_clone = ggml_im2col_3d(ggml_ctx, src_clone[0], src_clone[1], IC, s0, s1, s2, p0, p1, p2, d0, d1, d2, tensor->type); + } else if (tensor->op == GGML_OP_TIMESTEP_EMBEDDING) { + const int32_t dim = tensor->op_params[0]; + const int32_t max_period = tensor->op_params[1]; + tensor_clone = ggml_timestep_embedding(ggml_ctx, src_clone[0], dim, max_period); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_1D){ + const int32_t s0 = tensor->op_params[0]; + const int32_t p0 = tensor->op_params[1]; + const int32_t d0 = tensor->op_params[2]; + tensor_clone = ggml_conv_transpose_1d(ggml_ctx, src_clone[0], src_clone[1], s0, p0, d0); + } else if (tensor->op == GGML_OP_POOL_2D) { + enum ggml_op_pool op = static_cast(tensor->op_params[0]); + const int32_t k0 = tensor->op_params[1]; + const int32_t k1 = tensor->op_params[2]; + const int32_t s0 = tensor->op_params[3]; + const int32_t s1 = tensor->op_params[4]; + const int32_t p0 = tensor->op_params[5]; + const int32_t p1 = tensor->op_params[6]; + + tensor_clone = ggml_pool_2d(ggml_ctx, src_clone[0], op, k0, k1, s0, s1, p0, p1); + } else if (tensor->op == GGML_OP_CONV_2D) { + const int32_t s0 = tensor->op_params[0]; + const int32_t s1 = tensor->op_params[1]; + const int32_t p0 = tensor->op_params[2]; + const int32_t p1 = tensor->op_params[3]; + const int32_t d0 = tensor->op_params[4]; + const int32_t d1 = tensor->op_params[5]; + tensor_clone = ggml_conv_2d(ggml_ctx, src_clone[0], src_clone[1], s0, s1, p0, p1, d0, d1); + } else if (tensor->op == GGML_OP_CONV_TRANSPOSE_2D) { + const int32_t s = tensor->op_params[0]; + tensor_clone = ggml_conv_transpose_2d_p0(ggml_ctx, src_clone[0], src_clone[1], s); + } else if (tensor->op == GGML_OP_LEAKY_RELU) { + const float * op_params = (const float *)tensor->op_params; + tensor_clone = ggml_leaky_relu(ggml_ctx, src_clone[0], op_params[0], false); + } else if (tensor->op == GGML_OP_RWKV_WKV6) { + tensor_clone = ggml_rwkv_wkv6(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4], src_clone[5]); + } else if (tensor->op == GGML_OP_RWKV_WKV7) { + tensor_clone = ggml_rwkv_wkv7(ggml_ctx, src_clone[0], src_clone[1], src_clone[2], src_clone[3], + src_clone[4], src_clone[5], src_clone[6]); + } else if (tensor->op == GGML_OP_OPT_STEP_ADAMW) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_adamw(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2], src_clone[3], src_clone[4]); + } else if (tensor->op == GGML_OP_OPT_STEP_SGD) { + src_clone[0]->flags = src0->flags; + tensor_clone = ggml_opt_step_sgd(ggml_ctx, src_clone[0], src_clone[1], + src_clone[2]); + } else if (tensor->op == GGML_OP_ADD_ID) { + tensor_clone = ggml_add_id(ggml_ctx, src_clone[0], src_clone[1], src_clone[2]); + } + else { + std::cerr << "Missing vk_check_results OP: " << ggml_op_name(tensor->op) << std::endl; + GGML_ABORT("fatal error"); + } + + ggml_cgraph * cgraph_cpu = ggml_new_graph(ggml_ctx); + ggml_build_forward_expand(cgraph_cpu, tensor_clone); + + ggml_graph_compute_with_ctx(ggml_ctx, cgraph_cpu, 8); + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + ggml_vk_print_tensor(tensor_clone, "tensor_clone"); + } + + comp_size = ggml_nbytes(tensor_clone); + + comp_result = malloc(comp_size); + memcpy(comp_result, tensor_clone->data, comp_size); + memcpy(comp_nb, tensor_clone->nb, sizeof(size_t) * GGML_MAX_DIMS); + + for (int i = 0; i < 6; i++) { + if (src_buffer[i] != nullptr) { + free(src_buffer[i]); + } + } + + ggml_free(ggml_ctx); + + VK_LOG_DEBUG("END ggml_vk_check_results_0(" << tensor->name << ")"); +} + +static void ggml_vk_check_results_1(ggml_backend_vk_context * ctx, ggml_cgraph * cgraph, int tensor_idx) { + ggml_tensor * tensor = cgraph->nodes[tensor_idx]; + if (tensor->op == GGML_OP_TRANSPOSE || tensor->op == GGML_OP_SET_ROWS) { + return; + } + if (ctx->num_additional_fused_ops == 1 && + tensor->op == GGML_OP_RMS_NORM && + cgraph->nodes[tensor_idx + 1]->op == GGML_OP_MUL) { + tensor = cgraph->nodes[tensor_idx + 1]; + } + + if (!(vk_output_tensor > 0 && vk_output_tensor == check_counter) && check_counter <= vk_skip_checks) { + return; + } + + VK_LOG_DEBUG("ggml_vk_check_results_1(" << tensor->name << ")"); + + ggml_tensor * src0 = tensor->src[0]; + ggml_tensor * src1 = tensor->src[1]; + ggml_tensor * src2 = tensor->src[2]; + ggml_tensor * src3 = tensor->src[3]; + + void * tensor_data = tensor->data; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + size_t tensor_size = ggml_nbytes(tensor); + tensor_data = malloc(tensor_size); + + ggml_backend_vk_buffer_context * buf_ctx = (ggml_backend_vk_buffer_context *)tensor->buffer->context; + + vk_buffer& buffer_gpu = buf_ctx->dev_buffer; + uint64_t offset = vk_tensor_offset(tensor) + tensor->view_offs; + if (offset + tensor_size >= buffer_gpu->size) { + tensor_size = buffer_gpu->size - offset; + } + + ggml_vk_buffer_read(buffer_gpu, offset, tensor_data, tensor_size); + } + + float first_error_result = -1.0f; + float first_error_correct = -1.0f; + std::array first_error = { -1, -1, -1, -1 }; + double avg_err = 0.0; + size_t counter = 0; + + for (int i3 = 0; i3 < tensor->ne[3]; i3++) { + for (int i2 = 0; i2 < tensor->ne[2]; i2++) { + for (int i1 = 0; i1 < tensor->ne[1]; i1++) { + for (int i0 = 0; i0 < tensor->ne[0]; i0++) { + const bool buffer_size_fit = i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0] < comp_size; + float correct = 0.0f; + float result = 0.0f; + + if (buffer_size_fit) { + if (tensor->type == GGML_TYPE_F32) { + correct = *(float *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(float *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_F16) { + correct = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_fp16_to_fp32(*(ggml_fp16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_BF16) { + correct = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0])); + result = ggml_bf16_to_fp32(*(ggml_bf16_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0])); + } else if (tensor->type == GGML_TYPE_I32) { + correct = *(int32_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int32_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else if (tensor->type == GGML_TYPE_I64) { + correct = *(int64_t *) ((char *) comp_result + i3*comp_nb[3] + i2*comp_nb[2] + i1*comp_nb[1] + i0*comp_nb[0]); + result = *(int64_t *) ((char *) tensor_data + i3*tensor->nb[3] + i2*tensor->nb[2] + i1*tensor->nb[1] + i0*tensor->nb[0]); + } else { + std::cerr << "Results check not implemented for type " << ggml_type_name(tensor->type) << std::endl; + } + } else { + std::cerr << "Missing debug code for type " << ggml_type_name(tensor->type) << std::endl; + GGML_ABORT("fatal error"); + } + + if ((std::isnan(correct) != std::isnan(result)) || (std::isinf(correct) != std::isinf(result)) || !buffer_size_fit) { + std::cerr << "ERROR: Invalid value in " << ggml_op_name(tensor->op) << " i3=" << i3 << " i2=" << i2 << " i1=" << i1 << " i0=" << i0 << " result=" << result << " correct=" << correct << " avg_err=" << (avg_err / counter) << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " src0->name=" << src0->name << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " src1->name=" << src1->name << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " src2->name=" << src2->name << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " src3->name=" << src3->name << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, i0, i1, i2, i3); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, i0, i1, i2, i3); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } + const double denom = std::fabs(correct) > 1.0f ? (std::fabs(correct) > 1e-8 ? std::fabs(correct) : 1e-8) : 1.0f; + if (first_error[0] == -1 && std::fabs(correct - result) / denom > 0.5) { + first_error[0] = i0; + first_error[1] = i1; + first_error[2] = i2; + first_error[3] = i3; + first_error_result = result; + first_error_correct = correct; + } + + // Special case, value is infinite, avoid NaN result in avg_err + // NaN also appears in results, if both are nan error is 0 + if (!std::isinf(correct) && !std::isinf(result) && !std::isnan(correct) && !std::isnan(result)) { + avg_err += std::fabs(correct - result) / denom; + } + counter++; + } + } + } + } + + avg_err /= counter; + + if (vk_output_tensor > 0 && vk_output_tensor == check_counter) { + std::cerr << "TENSOR CHECK: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, 5, 5, 0, 0); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, 5, 5, 0, 0); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + } + + if (avg_err > 0.5 || std::isnan(avg_err)) { + std::cerr << "ERROR: avg_err=" << avg_err << " in " << ggml_op_name(tensor->op) << " (check " << check_counter << ")" << std::endl; + std::cerr << "tensor=" << tensor << " tensor->name=" << tensor->name << " tensor->type: " << ggml_type_name(tensor->type) << " ne0=" << tensor->ne[0] << " nb0=" << tensor->nb[0] << " ne1=" << tensor->ne[1] << " nb1=" << tensor->nb[1] << " ne2=" << tensor->ne[2] << " nb2=" << tensor->nb[2] << " ne3=" << tensor->ne[3] << " nb3=" << tensor->nb[3] << " offset=" << tensor->view_offs << std::endl; + if (src0 != nullptr) { + std::cerr << "src0=" << src0 << " op=" << ggml_op_name(src0->op) << " type=" << ggml_type_name(src0->type) << " ne0=" << src0->ne[0] << " nb0=" << src0->nb[0] << " ne1=" << src0->ne[1] << " nb1=" << src0->nb[1] << " ne2=" << src0->ne[2] << " nb2=" << src0->nb[2] << " ne3=" << src0->ne[3] << " nb3=" << src0->nb[3] << " offset=" << src0->view_offs << std::endl; + } + if (src1 != nullptr) { + std::cerr << "src1=" << src1 << " op=" << ggml_op_name(src1->op) << " type=" << ggml_type_name(src1->type) << " ne0=" << src1->ne[0] << " nb0=" << src1->nb[0] << " ne1=" << src1->ne[1] << " nb1=" << src1->nb[1] << " ne2=" << src1->ne[2] << " nb2=" << src1->nb[2] << " ne3=" << src1->ne[3] << " nb3=" << src1->nb[3] << " offset=" << src1->view_offs << std::endl; + } + if (src2 != nullptr) { + std::cerr << "src2=" << src2 << " op=" << ggml_op_name(src2->op) << " type=" << ggml_type_name(src2->type) << " ne0=" << src2->ne[0] << " nb0=" << src2->nb[0] << " ne1=" << src2->ne[1] << " nb1=" << src2->nb[1] << " ne2=" << src2->ne[2] << " nb2=" << src2->nb[2] << " ne3=" << src2->ne[3] << " nb3=" << src2->nb[3] << " offset=" << src2->view_offs << std::endl; + } + if (src3 != nullptr) { + std::cerr << "src3=" << src3 << " op=" << ggml_op_name(src3->op) << " type=" << ggml_type_name(src3->type) << " ne0=" << src3->ne[0] << " nb0=" << src3->nb[0] << " ne1=" << src3->ne[1] << " nb1=" << src3->nb[1] << " ne2=" << src3->ne[2] << " nb2=" << src3->nb[2] << " ne3=" << src3->ne[3] << " nb3=" << src3->nb[3] << " offset=" << src3->view_offs << std::endl; + } + std::cerr << "First error: result=" << first_error_result << " correct=" << first_error_correct << " i3=" << first_error[3] << " i2=" << first_error[2] << " i1=" << first_error[1] << " i0=" << first_error[0] << std::endl; + std::cerr << std::endl << "Result:" << std::endl; + ggml_vk_print_tensor_area(tensor, tensor_data, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl << "Correct:" << std::endl; + ggml_vk_print_tensor_area(tensor, comp_result, first_error[0], first_error[1], first_error[2], first_error[3]); + std::cerr << std::endl; + std::vector done; + ggml_vk_print_graph_origin(tensor, done); + GGML_ABORT("fatal error"); + } else { + std::cerr << check_counter << " " << tensor->name << " op=" << ggml_op_name(tensor->op) << " avg_err=" << avg_err << std::endl; + } + + free(comp_result); + comp_result = nullptr; + comp_size = 0; + + if (ggml_backend_buffer_is_vk(tensor->buffer)) { + free(tensor_data); + } + + VK_LOG_DEBUG("END ggml_vk_check_results_1(" << tensor->name << ")"); +} +#endif + +GGML_BACKEND_DL_IMPL(ggml_backend_vk_reg) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt new file mode 100644 index 0000000000..e1f613fb4f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/CMakeLists.txt @@ -0,0 +1,31 @@ +cmake_minimum_required(VERSION 3.19) +project("vulkan-shaders-gen" C CXX) + +find_package (Threads REQUIRED) + +if (GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat glslc support") +endif() +if (GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + message(STATUS "Enabling coopmat2 glslc support") +endif() +if (GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + message(STATUS "Enabling dot glslc support") +endif() +if (GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + add_compile_definitions(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + message(STATUS "Enabling bfloat16 glslc support") +endif() +if (GGML_VULKAN_SHADER_DEBUG_INFO) + add_compile_definitions(GGML_VULKAN_SHADER_DEBUG_INFO) + message(STATUS "Enabling shader debug info") +endif() + +set(TARGET vulkan-shaders-gen) +add_executable(${TARGET} vulkan-shaders-gen.cpp) +install(TARGETS ${TARGET} RUNTIME) +target_compile_features(${TARGET} PRIVATE cxx_std_17) +target_link_libraries(vulkan-shaders-gen PUBLIC Threads::Threads) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp new file mode 100644 index 0000000000..5084a70ed4 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/acc.comp @@ -0,0 +1,29 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + const uint offset = p.param3; + const uint src1_i = idx - offset; + const uint oz = src1_i / p.nb02; + const uint oy = (src1_i - (oz * p.nb02)) / p.nb01; + const uint ox = src1_i % p.nb01; + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + if (ox < p.ne10 && oy < p.ne11 && oz < p.ne12) { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + ox + oy * p.ne10 + oz * p.ne10 * p.ne11])); + } else { + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)])); + } +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp new file mode 100644 index 0000000000..3bcfe6908e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add.comp @@ -0,0 +1,69 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout (binding = 3, std430) buffer PartialBuf {float partial_sums[];}; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + +void main() { + uint idx = get_idx(); + uint orig_idx = idx; + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + FLOAT_TYPE sum_sq = 0; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + FLOAT_TYPE sum = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) + FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)]); + sum_sq += sum*sum; + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); + + idx += num_threads; + } + +#if ADD_RMS + if (p.param3 != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp new file mode 100644 index 0000000000..495249d5f6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/add_id.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint ne0; + uint ne1; + uint s01; + uint s02; + uint s11; + uint s21; +} p; + +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {int32_t data_c[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i1 = gl_WorkGroupID.x; + const uint i2 = gl_WorkGroupID.y; + + const uint i11 = data_c[i1 + i2 * p.s21]; + + const uint s1 = p.ne0; + const uint s2 = p.ne0 * p.ne1; + + const uint d0 = i1 * s1 + i2 * s2; + const uint a0 = i1 * p.s01 + i2 * p.s02; + const uint b0 = i11 * p.s11; + + for (uint i0 = gl_LocalInvocationID.x; i0 < p.ne0; i0 += BLOCK_SIZE) { + data_d[d0 + i0] = data_a[a0 + i0] + data_b[b0 + i0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp new file mode 100644 index 0000000000..7c12877671 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argmax.comp @@ -0,0 +1,60 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +#define FLT_MAX 3.402823466e+38F + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +shared FLOAT_TYPE tmpmax[BLOCK_SIZE]; +shared uint tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + + if (row >= p.KY) { + return; + } + + A_TYPE amax = -FLT_MAX; + uint acol = col; + + if (col < p.KX) { + amax = data_a[row*p.KX + col]; + } + + for (uint i = col + BLOCK_SIZE; i < p.KX; i += BLOCK_SIZE) { + A_TYPE val = data_a[row*p.KX + i]; + if (val > amax) { + amax = val; + acol = i; + } + } + + tmp[col] = acol; + tmpmax[col] = amax; + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s && col + s < p.KX) { + if (tmpmax[col] < tmpmax[col + s]) { + tmpmax[col] = tmpmax[col + s]; + tmp[col] = tmp[col + s]; + } + } + barrier(); + } + + if (col == 0) { + data_d[row] = D_TYPE(tmp[0]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp new file mode 100644 index 0000000000..c81b84452e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/argsort.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" + +layout(constant_id = 0) const int BLOCK_SIZE = 1024; +layout(constant_id = 1) const int BLOCK_SIZE_LOG2 = 10; +#define ASC 0 + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) buffer D {int data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint order; +} p; + +shared int dst_row[BLOCK_SIZE]; +shared A_TYPE a_sh[BLOCK_SIZE]; + +void swap(uint idx0, uint idx1) { + int tmp = dst_row[idx0]; + dst_row[idx0] = dst_row[idx1]; + dst_row[idx1] = tmp; +} + +void argsort(bool needs_bounds_check) { + // bitonic sort + const int col = int(gl_LocalInvocationID.x); + const uint row = gl_WorkGroupID.y; + + const uint row_offset = row * p.ncols; + + // initialize indices + dst_row[col] = col; + a_sh[col] = data_a[row_offset + col]; + barrier(); + + uint num_outer_loop_iters = BLOCK_SIZE_LOG2; + [[unroll]] for (uint k = 2, outer_idx = 0; outer_idx < num_outer_loop_iters; k *= 2, outer_idx++) { + uint num_inner_loop_iters = outer_idx + 1; + [[unroll]] for (uint j = k / 2, inner_idx = 0; inner_idx < num_inner_loop_iters; j /= 2, inner_idx++) { + const int ixj = int(col ^ j); + + int idx_0 = (col & k) == 0 ? col : ixj; + int idx_1 = (col & k) == 0 ? ixj : col; + + int sh_idx_0 = dst_row[idx_0]; + int sh_idx_1 = dst_row[idx_1]; + bool idx_0_oob = needs_bounds_check ? sh_idx_0 >= p.ncols : false; + bool idx_1_oob = needs_bounds_check ? sh_idx_1 >= p.ncols : false; + + if ((idx_0_oob || + (!idx_1_oob && a_sh[sh_idx_0] > a_sh[sh_idx_1])) && (ixj > col)) { + swap(idx_0, idx_1); + } + + barrier(); + } + } + + if (col < p.ncols) { + if (p.order == ASC) { + data_d[row_offset + col] = dst_row[col]; + } else { + data_d[row_offset + p.ncols - col - 1] = dst_row[col]; + } + } +} + +void main() { + if (p.ncols == BLOCK_SIZE) { + argsort(false); + } else { + argsort(true); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp new file mode 100644 index 0000000000..653431895e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/clamp.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val < p.param1 ? p.param1 : (val > p.param2 ? p.param2 : val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp new file mode 100644 index 0000000000..e404698382 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/concat.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + const int dim = p.param3; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne22*p.ne21*p.ne20); + const uint i3_offset = i3 * p.ne22*p.ne21*p.ne20; + const uint i2 = (idx - i3_offset) / (p.ne21*p.ne20); + const uint i2_offset = i2*p.ne21*p.ne20; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne20; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne20; + + uint o[4] = {0, 0, 0, 0}; + o[dim] = dim == 0 ? p.ne00 : (dim == 1 ? p.ne01 : (dim == 2 ? p.ne02 : p.ne03)); + + const uint src0_idx = i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00; + const uint src1_idx = (i3 - o[3])*p.nb13 + (i2 - o[2])*p.nb12 + (i1 - o[1])*p.nb11 + (i0 - o[0])*p.nb10; + const uint dst_idx = i3*p.nb23 + i2*p.nb22 + i1*p.nb21 + i0*p.nb20; + + const bool is_src0 = i0 < p.ne00 && i1 < p.ne01 && i2 < p.ne02 && i3 < p.ne03; + +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : data_b[get_boffset() + src1_idx]); +#else + if (is_src0) { + data_d[get_doffset() + dst_idx] = data_a[get_aoffset() + src0_idx]; + } else { + data_d[get_doffset() + dst_idx] = data_b[get_boffset() + src1_idx]; + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp new file mode 100644 index 0000000000..ca1a3ac25b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/contig_copy.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +#extension GL_EXT_control_flow_attributes : require + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + // fast path for when all four iterations are in-bounds + if (idx + (num_iter-1)*num_threads < p.ne) { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } else { + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + idx]); + data_d[get_doffset() + idx] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + idx] = D_TYPE(data_a[get_aoffset() + idx]); +#else + data_d[get_doffset() + idx] = data_a[get_aoffset() + idx]; +#endif + idx += num_threads; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp new file mode 100644 index 0000000000..70a301488e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_dw.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint ne; + uint batches; + uint channels; + uint dst_w; + uint dst_h; + uint src_w; + uint src_h; + uint knl_w; + uint knl_h; + int stride_x; + int stride_y; + int pad_x; + int pad_y; + int dilation_x; + int dilation_y; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE knl_data[];}; +layout (binding = 1) readonly buffer B {B_TYPE src_data[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst_data[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE conv_2d_dw_whcn(uint idx) { + uint i0 = idx / p.dst_w; + uint dst_x = idx - i0 * p.dst_w; + uint i1 = i0 / p.dst_h; + uint dst_y = i0 - i1 * p.dst_h; + uint n = i1 / p.channels; + uint c = i1 - n * p.channels; + + uint src_i = n * p.channels * p.src_h * p.src_w + c * p.src_h * p.src_w; + uint knl_i = c * p.knl_h * p.knl_w; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * p.src_w + src_x]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[knl_i + knl_y * p.knl_w + knl_x]); + sum = fma(v, k, sum); + } + } + return sum; +} + +FLOAT_TYPE conv_2d_dw_cwhn(uint idx) { + uint i0 = idx / p.channels; + uint c = idx - i0 * p.channels; + uint i1 = i0 / p.dst_w; + uint dst_x = i0 - i1 * p.dst_w; + uint n = i1 / p.dst_h; + uint dst_y = i1 - n * p.dst_h; + + uint src_i = n * p.channels * p.src_h * p.src_w; + uint src_row = p.src_w * p.channels; + uint knl_row = p.knl_w * p.channels; + + FLOAT_TYPE sum = 0.0; + for (uint knl_y = 0; knl_y < p.knl_h; ++knl_y) { + uint src_y = dst_y * p.stride_y + knl_y * p.dilation_y - p.pad_y; + if (src_y >= p.src_h) { // src_y < 0 will wrap to a large unsigned int + continue; + } + for (uint knl_x = 0; knl_x < p.knl_w; ++knl_x) { + uint src_x = dst_x * p.stride_x + knl_x * p.dilation_x - p.pad_x; + if (src_x >= p.src_w) { // src_x < 0 will wrap to a large unsigned int + continue; + } + FLOAT_TYPE v = FLOAT_TYPE(src_data[src_i + src_y * src_row + src_x * p.channels + c]); + FLOAT_TYPE k = FLOAT_TYPE(knl_data[ knl_y * knl_row + knl_x * p.channels + c]); + sum = fma(v, k, sum); + } + } + return sum; +} + +void main() { + uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + if (idx >= p.ne) { + return; + } + + FLOAT_TYPE result = +#ifdef WHCN + conv_2d_dw_whcn(idx); +#else + conv_2d_dw_cwhn(idx); +#endif + dst_data[idx] = D_TYPE(result); +} + diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp new file mode 100644 index 0000000000..0367e80bbf --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv2d_mm.comp @@ -0,0 +1,349 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#ifdef COOPMAT2 +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#ifdef USE_COLLECTIVES +# extension GL_KHR_shader_subgroup_shuffle : enable +#endif + +#include "types.glsl" + +// shape notation: [dim(N), ..., dim(0)] -- stride(dim(j)) >= stride(dim(i)) if i > j +layout(binding = 0) readonly buffer A { + A_TYPE knl_data[]; +}; // src0 - kernel: [KW, KH, Cin, Cout] for conv_2d, [KW, KH, Cout, Cin] for conv_transposed_2d + +layout(binding = 1) readonly buffer B { + B_TYPE src_data[]; +}; // src1 - input: [W, H, Cin, N] -- channel_first format + +layout(binding = 2) writeonly buffer D { + D_TYPE dst_data[]; +}; // dst - result: [OW, OH, Cout, N] + +layout(push_constant) uniform parameter { + // I/O channels, batch size + uint32_t Cout; + uint32_t Cin; + uint32_t N; + + // Tensor spatial sizes: kernel, input, output + uint32_t KW; + uint32_t KH; + uint32_t W; + uint32_t H; + uint32_t OW; + uint32_t OH; + + // Parameters: stride, padding, dilation - 0=y, 1=x + uint32_t s0; + uint32_t s1; + uint32_t p0; + uint32_t p1; + uint32_t d0; + uint32_t d1; + + // Strides in elements + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + + uint32_t nb1; + uint32_t nb2; + uint32_t nb3; + + // fastdiv helper values + uint32_t KWmp; uint32_t KWL; + uint32_t KWKHmp; uint32_t KWKHL; + uint32_t OWmp; uint32_t OWL; + uint32_t OWOHmp; uint32_t OWOHL; +#ifdef TRANSPOSE + uint32_t s0mp; uint32_t s0L; + uint32_t s1mp; uint32_t s1L; +#endif +} + +p; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; +// Blocktile sizes +layout(constant_id = 1) const uint BS_K = 128; +layout(constant_id = 2) const uint BS_CRS = 16; +layout(constant_id = 3) const uint BS_NPQ = 128; +// Thread-tile sizes +layout(constant_id = 4) const uint TS_K = 8; +layout(constant_id = 5) const uint use_collectives = 1; +layout(constant_id = 6) const uint SHMEM_PAD = 4; + +uint32_t tid = gl_LocalInvocationID.x; +const uint32_t WG_SIZE = gl_WorkGroupSize.x; + +uint splitWork(uint work_size, uint block_size) { + return (block_size + work_size - 1) / block_size; +} + +uint32_t K = p.Cout; +uint32_t CRS = p.Cin * p.KH * p.KW; +uint32_t NPQ = p.N * p.OH * p.OW; + +uint32_t n_elems_out = K * NPQ; + +// Number of blocktiles per input +uint32_t NB_CRS = splitWork(CRS, BS_CRS); + +#ifdef COOPMAT2 +#define SHMEM_TYPE float16_t +#else +#define SHMEM_TYPE float +#endif + +const uint32_t Ash_stride = BS_CRS + SHMEM_PAD; +const uint32_t Bsh_stride = BS_NPQ + SHMEM_PAD; + +const uint32_t Ash_numel = BS_K * BS_CRS; +const uint32_t Bsh_numel = BS_CRS * BS_NPQ; + +const uint32_t Ash_len = BS_K * Ash_stride; +const uint32_t Bsh_len = BS_CRS * Bsh_stride; + +shared SHMEM_TYPE Ash[Ash_len]; // K x CRS +shared SHMEM_TYPE Bsh[Bsh_len]; // CRS x NPQ + +// Threadtile sizes +const uint32_t TS_NPQ = BS_K * BS_NPQ / WG_SIZE / TS_K; + +// Number of threadtiles per blocktile +const uint32_t NT_K = BS_K / TS_K; +const uint32_t NT_NPQ = BS_NPQ / TS_NPQ; + +/* +Compute +KxCRS @ CRSxNPQ = K x NPQ +K=Cout +C=Cin +R,S=KH,KW +P,Q=OH,OW +*/ + +uint32_t B_idx_K = gl_WorkGroupID.x; +uint32_t B_idx_NPQ = gl_WorkGroupID.y; + +uint32_t T_y = tid / NT_NPQ; +uint32_t T_x = tid % NT_NPQ; + +uint32_t Ar = tid / BS_CRS; +uint32_t Ac = tid % BS_CRS; +const uint32_t ArpWg = WG_SIZE / BS_CRS; + +uint32_t Br = tid / BS_NPQ; +uint32_t Bc = tid % BS_NPQ; +const uint32_t BrpWg = WG_SIZE / BS_NPQ; + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +#ifdef COOPMAT2 +#define ACC_TYPE float16_t + +ACC_TYPE perElemOpStore(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem) +{ + uint32_t K_idx = B_idx_K * BS_K + r; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + c; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = D_TYPE(elem); + } + return elem; +} +#endif + +void main() { +#ifdef COOPMAT2 + coopmat matC; + matC = coopmat(0.0); +#else + float regC[TS_K][TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = 0.0; + } + } +#endif + /* Advance block in CRS dim */ + for (uint32_t B_idx_CRS = 0; B_idx_CRS < NB_CRS; B_idx_CRS++) { + uint32_t CRS_idx_a; + uint32_t Cin_idx_a; + uint32_t KH_idx_a; + uint32_t KW_idx_a; + +#ifdef USE_COLLECTIVES + uint32_t cached_CRS_idx; + uint32_t cached_Cin_idx; + uint32_t cached_KH_idx; + uint32_t cached_KW_idx; + if (use_collectives == 1) { + cached_CRS_idx = B_idx_CRS * BS_CRS + gl_SubgroupInvocationID; + cached_Cin_idx = fastdiv(cached_CRS_idx, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t cached_CRS_remainder = (cached_CRS_idx - cached_Cin_idx * p.KW * p.KH); + cached_KH_idx = fastdiv(cached_CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + cached_KW_idx = cached_CRS_remainder - cached_KH_idx * p.KW; + + CRS_idx_a = subgroupShuffle(cached_CRS_idx, Ac); + Cin_idx_a = subgroupShuffle(cached_Cin_idx, Ac); + KH_idx_a = subgroupShuffle(cached_KH_idx, Ac); + KW_idx_a = subgroupShuffle(cached_KW_idx, Ac); + } else { + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; + } +#else + CRS_idx_a = B_idx_CRS * BS_CRS + Ac; // Global CRS_idx_a (column index of A) + Cin_idx_a = fastdiv(CRS_idx_a, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); / (p.KW * p.KH); + CRS_remainder = CRS_idx_a - Cin_idx_a * p.KW * p.KH; + KH_idx_a = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_a = CRS_remainder - KH_idx_a * p.KW; +#endif + + /* Load kernel to A_block: (BS_K x BS_CRS)*/ + for (uint32_t r_offset = 0; r_offset < BS_K; r_offset += ArpWg) { + uint32_t B_ly = r_offset + Ar; + uint32_t B_lx = Ac; + uint32_t K_idx = B_idx_K * BS_K + B_ly; /* Global K_idx (row index of A)*/ +#ifdef TRANSPOSE + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + K_idx * p.nb02 + Cin_idx_a * p.nb03, K * CRS - 1); +#else + uint32_t knl_idx = min(KW_idx_a + KH_idx_a * p.nb01 + Cin_idx_a * p.nb02 + K_idx * p.nb03, K * CRS - 1); +#endif + float val = knl_data[knl_idx]; + if (K_idx >= K || CRS_idx_a >= CRS) { + val = 0.0; + } + Ash[B_ly * Ash_stride + B_lx] = SHMEM_TYPE(val); + } + /* Load input to B_block: (BS_CRS x BS_NPQ) */ + UNROLL for (uint32_t r_offset = 0; r_offset < BS_CRS; r_offset += BrpWg) { + uint32_t B_ly = r_offset + Br; /* Row index of B block */ + uint32_t B_lx = Bc; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + B_lx; /* Global NPQ index (column index of B) */ + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t NPQ_remainder = NPQ_idx - N_idx * p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_remainder, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_remainder - OH_idx * p.OW; + + uint32_t CRS_idx_b; + uint32_t Cin_idx_b; + uint32_t KH_idx_b; + uint32_t KW_idx_b; +#ifdef USE_COLLECTIVES + if (use_collectives == 1) { + CRS_idx_b = subgroupShuffle(cached_CRS_idx, r_offset + Br); + Cin_idx_b = subgroupShuffle(cached_Cin_idx, r_offset + Br); + KH_idx_b = subgroupShuffle(cached_KH_idx, r_offset + Br); + KW_idx_b = subgroupShuffle(cached_KW_idx, r_offset + Br); + } else { + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; + } +#else + CRS_idx_b = B_idx_CRS * BS_CRS + B_ly; /* Global CRS index (row index of B) */ + Cin_idx_b = fastdiv(CRS_idx_b, p.KWKHmp, p.KWKHL); // divide by (p.KW * p.KH); + uint32_t CRS_remainder = CRS_idx_b - Cin_idx_b * p.KW * p.KH; + KH_idx_b = fastdiv(CRS_remainder, p.KWmp, p.KWL); // divide by p.KW; + KW_idx_b = CRS_remainder - KH_idx_b * p.KW; +#endif + +#ifdef TRANSPOSE + uint32_t H_idx_x_s1 = OH_idx - KH_idx_b * p.d1 + p.p1; + uint32_t W_idx_x_s0 = OW_idx - KW_idx_b * p.d0 + p.p0; + uint32_t H_idx = fastdiv(H_idx_x_s1, p.s1mp, p.s1L); + uint32_t W_idx = fastdiv(W_idx_x_s0, p.s0mp, p.s0L); +#else + uint32_t H_idx = OH_idx * p.s1 + KH_idx_b * p.d1 - p.p1; + uint32_t W_idx = OW_idx * p.s0 + KW_idx_b * p.d0 - p.p0; +#endif + uint32_t src_idx = + min(max(W_idx + H_idx * p.nb11 + Cin_idx_b * p.nb12 + N_idx * p.nb13, 0), p.Cin * p.N * p.W * p.H - 1); + float val = src_data[src_idx]; + if (CRS_idx_b >= CRS || NPQ_idx >= NPQ + || H_idx >= p.H || W_idx >= p.W // Lower bound checks aren't necessary. (idx >= 0x80000000 for such case) +#ifdef TRANSPOSE + || (H_idx_x_s1 - H_idx * p.s1 != 0) || (W_idx_x_s0 - W_idx * p.s0 != 0) +#endif + ) { + val = 0.0; + } + Bsh[B_ly * Bsh_stride + B_lx] = SHMEM_TYPE(val); + } + barrier(); +#ifdef COOPMAT2 + coopmat matA; + coopmat matB; + + coopMatLoad(matA, Ash, 0, Ash_stride, gl_CooperativeMatrixLayoutRowMajor); + coopMatLoad(matB, Bsh, 0, Bsh_stride, gl_CooperativeMatrixLayoutRowMajor); + matC = coopMatMulAdd(matA, matB, matC); +#else + if (T_y * TS_K < K) { + UNROLL for (uint32_t CRS_lidx = 0; CRS_lidx < BS_CRS; CRS_lidx++) { + float regA[TS_K]; + float regB[TS_NPQ]; + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + regA[T_ly] = Ash[(T_y * TS_K + T_ly) * Ash_stride + CRS_lidx]; + } + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regB[T_lx] = Bsh[CRS_lidx * Bsh_stride + T_x * TS_NPQ + T_lx]; + } + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + regC[T_ly][T_lx] = fma(regA[T_ly], regB[T_lx], regC[T_ly][T_lx]); + } + } + } + } +#endif + barrier(); + } + /* Save C* */ +#ifdef COOPMAT2 + coopMatPerElementNV(matC, matC, perElemOpStore); +#else + if (T_y * TS_K < K) { + for (uint32_t T_ly = 0; T_ly < TS_K; T_ly++) { + for (uint32_t T_lx = 0; T_lx < TS_NPQ; T_lx++) { + uint32_t K_idx = B_idx_K * BS_K + T_y * TS_K + T_ly; + uint32_t NPQ_idx = B_idx_NPQ * BS_NPQ + T_x * TS_NPQ + T_lx; + uint32_t N_idx = fastdiv(NPQ_idx, p.OWOHmp, p.OWOHL); // divide by p.OH * p.OW; + uint32_t OH_idx = fastdiv(NPQ_idx - N_idx * p.OH * p.OW, p.OWmp, p.OWL); // divide by p.OW; + uint32_t OW_idx = NPQ_idx - N_idx * p.OH * p.OW - OH_idx * p.OW; + uint32_t dst_idx = OW_idx + OH_idx * p.nb1 + K_idx * p.nb2 + N_idx * p.nb3; + if (K_idx < K && NPQ_idx < NPQ) { + dst_data[dst_idx] = regC[T_ly][T_lx]; + } + } + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp new file mode 100644 index 0000000000..5217e18bdd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/conv_transpose_1d.comp @@ -0,0 +1,98 @@ +#version 450 + +#include "types.glsl" + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; // src0 - kernel: [K, Cout, Cin] +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; // src1 - input: [L, Cin] +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; // dst - result [KL, Cout] + +layout(local_size_x = 128 , local_size_y = 1, local_size_z = 1) in; + +layout (push_constant) uniform parameter { + uint32_t Cout; + uint32_t Cin; + uint32_t K; + uint32_t L; + uint32_t KL; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb11; + uint32_t nb1; + + int32_t s0; +} p; + + +uint32_t Cout_idx = gl_WorkGroupID.x; +const uint32_t bs = gl_WorkGroupSize.x; +uint32_t tid = gl_LocalInvocationID.x; +// Code is more straightforward if we assume it is bs*s0+K instead of (bs-1)*s0+K. +uint32_t tmp_len = bs*p.s0+p.K; +shared D_TYPE tmp[4096]; + +uint splitWork(uint workSize){ + return (bs + workSize -1) / bs; +} + +void main(){ + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx < tmp_len){ + tmp[idx] = 0.0; + } + } + + uint32_t L_blocks = splitWork(p.L); + for(uint32_t L_block_id = 0; L_block_id < L_blocks; L_block_id++){ + if(L_block_id > 0){ + barrier(); + // Shift values in tmp to the current processing window + for(int i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + if(idx >= bs*p.s0 && idx < tmp_len){ + tmp[idx-bs*p.s0] = tmp[idx]; + tmp[idx] = 0.0; + }else if(idx >= p.K && idx < bs*p.s0){ + tmp[idx] = 0.0; + } + } + } + barrier(); + + // Save contributions of the block to tmp + uint32_t L_idx = L_block_id*bs + tid; + for(uint32_t K_idx = 0; K_idx < p.K; K_idx++){ + D_TYPE dp = 0.0; + for(uint32_t Cin_idx = 0; Cin_idx < p.Cin; Cin_idx++){ + A_TYPE elemKrn = data_a[K_idx + Cout_idx * p.nb01 + Cin_idx * p.nb02]; + if(L_idx < p.L){ + B_TYPE elemInp = data_b[L_idx + Cin_idx*p.nb11]; + dp = fma(elemKrn, elemInp, dp); + } + } + tmp[tid*p.s0 + K_idx] += dp; + barrier(); + } + + // Save the computed values except the last block that can have different size + uint32_t KLb_idx = L_block_id*bs*p.s0; + if(L_block_id < L_blocks-1){ + for(uint32_t s0_idx = 0; s0_idx < p.s0; s0_idx++){ + uint32_t sh_idx = p.s0*tid+s0_idx; + uint32_t KL_idx = KLb_idx+sh_idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[sh_idx]; + } + } + } + } + + for(uint32_t i = 0; i < splitWork(tmp_len); i++){ + uint32_t idx = i*bs+tid; + uint32_t KL_idx = (L_blocks-1)*bs*p.s0+idx; + if(KL_idx < p.KL){ + data_d[KL_idx + Cout_idx*p.nb1] = tmp[idx]; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp new file mode 100644 index 0000000000..9f8bfd3c18 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + +#if defined(DATA_D_BF16) + float f = float(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(fp32_to_bf16(f)); +#elif !defined(OPTIMIZATION_ERROR_WORKAROUND) + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx(idx)]); +#else + data_d[get_doffset() + dst_idx(idx)] = data_a[get_aoffset() + src0_idx(idx)]; +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp new file mode 100644 index 0000000000..06df509525 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_from_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" +#include "dequant_funcs.glsl" + +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_MXFP4) +// 16 invocations needed for init_iq_shmem +layout(local_size_x = 16, local_size_y = 1, local_size_z = 1) in; +#else +layout(local_size_x = 1, local_size_y = 1, local_size_z = 1) in; +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); + if (gl_LocalInvocationIndex.x != 0) { + return; + } +#endif + + const uint idx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = get_doffset() + dst_idx(idx); + uint src_idx = src0_idx_quant(idx, QUANT_K); + + const uint a_offset = 0; + const uint ib = src_idx; + const vec2 dm = get_dm(ib, a_offset); + + [[unroll]] for (int j = 0; j < QUANT_K; j += 4) { + vec4 v = dequantize4(ib, j / QUANT_R, a_offset); + v = v * dm.x + vec4(dm.y); + +#if QUANT_R == 2 + data_d[dst_idx + j/2 + 0] = v[0]; + data_d[dst_idx + j/2 + QUANT_K/2 + 0] = v[1]; + data_d[dst_idx + j/2 + 1] = v[2]; + data_d[dst_idx + j/2 + QUANT_K/2 + 1] = v[3]; +#else + data_d[dst_idx + j + 0] = v[0]; + data_d[dst_idx + j + 1] = v[1]; + data_d[dst_idx + j + 2] = v[2]; + data_d[dst_idx + j + 3] = v[3]; +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp new file mode 100644 index 0000000000..b8c40eec10 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/copy_to_quant.comp @@ -0,0 +1,296 @@ +#version 450 + +#include "rte.glsl" +#include "types.glsl" + +#if defined(SET_ROWS) && QUANT_K == 1 +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 512; +#else +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; +const uint BLOCK_SIZE = 32; +#endif + +layout (binding = 0) readonly buffer S {float data_s[];}; + +#if defined(SET_ROWS) +#include "generic_binary_head.glsl" +layout (binding = 1) readonly buffer C {B_TYPE data_i[];}; +layout (binding = 2) writeonly buffer Q {A_TYPE data_q[];}; + +#if B_SIZE == 64 +#define DATA_I_SWIZZLE .x +#else +#define DATA_I_SWIZZLE +#endif + +#else +#include "generic_unary_head.glsl" +layout (binding = 1) writeonly buffer Q {A_TYPE data_q[];}; +#endif + +#if defined(DATA_A_Q4_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -8; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q4_0/2 + j]*id; + + const uint xi0 = min(15, int(x0 + 8.5)); + const uint xi1 = min(15, int(x1 + 8.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q4_1) +void quantize(uint dst_idx, uint src_idx) +{ + float vmin = 1.0/0.0; + float vmax = -vmin; + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1; ++j) { + const float v = data_s[src_idx + j]; + + if (v < vmin) vmin = v; + if (v > vmax) vmax = v; + } + + const float d = (vmax - vmin) / ((1 << 4) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(vmin); + + [[unroll]] for (int j = 0; j < QUANT_K_Q4_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - vmin)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q4_1/2 + j] - vmin)*id; + + const uint xi0 = min(15, int(x0 + 0.5)); + const uint xi1 = min(15, int(x1 + 0.5)); + + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + } +} +#endif + +#if defined(DATA_A_Q5_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + const float d = vmax / -16; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_0/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_Q5_0/2 + j]*id; + + const uint xi0 = min(31, int(x0 + 16.5)); + const uint xi1 = min(31, int(x1 + 16.5)); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_0/2); + } + data_q[dst_idx].qh[0] = uint16_t(qh & 0xFFFF); + data_q[dst_idx].qh[1] = uint16_t(qh >> 16); +} +#endif + +#if defined(DATA_A_Q5_1) +void quantize(uint dst_idx, uint src_idx) +{ + float min = data_s[src_idx + 0]; + float max = min; + + [[unroll]] for (int j = 1; j < QUANT_K_Q5_1; ++j) { + const float v = data_s[src_idx + j]; + min = v < min ? v : min; + max = v > max ? v : max; + } + + const float d = (max - min) / 31; + const float id = (d != 0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + data_q[dst_idx].m = float16_t(min); + + uint32_t qh = 0; + [[unroll]] for (int j = 0; j < QUANT_K_Q5_1/2; ++j) { + const float x0 = (data_s[src_idx + 0 + j] - min)*id; + const float x1 = (data_s[src_idx + QUANT_K_Q5_1/2 + j] - min)*id; + + const uint xi0 = uint(x0 + 0.5); + const uint xi1 = uint(x1 + 0.5); + + data_q[dst_idx].qs[j] = uint8_t((xi0 & 0xf) | ((xi1 & 0xf) << 4)); + qh |= ((xi0 & 0x10u) >> 4) << (j + 0); + qh |= ((xi1 & 0x10u) >> 4) << (j + QUANT_K_Q5_1/2); + } + data_q[dst_idx].qh = qh; +} +#endif + +#if defined(DATA_A_Q8_0) +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; // absolute max + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; j++) { + const float v = data_s[src_idx + j]; + amax = max(amax, abs(v)); + } + + const float d = amax / ((1 << 7) - 1); + const float id = (d != 0.0) ? 1.0/d : 0.0; + + data_q[dst_idx].d = float16_t(d); + + [[unroll]] for (int j = 0; j < QUANT_K_Q8_0; ++j) { + const float x0 = data_s[src_idx + j]*id; + + data_q[dst_idx].qs[j] = int8_t(round(x0)); + } +} +#endif + +#if defined(DATA_A_IQ4_NL) +uint best_index(float x) { + if (x <= kvalues_iq4nl[0]) return 0; + if (x >= kvalues_iq4nl[15]) return 15; + int ml = 0, mu = 15; + while (mu-ml > 1) { + int mav = (ml+mu)/2; + if (x < kvalues_iq4nl[mav]) mu = mav; else ml = mav; + } + return x - kvalues_iq4nl[mu-1] < kvalues_iq4nl[mu] - x ? mu-1 : mu; +} + +void quantize(uint dst_idx, uint src_idx) +{ + float amax = 0.0; + float vmax = 0.0; + + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL; ++j) { + const float v = data_s[src_idx + j]; + if (amax < abs(v)) { + amax = abs(v); + vmax = v; + } + } + + float d = vmax / kvalues_iq4nl[0]; + const float id = (d != 0.0) ? 1.0/d : 0.0; + + float sumqx = 0, sumq2 = 0; + [[unroll]] for (int j = 0; j < QUANT_K_IQ4_NL/2; ++j) { + const float x0 = data_s[src_idx + 0 + j]*id; + const float x1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*id; + const uint xi0 = best_index(x0); + const uint xi1 = best_index(x1); + data_q[dst_idx].qs[j] = uint8_t(xi0 | (xi1 << 4)); + const float v0 = kvalues_iq4nl[xi0]; + const float v1 = kvalues_iq4nl[xi1]; + const float w0 = data_s[src_idx + 0 + j]*data_s[src_idx + 0 + j]; + const float w1 = data_s[src_idx + QUANT_K_IQ4_NL/2 + j]*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumqx += w0*v0*data_s[src_idx + j] + w1*v1*data_s[src_idx + QUANT_K_IQ4_NL/2 + j]; + sumq2 += w0*v0*v0 + w1*v1*v1; + } + + data_q[dst_idx].d = float16_t(sumq2 > 0 ? sumqx/sumq2 : d); + +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(data_s[src_idx]); +} +#endif + +#if defined(DATA_A_BF16) +void quantize(uint dst_idx, uint src_idx) +{ + data_q[dst_idx] = A_TYPE(fp32_to_bf16(data_s[src_idx])); +} +#endif + +#if defined(SET_ROWS) + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint idx = ((gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x) * BLOCK_SIZE + gl_LocalInvocationID.x) * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + uint i12 = fastmod(i03, p.ne12); + uint i11 = fastmod(i02, p.ne11); + uint i10 = i01; + + uint i1 = data_i[src1_idx(i10, i11, i12, 0) + get_boffset()] DATA_I_SWIZZLE; + + uint src0_idx = src0_idx(i00, i01, i02, i03) + get_aoffset(); + uint dst_idx = dst_idx(i00 / QUANT_K, i1, i02, i03) + get_doffset(); + + quantize(dst_idx, src0_idx); +} + +#else + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint idx = (gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x) * QUANT_K; + + if (idx >= p.ne) { + return; + } + + uint dst_idx = dst_idx_quant(idx, QUANT_K); + uint src_idx = get_aoffset() + src0_idx(idx); + + quantize(dst_idx, src_idx); +} + +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp new file mode 100644 index 0000000000..db6865db98 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/cos.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(cos(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp new file mode 100644 index 0000000000..e75df66756 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/count_equal.comp @@ -0,0 +1,31 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" +#include "generic_head.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +const uint CHUNK_SIZE = 512; + +void main() { + const uint base = gl_WorkGroupID.x * CHUNK_SIZE; + const uint col = gl_LocalInvocationID.x; + + uint count = 0; + [[unroll]] + for (uint i = 0; i < CHUNK_SIZE; i += gl_WorkGroupSize.x) { + const uint idx = base + i + col; + if (idx >= p.KX) { + break; + } + count += uint(data_a[idx] == data_b[idx]); + } + + atomicAdd(data_d[0], D_TYPE(count)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp new file mode 100644 index 0000000000..765afffa80 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_f32.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_GlobalInvocationID.x * 16; + + if (i >= p.nel) { + return; + } + + [[unroll]] for (uint l = 0; l < 16; l++) { + data_b[i + l] = D_TYPE(data_a[i + l]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl new file mode 100644 index 0000000000..0d98f5a9d6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs.glsl @@ -0,0 +1,616 @@ +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#endif + +#include "types.glsl" + +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +#if defined(DATA_A_F32) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_F16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(data_a[a_offset + ib], data_a[a_offset + ib + 1]); +} +#endif + +#if defined(DATA_A_BF16) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(bf16_to_fp32(data_a[a_offset + ib]), bf16_to_fp32(data_a[a_offset + ib + 1])); +} +#endif + +#if defined(DATA_A_Q4_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2(vui & 0xF, vui >> 4) - 8.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12) - 8.0f); +} +#endif + +#if defined(DATA_A_Q4_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(vui & 0xF, vui >> 4); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(vui & 0xF, (vui >> 4) & 0xF, (vui >> 8) & 0xF, vui >> 12); +} +#endif + +#if defined(DATA_A_Q5_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a[a_offset + ib].qh[1]) << 16 | data_a[a_offset + ib].qh[0]; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return (vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y) - 16.0f); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = uint(data_a_packed16[a_offset + ib].qh[1]) << 16 | data_a_packed16[a_offset + ib].qh[0]; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f); +} +#endif + +#if defined(DATA_A_Q5_1) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a[a_offset + ib].qh; + const ivec2 qh = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2((vui & 0xF) | qh.x, (vui >> 4) | qh.y); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint uint_qh = data_a_packed16[a_offset + ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> iqs) << 4) & 0x10, (uint_qh >> (iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (iqs + 1)) << 4) & 0x10, (uint_qh >> (iqs + 13)) & 0x10); + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y); +} +#endif + +#if defined(DATA_A_Q8_0) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + return vec2(int(data_a[a_offset + ib].qs[iqs]), int(data_a[a_offset + ib].qs[iqs + 1])); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[a_offset + ib].qs[iqs/2 + 1])).xy; + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#if defined(DATA_A_IQ1_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint idxhi = bitfieldExtract(qh, 3 * int(ib8 & 3), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + const int i8 = int(iqs % 8); + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const float dl = 2 * bitfieldExtract(qh, 12, 3) + 1; + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec2 gvec = ivec2( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2) + ); + return dl * (vec2(gvec) + delta); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib8 = iqs / 8; + const uint ib16 = iqs / 16; + const int i8 = int(iqs % 8); + const uint sc = data_a[a_offset + ib].scales[iqs / 64]; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + // Signed bitfield extract. + const ivec4 gvec = ivec4( + bitfieldExtract(grid, 2 * (i8), 2), + bitfieldExtract(grid, 2 * (i8 + 1), 2), + bitfieldExtract(grid, 2 * (i8 + 2), 2), + bitfieldExtract(grid, 2 * (i8 + 3), 2) + ); + return dl * (vec4(gvec) + delta); +} +#endif + +#if defined(DATA_A_IQ2_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = (iqs / 8) % 4; + const uint qs = data_a[a_offset + ib].qs[8 * ib32 + ib8]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[4 * ib32 + 2], + data_a_packed16[a_offset + ib].qs[4 * ib32 + 3])); + const float db = 0.25 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xxs_grid[qs][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint scale = (data_a[a_offset + ib].scales[iqs / 32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[iqs / 8]; + const float db = 0.25 * (0.5 + scale); + const uint sign7 = qs >> 9; + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq2xs_grid[qs & 511][(iqs % 8) / 4] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ2_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid[iqs % 4] * (sign0 ? -1.0 : 1.0), + grid[(iqs % 4) + 1] * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint ib8 = iqs / 8; + + const uint scale = (data_a[a_offset + ib].scales[ib32] >> (4 * ((iqs / 16) & 1))) & 0xf; + const uint qs = data_a[a_offset + ib].qs[ib8]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[a_offset + ib].qs[QUANT_K / 8 + ib8] >> (iqs % 8); + + const float db = 0.25 * (0.5 + scale); + const u8vec4 grid = unpack8(iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(iqs % 8) / 4]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + // Scales are stored as packed 7+7+7+7+4 bits (4 sign tuples and 1 int4 scale) + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs] >> (8 * (iqs % 4))); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + return db * vec2( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint is = QUANT_K / 4 + 4 * ib32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint signs = pack32(u16vec2(data_a_packed16[a_offset + ib].qs[is / 2], + data_a_packed16[a_offset + ib].qs[is / 2 + 1])); + const float db = 0.5 * (0.5 + (signs >> 28)); + const uint sign7 = bitfieldExtract(signs, 7 * (int(ib4 / 2) % 4), 7); + // Add parity bit + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint sign = sign8 >> (iqs % 8); + const u8vec4 grid = unpack8(iq3xxs_grid[qs]); + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + return db * vec4( + grid.x * (sign0 ? -1.0 : 1.0), + grid.y * (sign1 ? -1.0 : 1.0), + grid.z * (sign2 ? -1.0 : 1.0), + grid.w * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ3_S) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint qs = data_a[a_offset + ib].qs[iqs / 4]; + const uint qh = data_a[a_offset + ib].qh[iqs / 32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[iqs / 64]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + const float db = 1 + 2 * ((scale >> (4 * ((iqs / 32) & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ((iqs / 4) % 8))) & 256)] >> (8 * (iqs % 4)); + return db * vec2( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0) + ); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib4 = iqs / 4; + const uint ib32 = iqs / 32; + const uint qs = data_a[a_offset + ib].qs[ib4]; + const uint qh = data_a[a_offset + ib].qh[ib32]; + const uint sign = data_a[a_offset + ib].signs[iqs / 8] >> (iqs % 8); + const uint scale = data_a[a_offset + ib].scales[ib32 / 2]; + bool sign0 = (sign & 1) != 0; + bool sign1 = (sign & 2) != 0; + bool sign2 = (sign & 4) != 0; + bool sign3 = (sign & 8) != 0; + const float db = 1 + 2 * ((scale >> (4 * (ib32 & 1))) & 0xf); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - ib4 % 8)) & 256)] >> (8 * (iqs % 4)); + return db * vec4( + int(grid & 0xFF) * (sign0 ? -1.0 : 1.0), + int((grid >> 8) & 0xFF) * (sign1 ? -1.0 : 1.0), + int((grid >> 16) & 0xFF) * (sign2 ? -1.0 : 1.0), + int((grid >> 24) & 0xFF) * (sign3 ? -1.0 : 1.0) + ); +} +#endif + +#if defined(DATA_A_IQ4_XS) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec2 qs = u8vec2(data_a[a_offset + ib].qs[iq], data_a[a_offset + ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint ib32 = iqs / 32; + const uint iq = 16 * ib32 + (iqs % 16); + + const uint sl = (data_a[a_offset + ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = (data_a[a_offset + ib].scales_h >> (2 * ib32)) & 3; + const uint qshift = (iqs & 16) >> 2; + u8vec4 qs = u8vec4( + data_a[a_offset + ib].qs[iq + 0], + data_a[a_offset + ib].qs[iq + 1], + data_a[a_offset + ib].qs[iq + 2], + data_a[a_offset + ib].qs[iq + 3] + ); + qs = (qs >> qshift) & uint8_t(0xF); + + const float dl = float(int(sl | (sh << 4)) - 32); + return dl * vec4( + kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y], + kvalues_iq4nl[qs.z], kvalues_iq4nl[qs.w]); +} +#endif + +#if defined(DATA_A_IQ4_NL) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a_packed16[a_offset + ib].qs[iqs/2]); + return vec4(kvalues_iq4nl[vui & 0xF], kvalues_iq4nl[(vui >> 4) & 0xF], kvalues_iq4nl[(vui >> 8) & 0xF], kvalues_iq4nl[vui >> 12]); +} +#endif + +#if defined(DATA_A_MXFP4) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + const uint vui = uint(data_a[a_offset + ib].qs[iqs]); + return vec2(kvalues_mxfp4[vui & 0xF], kvalues_mxfp4[vui >> 4]); +} +vec4 dequantize4(uint ib, uint iqs, uint a_offset) { + vec2 v0 = dequantize(ib, iqs, a_offset); + vec2 v1 = dequantize(ib, iqs + 1, a_offset); + return vec4(v0.x, v0.y, v1.x, v1.y); +} +#endif + +#if defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(0, 0); +} +#endif + +#if defined(DATA_A_IQ1_M) +vec2 get_dm(uint ib, uint a_offset) { + const uint16_t[4] scales = data_a[a_offset + ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + return vec2(d, 0); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), 0); +} +#endif + +#if defined(DATA_A_MXFP4) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(e8m0_to_fp32(data_a[a_offset + ib].e), 0); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +vec2 get_dm(uint ib, uint a_offset) { + return vec2(float(data_a[a_offset + ib].d), float(data_a[a_offset + ib].m)); +} +#endif + +#if defined(DATA_A_Q2_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[a_offset + ib].qs[qsi], data_a[a_offset + ib].qs[qsi + 1]); + const uint scales = data_a[a_offset + ib].scales[scalesi]; + const vec2 d = vec2(data_a[a_offset + ib].d); + + return d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q3_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[a_offset + ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[a_offset + ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[a_offset + ib].d) * float(us - 32); + + return vec2(dl * float(int8_t((data_a[a_offset + ib].qs[qsi ] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[a_offset + ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[a_offset + ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q4_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q5_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[a_offset + ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[a_offset + ib].scales[scidx0] & 0xF) | ((data_a[a_offset + ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[a_offset + ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[a_offset + ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + return vec2(fma(d, float((data_a[a_offset + ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[a_offset + ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[a_offset + ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif + +#if defined(DATA_A_Q6_K) +vec2 dequantize(uint ib, uint iqs, uint a_offset) { + iqs /= 2; + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[a_offset + ib].d) * float(data_a[a_offset + ib].scales[is]); + + return vec2(dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[a_offset + ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[a_offset + ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +} +vec2 get_dm(uint ib, uint a_offset) { + return vec2(1, 0); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl new file mode 100644 index 0000000000..6a5bb4574d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_funcs_cm2.glsl @@ -0,0 +1,720 @@ + +#include "types.glsl" + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ4_0 { + block_q4_0_packed16 block; +}; + +float16_t dequantFuncQ4_0(const in decodeBufQ4_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = uint32_t(bl.block.qs[(idx & 0xE) >> 1]); + qs >>= shift; + qs &= 0x0F0F; + qs = unpack8(qs)[idx & 1]; + float16_t ret = (float16_t(qs) - float16_t(8)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ4_1 { + block_q4_1 block; +}; + +float16_t dequantFuncQ4_1(const in decodeBufQ4_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(qs) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ5_0 { + block_q5_0 block; +}; + +float16_t dequantFuncQ5_0(const in decodeBufQ5_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = uint(bl.block.qh[1]) << 16 | bl.block.qh[0]; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = (float16_t(qs | qh) - float16_t(16)) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufQ5_1 { + block_q5_1 block; +}; + +float16_t dequantFuncQ5_1(const in decodeBufQ5_1 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const float16_t m = bl.block.m; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + + const uint uint_qh = bl.block.qh; + const uint qh = ((uint_qh >> idx) << 4) & 0x10; + + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + + float16_t ret = float16_t(qs | qh) * d + m; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ8_0 { + block_q8_0_packed16 block; +}; + +float16_t dequantFuncQ8_0(const in decodeBufQ8_0 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + // Load 16b and select the byte for this element + int32_t qs = unpack8(bl.block.qs[(iqs & 0x1E) >> 1])[iqs & 1]; + float16_t ret = float16_t(qs) * d; + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 4) buffer decodeBufQ2_K { + block_q2_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ2_K_packed16 { + block_q2_K_packed16 block; +}; + +float16_t dequantFuncQ2_K(const in decodeBufQ2_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ2_K_packed16 bl16 = decodeBufQ2_K_packed16(bl); + const f16vec2 d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint scalesi = (idx & 0xF0) >> 4; // 0..15 + const uint qsshift = (idx & 0x60) >> 4; // 0,2,4,6 + + uint qs = uint32_t(bl16.block.qs[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qs = (qs >> qsshift) & 0x0303; + qs = unpack8(qs)[idx & 1]; + + const uint scales = bl.block.scales[scalesi]; + float16_t ret = d.x * float16_t(scales & 0xF) * float16_t(qs) - d.y * float16_t(scales >> 4); + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ3_K { + block_q3_K block; +}; + +float16_t dequantFuncQ3_K(const in decodeBufQ3_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint idx = coordInBlock[1]; + const uint iqs = idx; + + const uint n = iqs / 128; // 0,1 + const uint qsi = n * 32 + (iqs % 32); // 0..63 + const uint hmi = (iqs % 32); // 0..31 + const uint j = (iqs % 128) / 8; // 0..15 + const uint is = iqs / 16; // 0..15 + const uint halfsplit = ((iqs % 128) / 32); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + uint32_t scaleidx0 = (is < 8) ? is : (is-8); + uint32_t scaleidx0shift = (is < 8) ? 0 : 4; + uint32_t scaleidx1 = is + 8 - (is/4)*4; + uint32_t scaleidx1shift = (is/4)*2; + + const int8_t us = int8_t(((bl.block.scales[scaleidx0] >> scaleidx0shift) & 0xF) | (((bl.block.scales[scaleidx1] >> scaleidx1shift) & 3) << 4)); + + const float16_t dl = bl.block.d * float16_t(us - 32); + + float16_t ret = dl * float16_t(int8_t((bl.block.qs[qsi ] >> qsshift) & 3) - (((bl.block.hmask[hmi ] & m) != 0) ? 0 : 4)); + + return ret; +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K { + block_q4_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed16 { + block_q4_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ4_K_packed128 { + block_q4_K_packed128 block; +}; + +#if defined(IS_MUL_MM2) + +// For Q4_K and Q5_K in the mat-mul shader, we decode a tile's worth of scales +// into shared memory and then process the whole tile using those scales. +// There is a fetch function that loads into private variables and then a store +// function that stores into shared memory. +// Q4_K and Q5_K have the same encoding of scales, so everything is shared except +// the part that fetches from the structure (which has a different block layout). +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +const uint shAscales_stride = (BM + 2); +// 1 scale per 32 elements -> 8 scales per block, per row +shared vec2 shAscales[8 * shAscales_stride]; +uvec4 row_v; +#endif + +#if defined(DATA_A_Q4_K) +layout (binding = 0) readonly buffer A_Q4_K_128 {block_q4_K_packed128 data_a_q4_k_packed128[];}; + +void fetch_scalesQ4_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q4_k_packed128[block_index].q4k[0]; + } +} +#endif +#if defined(DATA_A_Q5_K) +layout (binding = 0) readonly buffer A_Q5_K_128 {block_q5_K_packed128 data_a_q5_k_packed128[];}; + +void fetch_scalesQ5_K(uint ir_BM, uint pos_a, uint stride_a, uint block_k, uint tid, bool in_bounds) +{ + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + uint row = ir_BM + tid_row; + uint block_index = pos_a + row * stride_a + (block_k / QUANT_K); + if (in_bounds || row < p.M) { + row_v = data_a_q5_k_packed128[block_index].q5k[0]; + } +} +#endif + +#if defined(DATA_A_Q4_K) || defined(DATA_A_Q5_K) +void store_scalesQ4_K(uint tid) +{ + barrier(); + + uint tids_per_row = BLOCK_SIZE / BM; + uint is_per_tid = 8 / tids_per_row; + uint is_start = is_per_tid * (tid % tids_per_row); + uint tid_row = tid / tids_per_row; + + [[unroll]] for (uint idx = 0; idx < is_per_tid; ++idx) { + uint is = idx + is_start; + uvec4 v = row_v; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); + shAscales[is * shAscales_stride + tid_row] = vec2(d,m); + } + + barrier(); +} +#endif + +#endif + +float16_t dequantFuncQ4_K(const in decodeBufQ4_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ4_K_packed16 bl16 = decodeBufQ4_K_packed16(bl); + decodeBufQ4_K_packed128 bl128 = decodeBufQ4_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q4_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q4k[0]; + const vec2 loadd = vec2(unpackFloat2x16(v.x)); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float d = loadd.x * float(sc); + const float m = loadd.y * float(mbyte); +#endif + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4 + 8 * (idx & 1))) & 0xF; + + float ret = d * float(qs) - m; + + return float16_t(ret); +} + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K { + block_q5_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed16 { + block_q5_K_packed16 block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ5_K_packed128 { + block_q5_K_packed128 block; +}; + +float16_t dequantFuncQ5_K(const in decodeBufQ5_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ5_K_packed16 bl16 = decodeBufQ5_K_packed16(bl); + decodeBufQ5_K_packed128 bl128 = decodeBufQ5_K_packed128(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x20) >> 5; // 0,1 + const uint is = (idx & 0xE0) >> 5; // 0..7 + +#if defined(IS_MUL_MM2) && defined(DATA_A_Q5_K) + vec2 v = shAscales[is * shAscales_stride + (blockCoords[0] % BM)]; + float d = v.x; + float m = v.y; +#else + uvec4 v = bl128.block.q5k[0]; + + const f16vec2 loadd = unpackFloat2x16(v.x); + + uint32_t sc; + uint32_t mbyte; + + uint32_t scale0 = v.y; + uint32_t scale4 = v.z; + uint32_t scale8 = v.w; + + uint32_t sc_lo = scale0; + uint32_t mb_lo = scale4; + uint32_t sc_hi = (scale8 & 0x0F0F0F0F) | ((scale0 & 0xC0C0C0C0) >> 2); + uint32_t mb_hi = ((scale8 & 0xF0F0F0F0) >> 4) | ((scale4 & 0xC0C0C0C0) >> 2); + + sc = is < 4 ? sc_lo : sc_hi; + mbyte = is < 4 ? mb_lo : mb_hi; + sc = sc >> (8 * (is & 3)); + mbyte = mbyte >> (8 * (is & 3)); + sc &= 0x3F; + mbyte &= 0x3F; + + const float16_t d = loadd.x * float16_t(sc); + const float16_t m = loadd.y * float16_t(mbyte); +#endif + + uint qh = uint32_t(bl16.block.qh[(idx & 0x1E) >> 1]); + qh = ((qh >> is) & 0x101) << 4; + + uint qs = uint32_t(bl16.block.qs[((idx & 0xC0) >> 2) + ((idx & 0x1E) >> 1)]); + qs = (qs >> (b * 4)) & 0x0F0F; + qs = unpack8(qs | qh)[idx & 1]; + + float ret = d * float(qs) - m; + + return float16_t(ret); +} + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufQ6_K { + block_q6_K block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 16) buffer decodeBufQ6_K_packed16 { + block_q6_K_packed16 block; +}; + +float16_t dequantFuncQ6_K(const in decodeBufQ6_K bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufQ6_K_packed16 bl16 = decodeBufQ6_K_packed16(bl); + const uint idx = coordInBlock[1]; + + const uint b = (idx & 0x40) >> 6; // 0,1 + const uint qhshift = (idx & 0x60) >> 4; // 0,2,4,6 + const uint is = (idx & 0xF0) >> 4; // 0..15 + + const float16_t dscale = bl.block.d * float16_t(bl.block.scales[is]); + + uint ql = uint32_t(bl16.block.ql[((idx & 0x80) >> 2) + ((idx & 0x3E) >> 1)]); + ql = (ql >> (b * 4)) & 0x0F0F; + + uint qh = uint32_t(bl16.block.qh[((idx & 0x80) >> 3) + ((idx & 0x1E) >> 1)]); + qh = ((qh >> qhshift) & 0x0303) << 4; + + int q = unpack8(ql | qh)[idx & 1]; + + float16_t ret = dscale * float16_t(q - 32); + + return ret; +} + +#if defined(DATA_A_IQ1_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_S { + block_iq1_s block; +}; + +float16_t dequantFuncIQ1_S(const in decodeBufIQ1_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; + const uint ib8 = (idx & 0xF8) >> 3; + + const uint qh = bl.block.qh[ib32]; + const uint qs = bl.block.qs[ib8]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]; + + float16_t ret = float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * int(idx % 8), 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ1_M) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ1_M { + block_iq1_m block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 8) buffer decodeBufIQ1_M_packed64 { + block_iq1_m_packed64 block; +}; + +float16_t dequantFuncIQ1_M(const in decodeBufIQ1_M bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ1_M_packed64 bl64 = decodeBufIQ1_M_packed64(bl); + const uint idx = coordInBlock[1]; + + uvec2 scales = unpack32(bl64.block.scales); + const float16_t d = uint16BitsToHalf(uint16_t(((scales.x & 0xF000) >> 12) | ((scales.x & 0xF0000000) >> 24) | ((scales.y & 0xF000) >> 4) | ((scales.y & 0xF0000000) >> 16))); + + const uint ib8 = (idx & 0xF8) >> 3; + const uint ib16 = (idx & 0xF0) >> 4; + const int i8 = int(idx % 8); + const uint sc = bl.block.scales[ib8 / 8]; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib16] >> (4 * (ib8 & 1)); + const float dl = 2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1; + const float delta = ((qh & 8) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const uint grid = iq1s_grid[qs | ((qh & 7) << 8)]; + + float16_t ret = d * float16_t(dl) * (float16_t(bitfieldExtract(int(grid), 2 * i8, 2)) + float16_t(delta)); + return ret; +} +#endif + +#if defined(DATA_A_IQ2_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS { + block_iq2_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XXS_packed16 { + block_iq2_xxs_packed16 block; +}; + +float16_t dequantFuncIQ2_XXS(const in decodeBufIQ2_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ2_XXS_packed16 bl16 = decodeBufIQ2_XXS_packed16(bl); + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0x18) >> 3; // 0..3 + const uint iqs = 8 * ib32 + ib8; + + const uint qs = bl.block.qs[iqs]; + const uint signscale = pack32(u16vec2(bl16.block.qs[4*ib32+2], bl16.block.qs[4*ib32+3])); + + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float(signscale >> 28)); + uint sign = bitfieldExtract(signscale, 7 * int(ib8), 7); + sign |= bitCount(sign) << 7; + + uint g2 = iq2xxs_grid[qs][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_XS { + block_iq2_xs block; +}; + +float16_t dequantFuncIQ2_XS(const in decodeBufIQ2_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint is = (idx & 0xE0) >> 5; // 0..8 + const uint sshift = (idx & 0x10) >> 2; // 0,4 + const uint iqs = (idx & 0xF8) >> 3; // 0..63 + + const uint16_t qs = bl.block.qs[iqs]; + const float dscale = float(bl.block.d) * 0.25 * (0.5 + float((bl.block.scales[is] >> sshift) & 0xF)); + + uint sign = uint(qs >> 9); + sign |= bitCount(sign) << 7; + uint g2 = iq2xs_grid[qs & 0x1FF][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 g = vec2(unpack8(g2)); + + vec2 ret = dscale * g * ((sign & (1 << (idx & 7))) != 0 ? -1.0hf : 1.0hf); + return float16_t(ret[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ2_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ2_S { + block_iq2_s block; +}; + +float16_t dequantFuncIQ2_S(const in decodeBufIQ2_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + const uint ib8 = (idx & 0xF8) >> 3; // 0..31 + const uint qhshift = 2 * (ib8 % 4); + + const uint scale = (bl.block.scales[ib32] >> ((idx & 0x10) >> 2)) & 0xf; + const uint qs = bl.block.qs[ib8]; + const uint qh = bl.block.qh[ib32]; + const uint sign = bl.block.qs[QUANT_K / 8 + ib8] >> (idx & 0x6); + + const float d = float(bl.block.d); + const float db = d * 0.25 * (0.5 + scale); + const ivec2 sign01 = 1 - (2 & ivec2(sign << 1, sign)); + uint g2 = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)][(idx & 4) >> 2]; + g2 >>= (idx & 2) * 8; + const vec2 v = db * vec2(sign01) * vec2(unpack8(g2)); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_XXS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS { + block_iq3_xxs block; +}; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_XXS_packed16 { + block_iq3_xxs_packed16 block; +}; + +float16_t dequantFuncIQ3_XXS(const in decodeBufIQ3_XXS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + decodeBufIQ3_XXS_packed16 bl16 = decodeBufIQ3_XXS_packed16(bl); + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint is = QUANT_K / 4 + ((idx & 0xE0) >> 3);// 8 values + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint signs = pack32(u16vec2( + bl16.block.qs[is/2+0], + bl16.block.qs[is/2+1] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (idx & 0x6); + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const uint grid = iq3xxs_grid[qs] >> (16 * ((idx & 2) >> 1)); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ3_S) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ3_S { + block_iq3_s block; +}; + +float16_t dequantFuncIQ3_S(const in decodeBufIQ3_S bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + uint idx = coordInBlock[1]; + + const uint iqs = (idx & 0xFC) >> 2; // 0..63 + const uint iqh = (idx & 0xE0) >> 5; + + const float d = float(bl.block.d); + const uint qs = bl.block.qs[iqs]; + const uint qh = bl.block.qh[iqh]; + const int8_t sign = int8_t(bl.block.signs[iqs / 2] >> (idx & 0x6)); + const uint scale = bl.block.scales[iqs / 16]; + const ivec2 sign01 = ivec2(1 - (2 & ivec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)] >> ((idx & 2) << 3); + const vec2 v = db * vec2(sign01) * vec2(unpack8(grid).xy); + + return float16_t(v[idx & 1]); +} +#endif + +#if defined(DATA_A_IQ4_XS) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_XS { + block_iq4_xs block; +}; + +float16_t dequantFuncIQ4_XS(const in decodeBufIQ4_XS bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + + const uint ib32 = (idx & 0xE0) >> 5; // 0..7 + + const uint sl = (bl.block.scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((bl.block.scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 16) >> 2; + const uint q = (bl.block.qs[16 * ib32 + (idx % 16)] >> qshift) & 0xF; + + float16_t ret = d * float16_t(int(sl | (sh << 4)) - 32) * float16_t(kvalues_iq4nl[q]); + return ret; +} +#endif + +#if defined(DATA_A_IQ4_NL) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufIQ4_NL { + block_iq4_nl block; +}; + +float16_t dequantFuncIQ4_NL(const in decodeBufIQ4_NL bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float16_t d = bl.block.d; + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_iq4nl[qs]) * d; + return ret; +} +#endif + +#if defined(DATA_A_MXFP4) +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufMXFP4 { + block_mxfp4 block; +}; + +float16_t dequantFuncMXFP4(const in decodeBufMXFP4 bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const float d = e8m0_to_fp32(bl.block.e); + const uint idx = coordInBlock[1]; + const uint iqs = idx & 0xF; + const uint shift = (idx & 0x10) >> 2; + uint32_t qs = bl.block.qs[iqs]; + qs >>= shift; + qs &= 0xF; + float16_t ret = float16_t(kvalues_mxfp4[qs] * d); + return ret; +} +#endif + +#if defined(DATA_A_Q4_0) +#define dequantFuncA dequantFuncQ4_0 +#elif defined(DATA_A_Q4_1) +#define dequantFuncA dequantFuncQ4_1 +#elif defined(DATA_A_Q5_0) +#define dequantFuncA dequantFuncQ5_0 +#elif defined(DATA_A_Q5_1) +#define dequantFuncA dequantFuncQ5_1 +#elif defined(DATA_A_Q8_0) +#define dequantFuncA dequantFuncQ8_0 +#elif defined(DATA_A_Q2_K) +#define dequantFuncA dequantFuncQ2_K +#elif defined(DATA_A_Q3_K) +#define dequantFuncA dequantFuncQ3_K +#elif defined(DATA_A_Q4_K) +#define dequantFuncA dequantFuncQ4_K +#define fetch_scales fetch_scalesQ4_K +#define store_scales store_scalesQ4_K +#elif defined(DATA_A_Q5_K) +#define dequantFuncA dequantFuncQ5_K +#define fetch_scales fetch_scalesQ5_K +#define store_scales store_scalesQ4_K +#elif defined(DATA_A_Q6_K) +#define dequantFuncA dequantFuncQ6_K +#elif defined(DATA_A_IQ1_S) +#define dequantFuncA dequantFuncIQ1_S +#elif defined(DATA_A_IQ1_M) +#define dequantFuncA dequantFuncIQ1_M +#elif defined(DATA_A_IQ2_XXS) +#define dequantFuncA dequantFuncIQ2_XXS +#elif defined(DATA_A_IQ2_XS) +#define dequantFuncA dequantFuncIQ2_XS +#elif defined(DATA_A_IQ2_S) +#define dequantFuncA dequantFuncIQ2_S +#elif defined(DATA_A_IQ3_XXS) +#define dequantFuncA dequantFuncIQ3_XXS +#elif defined(DATA_A_IQ3_S) +#define dequantFuncA dequantFuncIQ3_S +#elif defined(DATA_A_IQ4_XS) +#define dequantFuncA dequantFuncIQ4_XS +#elif defined(DATA_A_IQ4_NL) +#define dequantFuncA dequantFuncIQ4_NL +#elif defined(DATA_A_MXFP4) +#define dequantFuncA dequantFuncMXFP4 +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl new file mode 100644 index 0000000000..addceafade --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_head.glsl @@ -0,0 +1,13 @@ +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint M; + uint K; + uint stride_a; + uint stride_b; + uint nel; +} p; + +#include "types.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp new file mode 100644 index 0000000000..637c95fa35 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_m.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_m data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint ib64 = ib32 / 2; + const uint b_idx = 256 * ib + 32 * ib32; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ib].scales[ib64]; + [[unroll]] for (int l = 0; l < 4; ++l) { + const uint ib16 = 2 * ib32 + l / 2; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const uint qh = data_a[ib].qh[ib16] >> (4 * (l & 1)); + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp new file mode 100644 index 0000000000..d1cbc5e9d0 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq1_s.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq1_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + uint qh = data_a[ib].qh[ib32]; + const float d = float(data_a[ib].d); + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint hi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (hi << 8)]); + [[unroll]] for (int j = 0; j < 8; ++j) { + data_b[b_idx + 8 * l + j] = D_TYPE(dl * (bitfieldExtract(grid, 2*j, 2) + delta)); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp new file mode 100644 index 0000000000..78490162cd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_s.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + uint qh = data_a[ib].qh[ib32]; + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint qs = data_a[ib].qs[4 * ib32 + l]; + const uint8_t sign = data_a[ib].qs[QUANT_K / 8 + 4 * ib32 + l]; + qs |= (qh << (8 - 2 * l)) & 0x300; + const uvec2 grid = iq2s_grid[qs]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp new file mode 100644 index 0000000000..9b8ce0a7f8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xs.comp @@ -0,0 +1,43 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (32 values with 2 scales) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * ib32; + + const float d = float(data_a[ib].d); + const vec2 scale = vec2(data_a[ib].scales[ib32] & 0xf, data_a[ib].scales[ib32] >> 4); + const vec2 db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + uint16_t qs = data_a[ib].qs[4 * ib32 + l]; + const uint sign7 = qs >> 9; + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uvec2 grid = iq2xs_grid[qs & 511]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db[l/2] * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db[l/2] * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db[l/2] * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db[l/2] * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db[l/2] * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db[l/2] * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db[l/2] * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db[l/2] * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp new file mode 100644 index 0000000000..aacf07d0f8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq2_xxs.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq2_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // Each block is described by 4 lattice indices, 4x7 sign bits and 4 scale bits + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[8*is + 4], + data_a[ib].qs[8*is + 5], + data_a[ib].qs[8*is + 6], + data_a[ib].qs[8*is + 7] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.25; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + const uint sign8 = sign7 | (bitCount(sign7) << 7); // parity bit + const uint qs = data_a[ib].qs[8 * is + l]; + const uvec2 grid = iq2xxs_grid[qs]; + const u8vec4 grid0 = unpack8(grid.x); + const u8vec4 grid1 = unpack8(grid.y); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp new file mode 100644 index 0000000000..f2c20b1d2c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_s.comp @@ -0,0 +1,40 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_s data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale nibble. + // Each block contains 4 scale bytes (8 scales) for 256 output values. + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + + const float d = float(data_a[ib].d); + const float db = d * (1 + 2 * ((data_a[ib].scales[is / 2] >> (4 * (is % 2))) & 0xf)); + + // We must produce 32 values using 4 sign bytes, 1 qh byte, 8 qs bytes. + uint qh = data_a[ib].qh[is]; + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = 8 * is + l; + const uint qs = data_a[ib].qs[iqs]; + const uint gidx = qs | ((qh << (8 - l)) & 256); + const uint8_t signs = data_a[ib].signs[iqs / 2] >> (4 * (l & 1)); + const u8vec4 grid = unpack8(iq3s_grid[gidx]); + data_b[b_idx + 4 * l + 0] = D_TYPE(db * grid.x * ((signs & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 1] = D_TYPE(db * grid.y * ((signs & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 2] = D_TYPE(db * grid.z * ((signs & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 4 * l + 3] = D_TYPE(db * grid.w * ((signs & 8) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp new file mode 100644 index 0000000000..671c1f4a0d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq3_xxs.comp @@ -0,0 +1,51 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq3_xxs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 scale block (32 values) + // 8 threads handle 1 superblock + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint is = gl_LocalInvocationID.x % 8; + const uint b_idx = 256 * ib + 32 * is; + const uint s_idx = QUANT_K / 4 + 4 * is; + + const float d = float(data_a[ib].d); + uint signscale = pack32(u8vec4( + data_a[ib].qs[s_idx + 0], + data_a[ib].qs[s_idx + 1], + data_a[ib].qs[s_idx + 2], + data_a[ib].qs[s_idx + 3] + )); + const float db = d * (0.5 + (signscale >> 28)) * 0.5; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint sign7 = bitfieldExtract(signscale, 7 * int(l), 7); + // Restore parity bit. + const uint sign8 = sign7 | (bitCount(sign7) << 7); + const uint qs0 = data_a[ib].qs[8 * is + 2 * l]; + const uint qs1 = data_a[ib].qs[8 * is + 2 * l + 1]; + const u8vec4 grid0 = unpack8(iq3xxs_grid[qs0]); + const u8vec4 grid1 = unpack8(iq3xxs_grid[qs1]); + data_b[b_idx + 8 * l + 0] = D_TYPE(db * grid0.x * ((sign8 & 1) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 1] = D_TYPE(db * grid0.y * ((sign8 & 2) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 2] = D_TYPE(db * grid0.z * ((sign8 & 4) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 3] = D_TYPE(db * grid0.w * ((sign8 & 8) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 4] = D_TYPE(db * grid1.x * ((sign8 & 16) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 5] = D_TYPE(db * grid1.y * ((sign8 & 32) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 6] = D_TYPE(db * grid1.z * ((sign8 & 64) != 0 ? -1.0 : 1.0)); + data_b[b_idx + 8 * l + 7] = D_TYPE(db * grid1.w * ((sign8 & 128) != 0 ? -1.0 : 1.0)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp new file mode 100644 index 0000000000..8f7833eab2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_nl.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_nl data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp new file mode 100644 index 0000000000..a313699775 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_iq4_xs.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_iq4_xs data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + // Each thread handles 1 subblock (1 scale and 32 quantized values) + const uint ib = gl_WorkGroupID.x * 32 + gl_LocalInvocationID.x / 8; + + init_iq_shmem(gl_WorkGroupSize); + + if (ib >= p.nel / 256) { + return; + } + + const uint ib32 = gl_LocalInvocationID.x % 8; + + const float d = float(data_a[ib].d); + // Scales are 6 bits + const uint scale = ((data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF) + | (((data_a[ib].scales_h >> (2 * ib32)) & 3) << 4); + const float dl = d * (int(scale) - 32); + + const uint b_idx = 256 * ib + 32 * ib32; + const uint q_idx = 16 * ib32; + [[unroll]] for (uint l = 0; l < 16; ++l) { + data_b[b_idx + l + 0] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(dl * kvalues_iq4nl[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp new file mode 100644 index 0000000000..ffba5a77dd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_mxfp4.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_mxfp4 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + init_iq_shmem(gl_WorkGroupSize); + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = e8m0_to_fp32(data_a[ib].e); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] & 0xF]); + data_b[b_idx + l + 16] = D_TYPE(d * kvalues_mxfp4[data_a[ib].qs[q_idx + l] >> 4]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp new file mode 100644 index 0000000000..58dc2e5dfd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q2_k.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.nel / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 32 * ip + il; + const uint8_t qs = data_a[i].qs[32 * ip + il]; + + FLOAT_TYPE dall = FLOAT_TYPE(data_a[i].d.x); + FLOAT_TYPE dmin = FLOAT_TYPE(data_a[i].d.y); + data_b[y_idx + 0] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+0] & 0xF) * ((qs >> 0) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+0] >> 4)); + data_b[y_idx + 32] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+2] & 0xF) * ((qs >> 2) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+2] >> 4)); + data_b[y_idx + 64] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+4] & 0xF) * ((qs >> 4) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+4] >> 4)); + data_b[y_idx + 96] = D_TYPE(dall * FLOAT_TYPE((data_a[i].scales[is+6] & 0xF) * ((qs >> 6) & 3)) - dmin * FLOAT_TYPE(data_a[i].scales[is+6] >> 4)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp new file mode 100644 index 0000000000..0c90be8b4e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q3_k.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = uint(gl_WorkGroupID.x * 256 + wgy); + if (i >= p.nel / QUANT_K) { + return; + } + + const uint r = gl_LocalInvocationID.x / 4; + const uint tid = r / 2; + const uint is0 = r % 2; + const uint l0 = 16 * is0 + 4 * (gl_LocalInvocationID.x % 4); + const uint n = tid / 4; + const uint j = tid - 4*n; + + const uint8_t m = uint8_t(1 << (4*n + j)); + const uint is = 8*n + 2*j + is0; + const uint shift = 2*j; + + const int8_t us = int8_t(is < 4 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+8] >> 0) & 3) << 4) : + is < 8 ? (data_a[i].scales[is-0] & 0xF) | (((data_a[i].scales[is+4] >> 2) & 3) << 4) : + is < 12 ? (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is+0] >> 4) & 3) << 4) : + (data_a[i].scales[is-8] >> 4) | (((data_a[i].scales[is-4] >> 6) & 3) << 4)); + const FLOAT_TYPE d_all = FLOAT_TYPE(data_a[i].d); + const FLOAT_TYPE dl = d_all * FLOAT_TYPE(us - 32); + + const uint y_idx = i * QUANT_K + 128 * n + 32 * j; + const uint qs_idx = 32*n; + + for (uint l = l0; l < l0 + 4; ++l) { + data_b[y_idx + l] = D_TYPE(dl * FLOAT_TYPE(int8_t((data_a[i].qs[qs_idx + l] >> shift) & 3) - (((data_a[i].hmask[l] & m) != 0) ? 0 : 4))); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp new file mode 100644 index 0000000000..b92b292135 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_0.comp @@ -0,0 +1,30 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint q_idx = 8*il; + const uint b_idx = 1024*i + 32*ir + q_idx; + + const float d = float(data_a[ib].d); + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] & 0xF) - 8.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * ((data_a[ib].qs[q_idx + l] >> 4) - 8.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp new file mode 100644 index 0000000000..6b63cbe583 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_1.comp @@ -0,0 +1,32 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q4_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + data_b[b_idx + l + 0] = D_TYPE(d * (data_a[ib].qs[q_idx + l] & 0xF) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (data_a[ib].qs[q_idx + l] >> 4) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp new file mode 100644 index 0000000000..8b7be557e9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q4_k.comp @@ -0,0 +1,68 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 32, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.nel / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 8; + const uint ir = tid % 8; + const uint is = 2 * il; + const uint n = 4; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + n * ir; + const uint qs_idx = 32*il + n * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + [[unroll]] for (uint l = 0; l < n; ++l) { + data_b[y_idx + l ] = D_TYPE(d1 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] & 0xF) - m1); + data_b[y_idx + l + 32] = D_TYPE(d2 * FLOAT_TYPE(data_a[ib].qs[qs_idx + l] >> 4) - m2); + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp new file mode 100644 index 0000000000..f1b0bac872 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_0.comp @@ -0,0 +1,34 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const uint qh = uint(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10)) - 16.0f)); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10)) - 16.0f)); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp new file mode 100644 index 0000000000..c495b31f17 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_1.comp @@ -0,0 +1,35 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q5_1 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 8*il; + + const float d = float(data_a[ib].d); + const float m = float(data_a[ib].m); + const uint qh = data_a[ib].qh; + + const uint q_idx = 8*il; + + [[unroll]] for (uint l = 0; l < 8; ++l) { + const uint iqs = q_idx + l; + const uint vui = uint(data_a[ib].qs[iqs]); + data_b[b_idx + l + 0] = D_TYPE(d * (((vui & 0xF) | (((qh >> iqs) << 4) & 0x10))) + m); + data_b[b_idx + l + 16] = D_TYPE(d * (((vui >> 4) | ((qh >> (iqs + 12)) & 0x10))) + m); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp new file mode 100644 index 0000000000..6bc04670fc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q5_k.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint ib = gl_WorkGroupID.x * 256 + wgy; + if (ib >= p.nel / QUANT_K) { + return; + } + + const uint tid = gl_LocalInvocationID.x; + const uint il = tid / 16; + const uint ir = tid % 16; + const uint is = 2 * il; + + const FLOAT_TYPE dall = FLOAT_TYPE(data_a[ib].d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(data_a[ib].d.y); + + const uint y_idx = ib * QUANT_K + 64 * il + 2 * ir; + const uint qs_idx = 32*il + 2 * ir; + const uint qh_idx = 2 * ir; + + uint scidx0 = (is < 4) ? is : (is + 4); + uint scidx1 = (is < 4) ? is : (is - 4); + uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint scidxshift1 = (is < 4) ? 0 : 2; + uint mbidx0 = is + 4; + uint mbidx1 = (is < 4) ? is + 4 : is; + uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + uint mbidxshift0 = (is < 4) ? 0 : 4; + uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + uint mbidxshift1 = (is < 4) ? 0 : 2; + + uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d1 = dall * sc; + const FLOAT_TYPE m1 = dmin * mbyte; + + scidx0 = (is < 4) ? is + 1 : (is + 5); + scidx1 = (is < 4) ? is + 1 : (is - 3); + scidxmask1 = (is < 4) ? 0x30 : 0xC0; + scidxshift1 = (is < 4) ? 0 : 2; + mbidx0 = is + 5; + mbidx1 = (is < 4) ? is + 5 : is + 1; + mbidxmask0 = (is < 4) ? 0xF : 0xF0; + mbidxshift0 = (is < 4) ? 0 : 4; + mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + mbidxshift1 = (is < 4) ? 0 : 2; + + sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const FLOAT_TYPE d2 = dall * sc; + const FLOAT_TYPE m2 = dmin * mbyte; + + const uint8_t hm1 = uint8_t(1 << (2 * il )); + const uint8_t hm2 = uint8_t(1 << (2 * il + 1)); + data_b[y_idx ] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] & 0xF) + (((data_a[ib].qh[qh_idx ] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 1] = D_TYPE(d1 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] & 0xF) + (((data_a[ib].qh[qh_idx + 1] & hm1) != 0) ? 16 : 0)) - m1); + data_b[y_idx + 32] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx ] >> 4) + (((data_a[ib].qh[qh_idx ] & hm2) != 0) ? 16 : 0)) - m2); + data_b[y_idx + 33] = D_TYPE(d2 * FLOAT_TYPE((data_a[ib].qs[qs_idx + 1] >> 4) + (((data_a[ib].qh[qh_idx + 1] & hm2) != 0) ? 16 : 0)) - m2); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp new file mode 100644 index 0000000000..c8d6fcb49f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q6_k.comp @@ -0,0 +1,33 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 64, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + [[unroll]] for (uint wgy = 0; wgy < 256; wgy++) { + const uint i = gl_WorkGroupID.x * 256 + wgy; + if (i >= p.nel / QUANT_K) { + return; + } + const uint tid = gl_LocalInvocationID.x; + const uint ip = tid / 32; + const uint il = tid - 32 * ip; + const uint is = 8 * ip + il / 16; + + const uint y_idx = i * QUANT_K + 128 * ip + il; + + const uint ql_idx = 64 * ip + il; + const uint8_t qh = data_a[i].qh[32 * ip + il]; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[i].d); + + data_b[y_idx + 0] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 0] * (int8_t((data_a[i].ql[ql_idx + 0] & 0xF) | (((qh >> 0) & 3) << 4)) - 32))); + data_b[y_idx + 32] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 2] * (int8_t((data_a[i].ql[ql_idx + 32] & 0xF) | (((qh >> 2) & 3) << 4)) - 32))); + data_b[y_idx + 64] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 4] * (int8_t((data_a[i].ql[ql_idx + 0] >> 4) | (((qh >> 4) & 3) << 4)) - 32))); + data_b[y_idx + 96] = D_TYPE(d * FLOAT_TYPE(data_a[i].scales[is + 6] * (int8_t((data_a[i].ql[ql_idx + 32] >> 4) | (((qh >> 6) & 3) << 4)) - 32))); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp new file mode 100644 index 0000000000..10844ddf78 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/dequant_q8_0.comp @@ -0,0 +1,31 @@ +#version 450 + +#include "dequant_head.glsl" + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {block_q8_0 data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_b[];}; + +void main() { + const uint i = gl_WorkGroupID.x * 4 + gl_LocalInvocationID.x / 64; + + const uint tid = gl_LocalInvocationID.x % 64; + const uint il = tid/32; + const uint ir = tid%32; + const uint ib = 32*i + ir; + if (ib >= p.nel / 32) { + return; + } + + const uint b_idx = 1024*i + 32*ir + 16*il; + + const float d = float(data_a[ib].d); + + const uint q_idx = 16*il; + + [[unroll]] for (uint l = 0; l < 16; l += 2) { + data_b[b_idx + l ] = D_TYPE(d * data_a[ib].qs[q_idx + l ]); + data_b[b_idx + l + 1] = D_TYPE(d * data_a[ib].qs[q_idx + l + 1]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp new file mode 100644 index 0000000000..9cef8a8ec3 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/diag_mask_inf.comp @@ -0,0 +1,34 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint ncols; + uint rows_per_channel; + uint n_past; +} p; + +#include "types.glsl" + +layout(local_size_x = 1, local_size_y = 512, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint col = gl_GlobalInvocationID.y; + const uint row = gl_GlobalInvocationID.x; + + if (col >= p.ncols) { + return; + } + + const uint i = row*p.ncols + col; + if (col > p.n_past + row % p.rows_per_channel) { + data_d[i] = D_TYPE(uintBitsToFloat(0xFF800000)); + } else { + data_d[i] = D_TYPE(data_a[i]); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp new file mode 100644 index 0000000000..572472f8a9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/div.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) / FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp new file mode 100644 index 0000000000..b69d4ddb09 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/exp.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "rte.glsl" +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(exp(float(data_a[i]))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp new file mode 100644 index 0000000000..62acbf107a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn.comp @@ -0,0 +1,383 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_shuffle : enable + +#include "types.glsl" +#include "flash_attn_base.glsl" + +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; + +const uint32_t cols_per_iter = WorkGroupSize / D_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +shared FLOAT_TYPE tmpsh[WorkGroupSize]; +shared vec4 tmpshv4[WorkGroupSize]; + +shared float masksh[Bc][Br]; +shared vec4 Qf[Br][HSK / 4]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = gl_LocalInvocationIndex / D_split; + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && + i * Br + r < N) { + Qf[r][d] = vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d]) * p.scale; + } + } + barrier(); + + vec4 Of[Br][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = vec4(0.0); + } + } + + float Lf[Br], Mf[Br]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + float slope[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = 1.0; + } + + // ALiBi + if (p.max_bias > 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + float Sf[Br][cols_per_thread]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = 0.0; + } + } + + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + [[unroll]] for (uint32_t d = 0; d < HSK_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * k_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 K_Tf = dequantize4(ib, iqs, k_offset, BINDING_IDX_K); +#else + vec4 K_Tf = vec4(data_kv4[k_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * k_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += dot(Qf[r][d * D_split + d_tid], K_Tf); + } + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + // Compute sum across the D_split + [[unroll]] for (uint s = D_split / 2; s > 0; s >>= 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Sf[r][c] += subgroupShuffleXor(Sf[r][c], s); + } + } + } + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Sf[r][c] = p.logit_softcap * tanh(Sf[r][c]); + } + } + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + masksh[c][r] = float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)]); + } else { + masksh[c][r] = float(0); + } + } + } + barrier(); + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float mvf = masksh[c * cols_per_iter + col_tid][r]; + + Sf[r][c] += slope[r]*mvf; + } + } + barrier(); + } + + float rowmaxf[Br], Pf[Br][cols_per_thread], rowsumf[Br], eMf[Br], Moldf[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + rowmaxf[r] = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + rowmaxf[r] = max(rowmaxf[r], Sf[r][c]); + } + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + Pf[r][c] = exp(Sf[r][c] - Mf[r]); + } + eMf[r] = exp(Moldf[r] - Mf[r]); + + // Compute sum across row of P + rowsumf[r] = 0.0; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + rowsumf[r] += Pf[r][c]; + } + + Lf[r] = eMf[r]*Lf[r] + rowsumf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] = eMf[r] * Of[r][d]; + } + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] += Pf[r][c] * Vf; + } + } + } + + barrier(); + } + + // reduce across threads + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float rowmaxf, eMf; + + tmpsh[tid] = Mf[r]; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = max(tmpsh[tid], tmpsh[tid + s]); + } + barrier(); + } + rowmaxf = tmpsh[d_tid]; + barrier(); + + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf = exp(Moldf - Mf[r]); + + Lf[r] = eMf*Lf[r]; + + tmpsh[tid] = Lf[r]; + + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + tmpsh[tid] = tmpsh[tid] + tmpsh[tid + s]; + } + barrier(); + } + Lf[r] = tmpsh[d_tid]; + barrier(); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + + Of[r][d] = eMf * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x) / 2; s >= D_split; s >>= 1) { + if (tid < s) { + Of[r][d] += tmpshv4[tid + s]; + tmpshv4[tid] = Of[r][d]; + } + barrier(); + } + Of[r][d] = tmpshv4[d_tid]; + barrier(); + } + } + + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + perElemOpStoreCol0(r, 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(r, 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + float sink = perElemOpGetSink(r, 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ms; + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + + float Lfrcp[Br]; + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + Of[r][d] *= Lfrcp[r]; +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -vec4(ACC_TYPE_MAX), vec4(ACC_TYPE_MAX)); +#endif + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (r < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(r, 4*(d * D_split + d_tid) + comp, Of[r][d][comp], o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < Br; ++r) { + if (i * Br + r < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * HSV + (i * Br + r) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl new file mode 100644 index 0000000000..9b1f153bf7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_base.glsl @@ -0,0 +1,202 @@ + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (constant_id = 0) const uint32_t WorkGroupSize = 128; +layout (constant_id = 1) const uint32_t Br = 1; +layout (constant_id = 2) const uint32_t Bc = 32; +layout (constant_id = 3) const uint32_t HSK = 32; +layout (constant_id = 4) const uint32_t HSV = 32; +layout (constant_id = 5) const uint32_t Clamp = 0; +layout (constant_id = 6) const uint32_t D_split = 16; + +// Round up head sizes to a multiple of 16, for coopmat1/coopmat2 paths +const uint32_t HSK_pad = (HSK + 15) & ~15; +const uint32_t HSV_pad = (HSV + 15) & ~15; + +const bool KV_bounds_check = Clamp != 0; + +layout (push_constant) uniform parameter { + uint32_t N; + uint32_t KV; + + uint32_t ne1; + uint32_t ne2; + uint32_t ne3; + + uint32_t neq2; + uint32_t neq3; + uint32_t nek2; + uint32_t nek3; + uint32_t nev2; + uint32_t nev3; + uint32_t nem1; + uint32_t nem2; + uint32_t nem3; + + uint32_t nb01; + uint32_t nb02; + uint32_t nb03; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t nb21; + uint32_t nb22; + uint32_t nb23; + + float scale; + float max_bias; + float logit_softcap; + + uint32_t mask_n_head_log2; + float m0; + float m1; + + uint32_t gqa_ratio; + uint32_t split_kv; + uint32_t k_num; +} p; + +#define SINK_ENABLE_BIT (1<<24) +#define MASK_ENABLE_BIT (1<<16) +#define N_LOG2_MASK 0xFFFF + +layout (binding = 4) readonly buffer S {float data_s[];}; + +layout (binding = 5) writeonly buffer O {D_TYPE data_o[];}; + +#if defined(A_TYPE_PACKED16) +#define BINDING_IDX_K 0 +#define BINDING_IDX_V 1 +layout (binding = 1) readonly buffer K_PACKED16 {A_TYPE_PACKED16 k_data_packed16[];} k_packed; +layout (binding = 2) readonly buffer V_PACKED16 {A_TYPE_PACKED16 v_data_packed16[];} v_packed; +#endif + +#if defined(DATA_A_Q4_0) +#define BLOCK_BYTE_SIZE 18 + +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + uint vui_lo = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(k_packed.k_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(k_packed.k_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } else { + uint vui_lo = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 0]); + uint vui_hi = uint(v_packed.v_data_packed16[a_offset + ib].qs[(iqs & 0xF) / 2 + 1]); + uint shift = (iqs & 0x10) >> 2; + vui_lo >>= shift; + vui_hi >>= shift; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * (vec4(vui_lo & 0xF, (vui_lo >> 8) & 0xF, vui_hi & 0xF, (vui_hi >> 8) & 0xF) - 8.0f); + } +} +#endif + +#if defined(DATA_A_Q8_0) +#define BLOCK_BYTE_SIZE 34 +vec4 dequantize4(uint ib, uint iqs, uint a_offset, uint binding_idx) { + if (binding_idx == BINDING_IDX_K) { + const i8vec2 v0 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(k_packed.k_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(k_packed.k_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } else { + const i8vec2 v0 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(v_packed.v_data_packed16[a_offset + ib].qs[iqs / 2 + 1])).xy; + + return float(v_packed.v_data_packed16[a_offset + ib].d) * vec4(v0.x, v0.y, v1.x, v1.y); + } +} +#endif + +#define CEIL_DIV(a, b) (((a) + (b) - 1) / (b)) + + +// Store column zero. This is used to save per-row m and L values for split_k. +ACC_TYPE perElemOpStoreCol0(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c == 0) { + uint32_t offset = iq2 + r; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +// Load the slope matrix, indexed by Q's dimension 2. +ACC_TYPE perElemOpComputeSlope(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + uint32_t n_head_log2 = p.mask_n_head_log2 & N_LOG2_MASK; + + const ACC_TYPE base = ACC_TYPE(h < n_head_log2 ? p.m0 : p.m1); + const int exph = int(h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1); + + return ACC_TYPE(pow(base, ACC_TYPE(exph))); +} + +// Load the sink value, indexed by Q's dimension 2. +ACC_TYPE perElemOpGetSink(const in uint32_t r, const in uint32_t c, const in ACC_TYPE elem, const in uint32_t iq2) +{ + const uint32_t h = iq2 + (r % p.gqa_ratio); + + return ACC_TYPE(data_s[h]); +} + +uint32_t i, N, KV, split_k_index, Tr, start_j, end_j, + iq2, iq3, rk2, rk3, rv2, rv3, ik2, ik3, iv2, iv3, + q_stride, k_stride, v_stride, m_stride; + +void init_indices() +{ + N = p.N; + KV = p.KV; + + i = gl_WorkGroupID.x; + split_k_index = 0; + + if (p.k_num > 1) { + i = 0; + split_k_index = gl_WorkGroupID.x; + } + + Tr = CEIL_DIV(N, Br); + + start_j = split_k_index * p.split_kv / Bc; + end_j = CEIL_DIV(min(KV, (split_k_index + 1) * p.split_kv), Bc); + + // When not using grouped query attention, all rows share the same iq2, equal to gl_WorkGroupID.y. + // When using grouped query attention, each workgroup does gqa_ratio consecutive values of iq2. + iq2 = gl_WorkGroupID.y * p.gqa_ratio; + iq3 = gl_WorkGroupID.z; + + // broadcast factors + rk2 = p.neq2/p.nek2; + rk3 = p.neq3/p.nek3; + + rv2 = p.neq2/p.nev2; + rv3 = p.neq3/p.nev3; + + // k indices + ik3 = iq3 / rk3; + ik2 = iq2 / rk2; + + // v indices + iv3 = iq3 / rv3; + iv2 = iq2 / rv2; + + // nb?1 are already divided by the type size and are in units of elements. + // When using grouped query attention, Q is indexed by iq2, so the stride + // should be nb02 (which is in bytes). + q_stride = p.gqa_ratio > 1 ? (p.nb02 / 4) : p.nb01; + k_stride = p.nb11; + v_stride = p.nb21; + // When using grouped query attention, all rows use the same mask (stride 0). + // "p.gqa_ratio >> 16" is just a roundabout way of writing zero + // that prevents the compiler from folding the "&" through the select + // and breaking the alignment detection. + m_stride = (p.gqa_ratio > 1) ? (p.gqa_ratio >> 16) : KV; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp new file mode 100644 index 0000000000..2066a05b34 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm1.comp @@ -0,0 +1,418 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable + +#include "types.glsl" +#include "flash_attn_base.glsl" + +const uint32_t HSK_per_thread = HSK / D_split; +const uint32_t HSV_per_thread = HSV / D_split; + +const uint32_t row_split = 4; +const uint32_t rows_per_thread = Br / row_split; +const uint32_t cols_per_iter = gl_WorkGroupSize.x / D_split / row_split; +const uint32_t cols_per_thread = Bc / cols_per_iter; + + +layout (binding = 0) readonly buffer Q {float data_q[];}; +layout (binding = 0) readonly buffer QV4 {vec4 data_qv4[];}; +layout (binding = 1) readonly buffer K {float16_t data_k[];}; +layout (binding = 1) readonly buffer KV4 {f16vec4 data_kv4[];}; +layout (binding = 2) readonly buffer V {float16_t data_v[];}; +layout (binding = 2) readonly buffer VV4 {f16vec4 data_vv4[];}; +layout (binding = 3) readonly buffer M {float16_t data_m[];}; + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + uint32_t offset = (iq2 + r) * HSV + c; + data_o[o_offset + offset] = D_TYPE(elem); + return elem; +} + +// These need to be supported N,M values for a MatBc x MatBr x 16 coopmatmuladd +const uint32_t MatBr = 16; +const uint32_t MatBc = 16; + +shared FLOAT_TYPE tmpsh[gl_WorkGroupSize.x]; +shared ACC_TYPEV4 tmpshv4[gl_WorkGroupSize.x]; + +const uint32_t qstride = HSK_pad / 4 + 2; // in units of f16vec4 +shared f16vec4 Qf[Br * qstride]; + +// Avoid padding for hsk==256 to make it fit in 48KB shmem. +const uint32_t sfshstride = (HSK <= 128) ? (Br + 8) : Br; +shared ACC_TYPE sfsh[Bc * sfshstride]; + +const uint32_t kshstride = HSK_pad / 4 + 2; // in units of f16vec4 +shared f16vec4 ksh[Bc * kshstride]; + +shared float slope[Br]; + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + const uint32_t tid = gl_LocalInvocationIndex; + + const uint32_t threads_per_rowgroup = gl_WorkGroupSize.x / row_split; + const uint32_t row_tid = gl_LocalInvocationIndex / threads_per_rowgroup; + const uint32_t d_tid = gl_LocalInvocationIndex % D_split; + const uint32_t col_tid = (gl_LocalInvocationIndex % threads_per_rowgroup) / D_split; + +#define tile_row(r) (row_tid * rows_per_thread + (r)) + + // Zero-initialize shared memory for Q/K when HSK is not a multiple of 16 (HSK_pad > HSK). + if ((HSK % 16) != 0) { + [[unroll]] for (uint i = 0; i < Br * qstride; i += gl_WorkGroupSize.x) { + if (i + tid < Br * qstride) { + Qf[i + tid] = f16vec4(0); + } + } + [[unroll]] for (uint i = 0; i < Bc * kshstride; i += gl_WorkGroupSize.x) { + if (i + tid < Bc * kshstride) { + ksh[i + tid] = f16vec4(0); + } + } + barrier(); + } + + uint32_t q_offset = (iq2*p.nb02+iq3*p.nb03) / 4; + + [[unroll]] for (uint32_t idx = 0; idx < Br * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t r = (idx + tid) / (HSK / 4); + if (r < Br && d < HSK / 4 && + i * Br + r < N) { + Qf[r * qstride + d] = f16vec4(data_qv4[q_offset / 4 + (i * Br + r) * q_stride / 4 + d] * p.scale); + } + } + barrier(); + + ACC_TYPEV4 Of[rows_per_thread][HSV_per_thread / 4]; + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPEV4(0.0); + } + } + + float Lf[rows_per_thread], Mf[rows_per_thread]; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = 0; + Mf[r] = NEG_FLT_MAX_OVER_2; + } + + // ALiBi + if (p.max_bias > 0.0f) { + if (tid < Br) { + uint r = tid; + slope[r] = perElemOpComputeSlope(r, col_tid, ACC_TYPE(0), iq2); + } + barrier(); + } else { + if (tid < Br) { + uint r = tid; + slope[r] = 1.0; + } + barrier(); + } + +#if BLOCK_SIZE > 1 + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / BLOCK_BYTE_SIZE; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / BLOCK_BYTE_SIZE; +#else + uint32_t k_offset = (ik2*p.nb12 + ik3*p.nb13) / 2; + uint32_t v_offset = (iv2*p.nb22 + iv3*p.nb23) / 2; +#endif + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV; + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + [[unroll]] for (uint32_t idx = 0; idx < Bc * HSK / 4; idx += gl_WorkGroupSize.x) { + uint32_t d = (idx + tid) % (HSK / 4); + uint32_t c = (idx + tid) / (HSK / 4); + if (c < Bc && d < HSK / 4) { + f16vec4 K_Tf = f16vec4(0); + if (!KV_bounds_check || j * Bc + c < KV) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c) * k_stride * BLOCK_SIZE + 4 * d; + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + K_Tf = f16vec4(dequantize4(ib, iqs, k_offset, BINDING_IDX_K)); +#else + K_Tf = f16vec4(data_kv4[k_offset / 4 + (j * Bc + c) * k_stride / 4 + d]); +#endif + } + + ksh[c * kshstride + d] = K_Tf; + } + } + barrier(); + + // K * Q^T -> S^T: Bc x HSK_pad * HSK_pad x Br -> Bc x Br + // Bc split across workgroup (four subgroups), loop over HSK in chunks of 16: 16 x 16 * 16 x 16 -> 16 x 16 + // This is written transposed in order to allow for N being 8 if implementations need it + coopmat SfMat = coopmat(0); + coopmat KMat; + coopmat QMat; + + for (uint32_t d = 0; d < HSK_pad / 16; ++d) { + coopMatLoad(QMat, Qf, d * 16 / 4, qstride, gl_CooperativeMatrixLayoutColumnMajor); + + uint coord = (gl_SubgroupID * MatBc) * kshstride + d * 16 / 4; + coopMatLoad(KMat, ksh, coord, kshstride, gl_CooperativeMatrixLayoutRowMajor); + + SfMat = coopMatMulAdd(KMat, QMat, SfMat); + } + + uint coord = gl_SubgroupID * MatBc * sfshstride; + coopMatStore(SfMat, sfsh, coord, sfshstride, gl_CooperativeMatrixLayoutRowMajor); + barrier(); + + if (p.logit_softcap != 0.0f) { + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) / Br; + uint32_t r = (idx + tid) % Br; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + sfsh[c * sfshstride + r] = ACC_TYPE(p.logit_softcap * tanh(sfsh[c * sfshstride + r])); + } + } + barrier(); + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + [[unroll]] for (uint32_t idx = 0; idx < Bc * Br; idx += gl_WorkGroupSize.x) { + uint32_t c = (idx + tid) % Bc; + uint32_t r = (idx + tid) / Bc; + if (idx + tid < Bc * Br || idx + gl_WorkGroupSize.x <= Bc * Br) { + if ((!KV_bounds_check || j * Bc + c < KV) && (!nem1_bounds_check || i * Br + r < p.nem1)) { + sfsh[c * sfshstride + r] += ACC_TYPE(slope[r] * float(data_m[m_offset + (i * Br + r) * m_stride + (j * Bc + c)])); + } + } + } + barrier(); + } + + float eMf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float rowmaxf = NEG_FLT_MAX_OVER_2; + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + rowmaxf = max(rowmaxf, float(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride])); + } + float Moldf = Mf[r]; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf, Moldf); + eMf[r] = exp(Moldf - Mf[r]); + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + } + } + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t c = 0; c < cols_per_thread; ++c) { + if (KV_bounds_check && j * Bc + c * cols_per_iter + col_tid >= KV) { + continue; + } + float Pf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Pf[r] = exp(sfsh[tile_row(r) + (c * cols_per_iter + col_tid) * sfshstride] - Mf[r]); + Lf[r] += Pf[r]; + } + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { +#if BLOCK_SIZE > 1 + uint coord = (j * Bc + c * cols_per_iter + col_tid) * v_stride * BLOCK_SIZE + 4 * (d * D_split + d_tid); + uint ib = coord / BLOCK_SIZE; + uint iqs = (coord % BLOCK_SIZE); + vec4 Vf = dequantize4(ib, iqs, v_offset, BINDING_IDX_V); +#else + vec4 Vf = vec4(data_vv4[v_offset / 4 + (j * Bc + c * cols_per_iter + col_tid) * v_stride / 4 + d * D_split + d_tid]); +#endif + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] += ACC_TYPE(Pf[r]) * ACC_TYPEV4(Vf); + } + } + } + + barrier(); + } + + // reduce across threads + + float rowmaxf[rows_per_thread], eMf[rows_per_thread], Moldf[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE M = Mf[r]; + tmpsh[tid] = M; + // Compute max across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + M = max(M, tmpsh[tid ^ s]); + barrier(); + tmpsh[tid] = M; + barrier(); + } + rowmaxf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Moldf[r] = Mf[r]; + + // M = max(rowmax, Mold) + // eM = e^(Mold - M) + Mf[r] = max(rowmaxf[r], Moldf[r]); + eMf[r] = exp(Moldf[r] - Mf[r]); + + Lf[r] = eMf[r]*Lf[r]; + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + FLOAT_TYPE L = Lf[r]; + tmpsh[tid] = L; + // Compute sum across the row + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + L += tmpsh[tid ^ s]; + barrier(); + tmpsh[tid] = L; + barrier(); + } + Lf[r] = tmpsh[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + + Of[r][d] = ACC_TYPE(eMf[r]) * Of[r][d]; + tmpshv4[tid] = Of[r][d]; + + barrier(); + [[unroll]] for (int s = int(gl_WorkGroupSize.x / row_split) / 2; s >= D_split; s >>= 1) { + Of[r][d] += tmpshv4[tid ^ s]; + barrier(); + tmpshv4[tid] = Of[r][d]; + barrier(); + } + Of[r][d] = tmpshv4[d_tid + row_tid * threads_per_rowgroup]; + barrier(); + } + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Lf[r]), o_offset, iq2, N); + perElemOpStoreCol0(tile_row(r), 0u, ACC_TYPE(Mf[r]), o_offset + p.ne1, iq2, N); + } + } + + return; + } + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + float sink = perElemOpGetSink(tile_row(r), 0u, ACC_TYPE(0), iq2); + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > Mf[r]) { + ms = exp(Mf[r] - sink); + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + Of[r][d] *= ACC_TYPE(ms); + } + } else { + vs = exp(sink - Mf[r]); + } + + Lf[r] = Lf[r]*ms + vs; + } + } + + float Lfrcp[rows_per_thread]; + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Lfrcp[r] = 1.0 / Lf[r]; + } + + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + Of[r][d] *= ACC_TYPE(Lfrcp[r]); +#if defined(ACC_TYPE_MAX) + Of[r][d] = clamp(Of[r][d], -ACC_TYPE_MAX, ACC_TYPE_MAX); +#endif + } + } + + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + + if (p.gqa_ratio > 1) { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + perElemOpGqaStore(tile_row(r), 4*(d * D_split + d_tid) + comp, float(Of[r][d][comp]), o_offset, iq2, N); + } + } + } + } + } else { + [[unroll]] for (uint32_t r = 0; r < rows_per_thread; ++r) { + if (i * Br + tile_row(r) < N) { + [[unroll]] for (uint32_t d = 0; d < HSV_per_thread / 4; ++d) { + [[unroll]] for (uint32_t comp = 0; comp < 4; ++comp) { + data_o[o_offset + iq2 * HSV + (i * Br + tile_row(r)) * p.ne1 * HSV + 4*(d * D_split + d_tid) + comp] = D_TYPE(Of[r][d][comp]); + } + } + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp new file mode 100644 index 0000000000..910da1ab0c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_cm2.comp @@ -0,0 +1,320 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#extension GL_EXT_null_initializer : enable + +#include "types.glsl" +#include "dequant_funcs_cm2.glsl" +#include "flash_attn_base.glsl" + +layout (binding = 0) readonly buffer Q {uint8_t data_q[];}; +layout (binding = 1) readonly buffer K {uint8_t data_k[];}; +layout (binding = 2) readonly buffer V {uint8_t data_v[];}; +layout (binding = 3) readonly buffer M {uint8_t data_m[];}; + +ACC_TYPE maxReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return max(x, y); +} + +ACC_TYPE smearReduce(const in ACC_TYPE x, const in ACC_TYPE y) { + return x; +} + +// Replace matrix elements >= numRows or numCols with 'replace' +ACC_TYPE replacePadding(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem, const in ACC_TYPE replace, const in uint32_t numRows, const in uint32_t numCols) { + if (row >= numRows || col >= numCols) { + return replace; + } + return elem; +} + +ACC_TYPE Exp(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem) +{ + return exp(elem); +} + +ACC_TYPE Max(const in uint32_t row, const in uint32_t col, const in ACC_TYPE elem0, const in ACC_TYPE elem1) +{ + return max(elem0, elem1); +} + +#if defined(BLOCK_SIZE) +#define DECODEFUNC , DEQUANTFUNC +#else +#define DECODEFUNC +#endif + +// Store the output when doing grouped query attention. +// Rows index by Q's dimension 2, and the first N rows are valid. +D_TYPE perElemOpGqaStore(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t o_offset, const in uint32_t iq2, const in uint32_t N) +{ + if (r < N && c < HSV) { + uint32_t offset = (iq2 + r) * HSV + c; + data_o[o_offset + offset] = D_TYPE(elem); + } + return elem; +} + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + init_indices(); + + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutQ = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, Clamp> tensorLayoutK = createTensorLayoutNV(2, Clamp); + tensorLayoutNV<2, Clamp> tensorLayoutV = createTensorLayoutNV(2, Clamp); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if defined(BLOCK_SIZE) + tensorLayoutK = setTensorLayoutBlockSizeNV(tensorLayoutK, 1, BLOCK_SIZE); + tensorLayoutV = setTensorLayoutBlockSizeNV(tensorLayoutV, 1, BLOCK_SIZE); +#endif + + tensorLayoutQ = setTensorLayoutDimensionNV(tensorLayoutQ, N, HSK); + tensorLayoutK = setTensorLayoutDimensionNV(tensorLayoutK, KV, HSK); + tensorLayoutV = setTensorLayoutDimensionNV(tensorLayoutV, KV, HSV); + + // hint to the compiler that strides are aligned for the aligned variant of the shader + if (Clamp != gl_CooperativeMatrixClampModeConstantNV) + { + q_stride &= ~7; +#if !defined(BLOCK_SIZE) + k_stride &= ~7; + v_stride &= ~7; +#endif + m_stride &= ~7; + } + tensorLayoutQ = setTensorLayoutStrideNV(tensorLayoutQ, q_stride, 1); + tensorLayoutK = setTensorLayoutStrideNV(tensorLayoutK, k_stride, 1); + tensorLayoutV = setTensorLayoutStrideNV(tensorLayoutV, v_stride, 1); + + coopmat Q; + coopmat Qf16; + + uint32_t q_offset = iq2*p.nb02+iq3*p.nb03; + coopMatLoadTensorNV(Q, data_q, q_offset, sliceTensorLayoutNV(tensorLayoutQ, i * Br, Br, 0, HSK_pad)); + + Qf16 = coopmat(Q); + Qf16 *= float16_t(p.scale); + + coopmat O = coopmat(0); + + coopmat L, M; + + // Use -FLT_MAX/2 rather than -inf to reduce the possibility of NaNs, e.g. when computing Mold-M. + const float NEG_FLT_MAX_OVER_2 = uintBitsToFloat(0xFEFFFFFF); + + L = coopmat(0); + M = coopmat(NEG_FLT_MAX_OVER_2); + + coopmat slopeMat = coopmat(1.0); + + // ALiBi + if (p.max_bias > 0.0f) { + coopMatPerElementNV(slopeMat, slopeMat, perElemOpComputeSlope, iq2); + } + + uint32_t m_offset = 0; + if (p.nem2 != 1 || p.nem3 != 1) { + m_offset = ((iq3 % p.nem3) * p.nem2 + (iq2 % p.nem2)) * p.nem1 * KV * 2 /*sizeof(float16_t)*/; + } + + [[dont_unroll]] + for (uint32_t j = start_j; j < end_j; ++j) { + + coopmat S = coopmat(0); + + coopmat K_T; + + uint32_t k_offset = ik2*p.nb12 + ik3*p.nb13; + coopMatLoadTensorNV(K_T, data_k, k_offset, sliceTensorLayoutNV(tensorLayoutK, j * Bc, Bc, 0, HSK_pad), tensorViewTranspose DECODEFUNC); + S = coopMatMulAdd(Qf16, K_T, S); + + if (p.logit_softcap != 0.0f) { + [[unroll]] + for (int k = 0; k < S.length(); ++k) { + S[k] = ACC_TYPE(p.logit_softcap)*tanh(S[k]); + } + } + + if ((p.mask_n_head_log2 & MASK_ENABLE_BIT) != 0) { + bool nem1_bounds_check = !(p.gqa_ratio > 1) && (p.nem1 % Br) != 0; + + if (nem1_bounds_check) { + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutM = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, p.nem1, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } else { + tensorLayoutNV<2, Clamp> tensorLayoutM = createTensorLayoutNV(2, Clamp); + // Don't clamp against nem1 when GQA is enabled + uint32_t m_height = p.gqa_ratio > 1 ? ~0 : p.nem1; + tensorLayoutM = setTensorLayoutDimensionNV(tensorLayoutM, m_height, KV); + tensorLayoutM = setTensorLayoutStrideNV(tensorLayoutM, m_stride, 1); + + coopmat mv; + + coopMatLoadTensorNV(mv, data_m, m_offset, sliceTensorLayoutNV(tensorLayoutM, i * Br, Br, j * Bc, Bc)); + + S += slopeMat*coopmat(mv); + } + } + + // Clear padding elements to -inf, so they don't contribute to rowmax + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(S, S, replacePadding, ACC_TYPE(NEG_FLT_MAX_OVER_2), R, C); + } + + coopmat rowmax, P, rowsum, eM; + + coopMatReduceNV(rowmax, S, gl_CooperativeMatrixReduceRowNV, maxReduce); + + coopmat Mold = M; + + // M = max(rowmax, Mold) + // P = e^(S - M) + // eM = e^(Mold - M) + coopMatPerElementNV(M, rowmax, Max, Mold); + coopMatPerElementNV(P, S - M, Exp); + coopMatPerElementNV(eM, Mold - M, Exp); + + // Clear padding elements to 0, so they don't contribute to rowsum + if (Clamp != 0 && + ((j + 1) * Bc > KV || + (i + 1) * Br > N)) { + + uint R = ((i + 1) * Br > N) ? (N % Br) : Br; + uint C = ((j + 1) * Bc > KV) ? (KV % Bc) : Bc; + + coopMatPerElementNV(P, P, replacePadding, ACC_TYPE(0.0), R, C); + } + + coopmat P_A = coopmat(P); + + // compute rowsum by multiplying by matrix of all ones. + coopmat One = coopmat(1.0); + + rowsum = coopmat(0.0); + rowsum = coopMatMulAdd(P_A, One, rowsum); + + coopmat V; + uint32_t v_offset = iv2*p.nb22 + iv3*p.nb23; + coopMatLoadTensorNV(V, data_v, v_offset, sliceTensorLayoutNV(tensorLayoutV, j * Bc, Bc, 0, HSV_pad) DECODEFUNC); + + L = eM*L + rowsum; + + // This is the "diagonal" matrix in the paper, but since we do componentwise + // multiply rather than matrix multiply it has the diagonal element smeared + // across the row + coopmat eMdiag; + + // resize eM by using smear/reduce + coopMatReduceNV(eMdiag, eM, gl_CooperativeMatrixReduceRowNV, smearReduce); + + // multiply with fp16 accumulation, then add to O. + coopmat PV = coopmat(0); + PV = coopMatMulAdd(P_A, V, PV); + + O = eMdiag * O + coopmat(PV); + } + + // If there is split_k, then the split_k resolve shader does the final + // division by L. Store the intermediate O value and per-row m and L values. + if (p.k_num > 1) { + coopmat O_D = coopmat(O); + + uint32_t o_offset = HSV * p.ne1 * (split_k_index + iq3 * p.k_num); + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + + o_offset = HSV * p.ne1 * p.ne3 * p.k_num + p.ne1 * (split_k_index + iq3 * p.k_num) * 2; + coopMatPerElementNV(L, L, perElemOpStoreCol0, o_offset, iq2, N); + coopMatPerElementNV(M, M, perElemOpStoreCol0, o_offset + p.ne1, iq2, N); + return; + } + + coopmat Ldiag; + + // resize L by using smear/reduce + coopMatReduceNV(Ldiag, L, gl_CooperativeMatrixReduceRowNV, smearReduce); + + if ((p.mask_n_head_log2 & SINK_ENABLE_BIT) != 0) { + coopmat S; + coopMatPerElementNV(S, S, perElemOpGetSink, iq2); + + coopmat Mr; + + // resize M by using smear/reduce + coopMatReduceNV(Mr, M, gl_CooperativeMatrixReduceRowNV, smearReduce); + + // O, Ldiag, Mr all have the same type so all element locations match + [[unroll]] for (uint32_t i = 0; i < Ldiag.length(); ++i) { + ACC_TYPE sink = S[i]; + + ACC_TYPE ms = ACC_TYPE(1.0f); + ACC_TYPE vs = ACC_TYPE(1.0f); + + if (sink > Mr[i]) { + ms = exp(Mr[i] - sink); + + O[i] *= ms; + } else { + vs = exp(sink - Mr[i]); + } + + Ldiag[i] = Ldiag[i]*ms + vs; + } + } + + [[unroll]] + for (int k = 0; k < Ldiag.length(); ++k) { + Ldiag[k] = ACC_TYPE(1.0) / Ldiag[k]; + } + + O = Ldiag*O; + +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < O.length(); ++i) { O[i] = clamp(O[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + uint32_t o_offset = iq3*p.ne2*p.ne1*HSV; + + coopmat O_D = coopmat(O); + if (p.gqa_ratio > 1) { + coopMatPerElementNV(O_D, O_D, perElemOpGqaStore, o_offset, iq2, N); + } else { + tensorLayoutNV<3, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(3, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.ne2, p.ne1, HSV); + + // permute dimensions + tensorViewNV<3, false, 1, 0, 2> tensorViewPermute = createTensorViewNV(3, false, 1, 0, 2); + + coopMatStoreTensorNV(O_D, data_o, o_offset, sliceTensorLayoutNV(tensorLayoutD, i * Br, Br, iq2, N, 0, HSV_pad), tensorViewPermute); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp new file mode 100644 index 0000000000..06e83822fe --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/flash_attn_split_k_reduce.comp @@ -0,0 +1,120 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 1) readonly buffer B {float data_s[];}; +layout (binding = 2) writeonly buffer D {float data_d[];}; + +layout (push_constant) uniform parameter { + uint D; + uint N; + uint ne3; + uint k_num; + uint sinks; +} p; + +shared float tmpsh[BLOCK_SIZE]; + +void main() { + // Each workgroup handles a row + const uint n = gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + const uint iq3 = gl_WorkGroupID.z; + + uint D = p.D; + uint N = p.N; + uint k_num = p.k_num; + + uint l_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + n; + uint m_offset = D * N * p.ne3 * k_num + N * iq3 * k_num * 2 + N + n; + uint lm_stride = N * 2; + + // Compute the max m value for the row + float m_max = -1.0/0.0; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float m = data_a[m_offset + (k + tid) * lm_stride]; + m_max = max(m_max, m); + } + + // reduce across the workgroup + tmpsh[tid] = m_max; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + m_max = max(m_max, tmpsh[tid + s]); + tmpsh[tid] = m_max; + } + barrier(); + } + m_max = tmpsh[0]; + + barrier(); + + // Compute L based on m_max + float L = 0; + for (uint k = 0; k + tid < k_num; k += BLOCK_SIZE) { + float l = data_a[l_offset + (k + tid) * lm_stride]; + float m = data_a[m_offset + (k + tid) * lm_stride]; + L += exp(m - m_max) * l; + } + + // reduce across the workgroup + tmpsh[tid] = L; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + L += tmpsh[tid + s]; + tmpsh[tid] = L; + } + barrier(); + } + L = tmpsh[0]; + + float sink; + if (p.sinks != 0) { + sink = data_s[n]; + + float ms = 1.0f; + float vs = 1.0f; + + if (sink > m_max) { + ms = exp(m_max - sink); + } else { + vs = exp(sink - m_max); + } + + L = L*ms + vs; + } + + L = 1.0 / L; + + // D dimension is split across workgroups in the y dimension + uint d = tid + gl_WorkGroupID.y * BLOCK_SIZE; + // Scale and sum the O contributions based on m_max and store the result to memory + if (d < D) { + float O = 0.0; + [[unroll]] for (uint k = 0; k < k_num; ++k) { + uint o_offset = D * N * (k + iq3 * k_num) + D * n + d; + float m = data_a[m_offset + k * lm_stride]; + O += exp(m - m_max) * data_a[o_offset]; + } + if (p.sinks != 0) { + if (sink > m_max) { + float ms = 1.0f; + ms = exp(m_max - sink); + O *= ms; + } + } + O *= L; + + const float FLT_MAX = uintBitsToFloat(0x7F7FFFFF); + O = clamp(O, -FLT_MAX, FLT_MAX); + + data_d[iq3 * D * N + D * n + d] = O; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp new file mode 100644 index 0000000000..e017b50368 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu.comp @@ -0,0 +1,13 @@ +#version 450 + +#include "glu_head.glsl" + +const float GELU_COEF_A = 0.044715f; +const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + +float op(float a, float b) { + const float val = SQRT_2_OVER_PI*a*(1.0f + GELU_COEF_A*a*a); + return 0.5f*a*(2.0f - 2.0f / (exp(2 * val) + 1)) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp new file mode 100644 index 0000000000..759a1848fa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_erf.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "glu_head.glsl" + +// based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation +// ref: https://www.johndcook.com/blog/python_erf/ +const float p_erf = 0.3275911f; +const float a1_erf = 0.254829592f; +const float a2_erf = -0.284496736f; +const float a3_erf = 1.421413741f; +const float a4_erf = -1.453152027f; +const float a5_erf = 1.061405429f; + +const float SQRT_2_INV = 0.70710678118654752440084436210484f; + +float op(float a, float b) { + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + return 0.5f * a * (1.0f + erf_approx) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp new file mode 100644 index 0000000000..c4032ab21d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/geglu_quick.comp @@ -0,0 +1,11 @@ +#version 450 + +#include "glu_head.glsl" + +const float GELU_QUICK_COEF = -1.702f; + +float op(float a, float b) { + return a * (1.0f / (1.0f + exp(GELU_QUICK_COEF * a))) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp new file mode 100644 index 0000000000..a95c2525c8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu.comp @@ -0,0 +1,25 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_COEF_A = 0.044715f; + const float SQRT_2_OVER_PI = 0.79788456080286535587989211986876f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + const float val = SQRT_2_OVER_PI*xi*(1.0f + GELU_COEF_A*xi*xi); + data_d[i] = D_TYPE(0.5f*xi*(2.0f - 2.0f / (exp(2 * val) + 1))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp new file mode 100644 index 0000000000..58375aba09 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_erf.comp @@ -0,0 +1,39 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + // based on Abramowitz and Stegun formula 7.1.26 or similar Hastings' approximation + // ref: https://www.johndcook.com/blog/python_erf/ + const float p_erf = 0.3275911f; + const float a1_erf = 0.254829592f; + const float a2_erf = -0.284496736f; + const float a3_erf = 1.421413741f; + const float a4_erf = -1.453152027f; + const float a5_erf = 1.061405429f; + + const float SQRT_2_INV = 0.70710678118654752440084436210484f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float a = float(data_a[i]); + const float a_div_sqr2 = a * SQRT_2_INV; + const float sign_x = sign(a_div_sqr2); + const float x = abs(a_div_sqr2); + const float t = 1.0f / (1.0f + p_erf * x); + const float y = 1.0f - (((((a5_erf * t + a4_erf) * t) + a3_erf) * t + a2_erf) * t + a1_erf) * t * exp(-x * x); + const float erf_approx = sign_x * y; + + data_d[i] = D_TYPE(0.5f * a * (1.0f + erf_approx)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp new file mode 100644 index 0000000000..bfdfe2182d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/gelu_quick.comp @@ -0,0 +1,23 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const float GELU_QUICK_COEF = -1.702f; + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * (1.0f / (1.0f + exp(GELU_QUICK_COEF * x)))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl new file mode 100644 index 0000000000..99595fc688 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_binary_head.glsl @@ -0,0 +1,51 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +#include "rte.glsl" +#include "utils.glsl" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint ne20; uint ne21; uint ne22; uint ne23; uint nb20; uint nb21; uint nb22; uint nb23; + uint misalign_offsets; + float param1; float param2; int param3; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +// true if src0/src1 are the same shape and the indices can be reused without additional modulus +layout(constant_id = 0) const bool norepeat = false; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_boffset() { return (p.misalign_offsets >> 8) & 0xFF; } +uint get_doffset() { return p.misalign_offsets & 0xFF; } + + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03) { + get_indices(idx, i00, i01, i02, i03, p.ne00, p.ne01, p.ne02, p.ne03); +} + +uint src0_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint src1_idx(uint i00, uint i01, uint i02, uint i03) { + if (norepeat) { + return i03*p.nb13 + i02*p.nb12 + i01*p.nb11 + i00*p.nb10; + } else { + return fastmod(i03, p.ne13)*p.nb13 + fastmod(i02, p.ne12)*p.nb12 + fastmod(i01, p.ne11)*p.nb11 + fastmod(i00, p.ne10)*p.nb10; + } +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb23 + i02*p.nb22 + i01*p.nb21 + i00*p.nb20; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl new file mode 100644 index 0000000000..66e46ae679 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_head.glsl @@ -0,0 +1,9 @@ +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + float param1; + float param2; +} p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl new file mode 100644 index 0000000000..8dc9d360d5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/generic_unary_head.glsl @@ -0,0 +1,76 @@ +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + float param1; float param2; + + uint ne0_012mp; uint ne0_012L; + uint ne0_01mp; uint ne0_01L; + uint ne0_0mp; uint ne0_0L; + uint ne1_012mp; uint ne1_012L; + uint ne1_01mp; uint ne1_01L; + uint ne1_0mp; uint ne1_0L; +} p; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + +uint src0_idx(uint idx) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; +} + +uint dst_idx(uint idx) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; +} + +uint src0_idx_quant(uint idx, uint qk) { + const uint i03 = fastdiv(idx, p.ne0_012mp, p.ne0_012L); + const uint i03_offset = i03 * p.ne02*p.ne01*p.ne00; + const uint i02 = fastdiv(idx - i03_offset, p.ne0_01mp, p.ne0_01L); + const uint i02_offset = i02*p.ne01*p.ne00; + const uint i01 = fastdiv(idx - i03_offset - i02_offset, p.ne0_0mp, p.ne0_0L); + const uint i00 = idx - i03_offset - i02_offset - i01*p.ne00; + return i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + (i00/qk)*p.nb00; +} + +uint dst_idx_quant(uint idx, uint qk) { + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + (i10/qk)*p.nb10; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp new file mode 100644 index 0000000000..76d83041ce --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = gl_GlobalInvocationID.x; + + if (i00 >= p.ne00) { + return; + } + + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; + + const uint i01 = data_b[get_boffset() + i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = get_aoffset() + i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = get_doffset() + i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + +#if defined(DATA_A_BF16) + FLOAT_TYPE v = FLOAT_TYPE(bf16_to_fp32(data_a[a_offset + i00])); +#else + FLOAT_TYPE v = FLOAT_TYPE(data_a[a_offset + i00]); +#endif +#ifndef OPTIMIZATION_ERROR_WORKAROUND + data_d[d_offset + i00] = D_TYPE(v); +#else + data_d[d_offset + i00] = D_TYPE(v); +#endif + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp new file mode 100644 index 0000000000..9dba437edb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/get_rows_quant.comp @@ -0,0 +1,51 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "types.glsl" +#include "generic_binary_head.glsl" +#include "dequant_funcs.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint i00 = (gl_GlobalInvocationID.x)*2; + +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + if (i00 >= p.ne00) { + return; + } + + uint gid_z = gl_GlobalInvocationID.z; + while (gid_z < p.ne11 * p.ne12) { + uint gid_y = gl_GlobalInvocationID.y; + while (gid_y < p.ne10) { + const uint i10 = gid_y; + const uint i11 = gid_z / p.ne12; + const uint i12 = gid_z % p.ne12; + + const uint i01 = data_b[i10*p.nb10 + i11*p.nb11 + i12*p.nb12]; + + const uint a_offset = i01*p.nb01 + i11*p.nb02 + i12*p.nb03; + const uint d_offset = i10*p.nb21 + i11*p.nb22 + i12*p.nb23; + + const uint ib = a_offset + i00/QUANT_K; // block index + const uint iqs = (i00%QUANT_K)/QUANT_R; // quant index + const uint iybs = i00 - i00%QUANT_K; // dst block start index + const uint y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + vec2 v = dequantize(ib, iqs, 0); + const vec2 dm = get_dm(ib, 0); + v = v * dm.x + dm.y; + + data_d[d_offset + iybs + iqs ] = D_TYPE(v.x); + data_d[d_offset + iybs + iqs + y_offset] = D_TYPE(v.y); + + gid_y += gl_WorkGroupSize.y * gl_NumWorkGroups.y; + } + gid_z += gl_WorkGroupSize.z * gl_NumWorkGroups.z; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl new file mode 100644 index 0000000000..2168989340 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_head.glsl @@ -0,0 +1,19 @@ +#extension GL_EXT_shader_16bit_storage : require + +#include "rte.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {A_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter +{ + uint N; + uint ne00; + uint ne20; + uint mode; + float alpha; + float limit; +} p; diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl new file mode 100644 index 0000000000..85cf65a9ec --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/glu_main.glsl @@ -0,0 +1,29 @@ +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.N) { + return; + } + + const uint row = i / p.ne20; + const uint col = i - row * p.ne20; + + if (p.mode == 0) { + // Default + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx]), float(data_a[idx + offset]))); + } else if (p.mode == 1) { + // Swapped + const uint offset = p.ne00 / 2; + const uint idx = row * p.ne00 + col; + + data_d[row * offset + col] = D_TYPE(op(float(data_a[idx + offset]), float(data_a[idx]))); + } else { + // Split + const uint idx = row * p.ne00 + col; + + data_d[idx] = D_TYPE(op(float(data_a[idx]), float(data_b[idx]))); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp new file mode 100644 index 0000000000..bdf97dbb5d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/group_norm.comp @@ -0,0 +1,66 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared float tmp[BLOCK_SIZE]; + +void main() { + const uint group_size = p.KX; + const float eps = p.param1; + + const uint tid = gl_LocalInvocationID.x; + const uint start = gl_WorkGroupID.x * group_size + tid; + const uint end = (gl_WorkGroupID.x + 1) * group_size; + + tmp[tid] = 0.0f; + + // Calculate mean + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + tmp[tid] += float(data_a[col]); + } + + // tmp up partial tmps and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float mean = tmp[0] / group_size; + barrier(); + tmp[tid] = 0.0f; + + // Calculate variance + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + const float xi = float(data_a[col]) - mean; + data_d[col] = D_TYPE(xi); + tmp[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + const float variance = tmp[0] / group_size; + const float scale = inversesqrt(variance + eps); + + [[unroll]] for (uint col = start; col < end; col += BLOCK_SIZE) { + data_d[col] *= D_TYPE(scale); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp new file mode 100644 index 0000000000..b4dbdf3141 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardsigmoid.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp new file mode 100644 index 0000000000..1ec315915e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/hardswish.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float x = float(data_a[i]); + data_d[i] = D_TYPE(x * min(1.0f, max(0.0f, (x + 3.0f) / 6.0f))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp new file mode 100644 index 0000000000..1827d647a2 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col.comp @@ -0,0 +1,103 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require + +#include "rte.glsl" +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + BDA_STORAGE_T dst_addr; + uint batch_offset; uint offset_delta; + uint IC; + uint IW; uint IH; + uint OW; uint OH; + uint KW; uint KH; + uint pelements; + uint CHW; + int s0; int s1; + int p0; int p1; + int d0; int d1; +} p; + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +const uint NUM_ITER = 512 / BLOCK_SIZE; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + +void main() { + const uint gidx = gl_GlobalInvocationID.x; + + const uint oh = gl_GlobalInvocationID.y; + const uint batch = gl_GlobalInvocationID.z / p.IC; + const uint ic = gl_GlobalInvocationID.z % p.IC; + + const uint src_base = ic * p.offset_delta + batch * p.batch_offset; + const BDA_OFFSET_T dst_base = ((BDA_OFFSET_T(batch) * p.OH + oh) * p.OW) * p.CHW + BDA_OFFSET_T(ic) * (p.KW * p.KH); + const int oh_s1 = int(oh) * p.s1; + const uint ksize = p.OW * p.KH; + + const uint base_linear_idx = gidx * NUM_ITER; + + uint current_kx = base_linear_idx / ksize; + const uint rem = base_linear_idx - (current_kx * ksize); + uint current_ky = rem / p.OW; + uint current_ix = rem % p.OW; + + A_TYPE values[NUM_ITER]; + BDA_OFFSET_T offset_dst[NUM_ITER]; + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + values[idx] = A_TYPE(0); + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + + const uint iiw = current_ix * p.s0 + current_kx * p.d0 - p.p0; + const uint iih = oh_s1 + current_ky * p.d1 - p.p1; + + offset_dst[idx] = dst_base + BDA_OFFSET_T(current_ix) * p.CHW + current_ky * p.KW + current_kx; + + if ((iih < p.IH) && (iiw < p.IW)) { + values[idx] = data_a[src_base + iih * p.IW + iiw]; + } + + if (++current_ix == p.OW) { + current_ix = 0; + if (++current_ky == p.KH) { + current_ky = 0; + current_kx++; + } + } + } + + [[unroll]] for (uint idx = 0; idx < NUM_ITER; ++idx) { + + const uint linear_idx = base_linear_idx + idx; + + if (linear_idx >= p.pelements) { + continue; + } + +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst[idx]); + dst_addr.d = D_TYPE(values[idx]); +#else + data_d[offset_dst[idx]] = D_TYPE(values[idx]); +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp new file mode 100644 index 0000000000..4bf8b4ca04 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/im2col_3d.comp @@ -0,0 +1,125 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "rte.glsl" +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + BDA_STORAGE_T dst_addr; + uint32_t nb10; + uint32_t nb11; + uint32_t nb12; + uint32_t nb13; + uint32_t s0; + uint32_t s1; + uint32_t s2; + uint32_t p0; + uint32_t p1; + uint32_t p2; + uint32_t d0; + uint32_t d1; + uint32_t d2; + uint32_t IW; + uint32_t IH; + uint32_t ID; + uint32_t IC; + uint32_t KW; + uint32_t OH; + uint32_t KD_KH_KW; + uint32_t KH_KW; + uint32_t IC_KD_KH_KW; + uint32_t N_OD_OH; + uint32_t OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW; + uint32_t misalign_offsets; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +#if BDA +layout (buffer_reference) buffer D_ptr {D_TYPE d;}; +#endif + +void main() { + const uint32_t i = gl_GlobalInvocationID.x; + + uint32_t nb10 = p.nb10; + uint32_t nb11 = p.nb11; + uint32_t nb12 = p.nb12; + uint32_t nb13 = p.nb13; + uint32_t s0 = p.s0; + uint32_t s1 = p.s1; + uint32_t s2 = p.s2; + uint32_t p0 = p.p0; + uint32_t p1 = p.p1; + uint32_t p2 = p.p2; + uint32_t d0 = p.d0; + uint32_t d1 = p.d1; + uint32_t d2 = p.d2; + uint32_t IW = p.IW; + uint32_t IH = p.IH; + uint32_t ID = p.ID; + uint32_t IC = p.IC; + uint32_t KW = p.KW; + uint32_t OH = p.OH; + uint32_t KD_KH_KW = p.KD_KH_KW; + uint32_t KH_KW = p.KH_KW; + uint32_t IC_KD_KH_KW = p.IC_KD_KH_KW; + uint32_t N_OD_OH = p.N_OD_OH; + uint32_t OD_OH = p.OD_OH; + uint32_t OD_OH_OW_IC_KD_KH_KW = p.OD_OH_OW_IC_KD_KH_KW; + uint32_t OH_OW_IC_KD_KH_KW = p.OH_OW_IC_KD_KH_KW; + uint32_t OW_IC_KD_KH_KW = p.OW_IC_KD_KH_KW; + + if (i >= IC_KD_KH_KW) { + return; + } + + const uint32_t iic = i / KD_KH_KW; + const uint32_t ikd = (i - iic * KD_KH_KW) / KH_KW; + const uint32_t ikh = (i - iic * KD_KH_KW - ikd * KH_KW) / KW; + const uint32_t ikw = i % KW; + + const uint32_t iow = gl_GlobalInvocationID.y; + for (uint32_t iz = gl_GlobalInvocationID.z; iz < N_OD_OH; iz += gl_NumWorkGroups.z) { + const uint32_t in_ = iz / OD_OH; + const uint32_t iod = (iz - in_*OD_OH) / OH; + const uint32_t ioh = iz % OH; + + const uint32_t iiw = iow * s0 + ikw * d0 - p0; + const uint32_t iih = ioh * s1 + ikh * d1 - p1; + const uint32_t iid = iod * s2 + ikd * d2 - p2; + + const BDA_OFFSET_T offset_dst = BDA_OFFSET_T(in_)*OD_OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(iod)*OH_OW_IC_KD_KH_KW + BDA_OFFSET_T(ioh)*OW_IC_KD_KH_KW + BDA_OFFSET_T(iow)*IC_KD_KH_KW + iic*KD_KH_KW + ikd * KH_KW + ikh*KW + ikw; + + const uint32_t offset_src = (in_*IC + iic)*nb13 + iid*nb12 + iih*nb11 + iiw*nb10; +#if BDA + D_ptr dst_addr = D_ptr(p.dst_addr + D_SIZE * offset_dst); + if (iih >= IH || iiw >= IW || iid >= ID) { + dst_addr.d = D_TYPE(0.0f); + } else { + dst_addr.d = D_TYPE(data_a[offset_src + get_aoffset()]); + } +#else + if (iih >= IH || iiw >= IW || iid >= ID) { + data_d[offset_dst + get_doffset()] = D_TYPE(0.0f); + } else { + data_d[offset_dst + get_doffset()] = D_TYPE(data_a[offset_src + get_aoffset()]); + } +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp new file mode 100644 index 0000000000..83ef2f8795 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/l2_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[row*p.KX + col]); + sum[tid] += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE scale = inversesqrt(max(sum[0], FLOAT_TYPE(p.param1))); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale * FLOAT_TYPE(data_a[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp new file mode 100644 index 0000000000..b281e855cb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/leaky_relu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float val = float(data_a[i]); + data_d[i] = D_TYPE(max(val, 0.0f) + min(val, 0.0f) * p.param1); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp new file mode 100644 index 0000000000..02ef1eace1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul.comp @@ -0,0 +1,27 @@ +#version 450 + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) * FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp new file mode 100644 index 0000000000..4c64fd47af --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_split_k_reduce.comp @@ -0,0 +1,48 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 256, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {float data_a[];}; +layout (binding = 0) readonly buffer A4 {vec4 data_a4[];}; +layout (binding = 1) writeonly buffer D {float data_d[];}; +layout (binding = 1) writeonly buffer D4 {vec4 data_d4[];}; + +layout (push_constant) uniform parameter { + uint ne; + uint k_num; +} p; + +void main() { + // Each invocation handles four consecutive components + const uint idx = gl_GlobalInvocationID.x * 4; + + if (idx >= p.ne) { + return; + } + + // Check if all four components are in bounds and aligned, + // then use vector loads + if (idx + 3 < p.ne && (p.ne % 4) == 0) { + vec4 result = vec4(0.0f); + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a4[(i * p.ne + idx) / 4]; + } + + data_d4[idx / 4] = result; + } else { + [[unroll]] for (uint j = 0; j < 4; ++j) { + if (idx + j < p.ne) { + float result = 0.0f; + + [[unroll]] for (uint i = 0; i < p.k_num; i++) { + result += data_a[i * p.ne + idx + j]; + } + + data_d[idx + j] = result; + } + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp new file mode 100644 index 0000000000..9a03925cfd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec.comp @@ -0,0 +1,169 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#if !defined(DATA_A_F32) && !defined(DATA_A_F16) && !defined(DATA_A_BF16) +#define K_PER_ITER 8 +#else +#define K_PER_ITER 2 +#endif + + +uint a_offset, b_offset, d_offset, y_offset; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i, bool lastiter) +{ + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + K_PER_ITER*tid; + const uint iqs = (col%QUANT_K)/QUANT_R; // quant index + const uint iybs = col - col%QUANT_K; // y block start index + +#if K_PER_ITER == 8 +#if QUANT_R == 2 + const vec4 bv02 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv13 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs + y_offset) / 4]); + const vec4 bv0 = vec4(bv02.x, bv13.x, bv02.y, bv13.y); + const vec4 bv1 = vec4(bv02.z, bv13.z, bv02.w, bv13.w); +#else + const vec4 bv0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4]); + const vec4 bv1 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + iybs + iqs) / 4 + 1]); +#endif +#else + // Check if the second of the pair of elements is OOB, and don't fetch B or + // accumulate it. We still fetch a pair of elements for A, which is fine for + // quantized formats since they'll be within the same block. We should + // probably skip fetching the second element for F16/F32, but as of now we + // still do. + const bool OOB = lastiter && (iybs + iqs + y_offset >= p.ncols); + + FLOAT_TYPE b0 = 0, b1 = 0; + b0 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs]); + if (!OOB) { + b1 = FLOAT_TYPE(data_b[j*p.batch_stride_b + b_offset + iybs + iqs + y_offset]); + } +#endif + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib = (ibi + col)/QUANT_K; // block index + ibi += p.ncols; + +#if K_PER_ITER == 8 + vec4 v = dequantize4(ib, iqs, a_offset); + vec4 v2 = dequantize4(ib, iqs+(4/QUANT_R), a_offset); + + const vec2 dm = get_dm(ib, a_offset); + if (dm.y != 0) { // quant has min component + v = v * dm.x + dm.y; + v2 = v2 * dm.x + dm.y; + } + + // matrix multiplication + FLOAT_TYPE rowtmp = dot(bv0, v); + rowtmp += dot(bv1, v2); + + if (dm.y == 0) + rowtmp *= dm.x; + + temp[j][n] += rowtmp; +#else + const vec2 v = dequantize(ib, iqs, a_offset); + + // matrix multiplication + temp[j][n] = fma(FLOAT_TYPE(v.x), b0, temp[j][n]); + if (!OOB) { + temp[j][n] = fma(FLOAT_TYPE(v.y), b1, temp[j][n]); + } +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + + y_offset = QUANT_R == 1 ? 1 : QUANT_K/2; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + // If the K dimension is odd, we need lastiter==true on the last iteration + // so OOB is computed correctly. Skip some unrolling to make that happen. + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, false); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER, true); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl new file mode 100644 index 0000000000..450dee0408 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_base.glsl @@ -0,0 +1,182 @@ +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_8bit_storage : require + +#if USE_SUBGROUP_ADD || USE_SUBGROUP_ADD_NO_SHMEM +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_arithmetic : require +#endif + +#ifdef MUL_MAT_ID +#define EXPERT_COUNT 8 +#endif + +#include "types.glsl" + +#ifndef MMQ +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#else +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#endif + +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +#ifdef B_TYPE_VEC2 +layout (binding = 1) readonly buffer BV2 {B_TYPE_VEC2 data_b_v2[];}; +#endif +#ifdef B_TYPE_VEC4 +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; +#endif + +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +#include "dequant_funcs.glsl" + +layout (push_constant) uniform parameter +{ + uint ncols; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint ne11; +#else + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +void get_offsets(out uint a_offset, out uint b_offset, out uint d_offset) { +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.y; +#else + const uint batch_idx = gl_GlobalInvocationID.y; +#endif + +#ifndef MUL_MAT_ID + uint batch_idx_a = 0; + if (batch_idx != 0) { + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + batch_idx_a = i03 * p.ne02 + i02; + } +#else + const uint expert_id = data_ids[expert_idx]; +#endif + + a_offset = +#ifdef MUL_MAT_ID + expert_id * p.batch_stride_a; +#else + batch_idx_a * p.batch_stride_a; +#endif + b_offset = +#ifdef MUL_MAT_ID + (expert_idx % p.ne11) * p.stride_b; +#else + batch_idx * p.batch_stride_b; +#endif + d_offset = +#ifdef MUL_MAT_ID + expert_idx * p.stride_d; +#else + batch_idx * p.batch_stride_d; +#endif +} + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; +layout (constant_id = 1) const uint NUM_ROWS = 1; +layout (constant_id = 2) const uint NUM_COLS = 1; + +#ifdef USE_SUBGROUP_ADD_NO_SHMEM +void reduce_result(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +} +#else +shared FLOAT_TYPE tmpsh[NUM_COLS][NUM_ROWS][BLOCK_SIZE]; + +void reduce_result(FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const in uint32_t d_offset, const in uint32_t first_row, const in uint32_t num_rows, const in uint32_t tid) { + // subgroupAdd is probably faster on devices that support it, + // particularly when the workgroup has more than one subgroup +#if USE_SUBGROUP_ADD + // sum up partial sums within a subgroup + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = subgroupAdd(temp[j][n]); + } + } + + // Go through shared memory to sum partials across subgroups + if (gl_SubgroupInvocationID == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][gl_SubgroupID] = temp[j][n]; + } + } + } + barrier(); + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0); + [[unroll]] for (uint s = 0; s < gl_NumSubgroups; ++s) { + temp[j][n] += tmpsh[j][n][s]; + } + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(temp[j][n]); + } + } + } +#else + // sum up partial sums and write back result + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] = temp[j][n]; + } + } + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE/2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + tmpsh[j][n][tid] += tmpsh[j][n][tid + s]; + } + } + } + barrier(); + } + if (tid == 0) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + data_d[j*p.batch_stride_d + d_offset + first_row + n] = D_TYPE(tmpsh[j][n][0]); + } + } + } +#endif +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp new file mode 100644 index 0000000000..4cb292380c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_m.comp @@ -0,0 +1,82 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint16_t[4] scales = data_a[ibi].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + + const uint sc = data_a[ibi].scales[ib32 / 2] >> (6 * (ib32 & 1)); + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qh = data_a[ibi].qh[2 * ib32 + l / 2] >> (4 * (l&1)); + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(l / 2), 3) + 1); + + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp new file mode 100644 index 0000000000..0b74b33212 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq1_s.comp @@ -0,0 +1,79 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint qh = data_a[ibi].qh[ib32]; + const float dl = d * float(2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + + [[unroll]] for (uint l = 0; l < 4; ++l) { + const uint qs = data_a[ibi].qs[4 * ib32 + l]; + const uint idxhi = bitfieldExtract(qh, 3 * int(l), 3); + const int16_t grid = int16_t(iq1s_grid[qs | (idxhi << 8)]); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int k = 0; k < 4; ++k) { + sum = fma(FLOAT_TYPE(b0[k]), bitfieldExtract(grid, 2 * k, 2) + delta, + fma(FLOAT_TYPE(b4[k]), bitfieldExtract(grid, 8 + 2 * k, 2) + delta, sum)); + } + temp[j][n] = fma(dl, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp new file mode 100644 index 0000000000..e424af12c5 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + const uint qh = data_a[ibi].qh[ib32]; + const u8vec2 qs16 = unpack8(uint32_t(data_a_packed16[ibi].qs[itid])).xy; // vec4 used due to #12147 + const u8vec2 sign16 = unpack8(uint32_t(data_a_packed16[ibi].qs[QUANT_K / 16 + itid])).xy; + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint8_t sign = sign16[l]; + const uint qs = qs16[l] | ((qh << (8 - nibble_shift - 2 * l)) & 0x300); + const uvec2 grid = iq2s_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp new file mode 100644 index 0000000000..0cd906dbbf --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint nibble_shift = 4 * (itid & 1); + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32] >> nibble_shift) & 0xF; + const float db = d * (0.5 + scale) * 0.25; + + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[2 * itid + l]; + const uint sign = qs >> 9; + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xs_grid[qs & 511].x)); + const vec4 grid1 = vec4(unpack8(iq2xs_grid[qs & 511].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp new file mode 100644 index 0000000000..71bd72d17e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq2_xxs.comp @@ -0,0 +1,87 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[4 * ib32 + 2], + data_a_packed16[ibi].qs[4 * ib32 + 3])); + const float db = d * 0.25 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs = data_a[ibi].qs[8 * ib32 + 2 * (itid & 1) + l]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq2xxs_grid[qs].x)); + const vec4 grid1 = vec4(unpack8(iq2xxs_grid[qs].y)); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp new file mode 100644 index 0000000000..a4b9ab1f94 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_s.comp @@ -0,0 +1,90 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ib32, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 32 * ib32; + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint scale = (data_a[ibi].scales[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const float dscale = d * (1 + 2 * scale); + const uint qh = data_a[ibi].qh[ib32]; + FLOAT_TYPE sum[NUM_COLS]; + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + sum[j] = 0.0; + } + [[unroll]] for (uint l = 0; l < 4; ++l) { + const u8vec2 qs = unpack8(uint32_t(data_a_packed16[ibi].qs[4 * ib32 + l])).xy; // vec4 used due to #12147 + const uint sign = data_a[ibi].signs[4 * ib32 + l]; + const vec4 grid0 = vec4(unpack8(iq3s_grid[qs.x | ((qh << (8 - 2*l)) & 0x100)])); + const vec4 grid1 = vec4(unpack8(iq3s_grid[qs.y | ((qh << (7 - 2*l)) & 0x100)])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + sum[j] = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign & 128) != 0 ? -grid1.w : grid1.w), + sum[j])))))))); + } + } + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + temp[j][n] = fma(dscale, sum[j], temp[j][n]); + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 8 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/8; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 8; // 0...7 + const uint ix = tid / 8; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp new file mode 100644 index 0000000000..40849c691f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_iq3_xxs.comp @@ -0,0 +1,88 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y_idx = i * QUANT_K + 16 * itid; + const uint ib32 = itid / 2; // 0..7 + + uint ibi = a_offset / QUANT_K + first_row * num_blocks_per_row + i; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const float d = float(data_a[ibi].d); + const uint signscale = pack32(u16vec2( + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32], + data_a_packed16[ibi].qs[QUANT_K / 8 + 2 * ib32 + 1])); + const float db = d * 0.5 * (0.5 + (signscale >> 28)); + [[unroll]] for (uint l = 0; l < 2; ++l) { + const uint qs0 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l]; + const uint qs1 = data_a[ibi].qs[8 * ib32 + 4 * (itid & 1) + 2 * l + 1]; + const uint sign = bitfieldExtract(signscale, 7 * int(2 * (itid & 1) + l), 7); + const uint sign7 = bitCount(sign); + const vec4 grid0 = vec4(unpack8(iq3xxs_grid[qs0])); + const vec4 grid1 = vec4(unpack8(iq3xxs_grid[qs1])); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const vec4 b0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 0]); + const vec4 b4 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 2*l + 1]); + + FLOAT_TYPE sum = + fma(FLOAT_TYPE(b0.x), FLOAT_TYPE((sign & 1) != 0 ? -grid0.x : grid0.x), + fma(FLOAT_TYPE(b0.y), FLOAT_TYPE((sign & 2) != 0 ? -grid0.y : grid0.y), + fma(FLOAT_TYPE(b0.z), FLOAT_TYPE((sign & 4) != 0 ? -grid0.z : grid0.z), + fma(FLOAT_TYPE(b0.w), FLOAT_TYPE((sign & 8) != 0 ? -grid0.w : grid0.w), + fma(FLOAT_TYPE(b4.x), FLOAT_TYPE((sign & 16) != 0 ? -grid1.x : grid1.x), + fma(FLOAT_TYPE(b4.y), FLOAT_TYPE((sign & 32) != 0 ? -grid1.y : grid1.y), + fma(FLOAT_TYPE(b4.z), FLOAT_TYPE((sign & 64) != 0 ? -grid1.z : grid1.z), + fma(FLOAT_TYPE(b4.w), FLOAT_TYPE((sign7 & 1) != 0 ? -grid1.w : grid1.w), + FLOAT_TYPE(0.0))))))))); + temp[j][n] = fma(db, sum, temp[j][n]); + } + } + ibi += num_blocks_per_row; + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint blocks_per_wg = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid % 16; // 0...15 + const uint ix = tid / 16; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += blocks_per_wg) + calc_superblock(a_offset, b_offset, itid, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + init_iq_shmem(gl_WorkGroupSize); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp new file mode 100644 index 0000000000..638878d94c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_nc.comp @@ -0,0 +1,122 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#define BLOCK_SIZE 32 +#define FLOAT_TYPE float + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint row_stride_x; + uint channel_stride_x; + uint channel_stride_y; + uint channel_x_divisor; + uint ne12; + uint b_offset; + uint d_offset; + uint nb03; + uint nb13; + uint nb23; +} p; + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + const uint channel = gl_GlobalInvocationID.z; + const uint i3 = gl_WorkGroupID.x; + const uint channel_x = channel / p.channel_x_divisor; + const uint channel_y = channel % p.ne12; + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + const uint idst = i3*p.nb23 + channel*nrows_dst + row_dst; + + FLOAT_TYPE temp = 0.0f; + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.row_stride_x % 4) == 0 && (p.channel_stride_x % 4) == 0; + + for (uint col_x0 = 0; col_x0 < p.ncols_x;) { + + // Unroll 2x and do vec4 loads if aligned + const uint unroll_count = 2; + if (col_x0 + unroll_count * 4 * BLOCK_SIZE <= p.ncols_x && is_aligned) { + [[unroll]] for (uint i = 0; i < unroll_count; ++i) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } + // do vec4 loads if aligned + } else if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + const uint col_x = col_x0 + 4*tid; + + const uint row_y = col_x; + + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; + + const vec4 av4 = vec4(data_a_v4[ix / 4]); + const vec4 bv4 = vec4(data_b_v4[iy / 4]); + + temp += dot(av4, bv4); + + col_x0 += 4*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + if (col_x >= p.ncols_x) { + break; + } + + const uint row_y = col_x; + + const uint ix = i3*p.nb03 + channel_x*p.channel_stride_x + row_x*p.row_stride_x + col_x; + const uint iy = i3*p.nb13 + channel_y*p.channel_stride_y + row_y; + + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + temp = fma(xi, FLOAT_TYPE(data_b[iy]), temp); + col_x0 += BLOCK_SIZE; + } + } + + tmp[tid] = temp; + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + tmp[tid] += tmp[tid + s]; + } + barrier(); + } + + if (tid == 0) { + dst[idst] = tmp[0]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp new file mode 100644 index 0000000000..7aa070eebd --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_p021.comp @@ -0,0 +1,154 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#if USE_SUBGROUP_ADD +#extension GL_KHR_shader_subgroup_arithmetic : enable +#endif + +#define FLOAT_TYPE float + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE dst[];}; + +layout (binding = 0) readonly buffer AV4 {A_TYPE_VEC4 data_a_v4[];}; +layout (binding = 1) readonly buffer BV4 {B_TYPE_VEC4 data_b_v4[];}; + +layout(constant_id = 0) const int BLOCK_SIZE = 32; +// gqa_ratio is in the range [1,8] +layout(constant_id = 1) const uint gqa_ratio = 1; + +layout (push_constant) uniform parameter +{ + uint ncols_x; + uint nrows_x; + uint nchannels_x; + uint nchannels_y; + uint b_offset; + uint d_offset; +} p; + +#if !USE_SUBGROUP_ADD +shared FLOAT_TYPE tmp[8][BLOCK_SIZE]; +#endif + +void main() { + const uint tid = gl_LocalInvocationID.x; + const uint row_x = gl_GlobalInvocationID.y; + + uint channel, channel_x; + + // When gqa_ratio > 1, each invocation does multiple rows. + // The row in the A matrix is starting from channel / gqa_ratio and the + // rows in the B matrix are [channel, channel+gqa_ratio). + // When gpa_ratio is 1, each invocation does one row. + if (gqa_ratio > 1) { + channel_x = gl_GlobalInvocationID.z; + channel = channel_x * gqa_ratio; + } else { + channel = gl_GlobalInvocationID.z; + channel_x = channel / (p.nchannels_y / p.nchannels_x);; + } + + const uint nrows_y = p.ncols_x; + const uint nrows_dst = p.nrows_x; + const uint row_dst = row_x; + + FLOAT_TYPE temp[8]; + [[unroll]] for (uint i = 0; i < 8; ++i) { + temp[i] = FLOAT_TYPE(0.0f); + } + + // Detect alignment for vector loads + bool is_aligned = (p.ncols_x % 4) == 0 && (p.nchannels_x % 4) == 0 && (nrows_y % 4) == 0; + + for (uint col_x0 = 0; col_x0 < p.ncols_x; col_x0 += BLOCK_SIZE) { + + // Use vec4 loads if aligned + if (col_x0 + 4*BLOCK_SIZE <= p.ncols_x && is_aligned) { + + uint col_x = col_x0 + 4*tid; + const uint row_y = col_x; + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const vec4 av4 = vec4(data_a_v4[ix / 4]); + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + vec4 bv4 = data_b_v4[iy / 4]; + temp[c] += dot(av4, bv4); + } + + col_x0 += 3*BLOCK_SIZE; + } else { + const uint col_x = col_x0 + tid; + + if (col_x >= p.ncols_x) { + break; + } + + // x is transposed and permuted + const uint ix = row_x*p.nchannels_x*p.ncols_x + channel_x*p.ncols_x + col_x; + const FLOAT_TYPE xi = FLOAT_TYPE(data_a[ix]); + + const uint row_y = col_x; + + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // y is not transposed but permuted + const uint iy = (channel + c)*nrows_y + row_y; + + temp[c] = fma(xi, FLOAT_TYPE(data_b[iy]), temp[c]); + } + } + } + +#if USE_SUBGROUP_ADD + // reduce vec4 at a time + vec4 t = vec4(temp[0], temp[1], temp[2], temp[3]); + t = subgroupAdd(t); + temp[0] = t[0]; + temp[1] = t[1]; + temp[2] = t[2]; + temp[3] = t[3]; + if (gqa_ratio > 4) { + t = vec4(temp[4], temp[5], temp[6], temp[7]); + t = subgroupAdd(t); + temp[4] = t[0]; + temp[5] = t[1]; + temp[6] = t[2]; + temp[7] = t[3]; + } +#else + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + tmp[c][tid] = temp[c]; + } + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] += tmp[c][tid + s]; + tmp[c][tid] = temp[c]; + } + } + barrier(); + } + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + temp[c] = tmp[c][tid]; + } +#endif + + if (tid == 0) { + [[unroll]] for (uint c = 0; c < gqa_ratio; ++c) { + // dst is not transposed and not permuted + const uint idst = (channel + c)*nrows_dst + row_dst; + dst[idst] = temp[c]; + } + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp new file mode 100644 index 0000000000..03ed25d3bf --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q2_k.comp @@ -0,0 +1,130 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache1[2][BLOCK_SIZE/16][16]; +shared FLOAT_TYPE sccache2[2][BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint v_im, const uint ix, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + } + barrier(); + + if (i >= num_blocks_per_row) + continue; + } else { + const uint32_t scale = uint32_t(data_a[ib0 + i].scales[itid]); + sccache1[csel][ix][itid] = FLOAT_TYPE(scale & 0xF); + sccache2[csel][ix][itid] = FLOAT_TYPE((scale >> 4) & 0xF); + barrier(); + } + + const uint32_t qs_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum1 = FLOAT_TYPE(0.0); + FLOAT_TYPE sum2 = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum1 = fma(FLOAT_TYPE(b0[l]), sccache1[csel][ix][ 8*v_im] * qs_u32_0[l ], + fma(FLOAT_TYPE(b16[l]), sccache1[csel][ix][1 + 8*v_im] * qs_u32_0[l+2], + fma(FLOAT_TYPE(b32[l]), sccache1[csel][ix][2 + 8*v_im] * qs_u32_2[l ], + fma(FLOAT_TYPE(b48[l]), sccache1[csel][ix][3 + 8*v_im] * qs_u32_2[l+2], + fma(FLOAT_TYPE(b64[l]), sccache1[csel][ix][4 + 8*v_im] * qs_u32_4[l ], + fma(FLOAT_TYPE(b80[l]), sccache1[csel][ix][5 + 8*v_im] * qs_u32_4[l+2], + fma(FLOAT_TYPE(b96[l]), sccache1[csel][ix][6 + 8*v_im] * qs_u32_6[l ], + fma(FLOAT_TYPE(b112[l]), sccache1[csel][ix][7 + 8*v_im] * qs_u32_6[l+2], sum1)))))))); + sum2 = fma(FLOAT_TYPE(b0[l]), sccache2[csel][ix][ 8*v_im], + fma(FLOAT_TYPE(b16[l]), sccache2[csel][ix][1 + 8*v_im], + fma(FLOAT_TYPE(b32[l]), sccache2[csel][ix][2 + 8*v_im], + fma(FLOAT_TYPE(b48[l]), sccache2[csel][ix][3 + 8*v_im], + fma(FLOAT_TYPE(b64[l]), sccache2[csel][ix][4 + 8*v_im], + fma(FLOAT_TYPE(b80[l]), sccache2[csel][ix][5 + 8*v_im], + fma(FLOAT_TYPE(b96[l]), sccache2[csel][ix][6 + 8*v_im], + fma(FLOAT_TYPE(b112[l]), sccache2[csel][ix][7 + 8*v_im], sum2)))))))); + } + temp[j][n] = fma(dall, sum1, fma(-dmin, sum2, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, v_im, ix, q_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp new file mode 100644 index 0000000000..528f224d86 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q3_k.comp @@ -0,0 +1,132 @@ +#version 450 +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][2][8]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint ix, const uint itid8, const uint v_im, const uint v_im4, const uint v_in, const uint32_t hm_m[4], const uint q_offset, const uint y_offset, const uint s_shift, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t hmk = ~(uint32_t(data_a_packed16[ib0 + i].hmask[v_in]) | (uint32_t(data_a_packed16[ib0 + i].hmask[v_in + 8]) << 16)); + const vec4 hmk_0 = vec4(unpack8(((hmk & hm_m[0]) >> ( v_im4)) << 2)); + const vec4 hmk_1 = vec4(unpack8(((hmk & hm_m[1]) >> (1 + v_im4)) << 2)); + const vec4 hmk_2 = vec4(unpack8(((hmk & hm_m[2]) >> (2 + v_im4)) << 2)); + const vec4 hmk_3 = vec4(unpack8(((hmk & hm_m[3]) >> (3 + v_im4)) << 2)); + + // 0, 1, 16, 17 + uint32_t qs_u32 = uint32_t(data_a[ib0 + i].qs[q_offset]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 1]) << 8); + qs_u32 |= (uint32_t(data_a[ib0 + i].qs[q_offset + 16]) | (uint32_t(data_a[ib0 + i].qs[q_offset + 17]) << 8)) << 16; + const vec4 qs_u32_0 = vec4(unpack8(qs_u32 & 0x03030303)); + const vec4 qs_u32_2 = vec4(unpack8((qs_u32 >> 2) & 0x03030303)); + const vec4 qs_u32_4 = vec4(unpack8((qs_u32 >> 4) & 0x03030303)); + const vec4 qs_u32_6 = vec4(unpack8((qs_u32 >> 6) & 0x03030303)); + + if (all_threads) { + sccache[csel][ix][v_im][itid8] = FLOAT_TYPE(int8_t(((data_a[ib0+i].scales[itid8] >> v_im4) & 0xF) | (((data_a[ib0+i].scales[itid8%4+8] >> s_shift) & 3) << 4)) - 32); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 b0 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 0]); + vec2 b16 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 8]); + vec2 b32 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 16]); + vec2 b48 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 24]); + vec2 b64 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 32]); + vec2 b80 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 40]); + vec2 b96 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 48]); + vec2 b112 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y_idx) / 2 + 56]); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0); + [[unroll]] for (int l = 0; l < 2; ++l) { + sum = fma(FLOAT_TYPE( b0[l]) * sccache[csel][ix][v_im][0], qs_u32_0[l ] - hmk_0[l ], + fma(FLOAT_TYPE( b16[l]) * sccache[csel][ix][v_im][1], qs_u32_0[l+2] - hmk_0[l+2], + fma(FLOAT_TYPE( b32[l]) * sccache[csel][ix][v_im][2], qs_u32_2[l ] - hmk_1[l ], + fma(FLOAT_TYPE( b48[l]) * sccache[csel][ix][v_im][3], qs_u32_2[l+2] - hmk_1[l+2], + fma(FLOAT_TYPE( b64[l]) * sccache[csel][ix][v_im][4], qs_u32_4[l ] - hmk_2[l ], + fma(FLOAT_TYPE( b80[l]) * sccache[csel][ix][v_im][5], qs_u32_4[l+2] - hmk_2[l+2], + fma(FLOAT_TYPE( b96[l]) * sccache[csel][ix][v_im][6], qs_u32_6[l ] - hmk_3[l ], + fma(FLOAT_TYPE(b112[l]) * sccache[csel][ix][v_im][7], qs_u32_6[l+2] - hmk_3[l+2], sum)))))))); + } + temp[j][n] = fma(d, sum, temp[j][n]); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + const uint itid8 = itid%8; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_im4 = v_im*4; + const uint v_in = itid - 8*v_im; // 0...7 + + const uint32_t m = 0x01010101 << (4 * v_im); + uint32_t hm_m[4]; + [[unroll]] for (uint j = 0; j < 4; ++j) + hm_m[j] = m << j; + + const uint l0 = 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint s_shift = v_im4 + 2*(itid8/4); + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, ix, itid8, v_im, v_im4, v_in, hm_m, q_offset, y_offset, s_shift, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp new file mode 100644 index 0000000000..21d07d2e50 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q4_k.comp @@ -0,0 +1,136 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4]; + const uint32_t qs64_u32 = data_a_packed32[ib0 + i].qs[q_offset / 4 + 16]; + + const uint32_t qs0_u32_lo4 = qs0_u32 & 0x0F0F0F0F; + const uint32_t qs0_u32_hi4 = (qs0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t qs64_u32_lo4 = qs64_u32 & 0x0F0F0F0F; + const uint32_t qs64_u32_hi4 = (qs64_u32 >> 4) & 0x0F0F0F0F; + + const vec4 qs0_lo4 = vec4(unpack8(qs0_u32_lo4)); + const vec4 qs64_lo4 = vec4(unpack8(qs64_u32_lo4)); + const vec4 qs0_hi4 = vec4(unpack8(qs0_u32_hi4)); + const vec4 qs64_hi4 = vec4(unpack8(qs64_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_lo4.x; + const FLOAT_TYPE q4_1 = qs0_lo4.y; + const FLOAT_TYPE q4_2 = qs0_lo4.z; + const FLOAT_TYPE q4_3 = qs0_lo4.w; + const FLOAT_TYPE q4_4 = qs0_hi4.x; + const FLOAT_TYPE q4_5 = qs0_hi4.y; + const FLOAT_TYPE q4_6 = qs0_hi4.z; + const FLOAT_TYPE q4_7 = qs0_hi4.w; + const FLOAT_TYPE q4_8 = qs64_lo4.x; + const FLOAT_TYPE q4_9 = qs64_lo4.y; + const FLOAT_TYPE q4_10 = qs64_lo4.z; + const FLOAT_TYPE q4_11 = qs64_lo4.w; + const FLOAT_TYPE q4_12 = qs64_hi4.x; + const FLOAT_TYPE q4_13 = qs64_hi4.y; + const FLOAT_TYPE q4_14 = qs64_hi4.z; + const FLOAT_TYPE q4_15 = qs64_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by10 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 ]); + vec4 by132 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y1_idx) / 4 + 8]); + vec4 by20 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 ]); + vec4 by232 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y2_idx) / 4 + 8]); + + const FLOAT_TYPE sx = fma(FLOAT_TYPE(by10.x), q4_0, fma(FLOAT_TYPE(by10.y), q4_1, fma(FLOAT_TYPE(by10.z), q4_2, FLOAT_TYPE(by10.w) * q4_3))); + const FLOAT_TYPE sy = fma(FLOAT_TYPE(by132.x), q4_4, fma(FLOAT_TYPE(by132.y), q4_5, fma(FLOAT_TYPE(by132.z), q4_6, FLOAT_TYPE(by132.w) * q4_7))); + const FLOAT_TYPE sz = fma(FLOAT_TYPE(by20.x), q4_8, fma(FLOAT_TYPE(by20.y), q4_9, fma(FLOAT_TYPE(by20.z), q4_10, FLOAT_TYPE(by20.w) * q4_11))); + const FLOAT_TYPE sw = fma(FLOAT_TYPE(by232.x), q4_12, fma(FLOAT_TYPE(by232.y), q4_13, fma(FLOAT_TYPE(by232.z), q4_14, FLOAT_TYPE(by232.w) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x), sc2, fma(FLOAT_TYPE(by132.x), sc3, fma(FLOAT_TYPE(by20.x), sc6, fma(FLOAT_TYPE(by232.x), sc7, + fma(FLOAT_TYPE(by10.y), sc2, fma(FLOAT_TYPE(by132.y), sc3, fma(FLOAT_TYPE(by20.y), sc6, fma(FLOAT_TYPE(by232.y), sc7, + fma(FLOAT_TYPE(by10.z), sc2, fma(FLOAT_TYPE(by132.z), sc3, fma(FLOAT_TYPE(by20.z), sc6, fma(FLOAT_TYPE(by232.z), sc7, + fma(FLOAT_TYPE(by10.w), sc2, fma(FLOAT_TYPE(by132.w), sc3, fma(FLOAT_TYPE(by20.w), sc6, FLOAT_TYPE(by232.w) * sc7))))))))))))))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + const uint n = 4; + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = n * (2 * ir + v_in); // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp new file mode 100644 index 0000000000..9e46c89a11 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q5_k.comp @@ -0,0 +1,167 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint v_im, const uint l0, const uint q_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows) { + const uint y1_idx = i * QUANT_K + y_offset; + const uint y2_idx = y1_idx + 128; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + vec2 d = vec2(data_a[ib0 + i].d); + const FLOAT_TYPE dall = FLOAT_TYPE(d.x); + const FLOAT_TYPE dmin = FLOAT_TYPE(d.y); + + const uint32_t scale0_u32 = data_a_packed16[ib0 + i].scales[v_im ]; + const uint32_t scale4_u32 = data_a_packed16[ib0 + i].scales[v_im + 2]; + const uint32_t scale8_u32 = data_a_packed16[ib0 + i].scales[v_im + 4]; + + const uint32_t scale_0_4_l = (scale4_u32 << 16) | scale0_u32; + const uint32_t scale_0_4_h = (scale_0_4_l & 0xC0C0C0C0) >> 2; + const vec4 scale_0_4_l_f = vec4(unpack8(scale_0_4_l & 0x3F3F3F3F)); + const vec4 scale8_f = vec4(unpack8((((scale8_u32 << 12) | scale8_u32) & 0x0F0F0F0F) | scale_0_4_h)); + + const FLOAT_TYPE sc0 = scale_0_4_l_f.x; + const FLOAT_TYPE sc1 = scale_0_4_l_f.y; + const FLOAT_TYPE sc2 = scale_0_4_l_f.z; + const FLOAT_TYPE sc3 = scale_0_4_l_f.w; + const FLOAT_TYPE sc4 = scale8_f.x; + const FLOAT_TYPE sc5 = scale8_f.y; + const FLOAT_TYPE sc6 = scale8_f.z; + const FLOAT_TYPE sc7 = scale8_f.w; + + const uint32_t qs0_16_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 8]) << 16); + const uint32_t qs64_80_u32 = uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 32]) | (uint32_t(data_a_packed16[ib0 + i].qs[q_offset / 2 + 40]) << 16); + + uint32_t qs0_16_u32_lo4 = qs0_16_u32 & 0x0F0F0F0F; + uint32_t qs0_16_u32_hi4 = (qs0_16_u32 >> 4) & 0x0F0F0F0F; + uint32_t qs64_80_u32_lo4 = qs64_80_u32 & 0x0F0F0F0F; + uint32_t qs64_80_u32_hi4 = (qs64_80_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh = pack32(u16vec2(data_a_packed16[ib0 + i].qh[l0 / 2], data_a_packed16[ib0 + i].qh[l0 / 2 + 8])); + + const uint32_t qs0_16_lo4_offset16 = ((qh >> (2*v_im)) & 0x01010101) << 4; + const uint32_t qs0_16_hi4_offset16 = ((qh >> (2*v_im)) & 0x02020202) << 3; + const uint32_t qs64_80_lo4_offset16 = ((qh >> (2*v_im)) & 0x10101010); + const uint32_t qs64_80_hi4_offset16 = ((qh >> (2*v_im)) & 0x20202020) >> 1; + + qs0_16_u32_lo4 += qs0_16_lo4_offset16; + qs0_16_u32_hi4 += qs0_16_hi4_offset16; + qs64_80_u32_lo4 += qs64_80_lo4_offset16; + qs64_80_u32_hi4 += qs64_80_hi4_offset16; + + const vec4 qs0_16_lo4 = vec4(unpack8(qs0_16_u32_lo4)); + const vec4 qs64_80_lo4 = vec4(unpack8(qs64_80_u32_lo4)); + const vec4 qs0_16_hi4 = vec4(unpack8(qs0_16_u32_hi4)); + const vec4 qs64_80_hi4 = vec4(unpack8(qs64_80_u32_hi4)); + + const FLOAT_TYPE q4_0 = qs0_16_lo4.x; + const FLOAT_TYPE q4_1 = qs0_16_lo4.y; + const FLOAT_TYPE q4_2 = qs0_16_lo4.z; + const FLOAT_TYPE q4_3 = qs0_16_lo4.w; + const FLOAT_TYPE q4_4 = qs0_16_hi4.x; + const FLOAT_TYPE q4_5 = qs0_16_hi4.y; + const FLOAT_TYPE q4_6 = qs0_16_hi4.z; + const FLOAT_TYPE q4_7 = qs0_16_hi4.w; + const FLOAT_TYPE q4_8 = qs64_80_lo4.x; + const FLOAT_TYPE q4_9 = qs64_80_lo4.y; + const FLOAT_TYPE q4_10 = qs64_80_lo4.z; + const FLOAT_TYPE q4_11 = qs64_80_lo4.w; + const FLOAT_TYPE q4_12 = qs64_80_hi4.x; + const FLOAT_TYPE q4_13 = qs64_80_hi4.y; + const FLOAT_TYPE q4_14 = qs64_80_hi4.z; + const FLOAT_TYPE q4_15 = qs64_80_hi4.w; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec2 by10 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 ]); + vec2 by116 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 8]); + vec2 by132 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 16]); + vec2 by148 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y1_idx) / 2 + 24]); + vec2 by20 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 ]); + vec2 by216 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 8]); + vec2 by232 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 16]); + vec2 by248 = vec2(data_b_v2[(j*p.batch_stride_b + b_offset + y2_idx) / 2 + 24]); + + const FLOAT_TYPE sx = + fma(FLOAT_TYPE(by10.x), q4_0, + fma(FLOAT_TYPE(by10.y), q4_1, + fma(FLOAT_TYPE(by116.x), q4_2, + FLOAT_TYPE(by116.y) * q4_3))); + const FLOAT_TYPE sy = + fma(FLOAT_TYPE(by132.x), q4_4, + fma(FLOAT_TYPE(by132.y), q4_5, + fma(FLOAT_TYPE(by148.x), q4_6, + FLOAT_TYPE(by148.y) * q4_7))); + const FLOAT_TYPE sz = + fma(FLOAT_TYPE(by20.x), q4_8, + fma(FLOAT_TYPE(by20.y), q4_9, + fma(FLOAT_TYPE(by216.x), q4_10, + FLOAT_TYPE(by216.y) * q4_11))); + const FLOAT_TYPE sw = + fma(FLOAT_TYPE(by232.x), q4_12, + fma(FLOAT_TYPE(by232.y), q4_13, + fma(FLOAT_TYPE(by248.x), q4_14, + FLOAT_TYPE(by248.y) * q4_15))); + const FLOAT_TYPE smin = + fma(FLOAT_TYPE(by10.x) + FLOAT_TYPE(by10.y) + FLOAT_TYPE(by116.x) + FLOAT_TYPE(by116.y), sc2, + fma(FLOAT_TYPE(by132.x) + FLOAT_TYPE(by132.y) + FLOAT_TYPE(by148.x) + FLOAT_TYPE(by148.y), sc3, + fma(FLOAT_TYPE(by20.x) + FLOAT_TYPE(by20.y) + FLOAT_TYPE(by216.x) + FLOAT_TYPE(by216.y), sc6, + (FLOAT_TYPE(by232.x) + FLOAT_TYPE(by232.y) + FLOAT_TYPE(by248.x) + FLOAT_TYPE(by248.y)) * sc7))); + temp[j][n] = fma(dall, fma(sx, sc0, fma(sy, sc1, fma(sz, sc4, sw * sc5))), fma(-dmin, smin, temp[j][n])); + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint il = itid/4; // 0...3 + const uint ir = itid - 4*il; // 0...3 + + const uint v_im = il / 2; // 0 or 1. 0 computes 0,32 + 128,160, 1 computes 64,96 + 192,224 + const uint v_in = il % 2; + + const uint l0 = 4*ir + 2*v_in; // 0...15 + const uint q_offset = 32*v_im + l0; + const uint y_offset = 64*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + [[unroll]] for (uint i = ix; i < num_blocks_per_row; i += it_size) + calc_superblock(a_offset, b_offset, v_im, l0, q_offset, y_offset, i, num_blocks_per_row, first_row, num_rows); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp new file mode 100644 index 0000000000..d7a7f6426e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vec_q6_k.comp @@ -0,0 +1,130 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sccache[2][BLOCK_SIZE/16][16]; + +FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; +uint csel = 0; + +void calc_superblock(const uint a_offset, const uint b_offset, const uint itid, const uint ix, const uint ql_offset, const uint qh_offset, const uint s_offset, const uint y_offset, const uint i, const uint num_blocks_per_row, const uint first_row, const uint num_rows, const bool all_threads) { + const uint y_idx = i * QUANT_K + y_offset; + + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint ib0 = a_offset / QUANT_K + (first_row+n)*num_blocks_per_row; + csel ^= 1; + + if (!all_threads) { // when we don't have enough blocks to use all threads + if (i < num_blocks_per_row) + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + + if (i >= num_blocks_per_row) + continue; + } + + const uint32_t ql0_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 1]) << 16); + const uint32_t ql32_u32 = uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 16]) | (uint32_t(data_a_packed16[ib0 + i].ql[ql_offset / 2 + 17]) << 16); + + const uint32_t ql0_u32_lo4 = ql0_u32 & 0x0F0F0F0F; + const uint32_t ql0_u32_hi4 = (ql0_u32 >> 4) & 0x0F0F0F0F; + const uint32_t ql32_u32_lo4 = ql32_u32 & 0x0F0F0F0F; + const uint32_t ql32_u32_hi4 = (ql32_u32 >> 4) & 0x0F0F0F0F; + + const uint32_t qh_u32 = uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2]) | (uint32_t(data_a_packed16[ib0 + i].qh[qh_offset / 2 + 1]) << 16); + const uint32_t qh0_u32 = (qh_u32 & 0x03030303) << 4; + const uint32_t qh2_u32 = (qh_u32 & 0x0C0C0C0C) << 2; + const uint32_t qh4_u32 = (qh_u32 & 0x30303030); + const uint32_t qh6_u32 = (qh_u32 & 0xC0C0C0C0) >> 2; + + const uint32_t q0_u32 = ql0_u32_lo4 | qh0_u32; + const uint32_t q1_u32 = ql32_u32_lo4 | qh2_u32; + const uint32_t q2_u32 = ql0_u32_hi4 | qh4_u32; + const uint32_t q3_u32 = ql32_u32_hi4 | qh6_u32; + + const vec4 q0 = vec4(unpack8(q0_u32)) - 32; + const vec4 q1 = vec4(unpack8(q1_u32)) - 32; + const vec4 q2 = vec4(unpack8(q2_u32)) - 32; + const vec4 q3 = vec4(unpack8(q3_u32)) - 32; + + if (all_threads) { + sccache[csel][ix][itid] = FLOAT_TYPE(data_a[ib0 + i].scales[itid]); + barrier(); + } + + const FLOAT_TYPE d = FLOAT_TYPE(data_a[ib0 + i].d); + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + vec4 by0 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 ]); + vec4 by32 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 8]); + vec4 by64 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 16]); + vec4 by96 = vec4(data_b_v4[(j*p.batch_stride_b + b_offset + y_idx) / 4 + 24]); + + FLOAT_TYPE sum[4] = {0, 0, 0, 0}; + [[unroll]] for (uint l = 0; l < 4; ++l) { + sum[0] = fma(FLOAT_TYPE(by0[l]), q0[l], sum[0]); + sum[1] = fma(FLOAT_TYPE(by32[l]), q1[l], sum[1]); + sum[2] = fma(FLOAT_TYPE(by64[l]), q2[l], sum[2]); + sum[3] = fma(FLOAT_TYPE(by96[l]), q3[l], sum[3]); + } + temp[j][n] = fma(fma(sum[0], sccache[csel][ix][s_offset], fma(sum[1], sccache[csel][ix][s_offset + 2], fma(sum[2], sccache[csel][ix][s_offset + 4], sum[3] * sccache[csel][ix][s_offset + 6]))), d, temp[j][n]); + } + } +} + +void compute_outputs(const uint first_row, const uint num_rows) { + uint a_offset, b_offset, d_offset; + get_offsets(a_offset, b_offset, d_offset); + + const uint num_blocks_per_row = p.ncols / QUANT_K; + + // 16 threads are used to process each block + const uint it_size = gl_WorkGroupSize.x/16; + const uint tid = gl_LocalInvocationID.x; + const uint itid = tid%16; // 0...15 + const uint ix = tid/16; + + const uint v_im = itid/8; // 0 or 1. 0 computes 0..., 1 computes 128... + const uint v_in = itid - 8*v_im; // 0...7 + + const uint l0 = 4 * v_in; // 0, 4, 8, ..., 28 + const uint is = v_in / 4; + + const uint ql_offset = 64*v_im + l0; + const uint qh_offset = 32*v_im + l0; + const uint s_offset = 8*v_im + is; + const uint y_offset = 128*v_im + l0; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint i = 0; i < NUM_ROWS; ++i) { + temp[j][i] = FLOAT_TYPE(0); + } + } + + const uint nbr_par_th = num_blocks_per_row%it_size; + const uint nbr_all_th = num_blocks_per_row - nbr_par_th; + uint i0 = 0; + [[unroll]] for (; i0 < nbr_all_th; i0 += it_size) + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, true); + calc_superblock(a_offset, b_offset, itid, ix, ql_offset, qh_offset, s_offset, y_offset, i0 + ix, num_blocks_per_row, first_row, num_rows, false); + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp new file mode 100644 index 0000000000..64293f6eca --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mat_vecq.comp @@ -0,0 +1,140 @@ +#version 450 + +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_integer_dot_product : require + +#define MMQ +#define B_TYPE block_q8_1_x4 + +#include "mul_mat_vec_base.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#define K_PER_ITER 8 + +#include "mul_mmq_funcs.glsl" + +uint a_offset, b_offset, d_offset; + +int32_t cache_b_qs[2]; +vec2 cache_b_ds; + +void iter(inout FLOAT_TYPE temp[NUM_COLS][NUM_ROWS], const uint first_row, const uint num_rows, const uint tid, const uint i) { + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + const uint col = i*BLOCK_SIZE + tid*K_PER_ITER; + + // Preload data_b block + const uint b_block_idx = (j*p.batch_stride_b + col) / QUANT_K_Q8_1 + b_offset; + const uint b_qs_idx = tid % 4; + const uint b_block_idx_outer = b_block_idx / 4; + const uint b_block_idx_inner = b_block_idx % 4; + cache_b_ds = vec2(data_b[b_block_idx_outer].ds[b_block_idx_inner]); + +#if QUANT_R == 2 + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx + 4]; +#else + cache_b_qs[0] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2]; + cache_b_qs[1] = data_b[b_block_idx_outer].qs[b_block_idx_inner * 8 + b_qs_idx * 2 + 1]; +#endif + + uint ibi = first_row*p.ncols; + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + const uint a_block_idx = (ibi + col)/QUANT_K + a_offset; + ibi += p.ncols; + + int32_t q_sum = 0; +#if QUANT_R == 2 + const i32vec2 data_a_qs = repack(a_block_idx, b_qs_idx); + q_sum += dotPacked4x8EXT(data_a_qs.x, + cache_b_qs[0]); + q_sum += dotPacked4x8EXT(data_a_qs.y, + cache_b_qs[1]); +#else + int32_t data_a_qs = repack(a_block_idx, b_qs_idx * 2); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[0]); + data_a_qs = repack(a_block_idx, b_qs_idx * 2 + 1); + q_sum += dotPacked4x8EXT(data_a_qs, + cache_b_qs[1]); +#endif + +#if QUANT_AUXF == 1 + temp[j][n] += mul_q8_1(q_sum, get_d(a_block_idx), cache_b_ds, 4); +#else + temp[j][n] += mul_q8_1(q_sum, get_dm(a_block_idx), cache_b_ds, 4); +#endif + } + } +} + +void compute_outputs(const uint32_t first_row, const uint32_t num_rows) { + const uint tid = gl_LocalInvocationID.x; + + get_offsets(a_offset, b_offset, d_offset); + a_offset /= QUANT_K; + b_offset /= QUANT_K_Q8_1; + + FLOAT_TYPE temp[NUM_COLS][NUM_ROWS]; + + [[unroll]] for (uint j = 0; j < NUM_COLS; ++j) { + [[unroll]] for (uint n = 0; n < num_rows; ++n) { + temp[j][n] = FLOAT_TYPE(0.0f); + } + } + + uint num_iters = p.ncols / (K_PER_ITER * BLOCK_SIZE); + if (num_iters * K_PER_ITER * BLOCK_SIZE + K_PER_ITER*tid < p.ncols) { + num_iters++; + } + int unroll_count = 4; + uint unrolled_iters = num_iters & ~(unroll_count - 1); + + uint i = 0; + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + + unroll_count = 2; + unrolled_iters = num_iters & ~(unroll_count - 1); + +#if K_PER_ITER == 2 + if ((p.ncols & 1) != 0 && + unrolled_iters == num_iters && + unrolled_iters > 0) { + unrolled_iters -= unroll_count; + } +#endif + + while (i < unrolled_iters) { + // Manually partially unroll the loop + [[unroll]] for (uint k = 0; k < unroll_count; ++k) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + } + while (i < num_iters) { + iter(temp, first_row, num_rows, tid, i*K_PER_ITER); + i++; + } + + reduce_result(temp, d_offset, first_row, num_rows, tid); +} + +void main() { + const uint first_row = NUM_ROWS * (gl_WorkGroupID.x + gl_NumWorkGroups.x * gl_WorkGroupID.z); + + // do NUM_ROWS at a time, unless there aren't enough remaining rows + if (first_row + NUM_ROWS <= p.stride_d) { + compute_outputs(first_row, NUM_ROWS); + } else { + if (first_row >= p.stride_d) { + return; + } + compute_outputs(first_row, p.stride_d - first_row); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp new file mode 100644 index 0000000000..85400ac5fc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm.comp @@ -0,0 +1,481 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif +#if defined(DATA_A_IQ1_M) +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#if defined(DATA_A_BF16) && defined(COOPMAT) +#extension GL_EXT_bfloat16 : enable +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#endif + +#if defined(COOPMAT) || defined(MUL_MAT_ID_USE_SUBGROUPS) +#extension GL_KHR_shader_subgroup_basic : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.glsl" + +#ifndef LOAD_VEC_A +#define LOAD_VEC_A 1 +#endif +#ifndef LOAD_VEC_B +#define LOAD_VEC_B 1 +#endif + +// Load 2 values at once without affecting index calculations through LOAD_VEC +#if (defined(DATA_A_F32) || defined(DATA_A_F16) || defined(DATA_A_BF16)) && !defined(ALIGNED) +#define LOAD_VEC_BATCH_A 2 +#else +#define LOAD_VEC_BATCH_A 1 +#endif +#if !defined(ALIGNED) +#define LOAD_VEC_BATCH_B 2 +#else +#define LOAD_VEC_BATCH_B 1 +#endif + +#if !defined(TO_FLOAT_TYPE) +#define TO_FLOAT_TYPE FLOAT_TYPE +#endif + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +#if defined(A_TYPE_PACKED16) +layout (binding = 0) readonly buffer A_PACKED16 {A_TYPE_PACKED16 data_a_packed16[];}; +#endif +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif + +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 2 + 4) +#else +#define SHMEM_STRIDE (BK / 2 + 1) +#endif + +shared FLOAT_TYPE_VEC2 buf_a[BM * SHMEM_STRIDE]; +shared FLOAT_TYPE_VEC2 buf_b[BN * SHMEM_STRIDE]; + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[BN]; +uint _ne1; + +#ifdef MUL_MAT_ID_USE_SUBGROUPS +shared uvec4 ballots_sh[NUM_WARPS]; + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec2(ii0, ii1); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif // MUL_MAT_ID_USE_SUBGROUPS +#endif // MUL_MAT_ID + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mm_funcs.glsl" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A / LOAD_VEC_BATCH_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B / LOAD_VEC_BATCH_B); + + const uint loadstride_a = gl_WorkGroupSize.x * LOAD_VEC_A * LOAD_VEC_BATCH_A / BK; + const uint loadstride_b = gl_WorkGroupSize.x * LOAD_VEC_B * LOAD_VEC_BATCH_B / BK; + +#ifdef MUL_MAT_ID +#ifdef MUL_MAT_ID_USE_SUBGROUPS + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); + } +#else + _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1 && _ne1 < (ic + 1) * BN; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0 && _ne1 < (ic + 1) * BN; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + if (_ne1 >= ic * BN) { + row_ids[_ne1 - ic * BN] = u16vec2(ii0, ii1); + } + _ne1++; + } + } + } + + barrier(); +#endif + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / LOAD_VEC_A; +#ifdef MUL_MAT_ID + uint pos_b = 0; +#else + uint pos_b = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / LOAD_VEC_B; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + FLOAT_TYPE_VEC2 cache_a[WMITER * TM]; + FLOAT_TYPE_VEC2 cache_b[TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; l < BM; l += loadstride_a) { + load_a_to_shmem(pos_a, loadr_a, loadc_a + l, ir * BM + loadc_a + l, block, end_k); + } + [[unroll]] for (uint l = 0; l < BN; l += loadstride_b) { +#if !defined(MUL_MAT_ID) + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic * BN + loadc_b + l, block, end_k); +#else + load_b_to_shmem(pos_b, loadr_b, loadc_b + l, ic, _ne1, block, end_k); +#endif + } + + barrier(); + + pos_a += BK / LOAD_VEC_A; + pos_b += BK / LOAD_VEC_B; + +#ifdef COOPMAT + [[unroll]] for (uint i = 0; i < BK; i += TK) { + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + // Load from shared into cache + coopMatLoad(cache_a, buf_a, (warp_r * WM + cm_row * TM) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatLoad(cache_b, buf_b, (warp_c * WN + cm_col * TN) * SHMEM_STRIDE + i / 2, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + sums[cm_col * cms_per_row + cm_row] = coopMatMulAdd(cache_a, cache_b, sums[cm_col * cms_per_row + cm_row]); + } + } + } +#else + [[unroll]] for (uint i = 0; i < BK / 2; i++) { + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint j = 0; j < TM; j++) { + cache_a[wsir * TM + j] = buf_a[(warp_r * WM + wsir * WSUBM + tiwr * TM + j) * SHMEM_STRIDE + i]; + } + } + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint j = 0; j < TN; j++) { + cache_b[j] = buf_b[(warp_c * WN + wsic * WSUBN + tiwc * TN + j) * SHMEM_STRIDE + i]; + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + sums[sums_idx] = fma(ACC_TYPE(cache_a[wsir * TM + cr].x), ACC_TYPE(cache_b[cc].x), fma(ACC_TYPE(cache_a[wsir * TM + cr].y), ACC_TYPE(cache_b[cc].y), sums[sums_idx])); + } + } + } + } + } +#endif + + barrier(); + } + +#if defined(ACC_TYPE_MAX) +#ifdef COOPMAT + [[unroll]] for (uint j = 0; j < cms_per_row * cms_per_col; j++) { + [[unroll]] for (uint i = 0; i < sums[j].length(); ++i) { + sums[j][i] = clamp(sums[j][i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } + } +#else + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = clamp(sums[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); + } +#endif +#endif + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i - ic * BN]; + + if (dr + cm_row * TM + store_r < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i - ic * BN]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + if (dr_warp + cr < p.M) { + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp new file mode 100644 index 0000000000..2e04baa44e --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_cm2.comp @@ -0,0 +1,609 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require + +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require + +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_cooperative_matrix : enable +#extension GL_NV_cooperative_matrix2 : enable +#extension GL_EXT_buffer_reference : enable +#extension GL_KHR_shader_subgroup_ballot : enable +#extension GL_KHR_shader_subgroup_vote : enable +#ifdef DATA_A_BF16 +#extension GL_EXT_bfloat16 : enable +#endif + +#include "types.glsl" +#include "utils.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +#define IS_MUL_MM2 1 + +layout (constant_id = 0) const uint BLOCK_SIZE = 256; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +layout (constant_id = 3) const uint BK = 16; // Assumed to be 32 if working with a quant + +layout (constant_id = 4) const bool enable_smaller_matrices = false; +const uint BNover2 = enable_smaller_matrices ? (BN / 2) : BN; +const uint BNover4 = enable_smaller_matrices ? (BN / 4) : BN; + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif + // N dimension for the B matrix can be >= p.N + uint padded_N; +} p; + + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer B {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#if QUANT_K > 1 +#define DECODEFUNCA , dequantFuncA + +#include "dequant_funcs_cm2.glsl" + +#else +#define DECODEFUNCA +#endif + +#if !defined(fetch_scales) +#define fetch_scales(a, b, c, d, e, f) +#endif +#if !defined(store_scales) +#define store_scales(a) +#endif + +#if defined(DATA_A_BF16) +#define MAT_TYPE bfloat16_t +#else +#define MAT_TYPE FLOAT_TYPE +#endif + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; + +shared u16vec4 row_ids[BN]; + +layout(buffer_reference, std430, buffer_reference_align = 2) buffer decodeBufB { + B_TYPE b[]; +}; + +uint _ne1; +layout (constant_id = 5) const uint subgroup_size = 32; +shared uvec4 ballots_sh[BLOCK_SIZE / subgroup_size]; + +B_TYPE decodeFuncB(const in decodeBufB bl, const in uint blockCoords[2], const in uint coordInBlock[2]) +{ + const uint row_i = blockCoords[0]; + + if (row_i >= _ne1) { + return B_TYPE(0.0); + } + + const u16vec4 row_idx = row_ids[row_i & (BN - 1)]; + B_TYPE ret = data_b[row_idx.y * p.batch_stride_b + row_idx.x * p.stride_b + blockCoords[1]]; + + return ret; +} + +D_TYPE perElemOpD(const in uint32_t r, const in uint32_t c, const in D_TYPE elem, const in uint32_t ir, const in uint32_t ic) +{ + uint dr = ir * BM + r; + uint dc = ic * BN + c; + + if (dr < p.M && dc < _ne1) { + uint row_i = c; + const u16vec4 row_idx = row_ids[row_i]; + data_d[row_idx.y * p.batch_stride_d + row_idx.z * p.stride_d + dr] = elem; + } + return elem; +} + +void load_row_ids(uint expert_idx, bool nei0_is_pow2, uint ic) { + _ne1 = 0; + uint num_elements = p.nei1 * p.nei0; + uint nei0shift = findLSB(p.nei0); + + uint ids[16]; + uint iter = 0; + + for (uint j = 0; j < num_elements; j += BLOCK_SIZE) { + // prefetch up to 16 elements + if (iter == 0) { + [[unroll]] for (uint k = 0; k < 16; ++k) { + uint i = j + gl_LocalInvocationIndex + k*BLOCK_SIZE; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + ids[k] = in_range ? data_ids[ii1*p.nbi1 + ii0] : 0; + } + } + uint i = j + gl_LocalInvocationIndex; + bool in_range = i < num_elements; + uint ii1; + if (nei0_is_pow2) { + ii1 = i >> nei0shift; + } else { + ii1 = i / p.nei0; + } + uint ii0 = i - ii1 * p.nei0; + uint id = ids[iter++]; + uvec4 ballot = subgroupBallot(in_range && id == expert_idx); + + ballots_sh[gl_SubgroupID] = ballot; + barrier(); + + uint subgroup_base = 0; + uint total = 0; + for (uint k = 0; k < gl_NumSubgroups; ++k) { + if (k == gl_SubgroupID) { + subgroup_base = total; + } + total += subgroupBallotBitCount(ballots_sh[k]); + } + barrier(); + + uint idx = subgroup_base + subgroupBallotExclusiveBitCount(ballot); + if (in_range && id == expert_idx && _ne1 + idx >= ic * BN && _ne1 + idx < (ic + 1) * BN) { + row_ids[_ne1 + idx - ic * BN] = u16vec4(fastmod(ii0, p.ne11), ii1, ii0, 0); + } + _ne1 += total; + iter &= 15; + if (_ne1 >= (ic + 1) * BN) { + break; + } + } + barrier(); +} +#endif + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + + const uint tid = gl_LocalInvocationIndex; + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + +#ifdef MUL_MAT_ID + if (bitCount(p.nei0) == 1) { + load_row_ids(expert_idx, true, ic); + } else { + load_row_ids(expert_idx, false, ic); + } + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + uint start_k = 0; + const uint end_k = p.K; +#else + uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + +#ifdef MUL_MAT_ID + uint pos_a = (expert_idx * p.batch_stride_a) / QUANT_K; + uint pos_b = 0; +#else + uint pos_a = (batch_idx_a * p.batch_stride_a) / QUANT_K; + uint pos_b = batch_idx * p.batch_stride_b; + uint pos_d = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + + uint stride_a = p.stride_a / QUANT_K; + uint stride_b = p.stride_b; + + // Hint to the compiler that values are aligned (want 16B alignment). + // Quants are always block-aligned, no alignment needed. +#if ALIGNED +#if QUANT_K == 1 + stride_a &= ~7; +#endif + stride_b &= ~7; +#endif + + // Create layouts for both clamped and unclamped accesses + tensorLayoutNV<2> tensorLayoutA = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutAClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2> tensorLayoutB = createTensorLayoutNV(2); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutBClamp = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + tensorLayoutNV<2, gl_CooperativeMatrixClampModeConstantNV> tensorLayoutD = createTensorLayoutNV(2, gl_CooperativeMatrixClampModeConstantNV); + +#if QUANT_K > 1 + tensorLayoutA = setTensorLayoutBlockSizeNV(tensorLayoutA, 1, QUANT_K); + tensorLayoutAClamp = setTensorLayoutBlockSizeNV(tensorLayoutAClamp, 1, QUANT_K); +#endif + + // Use end_k rather than p.K as the dimension because that's what + // we need to bound check against when using split_k. + // Bounds check B against padded_N, but bounds check D against N. + tensorLayoutA = setTensorLayoutDimensionNV(tensorLayoutA, p.M, end_k); + tensorLayoutB = setTensorLayoutDimensionNV(tensorLayoutB, p.padded_N, end_k); + tensorLayoutD = setTensorLayoutDimensionNV(tensorLayoutD, p.N, p.M); + tensorLayoutAClamp = setTensorLayoutDimensionNV(tensorLayoutAClamp, p.M, end_k); + tensorLayoutBClamp = setTensorLayoutDimensionNV(tensorLayoutBClamp, p.padded_N, end_k); + + tensorLayoutD = setTensorLayoutStrideNV(tensorLayoutD, p.stride_d, 1); + + tensorViewNV<2, false, 1, 0> tensorViewTranspose = createTensorViewNV(2, false, 1, 0); + +#if !defined(MUL_MAT_ID) + + const uint START_ALIGN_K = 256; + // For Qi_K (block size 256), unroll whole 256 element tiles. + // For legacy quants (block size 32), unroll 8x. + const uint UNROLL_K = (QUANT_K == 256) ? 256 : (BK * 8); + const uint unroll_count = UNROLL_K / BK; + + // Detect a fast path where all loads are entirely in bounds and no clamping is required + if ((ir + 1) * BM <= p.M && (ic + 1) * BN <= p.padded_N && (start_k % START_ALIGN_K) == 0 && (end_k % BK) == 0 && +#if QUANT_K == 1 + (stride_a % 8) == 0 && +#endif + (stride_b % 8) == 0) { + // Hint to the compiler that values are aligned (want 16B alignment) + start_k &= ~(START_ALIGN_K-1); + stride_b &= ~7; +#if QUANT_K == 1 + stride_a &= ~7; +#endif + + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + uint k_iters = (end_k - start_k) / UNROLL_K; + uint block_k = start_k; + + // fetch scale values for a tile of quants. These will be copied into shared memory. + // The fetches and stores are pipelined to hide the latency. + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, true); + + if (enable_smaller_matrices && ic * BN + BNover4 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover4, ir * BM, BM), tensorViewTranspose); + return; + } else if (enable_smaller_matrices && ic * BN + BNover2 >= p.N) { + coopmat sum = coopmat(0.0); + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BNover2, ir * BM, BM), tensorViewTranspose); + return; + } else { + coopmat sum = coopmat(0.0); + + for (uint i = 0; i < k_iters; ++i) { + + store_scales(tid); + if (block_k + UNROLL_K < end_k) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + UNROLL_K, tid, true); + } + + // Manually partial unroll + [[unroll]] for (uint j = 0; j < unroll_count; ++j) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } + } + // Do any remaining iterations that were not unrolled + if (block_k < end_k) { + store_scales(tid); + } + while (block_k < end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + block_k += BK; + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + coopmat mat_d = coopmat(sum); + + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); + return; + } + } else +#endif // !defined(MUL_MAT_ID) + { + tensorLayoutA = setTensorLayoutStrideNV(tensorLayoutA, stride_a, 1); + + tensorLayoutAClamp = setTensorLayoutStrideNV(tensorLayoutAClamp, stride_a, 1); + + tensorLayoutB = setTensorLayoutStrideNV(tensorLayoutB, stride_b, 1); + + tensorLayoutBClamp = setTensorLayoutStrideNV(tensorLayoutBClamp, stride_b, 1); + + uint k_iters = (end_k - start_k + BK - 1) / BK; + + fetch_scales(ir * BM, pos_a, stride_a, start_k, tid, false); + store_scales(tid); + +#ifdef MUL_MAT_ID + if (enable_smaller_matrices && ic * BN + BNover4 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover4, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } + if (enable_smaller_matrices && ic * BN + BNover2 >= _ne1) { + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BNover2, block_k, BK), tensorViewTranspose, decodeFuncB); + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); + return; + } +#endif + coopmat sum; + sum = coopmat(0.0); + + [[dont_unroll]] + for (uint block_k = start_k, i = 0; i < k_iters; block_k += BK, ++i) { + + if ((block_k % QUANT_K) == 0) { + store_scales(tid); + } + if (block_k + BK < end_k && ((block_k + BK) % QUANT_K) == 0) { + fetch_scales(ir * BM, pos_a, stride_a, block_k + BK, tid, false); + } + + if ((ir + 1) * BM <= p.M && block_k + BK <= end_k) { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutA, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } else { + coopmat mat_a; + coopmat mat_b; + + coopMatLoadTensorNV(mat_a, data_a, pos_a, sliceTensorLayoutNV(tensorLayoutAClamp, ir * BM, BM, block_k, BK) DECODEFUNCA); +#ifdef MUL_MAT_ID + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutB, ic * BN, BN, block_k, BK), tensorViewTranspose, decodeFuncB); +#else + coopMatLoadTensorNV(mat_b, data_b, pos_b, sliceTensorLayoutNV(tensorLayoutBClamp, ic * BN, BN, block_k, BK), tensorViewTranspose); +#endif + + sum = coopMatMulAdd(mat_a, mat_b, sum); + } + } +#if defined(ACC_TYPE_MAX) + [[unroll]] for (uint i = 0; i < sum.length(); ++i) { sum[i] = clamp(sum[i], -ACC_TYPE_MAX, ACC_TYPE_MAX); } +#endif + + // Convert from ACC_TYPE to D_TYPE + coopmat mat_d; + mat_d = coopmat(sum); + +#ifdef MUL_MAT_ID + // Call callback to store each element, remapping row through shared memory + coopMatPerElementNV(mat_d, mat_d, perElemOpD, ir, ic); +#else + coopMatStoreTensorNV(mat_d, data_d, pos_d, sliceTensorLayoutNV(tensorLayoutD, ic * BN, BN, ir * BM, BM), tensorViewTranspose); +#endif + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl new file mode 100644 index 0000000000..0ebfbd6462 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mm_funcs.glsl @@ -0,0 +1,556 @@ +void load_a_to_shmem(const uint pos_a, const uint row, const uint col, const uint idx_m, const uint block, const uint end_k) { +#if defined(DATA_A_F32) || defined(DATA_A_F16) +#if LOAD_VEC_A == 8 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC8 aa = FLOAT_TYPE_VEC8(data_a[idx]); + buf_a[buf_idx ] = aa[0].xy; + buf_a[buf_idx + 1] = aa[0].zw; + buf_a[buf_idx + 2] = aa[1].xy; + buf_a[buf_idx + 3] = aa[1].zw; +#elif LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(data_a[idx]); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], + data_a[idx + 1]); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(data_a[idx], 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_BF16) +#if LOAD_VEC_A == 4 + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + FLOAT_TYPE_VEC4 aa = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_a[idx])); + buf_a[buf_idx ] = aa.xy; + buf_a[buf_idx + 1] = aa.zw; +#else // LOAD_VEC_BATCH_A == 2 + const uint idx = pos_a + col * p.stride_a + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_m < p.M && block + row * 2 + 1 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), + TO_FLOAT_TYPE(data_a[idx + 1])); + } else if (idx_m < p.M && block + row * 2 < end_k) { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_a[idx]), 0.0f); + } else { + buf_a[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +#elif defined(DATA_A_Q4_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = (vec4(unpack8(vui & 0x0F0F0F0F)) - 8.0f) * d; + const vec4 v1 = (vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) - 8.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q4_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + 2 * row; + + const uint ib = idx / 4; + const uint iqs = idx & 0x03; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint vui = uint(data_a_packed16[ib].qs[2*iqs]) | (uint(data_a_packed16[ib].qs[2*iqs + 1]) << 16); + const vec4 v0 = vec4(unpack8(vui & 0x0F0F0F0F)) * d + m; + const vec4 v1 = vec4(unpack8((vui >> 4) & 0x0F0F0F0F)) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v0.xy); + buf_a[buf_idx + 1 ] = FLOAT_TYPE_VEC2(v0.zw); + buf_a[buf_idx + 8 ] = FLOAT_TYPE_VEC2(v1.xy); + buf_a[buf_idx + 9 ] = FLOAT_TYPE_VEC2(v1.zw); +#elif defined(DATA_A_Q5_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const uint uint_qh = uint(data_a_packed16[ib].qh[1]) << 16 | uint(data_a_packed16[ib].qh[0]); + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = (vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) - 16.0f) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q5_1) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const float m = float(data_a_packed16[ib].m); + const uint uint_qh = data_a_packed16[ib].qh; + const ivec2 qh0 = ivec2(((uint_qh >> 2*iqs) << 4) & 0x10, (uint_qh >> (2*iqs + 12)) & 0x10); + const ivec2 qh1 = ivec2(((uint_qh >> (2*iqs + 1)) << 4) & 0x10, (uint_qh >> (2*iqs + 13)) & 0x10); + + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + const vec4 v = vec4((vui & 0xF) | qh0.x, ((vui >> 4) & 0xF) | qh0.y, ((vui >> 8) & 0xF) | qh1.x, (vui >> 12) | qh1.y) * d + m; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xz); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(v.yw); +#elif defined(DATA_A_Q8_0) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const float d = float(data_a_packed16[ib].d); + const i8vec2 v0 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs])).xy; // vec4 used due to #12147 + const i8vec2 v1 = unpack8(int32_t(data_a_packed16[ib].qs[2*iqs + 1])).xy; + const vec4 v = vec4(v0.x, v0.y, v1.x, v1.y) * d; + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2(v.zw); +#elif defined(DATA_A_Q2_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint qsi = (iqs / 64) * 32 + (iqs % 16) * 2; // 0,2,4..30 + const uint scalesi = iqs / 8; // 0..15 + const uint qsshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + + const uvec2 qs = uvec2(data_a[ib].qs[qsi], data_a[ib].qs[qsi + 1]); + const uint scales = data_a[ib].scales[scalesi]; + const vec2 d = vec2(data_a[ib].d); + + const vec2 v = d.x * float(scales & 0xF) * vec2((qs >> qsshift) & 3) - d.y * float(scales >> 4); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_Q3_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + const uint hmi = (iqs % 16) * 2; // 0,2,4..30 + const uint j = (iqs % 64) / 4; // 0..3 + const uint is = iqs / 8; // 0..15 + const uint halfsplit = ((iqs % 64) / 16); // 0,1,2,3 + const uint qsshift = halfsplit * 2; // 0,2,4,6 + const uint m = 1 << (4 * n + halfsplit); // 1,2,4,8,16,32,64,128 + + const int8_t us = int8_t(((data_a[ib].scales[is % 8] >> (4 * int(is / 8))) & 0xF) + | (((data_a[ib].scales[8 + (is % 4)] >> (2 * int(is / 4))) & 3) << 4)); + const float dl = float(data_a[ib].d) * float(us - 32); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dl * float(int8_t((data_a[ib].qs[qsi ] >> qsshift) & 3) - (((data_a[ib].hmask[hmi ] & m) != 0) ? 0 : 4)), + dl * float(int8_t((data_a[ib].qs[qsi + 1] >> qsshift) & 3) - (((data_a[ib].hmask[hmi + 1] & m) != 0) ? 0 : 4))); +#elif defined(DATA_A_Q4_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0 | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF), m)); +#elif defined(DATA_A_Q5_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 32; // 0,1,2,3 + const uint b = (iqs % 32) / 16; // 0,1 + const uint is = 2 * n + b; // 0..7 + const uint qsi = n * 32 + (iqs % 16) * 2; // 0,2,4..126 + const uint qhi = (iqs % 16) * 2; // 0,2,4..30 + + const uint8_t hm = uint8_t(1 << (iqs / 16)); + + const vec2 loadd = vec2(data_a[ib].d); + + const uint scidx0 = (is < 4) ? is : (is + 4); + const uint scidx1 = (is < 4) ? is : (is - 4); + const uint scidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint scidxshift1 = (is < 4) ? 0 : 2; + const uint mbidx0 = is + 4; + const uint mbidx1 = (is < 4) ? is + 4 : is; + const uint mbidxmask0 = (is < 4) ? 0xF : 0xF0; + const uint mbidxshift0 = (is < 4) ? 0 : 4; + const uint mbidxmask1 = (is < 4) ? 0x30 : 0xC0; + const uint mbidxshift1 = (is < 4) ? 0 : 2; + + const uint8_t sc = uint8_t((data_a[ib].scales[scidx0] & 0xF) | ((data_a[ib].scales[scidx1] & scidxmask1) >> scidxshift1)); + const uint8_t mbyte = uint8_t(((data_a[ib].scales[mbidx0] & mbidxmask0) >> mbidxshift0) | ((data_a[ib].scales[mbidx1] & mbidxmask1) >> mbidxshift1)); + + const float d = loadd.x * sc; + const float m = -loadd.y * mbyte; + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(fma(d, float((data_a[ib].qs[qsi ] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi ] & hm) != 0 ? 16 : 0), m), + fma(d, float((data_a[ib].qs[qsi + 1] >> (b * 4)) & 0xF) + float((data_a[ib].qh[qhi + 1] & hm) != 0 ? 16 : 0), m)); +#elif defined(DATA_A_Q6_K) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint iqs = idx % 128; // 0..127 + + const uint n = iqs / 64; // 0,1 + const uint b = (iqs % 64) / 32; // 0,1 + const uint is_b = (iqs % 16) / 8; // 0,1 + const uint qhshift = ((iqs % 64) / 16) * 2; // 0,2,4,6 + const uint is = 8 * n + qhshift + is_b; // 0..15 + const uint qsi = n * 64 + (iqs % 32) * 2; // 0,2,4..126 + const uint qhi = n * 32 + (iqs % 16) * 2; // 0,2,4..62 + + const float dscale = float(data_a[ib].d) * float(data_a[ib].scales[is]); + + buf_a[buf_idx] = FLOAT_TYPE_VEC2(dscale * float(int8_t(((data_a[ib].ql[qsi ] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi ] >> qhshift) & 3) << 4)) - 32), + dscale * float(int8_t(((data_a[ib].ql[qsi + 1] >> (b * 4)) & 0xF) | (((data_a[ib].qh[qhi + 1] >> qhshift) & 3) << 4)) - 32)); +#elif defined(DATA_A_IQ1_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 32; + + const float d = float(data_a[ib].d); + const uint qh = data_a[ib].qh[ib32]; + const uint qs = data_a[ib].qs[ib8]; + const float dl = d * (2 * bitfieldExtract(qh, 12, 3) + 1); + const float delta = ((qh & 0x8000) != 0) ? -IQ1S_DELTA : IQ1S_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | (bitfieldExtract(qh, 3 * int(ib8 & 3), 3) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ1_M) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; + const uint ib16 = ib8 / 2; + + const uint16_t[4] scales = data_a[ib].scales; + const u16vec4 s = u16vec4(scales[0], scales[1], scales[2], scales[3]) >> 12; + const float d = float(unpackHalf2x16(s.x | (s.y << 4) | (s.z << 8) | (s.w << 12)).x); + const uint sc = scales[ib8 / 8]; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib16] >> (4 * (ib8 & 1)); + const float dl = d * (2 * bitfieldExtract(sc, 3 * int(ib16 & 3), 3) + 1); + const float delta = ((qh & 8) != 0) ? -IQ1M_DELTA : IQ1M_DELTA; + const int16_t grid = int16_t(iq1s_grid[qs | ((qh & 7) << 8)]); + + [[unroll]] for (int k = 0; k < 4; ++k) { + buf_a[buf_idx + k] = FLOAT_TYPE_VEC2(dl * (bitfieldExtract(grid, 4 * k , 2) + delta), + dl * (bitfieldExtract(grid, 4 * k + 2, 2) + delta)); + } +#elif defined(DATA_A_IQ2_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[8 * ib32 + ib8]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[8*ib32 + 4], + data_a[ib].qs[8*ib32 + 5], + data_a[ib].qs[8*ib32 + 6], + data_a[ib].qs[8*ib32 + 7] + )); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + (signs >> 28))); + const uint32_t sign7 = bitfieldExtract(signs, 7 * int(ib8), 7); + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xxs_grid[qs]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib32 = (idx % 32) / 4; // 0..7 + const uint ib8 = idx % 4; // 0..3 + + const float d = float(data_a[ib].d); + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uint qs = data_a[ib].qs[4 * ib32 + ib8]; + const uint sign7 = qs >> 9; + const uint sign = sign7 | (bitCount(sign7) << 7); + const uvec2 grid = iq2xs_grid[qs & 511]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ2_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 32; // 8 values per idx + const uint ib8 = idx % 32; // 0..31 + const uint ib32 = ib8 / 4; // 0..7 + + const uint scale = (data_a[ib].scales[ib32] >> (2 * (ib8 & 2))) & 0xf; + const uint qs = data_a[ib].qs[ib8]; + const uint qh = data_a[ib].qh[ib32]; + const uint qhshift = 2 * (ib8 % 4); + const uint sign = data_a[ib].qs[QUANT_K / 8 + ib8]; + + const float d = float(data_a[ib].d); + const FLOAT_TYPE db = FLOAT_TYPE(d * 0.25 * (0.5 + scale)); + const uvec2 grid = iq2s_grid[qs | ((qh << (8 - qhshift)) & 0x300)]; + const vec4 grid0 = vec4(unpack8(grid.x)); + const vec4 grid1 = vec4(unpack8(grid.y)); + + buf_a[buf_idx ] = db * FLOAT_TYPE_VEC2((sign & 1) != 0 ? -grid0.x : grid0.x, + (sign & 2) != 0 ? -grid0.y : grid0.y); + buf_a[buf_idx + 1] = db * FLOAT_TYPE_VEC2((sign & 4) != 0 ? -grid0.z : grid0.z, + (sign & 8) != 0 ? -grid0.w : grid0.w); + buf_a[buf_idx + 2] = db * FLOAT_TYPE_VEC2((sign & 16) != 0 ? -grid1.x : grid1.x, + (sign & 32) != 0 ? -grid1.y : grid1.y); + buf_a[buf_idx + 3] = db * FLOAT_TYPE_VEC2((sign & 64) != 0 ? -grid1.z : grid1.z, + (sign & 128) != 0 ? -grid1.w : grid1.w); +#elif defined(DATA_A_IQ3_XXS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint is = QUANT_K / 4 + 4 * (iqs / 8); // 8 values + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint signs = pack32(u8vec4( + data_a[ib].qs[is+0], + data_a[ib].qs[is+1], + data_a[ib].qs[is+2], + data_a[ib].qs[is+3] + )); + const float db = d * 0.5 * (0.5 + (signs >> 28)); + const uint32_t sign7 = bitfieldExtract(signs, 7 * (int(iqs / 2) % 4), 7); + const uint sign = (sign7 | (bitCount(sign7) << 7)) >> (4 * (idx % 2)); + const uint grid = iq3xxs_grid[qs]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ3_S) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 64; // 4 values per idx + const uint iqs = idx % 64; // 0..63 + const uint iqh = iqs / 8; + + const float d = float(data_a[ib].d); + const uint qs = data_a[ib].qs[iqs]; + const uint qh = data_a[ib].qh[iqh]; + const int8_t sign = int8_t(data_a[ib].signs[iqs / 2] >> (4 * (idx % 2))); + const uint scale = data_a[ib].scales[iqs / 16]; + const i8vec2 sign01 = i8vec2(1 - (2 & i8vec2(sign << 1, sign))); + const float db = d * (1 + 2 * ((scale >> (4 * (iqh & 1))) & 0xf)); + const uint32_t grid = iq3s_grid[qs | ((qh << (8 - (iqs % 8))) & 256)]; + const vec4 v = db * vec4(unpack8(grid)); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2((sign & 1) != 0 ? -v.x : v.x, + (sign & 2) != 0 ? -v.y : v.y); + buf_a[buf_idx + 1] = FLOAT_TYPE_VEC2((sign & 4) != 0 ? -v.z : v.z, + (sign & 8) != 0 ? -v.w : v.w); +#elif defined(DATA_A_IQ4_XS) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_A / 2; + + const uint ib = idx / 128; // 2 values per idx + const uint ib32 = (idx % 128) / 16; // 0..7 + const uint iq = 16 * ib32 + 2 * (idx % 8); + + const uint sl = (data_a[ib].scales_l[ib32/2] >> (4 * (ib32 & 1))) & 0xF; + const uint sh = ((data_a[ib].scales_h) >> (2 * ib32)) & 3; + const uint qshift = (idx & 8) >> 1; + u8vec2 qs = u8vec2(data_a[ib].qs[iq], data_a[ib].qs[iq + 1]); + qs = (qs >> qshift) & uint8_t(0xF); + + const float d = float(data_a[ib].d); + const vec2 v = d * float(int(sl | (sh << 4)) - 32) * vec2(kvalues_iq4nl[qs.x], kvalues_iq4nl[qs.y]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(v.xy); +#elif defined(DATA_A_IQ4_NL) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = idx & 0x07; + + const FLOAT_TYPE d = FLOAT_TYPE(data_a_packed16[ib].d); + const uint vui = uint(data_a_packed16[ib].qs[iqs]); + + buf_a[buf_idx ] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[vui & 0xF], + kvalues_iq4nl[bitfieldExtract(vui, 8, 4)]); + buf_a[buf_idx + 8] = d * FLOAT_TYPE_VEC2(kvalues_iq4nl[bitfieldExtract(vui, 4, 4)], + kvalues_iq4nl[vui >> 12]); +#elif defined(DATA_A_MXFP4) + const uint idx = pos_a + col * p.stride_a / LOAD_VEC_A + row; + const uint buf_idx = col * SHMEM_STRIDE + row; + + const uint ib = idx / 8; + const uint iqs = (idx & 0x07) * 2; + + const float d = e8m0_to_fp32(data_a[ib].e); + const uint vui = uint(data_a[ib].qs[iqs]); + const uint vui2 = uint(data_a[ib].qs[iqs+1]); + + buf_a[buf_idx ] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui & 0xF] * d, + kvalues_mxfp4[vui2 & 0xF] * d); + buf_a[buf_idx + 8] = FLOAT_TYPE_VEC2(kvalues_mxfp4[vui >> 4] * d, + kvalues_mxfp4[vui2 >> 4] * d); +#endif +} + +#if !defined(MUL_MAT_ID) +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint idx_n, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const uint idx = pos_b + col * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_BATCH_B == 2 + const uint idx = pos_b + col * p.stride_b + row * 2; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (idx_n < p.N && block + row * 2 + 1 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (idx_n < p.N && block + row * 2 < end_k) { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#else +void load_b_to_shmem(const uint pos_b, const uint row, const uint col, const uint ic, const uint _ne1, const uint block, const uint end_k) { +#if LOAD_VEC_B == 8 + // Not supported for b_type bf16 because bf16mat2x4 does not exist + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; + FLOAT_TYPE_VEC8 bb = FLOAT_TYPE_VEC8(data_b[idx]); + buf_b[buf_idx + 0] = bb[0].xy; + buf_b[buf_idx + 1] = bb[0].zw; + buf_b[buf_idx + 2] = bb[1].xy; + buf_b[buf_idx + 3] = bb[1].zw; +#elif LOAD_VEC_B == 4 + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + row; + const uint buf_idx = col * SHMEM_STRIDE + row * LOAD_VEC_B / 2; +#if defined(DATA_B_BF16) + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(TO_FLOAT_TYPE(data_b[idx])); +#else + FLOAT_TYPE_VEC4 bb = FLOAT_TYPE_VEC4(data_b[idx]); +#endif + buf_b[buf_idx + 0] = bb.xy; + buf_b[buf_idx + 1] = bb.zw; +#else // LOAD_VEC_BATCH_B == 2 + const uint row_i = ic * BN + col; + const uint buf_idx = col * SHMEM_STRIDE + row; + if (row_i < _ne1 && block + row * 2 + 1 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), + TO_FLOAT_TYPE(data_b[idx + 1])); + } else if (row_i < _ne1 && block + row * 2 < end_k) { + const u16vec2 row_idx = row_ids[col]; + const uint idx = pos_b + row_idx.y * p.batch_stride_b + (row_idx.x % p.ne11) * p.stride_b + row * 2; + buf_b[buf_idx] = FLOAT_TYPE_VEC2(TO_FLOAT_TYPE(data_b[idx]), 0.0f); + } else { + buf_b[buf_idx] = FLOAT_TYPE_VEC2(0.0f); + } +#endif +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp new file mode 100644 index 0000000000..b5d761c0ba --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq.comp @@ -0,0 +1,449 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#extension GL_EXT_integer_dot_product : require + +#ifdef FLOAT16 +#extension GL_EXT_shader_explicit_arithmetic_types_float16 : require +#endif + +#ifdef COOPMAT +#extension GL_KHR_cooperative_matrix : enable +#extension GL_KHR_memory_scope_semantics : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#ifdef MUL_MAT_ID +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#endif + +#include "types.glsl" + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE_PACKED16 data_a[];}; +#if defined(A_TYPE_PACKED32) +layout (binding = 0) readonly buffer A_PACKED32 {A_TYPE_PACKED32 data_a_packed32[];}; +#endif +layout (binding = 1) readonly buffer B {block_q8_1_x4_packed128 data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +#ifdef MUL_MAT_ID +layout (binding = 3) readonly buffer IDS {int data_ids[];}; +#endif + +layout (push_constant) uniform parameter +{ + uint M; + uint N; + uint K; + uint stride_a; + uint stride_b; + uint stride_d; + + uint batch_stride_a; + uint batch_stride_b; + uint batch_stride_d; + +#ifdef MUL_MAT_ID + uint nei0; + uint nei1; + uint nbi1; + uint ne11; +#else + uint k_split; + uint ne02; + uint ne12; + uint broadcast2; + uint broadcast3; +#endif +} p; + +layout (constant_id = 0) const uint BLOCK_SIZE = 64; +layout (constant_id = 1) const uint BM = 64; +layout (constant_id = 2) const uint BN = 64; +// layout (constant_id = 3) const uint BK = 32; +layout (constant_id = 4) const uint WM = 32; +layout (constant_id = 5) const uint WN = 32; +layout (constant_id = 6) const uint WMITER = 2; +layout (constant_id = 7) const uint TM = 4; +layout (constant_id = 8) const uint TN = 2; +layout (constant_id = 9) const uint TK = 1; // Only needed for coopmat +layout (constant_id = 10) const uint WARP = 32; + +#define BK 32 + +#ifdef COOPMAT +#define SHMEM_STRIDE (BK / 4 + 4) +#else +#define SHMEM_STRIDE (BK / 4 + 1) +#endif + +shared int32_t buf_a_qs[BM * SHMEM_STRIDE]; + +#ifndef COOPMAT +#if QUANT_AUXF == 1 +shared FLOAT_TYPE buf_a_dm[BM]; +#else +shared FLOAT_TYPE_VEC2 buf_a_dm[BM]; +#endif +#endif + +shared int32_t buf_b_qs[BN * SHMEM_STRIDE]; +#ifndef COOPMAT +shared FLOAT_TYPE_VEC2 buf_b_ds[BN]; +#endif + +#define LOAD_VEC_A (4 * QUANT_R) +#define LOAD_VEC_B 16 + +#ifdef MUL_MAT_ID +shared u16vec2 row_ids[4096]; +#endif // MUL_MAT_ID + +#define NUM_WARPS (BLOCK_SIZE / WARP) + +#ifdef COOPMAT +shared ACC_TYPE coopmat_stage[TM * TN * NUM_WARPS]; +#endif + +#include "mul_mmq_funcs.glsl" + +void main() { +#ifdef NEEDS_INIT_IQ_SHMEM + init_iq_shmem(gl_WorkGroupSize); +#endif + +#ifdef MUL_MAT_ID + const uint expert_idx = gl_GlobalInvocationID.z; +#else + const uint batch_idx = gl_GlobalInvocationID.z; + + const uint i13 = batch_idx / p.ne12; + const uint i12 = batch_idx % p.ne12; + + const uint i03 = i13 / p.broadcast3; + const uint i02 = i12 / p.broadcast2; + + const uint batch_idx_a = i03 * p.ne02 + i02; +#endif + + const uint blocks_m = (p.M + BM - 1) / BM; + const uint ir = gl_WorkGroupID.x % blocks_m; + const uint ik = gl_WorkGroupID.x / blocks_m; + const uint ic = gl_WorkGroupID.y; + + const uint WNITER = (WM * WN) / (WARP * TM * TN * WMITER); + const uint WSUBM = WM / WMITER; + const uint WSUBN = WN / WNITER; + +#ifdef COOPMAT + const uint warp_i = gl_SubgroupID; + + const uint tiw = gl_SubgroupInvocationID; + + const uint cms_per_row = WM / TM; + const uint cms_per_col = WN / TN; + + const uint storestride = WARP / TM; + const uint store_r = tiw % TM; + const uint store_c = tiw / TM; +#else + const uint warp_i = gl_LocalInvocationID.x / WARP; + + const uint tiw = gl_LocalInvocationID.x % WARP; + + const uint tiwr = tiw % (WSUBM / TM); + const uint tiwc = tiw / (WSUBM / TM); +#endif + + const uint warp_r = warp_i % (BM / WM); + const uint warp_c = warp_i / (BM / WM); + + const uint loadr_a = gl_LocalInvocationID.x % (BK / LOAD_VEC_A); + const uint loadc_a = gl_LocalInvocationID.x / (BK / LOAD_VEC_A); + const uint loadr_b = gl_LocalInvocationID.x % (BK / LOAD_VEC_B); + const uint loadc_b = gl_LocalInvocationID.x / (BK / LOAD_VEC_B); + + const uint loadstride_a = BLOCK_SIZE * LOAD_VEC_A / BK; + const uint loadstride_b = BLOCK_SIZE * LOAD_VEC_B / BK; + +#ifdef MUL_MAT_ID + uint _ne1 = 0; + for (uint ii1 = 0; ii1 < p.nei1; ii1++) { + for (uint ii0 = 0; ii0 < p.nei0; ii0++) { + if (data_ids[ii1*p.nbi1 + ii0] == expert_idx) { + row_ids[_ne1] = u16vec2(ii0, ii1); + _ne1++; + } + } + } + + barrier(); + + // Workgroup has no work + if (ic * BN >= _ne1) return; +#endif + +#ifdef MUL_MAT_ID + const uint start_k = 0; + const uint end_k = p.K; +#else + const uint start_k = ik * p.k_split; + const uint end_k = min(p.K, (ik + 1) * p.k_split); +#endif + + uint pos_a_ib = ( +#ifdef MUL_MAT_ID + expert_idx * p.batch_stride_a + +#else + batch_idx_a * p.batch_stride_a + +#endif + ir * BM * p.stride_a + start_k) / BK; +#ifdef MUL_MAT_ID + uint pos_b_ib = 0; +#else + uint pos_b_ib = (batch_idx * p.batch_stride_b + ic * BN * p.stride_b + start_k) / BK; +#endif + +#ifdef COOPMAT + coopmat cache_a; + coopmat cache_b; + coopmat cm_result; + + coopmat factors[cms_per_row * cms_per_col]; + + coopmat sums[cms_per_row * cms_per_col]; + + [[unroll]] for (uint i = 0; i < cms_per_row * cms_per_col; i++) { + sums[i] = coopmat(0.0f); + } +#else + int32_t cache_a_qs[WMITER * TM * BK / 4]; + + int32_t cache_b_qs[TN * BK / 4]; + + ACC_TYPE sums[WMITER * TM * WNITER * TN]; + + [[unroll]] for (uint i = 0; i < WMITER*TM*WNITER*TN; i++) { + sums[i] = ACC_TYPE(0.0f); + } +#endif + +#if QUANT_AUXF == 1 + FLOAT_TYPE cache_a_dm[WMITER * TM]; +#else + FLOAT_TYPE_VEC2 cache_a_dm[WMITER * TM]; +#endif + + FLOAT_TYPE_VEC2 cache_b_ds[TN]; + + for (uint block = start_k; block < end_k; block += BK) { + [[unroll]] for (uint l = 0; loadc_a + l < BM; l += loadstride_a) { + const uint ib = pos_a_ib + (loadc_a + l) * p.stride_a / BK; + const uint iqs = loadr_a; + const uint buf_ib = loadc_a + l; + + if (iqs == 0) { +#if QUANT_AUXF == 1 + buf_a_dm[buf_ib] = get_d(ib); +#else + buf_a_dm[buf_ib] = get_dm(ib); +#endif + } +#if QUANT_R == 1 + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs] = repack(ib, iqs); +#else + const i32vec2 vals = repack(ib, iqs); + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs ] = vals.x; + buf_a_qs[buf_ib * SHMEM_STRIDE + iqs + 4] = vals.y; +#endif + } + [[unroll]] for (uint l = 0; loadc_b + l < BN; l += loadstride_b) { +#ifdef MUL_MAT_ID + const u16vec2 row_idx = row_ids[ic * BN + loadc_b + l]; + const uint idx = pos_b_ib + row_idx.y * p.batch_stride_b / LOAD_VEC_B + (row_idx.x % p.ne11) * p.stride_b / LOAD_VEC_B + loadr_b; + const uint ib = idx / 8; + const uint iqs = idx & 0x7; +#else + const uint ib = pos_b_ib + (loadc_b + l) * p.stride_b / BK; + const uint ib_outer = ib / 4; + const uint ib_inner = ib % 4; + + const uint iqs = loadr_b; +#endif + + const uint buf_ib = loadc_b + l; + + if (iqs == 0) { + buf_b_ds[buf_ib] = FLOAT_TYPE_VEC2(data_b[ib_outer].ds[ib_inner]); + } + const ivec4 values = data_b[ib_outer].qs[ib_inner * 2 + iqs]; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 ] = values.x; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 1] = values.y; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 2] = values.z; + buf_b_qs[buf_ib * SHMEM_STRIDE + iqs * 4 + 3] = values.w; + } + + barrier(); + + pos_a_ib += 1; + pos_b_ib += 1; + +#ifdef COOPMAT + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + const uint ib_a = warp_r * WM + cm_row * TM; + // Load from shared into cache + coopMatLoad(cache_a, buf_a_qs, ib_a * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutRowMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TM; t_idx++) { + cache_a_dm[t_idx] = buf_a_dm[ib_a + t_idx]; + } + + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const uint ib_b = warp_c * WN + cm_col * TN; + coopMatLoad(cache_b, buf_b_qs, ib_b * SHMEM_STRIDE, SHMEM_STRIDE, gl_CooperativeMatrixLayoutColumnMajor); + + // TODO: only cache values that are actually needed + [[unroll]] for (uint t_idx = 0; t_idx < TN; t_idx++) { + cache_b_dm[t_idx] = buf_b_d[ib_b + t_idx]; + } + + cm_result = coopmat(0); + cm_result = coopMatMulAdd(cache_a, cache_b, cm_result); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + coopmat_stage[warp_i * TM * TN + (store_c + col) * TM + store_r] = ACC_TYPE(float(cache_a_d[store_r]) * float(cache_b_d[store_c + col])); + } + + coopMatLoad(factors, coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + sums[cm_col * cms_per_row + cm_row] += factors * coopmat(cm_result); + } + } +#else + // Load from shared into cache + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint ib = warp_r * WM + wsir * WSUBM + tiwr * TM + cr; + cache_a_dm[wsir * TM + cr] = buf_a_dm[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_a_qs[(wsir * TM + cr) * (BK / 4) + idx_k] = buf_a_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + } + + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + const uint ib = warp_c * WN + wsic * WSUBN + tiwc * TN + cc; + cache_b_ds[cc] = buf_b_ds[ib]; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + cache_b_qs[cc * (BK / 4) + idx_k] = buf_b_qs[ib * SHMEM_STRIDE + idx_k]; + } + } + + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + [[unroll]] for (uint cc = 0; cc < TN; cc++) { + [[unroll]] for (uint cr = 0; cr < TM; cr++) { + const uint cache_a_idx = wsir * TM + cr; + const uint sums_idx = (wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr; + int32_t q_sum = 0; + [[unroll]] for (uint idx_k = 0; idx_k < BK / 4; idx_k++) { + q_sum += dotPacked4x8EXT(cache_a_qs[cache_a_idx * (BK / 4) + idx_k], + cache_b_qs[cc * (BK / 4) + idx_k]); + } + + sums[sums_idx] += mul_q8_1(q_sum, cache_a_dm[cache_a_idx], cache_b_ds[cc], 1); + } + } + } + } +#endif + + barrier(); + } + + const uint dr = ir * BM + warp_r * WM; + const uint dc = ic * BN + warp_c * WN; + +#ifndef MUL_MAT_ID + const uint offsets = batch_idx * p.batch_stride_d + ik * p.batch_stride_d * gl_NumWorkGroups.z; +#endif + +#ifdef COOPMAT +#ifdef MUL_MAT_ID + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < BN; col += storestride) { + const uint row_i = dc + cm_col * TN + col + store_c; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; + + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } +#else + const bool is_aligned = p.stride_d % 4 == 0; // Assumption: D_TYPE == float + + [[unroll]] for (uint cm_row = 0; cm_row < cms_per_row; cm_row++) { + [[unroll]] for (uint cm_col = 0; cm_col < cms_per_col; cm_col++) { + const bool is_in_bounds = dr + (cm_row + 1) * TM <= p.M && dc + (cm_col + 1) * TN <= p.N; + + if (is_aligned && is_in_bounds) { + // Full coopMat is within bounds and stride_d is aligned with 16B + coopmat cm_dtype = coopmat(sums[cm_col * cms_per_row + cm_row]); + coopMatStore(cm_dtype, data_d, offsets + (dc + cm_col * TN) * p.stride_d + dr + cm_row * TM, p.stride_d, gl_CooperativeMatrixLayoutColumnMajor); + } else if (is_in_bounds) { + // Full coopMat is within bounds, but stride_d is not aligned + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } else if (dr + cm_row * TM < p.M && dc + cm_col * TN < p.N) { + // Partial coopMat is within bounds + coopMatStore(sums[cm_col * cms_per_row + cm_row], coopmat_stage, warp_i * TM * TN, TM, gl_CooperativeMatrixLayoutColumnMajor); + + [[unroll]] for (uint col = 0; col < TN; col += storestride) { + if (dr + cm_row * TM + store_r < p.M && dc + cm_col * TN + col + store_c < p.N) { + data_d[offsets + (dc + cm_col * TN + col + store_c) * p.stride_d + dr + cm_row * TM + store_r] = D_TYPE(coopmat_stage[warp_i * TM * TN + (col + store_c) * TM + store_r]); + } + } + } + } + } +#endif // MUL_MAT_ID +#else + [[unroll]] for (uint wsic = 0; wsic < WNITER; wsic++) { + [[unroll]] for (uint wsir = 0; wsir < WMITER; wsir++) { + + const uint dr_warp = dr + wsir * WSUBM + tiwr * TM; + const uint dc_warp = dc + wsic * WSUBN + tiwc * TN; + [[unroll]] for (uint cc = 0; cc < TN; cc++) { +#ifdef MUL_MAT_ID + const uint row_i = dc_warp + cc; + if (row_i >= _ne1) break; + + const u16vec2 row_idx = row_ids[row_i]; +#endif // MUL_MAT_ID + [[unroll]] for (uint cr = 0; cr < TM; cr++) { +#ifdef MUL_MAT_ID + data_d[row_idx.y * p.batch_stride_d + row_idx.x * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); +#else + if (dr_warp + cr < p.M && dc_warp + cc < p.N) { + data_d[offsets + (dc_warp + cc) * p.stride_d + dr_warp + cr] = D_TYPE(sums[(wsic * TN + cc) * (WMITER * TM) + wsir * TM + cr]); + } +#endif // MUL_MAT_ID + } + } + } + } +#endif // COOPMAT +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl new file mode 100644 index 0000000000..fe71eb131c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/mul_mmq_funcs.glsl @@ -0,0 +1,105 @@ +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require + +#include "types.glsl" + +// Each iqs value maps to a 32-bit integer + +#if defined(DATA_A_Q4_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q4_0 block (18 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (8 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q4_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q4_1 block (20 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + return i32vec2( vui & 0x0F0F0F0F, + (vui >> 4) & 0x0F0F0F0F); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q5_0) +i32vec2 repack(uint ib, uint iqs) { + // Use 2-byte loads since a q5_0 block (22 bytes) is not divisible by 4 + const u16vec2 quants = u16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1]); + const uint32_t vui = pack32(quants); + const int32_t qh = int32_t((uint32_t(data_a[ib].qh[1]) << 16 | data_a[ib].qh[0]) >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(da * (float(q_sum) * dsb.x - (16 / sum_divisor) * dsb.y)); +} +#endif + +#if defined(DATA_A_Q5_1) +i32vec2 repack(uint ib, uint iqs) { + // Use 4-byte loads since a q5_1 block (24 bytes) is divisible by 4 + const uint32_t vui = data_a_packed32[ib].qs[iqs]; + const int32_t qh = int32_t(data_a_packed32[ib].qh >> (4 * iqs)); + const int32_t v0 = int32_t(vui & 0x0F0F0F0F) + | ((qh & 0xF) * 0x02040810) & 0x10101010; // (0,1,2,3) -> (4,12,20,28) + + const int32_t v1 = int32_t((vui >> 4) & 0x0F0F0F0F) + | (((qh >> 16) & 0xF) * 0x02040810) & 0x10101010; // (16,17,18,19) -> (4,12,20,28) + + return i32vec2(v0, v1); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const vec2 dma, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * dma.x * dsb.x + dma.y * dsb.y / sum_divisor); +} +#endif + +#if defined(DATA_A_Q8_0) +int32_t repack(uint ib, uint iqs) { + // Use 2-byte loads since a q8_0 block (34 bytes) is not divisible by 4 + return pack32(i16vec2(data_a[ib].qs[iqs * 2 ], + data_a[ib].qs[iqs * 2 + 1])); +} + +ACC_TYPE mul_q8_1(const int32_t q_sum, const float da, const vec2 dsb, const int32_t sum_divisor) { + return ACC_TYPE(float(q_sum) * da * dsb.x); +} +#endif + +#if defined(DATA_A_Q4_0) || defined(DATA_A_Q5_0) || defined(DATA_A_Q8_0) || defined(DATA_A_IQ1_S) || defined(DATA_A_IQ2_XXS) || defined(DATA_A_IQ2_XS) || defined(DATA_A_IQ2_S) || defined(DATA_A_IQ3_XXS) || defined(DATA_A_IQ3_S) || defined(DATA_A_IQ4_XS) || defined(DATA_A_IQ4_NL) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(data_a[ib].d); +} +#endif + +#if defined(DATA_A_MXFP4) +FLOAT_TYPE get_d(uint ib) { + return FLOAT_TYPE(e8m0_to_fp32(data_a[ib].e)); +} +#endif + +#if defined(DATA_A_Q4_1) || defined(DATA_A_Q5_1) +FLOAT_TYPE_VEC2 get_dm(uint ib) { + return FLOAT_TYPE_VEC2(data_a_packed32[ib].dm); +} +#endif diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp new file mode 100644 index 0000000000..1e8f694a72 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/multi_add.comp @@ -0,0 +1,111 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require +#extension GL_EXT_nonuniform_qualifier : enable +#extension GL_EXT_control_flow_attributes : require +#if ADD_RMS +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable +#endif + +#include "rte.glsl" +#include "types.glsl" +#include "utils.glsl" + +layout (push_constant) uniform parameter2 +{ + // shape for dst + uint ne20; uint ne21; uint ne22; uint ne23; + + // strides for srcs+dst + uint nb[12][4]; + + uint rms_partials; +} p; + +// Workaround for MoltenVK Bug, see https://github.com/ggml-org/llama.cpp/issues/15498 +// layout (binding = 0) readonly buffer A {A_TYPE data_a[];} a[]; +// layout (binding = 0) writeonly buffer D {D_TYPE data_d[];} d[]; +layout (binding = 0) buffer A {A_TYPE data_a[];} a[]; +layout (binding = 0) buffer D {D_TYPE data_d[];} d[]; + +layout (binding = 0, std430) buffer PartialBuf {float partial_sums[];} partials[]; + +layout(constant_id = 0) const uint num_srcs = 2; + +uint src_idx(uint s, uint i00, uint i01, uint i02, uint i03) { + return i03*p.nb[s][3] + i02*p.nb[s][2] + i01*p.nb[s][1] + i00*p.nb[s][0]; +} + +uint dst_idx(uint i00, uint i01, uint i02, uint i03) { + uint nb20 = p.nb[num_srcs][0]; + uint nb21 = p.nb[num_srcs][1]; + uint nb22 = p.nb[num_srcs][2]; + uint nb23 = p.nb[num_srcs][3]; + return i03*nb23 + i02*nb22 + i01*nb21 + i00*nb20; +} + +uint get_idx() { + return gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; +} + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +#if ADD_RMS +// XXX TODO this could be sized based on number of subgroups, but that't not considered a constant +shared FLOAT_TYPE sumsh[num_threads]; +#endif + +void main() { + uint idx = get_idx(); + uint orig_idx = idx; + + uint ne = p.ne20 * p.ne21 * p.ne22 * p.ne23; + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + FLOAT_TYPE sum_sq = 0; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03, p.ne20, p.ne21, p.ne22, p.ne23); + + FLOAT_TYPE sum = FLOAT_TYPE(0); + [[unroll]] for (uint s = 0; s < num_srcs; ++s) { + sum += FLOAT_TYPE(a[s].data_a[src_idx(s, i00, i01, i02, i03)]); + } + sum_sq += sum*sum; + d[num_srcs].data_d[dst_idx(i00, i01, i02, i03)] = D_TYPE(sum); + + idx += num_threads; + } + +#if ADD_RMS + if (p.rms_partials != 0) { + // reduce the sum within each subgroup, then across subgroups + const uint NumSubgroups = num_threads / gl_SubgroupSize; + sum_sq = subgroupAdd(sum_sq); + if (gl_SubgroupInvocationID == 0) { + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + [[unroll]] for (uint s = NumSubgroups / 2; s > 0; s >>= 1) { + if (gl_SubgroupID < s && gl_SubgroupInvocationID == 0) { + sum_sq += sumsh[gl_SubgroupID + s]; + sumsh[gl_SubgroupID] = sum_sq; + } + barrier(); + } + + if (gl_SubgroupID == 0 && gl_SubgroupInvocationID == 0) { + partials[num_srcs + 1].partial_sums[orig_idx / (num_iter * num_threads)] = sum_sq; + } + } +#endif +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp new file mode 100644 index 0000000000..cc3ea0b760 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/norm.comp @@ -0,0 +1,44 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +shared vec2 sum[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + sum[tid] = vec2(0.0f, 0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const float xi = float(data_a[row*p.KX + col]); + sum[tid].x += xi; + sum[tid].y += xi * xi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum[tid] += sum[tid + s]; + } + barrier(); + } + + const float mean = sum[0].x / p.KX; + const float var = sum[0].y / p.KX - mean * mean; + const float inv_std = inversesqrt(var + p.param1); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE((float(data_a[row*p.KX + col]) - mean) * inv_std); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp new file mode 100644 index 0000000000..1f05f922cc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_adamw.comp @@ -0,0 +1,42 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE x[];}; +layout (binding = 1) readonly buffer G {A_TYPE grad[];}; +layout (binding = 2) buffer GM {A_TYPE gradm[];}; +layout (binding = 3) buffer GV {A_TYPE gradv[];}; +layout (binding = 4) readonly buffer P {float params[7];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = params[0]; + const float beta1 = params[1]; + const float beta2 = params[2]; + const float eps = params[3]; + const float wd = params[4]; + const float beta1h = params[5]; + const float beta2h = params[6]; + + const float gi = grad[i]; + const float gmi = gradm[i]*beta1 + gi*(1.0f - beta1); + const float gvi = gradv[i]*beta2 + gi*gi*(1.0f - beta2); + + gradm[i] = gmi; + gradv[i] = gvi; + + const float mh = gmi*beta1h; + const float vh = sqrt(gvi*beta2h) + eps; + + x[i] = x[i]*(1.0f - alpha*wd) - alpha*mh/vh; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp new file mode 100644 index 0000000000..1251f9cc64 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/opt_step_sgd.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) buffer X {A_TYPE data_x[];}; +layout (binding = 1) readonly buffer G {A_TYPE data_grad[];}; +layout (binding = 2) readonly buffer P {float data_params[2];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float alpha = data_params[0]; + const float keep = 1.f - alpha * data_params[1]; + + data_x[i] = data_x[i] * keep - alpha * data_grad[i]; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp new file mode 100644 index 0000000000..f3c8176872 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pad.comp @@ -0,0 +1,49 @@ +#version 450 + +#include "types.glsl" + +layout (push_constant) uniform parameter +{ + uint ne; + uint ne00; uint ne01; uint ne02; uint ne03; uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; uint nb10; uint nb11; uint nb12; uint nb13; + uint misalign_offsets; + + uint lp0; uint rp0; + uint lp1; uint rp1; + uint lp2; uint rp2; + uint lp3; uint rp3; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i3 = idx / (p.ne12*p.ne11*p.ne10); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = (idx - i3_offset) / (p.ne11*p.ne10); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = (idx - i3_offset - i2_offset) / p.ne10; + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint src0_idx = (i3 - p.lp3)*p.nb03 + (i2 - p.lp2)*p.nb02 + (i1 - p.lp1)*p.nb01 + (i0 - p.lp0)*p.nb00; + const uint dst_idx = i3*p.nb13 + i2*p.nb12 + i1*p.nb11 + i0*p.nb10; + + const bool is_src0 = i0 >= p.lp0 && i0 < p.ne10 - p.rp0 && + i1 >= p.lp1 && i1 < p.ne11 - p.rp1 && + i2 >= p.lp2 && i2 < p.ne12 - p.rp2 && + i3 >= p.lp3 && i3 < p.ne13 - p.rp3; + + data_d[get_doffset() + dst_idx] = D_TYPE(is_src0 ? data_a[get_aoffset() + src0_idx] : 0.0f); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp new file mode 100644 index 0000000000..d9d7166e36 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/pool2d.comp @@ -0,0 +1,74 @@ +#version 450 + +#include "types.glsl" + +#extension GL_EXT_shader_16bit_storage : require + +layout(push_constant) uniform parameter { + uint IW; uint IH; + uint OW; uint OH; + uint OC; + uint pelements; + uint op; + int k0; int k1; + int s0; int s1; + int p0; int p1; +} p; + +#define BLOCK_SIZE 512 +#define FLT_MAX 3.402823466e+38F +#define OP_POOL_MAX 0u +#define OP_POOL_AVG 1u + +layout (local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout(binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint idx = gl_GlobalInvocationID.x; + if (idx >= p.pelements) { + return; + } + + const uint O_HW = p.OW * p.OH; + + const uint nc = idx / O_HW; + const uint cur_oh = (idx % O_HW) / p.OW; + const uint cur_ow = (idx % O_HW) % p.OW; + + const int start_h = int(cur_oh) * p.s0 - p.p0; + const uint bh = max(start_h, 0); + const uint eh = min(start_h + p.k0, p.IH); + + const int start_w = int(cur_ow) * p.s1 - p.p1; + const uint bw = max(start_w, 0); + const uint ew = min(start_w + p.k1, p.IW); + + const float scale = 1.0 / float(p.k0 * p.k1); + float res; + + if (p.op == OP_POOL_AVG) { + res = 0.0; + } else if (p.op == OP_POOL_MAX) { + res = -FLT_MAX; + } else { + return; + } + + #pragma unroll + for (uint i = bh; i < eh; i++) { + #pragma unroll + for (uint j = bw; j < ew; j++) { + const float cur = D_TYPE(data_a[nc * p.IH * p.IW + i * p.IW + j]); + + if (p.op == OP_POOL_AVG) { + res += cur * scale; + } else if (p.op == OP_POOL_MAX) { + res = max(res, cur); + } + } + } + + data_d[nc * O_HW + cur_oh * p.OW + cur_ow] = res; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp new file mode 100644 index 0000000000..0f3c6ca871 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/quantize_q8_1.comp @@ -0,0 +1,127 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require +#extension GL_EXT_shader_16bit_storage : require + +#ifdef USE_SUBGROUPS +#extension GL_KHR_shader_subgroup_basic : require +#extension GL_KHR_shader_subgroup_clustered : require + +#define INVOCATION_ID gl_SubgroupInvocationID.x +#else +#define INVOCATION_ID gl_LocalInvocationID.x +#endif + +layout (push_constant) uniform parameter +{ + uint ne; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint GROUP_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {vec4 data_a[];}; +#ifndef QBLOCK_X4 +layout (binding = 1) writeonly buffer D {block_q8_1_packed32 data_b[];}; +#else +layout (binding = 1) writeonly buffer D {block_q8_1_x4 data_b[];}; +#endif + +#ifndef USE_SUBGROUPS +shared float shmem[GROUP_SIZE]; +#endif + +void quantize() { + const uint wgid = gl_WorkGroupID.x; + const uint tid = INVOCATION_ID; + + // Each thread handles a vec4, so 8 threads handle a block + const uint blocks_per_group = GROUP_SIZE / 8; + + const uint block_in_wg = tid / 8; + + const uint ib = wgid * blocks_per_group + block_in_wg; + const uint iqs = tid % 8; + +#ifndef QBLOCK_X4 + if (ib >= gl_NumWorkGroups.x * blocks_per_group) { + return; + } +#else + const uint ibx4_outer = ib / 4; + const uint ibx4_inner = ib % 4; + + const uint required_x4_blocks = (p.ne + 127) / 128; + if (ibx4_outer >= required_x4_blocks) { + return; + } +#endif + + const uint a_idx = ib * 8 + iqs; + + vec4 vals = a_idx < p.ne ? data_a[a_idx] : vec4(0.0f); + const vec4 abs_vals = abs(vals); + + // Find absolute max for each block + const float thread_max = max(max(abs_vals.x, abs_vals.y), max(abs_vals.z, abs_vals.w)); +#ifndef USE_SUBGROUPS + shmem[tid] = thread_max; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] = max(shmem[tid], shmem[tid + s]); + } + barrier(); + } + + const float amax = shmem[block_in_wg * 8]; +#else + const float amax = subgroupClusteredMax(thread_max, 8); +#endif + + const float d = amax / 127.0; + const float d_inv = d != 0.0 ? 1.0 / d : 0.0; + vals = round(vals * d_inv); + +#ifndef QBLOCK_X4 + data_b[ib].qs[iqs] = pack32(i8vec4(round(vals))); +#else + data_b[ibx4_outer].qs[ibx4_inner * 8 + iqs] = pack32(i8vec4(round(vals))); +#endif + +#ifndef USE_SUBGROUPS + barrier(); +#endif + + // Calculate the sum for each block + const float thread_sum = vals.x + vals.y + vals.z + vals.w; +#ifndef USE_SUBGROUPS + shmem[tid] = thread_sum; + barrier(); + [[unroll]] for (uint s = 4; s > 0; s >>= 1) { + if (iqs < s) { + shmem[tid] += shmem[tid + s]; + } + barrier(); + } +#else + const float sum = subgroupClusteredAdd(thread_sum, 8); +#endif + if (iqs == 0) { +#ifndef USE_SUBGROUPS + const float sum = shmem[tid]; +#endif + +#ifndef QBLOCK_X4 + data_b[ib].ds = f16vec2(vec2(d, sum * d)); +#else + data_b[ibx4_outer].ds[ibx4_inner] = f16vec2(vec2(d, sum * d)); +#endif + } +} + +void main() { + quantize(); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp new file mode 100644 index 0000000000..86be2669a1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/reglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.glsl" + +float op(float a, float b) { + return max(a, 0.0f) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp new file mode 100644 index 0000000000..5725cef236 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/relu.comp @@ -0,0 +1,21 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + data_d[i] = D_TYPE(max(float(data_a[i]), 0)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp new file mode 100644 index 0000000000..8f4b9a8684 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint src0_idx_mod(uint idx) { + const uint i13 = idx / (p.ne12*p.ne11*p.ne10); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = (idx - i13_offset) / (p.ne11*p.ne10); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = (idx - i13_offset - i12_offset) / p.ne10; + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + return (i13 % p.ne03)*p.nb03 + (i12 % p.ne02)*p.nb02 + (i11 % p.ne01)*p.nb01 + (i10 % p.ne00)*p.nb00; +} + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(data_a[get_aoffset() + src0_idx_mod(idx)]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp new file mode 100644 index 0000000000..87df782944 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/repeat_back.comp @@ -0,0 +1,37 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + // Destination multi-index (inlined dst_idx) + const uint i13 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i13_offset = i13 * p.ne12*p.ne11*p.ne10; + const uint i12 = fastdiv(idx - i13_offset, p.ne1_01mp, p.ne1_01L); + const uint i12_offset = i12*p.ne11*p.ne10; + const uint i11 = fastdiv(idx - i13_offset - i12_offset, p.ne1_0mp, p.ne1_0L); + const uint i10 = idx - i13_offset - i12_offset - i11*p.ne10; + const uint d_idx = i13*p.nb13 + i12*p.nb12 + i11*p.nb11 + i10*p.nb10; + + // Accumulate from sources + A_TYPE acc = A_TYPE(0); + for (uint i3 = i13; i3 < p.ne03; i3 += p.ne13) { + for (uint i2 = i12; i2 < p.ne02; i2 += p.ne12) { + for (uint i1 = i11; i1 < p.ne01; i1 += p.ne11) { + for (uint i0 = i10; i0 < p.ne00; i0 += p.ne10) { + acc += data_a[i3*p.nb03 + i2*p.nb02 + i1*p.nb01 + i0*p.nb00]; + } + } + } + } + + data_d[get_doffset() + d_idx] = D_TYPE(acc); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp new file mode 100644 index 0000000000..d5b211ffaa --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm.comp @@ -0,0 +1,105 @@ +#version 450 + +#include "generic_binary_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void rms_norm(uint num_iters) { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = gl_WorkGroupID.x; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + const uint tid = gl_LocalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + FLOAT_TYPE xi = FLOAT_TYPE(0); + if (col < ncols) { + xi = FLOAT_TYPE(data_a[a_offset + col]); + } + sum += xi * xi; + } + + sumsh[tid] = sum; + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum += sumsh[tid + s]; + sumsh[tid] = sum; + } + barrier(); + } + sum = sumsh[0]; + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } + } else { + [[unroll]] for (uint col = tid, idx = 0; idx < num_iters; col += BLOCK_SIZE, ++idx) { + if (col >= ncols) { + continue; + } + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } + } +} + +void main() { + // instantiate the rms_norm function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.ne00 + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + rms_norm(num_blocks); + } else if (num_blocks > 16) { + rms_norm(32); + } else if (num_blocks > 8) { + rms_norm(16); + } else if (num_blocks > 4) { + rms_norm(8); + } else if (num_blocks == 4) { + rms_norm(4); + } else if (num_blocks == 3) { + rms_norm(3); + } else if (num_blocks == 2) { + rms_norm(2); + } else if (num_blocks == 1) { + rms_norm(1); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp new file mode 100644 index 0000000000..87707fc149 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_back.comp @@ -0,0 +1,55 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 512 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_b[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_xx[BLOCK_SIZE]; +shared FLOAT_TYPE sum_xg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + // Compute derivative of x[i]/norm(x) = g[i]/norm(x) - x[i] dot(x,g)/KX / norm(x)^1.5 + + // partial sums for thread in warp + sum_xx[tid] = FLOAT_TYPE(0.0f); + sum_xg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_a[row*p.KX + col]); + const FLOAT_TYPE xi = FLOAT_TYPE(data_b[row*p.KX + col]); + sum_xx[tid] += xi * xi; + sum_xg[tid] += xi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (int s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_xx[tid] += sum_xx[tid + s]; + sum_xg[tid] += sum_xg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE eps = FLOAT_TYPE(p.param1); + const FLOAT_TYPE mean = sum_xx[0] / FLOAT_TYPE(p.KX); + const FLOAT_TYPE scale_g = inversesqrt(mean + eps); + const FLOAT_TYPE scale_x = -scale_g * sum_xg[0] / (sum_xx[0] + FLOAT_TYPE(p.KX) * eps); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE( + scale_g * FLOAT_TYPE(data_a[row*p.KX + col]) + + scale_x * FLOAT_TYPE(data_b[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp new file mode 100644 index 0000000000..4618b2c7e8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rms_norm_partials.comp @@ -0,0 +1,65 @@ +#version 450 + +#include "generic_binary_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#extension GL_KHR_shader_subgroup_arithmetic : enable +#extension GL_KHR_shader_subgroup_basic : enable + +#define BLOCK_SIZE 128 + +layout (constant_id = 1) const bool do_multiply = false; + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 3, std430) readonly buffer PartialsBuf {float partial_sums[];}; + +shared FLOAT_TYPE sumsh[BLOCK_SIZE]; + +void main() { + const uint ncols = p.ne00; + const uint nrows = gl_NumWorkGroups.x; + const uint nchannels = gl_NumWorkGroups.y; + + const uint row = 0; + const uint channel = gl_WorkGroupID.y; + const uint samp = gl_WorkGroupID.z; + // The work is split across multiple workgroups in the x dimension. Each invocation + // processes one element + const uint tid = gl_GlobalInvocationID.x; + + const uint stride_row = p.nb01; + const uint stride_channel = p.nb02; + const uint stride_sample = p.nb03; + + uint32_t a_offset = samp*stride_sample + channel*stride_channel + row*stride_row + get_aoffset(); + uint32_t b_offset = src1_idx(0, row, channel, samp) + get_boffset(); + uint32_t d_offset = ((samp*nchannels + channel)*nrows + row)*ncols + get_doffset(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); // partial sum for thread in warp + + uint32_t num_partials = p.param3; + for (uint32_t i = gl_SubgroupInvocationID; i < num_partials; i += gl_SubgroupSize) { + sum += partial_sums[i]; + } + sum = subgroupAdd(sum); + + uint col = tid; + if (col >= ncols) { + return; + } + + const FLOAT_TYPE mean = sum / FLOAT_TYPE(ncols); + const FLOAT_TYPE scale = inversesqrt(mean + FLOAT_TYPE(p.param1)); + + if (do_multiply) { + if (ncols > p.ne10) { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + fastmod(col, p.ne10)])); + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col]) * FLOAT_TYPE(data_b[b_offset + col])); + } + } else { + data_d[d_offset + col] = D_TYPE(scale * FLOAT_TYPE(data_a[a_offset + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp new file mode 100644 index 0000000000..68fbd0c7be --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/roll.comp @@ -0,0 +1,46 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +uint wrap_idx(int i, uint ne) { + if (i < 0) { + return i + ne; + } else if (i >= ne) { + return i - ne; + } + return i; +} + +void main() { + const uint idx = get_idx(); + if (idx >= p.ne) { + return; + } + + const uint i3 = fastdiv(idx, p.ne1_012mp, p.ne1_012L); + const uint i3_offset = i3 * p.ne12*p.ne11*p.ne10; + const uint i2 = fastdiv(idx - i3_offset, p.ne1_01mp, p.ne1_01L); + const uint i2_offset = i2*p.ne11*p.ne10; + const uint i1 = fastdiv(idx - i3_offset - i2_offset, p.ne1_0mp, p.ne1_0L); + const uint i0 = idx - i3_offset - i2_offset - i1*p.ne10; + + const uint p1 = floatBitsToUint(p.param1); + const uint p2 = floatBitsToUint(p.param2); + const int s0 = int(p1 >> 16) - 0x8000; + const int s1 = int(p1 & 0xFFFF) - 0x8000; + const int s2 = int(p2 >> 16) - 0x8000; + const int s3 = int(p2 & 0xFFFF) - 0x8000; + + const uint i00 = wrap_idx(int(i0) - s0, p.ne10); + const uint i01 = wrap_idx(int(i1) - s1, p.ne11); + const uint i02 = wrap_idx(int(i2) - s2, p.ne12); + const uint i03 = wrap_idx(int(i3) - s3, p.ne13); + + const uint a_idx = i03*p.nb03 + i02*p.nb02 + i01*p.nb01 + i00*p.nb00; + const uint d_idx = i3 *p.nb13 + i2 *p.nb12 + i1 *p.nb11 + i0 *p.nb10; + + data_d[get_doffset() + d_idx] = D_TYPE(data_a[get_aoffset() + a_idx]); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl new file mode 100644 index 0000000000..50fc1f1e2d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_head.glsl @@ -0,0 +1,55 @@ +#include "types.glsl" + +#extension GL_EXT_shader_16bit_storage : require + +#include "rte.glsl" + +layout(local_size_x = 1, local_size_y = 256, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {int data_pos[];}; +layout (binding = 2) readonly buffer Z {float data_ff[];}; +layout (binding = 3) writeonly buffer D {D_TYPE data_d[];}; + +layout (push_constant) uniform parameter { + uint ncols; + uint n_dims; + float freq_scale; + uint p_delta_rows; + float freq_base; + float ext_factor; + float attn_factor; + float corr_dims[2]; + float theta_scale; + uint has_ff; + uint ne02; + uint s1; + uint s2; + int sections[4]; + uint is_back; +} p; + +float rope_yarn_ramp(const float low, const float high, const uint i0) { + const float y = (i0 / 2 - low) / max(0.001f, high - low); + return 1.0f - min(1.0f, max(0.0f, y)); +} + +void rope_yarn(const float theta_extrap, const uint i0, out float cos_theta, out float sin_theta) { + float mscale = p.attn_factor; + // Get n-d rotational scaling corrected for extrapolation + float theta_interp = p.freq_scale * theta_extrap; + float theta = theta_interp; + if (p.ext_factor != 0.0f) { + float ramp_mix = rope_yarn_ramp(p.corr_dims[0], p.corr_dims[1], i0) * p.ext_factor; + theta = theta_interp * (1 - ramp_mix) + theta_extrap * ramp_mix; + + // Get n-d magnitude scaling corrected for interpolation + mscale *= 1.0f + 0.1f * log(1.0f / p.freq_scale); + } + // Backprogagation uses inverted rotation + if (p.is_back != 0) { + theta = -theta; + } + cos_theta = cos(theta) * mscale; + sin_theta = sin(theta) * mscale; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp new file mode 100644 index 0000000000..111286b498 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -0,0 +1,58 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + + const int sect_dims = p.sections[0] + p.sections[1] + p.sections[2] + p.sections[3]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= p.sections[0] && sector < sec_w) { + theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); + } + else if (sector >= sec_w + p.sections[2]) { + theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp new file mode 100644 index 0000000000..06e095bef9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_neox.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + if (i0 >= p.n_dims) { + data_d[idst + i0/2 + 0] = data_a[ix + i0/2 + 0]; + data_d[idst + i0/2 + 1] = data_a[ix + i0/2 + 1]; + + return; + } + + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims/2]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims/2] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp new file mode 100644 index 0000000000..6ba9575409 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_norm.comp @@ -0,0 +1,41 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0; + + if (i0 >= p.n_dims) { + data_d[idst + 0] = data_a[ix + 0]; + data_d[idst + 1] = data_a[ix + 1]; + + return; + } + + const float theta_base = data_pos[channel_x] * pow(p.theta_scale, i0/2.0f); + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + 1]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + 1] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp new file mode 100644 index 0000000000..d37d1c1043 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rope_vision.comp @@ -0,0 +1,47 @@ +#version 450 + +#include "rope_head.glsl" + +void main() { + const uint i0 = 2*gl_GlobalInvocationID.y; + uint ne0 = p.ncols; + uint ne1 = p.p_delta_rows; + uint ne2 = p.ne02; + + if (i0 >= ne0) { + return; + } + + const uint row_dst = gl_GlobalInvocationID.x; + + const uint row_x = row_dst % ne1; + const uint channel_x = row_dst / ne1; + + const uint idst = row_dst*ne0 + i0/2; + const uint ix = channel_x*p.s2 + row_x*p.s1 + i0/2; + + const int sect_dims = p.sections[0] + p.sections[1]; + const int sec_w = p.sections[1] + p.sections[0]; + const uint sector = (i0 / 2) % sect_dims; + + float theta_base = 0.0; + if (sector < p.sections[0]) { + const uint p0 = sector; + theta_base = data_pos[channel_x]*pow(p.theta_scale, p0); + } + else if (sector >= p.sections[0] && sector < sec_w) { + const uint p0 = sector - p.sections[0]; + theta_base = data_pos[channel_x + ne2]*pow(p.theta_scale, p0); + } + + const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f; + + float cos_theta, sin_theta; + rope_yarn(theta_base / freq_factor, i0, cos_theta, sin_theta); + + const float x0 = float(data_a[ix + 0]); + const float x1 = float(data_a[ix + p.n_dims]); + + data_d[idst + 0] = D_TYPE(x0*cos_theta - x1*sin_theta); + data_d[idst + p.n_dims] = D_TYPE(x0*sin_theta + x1*cos_theta); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl new file mode 100644 index 0000000000..ad51c1e80b --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/rte.glsl @@ -0,0 +1,5 @@ + +#if RTE16 +#extension GL_EXT_spirv_intrinsics : enable +spirv_execution_mode(capabilities = [4467], 4462, 16); // RoundingModeRTE, 16 bits +#endif // RTE16 diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp new file mode 100644 index 0000000000..35ec726a01 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/scale.comp @@ -0,0 +1,24 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +const uint num_threads = 128; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 4; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + + data_d[get_doffset() + idx] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + idx]) * FLOAT_TYPE(p.param1) + FLOAT_TYPE(p.param2)); + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp new file mode 100644 index 0000000000..32298d43c6 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sigmoid.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. / (1 + exp(-1. * float(data_a[i])))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp new file mode 100644 index 0000000000..7d1cc6f45a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu.comp @@ -0,0 +1,22 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + const float xi = float(data_a[i]); + data_d[i] = D_TYPE(xi / (1.0f + exp(-xi))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp new file mode 100644 index 0000000000..e5d949ff18 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/silu_back.comp @@ -0,0 +1,26 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer X {B_TYPE data_x[];}; +layout (binding = 2) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + + // Compute derivative of SiLU(x): 1/(1+exp(-x)) - x*exp(-x)/(1+exp(-x))^2 + + const float xi = float(data_x[i]); + const float s = 1.0f / (1.0f + exp(-xi)); + data_d[i] = D_TYPE(data_g[i] * (s + xi * s * (1 - s))); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp new file mode 100644 index 0000000000..61f17b2f00 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sin.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sin(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp new file mode 100644 index 0000000000..dca0d896bc --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max.comp @@ -0,0 +1,195 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +layout (push_constant) uniform parameter +{ + uint KX; + uint KY; + uint ne00; + uint ne01; + uint ne02; + uint ne12; + uint ne13; + uint nb11; + uint nb12; + uint nb13; + float scale; + float max_bias; + float m0; + float m1; + uint n_head_log2; + uint nrows_x; + uint has_sinks; +} p; + +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_b[];}; +layout (binding = 2) readonly buffer Z {float data_c[];}; +layout (binding = 3) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE vals[BLOCK_SIZE]; + +// num_iters is the number of BLOCK_SIZE loop iterations we need to iterate +// over all the columns. The main function tries to pass a constant here, +// as if it were a template function, to allow unrolling. +void soft_max(uint num_iters) { + const uint tid = gl_LocalInvocationID.x; + const uint rowx = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + + const uint32_t i03 = rowx / (p.ne01 * p.ne02); + const uint32_t i02 = (rowx - i03 * p.ne01 * p.ne02) / p.ne01; + const uint32_t i01 = rowx % p.ne01; + + uint rowy_start = 0; + if (p.KY > 0) { + rowy_start = i01 * p.nb11 + (i02 % p.ne12) * p.nb12 + (i03 % p.ne13) * p.nb13; + } + + if (rowx >= p.nrows_x) { + return; + } + + float slope = 1.0f; + + // ALiBi + if (p.max_bias > 0.0f) { + const uint h = (rowx / p.ne01) % p.ne02; // head index + + const float base = h < p.n_head_log2 ? p.m0 : p.m1; + const uint exp = h < p.n_head_log2 ? h + 1 : 2*(h - p.n_head_log2) + 1; + + slope = pow(base, exp); + } + + // Find max + FLOAT_TYPE max_val = p.has_sinks == 0 ? uintBitsToFloat(0xFF800000) : data_c[i02]; + + // Cache values while we compute the max, so we don't need to read them + // again when we're ready to compute exp(x-max). + const uint DATA_CACHE_SIZE = 16; + FLOAT_TYPE data_cache[DATA_CACHE_SIZE]; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + FLOAT_TYPE a = FLOAT_TYPE(0); + if (col < p.KX) { + a = data_a[rowx * p.KX + col]; + } + + FLOAT_TYPE b = FLOAT_TYPE(0); + if (p.KY > 0 && col < p.KX) { + b = data_b[rowy_start + col]; + } + + FLOAT_TYPE v = a * p.scale + slope * b; + + if (col < p.KX) { + max_val = max(max_val, v); + } + + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = v; + } + } + + // reduce across the workgroup + vals[tid] = max_val; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] = max(vals[tid], vals[tid + s]); + } + barrier(); + } + + max_val = vals[0]; + barrier(); + + FLOAT_TYPE sum = FLOAT_TYPE(0.0f); + + // Compute sum{exp(x - max)} + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + break; + } + + // compute exp(a*scale+b*slope), add it to sum, and cache the new value + // in data_cache if possible. + const uint i = rowx * p.KX + col; + FLOAT_TYPE val; + if (idx < DATA_CACHE_SIZE) { + val = exp(data_cache[idx] - max_val); + } else { + val = exp(FLOAT_TYPE(data_a[i]) * p.scale + (p.KY > 0 ? slope * FLOAT_TYPE(data_b[rowy_start + col]) : FLOAT_TYPE(0.0f)) - max_val); + } + sum += val; + if (idx < DATA_CACHE_SIZE) { + data_cache[idx] = val; + } else { + data_d[i] = D_TYPE(val); + } + } + + // reduce across the workgroup + vals[tid] = sum; + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + vals[tid] += vals[tid + s]; + } + barrier(); + } + sum = vals[0]; + + if (p.has_sinks != 0) { + sum += FLOAT_TYPE(exp(FLOAT_TYPE(data_c[i02]) - max_val)); + } + + FLOAT_TYPE rcpdivisor = 1.0/sum; + + [[unroll]] for (uint col0 = 0, idx = 0; idx < num_iters; col0 += BLOCK_SIZE, ++idx) { + const uint col = col0 + tid; + + if (col >= p.KX) { + continue; + } + + if (idx < DATA_CACHE_SIZE) { + data_d[rowx*p.KX + col] = D_TYPE(data_cache[idx] * rcpdivisor); + } else { + data_d[rowx*p.KX + col] *= D_TYPE(rcpdivisor); + } + } +} + +void main() { + // instantiate the soft_max function for several different + // dimensions, to allow loop unrolling + uint num_blocks = (p.KX + BLOCK_SIZE - 1) / BLOCK_SIZE; + if (num_blocks > 32) { + soft_max(num_blocks); + } else if (num_blocks > 16) { + soft_max(32); + } else if (num_blocks > 8) { + soft_max(16); + } else if (num_blocks > 4) { + soft_max(8); + } else if (num_blocks == 4) { + soft_max(4); + } else if (num_blocks == 3) { + soft_max(3); + } else if (num_blocks == 2) { + soft_max(2); + } else if (num_blocks == 1) { + soft_max(1); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp new file mode 100644 index 0000000000..d873332eeb --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/soft_max_back.comp @@ -0,0 +1,54 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : enable + +#include "generic_head.glsl" +#include "types.glsl" + +layout(constant_id = 0) const uint BLOCK_SIZE = 32; +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +// In this shader Y = softmax(X) and X is not provided as input. + +layout (binding = 0) readonly buffer G {A_TYPE data_g[];}; +layout (binding = 1) readonly buffer Y {B_TYPE data_y[];}; +layout (binding = 2) buffer D {D_TYPE data_d[];}; + +shared FLOAT_TYPE sum_yg[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint tid = gl_LocalInvocationID.x; + + if (row >= p.KY) { + return; + } + + FLOAT_TYPE scale = p.param1; + + // partial sums for thread in warp + sum_yg[tid] = FLOAT_TYPE(0.0f); + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + const FLOAT_TYPE gi = FLOAT_TYPE(data_g[row*p.KX + col]); + const FLOAT_TYPE yi = FLOAT_TYPE(data_y[row*p.KX + col]); + sum_yg[tid] += yi * gi; + } + + // sum up partial sums and write back result + barrier(); + [[unroll]] for (uint s = BLOCK_SIZE / 2; s > 0; s >>= 1) { + if (tid < s) { + sum_yg[tid] += sum_yg[tid + s]; + } + barrier(); + } + + const FLOAT_TYPE dot_yg = sum_yg[0]; + + [[unroll]] for (uint col = tid; col < p.KX; col += BLOCK_SIZE) { + data_d[row*p.KX + col] = D_TYPE(scale + * (FLOAT_TYPE(data_g[row*p.KX + col]) - dot_yg) + * FLOAT_TYPE(data_y[row*p.KX + col])); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp new file mode 100644 index 0000000000..70daad6c5d --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sqrt.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(sqrt(val)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp new file mode 100644 index 0000000000..4eb56afcb1 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/square.comp @@ -0,0 +1,17 @@ +#version 450 + +#include "types.glsl" +#include "generic_unary_head.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +void main() { + const uint idx = get_idx(); + + if (idx >= p.ne) { + return; + } + + const FLOAT_TYPE val = FLOAT_TYPE(data_a[get_aoffset() + src0_idx(idx)]); + data_d[get_doffset() + dst_idx(idx)] = D_TYPE(val * val); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp new file mode 100644 index 0000000000..bc924b520a --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sub.comp @@ -0,0 +1,29 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +#include "types.glsl" +#include "generic_binary_head.glsl" + +const uint num_threads = 256; + +layout(local_size_x = num_threads, local_size_y = 1, local_size_z = 1) in; + +void main() { + uint idx = get_idx(); + + // num_threads * num_iter must equal 512, to match the wg_denoms and get_idx calculation + const uint num_iter = 2; + + [[unroll]] for (uint i = 0; i < num_iter; ++i) { + if (idx >= p.ne) { + continue; + } + uint i00, i01, i02, i03; + get_indices(idx, i00, i01, i02, i03); + + data_d[get_doffset() + dst_idx(i00, i01, i02, i03)] = D_TYPE(FLOAT_TYPE(data_a[get_aoffset() + src0_idx(i00, i01, i02, i03)]) - FLOAT_TYPE(data_b[get_boffset() + src1_idx(i00, i01, i02, i03)])); + + idx += num_threads; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp new file mode 100644 index 0000000000..bc22aa7bd7 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/sum_rows.comp @@ -0,0 +1,70 @@ +#version 450 + +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x_id = 0, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +layout (constant_id = 0) const uint BLOCK_SIZE = 32; + +layout (push_constant) uniform parameter +{ + uint n_cols; + uint ne01, ne02; + uint nb01, nb02, nb03; + uint nb11, nb12, nb13; + float weight; + uint misalign_offsets; + uint ne0_12mp, ne0_12L; + uint ne0_1mp, ne0_1L; +} p; + +uint get_aoffset() { return p.misalign_offsets >> 16; } +uint get_doffset() { return p.misalign_offsets & 0xFFFF; } + +// see init_fastdiv_values in ggml-vulkan.cpp +uint fastdiv(uint n, uint mp, uint L) { + uint msbs, lsbs; + // msbs = mulhi(n, mp) + umulExtended(n, mp, msbs, lsbs); + return (msbs + n) >> L; +} + + +shared FLOAT_TYPE tmp[BLOCK_SIZE]; + +void main() { + const uint row = gl_WorkGroupID.z * 262144 + gl_WorkGroupID.y * 512 + gl_WorkGroupID.x; + const uint col = gl_LocalInvocationID.x; + const float weight = p.weight; + + const uint i03 = fastdiv(row, p.ne0_12mp, p.ne0_12L); + const uint i03_offset = i03 * p.ne01*p.ne02; + const uint i02 = fastdiv(row - i03_offset, p.ne0_1mp, p.ne0_1L); + const uint i01 = row - i03_offset - i02*p.ne01; + + const uint src_idx = get_aoffset() + i01 * p.nb01 + i02 * p.nb02 + i03 * p.nb03; + const uint dst_idx = get_doffset() + i01 * p.nb11 + i02 * p.nb12 + i03 * p.nb13; + + tmp[col] = FLOAT_TYPE(0.0); + + for (uint i = col; i < p.n_cols; i += BLOCK_SIZE) { + tmp[col] += FLOAT_TYPE(data_a[src_idx + i]); + } + + barrier(); + [[unroll]] for (int s = int(BLOCK_SIZE) / 2; s > 0; s >>= 1) { + if (col < s) { + tmp[col] += tmp[col + s]; + } + barrier(); + } + + if (col == 0) { + data_d[dst_idx] = D_TYPE(tmp[0] * weight); + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp new file mode 100644 index 0000000000..4fee433a12 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu.comp @@ -0,0 +1,9 @@ +#version 450 + +#include "glu_head.glsl" + +float op(float a, float b) { + return a / (1.0f + exp(-a)) * b; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp new file mode 100644 index 0000000000..bda9dea21c --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/swiglu_oai.comp @@ -0,0 +1,14 @@ +#version 450 + +#include "glu_head.glsl" + +float op(float a, float b) { + float xi = min(a, p.limit); + float gi = max(min(b, p.limit), -p.limit); + + float out_glu = xi / (1.0f + exp(-xi * p.alpha)); + out_glu = out_glu * (1.0f + gi); + return out_glu; +} + +#include "glu_main.glsl" diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp new file mode 100644 index 0000000000..7b5eb413bf --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/tanh.comp @@ -0,0 +1,20 @@ +#version 450 + +#include "generic_head.glsl" +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (i >= p.KX) { + return; + } + data_d[i] = D_TYPE(1. - 2. / (exp(2.*float(data_a[i])) + 1.)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp new file mode 100644 index 0000000000..1605565457 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/timestep_embedding.comp @@ -0,0 +1,42 @@ +#version 450 + +#extension GL_EXT_shader_16bit_storage : require + +layout (push_constant) uniform parameter +{ + uint nb1; + uint dim; + uint max_period; +} p; + +#include "types.glsl" + +#extension GL_EXT_control_flow_attributes : enable +#define BLOCK_SIZE 256 + +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer X {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +void main() { + const uint i = gl_WorkGroupID.y; + const uint j = gl_GlobalInvocationID.x; + const uint d_offset = i * p.nb1; + + const uint half_dim = p.dim / 2; + + if (p.dim % 2 != 0 && j == half_dim) { + data_d[d_offset + 2 * half_dim] = 0.f; + } + + if (j >= half_dim) { + return; + } + + const float timestep = float(data_a[i]); + const float freq = float(exp(-log(p.max_period) * j / half_dim)); + const float arg = timestep * freq; + data_d[d_offset + j] = D_TYPE(cos(arg)); + data_d[d_offset + j + half_dim] = D_TYPE(sin(arg)); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl new file mode 100644 index 0000000000..2fa54ce51f --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/types.glsl @@ -0,0 +1,1465 @@ +#if !defined(GGML_TYPES_COMP) +#define GGML_TYPES_COMP + +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int32 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int16 : require +#extension GL_EXT_shader_explicit_arithmetic_types_int8 : require +#extension GL_EXT_shader_16bit_storage : require + +#if defined(DATA_A_F32) +#define QUANT_K 1 +#define QUANT_R 1 + +#if LOAD_VEC_A == 4 +#define A_TYPE vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE mat2x4 +#else +#define A_TYPE float +#endif +#endif + +#if defined(DATA_A_F16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if LOAD_VEC_A == 4 +#define A_TYPE f16vec4 +#elif LOAD_VEC_A == 8 +#define A_TYPE f16mat2x4 +#else +#define A_TYPE float16_t +#endif +#endif + +#if defined(DATA_A_BF16) +#define QUANT_K 1 +#define QUANT_R 1 + +#if LOAD_VEC_A == 4 +#define A_TYPE u16vec4 +#elif LOAD_VEC_A == 8 +#error unsupported +#else +#define A_TYPE uint16_t +#endif +#endif + +#define QUANT_K_Q4_0 32 +#define QUANT_R_Q4_0 2 + +struct block_q4_0 +{ + float16_t d; + uint8_t qs[16]; +}; +struct block_q4_0_packed16 +{ + float16_t d; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q4_0) +#define QUANT_K QUANT_K_Q4_0 +#define QUANT_R QUANT_R_Q4_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q4_0 +#define A_TYPE_PACKED16 block_q4_0_packed16 +#endif + +#define QUANT_K_Q4_1 32 +#define QUANT_R_Q4_1 2 + +struct block_q4_1 +{ + float16_t d; + float16_t m; + uint8_t qs[16]; +}; + +struct block_q4_1_packed16 +{ + float16_t d; + float16_t m; + uint16_t qs[16/2]; +}; + +struct block_q4_1_packed32 +{ + f16vec2 dm; + uint32_t qs[16/4]; +}; + +#if defined(DATA_A_Q4_1) +#define QUANT_K QUANT_K_Q4_1 +#define QUANT_R QUANT_R_Q4_1 +#define QUANT_AUXF 2 +#define A_TYPE block_q4_1 +#define A_TYPE_PACKED16 block_q4_1_packed16 +#define A_TYPE_PACKED32 block_q4_1_packed32 +#endif + +#define QUANT_K_Q5_0 32 +#define QUANT_R_Q5_0 2 + +struct block_q5_0 +{ + float16_t d; + uint16_t qh[2]; + uint8_t qs[16]; +}; + +struct block_q5_0_packed16 +{ + float16_t d; + uint16_t qh[2]; + uint16_t qs[16/2]; +}; + +#if defined(DATA_A_Q5_0) +#define QUANT_K QUANT_K_Q5_0 +#define QUANT_R QUANT_R_Q5_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q5_0 +#define A_TYPE_PACKED16 block_q5_0_packed16 +#endif + +#define QUANT_K_Q5_1 32 +#define QUANT_R_Q5_1 2 + +struct block_q5_1 +{ + float16_t d; + float16_t m; + uint qh; + uint8_t qs[16]; +}; + +struct block_q5_1_packed16 +{ + float16_t d; + float16_t m; + uint qh; + uint16_t qs[16/2]; +}; + +struct block_q5_1_packed32 +{ + f16vec2 dm; + uint qh; + uint32_t qs[16/4]; +}; + +#if defined(DATA_A_Q5_1) +#define QUANT_K QUANT_K_Q5_1 +#define QUANT_R QUANT_R_Q5_1 +#define QUANT_AUXF 2 +#define A_TYPE block_q5_1 +#define A_TYPE_PACKED16 block_q5_1_packed16 +#define A_TYPE_PACKED32 block_q5_1_packed32 +#endif + +#define QUANT_K_Q8_0 32 +#define QUANT_R_Q8_0 1 + +struct block_q8_0 +{ + float16_t d; + int8_t qs[32]; +}; +struct block_q8_0_packed16 +{ + float16_t d; + int16_t qs[32/2]; +}; +struct block_q8_0_packed32 +{ + float16_t d; + int32_t qs[32/4]; +}; + +#if defined(DATA_A_Q8_0) +#define QUANT_K QUANT_K_Q8_0 +#define QUANT_R QUANT_R_Q8_0 +#define QUANT_AUXF 1 +#define A_TYPE block_q8_0 +#define A_TYPE_PACKED16 block_q8_0_packed16 +#define A_TYPE_PACKED32 block_q8_0_packed32 +#endif + +#define QUANT_K_Q8_1 32 +#define QUANT_R_Q8_1 1 + +struct block_q8_1 +{ + f16vec2 ds; + int8_t qs[32]; +}; +struct block_q8_1_packed16 +{ + f16vec2 ds; + int16_t qs[16]; +}; +struct block_q8_1_packed32 +{ + f16vec2 ds; + int32_t qs[8]; +}; + +// 4 blocks in one to allow 16-byte/128-bit alignment and loads +struct block_q8_1_x4 +{ + f16vec2 ds[4]; + int32_t qs[32]; +}; +struct block_q8_1_x4_packed128 +{ + f16vec2 ds[4]; + ivec4 qs[8]; +}; + +// K-quants +#define QUANT_K_Q2_K 256 + +struct block_q2_K +{ + uint8_t scales[QUANT_K_Q2_K/16]; + uint8_t qs[QUANT_K_Q2_K/4]; + f16vec2 d; +}; + +struct block_q2_K_packed16 +{ + uint16_t scales[QUANT_K_Q2_K/16/2]; + uint16_t qs[QUANT_K_Q2_K/4/2]; + f16vec2 d; +}; + +struct block_q2_K_packed32 +{ + uint32_t scales[QUANT_K_Q2_K/16/4]; + uint32_t qs[QUANT_K_Q2_K/4/4]; + f16vec2 d; +}; + +#if defined(DATA_A_Q2_K) +#define QUANT_K QUANT_K_Q2_K +#define QUANT_R 1 +#define A_TYPE block_q2_K +#define A_TYPE_PACKED16 block_q2_K_packed16 +#define A_TYPE_PACKED32 block_q2_K_packed32 +#endif + +#define QUANT_K_Q3_K 256 + +struct block_q3_K +{ + uint8_t hmask[QUANT_K_Q3_K/8]; + uint8_t qs[QUANT_K_Q3_K/4]; + uint8_t scales[12]; + float16_t d; +}; + +struct block_q3_K_packed16 +{ + uint16_t hmask[QUANT_K_Q3_K/8/2]; + uint16_t qs[QUANT_K_Q3_K/4/2]; + uint16_t scales[12/2]; + float16_t d; +}; + +#if defined(DATA_A_Q3_K) +#define QUANT_K QUANT_K_Q3_K +#define QUANT_R 1 +#define A_TYPE block_q3_K +#define A_TYPE_PACKED16 block_q3_K_packed16 +#endif + +#define QUANT_K_Q4_K 256 + +struct block_q4_K +{ + f16vec2 d; + uint8_t scales[3*QUANT_K_Q4_K/64]; + uint8_t qs[QUANT_K_Q4_K/2]; +}; + +struct block_q4_K_packed16 +{ + f16vec2 d; + uint16_t scales[3*QUANT_K_Q4_K/64/2]; + uint16_t qs[QUANT_K_Q4_K/2/2]; +}; + +struct block_q4_K_packed32 +{ + f16vec2 d; + uint32_t scales[3*QUANT_K_Q4_K/64/4]; + uint32_t qs[QUANT_K_Q4_K/2/4]; +}; + +struct block_q4_K_packed128 +{ + uvec4 q4k[9]; +}; + +#if defined(DATA_A_Q4_K) +#define QUANT_K QUANT_K_Q4_K +#define QUANT_R 1 +#define A_TYPE block_q4_K +#define A_TYPE_PACKED16 block_q4_K_packed16 +#define A_TYPE_PACKED32 block_q4_K_packed32 +#endif + +#define QUANT_K_Q5_K 256 + +struct block_q5_K +{ + f16vec2 d; + uint8_t scales[12]; + uint8_t qh[QUANT_K_Q5_K/8]; + uint8_t qs[QUANT_K_Q5_K/2]; +}; + +struct block_q5_K_packed16 +{ + f16vec2 d; + uint16_t scales[12/2]; + uint16_t qh[QUANT_K_Q5_K/8/2]; + uint16_t qs[QUANT_K_Q5_K/2/2]; +}; + +struct block_q5_K_packed128 +{ + uvec4 q5k[11]; +}; + +#if defined(DATA_A_Q5_K) +#define QUANT_K QUANT_K_Q5_K +#define QUANT_R 1 +#define A_TYPE block_q5_K +#define A_TYPE_PACKED16 block_q5_K_packed16 +#endif + +#define QUANT_K_Q6_K 256 + +struct block_q6_K +{ + uint8_t ql[QUANT_K_Q6_K/2]; + uint8_t qh[QUANT_K_Q6_K/4]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +struct block_q6_K_packed16 +{ + uint16_t ql[QUANT_K_Q6_K/2/2]; + uint16_t qh[QUANT_K_Q6_K/4/2]; + int8_t scales[QUANT_K_Q6_K/16]; + float16_t d; +}; + +#if defined(DATA_A_Q6_K) +#define QUANT_K QUANT_K_Q6_K +#define QUANT_R 1 +#define A_TYPE block_q6_K +#define A_TYPE_PACKED16 block_q6_K_packed16 +#endif + +// IQuants + +#define QUANT_K_IQ1_S 256 +#define QUANT_R_IQ1_S 1 + +struct block_iq1_s { + float16_t d; + uint8_t qs[QUANT_K_IQ1_S/8]; + uint16_t qh[QUANT_K_IQ1_S/32]; +}; + +#define QUANT_K_IQ1_M 256 +#define QUANT_R_IQ1_M 1 + +struct block_iq1_m { + uint8_t qs[QUANT_K_IQ1_M/8]; + uint8_t qh[QUANT_K_IQ1_M/16]; + uint16_t scales[QUANT_K_IQ1_M/64]; +}; + +struct block_iq1_m_packed64 { + uint64_t qs[QUANT_K_IQ1_M/8/8]; + uint64_t qh[QUANT_K_IQ1_M/16/8]; + uint64_t scales; +}; + +#if defined(DATA_A_IQ1_S) +#define QUANT_K QUANT_K_IQ1_S +#define QUANT_R QUANT_R_IQ1_S +#define A_TYPE block_iq1_s +#endif + +#if defined(DATA_A_IQ1_M) +#define QUANT_K QUANT_K_IQ1_M +#define QUANT_R QUANT_R_IQ1_M +#define A_TYPE block_iq1_m +#endif + +#if defined(DATA_A_IQ1_S) || defined(DATA_A_IQ1_M) +#define IQ1S_DELTA 0.125f +#define IQ1M_DELTA 0.125f + +// Packed IQ1S grid where every 2 vec8 are encoded on 32 bits (2 bits per coordinate). +const uint[1024] iq1s_grid_const = { + 0xfffdffff, 0xfff7fff0, 0xffccfff5, 0xffdfffc0, 0xffd7ffdd, 0xff30ffd5, 0xff03ff0c, 0xff10ff01, + 0xff7dff7f, 0xff75ff77, 0xff5fff40, 0xff57ff5d, 0xfcf3ff55, 0xfcccfcf0, 0xfcc1fcc3, 0xfcc5fcc4, + 0xfc3cfcd0, 0xfc34fc31, 0xfc00fc0d, 0xfc1cfc05, 0xfc11fc13, 0xfc70fc17, 0xfc43fc4c, 0xfc50fc41, + 0xfdfdfdff, 0xfdf5fdf7, 0xfddffdc0, 0xfdd7fddd, 0xfd30fdd5, 0xfd04fd0c, 0xfd14fd13, 0xfd7dfd7f, + 0xfd75fd77, 0xfd40fd4c, 0xfd5ffd44, 0xfd57fd5d, 0xf3ccfd55, 0xf3c1f3c3, 0xf33cf3d0, 0xf300f334, + 0xf313f305, 0xf34cf310, 0xf350f344, 0xf0f3f0fc, 0xf0f1f0f0, 0xf0c7f0c0, 0xf0d4f0c5, 0xf030f03f, + 0xf00ff035, 0xf003f00c, 0xf001f000, 0xf01ff004, 0xf010f01d, 0xf015f017, 0xf04cf07c, 0xf047f040, + 0xf05cf045, 0xf050f053, 0xf054f051, 0xf1c4f1c3, 0xf133f13c, 0xf10df10f, 0xf107f100, 0xf11cf11f, + 0xf114f111, 0xf14cf170, 0xf144f143, 0xf7fdf7ff, 0xf7f5f7f7, 0xf7dff7c0, 0xf7d7f7dd, 0xf730f7d5, + 0xf701f70c, 0xf77ff710, 0xf777f77d, 0xf740f775, 0xf75df75f, 0xf755f757, 0xf4ccf4f0, 0xf4c4f4c3, + 0xf4d0f4d3, 0xf40ff43c, 0xf400f40c, 0xf413f41c, 0xf44cf414, 0xf441f443, 0xf450f444, 0xf5fdf5ff, + 0xf5f5f5f7, 0xf5dff5c0, 0xf5d7f5dd, 0xf530f5d5, 0xf504f50c, 0xf510f51c, 0xf57df57f, 0xf577f570, + 0xf540f575, 0xf55df55f, 0xf555f557, 0xcfcccfcf, 0xcfc4cfc3, 0xcfd0cfd3, 0xcf33cf3c, 0xcf00cf0f, + 0xcf1ccf07, 0xcf10cf13, 0xcf4ccf14, 0xcf41cf43, 0xcf50cf5c, 0xccf3ccfc, 0xccf4ccf1, 0xcccdcccf, + 0xccc7ccc0, 0xccd3ccdc, 0xcc30ccd4, 0xcc0fcc35, 0xcc0dcc0c, 0xcc00cc03, 0xcc04cc01, 0xcc10cc1f, + 0xcc4dcc73, 0xcc5ccc40, 0xcdcccc53, 0xcdc1cdc3, 0xcd3fcdd0, 0xcd34cd31, 0xcd00cd0d, 0xcd05cd07, + 0xcd11cd13, 0xcd4ccd70, 0xcd41cd43, 0xc3fccd50, 0xc3f4c3f1, 0xc3c0c3c3, 0xc3c4c3c7, 0xc3d1c3dc, + 0xc330c33c, 0xc337c331, 0xc30cc335, 0xc300c303, 0xc304c301, 0xc310c31d, 0xc373c317, 0xc34fc374, + 0xc340c343, 0xc344c347, 0xc35cc345, 0xc350c353, 0xc0fdc354, 0xc0f5c0f0, 0xc0c3c0cc, 0xc0c1c0c0, + 0xc0dfc0c4, 0xc0d0c0dd, 0xc0d5c0d7, 0xc033c03c, 0xc031c030, 0xc00dc00c, 0xc000c003, 0xc004c001, + 0xc01cc005, 0xc010c013, 0xc014c011, 0xc07dc07f, 0xc070c073, 0xc075c077, 0xc04cc04f, 0xc040c043, + 0xc044c041, 0xc05fc045, 0xc050c05d, 0xc1f3c1fc, 0xc1f1c1f0, 0xc1c1c1c0, 0xc1c5c1c7, 0xc1d1c1dc, + 0xc13dc13f, 0xc130c133, 0xc135c137, 0xc100c10c, 0xc107c101, 0xc11cc104, 0xc110c113, 0xc114c117, + 0xc171c115, 0xc14dc175, 0xc153c140, 0xc7ccc154, 0xc7d0c7c1, 0xc733c73c, 0xc734c731, 0xc700c70f, + 0xc705c707, 0xc71cc71f, 0xc711c713, 0xc770c714, 0xc743c74c, 0xc4cfc750, 0xc4c0c4cd, 0xc4dcc4c5, + 0xc43dc4d0, 0xc430c433, 0xc40cc437, 0xc400c403, 0xc404c401, 0xc41fc405, 0xc415c410, 0xc44cc474, + 0xc440c44d, 0xc45cc447, 0xc454c451, 0xc5c1c5f4, 0xc5d1c5d3, 0xc531c533, 0xc50fc534, 0xc500c50d, + 0xc51cc507, 0xc514c511, 0xc54cc570, 0xc545c541, 0xdffddfff, 0xdff5dff7, 0xdfdfdfc0, 0xdfd0dfdd, + 0xdfd5dfd7, 0xdf0cdf30, 0xdf1cdf04, 0xdf7fdf10, 0xdf77df7d, 0xdf40df75, 0xdf5ddf5f, 0xdf57df50, + 0xdcf0df55, 0xdcc3dccc, 0xdcd0dcc4, 0xdc33dc3d, 0xdc00dc34, 0xdc05dc07, 0xdc13dc1c, 0xdc11dc10, + 0xdc4fdc70, 0xdc44dc41, 0xddfcdc50, 0xddf5ddf7, 0xddc0ddcc, 0xdddddddf, 0xddd5ddd7, 0xdd0cdd30, + 0xdd04dd01, 0xdd7cdd10, 0xdd75dd77, 0xdd40dd4c, 0xdd5ddd5f, 0xdd55dd57, 0xd3c3d3f0, 0xd3c4d3c1, + 0xd333d3d0, 0xd331d330, 0xd30dd334, 0xd307d300, 0xd311d305, 0xd34cd370, 0xd344d343, 0xd350d35c, + 0xd0c0d0f4, 0xd0d4d0dc, 0xd030d03f, 0xd00cd037, 0xd000d003, 0xd01dd004, 0xd017d010, 0xd04fd074, + 0xd040d043, 0xd045d047, 0xd053d05c, 0xd054d051, 0xd1cfd1f0, 0xd1c4d1cd, 0xd13cd1d0, 0xd100d134, + 0xd11cd11f, 0xd173d114, 0xd14fd171, 0xd7ffd145, 0xd7f7d7fd, 0xd7c0d7f5, 0xd7ddd7df, 0xd7d5d7d7, + 0xd70cd730, 0xd710d703, 0xd77dd77f, 0xd775d777, 0xd75dd75f, 0xd755d757, 0xd4ccd4f4, 0xd4c4d4c3, + 0xd431d4d0, 0xd40dd434, 0xd41cd400, 0xd411d413, 0xd470d414, 0xd441d44f, 0xd453d444, 0xd5ffd450, + 0xd5f7d5fd, 0xd5dfd5f5, 0xd5d7d5dd, 0xd530d5d5, 0xd501d50c, 0xd510d504, 0xd57dd57f, 0xd575d577, + 0xd55fd540, 0xd557d55d, 0x3ff0d555, 0x3fc13fcc, 0x3f343fd0, 0x3f003f0d, 0x3f053f07, 0x3f133f1c, + 0x3f433f11, 0x3f5c3f44, 0x3cff3f51, 0x3cf33cfc, 0x3cf43cf1, 0x3cc03ccd, 0x3cc73cc1, 0x3cdc3cc5, + 0x3cd43cd1, 0x3c373c30, 0x3c0c3c35, 0x3c003c03, 0x3c043c01, 0x3c103c05, 0x3c153c17, 0x3c733c7c, + 0x3c4f3c71, 0x3c403c4d, 0x3c5c3c5f, 0x3df03c5d, 0x3dc33dcc, 0x3dd03dc1, 0x3d0d3d3c, 0x3d053d00, + 0x3d143d13, 0x3d433d74, 0x33fc3d50, 0x33c433c0, 0x333033d4, 0x33353337, 0x3303330c, 0x33013300, + 0x331d331c, 0x33173310, 0x337c3315, 0x33743371, 0x334d334f, 0x335f3340, 0x3354335c, 0x30fd30fc, + 0x30f530f0, 0x30c330cc, 0x30c130c0, 0x30df30c4, 0x30d530d0, 0x3033303c, 0x30313030, 0x300f3034, + 0x3003300c, 0x30013000, 0x30043007, 0x3013301c, 0x30113010, 0x307d3014, 0x30703073, 0x304c3077, + 0x30403043, 0x30443041, 0x30503045, 0x30553057, 0x31f031fc, 0x31c331f4, 0x31c731c0, 0x31dc31c5, + 0x31d431d3, 0x313d313f, 0x31373130, 0x310c310f, 0x3100310d, 0x31043101, 0x3110311d, 0x317c3117, + 0x31753170, 0x31403143, 0x3153315c, 0x37f03151, 0x37c037cc, 0x37d037c5, 0x3734373d, 0x3700370f, + 0x371c3707, 0x37113713, 0x37703714, 0x3743374c, 0x37443741, 0x34fc3750, 0x34f134f0, 0x34cf34f5, + 0x34c034c3, 0x34dc34c7, 0x34d134d3, 0x3430343f, 0x340c3435, 0x3403340d, 0x34013400, 0x341f3404, + 0x3410341d, 0x34153411, 0x34743471, 0x3440344d, 0x34473441, 0x3453345c, 0x34543451, 0x353335c1, + 0x35343531, 0x35073500, 0x35133505, 0x35433514, 0x0ffc3550, 0x0ff00ff3, 0x0ff40ff1, 0x0fc00fcd, + 0x0fdc0fc5, 0x0fd40fd3, 0x0f300f3f, 0x0f0c0f37, 0x0f000f03, 0x0f040f01, 0x0f170f10, 0x0f740f71, + 0x0f470f40, 0x0f5c0f5f, 0x0f540f51, 0x0cf70cf0, 0x0cf50cf4, 0x0cc30ccc, 0x0cc10cc0, 0x0cc40cc7, + 0x0cd00cdf, 0x0cd70cd1, 0x0c3c0cd5, 0x0c300c33, 0x0c340c31, 0x0c0c0c0f, 0x0c030c0d, 0x0c010c00, + 0x0c040c07, 0x0c1c0c05, 0x0c100c13, 0x0c140c11, 0x0c700c7d, 0x0c430c4c, 0x0c410c40, 0x0c5f0c44, + 0x0c550c50, 0x0df10dfc, 0x0dc00dcd, 0x0ddc0dc5, 0x0d3d0dd3, 0x0d350d30, 0x0d030d0c, 0x0d010d00, + 0x0d1d0d04, 0x0d700d10, 0x0d4d0d4f, 0x0d440d40, 0x0d530d45, 0x03f003f3, 0x03c303cc, 0x03c103c0, + 0x03c403c7, 0x03d003dc, 0x03d503d7, 0x0333033c, 0x03310330, 0x03350334, 0x030c030f, 0x03000303, + 0x03070301, 0x03050304, 0x031d031c, 0x03100313, 0x03140311, 0x0377037f, 0x034c0375, 0x03400343, + 0x03440341, 0x0353035c, 0x03550350, 0x00fd00fc, 0x00f000f3, 0x00f400f1, 0x00cc00cf, 0x00c300cd, + 0x00c100c0, 0x00c500c4, 0x00d300dc, 0x00d100d0, 0x003f00d4, 0x003d003c, 0x00300033, 0x00370031, + 0x000f0034, 0x000d000c, 0x00000003, 0x00070001, 0x00050004, 0x001c001f, 0x00100013, 0x00170011, + 0x00150014, 0x0073007c, 0x00740070, 0x004f0075, 0x0043004c, 0x00410040, 0x00440047, 0x0053005c, + 0x00510050, 0x01ff0054, 0x01fd01fc, 0x01f101f3, 0x01f401f7, 0x01c301cc, 0x01c701c0, 0x01df01c4, + 0x01dd01dc, 0x01d001d3, 0x01d701d1, 0x013c01d4, 0x01310130, 0x01340137, 0x010f0135, 0x010d010c, + 0x01000103, 0x01070101, 0x01050104, 0x0113011c, 0x01140110, 0x0170017d, 0x01770171, 0x01750174, + 0x0140014c, 0x015d0145, 0x01510150, 0x01540157, 0x07f007f3, 0x07f407f1, 0x07c007cf, 0x07dc07c7, + 0x073007d5, 0x07350737, 0x0703070c, 0x07010700, 0x07040707, 0x071d071f, 0x07100713, 0x0774077d, + 0x074d074f, 0x07470740, 0x0754075c, 0x04fd04fc, 0x04f504f0, 0x04c304cc, 0x04c104c0, 0x04d004c4, + 0x0433043c, 0x04310430, 0x040f0434, 0x040d040c, 0x04000403, 0x04070401, 0x04050404, 0x0413041c, + 0x04110410, 0x047c0414, 0x04740470, 0x0443044c, 0x04410440, 0x04440447, 0x05f30450, 0x05c005f7, + 0x05df05c5, 0x05d105d0, 0x053005d4, 0x05340537, 0x0500050c, 0x05070501, 0x051d0504, 0x05170510, + 0x057c0515, 0x054d0575, 0x05410540, 0x05450547, 0x1ff0055c, 0x1fc11fc3, 0x1fd01fc4, 0x1f0f1f33, + 0x1f011f00, 0x1f051f07, 0x1f131f1c, 0x1f141f11, 0x1f411f7c, 0x1cfc1f50, 0x1cf11cf3, 0x1ccd1cf4, + 0x1cdc1cc0, 0x1cd11cdd, 0x1c301cd4, 0x1c0c1c34, 0x1c011c00, 0x1c101c04, 0x1c151c11, 0x1c751c73, + 0x1c401c4d, 0x1c511c5c, 0x1dcc1c54, 0x1dc41dc1, 0x1d3c1d3f, 0x1d001d31, 0x1d071d01, 0x1d701d1f, + 0x1d411d4c, 0x13cc1d50, 0x13c013cd, 0x13c513c1, 0x13d113dc, 0x133f13d4, 0x1330133d, 0x13351337, + 0x1303130c, 0x13011300, 0x13051304, 0x131d131f, 0x13731310, 0x13741370, 0x134d134f, 0x13401343, + 0x13471341, 0x135c1345, 0x13541353, 0x10f710f0, 0x10cc10f5, 0x10c110c0, 0x103310c4, 0x10311030, + 0x100f1034, 0x1003100c, 0x10011000, 0x101c1004, 0x10101013, 0x10141011, 0x10741071, 0x104c1075, + 0x10411040, 0x10451044, 0x1050105d, 0x10571051, 0x11f411fd, 0x11df11c0, 0x11d711d1, 0x113f11d4, + 0x11371130, 0x110c1135, 0x11001103, 0x11071101, 0x111f1105, 0x11171110, 0x117d117f, 0x11751170, + 0x11411143, 0x11441147, 0x1153115f, 0x11551151, 0x17c417c1, 0x173c17d0, 0x1700170d, 0x171c1705, + 0x17701714, 0x1747174c, 0x14fc1751, 0x14cf14f3, 0x14dc14c0, 0x14d114d3, 0x143f14d4, 0x1430143c, + 0x14371431, 0x1403140c, 0x14011400, 0x141f1404, 0x14151410, 0x1473147d, 0x14401475, 0x1453145c, + 0x14541450, 0x15c115cc, 0x153c15c7, 0x15341533, 0x1500150f, 0x15051507, 0x15101513, 0x15711514, + 0x15471543, 0x15511545, 0x7ffd7fff, 0x7ff57ff7, 0x7fdd7fdf, 0x7fd57fd7, 0x7f0f7f30, 0x7f037f0c, + 0x7f047f01, 0x7f7f7f10, 0x7f777f7d, 0x7f407f75, 0x7f5d7f5f, 0x7f557f57, 0x7ccc7cf0, 0x7cc17cc3, + 0x7cd07cc4, 0x7c337c3c, 0x7c0f7c34, 0x7c007c0d, 0x7c077c01, 0x7c137c04, 0x7c147c11, 0x7c747c70, + 0x7c417c43, 0x7c507c44, 0x7dfd7dff, 0x7df57df7, 0x7ddf7dc0, 0x7dd77ddd, 0x7d0c7dd5, 0x7d047d03, + 0x7d7f7d10, 0x7d777d7d, 0x7d407d75, 0x7d5d7d5f, 0x7d557d57, 0x73c473c3, 0x7333733c, 0x7300730c, + 0x731c7305, 0x73147313, 0x73447343, 0x70f470fc, 0x70c070cd, 0x70d170c5, 0x703f70d4, 0x7030703c, + 0x700c7037, 0x70007003, 0x70047001, 0x70107005, 0x70177011, 0x707c7015, 0x70717073, 0x704f7074, + 0x7040704d, 0x70517047, 0x71c171cc, 0x71d071c4, 0x7133713c, 0x71357134, 0x7100710f, 0x71057104, + 0x7111711c, 0x71707115, 0x7145714c, 0x77ff7153, 0x77f777fd, 0x77c077f5, 0x77dd77df, 0x77d577d7, + 0x7730773c, 0x7703770c, 0x77107704, 0x777f7714, 0x7777777d, 0x77407775, 0x775d775f, 0x77557757, + 0x74f174f0, 0x74c374cc, 0x74d074c1, 0x7433743c, 0x74347431, 0x740d740f, 0x74057400, 0x7413741c, + 0x74417470, 0x74507444, 0x75fd75ff, 0x75f575f7, 0x75df75c0, 0x75d775dd, 0x753075d5, 0x7503750c, + 0x757f7501, 0x7577757d, 0x75407575, 0x755d755f, 0x75557557, 0x4fcc4ff0, 0x4fc74fc1, 0x4fd04fc4, + 0x4f314f3c, 0x4f004f34, 0x4f054f07, 0x4f154f14, 0x4f4c4f70, 0x4f414f43, 0x4f504f44, 0x4cf34cfc, + 0x4cf44cf1, 0x4cc04ccf, 0x4cc54cc7, 0x4cd34cdc, 0x4cd44cd1, 0x4c304c3f, 0x4c0c4c0f, 0x4c004c03, + 0x4c044c01, 0x4c104c1d, 0x4c714c73, 0x4c404c4d, 0x4c5c4c47, 0x4c514c53, 0x4df04c54, 0x4dc34dcc, + 0x4dd04dc4, 0x4d314d33, 0x4d0f4d34, 0x4d004d0d, 0x4d114d07, 0x4d704d14, 0x4d414d43, 0x43fc4d54, + 0x43f143f3, 0x43c043cf, 0x43d143c7, 0x4335433f, 0x4303430c, 0x43014300, 0x43044307, 0x431c431f, + 0x4310431d, 0x43714373, 0x4343434d, 0x43474340, 0x4354435c, 0x40f040ff, 0x40f540f7, 0x40cc40cf, + 0x40c040c3, 0x40c440c1, 0x40d040dc, 0x40d540d4, 0x4033403c, 0x40314030, 0x400f4034, 0x400d400c, + 0x40004003, 0x40074001, 0x40054004, 0x4013401c, 0x40114010, 0x407c4014, 0x40774070, 0x404d404c, + 0x40404043, 0x40444041, 0x405f4045, 0x4050405d, 0x40554057, 0x41f341fc, 0x41c041cf, 0x41df41c4, + 0x41d441d1, 0x41374130, 0x410c4134, 0x4100410d, 0x41044101, 0x41174110, 0x4173417d, 0x41754174, + 0x4143414d, 0x41534140, 0x41544151, 0x47c147f0, 0x47d047c4, 0x4731473c, 0x470d470f, 0x47014700, + 0x47134705, 0x47704710, 0x4741474c, 0x47504744, 0x44f144f3, 0x44cf44f4, 0x44c044cd, 0x44c544c7, + 0x44dc44df, 0x44d144d3, 0x443d443f, 0x44374430, 0x440c4435, 0x44004403, 0x44044401, 0x4410441d, + 0x44154411, 0x4473447c, 0x444d444f, 0x44454440, 0x4451445c, 0x45c045f0, 0x453345d0, 0x45344531, + 0x4500450f, 0x451c4507, 0x454c4570, 0x45404543, 0x5fff4541, 0x5ff75ffd, 0x5fc05ff5, 0x5fdd5fdf, + 0x5fd55fd7, 0x5f0c5f30, 0x5f015f03, 0x5f7f5f04, 0x5f775f7d, 0x5f405f75, 0x5f5d5f5f, 0x5f555f57, + 0x5cf45cf0, 0x5cc35ccc, 0x5cc45cc1, 0x5c315cc5, 0x5c0c5c34, 0x5c075c00, 0x5c1c5c05, 0x5c705c13, + 0x5c4d5c4f, 0x5c445c41, 0x5df75dfd, 0x5dcf5df5, 0x5ddd5dc4, 0x5dd55dd7, 0x5d0c5d30, 0x5d045d01, + 0x5d7f5d10, 0x5d775d7d, 0x5d405d75, 0x5d5d5d5f, 0x5d555d57, 0x53d053c4, 0x5333533c, 0x5303530f, + 0x53075300, 0x531c5305, 0x53115310, 0x53145317, 0x50f15370, 0x50cf50f4, 0x50c050cd, 0x50d150c7, + 0x503d50d4, 0x500c5030, 0x50005003, 0x50045001, 0x50155010, 0x5073507c, 0x50715070, 0x504d5074, + 0x50475040, 0x51cc51f0, 0x51c551c1, 0x51d051dc, 0x51315133, 0x510d5135, 0x51015100, 0x511f5107, + 0x5171511d, 0x5140514f, 0x51445141, 0x5153515c, 0x57ff5151, 0x57f757fd, 0x57df57f5, 0x57d757dd, + 0x570c57d5, 0x57015703, 0x577f5704, 0x5777577d, 0x57405775, 0x575d575f, 0x57555757, 0x54c354f0, + 0x54dc54c4, 0x543c54d0, 0x5400540f, 0x541c5405, 0x54145411, 0x5441544f, 0x55fd55ff, 0x55f555f7, + 0x55dd55df, 0x55d555d7, 0x5503550c, 0x557f5501, 0x5577557d, 0x55405575, 0x555d555f, 0x55555557 +}; + +shared uint16_t iq1s_grid[2048]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq1s_grid_const.length(); i += wgsize.x) { + uint idx = i + gl_LocalInvocationIndex.x; + if (iq1s_grid_const.length() % wgsize.x == 0 || idx < iq1s_grid_const.length()) { + u16vec2 g = unpack16(iq1s_grid_const[idx]); + iq1s_grid[2*idx+0] = g.x; + iq1s_grid[2*idx+1] = g.y; + } + } + barrier(); +} +#endif + +#define QUANT_K_IQ2_XXS 256 +#define QUANT_R_IQ2_XXS 1 + +struct block_iq2_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_XXS/4]; +}; + +struct block_iq2_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XXS/8]; +}; + +#if defined(DATA_A_IQ2_XXS) + +const uvec2[256] iq2xxs_grid_const = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x082b0808, 0x08080808), + uvec2(0x082b082b, 0x08080808), uvec2(0x082b2b08, 0x08080808), uvec2(0x082b2b2b, 0x08080808), uvec2(0x19080819, 0x08080808), + uvec2(0x19081908, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), + uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b082b2b, 0x08080808), + uvec2(0x2b2b082b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x2b081908, 0x08080819), uvec2(0x2b192b08, 0x08080819), + uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x082b082b, 0x0808082b), uvec2(0x2b08082b, 0x0808082b), + uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x082b0819, 0x08081908), + uvec2(0x082b1908, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x192b0808, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), + uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), uvec2(0x08082b08, 0x08081919), + uvec2(0x082b0808, 0x08081919), uvec2(0x1908192b, 0x08081919), uvec2(0x192b2b19, 0x08081919), uvec2(0x2b080808, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b2b1908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x08081919, 0x08082b08), + uvec2(0x08082b08, 0x08082b08), uvec2(0x08191908, 0x08082b08), uvec2(0x082b2b08, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x2b082b08, 0x08082b08), + uvec2(0x08081908, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x0808082b, 0x08082b2b), uvec2(0x08191908, 0x08082b2b), + uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x082b0819, 0x08190808), + uvec2(0x19080808, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), + uvec2(0x2b191919, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x082b0808, 0x08190819), + uvec2(0x19190808, 0x08190819), uvec2(0x19192b2b, 0x08190819), uvec2(0x2b080808, 0x08190819), uvec2(0x082b1908, 0x0819082b), + uvec2(0x19081919, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x08082b08, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x082b1919, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08192b08, 0x08191919), + uvec2(0x192b082b, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x0819192b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x2b2b0808, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x19081908, 0x082b0808), + uvec2(0x192b0819, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b08082b, 0x082b0808), uvec2(0x082b2b19, 0x082b0819), + uvec2(0x19082b08, 0x082b0819), uvec2(0x08080808, 0x082b082b), uvec2(0x0808082b, 0x082b082b), uvec2(0x08080819, 0x082b1908), + uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x19080808, 0x082b1908), uvec2(0x1919192b, 0x082b1908), + uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x192b1908, 0x082b1919), uvec2(0x2b190808, 0x082b192b), + uvec2(0x08082b08, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), uvec2(0x2b191908, 0x082b2b08), uvec2(0x19081908, 0x082b2b2b), + uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x08192b08, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x19080808, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x1919192b, 0x19080808), uvec2(0x192b0808, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), + uvec2(0x2b190808, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x192b0819, 0x19080819), + uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x19082b08, 0x1908082b), uvec2(0x1919192b, 0x1908082b), uvec2(0x192b2b08, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b192b19, 0x19081908), + uvec2(0x0819082b, 0x19081919), uvec2(0x082b1908, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08080819, 0x19082b08), + uvec2(0x08081908, 0x19082b08), uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x08080808, 0x19082b19), uvec2(0x19192b08, 0x19082b19), uvec2(0x192b0819, 0x19082b19), uvec2(0x2b08082b, 0x19082b19), + uvec2(0x19081919, 0x19082b2b), uvec2(0x2b190808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x08082b08, 0x19190808), + uvec2(0x08190819, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x2b080808, 0x19190808), + uvec2(0x2b082b08, 0x19190808), uvec2(0x08081908, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x2b2b1908, 0x19190819), + uvec2(0x2b190819, 0x1919082b), uvec2(0x2b190808, 0x19191908), uvec2(0x2b19082b, 0x19191908), uvec2(0x08082b2b, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x19191908, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x08190819, 0x19192b08), + uvec2(0x08192b19, 0x19192b08), uvec2(0x192b1908, 0x19192b08), uvec2(0x19080808, 0x19192b19), uvec2(0x08082b08, 0x19192b2b), + uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x192b2b08, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x19191919, 0x192b0819), uvec2(0x08192b08, 0x192b082b), uvec2(0x192b0808, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x08081919, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x0819082b, 0x192b1919), + uvec2(0x2b081908, 0x192b1919), uvec2(0x1908082b, 0x192b2b08), uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), + uvec2(0x08082b2b, 0x2b080808), uvec2(0x19080819, 0x2b080808), uvec2(0x2b08082b, 0x2b080808), uvec2(0x08081908, 0x2b080819), + uvec2(0x08192b08, 0x2b080819), uvec2(0x19080808, 0x2b080819), uvec2(0x08190819, 0x2b08082b), uvec2(0x08080819, 0x2b081908), + uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x192b0808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x1908192b, 0x2b081919), uvec2(0x2b191908, 0x2b081919), + uvec2(0x08082b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x192b0808, 0x2b08192b), uvec2(0x0808082b, 0x2b082b08), + uvec2(0x08081908, 0x2b082b19), uvec2(0x08190819, 0x2b082b2b), uvec2(0x08081908, 0x2b190808), uvec2(0x08190808, 0x2b190808), + uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x2b2b0819, 0x2b190808), uvec2(0x0819192b, 0x2b190819), + uvec2(0x2b080808, 0x2b190819), uvec2(0x19081919, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x082b082b, 0x2b191908), + uvec2(0x19081908, 0x2b191908), uvec2(0x19190819, 0x2b191919), uvec2(0x2b080819, 0x2b192b08), uvec2(0x082b0808, 0x2b192b19), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b081919, 0x2b2b0808), uvec2(0x08082b19, 0x2b2b0819), + uvec2(0x08080808, 0x2b2b082b), uvec2(0x08192b08, 0x2b2b1908), uvec2(0x19190808, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19) +}; + +shared uvec2 iq2xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2xxs_grid.length(); i += wgsize.x) { + if (iq2xxs_grid_const.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xxs_grid_const.length()) { + iq2xxs_grid[i + gl_LocalInvocationIndex.x] = iq2xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XXS +#define QUANT_R QUANT_R_IQ2_XXS +#define A_TYPE block_iq2_xxs +#define A_TYPE_PACKED16 block_iq2_xxs_packed16 +#endif + +#define QUANT_K_IQ2_XS 256 +#define QUANT_R_IQ2_XS 1 + +struct block_iq2_xs +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint8_t scales[QUANT_K_IQ2_XS/32]; +}; + +struct block_iq2_xs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_XS/8]; + uint16_t scales[QUANT_K_IQ2_XS/64]; +}; + +#if defined(DATA_A_IQ2_XS) + +const uvec2 iq2xs_grid_const[512] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x2b080808, 0x08080808), + uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), + uvec2(0x2b191908, 0x08080808), uvec2(0x2b192b19, 0x08080808), uvec2(0x2b2b0808, 0x08080808), uvec2(0x08080819, 0x08080819), + uvec2(0x08081908, 0x08080819), uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), + uvec2(0x0819082b, 0x08080819), uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x08192b2b, 0x08080819), + uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), + uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), uvec2(0x2b081908, 0x08080819), + uvec2(0x2b190808, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), uvec2(0x08081919, 0x0808082b), + uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), uvec2(0x082b0808, 0x0808082b), + uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), + uvec2(0x0808192b, 0x08081908), uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), + uvec2(0x08191919, 0x08081908), uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), + uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), uvec2(0x19082b08, 0x08081908), + uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), uvec2(0x1919192b, 0x08081908), uvec2(0x192b0808, 0x08081908), + uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x08080808, 0x08081919), + uvec2(0x0808082b, 0x08081919), uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x082b0808, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x19190808, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x2b080808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x08190808, 0x0808192b), uvec2(0x082b192b, 0x0808192b), uvec2(0x19080808, 0x0808192b), + uvec2(0x1908082b, 0x0808192b), uvec2(0x2b081908, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08082b2b, 0x08082b08), uvec2(0x08190819, 0x08082b08), + uvec2(0x08191908, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), uvec2(0x19080819, 0x08082b08), + uvec2(0x19081908, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x19192b08, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b2b0808, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), uvec2(0x08081908, 0x08082b19), + uvec2(0x08190808, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x2b080819, 0x08082b19), uvec2(0x2b082b19, 0x08082b19), + uvec2(0x08080808, 0x08082b2b), uvec2(0x082b0808, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x2b19192b, 0x08082b2b), + uvec2(0x2b2b0808, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), uvec2(0x0808192b, 0x08190808), + uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), uvec2(0x08191919, 0x08190808), + uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), uvec2(0x19080808, 0x08190808), + uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), + uvec2(0x19191908, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b2b2b, 0x08190808), uvec2(0x2b080819, 0x08190808), + uvec2(0x2b081908, 0x08190808), uvec2(0x2b190808, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), + uvec2(0x08081919, 0x08190819), uvec2(0x08082b08, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x082b0808, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), uvec2(0x19190808, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x2b19192b, 0x08190819), uvec2(0x08080819, 0x0819082b), + uvec2(0x08081908, 0x0819082b), uvec2(0x0808192b, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x19080808, 0x0819082b), + uvec2(0x192b0808, 0x0819082b), uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), + uvec2(0x08082b08, 0x08191908), uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x082b0808, 0x08191908), + uvec2(0x19080819, 0x08191908), uvec2(0x19081908, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x08080808, 0x0819192b), uvec2(0x08191908, 0x0819192b), + uvec2(0x19082b19, 0x0819192b), uvec2(0x08080819, 0x08192b08), uvec2(0x08081908, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x0819082b, 0x08192b08), uvec2(0x19080808, 0x08192b08), uvec2(0x19191908, 0x08192b08), uvec2(0x2b08192b, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x192b192b, 0x08192b19), uvec2(0x19190819, 0x08192b2b), + uvec2(0x2b2b2b19, 0x08192b2b), uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), + uvec2(0x08082b08, 0x082b0808), uvec2(0x08082b2b, 0x082b0808), uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x19080819, 0x082b0808), uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), + uvec2(0x2b080808, 0x082b0808), uvec2(0x2b2b0808, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), + uvec2(0x08190808, 0x082b0819), uvec2(0x19080808, 0x082b0819), uvec2(0x19082b08, 0x082b0819), uvec2(0x192b1919, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x2b080808, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), + uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x082b2b19, 0x082b1908), + uvec2(0x19080808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x19080819, 0x082b1919), uvec2(0x1919082b, 0x082b1919), + uvec2(0x2b192b19, 0x082b1919), uvec2(0x08080819, 0x082b192b), uvec2(0x08192b2b, 0x082b192b), uvec2(0x2b2b192b, 0x082b192b), + uvec2(0x08080808, 0x082b2b08), uvec2(0x08082b08, 0x082b2b08), uvec2(0x08082b2b, 0x082b2b08), uvec2(0x082b0808, 0x082b2b08), + uvec2(0x19191919, 0x082b2b08), uvec2(0x2b082b08, 0x082b2b08), uvec2(0x2b2b082b, 0x082b2b08), uvec2(0x192b2b08, 0x082b2b19), + uvec2(0x2b190808, 0x082b2b19), uvec2(0x08082b08, 0x082b2b2b), uvec2(0x082b0808, 0x082b2b2b), uvec2(0x2b08082b, 0x082b2b2b), + uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), uvec2(0x08081908, 0x19080808), + uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), uvec2(0x0819082b, 0x19080808), + uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), + uvec2(0x19080808, 0x19080808), uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), + uvec2(0x19082b2b, 0x19080808), uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x192b0808, 0x19080808), + uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), + uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), + uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x19080819, 0x19080819), + uvec2(0x19081908, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x2b080808, 0x19080819), uvec2(0x2b081919, 0x19080819), + uvec2(0x2b2b082b, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), uvec2(0x08190808, 0x1908082b), + uvec2(0x0819082b, 0x1908082b), uvec2(0x082b2b19, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x08080808, 0x19081908), + uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), uvec2(0x08082b08, 0x19081908), uvec2(0x08190819, 0x19081908), + uvec2(0x08191908, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x19080819, 0x19081908), + uvec2(0x19081908, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x2b080808, 0x19081908), uvec2(0x2b191908, 0x19081908), + uvec2(0x08080819, 0x19081919), uvec2(0x08081908, 0x19081919), uvec2(0x08190808, 0x19081919), uvec2(0x082b1908, 0x19081919), + uvec2(0x19080808, 0x19081919), uvec2(0x2b192b2b, 0x19081919), uvec2(0x08080808, 0x1908192b), uvec2(0x08082b2b, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), + uvec2(0x08190808, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x19081919, 0x19082b08), uvec2(0x19191908, 0x19082b08), + uvec2(0x192b082b, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x19081908, 0x19082b19), + uvec2(0x19190808, 0x19082b19), uvec2(0x192b2b19, 0x19082b19), uvec2(0x08081908, 0x19082b2b), uvec2(0x08080808, 0x19190808), + uvec2(0x0808082b, 0x19190808), uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), + uvec2(0x08191908, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), + uvec2(0x19081908, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x2b080808, 0x19190808), uvec2(0x08080819, 0x19190819), + uvec2(0x08081908, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x08191919, 0x19190819), uvec2(0x19080808, 0x19190819), + uvec2(0x1908082b, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x2b2b2b2b, 0x1919082b), + uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x082b0819, 0x19191908), + uvec2(0x19080808, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b2b0819, 0x19191908), + uvec2(0x08080808, 0x19191919), uvec2(0x08082b08, 0x19191919), uvec2(0x2b080808, 0x19191919), uvec2(0x2b082b08, 0x19191919), + uvec2(0x082b0819, 0x1919192b), uvec2(0x192b2b08, 0x1919192b), uvec2(0x2b2b0819, 0x1919192b), uvec2(0x08080808, 0x19192b08), + uvec2(0x08191908, 0x19192b08), uvec2(0x19080819, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x2b192b19, 0x19192b08), + uvec2(0x08192b2b, 0x19192b19), uvec2(0x19080808, 0x19192b19), uvec2(0x1908082b, 0x19192b19), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x08190808, 0x192b0808), uvec2(0x19080808, 0x192b0808), + uvec2(0x19191908, 0x192b0808), uvec2(0x192b082b, 0x192b0808), uvec2(0x2b08192b, 0x192b0808), uvec2(0x2b2b2b19, 0x192b0808), + uvec2(0x08080808, 0x192b0819), uvec2(0x082b1908, 0x192b082b), uvec2(0x19082b2b, 0x192b082b), uvec2(0x2b19082b, 0x192b082b), + uvec2(0x08080808, 0x192b1908), uvec2(0x0819192b, 0x192b1908), uvec2(0x08190808, 0x192b1919), uvec2(0x19080808, 0x192b1919), + uvec2(0x19081919, 0x192b1919), uvec2(0x2b2b1908, 0x192b1919), uvec2(0x08080819, 0x192b2b08), uvec2(0x192b2b2b, 0x192b2b08), + uvec2(0x082b1919, 0x192b2b19), uvec2(0x0808192b, 0x192b2b2b), uvec2(0x19191908, 0x192b2b2b), uvec2(0x192b082b, 0x192b2b2b), + uvec2(0x08080808, 0x2b080808), uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), + uvec2(0x08190819, 0x2b080808), uvec2(0x08191908, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b2b2b, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b08082b, 0x2b080808), uvec2(0x2b2b2b08, 0x2b080808), uvec2(0x2b2b2b2b, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x0808192b, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x19080808, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19192b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x082b0808, 0x2b08082b), + uvec2(0x2b080808, 0x2b08082b), uvec2(0x2b08082b, 0x2b08082b), uvec2(0x2b2b0808, 0x2b08082b), uvec2(0x2b2b2b08, 0x2b08082b), + uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x19080808, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b082b19, 0x2b081908), + uvec2(0x08080808, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x2b2b1919, 0x2b081919), uvec2(0x08192b08, 0x2b08192b), + uvec2(0x192b2b2b, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08082b08, 0x2b082b08), uvec2(0x082b1919, 0x2b082b08), + uvec2(0x19192b2b, 0x2b082b08), uvec2(0x2b080808, 0x2b082b08), uvec2(0x2b08082b, 0x2b082b08), uvec2(0x2b2b2b08, 0x2b082b08), + uvec2(0x0808192b, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x2b080808, 0x2b082b2b), uvec2(0x2b082b08, 0x2b082b2b), + uvec2(0x2b19192b, 0x2b082b2b), uvec2(0x2b2b2b08, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), uvec2(0x08081908, 0x2b190808), + uvec2(0x08190808, 0x2b190808), uvec2(0x19080808, 0x2b190808), uvec2(0x1919192b, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x08080808, 0x2b190819), uvec2(0x082b082b, 0x2b190819), uvec2(0x192b1908, 0x2b190819), uvec2(0x1919192b, 0x2b19082b), + uvec2(0x2b082b19, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x08081919, 0x2b191908), uvec2(0x19081908, 0x2b191908), + uvec2(0x19190808, 0x2b191908), uvec2(0x19192b08, 0x2b191908), uvec2(0x082b2b19, 0x2b191919), uvec2(0x2b190808, 0x2b191919), + uvec2(0x2b19082b, 0x2b191919), uvec2(0x19080819, 0x2b19192b), uvec2(0x19190819, 0x2b192b08), uvec2(0x2b2b192b, 0x2b192b08), + uvec2(0x19082b19, 0x2b192b19), uvec2(0x08191919, 0x2b192b2b), uvec2(0x192b0808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), + uvec2(0x0808082b, 0x2b2b0808), uvec2(0x08082b08, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), uvec2(0x082b0808, 0x2b2b0808), + uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x2b2b0808, 0x2b2b0808), uvec2(0x19190819, 0x2b2b0819), uvec2(0x19192b19, 0x2b2b0819), + uvec2(0x2b2b192b, 0x2b2b0819), uvec2(0x08080808, 0x2b2b082b), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b08, 0x2b2b082b), + uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b080808, 0x2b2b082b), uvec2(0x2b2b0808, 0x2b2b082b), uvec2(0x19080808, 0x2b2b1908), + uvec2(0x2b191919, 0x2b2b1908), uvec2(0x192b1919, 0x2b2b192b), uvec2(0x2b192b08, 0x2b2b192b), uvec2(0x08082b2b, 0x2b2b2b08), + uvec2(0x082b0808, 0x2b2b2b08), uvec2(0x082b082b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b0808, 0x2b2b2b08), + uvec2(0x2b2b2b08, 0x2b2b2b08), uvec2(0x08081908, 0x2b2b2b19), uvec2(0x2b081908, 0x2b2b2b19), uvec2(0x2b08192b, 0x2b2b2b19), + uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x082b2b2b, 0x2b2b2b2b), uvec2(0x2b190819, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b), +}; + +shared uvec2 iq2xs_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2xs_grid.length(); i += wgsize.x) { + if (iq2xs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2xs_grid_const.length()) { + iq2xs_grid[i + gl_LocalInvocationIndex.x] = iq2xs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_XS +#define QUANT_R QUANT_R_IQ2_XS +#define A_TYPE block_iq2_xs +#define A_TYPE_PACKED16 block_iq2_xs_packed16 +#endif + +#define QUANT_K_IQ2_S 256 +#define QUANT_R_IQ2_S 1 + +struct block_iq2_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ2_S/4]; + uint8_t qh[QUANT_K_IQ2_S/32]; + uint8_t scales[QUANT_K_IQ2_S/32]; +}; + +struct block_iq2_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ2_S/8]; + uint16_t qh[QUANT_K_IQ2_S/64]; + uint16_t scales[QUANT_K_IQ2_S/64]; +}; + +#if defined(DATA_A_IQ2_S) + +const uvec2 iq2s_grid_const[1024] = { + uvec2(0x08080808, 0x08080808), uvec2(0x0808082b, 0x08080808), uvec2(0x08081919, 0x08080808), uvec2(0x08082b08, 0x08080808), + uvec2(0x08082b2b, 0x08080808), uvec2(0x08190819, 0x08080808), uvec2(0x08191908, 0x08080808), uvec2(0x0819192b, 0x08080808), + uvec2(0x08192b19, 0x08080808), uvec2(0x082b0808, 0x08080808), uvec2(0x082b082b, 0x08080808), uvec2(0x082b1919, 0x08080808), + uvec2(0x082b2b08, 0x08080808), uvec2(0x19080819, 0x08080808), uvec2(0x19081908, 0x08080808), uvec2(0x1908192b, 0x08080808), + uvec2(0x19082b19, 0x08080808), uvec2(0x19190808, 0x08080808), uvec2(0x1919082b, 0x08080808), uvec2(0x19191919, 0x08080808), + uvec2(0x19192b08, 0x08080808), uvec2(0x192b0819, 0x08080808), uvec2(0x192b1908, 0x08080808), uvec2(0x192b192b, 0x08080808), + uvec2(0x192b2b19, 0x08080808), uvec2(0x2b080808, 0x08080808), uvec2(0x2b08082b, 0x08080808), uvec2(0x2b081919, 0x08080808), + uvec2(0x2b082b08, 0x08080808), uvec2(0x2b190819, 0x08080808), uvec2(0x2b191908, 0x08080808), uvec2(0x2b2b0808, 0x08080808), + uvec2(0x2b2b1919, 0x08080808), uvec2(0x2b2b2b2b, 0x08080808), uvec2(0x08080819, 0x08080819), uvec2(0x08081908, 0x08080819), + uvec2(0x0808192b, 0x08080819), uvec2(0x08082b19, 0x08080819), uvec2(0x08190808, 0x08080819), uvec2(0x0819082b, 0x08080819), + uvec2(0x08191919, 0x08080819), uvec2(0x08192b08, 0x08080819), uvec2(0x082b0819, 0x08080819), uvec2(0x082b1908, 0x08080819), + uvec2(0x19080808, 0x08080819), uvec2(0x1908082b, 0x08080819), uvec2(0x19081919, 0x08080819), uvec2(0x19082b08, 0x08080819), + uvec2(0x19190819, 0x08080819), uvec2(0x19191908, 0x08080819), uvec2(0x1919192b, 0x08080819), uvec2(0x19192b19, 0x08080819), + uvec2(0x192b0808, 0x08080819), uvec2(0x192b1919, 0x08080819), uvec2(0x192b2b08, 0x08080819), uvec2(0x2b080819, 0x08080819), + uvec2(0x2b081908, 0x08080819), uvec2(0x2b190808, 0x08080819), uvec2(0x2b19082b, 0x08080819), uvec2(0x2b191919, 0x08080819), + uvec2(0x2b2b0819, 0x08080819), uvec2(0x2b2b1908, 0x08080819), uvec2(0x08080808, 0x0808082b), uvec2(0x0808082b, 0x0808082b), + uvec2(0x08081919, 0x0808082b), uvec2(0x08082b08, 0x0808082b), uvec2(0x08190819, 0x0808082b), uvec2(0x08191908, 0x0808082b), + uvec2(0x082b0808, 0x0808082b), uvec2(0x082b2b2b, 0x0808082b), uvec2(0x19080819, 0x0808082b), uvec2(0x19081908, 0x0808082b), + uvec2(0x1908192b, 0x0808082b), uvec2(0x19082b19, 0x0808082b), uvec2(0x19190808, 0x0808082b), uvec2(0x19191919, 0x0808082b), + uvec2(0x2b080808, 0x0808082b), uvec2(0x2b081919, 0x0808082b), uvec2(0x2b082b2b, 0x0808082b), uvec2(0x2b191908, 0x0808082b), + uvec2(0x2b2b082b, 0x0808082b), uvec2(0x08080819, 0x08081908), uvec2(0x08081908, 0x08081908), uvec2(0x0808192b, 0x08081908), + uvec2(0x08082b19, 0x08081908), uvec2(0x08190808, 0x08081908), uvec2(0x0819082b, 0x08081908), uvec2(0x08191919, 0x08081908), + uvec2(0x08192b08, 0x08081908), uvec2(0x082b0819, 0x08081908), uvec2(0x082b1908, 0x08081908), uvec2(0x082b192b, 0x08081908), + uvec2(0x082b2b19, 0x08081908), uvec2(0x19080808, 0x08081908), uvec2(0x1908082b, 0x08081908), uvec2(0x19081919, 0x08081908), + uvec2(0x19082b08, 0x08081908), uvec2(0x19082b2b, 0x08081908), uvec2(0x19190819, 0x08081908), uvec2(0x19191908, 0x08081908), + uvec2(0x1919192b, 0x08081908), uvec2(0x19192b19, 0x08081908), uvec2(0x192b0808, 0x08081908), uvec2(0x192b082b, 0x08081908), + uvec2(0x192b1919, 0x08081908), uvec2(0x2b080819, 0x08081908), uvec2(0x2b081908, 0x08081908), uvec2(0x2b08192b, 0x08081908), + uvec2(0x2b082b19, 0x08081908), uvec2(0x2b190808, 0x08081908), uvec2(0x2b191919, 0x08081908), uvec2(0x2b192b08, 0x08081908), + uvec2(0x2b2b0819, 0x08081908), uvec2(0x2b2b1908, 0x08081908), uvec2(0x08080808, 0x08081919), uvec2(0x0808082b, 0x08081919), + uvec2(0x08081919, 0x08081919), uvec2(0x08082b08, 0x08081919), uvec2(0x08082b2b, 0x08081919), uvec2(0x08190819, 0x08081919), + uvec2(0x08191908, 0x08081919), uvec2(0x0819192b, 0x08081919), uvec2(0x08192b19, 0x08081919), uvec2(0x082b0808, 0x08081919), + uvec2(0x082b1919, 0x08081919), uvec2(0x082b2b08, 0x08081919), uvec2(0x19080819, 0x08081919), uvec2(0x19081908, 0x08081919), + uvec2(0x1908192b, 0x08081919), uvec2(0x19082b19, 0x08081919), uvec2(0x19190808, 0x08081919), uvec2(0x1919082b, 0x08081919), + uvec2(0x19191919, 0x08081919), uvec2(0x19192b08, 0x08081919), uvec2(0x192b0819, 0x08081919), uvec2(0x192b1908, 0x08081919), + uvec2(0x2b080808, 0x08081919), uvec2(0x2b08082b, 0x08081919), uvec2(0x2b081919, 0x08081919), uvec2(0x2b082b08, 0x08081919), + uvec2(0x2b190819, 0x08081919), uvec2(0x2b191908, 0x08081919), uvec2(0x2b2b0808, 0x08081919), uvec2(0x08080819, 0x0808192b), + uvec2(0x08081908, 0x0808192b), uvec2(0x0808192b, 0x0808192b), uvec2(0x08082b19, 0x0808192b), uvec2(0x08190808, 0x0808192b), + uvec2(0x08191919, 0x0808192b), uvec2(0x19080808, 0x0808192b), uvec2(0x19081919, 0x0808192b), uvec2(0x19082b08, 0x0808192b), + uvec2(0x19190819, 0x0808192b), uvec2(0x19191908, 0x0808192b), uvec2(0x192b0808, 0x0808192b), uvec2(0x2b080819, 0x0808192b), + uvec2(0x2b081908, 0x0808192b), uvec2(0x2b190808, 0x0808192b), uvec2(0x08080808, 0x08082b08), uvec2(0x0808082b, 0x08082b08), + uvec2(0x08081919, 0x08082b08), uvec2(0x08082b08, 0x08082b08), uvec2(0x08190819, 0x08082b08), uvec2(0x08191908, 0x08082b08), + uvec2(0x0819192b, 0x08082b08), uvec2(0x08192b19, 0x08082b08), uvec2(0x082b0808, 0x08082b08), uvec2(0x082b1919, 0x08082b08), + uvec2(0x082b2b2b, 0x08082b08), uvec2(0x19080819, 0x08082b08), uvec2(0x19081908, 0x08082b08), uvec2(0x1908192b, 0x08082b08), + uvec2(0x19082b19, 0x08082b08), uvec2(0x19190808, 0x08082b08), uvec2(0x1919082b, 0x08082b08), uvec2(0x19191919, 0x08082b08), + uvec2(0x19192b08, 0x08082b08), uvec2(0x192b0819, 0x08082b08), uvec2(0x192b1908, 0x08082b08), uvec2(0x2b080808, 0x08082b08), + uvec2(0x2b081919, 0x08082b08), uvec2(0x2b191908, 0x08082b08), uvec2(0x2b2b2b2b, 0x08082b08), uvec2(0x08080819, 0x08082b19), + uvec2(0x08081908, 0x08082b19), uvec2(0x08190808, 0x08082b19), uvec2(0x0819082b, 0x08082b19), uvec2(0x08191919, 0x08082b19), + uvec2(0x08192b08, 0x08082b19), uvec2(0x082b0819, 0x08082b19), uvec2(0x19080808, 0x08082b19), uvec2(0x19081919, 0x08082b19), + uvec2(0x19082b08, 0x08082b19), uvec2(0x19190819, 0x08082b19), uvec2(0x19191908, 0x08082b19), uvec2(0x192b0808, 0x08082b19), + uvec2(0x2b080819, 0x08082b19), uvec2(0x2b190808, 0x08082b19), uvec2(0x08080808, 0x08082b2b), uvec2(0x08190819, 0x08082b2b), + uvec2(0x08191908, 0x08082b2b), uvec2(0x082b082b, 0x08082b2b), uvec2(0x082b2b08, 0x08082b2b), uvec2(0x082b2b2b, 0x08082b2b), + uvec2(0x19190808, 0x08082b2b), uvec2(0x2b192b19, 0x08082b2b), uvec2(0x08080819, 0x08190808), uvec2(0x08081908, 0x08190808), + uvec2(0x0808192b, 0x08190808), uvec2(0x08082b19, 0x08190808), uvec2(0x08190808, 0x08190808), uvec2(0x0819082b, 0x08190808), + uvec2(0x08191919, 0x08190808), uvec2(0x08192b08, 0x08190808), uvec2(0x082b0819, 0x08190808), uvec2(0x082b1908, 0x08190808), + uvec2(0x082b192b, 0x08190808), uvec2(0x19080808, 0x08190808), uvec2(0x1908082b, 0x08190808), uvec2(0x19081919, 0x08190808), + uvec2(0x19082b08, 0x08190808), uvec2(0x19190819, 0x08190808), uvec2(0x19191908, 0x08190808), uvec2(0x1919192b, 0x08190808), + uvec2(0x19192b19, 0x08190808), uvec2(0x192b0808, 0x08190808), uvec2(0x192b082b, 0x08190808), uvec2(0x192b1919, 0x08190808), + uvec2(0x192b2b08, 0x08190808), uvec2(0x2b080819, 0x08190808), uvec2(0x2b081908, 0x08190808), uvec2(0x2b08192b, 0x08190808), + uvec2(0x2b190808, 0x08190808), uvec2(0x2b191919, 0x08190808), uvec2(0x2b192b08, 0x08190808), uvec2(0x2b2b0819, 0x08190808), + uvec2(0x2b2b1908, 0x08190808), uvec2(0x08080808, 0x08190819), uvec2(0x0808082b, 0x08190819), uvec2(0x08081919, 0x08190819), + uvec2(0x08082b08, 0x08190819), uvec2(0x08082b2b, 0x08190819), uvec2(0x08190819, 0x08190819), uvec2(0x08191908, 0x08190819), + uvec2(0x0819192b, 0x08190819), uvec2(0x08192b19, 0x08190819), uvec2(0x082b0808, 0x08190819), uvec2(0x082b082b, 0x08190819), + uvec2(0x082b1919, 0x08190819), uvec2(0x082b2b08, 0x08190819), uvec2(0x19080819, 0x08190819), uvec2(0x19081908, 0x08190819), + uvec2(0x1908192b, 0x08190819), uvec2(0x19082b19, 0x08190819), uvec2(0x19190808, 0x08190819), uvec2(0x1919082b, 0x08190819), + uvec2(0x19191919, 0x08190819), uvec2(0x19192b08, 0x08190819), uvec2(0x192b0819, 0x08190819), uvec2(0x192b1908, 0x08190819), + uvec2(0x2b080808, 0x08190819), uvec2(0x2b08082b, 0x08190819), uvec2(0x2b081919, 0x08190819), uvec2(0x2b082b08, 0x08190819), + uvec2(0x2b190819, 0x08190819), uvec2(0x2b191908, 0x08190819), uvec2(0x08080819, 0x0819082b), uvec2(0x08081908, 0x0819082b), + uvec2(0x08082b19, 0x0819082b), uvec2(0x08190808, 0x0819082b), uvec2(0x08191919, 0x0819082b), uvec2(0x082b0819, 0x0819082b), + uvec2(0x082b1908, 0x0819082b), uvec2(0x19080808, 0x0819082b), uvec2(0x19081919, 0x0819082b), uvec2(0x19190819, 0x0819082b), + uvec2(0x19191908, 0x0819082b), uvec2(0x2b080819, 0x0819082b), uvec2(0x2b081908, 0x0819082b), uvec2(0x2b190808, 0x0819082b), + uvec2(0x08080808, 0x08191908), uvec2(0x0808082b, 0x08191908), uvec2(0x08081919, 0x08191908), uvec2(0x08082b08, 0x08191908), + uvec2(0x08190819, 0x08191908), uvec2(0x08191908, 0x08191908), uvec2(0x0819192b, 0x08191908), uvec2(0x08192b19, 0x08191908), + uvec2(0x082b0808, 0x08191908), uvec2(0x082b1919, 0x08191908), uvec2(0x082b2b08, 0x08191908), uvec2(0x19080819, 0x08191908), + uvec2(0x19081908, 0x08191908), uvec2(0x1908192b, 0x08191908), uvec2(0x19082b19, 0x08191908), uvec2(0x19190808, 0x08191908), + uvec2(0x1919082b, 0x08191908), uvec2(0x19191919, 0x08191908), uvec2(0x19192b08, 0x08191908), uvec2(0x192b0819, 0x08191908), + uvec2(0x192b1908, 0x08191908), uvec2(0x2b080808, 0x08191908), uvec2(0x2b08082b, 0x08191908), uvec2(0x2b081919, 0x08191908), + uvec2(0x2b082b08, 0x08191908), uvec2(0x2b190819, 0x08191908), uvec2(0x2b191908, 0x08191908), uvec2(0x2b2b0808, 0x08191908), + uvec2(0x08080819, 0x08191919), uvec2(0x08081908, 0x08191919), uvec2(0x0808192b, 0x08191919), uvec2(0x08082b19, 0x08191919), + uvec2(0x08190808, 0x08191919), uvec2(0x0819082b, 0x08191919), uvec2(0x08191919, 0x08191919), uvec2(0x08192b08, 0x08191919), + uvec2(0x082b0819, 0x08191919), uvec2(0x082b1908, 0x08191919), uvec2(0x19080808, 0x08191919), uvec2(0x1908082b, 0x08191919), + uvec2(0x19081919, 0x08191919), uvec2(0x19082b08, 0x08191919), uvec2(0x19190819, 0x08191919), uvec2(0x19191908, 0x08191919), + uvec2(0x192b0808, 0x08191919), uvec2(0x2b080819, 0x08191919), uvec2(0x2b081908, 0x08191919), uvec2(0x2b190808, 0x08191919), + uvec2(0x08080808, 0x0819192b), uvec2(0x08081919, 0x0819192b), uvec2(0x08082b08, 0x0819192b), uvec2(0x08190819, 0x0819192b), + uvec2(0x08191908, 0x0819192b), uvec2(0x082b0808, 0x0819192b), uvec2(0x19080819, 0x0819192b), uvec2(0x19081908, 0x0819192b), + uvec2(0x19190808, 0x0819192b), uvec2(0x2b080808, 0x0819192b), uvec2(0x2b2b2b2b, 0x0819192b), uvec2(0x08080819, 0x08192b08), + uvec2(0x08081908, 0x08192b08), uvec2(0x0808192b, 0x08192b08), uvec2(0x08082b19, 0x08192b08), uvec2(0x08190808, 0x08192b08), + uvec2(0x08191919, 0x08192b08), uvec2(0x08192b08, 0x08192b08), uvec2(0x082b0819, 0x08192b08), uvec2(0x19080808, 0x08192b08), + uvec2(0x1908082b, 0x08192b08), uvec2(0x19081919, 0x08192b08), uvec2(0x19082b08, 0x08192b08), uvec2(0x19190819, 0x08192b08), + uvec2(0x19191908, 0x08192b08), uvec2(0x192b0808, 0x08192b08), uvec2(0x2b080819, 0x08192b08), uvec2(0x2b081908, 0x08192b08), + uvec2(0x08080808, 0x08192b19), uvec2(0x0808082b, 0x08192b19), uvec2(0x08081919, 0x08192b19), uvec2(0x08082b08, 0x08192b19), + uvec2(0x08190819, 0x08192b19), uvec2(0x08191908, 0x08192b19), uvec2(0x082b0808, 0x08192b19), uvec2(0x19080819, 0x08192b19), + uvec2(0x19081908, 0x08192b19), uvec2(0x19190808, 0x08192b19), uvec2(0x192b2b19, 0x08192b19), uvec2(0x2b2b082b, 0x08192b19), + uvec2(0x08081908, 0x08192b2b), uvec2(0x08190808, 0x08192b2b), uvec2(0x19080808, 0x08192b2b), uvec2(0x1919192b, 0x08192b2b), + uvec2(0x08080808, 0x082b0808), uvec2(0x0808082b, 0x082b0808), uvec2(0x08081919, 0x082b0808), uvec2(0x08082b08, 0x082b0808), + uvec2(0x08190819, 0x082b0808), uvec2(0x08191908, 0x082b0808), uvec2(0x0819192b, 0x082b0808), uvec2(0x08192b19, 0x082b0808), + uvec2(0x082b0808, 0x082b0808), uvec2(0x082b1919, 0x082b0808), uvec2(0x082b2b2b, 0x082b0808), uvec2(0x19080819, 0x082b0808), + uvec2(0x19081908, 0x082b0808), uvec2(0x19190808, 0x082b0808), uvec2(0x1919082b, 0x082b0808), uvec2(0x19191919, 0x082b0808), + uvec2(0x192b1908, 0x082b0808), uvec2(0x2b080808, 0x082b0808), uvec2(0x2b082b2b, 0x082b0808), uvec2(0x2b191908, 0x082b0808), + uvec2(0x2b2b2b2b, 0x082b0808), uvec2(0x08080819, 0x082b0819), uvec2(0x08081908, 0x082b0819), uvec2(0x08190808, 0x082b0819), + uvec2(0x0819082b, 0x082b0819), uvec2(0x08191919, 0x082b0819), uvec2(0x082b0819, 0x082b0819), uvec2(0x19080808, 0x082b0819), + uvec2(0x1908082b, 0x082b0819), uvec2(0x19081919, 0x082b0819), uvec2(0x19190819, 0x082b0819), uvec2(0x19191908, 0x082b0819), + uvec2(0x192b0808, 0x082b0819), uvec2(0x2b080819, 0x082b0819), uvec2(0x2b081908, 0x082b0819), uvec2(0x2b190808, 0x082b0819), + uvec2(0x08080808, 0x082b082b), uvec2(0x08082b2b, 0x082b082b), uvec2(0x082b082b, 0x082b082b), uvec2(0x082b2b08, 0x082b082b), + uvec2(0x082b2b2b, 0x082b082b), uvec2(0x19081908, 0x082b082b), uvec2(0x19190808, 0x082b082b), uvec2(0x2b082b08, 0x082b082b), + uvec2(0x2b082b2b, 0x082b082b), uvec2(0x2b2b2b08, 0x082b082b), uvec2(0x08080819, 0x082b1908), uvec2(0x08081908, 0x082b1908), + uvec2(0x0808192b, 0x082b1908), uvec2(0x08082b19, 0x082b1908), uvec2(0x08190808, 0x082b1908), uvec2(0x08191919, 0x082b1908), + uvec2(0x08192b08, 0x082b1908), uvec2(0x082b0819, 0x082b1908), uvec2(0x082b1908, 0x082b1908), uvec2(0x19080808, 0x082b1908), + uvec2(0x1908082b, 0x082b1908), uvec2(0x19081919, 0x082b1908), uvec2(0x19082b08, 0x082b1908), uvec2(0x19190819, 0x082b1908), + uvec2(0x19191908, 0x082b1908), uvec2(0x192b0808, 0x082b1908), uvec2(0x2b080819, 0x082b1908), uvec2(0x2b081908, 0x082b1908), + uvec2(0x2b190808, 0x082b1908), uvec2(0x08080808, 0x082b1919), uvec2(0x08081919, 0x082b1919), uvec2(0x08082b08, 0x082b1919), + uvec2(0x08190819, 0x082b1919), uvec2(0x08191908, 0x082b1919), uvec2(0x082b0808, 0x082b1919), uvec2(0x19080819, 0x082b1919), + uvec2(0x19081908, 0x082b1919), uvec2(0x19190808, 0x082b1919), uvec2(0x192b192b, 0x082b1919), uvec2(0x2b080808, 0x082b1919), + uvec2(0x08080819, 0x082b192b), uvec2(0x08081908, 0x082b192b), uvec2(0x08190808, 0x082b192b), uvec2(0x19080808, 0x082b192b), + uvec2(0x19192b19, 0x082b192b), uvec2(0x08080808, 0x082b2b08), uvec2(0x08081919, 0x082b2b08), uvec2(0x08190819, 0x082b2b08), + uvec2(0x08191908, 0x082b2b08), uvec2(0x19080819, 0x082b2b08), uvec2(0x19081908, 0x082b2b08), uvec2(0x19190808, 0x082b2b08), + uvec2(0x2b082b2b, 0x082b2b08), uvec2(0x2b2b2b2b, 0x082b2b08), uvec2(0x08080819, 0x082b2b19), uvec2(0x08081908, 0x082b2b19), + uvec2(0x08190808, 0x082b2b19), uvec2(0x2b191919, 0x082b2b19), uvec2(0x08082b2b, 0x082b2b2b), uvec2(0x082b082b, 0x082b2b2b), + uvec2(0x192b1908, 0x082b2b2b), uvec2(0x2b082b08, 0x082b2b2b), uvec2(0x2b082b2b, 0x082b2b2b), uvec2(0x08080819, 0x19080808), + uvec2(0x08081908, 0x19080808), uvec2(0x0808192b, 0x19080808), uvec2(0x08082b19, 0x19080808), uvec2(0x08190808, 0x19080808), + uvec2(0x0819082b, 0x19080808), uvec2(0x08191919, 0x19080808), uvec2(0x08192b08, 0x19080808), uvec2(0x08192b2b, 0x19080808), + uvec2(0x082b0819, 0x19080808), uvec2(0x082b1908, 0x19080808), uvec2(0x082b192b, 0x19080808), uvec2(0x19080808, 0x19080808), + uvec2(0x1908082b, 0x19080808), uvec2(0x19081919, 0x19080808), uvec2(0x19082b08, 0x19080808), uvec2(0x19082b2b, 0x19080808), + uvec2(0x19190819, 0x19080808), uvec2(0x19191908, 0x19080808), uvec2(0x1919192b, 0x19080808), uvec2(0x19192b19, 0x19080808), + uvec2(0x192b0808, 0x19080808), uvec2(0x192b082b, 0x19080808), uvec2(0x192b1919, 0x19080808), uvec2(0x2b080819, 0x19080808), + uvec2(0x2b081908, 0x19080808), uvec2(0x2b190808, 0x19080808), uvec2(0x2b191919, 0x19080808), uvec2(0x2b192b08, 0x19080808), + uvec2(0x2b2b0819, 0x19080808), uvec2(0x2b2b1908, 0x19080808), uvec2(0x08080808, 0x19080819), uvec2(0x0808082b, 0x19080819), + uvec2(0x08081919, 0x19080819), uvec2(0x08082b08, 0x19080819), uvec2(0x08190819, 0x19080819), uvec2(0x08191908, 0x19080819), + uvec2(0x0819192b, 0x19080819), uvec2(0x08192b19, 0x19080819), uvec2(0x082b0808, 0x19080819), uvec2(0x082b082b, 0x19080819), + uvec2(0x082b1919, 0x19080819), uvec2(0x19080819, 0x19080819), uvec2(0x19081908, 0x19080819), uvec2(0x1908192b, 0x19080819), + uvec2(0x19082b19, 0x19080819), uvec2(0x19190808, 0x19080819), uvec2(0x1919082b, 0x19080819), uvec2(0x19191919, 0x19080819), + uvec2(0x19192b08, 0x19080819), uvec2(0x192b0819, 0x19080819), uvec2(0x192b1908, 0x19080819), uvec2(0x2b080808, 0x19080819), + uvec2(0x2b08082b, 0x19080819), uvec2(0x2b081919, 0x19080819), uvec2(0x2b082b08, 0x19080819), uvec2(0x2b190819, 0x19080819), + uvec2(0x2b191908, 0x19080819), uvec2(0x2b2b0808, 0x19080819), uvec2(0x08080819, 0x1908082b), uvec2(0x08081908, 0x1908082b), + uvec2(0x08190808, 0x1908082b), uvec2(0x0819082b, 0x1908082b), uvec2(0x08191919, 0x1908082b), uvec2(0x08192b08, 0x1908082b), + uvec2(0x082b1908, 0x1908082b), uvec2(0x19080808, 0x1908082b), uvec2(0x19081919, 0x1908082b), uvec2(0x19082b08, 0x1908082b), + uvec2(0x19190819, 0x1908082b), uvec2(0x19191908, 0x1908082b), uvec2(0x192b0808, 0x1908082b), uvec2(0x2b080819, 0x1908082b), + uvec2(0x2b081908, 0x1908082b), uvec2(0x08080808, 0x19081908), uvec2(0x0808082b, 0x19081908), uvec2(0x08081919, 0x19081908), + uvec2(0x08082b08, 0x19081908), uvec2(0x08082b2b, 0x19081908), uvec2(0x08190819, 0x19081908), uvec2(0x08191908, 0x19081908), + uvec2(0x0819192b, 0x19081908), uvec2(0x08192b19, 0x19081908), uvec2(0x082b0808, 0x19081908), uvec2(0x082b082b, 0x19081908), + uvec2(0x082b1919, 0x19081908), uvec2(0x082b2b08, 0x19081908), uvec2(0x19080819, 0x19081908), uvec2(0x19081908, 0x19081908), + uvec2(0x1908192b, 0x19081908), uvec2(0x19082b19, 0x19081908), uvec2(0x19190808, 0x19081908), uvec2(0x1919082b, 0x19081908), + uvec2(0x19191919, 0x19081908), uvec2(0x19192b08, 0x19081908), uvec2(0x192b0819, 0x19081908), uvec2(0x192b1908, 0x19081908), + uvec2(0x2b080808, 0x19081908), uvec2(0x2b08082b, 0x19081908), uvec2(0x2b081919, 0x19081908), uvec2(0x2b082b08, 0x19081908), + uvec2(0x2b190819, 0x19081908), uvec2(0x2b191908, 0x19081908), uvec2(0x2b2b0808, 0x19081908), uvec2(0x08080819, 0x19081919), + uvec2(0x08081908, 0x19081919), uvec2(0x0808192b, 0x19081919), uvec2(0x08082b19, 0x19081919), uvec2(0x08190808, 0x19081919), + uvec2(0x0819082b, 0x19081919), uvec2(0x08191919, 0x19081919), uvec2(0x08192b08, 0x19081919), uvec2(0x082b0819, 0x19081919), + uvec2(0x082b1908, 0x19081919), uvec2(0x19080808, 0x19081919), uvec2(0x1908082b, 0x19081919), uvec2(0x19081919, 0x19081919), + uvec2(0x19082b08, 0x19081919), uvec2(0x19190819, 0x19081919), uvec2(0x19191908, 0x19081919), uvec2(0x192b0808, 0x19081919), + uvec2(0x192b2b2b, 0x19081919), uvec2(0x2b080819, 0x19081919), uvec2(0x2b081908, 0x19081919), uvec2(0x2b190808, 0x19081919), + uvec2(0x08080808, 0x1908192b), uvec2(0x0808082b, 0x1908192b), uvec2(0x08081919, 0x1908192b), uvec2(0x08082b08, 0x1908192b), + uvec2(0x08190819, 0x1908192b), uvec2(0x08191908, 0x1908192b), uvec2(0x082b0808, 0x1908192b), uvec2(0x19080819, 0x1908192b), + uvec2(0x19081908, 0x1908192b), uvec2(0x19190808, 0x1908192b), uvec2(0x2b080808, 0x1908192b), uvec2(0x2b2b1919, 0x1908192b), + uvec2(0x08080819, 0x19082b08), uvec2(0x08081908, 0x19082b08), uvec2(0x08082b19, 0x19082b08), uvec2(0x08190808, 0x19082b08), + uvec2(0x0819082b, 0x19082b08), uvec2(0x08191919, 0x19082b08), uvec2(0x08192b08, 0x19082b08), uvec2(0x082b0819, 0x19082b08), + uvec2(0x082b1908, 0x19082b08), uvec2(0x19080808, 0x19082b08), uvec2(0x1908082b, 0x19082b08), uvec2(0x19081919, 0x19082b08), + uvec2(0x19082b08, 0x19082b08), uvec2(0x19190819, 0x19082b08), uvec2(0x19191908, 0x19082b08), uvec2(0x192b0808, 0x19082b08), + uvec2(0x2b081908, 0x19082b08), uvec2(0x2b190808, 0x19082b08), uvec2(0x08080808, 0x19082b19), uvec2(0x0808082b, 0x19082b19), + uvec2(0x08081919, 0x19082b19), uvec2(0x08082b08, 0x19082b19), uvec2(0x08190819, 0x19082b19), uvec2(0x08191908, 0x19082b19), + uvec2(0x082b0808, 0x19082b19), uvec2(0x19080819, 0x19082b19), uvec2(0x19081908, 0x19082b19), uvec2(0x19190808, 0x19082b19), + uvec2(0x2b080808, 0x19082b19), uvec2(0x2b19192b, 0x19082b19), uvec2(0x08080819, 0x19082b2b), uvec2(0x08081908, 0x19082b2b), + uvec2(0x08190808, 0x19082b2b), uvec2(0x19080808, 0x19082b2b), uvec2(0x08080808, 0x19190808), uvec2(0x0808082b, 0x19190808), + uvec2(0x08081919, 0x19190808), uvec2(0x08082b08, 0x19190808), uvec2(0x08190819, 0x19190808), uvec2(0x08191908, 0x19190808), + uvec2(0x0819192b, 0x19190808), uvec2(0x08192b19, 0x19190808), uvec2(0x082b0808, 0x19190808), uvec2(0x082b082b, 0x19190808), + uvec2(0x082b1919, 0x19190808), uvec2(0x082b2b08, 0x19190808), uvec2(0x19080819, 0x19190808), uvec2(0x19081908, 0x19190808), + uvec2(0x1908192b, 0x19190808), uvec2(0x19082b19, 0x19190808), uvec2(0x19190808, 0x19190808), uvec2(0x1919082b, 0x19190808), + uvec2(0x19191919, 0x19190808), uvec2(0x19192b08, 0x19190808), uvec2(0x192b0819, 0x19190808), uvec2(0x192b1908, 0x19190808), + uvec2(0x2b080808, 0x19190808), uvec2(0x2b08082b, 0x19190808), uvec2(0x2b081919, 0x19190808), uvec2(0x2b082b08, 0x19190808), + uvec2(0x2b190819, 0x19190808), uvec2(0x2b191908, 0x19190808), uvec2(0x08080819, 0x19190819), uvec2(0x08081908, 0x19190819), + uvec2(0x0808192b, 0x19190819), uvec2(0x08082b19, 0x19190819), uvec2(0x08190808, 0x19190819), uvec2(0x0819082b, 0x19190819), + uvec2(0x08191919, 0x19190819), uvec2(0x08192b08, 0x19190819), uvec2(0x082b0819, 0x19190819), uvec2(0x082b1908, 0x19190819), + uvec2(0x19080808, 0x19190819), uvec2(0x1908082b, 0x19190819), uvec2(0x19081919, 0x19190819), uvec2(0x19082b08, 0x19190819), + uvec2(0x19190819, 0x19190819), uvec2(0x19191908, 0x19190819), uvec2(0x192b0808, 0x19190819), uvec2(0x2b080819, 0x19190819), + uvec2(0x2b081908, 0x19190819), uvec2(0x2b190808, 0x19190819), uvec2(0x08080808, 0x1919082b), uvec2(0x08081919, 0x1919082b), + uvec2(0x08082b08, 0x1919082b), uvec2(0x08190819, 0x1919082b), uvec2(0x08191908, 0x1919082b), uvec2(0x082b0808, 0x1919082b), + uvec2(0x19080819, 0x1919082b), uvec2(0x19081908, 0x1919082b), uvec2(0x19190808, 0x1919082b), uvec2(0x192b2b19, 0x1919082b), + uvec2(0x2b080808, 0x1919082b), uvec2(0x08080819, 0x19191908), uvec2(0x08081908, 0x19191908), uvec2(0x0808192b, 0x19191908), + uvec2(0x08082b19, 0x19191908), uvec2(0x08190808, 0x19191908), uvec2(0x0819082b, 0x19191908), uvec2(0x08191919, 0x19191908), + uvec2(0x08192b08, 0x19191908), uvec2(0x082b0819, 0x19191908), uvec2(0x082b1908, 0x19191908), uvec2(0x19080808, 0x19191908), + uvec2(0x1908082b, 0x19191908), uvec2(0x19081919, 0x19191908), uvec2(0x19082b08, 0x19191908), uvec2(0x19190819, 0x19191908), + uvec2(0x19191908, 0x19191908), uvec2(0x192b0808, 0x19191908), uvec2(0x2b080819, 0x19191908), uvec2(0x2b081908, 0x19191908), + uvec2(0x2b190808, 0x19191908), uvec2(0x08080808, 0x19191919), uvec2(0x0808082b, 0x19191919), uvec2(0x08081919, 0x19191919), + uvec2(0x08082b08, 0x19191919), uvec2(0x08190819, 0x19191919), uvec2(0x08191908, 0x19191919), uvec2(0x082b0808, 0x19191919), + uvec2(0x19080819, 0x19191919), uvec2(0x19081908, 0x19191919), uvec2(0x19190808, 0x19191919), uvec2(0x2b080808, 0x19191919), + uvec2(0x08080819, 0x1919192b), uvec2(0x08081908, 0x1919192b), uvec2(0x08190808, 0x1919192b), uvec2(0x082b192b, 0x1919192b), + uvec2(0x19080808, 0x1919192b), uvec2(0x08080808, 0x19192b08), uvec2(0x0808082b, 0x19192b08), uvec2(0x08081919, 0x19192b08), + uvec2(0x08082b08, 0x19192b08), uvec2(0x08190819, 0x19192b08), uvec2(0x08191908, 0x19192b08), uvec2(0x082b0808, 0x19192b08), + uvec2(0x19080819, 0x19192b08), uvec2(0x19081908, 0x19192b08), uvec2(0x19190808, 0x19192b08), uvec2(0x19192b2b, 0x19192b08), + uvec2(0x2b080808, 0x19192b08), uvec2(0x08080819, 0x19192b19), uvec2(0x08081908, 0x19192b19), uvec2(0x08190808, 0x19192b19), + uvec2(0x19080808, 0x19192b19), uvec2(0x08080808, 0x19192b2b), uvec2(0x08192b19, 0x19192b2b), uvec2(0x2b081919, 0x19192b2b), + uvec2(0x2b2b2b08, 0x19192b2b), uvec2(0x08080819, 0x192b0808), uvec2(0x08081908, 0x192b0808), uvec2(0x0808192b, 0x192b0808), + uvec2(0x08190808, 0x192b0808), uvec2(0x0819082b, 0x192b0808), uvec2(0x08191919, 0x192b0808), uvec2(0x08192b08, 0x192b0808), + uvec2(0x082b0819, 0x192b0808), uvec2(0x082b1908, 0x192b0808), uvec2(0x19080808, 0x192b0808), uvec2(0x19081919, 0x192b0808), + uvec2(0x19082b08, 0x192b0808), uvec2(0x19190819, 0x192b0808), uvec2(0x19191908, 0x192b0808), uvec2(0x192b0808, 0x192b0808), + uvec2(0x2b081908, 0x192b0808), uvec2(0x2b190808, 0x192b0808), uvec2(0x08080808, 0x192b0819), uvec2(0x0808082b, 0x192b0819), + uvec2(0x08081919, 0x192b0819), uvec2(0x08082b08, 0x192b0819), uvec2(0x08190819, 0x192b0819), uvec2(0x08191908, 0x192b0819), + uvec2(0x082b0808, 0x192b0819), uvec2(0x19080819, 0x192b0819), uvec2(0x19081908, 0x192b0819), uvec2(0x19190808, 0x192b0819), + uvec2(0x2b080808, 0x192b0819), uvec2(0x2b192b19, 0x192b0819), uvec2(0x08081908, 0x192b082b), uvec2(0x08190808, 0x192b082b), + uvec2(0x19080808, 0x192b082b), uvec2(0x1919192b, 0x192b082b), uvec2(0x2b2b0819, 0x192b082b), uvec2(0x08080808, 0x192b1908), + uvec2(0x08081919, 0x192b1908), uvec2(0x08082b08, 0x192b1908), uvec2(0x08190819, 0x192b1908), uvec2(0x08191908, 0x192b1908), + uvec2(0x082b0808, 0x192b1908), uvec2(0x19080819, 0x192b1908), uvec2(0x19081908, 0x192b1908), uvec2(0x19190808, 0x192b1908), + uvec2(0x2b080808, 0x192b1908), uvec2(0x08080819, 0x192b1919), uvec2(0x08081908, 0x192b1919), uvec2(0x08190808, 0x192b1919), + uvec2(0x19080808, 0x192b1919), uvec2(0x19082b2b, 0x192b1919), uvec2(0x192b2b08, 0x192b1919), uvec2(0x2b19082b, 0x192b1919), + uvec2(0x08080808, 0x192b192b), uvec2(0x2b191908, 0x192b192b), uvec2(0x08080819, 0x192b2b08), uvec2(0x08081908, 0x192b2b08), + uvec2(0x08190808, 0x192b2b08), uvec2(0x192b1919, 0x192b2b08), uvec2(0x2b192b08, 0x192b2b08), uvec2(0x08080808, 0x192b2b19), + uvec2(0x082b2b2b, 0x192b2b19), uvec2(0x1908082b, 0x192b2b2b), uvec2(0x2b2b0819, 0x192b2b2b), uvec2(0x08080808, 0x2b080808), + uvec2(0x0808082b, 0x2b080808), uvec2(0x08081919, 0x2b080808), uvec2(0x08082b08, 0x2b080808), uvec2(0x08190819, 0x2b080808), + uvec2(0x08191908, 0x2b080808), uvec2(0x08192b19, 0x2b080808), uvec2(0x082b0808, 0x2b080808), uvec2(0x082b1919, 0x2b080808), + uvec2(0x19080819, 0x2b080808), uvec2(0x19081908, 0x2b080808), uvec2(0x19190808, 0x2b080808), uvec2(0x1919082b, 0x2b080808), + uvec2(0x19191919, 0x2b080808), uvec2(0x19192b08, 0x2b080808), uvec2(0x192b0819, 0x2b080808), uvec2(0x2b080808, 0x2b080808), + uvec2(0x2b081919, 0x2b080808), uvec2(0x2b190819, 0x2b080808), uvec2(0x2b191908, 0x2b080808), uvec2(0x08080819, 0x2b080819), + uvec2(0x08081908, 0x2b080819), uvec2(0x08082b19, 0x2b080819), uvec2(0x08190808, 0x2b080819), uvec2(0x0819082b, 0x2b080819), + uvec2(0x08191919, 0x2b080819), uvec2(0x08192b08, 0x2b080819), uvec2(0x082b0819, 0x2b080819), uvec2(0x082b1908, 0x2b080819), + uvec2(0x19080808, 0x2b080819), uvec2(0x1908082b, 0x2b080819), uvec2(0x19081919, 0x2b080819), uvec2(0x19082b08, 0x2b080819), + uvec2(0x19190819, 0x2b080819), uvec2(0x19191908, 0x2b080819), uvec2(0x2b080819, 0x2b080819), uvec2(0x2b081908, 0x2b080819), + uvec2(0x2b190808, 0x2b080819), uvec2(0x2b2b2b19, 0x2b080819), uvec2(0x08080808, 0x2b08082b), uvec2(0x08081919, 0x2b08082b), + uvec2(0x08082b2b, 0x2b08082b), uvec2(0x08190819, 0x2b08082b), uvec2(0x08191908, 0x2b08082b), uvec2(0x19080819, 0x2b08082b), + uvec2(0x19081908, 0x2b08082b), uvec2(0x19190808, 0x2b08082b), uvec2(0x08080819, 0x2b081908), uvec2(0x08081908, 0x2b081908), + uvec2(0x0808192b, 0x2b081908), uvec2(0x08082b19, 0x2b081908), uvec2(0x08190808, 0x2b081908), uvec2(0x0819082b, 0x2b081908), + uvec2(0x08191919, 0x2b081908), uvec2(0x08192b08, 0x2b081908), uvec2(0x082b0819, 0x2b081908), uvec2(0x19080808, 0x2b081908), + uvec2(0x1908082b, 0x2b081908), uvec2(0x19081919, 0x2b081908), uvec2(0x19082b08, 0x2b081908), uvec2(0x19190819, 0x2b081908), + uvec2(0x19191908, 0x2b081908), uvec2(0x192b0808, 0x2b081908), uvec2(0x2b080819, 0x2b081908), uvec2(0x2b081908, 0x2b081908), + uvec2(0x2b190808, 0x2b081908), uvec2(0x08080808, 0x2b081919), uvec2(0x0808082b, 0x2b081919), uvec2(0x08081919, 0x2b081919), + uvec2(0x08082b08, 0x2b081919), uvec2(0x08190819, 0x2b081919), uvec2(0x08191908, 0x2b081919), uvec2(0x082b0808, 0x2b081919), + uvec2(0x19080819, 0x2b081919), uvec2(0x19081908, 0x2b081919), uvec2(0x19190808, 0x2b081919), uvec2(0x2b080808, 0x2b081919), + uvec2(0x2b082b2b, 0x2b081919), uvec2(0x08080819, 0x2b08192b), uvec2(0x08081908, 0x2b08192b), uvec2(0x08190808, 0x2b08192b), + uvec2(0x082b2b19, 0x2b08192b), uvec2(0x19080808, 0x2b08192b), uvec2(0x08080808, 0x2b082b08), uvec2(0x08081919, 0x2b082b08), + uvec2(0x08190819, 0x2b082b08), uvec2(0x08191908, 0x2b082b08), uvec2(0x19080819, 0x2b082b08), uvec2(0x19081908, 0x2b082b08), + uvec2(0x19190808, 0x2b082b08), uvec2(0x2b2b082b, 0x2b082b08), uvec2(0x08080819, 0x2b082b19), uvec2(0x08081908, 0x2b082b19), + uvec2(0x19080808, 0x2b082b19), uvec2(0x192b1919, 0x2b082b19), uvec2(0x082b082b, 0x2b082b2b), uvec2(0x19192b08, 0x2b082b2b), + uvec2(0x19192b2b, 0x2b082b2b), uvec2(0x2b08082b, 0x2b082b2b), uvec2(0x2b2b082b, 0x2b082b2b), uvec2(0x08080819, 0x2b190808), + uvec2(0x08081908, 0x2b190808), uvec2(0x08082b19, 0x2b190808), uvec2(0x08190808, 0x2b190808), uvec2(0x0819082b, 0x2b190808), + uvec2(0x08191919, 0x2b190808), uvec2(0x08192b08, 0x2b190808), uvec2(0x082b1908, 0x2b190808), uvec2(0x19080808, 0x2b190808), + uvec2(0x1908082b, 0x2b190808), uvec2(0x19081919, 0x2b190808), uvec2(0x19082b08, 0x2b190808), uvec2(0x19190819, 0x2b190808), + uvec2(0x19191908, 0x2b190808), uvec2(0x192b0808, 0x2b190808), uvec2(0x2b080819, 0x2b190808), uvec2(0x2b081908, 0x2b190808), + uvec2(0x2b190808, 0x2b190808), uvec2(0x08080808, 0x2b190819), uvec2(0x08081919, 0x2b190819), uvec2(0x08190819, 0x2b190819), + uvec2(0x08191908, 0x2b190819), uvec2(0x19080819, 0x2b190819), uvec2(0x19081908, 0x2b190819), uvec2(0x19190808, 0x2b190819), + uvec2(0x19192b2b, 0x2b190819), uvec2(0x08080819, 0x2b19082b), uvec2(0x08081908, 0x2b19082b), uvec2(0x08190808, 0x2b19082b), + uvec2(0x19080808, 0x2b19082b), uvec2(0x2b2b192b, 0x2b19082b), uvec2(0x08080808, 0x2b191908), uvec2(0x0808082b, 0x2b191908), + uvec2(0x08081919, 0x2b191908), uvec2(0x08082b08, 0x2b191908), uvec2(0x08190819, 0x2b191908), uvec2(0x08191908, 0x2b191908), + uvec2(0x082b0808, 0x2b191908), uvec2(0x19080819, 0x2b191908), uvec2(0x19081908, 0x2b191908), uvec2(0x19190808, 0x2b191908), + uvec2(0x2b080808, 0x2b191908), uvec2(0x2b19192b, 0x2b191908), uvec2(0x08080819, 0x2b191919), uvec2(0x08081908, 0x2b191919), + uvec2(0x08190808, 0x2b191919), uvec2(0x19080808, 0x2b191919), uvec2(0x2b192b08, 0x2b191919), uvec2(0x2b2b0819, 0x2b191919), + uvec2(0x08080808, 0x2b19192b), uvec2(0x1908192b, 0x2b19192b), uvec2(0x192b1908, 0x2b19192b), uvec2(0x08080819, 0x2b192b08), + uvec2(0x08081908, 0x2b192b08), uvec2(0x08190808, 0x2b192b08), uvec2(0x082b192b, 0x2b192b08), uvec2(0x19080808, 0x2b192b08), + uvec2(0x2b2b2b19, 0x2b192b08), uvec2(0x08080808, 0x2b192b19), uvec2(0x19082b19, 0x2b192b19), uvec2(0x1919082b, 0x2b192b19), + uvec2(0x2b190808, 0x2b192b2b), uvec2(0x08080808, 0x2b2b0808), uvec2(0x08081919, 0x2b2b0808), uvec2(0x08082b2b, 0x2b2b0808), + uvec2(0x08191908, 0x2b2b0808), uvec2(0x082b082b, 0x2b2b0808), uvec2(0x082b2b2b, 0x2b2b0808), uvec2(0x19080819, 0x2b2b0808), + uvec2(0x19081908, 0x2b2b0808), uvec2(0x19190808, 0x2b2b0808), uvec2(0x2b2b082b, 0x2b2b0808), uvec2(0x2b2b2b2b, 0x2b2b0808), + uvec2(0x19080808, 0x2b2b0819), uvec2(0x192b1919, 0x2b2b0819), uvec2(0x0808082b, 0x2b2b082b), uvec2(0x08082b2b, 0x2b2b082b), + uvec2(0x082b082b, 0x2b2b082b), uvec2(0x082b2b08, 0x2b2b082b), uvec2(0x082b2b2b, 0x2b2b082b), uvec2(0x2b08082b, 0x2b2b082b), + uvec2(0x2b082b08, 0x2b2b082b), uvec2(0x2b082b2b, 0x2b2b082b), uvec2(0x2b2b2b08, 0x2b2b082b), uvec2(0x08080819, 0x2b2b1908), + uvec2(0x08081908, 0x2b2b1908), uvec2(0x08190808, 0x2b2b1908), uvec2(0x19080808, 0x2b2b1908), uvec2(0x2b082b19, 0x2b2b1908), + uvec2(0x2b2b1908, 0x2b2b1908), uvec2(0x08080808, 0x2b2b1919), uvec2(0x08192b19, 0x2b2b1919), uvec2(0x19190819, 0x2b2b192b), + uvec2(0x08082b2b, 0x2b2b2b08), uvec2(0x082b2b08, 0x2b2b2b08), uvec2(0x2b2b082b, 0x2b2b2b08), uvec2(0x19191908, 0x2b2b2b19), + uvec2(0x2b08192b, 0x2b2b2b19), uvec2(0x08082b08, 0x2b2b2b2b), uvec2(0x08082b2b, 0x2b2b2b2b), uvec2(0x082b0808, 0x2b2b2b2b), + uvec2(0x082b082b, 0x2b2b2b2b), uvec2(0x082b2b08, 0x2b2b2b2b), uvec2(0x2b082b08, 0x2b2b2b2b), uvec2(0x2b2b2b2b, 0x2b2b2b2b) +}; + +shared uvec2 iq2s_grid[1024]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq2s_grid.length(); i += wgsize.x) { + if (iq2s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq2s_grid_const.length()) { + iq2s_grid[i + gl_LocalInvocationIndex.x] = iq2s_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ2_S +#define QUANT_R QUANT_R_IQ2_S +#define A_TYPE block_iq2_s +#define A_TYPE_PACKED16 block_iq2_s_packed16 +#endif + +#define QUANT_K_IQ3_XXS 256 +#define QUANT_R_IQ3_XXS 1 + +struct block_iq3_xxs +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_XXS/4 + QUANT_K_IQ3_XXS/8]; +}; + +struct block_iq3_xxs_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_XXS/8 + QUANT_K_IQ3_XXS/16]; +}; + +#if defined(DATA_A_IQ3_XXS) + +const uint32_t iq3xxs_grid_const[256] = { + 0x04040404, 0x04040414, 0x04040424, 0x04040c0c, 0x04040c1c, 0x04040c3e, 0x04041404, 0x04041414, + 0x04041c0c, 0x04042414, 0x04043e1c, 0x04043e2c, 0x040c040c, 0x040c041c, 0x040c0c04, 0x040c0c14, + 0x040c140c, 0x040c142c, 0x040c1c04, 0x040c1c14, 0x040c240c, 0x040c2c24, 0x040c3e04, 0x04140404, + 0x04140414, 0x04140424, 0x04140c0c, 0x04141404, 0x04141414, 0x04141c0c, 0x04141c1c, 0x04141c3e, + 0x04142c0c, 0x04142c3e, 0x04143e2c, 0x041c040c, 0x041c043e, 0x041c0c04, 0x041c0c14, 0x041c142c, + 0x041c3e04, 0x04240c1c, 0x04241c3e, 0x04242424, 0x04242c3e, 0x04243e1c, 0x04243e2c, 0x042c040c, + 0x042c043e, 0x042c1c14, 0x042c2c14, 0x04341c2c, 0x04343424, 0x043e0c04, 0x043e0c24, 0x043e0c34, + 0x043e241c, 0x043e340c, 0x0c04040c, 0x0c04041c, 0x0c040c04, 0x0c040c14, 0x0c04140c, 0x0c04141c, + 0x0c041c04, 0x0c041c14, 0x0c041c24, 0x0c04243e, 0x0c042c04, 0x0c0c0404, 0x0c0c0414, 0x0c0c0c0c, + 0x0c0c1404, 0x0c0c1414, 0x0c14040c, 0x0c14041c, 0x0c140c04, 0x0c140c14, 0x0c14140c, 0x0c141c04, + 0x0c143e14, 0x0c1c0404, 0x0c1c0414, 0x0c1c1404, 0x0c1c1c0c, 0x0c1c2434, 0x0c1c3434, 0x0c24040c, + 0x0c24042c, 0x0c242c04, 0x0c2c1404, 0x0c2c1424, 0x0c2c2434, 0x0c2c3e0c, 0x0c34042c, 0x0c3e1414, + 0x0c3e2404, 0x14040404, 0x14040414, 0x14040c0c, 0x14040c1c, 0x14041404, 0x14041414, 0x14041434, + 0x14041c0c, 0x14042414, 0x140c040c, 0x140c041c, 0x140c042c, 0x140c0c04, 0x140c0c14, 0x140c140c, + 0x140c1c04, 0x140c341c, 0x140c343e, 0x140c3e04, 0x14140404, 0x14140414, 0x14140c0c, 0x14140c3e, + 0x14141404, 0x14141414, 0x14141c3e, 0x14142404, 0x14142c2c, 0x141c040c, 0x141c0c04, 0x141c0c24, + 0x141c3e04, 0x141c3e24, 0x14241c2c, 0x14242c1c, 0x142c041c, 0x142c143e, 0x142c240c, 0x142c3e24, + 0x143e040c, 0x143e041c, 0x143e0c34, 0x143e242c, 0x1c04040c, 0x1c040c04, 0x1c040c14, 0x1c04140c, + 0x1c04141c, 0x1c042c04, 0x1c04342c, 0x1c043e14, 0x1c0c0404, 0x1c0c0414, 0x1c0c1404, 0x1c0c1c0c, + 0x1c0c2424, 0x1c0c2434, 0x1c14040c, 0x1c14041c, 0x1c140c04, 0x1c14142c, 0x1c142c14, 0x1c143e14, + 0x1c1c0c0c, 0x1c1c1c1c, 0x1c241c04, 0x1c24243e, 0x1c243e14, 0x1c2c0404, 0x1c2c0434, 0x1c2c1414, + 0x1c2c2c2c, 0x1c340c24, 0x1c341c34, 0x1c34341c, 0x1c3e1c1c, 0x1c3e3404, 0x24040424, 0x24040c3e, + 0x24041c2c, 0x24041c3e, 0x24042c1c, 0x24042c3e, 0x240c3e24, 0x24141404, 0x24141c3e, 0x24142404, + 0x24143404, 0x24143434, 0x241c043e, 0x241c242c, 0x24240424, 0x24242c0c, 0x24243424, 0x242c142c, + 0x242c241c, 0x242c3e04, 0x243e042c, 0x243e0c04, 0x243e0c14, 0x243e1c04, 0x2c040c14, 0x2c04240c, + 0x2c043e04, 0x2c0c0404, 0x2c0c0434, 0x2c0c1434, 0x2c0c2c2c, 0x2c140c24, 0x2c141c14, 0x2c143e14, + 0x2c1c0414, 0x2c1c2c1c, 0x2c240c04, 0x2c24141c, 0x2c24143e, 0x2c243e14, 0x2c2c0414, 0x2c2c1c0c, + 0x2c342c04, 0x2c3e1424, 0x2c3e2414, 0x34041424, 0x34042424, 0x34042434, 0x34043424, 0x340c140c, + 0x340c340c, 0x34140c3e, 0x34143424, 0x341c1c04, 0x341c1c34, 0x34242424, 0x342c042c, 0x342c2c14, + 0x34341c1c, 0x343e041c, 0x343e140c, 0x3e04041c, 0x3e04042c, 0x3e04043e, 0x3e040c04, 0x3e041c14, + 0x3e042c14, 0x3e0c1434, 0x3e0c2404, 0x3e140c14, 0x3e14242c, 0x3e142c14, 0x3e1c0404, 0x3e1c0c2c, + 0x3e1c1c1c, 0x3e1c3404, 0x3e24140c, 0x3e24240c, 0x3e2c0404, 0x3e2c0414, 0x3e2c1424, 0x3e341c04, +}; + +shared uint32_t iq3xxs_grid[256]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq3xxs_grid.length(); i += wgsize.x) { + if (iq3xxs_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3xxs_grid.length()) { + iq3xxs_grid[i + gl_LocalInvocationIndex.x] = iq3xxs_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_XXS +#define QUANT_R QUANT_R_IQ3_XXS +#define A_TYPE block_iq3_xxs +#define A_TYPE_PACKED16 block_iq3_xxs_packed16 +#endif + +#define QUANT_K_IQ3_S 256 +#define QUANT_R_IQ3_S 1 + +struct block_iq3_s +{ + float16_t d; + uint8_t qs[QUANT_K_IQ3_S/4]; + uint8_t qh[QUANT_K_IQ3_S/32]; + uint8_t signs[QUANT_K_IQ3_S/8]; + uint8_t scales[QUANT_K_IQ3_S/64]; +}; + +struct block_iq3_s_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ3_S/4/2]; + uint16_t qh[QUANT_K_IQ3_S/32/2]; + uint16_t signs[QUANT_K_IQ3_S/8/2]; + uint16_t scales[QUANT_K_IQ3_S/64/2]; +}; + +#if defined(DATA_A_IQ3_S) + +const uint32_t iq3s_grid_const[512] = { + 0x01010101, 0x01010103, 0x01010105, 0x0101010b, 0x0101010f, 0x01010301, 0x01010303, 0x01010305, + 0x01010309, 0x0101030d, 0x01010501, 0x01010503, 0x0101050b, 0x01010707, 0x01010901, 0x01010905, + 0x0101090b, 0x0101090f, 0x01010b03, 0x01010b07, 0x01010d01, 0x01010d05, 0x01010f03, 0x01010f09, + 0x01010f0f, 0x01030101, 0x01030103, 0x01030105, 0x01030109, 0x01030301, 0x01030303, 0x0103030b, + 0x01030501, 0x01030507, 0x0103050f, 0x01030703, 0x0103070b, 0x01030909, 0x01030d03, 0x01030d0b, + 0x01030f05, 0x01050101, 0x01050103, 0x0105010b, 0x0105010f, 0x01050301, 0x01050307, 0x0105030d, + 0x01050503, 0x0105050b, 0x01050701, 0x01050709, 0x01050905, 0x0105090b, 0x0105090f, 0x01050b03, + 0x01050b07, 0x01050f01, 0x01050f07, 0x01070107, 0x01070303, 0x0107030b, 0x01070501, 0x01070505, + 0x01070703, 0x01070707, 0x0107070d, 0x01070909, 0x01070b01, 0x01070b05, 0x01070d0f, 0x01070f03, + 0x01070f0b, 0x01090101, 0x01090307, 0x0109030f, 0x01090503, 0x01090509, 0x01090705, 0x01090901, + 0x01090907, 0x01090b03, 0x01090f01, 0x010b0105, 0x010b0109, 0x010b0501, 0x010b0505, 0x010b050d, + 0x010b0707, 0x010b0903, 0x010b090b, 0x010b090f, 0x010b0d0d, 0x010b0f07, 0x010d010d, 0x010d0303, + 0x010d0307, 0x010d0703, 0x010d0b05, 0x010d0f03, 0x010f0101, 0x010f0105, 0x010f0109, 0x010f0501, + 0x010f0505, 0x010f050d, 0x010f0707, 0x010f0b01, 0x010f0b09, 0x03010101, 0x03010103, 0x03010105, + 0x03010109, 0x03010301, 0x03010303, 0x03010307, 0x0301030b, 0x0301030f, 0x03010501, 0x03010505, + 0x03010703, 0x03010709, 0x0301070d, 0x03010b09, 0x03010b0d, 0x03010d03, 0x03010f05, 0x03030101, + 0x03030103, 0x03030107, 0x0303010d, 0x03030301, 0x03030309, 0x03030503, 0x03030701, 0x03030707, + 0x03030903, 0x03030b01, 0x03030b05, 0x03030f01, 0x03030f0d, 0x03050101, 0x03050305, 0x0305030b, + 0x0305030f, 0x03050501, 0x03050509, 0x03050705, 0x03050901, 0x03050907, 0x03050b0b, 0x03050d01, + 0x03050f05, 0x03070103, 0x03070109, 0x0307010f, 0x03070301, 0x03070307, 0x03070503, 0x0307050f, + 0x03070701, 0x03070709, 0x03070903, 0x03070d05, 0x03070f01, 0x03090107, 0x0309010b, 0x03090305, + 0x03090309, 0x03090703, 0x03090707, 0x03090905, 0x0309090d, 0x03090b01, 0x03090b09, 0x030b0103, + 0x030b0301, 0x030b0307, 0x030b0503, 0x030b0701, 0x030b0705, 0x030b0b03, 0x030d0501, 0x030d0509, + 0x030d050f, 0x030d0909, 0x030d090d, 0x030f0103, 0x030f0107, 0x030f0301, 0x030f0305, 0x030f0503, + 0x030f070b, 0x030f0903, 0x030f0d05, 0x030f0f01, 0x05010101, 0x05010103, 0x05010107, 0x0501010b, + 0x0501010f, 0x05010301, 0x05010305, 0x05010309, 0x0501030d, 0x05010503, 0x05010507, 0x0501050f, + 0x05010701, 0x05010705, 0x05010903, 0x05010907, 0x0501090b, 0x05010b01, 0x05010b05, 0x05010d0f, + 0x05010f01, 0x05010f07, 0x05010f0b, 0x05030101, 0x05030105, 0x05030301, 0x05030307, 0x0503030f, + 0x05030505, 0x0503050b, 0x05030703, 0x05030709, 0x05030905, 0x05030b03, 0x05050103, 0x05050109, + 0x0505010f, 0x05050503, 0x05050507, 0x05050701, 0x0505070f, 0x05050903, 0x05050b07, 0x05050b0f, + 0x05050f03, 0x05050f09, 0x05070101, 0x05070105, 0x0507010b, 0x05070303, 0x05070505, 0x05070509, + 0x05070703, 0x05070707, 0x05070905, 0x05070b01, 0x05070d0d, 0x05090103, 0x0509010f, 0x05090501, + 0x05090507, 0x05090705, 0x0509070b, 0x05090903, 0x05090f05, 0x05090f0b, 0x050b0109, 0x050b0303, + 0x050b0505, 0x050b070f, 0x050b0901, 0x050b0b07, 0x050b0f01, 0x050d0101, 0x050d0105, 0x050d010f, + 0x050d0503, 0x050d0b0b, 0x050d0d03, 0x050f010b, 0x050f0303, 0x050f050d, 0x050f0701, 0x050f0907, + 0x050f0b01, 0x07010105, 0x07010303, 0x07010307, 0x0701030b, 0x0701030f, 0x07010505, 0x07010703, + 0x07010707, 0x0701070b, 0x07010905, 0x07010909, 0x0701090f, 0x07010b03, 0x07010d07, 0x07010f03, + 0x07030103, 0x07030107, 0x0703010b, 0x07030309, 0x07030503, 0x07030507, 0x07030901, 0x07030d01, + 0x07030f05, 0x07030f0d, 0x07050101, 0x07050305, 0x07050501, 0x07050705, 0x07050709, 0x07050b01, + 0x07070103, 0x07070301, 0x07070309, 0x07070503, 0x07070507, 0x0707050f, 0x07070701, 0x07070903, + 0x07070907, 0x0707090f, 0x07070b0b, 0x07070f07, 0x07090107, 0x07090303, 0x0709030d, 0x07090505, + 0x07090703, 0x07090b05, 0x07090d01, 0x07090d09, 0x070b0103, 0x070b0301, 0x070b0305, 0x070b050b, + 0x070b0705, 0x070b0909, 0x070b0b0d, 0x070b0f07, 0x070d030d, 0x070d0903, 0x070f0103, 0x070f0107, + 0x070f0501, 0x070f0505, 0x070f070b, 0x09010101, 0x09010109, 0x09010305, 0x09010501, 0x09010509, + 0x0901050f, 0x09010705, 0x09010903, 0x09010b01, 0x09010f01, 0x09030105, 0x0903010f, 0x09030303, + 0x09030307, 0x09030505, 0x09030701, 0x0903070b, 0x09030907, 0x09030b03, 0x09030b0b, 0x09050103, + 0x09050107, 0x09050301, 0x0905030b, 0x09050503, 0x09050707, 0x09050901, 0x09050b0f, 0x09050d05, + 0x09050f01, 0x09070109, 0x09070303, 0x09070307, 0x09070501, 0x09070505, 0x09070703, 0x0907070b, + 0x09090101, 0x09090105, 0x09090509, 0x0909070f, 0x09090901, 0x09090f03, 0x090b010b, 0x090b010f, + 0x090b0503, 0x090b0d05, 0x090d0307, 0x090d0709, 0x090d0d01, 0x090f0301, 0x090f030b, 0x090f0701, + 0x090f0907, 0x090f0b03, 0x0b010105, 0x0b010301, 0x0b010309, 0x0b010505, 0x0b010901, 0x0b010909, + 0x0b01090f, 0x0b010b05, 0x0b010d0d, 0x0b010f09, 0x0b030103, 0x0b030107, 0x0b03010b, 0x0b030305, + 0x0b030503, 0x0b030705, 0x0b030f05, 0x0b050101, 0x0b050303, 0x0b050507, 0x0b050701, 0x0b05070d, + 0x0b050b07, 0x0b070105, 0x0b07010f, 0x0b070301, 0x0b07050f, 0x0b070909, 0x0b070b03, 0x0b070d0b, + 0x0b070f07, 0x0b090103, 0x0b090109, 0x0b090501, 0x0b090705, 0x0b09090d, 0x0b0b0305, 0x0b0b050d, + 0x0b0b0b03, 0x0b0b0b07, 0x0b0d0905, 0x0b0f0105, 0x0b0f0109, 0x0b0f0505, 0x0d010303, 0x0d010307, + 0x0d01030b, 0x0d010703, 0x0d010707, 0x0d010d01, 0x0d030101, 0x0d030501, 0x0d03050f, 0x0d030d09, + 0x0d050305, 0x0d050709, 0x0d050905, 0x0d050b0b, 0x0d050d05, 0x0d050f01, 0x0d070101, 0x0d070309, + 0x0d070503, 0x0d070901, 0x0d09050b, 0x0d090907, 0x0d090d05, 0x0d0b0101, 0x0d0b0107, 0x0d0b0709, + 0x0d0b0d01, 0x0d0d010b, 0x0d0d0901, 0x0d0f0303, 0x0d0f0307, 0x0f010101, 0x0f010109, 0x0f01010f, + 0x0f010501, 0x0f010505, 0x0f01070d, 0x0f010901, 0x0f010b09, 0x0f010d05, 0x0f030105, 0x0f030303, + 0x0f030509, 0x0f030907, 0x0f03090b, 0x0f050103, 0x0f050109, 0x0f050301, 0x0f05030d, 0x0f050503, + 0x0f050701, 0x0f050b03, 0x0f070105, 0x0f070705, 0x0f07070b, 0x0f070b07, 0x0f090103, 0x0f09010b, + 0x0f090307, 0x0f090501, 0x0f090b01, 0x0f0b0505, 0x0f0b0905, 0x0f0d0105, 0x0f0d0703, 0x0f0f0101, +}; + +shared uint32_t iq3s_grid[512]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + [[unroll]] for (uint i = 0; i < iq3s_grid.length(); i += wgsize.x) { + if (iq3s_grid.length() % wgsize.x == 0 || i + gl_LocalInvocationIndex.x < iq3s_grid.length()) { + iq3s_grid[i + gl_LocalInvocationIndex.x] = iq3s_grid_const[i + gl_LocalInvocationIndex.x]; + } + } + barrier(); +} + +#define QUANT_K QUANT_K_IQ3_S +#define QUANT_R QUANT_R_IQ3_S +#define A_TYPE block_iq3_s +#define A_TYPE_PACKED16 block_iq3_s_packed16 +#endif + +#define QUANT_K_IQ4_XS 256 +#define QUANT_R_IQ4_XS 1 + +struct block_iq4_xs +{ + float16_t d; + uint16_t scales_h; + uint8_t scales_l[QUANT_K_IQ4_XS/64]; + uint8_t qs[QUANT_K_IQ4_XS/2]; +}; + +#if defined(DATA_A_IQ4_XS) +#define QUANT_K QUANT_K_IQ4_XS +#define QUANT_R QUANT_R_IQ4_XS +#define A_TYPE block_iq4_xs +#endif + +#define QUANT_K_IQ4_NL 32 +#define QUANT_R_IQ4_NL 2 + +struct block_iq4_nl +{ + float16_t d; + uint8_t qs[QUANT_K_IQ4_NL/2]; +}; + +struct block_iq4_nl_packed16 +{ + float16_t d; + uint16_t qs[QUANT_K_IQ4_NL/2/2]; +}; + +#if defined(DATA_A_IQ4_NL) +#define QUANT_K QUANT_K_IQ4_NL +#define QUANT_R QUANT_R_IQ4_NL +#define A_TYPE block_iq4_nl +#define A_TYPE_PACKED16 block_iq4_nl_packed16 +#endif + +#define QUANT_K_MXFP4 32 +#define QUANT_R_MXFP4 2 + +struct block_mxfp4 +{ + uint8_t e; + uint8_t qs[QUANT_K_MXFP4/2]; +}; + +//struct block_mxfp4_packed16 +//{ +// uint8_t e; +// uint16_t qs[QUANT_K_MXFP4/2/2]; +//}; + +#if defined(DATA_A_MXFP4) +#define QUANT_K QUANT_K_MXFP4 +#define QUANT_R QUANT_R_MXFP4 +#define QUANT_AUXF 1 +#define A_TYPE block_mxfp4 +//#define A_TYPE_PACKED16 block_mxfp4_packed16 +#endif + +#if defined(DATA_A_IQ4_NL) || defined(DATA_A_IQ4_XS) +const int8_t kvalues_iq4nl_const[16] = { + int8_t(-127), int8_t(-104), int8_t(-83), int8_t(-65), int8_t(-49), int8_t(-35), int8_t(-22), int8_t(-10), + int8_t(1), int8_t(13), int8_t(25), int8_t(38), int8_t(53), int8_t(69), int8_t(89), int8_t(113) +}; + +shared FLOAT_TYPE kvalues_iq4nl[16]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_iq4nl.length(); i += wgsize.x) { + kvalues_iq4nl[i] = FLOAT_TYPE(kvalues_iq4nl_const[i]); + } + barrier(); +} +#endif + +#if defined(DATA_A_MXFP4) +const FLOAT_TYPE kvalues_mxfp4_const[16] = { + FLOAT_TYPE(0.0f), FLOAT_TYPE(0.5f), FLOAT_TYPE(1.0f), FLOAT_TYPE(1.5f), FLOAT_TYPE(2.0f), FLOAT_TYPE(3.0f), FLOAT_TYPE(4.0f), FLOAT_TYPE(6.0f), + FLOAT_TYPE(-0.0f), FLOAT_TYPE(-0.5f), FLOAT_TYPE(-1.0f), FLOAT_TYPE(-1.5f), FLOAT_TYPE(-2.0f), FLOAT_TYPE(-3.0f), FLOAT_TYPE(-4.0f), FLOAT_TYPE(-6.0f) +}; + +shared FLOAT_TYPE kvalues_mxfp4[16]; + +#define NEEDS_INIT_IQ_SHMEM +void init_iq_shmem(uvec3 wgsize) +{ + // copy the table into shared memory and sync + for (uint i = gl_LocalInvocationIndex.x; i < kvalues_mxfp4.length(); i += wgsize.x) { + kvalues_mxfp4[i] = kvalues_mxfp4_const[i]; + } + barrier(); +} +#endif + +// returns the bfloat value in the low 16b. +// See ggml_compute_fp32_to_bf16 +uint32_t fp32_to_bf16(float f) +{ + uint32_t u = floatBitsToUint(f); + u = (u + (0x7fff + ((u >> 16) & 1))) >> 16; + return u; +} + +float bf16_to_fp32(uint32_t u) +{ + return uintBitsToFloat(u << 16); +} + +vec4 bf16_to_fp32(uvec4 u) +{ + return vec4(bf16_to_fp32(u.x), bf16_to_fp32(u.y), bf16_to_fp32(u.z), bf16_to_fp32(u.w)); +} + +float e8m0_to_fp32(uint8_t x) { + uint32_t bits; + + if (x == 0) { + bits = 0x00400000; + } else { + bits = x; + bits = bits << 23; + } + + return uintBitsToFloat(bits); +} + +#if BDA + +#extension GL_EXT_buffer_reference : enable +#extension GL_EXT_shader_explicit_arithmetic_types_int64 : enable + +#define BDA_STORAGE_T uint64_t +#define BDA_OFFSET_T uint64_t + +#else + +#define BDA_STORAGE_T uvec2 +#define BDA_OFFSET_T uint + +#endif + +#endif // !defined(GGML_TYPES_COMP) diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp new file mode 100644 index 0000000000..154a2172d8 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/upscale.comp @@ -0,0 +1,100 @@ +#version 450 + +layout (push_constant) uniform parameter +{ + uint ne; uint a_offset; uint d_offset; + uint ne00; uint ne01; + uint nb00; uint nb01; uint nb02; uint nb03; + uint ne10; uint ne11; uint ne12; uint ne13; + float sf0; float sf1; float sf2; float sf3; +} p; + +#include "types.glsl" + +layout(local_size_x = 512, local_size_y = 1, local_size_z = 1) in; + +layout (binding = 0) readonly buffer A {A_TYPE data_a[];}; +layout (binding = 1) writeonly buffer D {D_TYPE data_d[];}; + +// from ggml.h: enum ggml_scale_mode, enum ggml_scale_flag +#define NEAREST 0 +#define BILINEAR 1 +#define ALIGN_CORNERS (1 << 8) + +layout (constant_id = 0) const uint scale_mode = 0; + +float fetch_nearest(uint i10, uint i11, uint i12, uint i13) { + const uint i00 = uint(i10 / p.sf0); + const uint i01 = uint(i11 / p.sf1); + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + + return data_a[p.a_offset + i03 * p.nb03 + i02 * p.nb02 + i01 * p.nb01 + i00 * p.nb00]; +} + +float fetch_bilinear(ivec2 c0, ivec2 c1, vec2 d, uint i12, uint i13) { + const uint i02 = uint(i12 / p.sf2); + const uint i03 = uint(i13 / p.sf3); + const uint base = p.a_offset + i03 * p.nb03 + i02 * p.nb02; + + const float v00 = data_a[base + c0.y * p.nb01 + c0.x * p.nb00]; + const float v01 = data_a[base + c0.y * p.nb01 + c1.x * p.nb00]; + const float v10 = data_a[base + c1.y * p.nb01 + c0.x * p.nb00]; + const float v11 = data_a[base + c1.y * p.nb01 + c1.x * p.nb00]; + + return + v00 * (1.0-d.x) * (1.0-d.y) + + v01 * d.x * (1.0-d.y) + + v10 * (1.0-d.x) * d.y + + v11 * d.x * d.y; +} + +float interpolate_bilinear(uint i10, uint i11, uint i12, uint i13) { + const ivec2 ne0 = ivec2(p.ne00, p.ne01); + + const vec2 c = (vec2(i10, i11) + 0.5) / vec2(p.sf0, p.sf1) - 0.5; + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = max(ivec2(c0f), 0); + const ivec2 c1 = min(ivec2(c0f + 1), ne0 - 1); + + return fetch_bilinear(c0, c1, d, i12, i13); +} + +float interpolate_bilinear_align_corners(uint i10, uint i11, uint i12, uint i13) { + const vec2 c = vec2(i10, i11) / vec2(p.sf0, p.sf1); + const vec2 c0f = floor(c); + const vec2 d = c - c0f; + const ivec2 c0 = ivec2(c0f); + const ivec2 c1 = c0 + 1; + + return fetch_bilinear(c0, c1, d, i12, i13); +} + +void main() { + const uint idx = gl_GlobalInvocationID.z * 262144 + gl_GlobalInvocationID.y * 512 + gl_GlobalInvocationID.x; + + if (idx >= p.ne) { + return; + } + + const uint i10 = idx % p.ne10; + const uint i11 = (idx / p.ne10) % p.ne11; + const uint i12 = (idx / (p.ne10 * p.ne11)) % p.ne12; + const uint i13 = (idx / (p.ne10 * p.ne11 * p.ne12)) % p.ne13; + + float result; + switch (scale_mode) { + case NEAREST: + result = fetch_nearest(i10, i11, i12, i13); + break; + case BILINEAR: + result = interpolate_bilinear(i10, i11, i12, i13); + break; + case BILINEAR | ALIGN_CORNERS: + result = interpolate_bilinear_align_corners(i10, i11, i12, i13); + break; + } + + data_d[p.d_offset + idx] = D_TYPE(result); +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl new file mode 100644 index 0000000000..dc4a1e6d96 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/utils.glsl @@ -0,0 +1,25 @@ +#ifndef UTILS_COMP +#define UTILS_COMP + +// mod and div are expensive and coordinates/dimensions are often power of 2 or equal to 1 +uint fastmod(uint a, uint b) { + if ((b & (b-1)) == 0) { + return a & (b-1); + } + return a % b; +} + +uint fastdiv(uint a, uint b) { + return (a < b) ? 0 : (a / b); +} + +void get_indices(uint idx, out uint i00, out uint i01, out uint i02, out uint i03, uint ne00, uint ne01, uint ne02, uint ne03) { + i03 = fastdiv(idx, (ne02*ne01*ne00)); + const uint i03_offset = i03 * ne02*ne01*ne00; + i02 = fastdiv((idx - i03_offset), (ne01*ne00)); + const uint i02_offset = i02*ne01*ne00; + i01 = (idx - i03_offset - i02_offset) / ne00; + i00 = idx - i03_offset - i02_offset - i01*ne00; +} + +#endif // UTILS_COMP diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp new file mode 100644 index 0000000000..f0cc24ff31 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/vulkan-shaders-gen.cpp @@ -0,0 +1,1097 @@ +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#ifdef _WIN32 + #define NOMINMAX + #include + #include // For _mkdir on Windows +#else + #include + #include + #include +#endif + +#define ASYNCIO_CONCURRENCY 64 + +std::mutex lock; +std::vector> shader_fnames; +std::locale c_locale("C"); + +std::string GLSLC = "glslc"; +std::string input_filepath = ""; +std::string output_dir = "/tmp"; +std::string target_hpp = ""; +std::string target_cpp = ""; + +const std::vector type_names = { + "f32", + "f16", + "q4_0", + "q4_1", + "q5_0", + "q5_1", + "q8_0", + "q2_k", + "q3_k", + "q4_k", + "q5_k", + "q6_k", + "iq1_s", + "iq1_m", + "iq2_xxs", + "iq2_xs", + "iq2_s", + "iq3_xxs", + "iq3_s", + "iq4_xs", + "iq4_nl", + "mxfp4", + "bf16", +}; + +enum MatMulIdType { + NONE, + DEFAULT, + SUBGROUP, +}; + +namespace { + +void execute_command(const std::string& command, std::string& stdout_str, std::string& stderr_str) { +#ifdef _WIN32 + HANDLE stdout_read, stdout_write; + HANDLE stderr_read, stderr_write; + SECURITY_ATTRIBUTES sa = { sizeof(SECURITY_ATTRIBUTES), NULL, TRUE }; + + if (!CreatePipe(&stdout_read, &stdout_write, &sa, 0) || + !SetHandleInformation(stdout_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stdout pipe"); + } + + if (!CreatePipe(&stderr_read, &stderr_write, &sa, 0) || + !SetHandleInformation(stderr_read, HANDLE_FLAG_INHERIT, 0)) { + throw std::runtime_error("Failed to create stderr pipe"); + } + + PROCESS_INFORMATION pi; + STARTUPINFOA si = {}; + si.cb = sizeof(STARTUPINFOA); + si.dwFlags = STARTF_USESTDHANDLES; + si.hStdOutput = stdout_write; + si.hStdError = stderr_write; + + std::vector cmd(command.begin(), command.end()); + cmd.push_back('\0'); + + if (!CreateProcessA(NULL, cmd.data(), NULL, NULL, TRUE, 0, NULL, NULL, &si, &pi)) { + throw std::runtime_error("Failed to create process"); + } + + CloseHandle(stdout_write); + CloseHandle(stderr_write); + + std::array buffer; + DWORD bytes_read; + + while (ReadFile(stdout_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while (ReadFile(stderr_read, buffer.data(), (DWORD)buffer.size(), &bytes_read, NULL) && bytes_read > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + CloseHandle(stdout_read); + CloseHandle(stderr_read); + WaitForSingleObject(pi.hProcess, INFINITE); + CloseHandle(pi.hProcess); + CloseHandle(pi.hThread); +#else + int stdout_pipe[2]; + int stderr_pipe[2]; + + if (pipe(stdout_pipe) != 0 || pipe(stderr_pipe) != 0) { + throw std::runtime_error("Failed to create pipes"); + } + + pid_t pid = fork(); + if (pid < 0) { + throw std::runtime_error("Failed to fork process"); + } + + if (pid == 0) { + close(stdout_pipe[0]); + close(stderr_pipe[0]); + dup2(stdout_pipe[1], STDOUT_FILENO); + dup2(stderr_pipe[1], STDERR_FILENO); + close(stdout_pipe[1]); + close(stderr_pipe[1]); + execl("/bin/sh", "sh", "-c", command.c_str(), (char*) nullptr); + _exit(EXIT_FAILURE); + } else { + close(stdout_pipe[1]); + close(stderr_pipe[1]); + + std::array buffer; + ssize_t bytes_read; + + while ((bytes_read = read(stdout_pipe[0], buffer.data(), buffer.size())) > 0) { + stdout_str.append(buffer.data(), bytes_read); + } + + while ((bytes_read = read(stderr_pipe[0], buffer.data(), buffer.size())) > 0) { + stderr_str.append(buffer.data(), bytes_read); + } + + close(stdout_pipe[0]); + close(stderr_pipe[0]); + waitpid(pid, nullptr, 0); + } +#endif +} + +bool directory_exists(const std::string& path) { + struct stat info; + if (stat(path.c_str(), &info) != 0) { + return false; // Path doesn't exist or can't be accessed + } + return (info.st_mode & S_IFDIR) != 0; // Check if it is a directory +} + +bool create_directory(const std::string& path) { +#ifdef _WIN32 + return _mkdir(path.c_str()) == 0 || errno == EEXIST; // EEXIST means the directory already exists +#else + return mkdir(path.c_str(), 0755) == 0 || errno == EEXIST; // 0755 is the directory permissions +#endif +} + +std::string to_uppercase(const std::string& input) { + std::string result = input; + for (char& c : result) { + c = std::toupper(c); + } + return result; +} + +bool string_starts_with(const std::string& str, const std::string& prefix) { + if (prefix.size() > str.size()) { + return false; + } + return std::equal(prefix.begin(), prefix.end(), str.begin()); +} + +bool string_ends_with(const std::string& str, const std::string& suffix) { + if (suffix.size() > str.size()) { + return false; + } + return std::equal(suffix.rbegin(), suffix.rend(), str.rbegin()); +} + +bool is_quantized_type(const std::string& type_name) { + return type_name != "f32" && type_name != "f16" && type_name != "bf16"; +} + +bool is_legacy_quant(const std::string& type_name) { + return type_name == "q4_0" || type_name == "q4_1" || type_name == "q5_0" || type_name == "q5_1" || type_name == "q8_0"; +} + +bool is_k_quant(const std::string& type_name) { + return string_ends_with(type_name, "_k"); +} + +bool is_iq_quant(const std::string& type_name) { + return string_starts_with(type_name, "iq"); +} + +static const char path_separator = '/'; + +std::string join_paths(const std::string& path1, const std::string& path2) { + return path1 + path_separator + path2; +} + +std::string basename(const std::string &path) { + return path.substr(path.find_last_of("/\\") + 1); +} + +std::stringstream make_generic_stringstream() { + std::stringstream ss; + ss.imbue(c_locale); + return ss; +} + +std::string read_binary_file(const std::string& path, bool may_not_exist = false) { + FILE* f = fopen(path.c_str(), "rb"); + if (!f) { + if (!may_not_exist) { + std::cerr << "Error opening file: " << path << " (" << strerror(errno) << ")\n"; + } + return {}; + } + + fseek(f, 0, SEEK_END); + size_t size = ftell(f); + fseek(f, 0, SEEK_SET); + + std::string data(size, '\0'); + size_t read_size = fread(data.data(), 1, size, f); + fclose(f); + if (read_size != size) { + std::cerr << "Error reading file: " << path << " (" << strerror(errno) << ")\n"; + return {}; + } + + return data; +} + +void write_binary_file(const std::string& path, const std::string& content) { + FILE* f = fopen(path.c_str(), "wb"); + if (!f) { + std::cerr << "Error opening file for writing: " << path << " (" << strerror(errno) << ")\n"; + return; + } + + size_t write_size = fwrite(content.data(), 1, content.size(), f); + fclose(f); + if (write_size != content.size()) { + std::cerr << "Error writing file: " << path << " (" << strerror(errno) << ")\n"; + return; + } +} + +void write_file_if_changed(const std::string& path, const std::string& content) { + std::string existing = read_binary_file(path, true); + if (existing != content) { + write_binary_file(path, content); + } +} + + +// variables to track number of compiles in progress +static uint32_t compile_count = 0; +static std::mutex compile_count_mutex; +static std::condition_variable compile_count_cond; +static bool generate_dep_file = true; + +void decrement_compile_count(uint32_t * count) { + if (count) { + std::lock_guard guard(compile_count_mutex); + assert(compile_count > 0); + compile_count--; + compile_count_cond.notify_all(); + } +} + +using compile_count_guard = std::unique_ptr; + +compile_count_guard acquire_compile_slot() { + // wait until fewer than N compiles are in progress. + // 16 is an arbitrary limit, the goal is to avoid "failed to create pipe" errors. + uint32_t N = std::max(1u, std::min(16u, std::thread::hardware_concurrency())); + std::unique_lock guard(compile_count_mutex); + compile_count_cond.wait(guard, [N] { return compile_count < N; }); + compile_count++; + return compile_count_guard(&compile_count, &decrement_compile_count); +} + +void string_to_spv_func(std::string name, std::string in_path, std::string out_path, std::map defines, bool coopmat, bool dep_file, compile_count_guard slot) { + std::string target_env = (name.find("_cm2") != std::string::npos) ? "--target-env=vulkan1.3" : "--target-env=vulkan1.2"; + + // disable spirv-opt for coopmat shaders for https://github.com/ggerganov/llama.cpp/issues/10734 + // disable spirv-opt for bf16 shaders for https://github.com/ggml-org/llama.cpp/issues/15344 + std::string opt_level = (coopmat || name.find("bf16") != std::string::npos) ? "" : "-O"; + + #ifdef _WIN32 + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, "\"" + in_path + "\"", "-o", "\"" + out_path + "\""}; + #else + std::vector cmd = {GLSLC, "-fshader-stage=compute", target_env, opt_level, in_path, "-o", out_path}; + #endif + + if (dep_file) { + cmd.push_back("-MD"); + cmd.push_back("-MF"); + cmd.push_back("\"" + target_cpp + ".d\""); + } + + #ifdef GGML_VULKAN_SHADER_DEBUG_INFO + cmd.push_back("-g"); + #endif + + for (const auto& define : defines) { + cmd.push_back("-D" + define.first + "=" + define.second); + } + + std::string command; + for (const auto& part : cmd) { + command += part + " "; + } + + std::string stdout_str, stderr_str; + try { + // std::cout << "Executing command: "; + // for (const auto& part : cmd) { + // std::cout << part << " "; + // } + // std::cout << std::endl; + + execute_command(command, stdout_str, stderr_str); + if (!stderr_str.empty()) { + std::cerr << "cannot compile " << name << "\n\n" << command << "\n\n" << stderr_str << std::endl; + return; + } + + if (dep_file) { + // replace .spv output path with the embed .cpp path which is used as output in CMakeLists.txt + std::string dep = read_binary_file(target_cpp + ".d", true); + if (!dep.empty()) { + size_t pos = dep.find(out_path); + if (pos != std::string::npos) { + dep.replace(pos, out_path.length(), target_cpp); + } + write_binary_file(target_cpp + ".d", dep); + } + } + + std::lock_guard guard(lock); + shader_fnames.push_back(std::make_pair(name, out_path)); + } catch (const std::exception& e) { + std::cerr << "Error executing command for " << name << ": " << e.what() << std::endl; + } +} + +std::map merge_maps(const std::map& a, const std::map& b) { + std::map result = a; + result.insert(b.begin(), b.end()); + return result; +} + +static std::vector> compiles; +void string_to_spv(std::string name, const std::string& source, const std::map& defines, bool fp16 = true, bool coopmat = false, bool coopmat2 = false, bool f16acc = false) { + name = name + (f16acc ? "_f16acc" : "") + (coopmat ? "_cm1" : "") + (coopmat2 ? "_cm2" : (fp16 ? "" : "_fp32")); + std::string out_path = join_paths(output_dir, name + ".spv"); + + if (input_filepath == "") { + // No input source to compile, only generate header for all shaders + shader_fnames.push_back(std::pair(name, out_path)); + return; + } else if (basename(input_filepath) != source) { + // Only compile shader variants matching the input filename + return; + } + + compile_count_guard slot = acquire_compile_slot(); + compiles.push_back(std::async( + string_to_spv_func, name, input_filepath, out_path, defines, coopmat, generate_dep_file, std::move(slot))); + // Don't write the same dep file from multiple processes + generate_dep_file = false; +} + +void matmul_shaders(bool fp16, MatMulIdType matmul_id_type, bool coopmat, bool coopmat2, bool f16acc) { + std::string load_vec = coopmat2 ? "1" : fp16 ? "8" : "4"; + std::string aligned_b_type_f32 = coopmat2 ? "float" : fp16 ? "mat2x4" : "vec4"; + std::string aligned_b_type_f16 = coopmat2 ? "float16_t" : fp16 ? "f16mat2x4" : "f16vec4"; + + std::map base_dict; + std::string shader_name = "matmul"; + + if (matmul_id_type == MatMulIdType::DEFAULT) { + base_dict["MUL_MAT_ID"] = "1"; + shader_name = "matmul_id"; + } else if (matmul_id_type == MatMulIdType::SUBGROUP) { + base_dict["MUL_MAT_ID"] = "1"; + base_dict["MUL_MAT_ID_USE_SUBGROUPS"] = "1"; + shader_name = "matmul_id_subgroup"; + } + + if (fp16) { + base_dict["FLOAT16"] = "1"; + } + + base_dict["ACC_TYPE" ] = f16acc ? "float16_t" : "float"; + base_dict["ACC_TYPE_VEC2"] = f16acc ? "f16vec2" : "vec2"; + if (f16acc) { + base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } + + if (coopmat) { + base_dict["COOPMAT"] = "1"; + } + + const std::string source_name = coopmat2 ? "mul_mm_cm2.comp" : "mul_mm.comp"; + + auto const &FLOAT_TYPE = [&](int vec, const std::string &t) -> std::string { + switch (vec) { + case 1: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "float"; + } + return "bfloat16_t"; + } + if (coopmat2 || fp16) { + return "float16_t"; + } + return "float"; + case 2: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec2"; + } + return "bf16vec2"; + } + if (coopmat2 || fp16) { + return "f16vec2"; + } + return "vec2"; + case 4: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "vec4"; + } + return "bf16vec4"; + } + if (coopmat2 || fp16) { + return "f16vec4"; + } + return "vec4"; + case 8: + if (t == "bf16") { + // scalar path promotes to float + if (!coopmat && !coopmat2) { + return "mat2x4"; + } + throw std::runtime_error("bf16 vec8 not supported"); + } + if (coopmat2 || fp16) { + return "f16mat2x4"; + } + return "mat2x4"; + default: + throw std::runtime_error("invalid vector size"); + } + }; + + const std::map float_type_dict_f16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "f16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "f16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "f16")}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, "f16")}, + }; + + // Shaders with f16 B_TYPE + string_to_spv(shader_name + "_f32_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}, }), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f32_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F32", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + string_to_spv(shader_name + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_f16), {{"DATA_A_F16", "1"}, {"LOAD_VEC_A", load_vec}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + + // bf16 + { + // For aligned matmul loads + std::string load_vec_a = coopmat2 ? "1" : "4"; + + // scalar path promotes to float + std::string to_float_type = (coopmat || coopmat2) ? "uintBitsToBFloat16EXT" : "bf16_to_fp32"; + + const std::map float_type_dict_bf16 = { + {"FLOAT_TYPE", FLOAT_TYPE(1, "bf16")}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, "bf16")}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, "bf16")}, + }; + + // If bfloat16 is not supported, then only compile the scalar (promote to fp32) shader +#if !defined(GGML_VULKAN_BFLOAT16_GLSLC_SUPPORT) + if (!(coopmat || coopmat2)) +#endif + { + string_to_spv(shader_name + "_bf16", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "uint16_t"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_bf16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict_bf16), {{"TO_FLOAT_TYPE", to_float_type}, {"DATA_A_BF16", "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", "4"}, {"B_TYPE", coopmat2 ? "bfloat16_t" : "u16vec4"}, {"D_TYPE", "float"}, {"B_IS_FLOAT", "1"}, {"DATA_B_BF16", "1"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + } + + for (const auto& tname : type_names) { + std::string load_vec_quant = "2"; + if ((tname == "q4_0") || (tname == "q4_1") || (tname == "iq1_s") || (tname == "iq1_m") || (tname == "iq2_xxs") || (tname == "iq2_xs") || (tname == "iq2_s")) + load_vec_quant = "8"; + else if ((tname == "q5_0") || (tname == "q5_1") || (tname == "q8_0") || (tname == "iq3_xxs") || (tname == "iq3_s") || (tname == "iq4_nl") || (tname == "mxfp4")) + load_vec_quant = "4"; + + if (tname == "bf16") { + continue; + } + + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + // For unaligned, load one at a time for f32/f16, or two at a time for quants + std::string load_vec_a_unaligned = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? "1" : load_vec_quant; + // For aligned matmul loads + std::string load_vec_a = (coopmat2 || tname == "f32" || tname == "f16" || tname == "bf16") ? load_vec : load_vec_quant; + + const std::map float_type_dict = { + {"FLOAT_TYPE", FLOAT_TYPE(1, tname)}, + {"FLOAT_TYPE_VEC2", FLOAT_TYPE(2, tname)}, + {"FLOAT_TYPE_VEC4", FLOAT_TYPE(4, tname)}, + {"FLOAT_TYPE_VEC8", FLOAT_TYPE(8, tname)}, + }; + + // don't generate f32 variants for coopmat2 + if (!coopmat2) { + string_to_spv(shader_name + "_" + tname + "_f32", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f32_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f32}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + + if (tname != "f16" && tname != "f32") { + string_to_spv(shader_name + "_" + tname + "_f16", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a_unaligned}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}}), fp16, coopmat, coopmat2, f16acc); + string_to_spv(shader_name + "_" + tname + "_f16_aligned", source_name, merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"LOAD_VEC_A", load_vec_a}, {"LOAD_VEC_B", load_vec}, {"B_TYPE", aligned_b_type_f16}, {"D_TYPE", "float"}, {"ALIGNED", "1"}}), fp16, coopmat, coopmat2, f16acc); + } + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (!coopmat && !coopmat2 && matmul_id_type == MatMulIdType::NONE && is_legacy_quant(tname)) { + string_to_spv(shader_name + "_" + tname + "_q8_1", "mul_mmq.comp", merge_maps(merge_maps(base_dict, float_type_dict), {{data_a_key, "1"}, {"D_TYPE", "float"},}), fp16, coopmat, coopmat2, f16acc); + } +#endif + } +} + +void process_shaders() { + std::map base_dict = {{"FLOAT_TYPE", "float"}}; + + // matmul + for (const MatMulIdType& matmul_id_type : {MatMulIdType::NONE, MatMulIdType::DEFAULT, MatMulIdType::SUBGROUP}) { + // No coopmats + // fp32 + matmul_shaders(false, matmul_id_type, false, false, false); + + // fp16, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, false, false, false); + matmul_shaders(true, matmul_id_type, false, false, true); + + if (matmul_id_type != MatMulIdType::DEFAULT) { +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + // Coopmat, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, true, false, false); + matmul_shaders(true, matmul_id_type, true, false, true); +#endif + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + // Coopmat2, fp32acc and fp16acc + matmul_shaders(true, matmul_id_type, false, true, false); + matmul_shaders(true, matmul_id_type, false, true, true); +#endif + } + } + + // flash attention + for (const auto& f16acc : {false, true}) { + std::map fa_base_dict = base_dict; + fa_base_dict["ACC_TYPE"] = f16acc ? "float16_t" : "float"; + fa_base_dict["ACC_TYPEV4"] = f16acc ? "f16vec4" : "vec4"; + if (f16acc) { + fa_base_dict["ACC_TYPE_MAX"] = "\"float16_t(65504.0)\""; + } + + for (const auto& tname : type_names) { + if (tname == "f32") { + continue; + } + if (tname == "bf16") continue; + +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, true, f16acc); + } else { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm2.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"DEQUANTFUNC", "dequantFunc"+to_uppercase(tname) }, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, true, f16acc); + } +#endif +#if defined(GGML_VULKAN_COOPMAT_GLSLC_SUPPORT) + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn_cm1.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname)}, {"COOPMAT", "1"}}), true, true, false, f16acc); + } +#endif + if (tname == "f16") { + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{"Q_TYPE", "float"}, {"D_TYPE", "float"}}), true, false, false, f16acc); + } else if (tname == "q4_0" || tname == "q8_0") { + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + string_to_spv("flash_attn_f32_f16_" + tname, "flash_attn.comp", + merge_maps(fa_base_dict, {{data_a_key, "1"}, {"Q_TYPE", "float"}, {"D_TYPE", "float"}, {"BLOCK_SIZE", "QUANT_K_"+to_uppercase(tname) }}), true, false, false, f16acc); + } + } + } + + for (const auto& tname : type_names) { + // mul mat vec + std::string data_a_key = "DATA_A_" + to_uppercase(tname); + std::string shader = (string_ends_with(tname, "_k") || string_starts_with(tname, "iq1_") || string_starts_with(tname, "iq2_") || string_starts_with(tname, "iq3_")) ? "mul_mat_vec_" + tname + ".comp" : "mul_mat_vec.comp"; + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}})); + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + + string_to_spv("mul_mat_vec_" + tname + "_f32_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_f16_f32_subgroup_no_shmem", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "float16_t"}, {"B_TYPE_VEC2", "f16vec2"}, {"B_TYPE_VEC4", "f16vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + + string_to_spv("mul_mat_vec_id_" + tname + "_f32", shader, merge_maps(base_dict, {{"MUL_MAT_ID", "1"}, {data_a_key, "1"}, {"B_TYPE", "float"}, {"B_TYPE_VEC2", "vec2"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}})); + + // mul mat vec with integer dot product +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + if (is_legacy_quant(tname)) { + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}})); + string_to_spv("mul_mat_vec_" + tname + "_q8_1_f32_subgroup_no_shmem", "mul_mat_vecq.comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"FLOAT_TYPE_VEC2", "vec2"}, {"ACC_TYPE", "float"}, {"USE_SUBGROUP_ADD_NO_SHMEM", "1"}})); + } +#endif + + // Dequant shaders + if (tname != "f16" && tname != "bf16") { + string_to_spv("dequant_" + tname, "dequant_" + tname + ".comp", merge_maps(base_dict, {{data_a_key, "1"}, {"D_TYPE", "float16_t"}})); + } + + shader = (tname == "f32" || tname == "f16" || tname == "bf16") ? "get_rows.comp" : "get_rows_quant.comp"; + + if (tname == "f16") { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}})); + } else { + string_to_spv("get_rows_" + tname, shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float16_t"}})); + } + string_to_spv("get_rows_" + tname + "_f32", shader, merge_maps(base_dict, {{data_a_key, "1"}, {"B_TYPE", "int"}, {"D_TYPE", "float"}})); + } + + string_to_spv("mul_mat_vec_p021_f16_f32_subgroup_add", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}, {"USE_SUBGROUP_ADD", "1"}}); + string_to_spv("mul_mat_vec_p021_f16_f32", "mul_mat_vec_p021.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + string_to_spv("mul_mat_vec_nc_f16_f32", "mul_mat_vec_nc.comp", {{"A_TYPE", "float16_t"}, {"A_TYPE_VEC4", "f16vec4"}, {"B_TYPE", "float"}, {"B_TYPE_VEC4", "vec4"}, {"D_TYPE", "float"}}); + + // Norms + string_to_spv("norm_f32", "norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("group_norm_f32", "group_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_f32", "rms_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_partials_f32", "rms_norm_partials.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("rms_norm_back_f32", "rms_norm_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("l2_norm_f32", "l2_norm.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("cpy_f32_f32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("cpy_f32_f16", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("cpy_f16_f16", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f16_f32", "copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("cpy_f32_bf16","copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("contig_cpy_f32_f32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_i32", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("contig_cpy_i32_f32", "contig_copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); + string_to_spv("contig_cpy_f32_f16", "contig_copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}}); + string_to_spv("contig_cpy_f16_f16", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f16_f32", "contig_copy.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("contig_cpy_f32_bf16","contig_copy.comp",{{"A_TYPE", "float"}, {"D_TYPE", "uint16_t"}, {"DATA_D_BF16", "1"}}); + string_to_spv("cpy_f32_i32", "copy.comp", {{"A_TYPE", "float"}, {"D_TYPE", "int"}}); + string_to_spv("cpy_i32_f32", "copy.comp", {{"A_TYPE", "int"}, {"D_TYPE", "float"}}); + + for (std::string t : {"q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("cpy_f32_" + t, "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("cpy_f32_" + t + "_rte", "copy_to_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("cpy_" + t + "_f32", "copy_from_quant.comp", {{"DATA_A_" + to_uppercase(t), "1"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + } + + for (std::string t : {"f32", "f16", "bf16", "q4_0", "q4_1", "q5_0", "q5_1", "q8_0", "iq4_nl"}) { + string_to_spv("set_rows_" + t + "_i32", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i32_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uint"}, {"B_SIZE", "32"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + string_to_spv("set_rows_" + t + "_i64", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + string_to_spv("set_rows_" + t + "_i64_rte", "copy_to_quant.comp", {{"SET_ROWS", "1"}, {"DATA_A_" + to_uppercase(t), "1"}, {"B_TYPE", "uvec2"}, {"B_SIZE", "64"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}}); + } + + auto get_type_str = [](bool f16) { + return f16 ? "float16_t" : "float"; + }; + auto get_suffix = [](bool src0_f16, bool src1_f16, bool dst_f16) { + std::string s; + s += std::string(src0_f16 ? "_f16" : "_f32"); + s += std::string(src1_f16 ? "_f16" : "_f32"); + s += std::string(dst_f16 ? "_f16" : "_f32"); + return s; + }; + for (std::string op : {"add", "sub", "mul", "div", "add_rms", }) { + for (auto src0_f16 : {false, true}) { + for (auto src1_f16 : {false, true}) { + for (auto dst_f16 : {false, true}) { + for (auto rte : {false, true}) { + auto source = op == "add_rms" ? std::string("add") : op; + auto name = op + get_suffix(src0_f16, src1_f16, dst_f16) + (rte ? "_rte" : ""); + auto add_rms = op == "add_rms" ? "1" : "0"; + string_to_spv(name.c_str(), source + ".comp", {{"A_TYPE", get_type_str(src0_f16)}, {"B_TYPE", get_type_str(src1_f16)}, {"D_TYPE", get_type_str(dst_f16)}, {"FLOAT_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}, {"ADD_RMS" , add_rms}}); + } + } + } + } + } + + string_to_spv("sub_f32", "sub.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("acc_f32", "acc.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("split_k_reduce", "mul_mat_split_k_reduce.comp", {}); + string_to_spv("fa_split_k_reduce", "flash_attn_split_k_reduce.comp", {}); + + string_to_spv("quantize_q8_1", "quantize_q8_1.comp", {}); + string_to_spv("quantize_q8_1_subgroup", "quantize_q8_1.comp", {{"USE_SUBGROUPS", "1"}}); + + string_to_spv("quantize_q8_1_x4", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}}); + string_to_spv("quantize_q8_1_x4_subgroup", "quantize_q8_1.comp", {{"QBLOCK_X4", "1"}, {"USE_SUBGROUPS", "1"}}); + + string_to_spv("mul_f32", "mul.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("div_f32", "div.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("repeat_f32", "repeat.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("repeat_back_f32", "repeat_back.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("scale_f32", "scale.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqr_f32", "square.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sqrt_f32", "sqrt.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("sin_f32", "sin.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("cos_f32", "cos.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("clamp_f32", "clamp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}}); + + string_to_spv("pad_f32", "pad.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("concat_f32", "concat.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("concat_f16", "concat.comp", {{"A_TYPE", "float16_t"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"OPTIMIZATION_ERROR_WORKAROUND", "1"}}); + string_to_spv("concat_i32", "concat.comp", {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}}); + + string_to_spv("upscale_f32", "upscale.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("exp_f16" + suffix, "exp.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("exp_f32" + suffix, "exp.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"} , {"RTE16", rte ? "1" : "0"}}); + } + string_to_spv("gelu_f16", "gelu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_f32", "gelu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_erf_f16", "gelu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_erf_f32", "gelu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("gelu_quick_f16", "gelu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("gelu_quick_f32", "gelu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_f16", "silu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("silu_f32", "silu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("relu_f16", "relu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("relu_f32", "relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("tanh_f16", "tanh.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("tanh_f32", "tanh.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("sigmoid_f16", "sigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("sigmoid_f32", "sigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardsigmoid_f16","hardsigmoid.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardsigmoid_f32","hardsigmoid.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("hardswish_f16", "hardswish.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("hardswish_f32", "hardswish.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + for (auto rte : {false, true}) { + std::string suffix = rte ? "_rte" : ""; + string_to_spv("geglu_f16" + suffix, "geglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_f32" + suffix, "geglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f16" + suffix, "reglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("reglu_f32" + suffix, "reglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f16" + suffix, "swiglu.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_f32" + suffix, "swiglu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f16" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("swiglu_oai_f32" + suffix, "swiglu_oai.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f16" + suffix, "geglu_erf.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_erf_f32" + suffix, "geglu_erf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f16" + suffix,"geglu_quick.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", rte ? "1" : "0"}}); + string_to_spv("geglu_quick_f32" + suffix,"geglu_quick.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"RTE16", rte ? "1" : "0"}}); + } + + string_to_spv("leaky_relu_f32", "leaky_relu.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("silu_back_f32", "silu_back.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("diag_mask_inf_f32", "diag_mask_inf.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("soft_max_f32", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_f32_f16", "soft_max.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float16_t"}, {"D_TYPE", "float"}})); + string_to_spv("soft_max_back_f32", "soft_max_back.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rope_norm_f32", "rope_norm.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_norm_f16", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_norm_f16_rte", "rope_norm.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_neox_f32", "rope_neox.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_neox_f16", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_neox_f16_rte", "rope_neox.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_multi_f32", "rope_multi.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_multi_f16", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_multi_f16_rte", "rope_multi.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("rope_vision_f32", "rope_vision.comp", {{"A_TYPE", "float"}, {"D_TYPE", "float"}}); + string_to_spv("rope_vision_f16", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}}); + string_to_spv("rope_vision_f16_rte", "rope_vision.comp", {{"A_TYPE", "float16_t"}, {"D_TYPE", "float16_t"}, {"RTE16", "1"}}); + + string_to_spv("argsort_f32", "argsort.comp", {{"A_TYPE", "float"}}); + + string_to_spv("argmax_f32", "argmax.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "int"}})); + string_to_spv("sum_rows_f32", "sum_rows.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + string_to_spv("count_equal_i32", "count_equal.comp", merge_maps(base_dict, {{"A_TYPE", "int"}, {"B_TYPE", "int"}, {"D_TYPE", "int"}})); + + for (std::string dim_str : {"", "_3d"}) { + for (bool bda : {false, true}) { + std::string bda_str = bda ? "_bda" : ""; + std::string bda_def = bda ? "1" : "0"; + string_to_spv("im2col" + dim_str + "_f32" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}, {"D_SIZE", "4"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"BDA", bda_def}})); + string_to_spv("im2col" + dim_str + "_f32_f16_rte" + bda_str, "im2col" + dim_str + ".comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float16_t"}, {"D_SIZE", "2"}, {"RTE16", "1"}, {"BDA", bda_def}})); + } + } + + string_to_spv("timestep_embedding_f32", "timestep_embedding.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("conv_transpose_1d_f32", "conv_transpose_1d.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}}); + + string_to_spv("pool2d_f32", "pool2d.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("rwkv_wkv6_f32", "wkv6.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + string_to_spv("rwkv_wkv7_f32", "wkv7.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + string_to_spv("opt_step_adamw_f32", "opt_step_adamw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + string_to_spv("opt_step_sgd_f32", "opt_step_sgd.comp", merge_maps(base_dict, {{"A_TYPE", "float"}})); + + for (auto transpose : {false, true}) { + for (auto unroll : {false, true}) { + for (auto a_f16 : {false, true}) { + std::map defines = { + {"A_TYPE", a_f16 ? "float16_t" : "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, + {"USE_COLLECTIVES", "1"}, {"UNROLL", unroll ? "[[unroll]]" : ""}, + }; + if (transpose) defines["TRANSPOSE"] = "1"; + std::string name = std::string(transpose ? "conv_transpose_2d": "conv2d") + + (a_f16 ? "_f16" : "") + "_f32"; + string_to_spv(name + (unroll ? "_unroll" : ""), "conv2d_mm.comp", defines); +#if defined(GGML_VULKAN_COOPMAT2_GLSLC_SUPPORT) + if (unroll) { + defines["COOPMAT2"] = "1"; + string_to_spv(name, "conv2d_mm.comp", defines, true, false, true); + } +#endif + } + } + } + + string_to_spv("conv2d_dw_whcn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + string_to_spv("conv2d_dw_whcn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"WHCN", "1"}})); + string_to_spv("conv2d_dw_cwhn_f16_f32", "conv2d_dw.comp", merge_maps(base_dict, {{"A_TYPE", "float16_t"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"CWHN", "1"}})); + + string_to_spv("roll_f32", "roll.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("add_id_f32", "add_id.comp", merge_maps(base_dict, {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}})); + + string_to_spv("multi_add_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "0"}}); + string_to_spv("multi_add_rms_f32", "multi_add.comp", {{"A_TYPE", "float"}, {"B_TYPE", "float"}, {"D_TYPE", "float"}, {"FLOAT_TYPE", "float"}, {"RTE16", "1"}, {"ADD_RMS" , "1"}}); + + for (auto &c : compiles) { + c.wait(); + } +} + +void write_output_files() { + std::stringstream hdr = make_generic_stringstream(); + std::stringstream src = make_generic_stringstream(); + + hdr << "#include \n\n"; + src << "#include \"" << basename(target_hpp) << "\"\n\n"; + + std::sort(shader_fnames.begin(), shader_fnames.end()); + for (const auto& pair : shader_fnames) { + const std::string& name = pair.first; + #ifdef _WIN32 + std::string path = pair.second; + std::replace(path.begin(), path.end(), '/', '\\' ); + #else + const std::string& path = pair.second; + #endif + + hdr << "extern const uint64_t " << name << "_len;\n"; + hdr << "extern const unsigned char " << name << "_data[];\n\n"; + + if (input_filepath != "") { + std::string data = read_binary_file(path); + if (data.empty()) { + continue; + } + + src << "const uint64_t " << name << "_len = " << data.size() << ";\n"; + src << "const unsigned char " << name << "_data[" << data.size() << "] = {\n" << std::hex; + auto bytes = reinterpret_cast(data.data()); + for (size_t i = 0; i < data.size(); ++i) { + src << "0x" << static_cast(bytes[i]) << ","; + if ((i + 1) % 12 == 0) src << "\n"; + } + src << std::dec << "\n};\n\n"; + } + } + + std::string suffixes[2] = {"_f32", "_f16"}; + for (auto op : {"add", "sub", "mul", "div", "add_rms"}) { + hdr << "extern const void * " << op << "_data[2][2][2][2];\n"; + hdr << "extern const uint64_t " << op << "_len[2][2][2][2];\n"; + + std::string op_file = op == "add_rms" ? "add.comp" : std::string(op) + ".comp"; + if (basename(input_filepath) != op_file) { + continue; + } + std::stringstream data = make_generic_stringstream(); + std::stringstream len = make_generic_stringstream(); + data << "const void * " << op << "_data[2][2][2][2] = "; + len << "const uint64_t " << op << "_len[2][2][2][2] = "; + for (uint32_t t0 = 0; t0 < 2; ++t0) { + if (t0 == 0) { + data << "{"; + len << "{"; + } + for (uint32_t t1 = 0; t1 < 2; ++t1) { + if (t1 == 0) { + data << "{"; + len << "{"; + } + for (uint32_t t2 = 0; t2 < 2; ++t2) { + if (t2 == 0) { + data << "{"; + len << "{"; + } + for (uint32_t rte = 0; rte < 2; ++rte) { + if (rte == 0) { + data << "{"; + len << "{"; + } + data << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + len << op << suffixes[t0] << suffixes[t1] << suffixes[t2] << ((rte != 0) ? "_rte" : ""); + data << "_data,"; + len << "_len,"; + if (rte == 1) { + data << "}, "; + len << "}, "; + } + } + if (t2 == 1) { + data << "}, "; + len << "}, "; + } + } + if (t1 == 1) { + data << "}, "; + len << "}, "; + } + } + if (t0 == 1) { + data << "};\n"; + len << "};\n"; + } + } + src << data.str(); + src << len.str(); + } + + std::vector btypes = {"f16", "f32"}; + +#if defined(GGML_VULKAN_INTEGER_DOT_GLSLC_SUPPORT) + btypes.push_back("q8_1"); +#endif + + for (const std::string& btype : btypes) { + for (const auto& tname : type_names) { + if (btype == "q8_1" && !is_legacy_quant(tname)) { + continue; + } + hdr << "extern const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3];\n"; + hdr << "extern const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3];\n"; + if (basename(input_filepath) == "mul_mat_vec.comp") { + src << "const void * arr_dmmv_" << tname << "_" << btype << "_f32_data[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_data, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_data};\n"; + src << "const uint64_t arr_dmmv_" << tname << "_" << btype << "_f32_len[3] = {mul_mat_vec_" << tname << "_" << btype << "_f32_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_len, mul_mat_vec_" << tname << "_" << btype << "_f32_subgroup_no_shmem_len};\n"; + } + } + } + + if (input_filepath == "") { + write_file_if_changed(target_hpp, hdr.str()); + } + if (target_cpp != "") { + write_binary_file(target_cpp, src.str()); + } +} + +} // namespace + +int main(int argc, char** argv) { + std::map args; + for (int i = 1; i < argc; ++i) { + std::string arg = argv[i]; + if (arg.rfind("--", 0) == 0) { + if (i + 1 < argc && argv[i + 1][0] != '-') { + args[arg] = argv[i + 1]; + ++i; + } else { + args[arg] = ""; + } + } + } + + if (args.find("--glslc") != args.end()) { + GLSLC = args["--glslc"]; // Path to glslc + } + if (args.find("--source") != args.end()) { + input_filepath = args["--source"]; // The shader source file to compile + } + if (args.find("--output-dir") != args.end()) { + output_dir = args["--output-dir"]; // Directory for containing SPIR-V output + } + if (args.find("--target-hpp") != args.end()) { + target_hpp = args["--target-hpp"]; // Path to generated header file + } + if (args.find("--target-cpp") != args.end()) { + target_cpp = args["--target-cpp"]; // Path to generated cpp file + } + + if (!directory_exists(output_dir)) { + if (!create_directory(output_dir)) { + std::cerr << "Error creating output directory: " << output_dir << "\n"; + return EXIT_FAILURE; + } + } + + process_shaders(); + + write_output_files(); + + return EXIT_SUCCESS; +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp new file mode 100644 index 0000000000..35cc6c45f9 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv6.comp @@ -0,0 +1,87 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 1) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 2) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 3) readonly buffer TimeFBuf { A_TYPE tf[]; }; +layout(binding = 4) readonly buffer TimeDBuf { A_TYPE td[]; }; +layout(binding = 5) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 6) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _k[BLOCK_SIZE], _r[BLOCK_SIZE], _tf[BLOCK_SIZE], _td[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid]; + } + + barrier(); + _tf[tid] = tf[head_id * head_size + tid]; + barrier(); + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _k[tid] = k[t]; + _r[tid] = r[t]; + _td[tid] = td[t]; + barrier(); + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 tf_vec = vec4(_tf[j], _tf[j+1], _tf[j+2], _tf[j+3]); + vec4 td_vec = vec4(_td[j], _td[j+1], _td[j+2], _td[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + + vec4 temp = tf_vec * kv + s_vec; + y += dot(r_vec, temp); + + s_vec = s_vec * td_vec + kv; + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + i * head_size + tid] = state[i]; + } +} diff --git a/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp new file mode 100644 index 0000000000..88c1c02b32 --- /dev/null +++ b/ml/backend/ggml/ggml/src/ggml-vulkan/vulkan-shaders/wkv7.comp @@ -0,0 +1,91 @@ +#version 450 + +#extension GL_EXT_control_flow_attributes : require + +#define BLOCK_SIZE 64 +layout(local_size_x = BLOCK_SIZE, local_size_y = 1, local_size_z = 1) in; + +layout(push_constant) uniform Parameters { + uint B; + uint T; + uint C; + uint H; +}; + +layout(binding = 0) readonly buffer RBuf { A_TYPE r[]; }; +layout(binding = 1) readonly buffer WBuf { A_TYPE w[]; }; +layout(binding = 2) readonly buffer KBuf { A_TYPE k[]; }; +layout(binding = 3) readonly buffer VBuf { A_TYPE v[]; }; +layout(binding = 4) readonly buffer ABuf { A_TYPE a[]; }; +layout(binding = 5) readonly buffer BBuf { A_TYPE b[]; }; +layout(binding = 6) readonly buffer StateBuf { A_TYPE state_in[]; }; +layout(binding = 7) buffer DstBuf { A_TYPE dst[]; }; + +shared A_TYPE _r[BLOCK_SIZE], _w[BLOCK_SIZE], _k[BLOCK_SIZE], _a[BLOCK_SIZE], _b[BLOCK_SIZE]; + +void main() { + const uint head_size = BLOCK_SIZE; + const uint batch_id = gl_WorkGroupID.x / H; + const uint head_id = gl_WorkGroupID.x % H; + const uint tid = gl_LocalInvocationID.x; + + const uint state_size = C * head_size; + const uint n_seq_tokens = T / B; + + if (batch_id >= B || head_id >= H) { + return; + } + + A_TYPE state[BLOCK_SIZE]; + [[unroll]] for (uint i = 0; i < head_size; i++) { + state[i] = state_in[batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i]; + } + + const uint start_t = batch_id * n_seq_tokens * C + head_id * head_size + tid; + const uint end_t = (batch_id + 1) * n_seq_tokens * C + head_id * head_size + tid; + + for (uint t = start_t; t < end_t; t += C) { + barrier(); + _r[tid] = r[t]; + _w[tid] = w[t]; + _k[tid] = k[t]; + _a[tid] = a[t]; + _b[tid] = b[t]; + barrier(); + + A_TYPE sa = 0.0; + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + vec4 a_vec = vec4(_a[j], _a[j+1], _a[j+2], _a[j+3]); + sa += dot(s_vec, a_vec); + } + + const A_TYPE v_val = v[t]; + A_TYPE y = 0.0; + + [[unroll]] for (uint j = 0; j < head_size; j += 4) { + vec4 r_vec = vec4(_r[j], _r[j+1], _r[j+2], _r[j+3]); + vec4 w_vec = vec4(_w[j], _w[j+1], _w[j+2], _w[j+3]); + vec4 k_vec = vec4(_k[j], _k[j+1], _k[j+2], _k[j+3]); + vec4 b_vec = vec4(_b[j], _b[j+1], _b[j+2], _b[j+3]); + vec4 s_vec = vec4(state[j], state[j+1], state[j+2], state[j+3]); + + vec4 kv = k_vec * v_val; + s_vec = s_vec * w_vec + kv + sa * b_vec; + y += dot(r_vec, s_vec); + + state[j] = s_vec.x; + state[j+1] = s_vec.y; + state[j+2] = s_vec.z; + state[j+3] = s_vec.w; + } + + dst[t] = y; + } + + [[unroll]] for (uint i = 0; i < head_size; i++) { + dst[T * C + batch_id * state_size + head_id * head_size * head_size + + tid * head_size + i] = state[i]; + } +} diff --git a/scripts/build_windows.ps1 b/scripts/build_windows.ps1 index 2359d7cc58..f1cd3fea26 100644 --- a/scripts/build_windows.ps1 +++ b/scripts/build_windows.ps1 @@ -184,6 +184,18 @@ function buildROCm() { } } +function buildVulkan(){ + if ($env:VULKAN_SDK) { + write-host "Building Vulkan backend libraries" + & cmake --fresh --preset Vulkan --install-prefix $script:DIST_DIR -DOLLAMA_RUNNER_DIR="vulkan" + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --build --preset Vulkan --config Release --parallel $script:JOBS + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + & cmake --install build --component Vulkan --strip + if ($LASTEXITCODE -ne 0) { exit($LASTEXITCODE)} + } +} + function buildOllama() { mkdir -Force -path "${script:DIST_DIR}\" write-host "Building ollama CLI" @@ -297,6 +309,7 @@ try { buildCUDA12 buildCUDA13 buildROCm + buildVulkan buildOllama buildApp gatherDependencies @@ -315,4 +328,4 @@ try { } finally { set-location $script:SRC_DIR $env:PKG_VERSION="" -} +} \ No newline at end of file