From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Web, 16 Oct 2025 20:37:19 -0700 Subject: [PATCH] interleave multi rope since ollama doesn't use mrope for anything else, change it to mean the interleaved version used for qwen3vl --- ggml/src/ggml-cpu/ops.cpp | 7 ++----- ggml/src/ggml-cuda/rope.cu | 12 +++--------- ggml/src/ggml-metal/ggml-metal.metal | 10 +++------- ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp | 12 +++--------- 4 files changed, 11 insertions(+), 30 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 31478dd8e..4d1ed207e 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -5509,15 +5509,12 @@ static void ggml_mrope_cache_init( } float theta = theta_t; - if (sector >= sections[0] && sector < sec_w) { + if (sector % 3 == 1 && sector < 1 + 3 * sections[1]) { theta = theta_h; } - else if (sector >= sec_w && sector < sec_w + sections[2]) { + else if (sector % 3 == 2 && sector < 2 + 3 * sections[2]) { theta = theta_w; } - else if (sector >= sec_w + sections[2]) { - theta = theta_e; - } rope_yarn( theta/ff, freq_scale, corr_dims, i0, ext_factor, mscale, &cache[i0 + 0], &cache[i0 + 1] diff --git a/ggml/src/ggml-cuda/rope.cu b/ggml/src/ggml-cuda/rope.cu index d058504cd..287fe9d2c 100644 --- a/ggml/src/ggml-cuda/rope.cu +++ b/ggml/src/ggml-cuda/rope.cu @@ -151,19 +151,13 @@ static __global__ void rope_multi( const int sec_w = sections.v[1] + sections.v[0]; const int sector = (i0 / 2) % sect_dims; - float theta_base = 0.0; - if (sector < sections.v[0]) { - theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); - } - else if (sector >= sections.v[0] && sector < sec_w) { + float theta_base = pos[channel_x]*powf(theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 1 + 3 * sections.v[1]) { theta_base = pos[channel_x + ne2 * 1]*powf(theta_scale, i0/2.0f); } - else if (sector >= sec_w && sector < sec_w + sections.v[2]) { + else if (sector % 3 == 2 && sector < 2 + 3 * sections.v[2]) { theta_base = pos[channel_x + ne2 * 2]*powf(theta_scale, i0/2.0f); } - else if (sector >= sec_w + sections.v[2]) { - theta_base = pos[channel_x + ne2 * 3]*powf(theta_scale, i0/2.0f); - } const float freq_factor = has_ff ? freq_factors[i0/2] : 1.0f; diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal index 375a0c7fd..9866c96b4 100644 --- a/ggml/src/ggml-metal/ggml-metal.metal +++ b/ggml/src/ggml-metal/ggml-metal.metal @@ -3858,15 +3858,11 @@ kernel void kernel_rope_multi( const int sec_w012 = args.sect_0 + args.sect_1 + args.sect_2; // end of section 2 const int sector = ic % sect_dims; - float theta_base; - if (sector < args.sect_0) { - theta_base = (float) pos[i2]; - } else if (sector < sec_w01) { + float theta_base = (float) pos[i2]; + if (sector % 3 == 1 && sector < 1 + 3 * args.sect_1) { theta_base = (float) pos[i2 + args.ne02]; - } else if (sector < sec_w012) { + } else if (sector % 3 == 2 && sector < 2 + 3 * args.sect_2) { theta_base = (float) pos[i2 + args.ne02 * 2]; - } else { - theta_base = (float) pos[i2 + args.ne02 * 3]; } // end of mrope diff --git a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp index 111286b49..6fc2b42f8 100644 --- a/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp +++ b/ggml/src/ggml-vulkan/vulkan-shaders/rope_multi.comp @@ -31,19 +31,13 @@ void main() { const int sec_w = p.sections[1] + p.sections[0]; const uint sector = (i0 / 2) % sect_dims; - float theta_base = 0.0; - if (sector < p.sections[0]) { - theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); - } - else if (sector >= p.sections[0] && sector < sec_w) { + float theta_base = data_pos[channel_x]*pow(p.theta_scale, i0/2.0f); + if (sector % 3 == 1 && sector < 1 + 3 * p.sections[1]) { theta_base = data_pos[channel_x + ne2 * 1]*pow(p.theta_scale, i0/2.0f); } - else if (sector >= sec_w && sector < sec_w + p.sections[2]) { + else if (sector % 3 == 2 && sector < 2 + 3 * p.sections[2]) { theta_base = data_pos[channel_x + ne2 * 2]*pow(p.theta_scale, i0/2.0f); } - else if (sector >= sec_w + p.sections[2]) { - theta_base = data_pos[channel_x + ne2 * 3]*pow(p.theta_scale, i0/2.0f); - } const float freq_factor = p.has_ff != 0 ? data_ff[i0/2] : 1.0f;