From 80d41b579b76fe2b10134e7e3fa0a184f6917e8d Mon Sep 17 00:00:00 2001 From: jmorganca Date: Sun, 15 Dec 2024 23:54:18 -0800 Subject: [PATCH] llama: add qwen2vl support --- llama/ggml-metal-embed.metal | 147 +++++++++++ llama/ggml-metal-impl.h | 1 + llama/ggml-metal.metal | 146 +++++++++++ llama/ggml-metal_darwin_arm64.m | 54 ++-- llama/patches/0014-qwen2vl-support.patch | 299 +++++++++++++++++++++++ 5 files changed, 632 insertions(+), 15 deletions(-) create mode 100644 llama/patches/0014-qwen2vl-support.patch diff --git a/llama/ggml-metal-embed.metal b/llama/ggml-metal-embed.metal index f45d869e9..72ed30948 100644 --- a/llama/ggml-metal-embed.metal +++ b/llama/ggml-metal-embed.metal @@ -2081,6 +2081,7 @@ typedef struct { float attn_factor; float beta_fast; float beta_slow; + int32_t sections[4]; } ggml_metal_kargs_rope; typedef struct { @@ -4785,8 +4786,148 @@ kernel void kernel_rope_neox( } } + +template +kernel void kernel_rope_multi( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3]; + int sec_w = args.sections[1] + args.sections[0]; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + const int sector = ic % sect_dims; + + float theta_base = (float) pos[i2]; + if (sector >= args.sections[0] && sector < sec_w) { + theta_base = (float) pos[i2 + args.ne2]; + } + else if (sector >= sec_w && sector < sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 2]; + } + else if (sector >= sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 3]; + } + + float theta = theta_base*pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_vision( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + int sect_dims = args.sections[0] + args.sections[1]; + int sec_w = args.sections[1] + args.sections[0]; + int sec_e = args.sections[2] + sec_w; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + const int ic = i0/2; + const int sector = ic % sect_dims; + + float theta_base = (float) pos[i2]; + if (sector >= args.sections[0] && sector < sec_w) { + theta_base = (float) pos[i2 + args.ne2]; + } + else if (sector >= sec_w && sector < sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 2]; + } + else if (sector >= sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 3]; + } + + int p = sector; + if (sector >= sec_w + args.sections[2]) { + p = sector - (sec_w + args.sections[2]); + } else if (sector >= sec_w) { + p = sector - sec_w; + } else if (sector >= args.sections[0]) { + p = sector - args.sections[0]; + } + + const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; + } +} + + typedef decltype(kernel_rope_norm) kernel_rope_norm_t; typedef decltype(kernel_rope_neox) kernel_rope_neox_t; +typedef decltype(kernel_rope_multi) kernel_rope_multi_t; +typedef decltype(kernel_rope_vision) kernel_rope_vision_t; template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; @@ -4794,6 +4935,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; + +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; + typedef void (im2col_t)( device const float * x, device char * dst, diff --git a/llama/ggml-metal-impl.h b/llama/ggml-metal-impl.h index 982b6f9dc..ab33da1ac 100644 --- a/llama/ggml-metal-impl.h +++ b/llama/ggml-metal-impl.h @@ -169,6 +169,7 @@ typedef struct { float attn_factor; float beta_fast; float beta_slow; + int32_t sections[4]; } ggml_metal_kargs_rope; typedef struct { diff --git a/llama/ggml-metal.metal b/llama/ggml-metal.metal index 8552f726b..b6964ec56 100644 --- a/llama/ggml-metal.metal +++ b/llama/ggml-metal.metal @@ -2594,8 +2594,148 @@ kernel void kernel_rope_neox( } } + +template +kernel void kernel_rope_multi( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3]; + int sec_w = args.sections[1] + args.sections[0]; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + if (i0 < args.n_dims) { + const int ic = i0/2; + const int sector = ic % sect_dims; + + float theta_base = (float) pos[i2]; + if (sector >= args.sections[0] && sector < sec_w) { + theta_base = (float) pos[i2 + args.ne2]; + } + else if (sector >= sec_w && sector < sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 2]; + } + else if (sector >= sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 3]; + } + + float theta = theta_base*pow(args.freq_base, inv_ndims*i0); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims/2]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; + } else { + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); + + dst_data[0] = src[0]; + dst_data[1] = src[1]; + } + } +} + +template +kernel void kernel_rope_vision( + constant ggml_metal_kargs_rope & args, + device const char * src0, + device const char * src1, + device const char * src2, + device char * dst, + ushort tiitg[[thread_index_in_threadgroup]], + ushort3 tptg [[threads_per_threadgroup]], + uint3 tgpig[[threadgroup_position_in_grid]]) { + const int i3 = tgpig[2]; + const int i2 = tgpig[1]; + const int i1 = tgpig[0]; + + float corr_dims[2]; + rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); + + device const int32_t * pos = (device const int32_t *) src1; + + int sect_dims = args.sections[0] + args.sections[1]; + int sec_w = args.sections[1] + args.sections[0]; + int sec_e = args.sections[2] + sec_w; + + const float inv_ndims = -1.f/args.n_dims; + + float cos_theta; + float sin_theta; + + for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { + const int ic = i0/2; + const int sector = ic % sect_dims; + + float theta_base = (float) pos[i2]; + if (sector >= args.sections[0] && sector < sec_w) { + theta_base = (float) pos[i2 + args.ne2]; + } + else if (sector >= sec_w && sector < sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 2]; + } + else if (sector >= sec_w + args.sections[2]) { + theta_base = (float) pos[i2 + args.ne2 * 3]; + } + + int p = sector; + if (sector >= sec_w + args.sections[2]) { + p = sector - (sec_w + args.sections[2]); + } else if (sector >= sec_w) { + p = sector - sec_w; + } else if (sector >= args.sections[0]) { + p = sector - args.sections[0]; + } + + const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p); + + const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; + + rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); + + device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); + device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); + + const float x0 = src[0]; + const float x1 = src[args.n_dims]; + + dst_data[0] = x0*cos_theta - x1*sin_theta; + dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; + } +} + + typedef decltype(kernel_rope_norm) kernel_rope_norm_t; typedef decltype(kernel_rope_neox) kernel_rope_neox_t; +typedef decltype(kernel_rope_multi) kernel_rope_multi_t; +typedef decltype(kernel_rope_vision) kernel_rope_vision_t; template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; @@ -2603,6 +2743,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; +template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; +template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; + +template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; +template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; + typedef void (im2col_t)( device const float * x, device char * dst, diff --git a/llama/ggml-metal_darwin_arm64.m b/llama/ggml-metal_darwin_arm64.m index 56d8a7549..128d46004 100644 --- a/llama/ggml-metal_darwin_arm64.m +++ b/llama/ggml-metal_darwin_arm64.m @@ -328,6 +328,10 @@ enum ggml_metal_kernel_type { GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, + GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, + GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F16, GGML_METAL_KERNEL_TYPE_IM2COL_F32, GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, @@ -928,6 +932,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); @@ -1155,16 +1163,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex case GGML_OP_NORM: return true; case GGML_OP_ROPE: - { - const int mode = ((const int32_t *) op->op_params)[2]; - if (mode & GGML_ROPE_TYPE_MROPE) { - return false; - } - if (mode & GGML_ROPE_TYPE_VISION) { - return false; - } - return true; - } + return true; case GGML_OP_IM2COL: return op->src[0]->type == GGML_TYPE_F16; case GGML_OP_POOL_1D: @@ -3083,6 +3082,7 @@ static void ggml_metal_encode_node( float attn_factor; float beta_fast; float beta_slow; + int32_t sections[4]; memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); @@ -3090,21 +3090,44 @@ static void ggml_metal_encode_node( memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); + memcpy(§ions, (const int32_t *) dst->op_params + 11, sizeof(int32_t)*4); const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; + const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; + const bool is_vision = mode == GGML_ROPE_TYPE_VISION; + + if (is_mrope) { + GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); + } + + if (is_vision) { + GGML_ASSERT(n_dims == ne00/2); + } id pipeline = nil; - if (!is_neox) { + if (is_neox) { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_mrope && !is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else if (is_vision) { + switch (src0->type) { + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } else { switch (src0->type) { - case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; - case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; + case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; + case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; default: GGML_ABORT("fatal error"); }; } @@ -3135,6 +3158,7 @@ static void ggml_metal_encode_node( /*.attn_factor =*/ attn_factor, /*.beta_fast =*/ beta_fast, /*.beta_slow =*/ beta_slow, + /*.sections =*/ {sections[0], sections[1], sections[2], sections[3]} }; [encoder setComputePipelineState:pipeline]; diff --git a/llama/patches/0014-qwen2vl-support.patch b/llama/patches/0014-qwen2vl-support.patch new file mode 100644 index 000000000..14865c8fb --- /dev/null +++ b/llama/patches/0014-qwen2vl-support.patch @@ -0,0 +1,299 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: jmorganca +Date: Sun, 15 Dec 2024 23:56:24 -0800 +Subject: [PATCH] qwen2vl support + +--- + ggml/src/ggml-metal/ggml-metal-impl.h | 1 + + ggml/src/ggml-metal/ggml-metal.m | 54 +++++++--- + ggml/src/ggml-metal/ggml-metal.metal | 146 ++++++++++++++++++++++++++ + 3 files changed, 186 insertions(+), 15 deletions(-) + +diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h +index e3dc25f1..766a4999 100644 +--- a/ggml/src/ggml-metal/ggml-metal-impl.h ++++ b/ggml/src/ggml-metal/ggml-metal-impl.h +@@ -143,6 +143,7 @@ typedef struct { + float attn_factor; + float beta_fast; + float beta_slow; ++ int32_t sections[4]; + } ggml_metal_kargs_rope; + + typedef struct { +diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m +index 787fc713..806c9fd3 100644 +--- a/ggml/src/ggml-metal/ggml-metal.m ++++ b/ggml/src/ggml-metal/ggml-metal.m +@@ -302,6 +302,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte + GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, + GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, ++ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, ++ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, ++ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, ++ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_F16, + GGML_METAL_KERNEL_TYPE_IM2COL_F32, + GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, +@@ -902,6 +906,10 @@ @implementation GGMLMetalClass + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true); ++ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true); + GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true); +@@ -1129,16 +1137,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex + case GGML_OP_NORM: + return true; + case GGML_OP_ROPE: +- { +- const int mode = ((const int32_t *) op->op_params)[2]; +- if (mode & GGML_ROPE_TYPE_MROPE) { +- return false; +- } +- if (mode & GGML_ROPE_TYPE_VISION) { +- return false; +- } +- return true; +- } ++ return true; + case GGML_OP_IM2COL: + return op->src[0]->type == GGML_TYPE_F16; + case GGML_OP_POOL_1D: +@@ -3057,6 +3056,7 @@ static void ggml_metal_encode_node( + float attn_factor; + float beta_fast; + float beta_slow; ++ int32_t sections[4]; + + memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float)); + memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float)); +@@ -3064,21 +3064,44 @@ static void ggml_metal_encode_node( + memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float)); + memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float)); + memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float)); ++ memcpy(§ions, (const int32_t *) dst->op_params + 11, sizeof(int32_t)*4); + + const bool is_neox = mode & GGML_ROPE_TYPE_NEOX; ++ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE; ++ const bool is_vision = mode == GGML_ROPE_TYPE_VISION; ++ ++ if (is_mrope) { ++ GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0); ++ } ++ ++ if (is_vision) { ++ GGML_ASSERT(n_dims == ne00/2); ++ } + + id pipeline = nil; + +- if (!is_neox) { ++ if (is_neox) { + switch (src0->type) { +- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; +- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; ++ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; ++ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; ++ default: GGML_ABORT("fatal error"); ++ }; ++ } else if (is_mrope && !is_vision) { ++ switch (src0->type) { ++ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break; ++ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break; ++ default: GGML_ABORT("fatal error"); ++ }; ++ } else if (is_vision) { ++ switch (src0->type) { ++ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break; ++ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } else { + switch (src0->type) { +- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break; +- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break; ++ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break; ++ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break; + default: GGML_ABORT("fatal error"); + }; + } +@@ -3109,6 +3132,7 @@ static void ggml_metal_encode_node( + /*.attn_factor =*/ attn_factor, + /*.beta_fast =*/ beta_fast, + /*.beta_slow =*/ beta_slow, ++ /*.sections =*/ {sections[0], sections[1], sections[2], sections[3]} + }; + + [encoder setComputePipelineState:pipeline]; +diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal +index 204c93e6..67b3240f 100644 +--- a/ggml/src/ggml-metal/ggml-metal.metal ++++ b/ggml/src/ggml-metal/ggml-metal.metal +@@ -2568,8 +2568,148 @@ kernel void kernel_rope_neox( + } + } + ++ ++template ++kernel void kernel_rope_multi( ++ constant ggml_metal_kargs_rope & args, ++ device const char * src0, ++ device const char * src1, ++ device const char * src2, ++ device char * dst, ++ ushort tiitg[[thread_index_in_threadgroup]], ++ ushort3 tptg [[threads_per_threadgroup]], ++ uint3 tgpig[[threadgroup_position_in_grid]]) { ++ const int i3 = tgpig[2]; ++ const int i2 = tgpig[1]; ++ const int i1 = tgpig[0]; ++ ++ float corr_dims[2]; ++ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); ++ ++ device const int32_t * pos = (device const int32_t *) src1; ++ ++ int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3]; ++ int sec_w = args.sections[1] + args.sections[0]; ++ ++ const float inv_ndims = -1.f/args.n_dims; ++ ++ float cos_theta; ++ float sin_theta; ++ ++ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { ++ if (i0 < args.n_dims) { ++ const int ic = i0/2; ++ const int sector = ic % sect_dims; ++ ++ float theta_base = (float) pos[i2]; ++ if (sector >= args.sections[0] && sector < sec_w) { ++ theta_base = (float) pos[i2 + args.ne2]; ++ } ++ else if (sector >= sec_w && sector < sec_w + args.sections[2]) { ++ theta_base = (float) pos[i2 + args.ne2 * 2]; ++ } ++ else if (sector >= sec_w + args.sections[2]) { ++ theta_base = (float) pos[i2 + args.ne2 * 3]; ++ } ++ ++ float theta = theta_base*pow(args.freq_base, inv_ndims*i0); ++ ++ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; ++ ++ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); ++ ++ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); ++ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); ++ ++ const float x0 = src[0]; ++ const float x1 = src[args.n_dims/2]; ++ ++ dst_data[0] = x0*cos_theta - x1*sin_theta; ++ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta; ++ } else { ++ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00); ++ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0); ++ ++ dst_data[0] = src[0]; ++ dst_data[1] = src[1]; ++ } ++ } ++} ++ ++template ++kernel void kernel_rope_vision( ++ constant ggml_metal_kargs_rope & args, ++ device const char * src0, ++ device const char * src1, ++ device const char * src2, ++ device char * dst, ++ ushort tiitg[[thread_index_in_threadgroup]], ++ ushort3 tptg [[threads_per_threadgroup]], ++ uint3 tgpig[[threadgroup_position_in_grid]]) { ++ const int i3 = tgpig[2]; ++ const int i2 = tgpig[1]; ++ const int i1 = tgpig[0]; ++ ++ float corr_dims[2]; ++ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims); ++ ++ device const int32_t * pos = (device const int32_t *) src1; ++ ++ int sect_dims = args.sections[0] + args.sections[1]; ++ int sec_w = args.sections[1] + args.sections[0]; ++ int sec_e = args.sections[2] + sec_w; ++ ++ const float inv_ndims = -1.f/args.n_dims; ++ ++ float cos_theta; ++ float sin_theta; ++ ++ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) { ++ const int ic = i0/2; ++ const int sector = ic % sect_dims; ++ ++ float theta_base = (float) pos[i2]; ++ if (sector >= args.sections[0] && sector < sec_w) { ++ theta_base = (float) pos[i2 + args.ne2]; ++ } ++ else if (sector >= sec_w && sector < sec_w + args.sections[2]) { ++ theta_base = (float) pos[i2 + args.ne2 * 2]; ++ } ++ else if (sector >= sec_w + args.sections[2]) { ++ theta_base = (float) pos[i2 + args.ne2 * 3]; ++ } ++ ++ int p = sector; ++ if (sector >= sec_w + args.sections[2]) { ++ p = sector - (sec_w + args.sections[2]); ++ } else if (sector >= sec_w) { ++ p = sector - sec_w; ++ } else if (sector >= args.sections[0]) { ++ p = sector - args.sections[0]; ++ } ++ ++ const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p); ++ ++ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f; ++ ++ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta); ++ ++ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00); ++ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0); ++ ++ const float x0 = src[0]; ++ const float x1 = src[args.n_dims]; ++ ++ dst_data[0] = x0*cos_theta - x1*sin_theta; ++ dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta; ++ } ++} ++ ++ + typedef decltype(kernel_rope_norm) kernel_rope_norm_t; + typedef decltype(kernel_rope_neox) kernel_rope_neox_t; ++typedef decltype(kernel_rope_multi) kernel_rope_multi_t; ++typedef decltype(kernel_rope_vision) kernel_rope_vision_t; + + template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm; + template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm; +@@ -2577,6 +2717,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_ + template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox; + template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox; + ++template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi; ++template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi; ++ ++template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision; ++template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision; ++ + typedef void (im2col_t)( + device const float * x, + device char * dst,