mirror of
https://github.com/ollama/ollama.git
synced 2025-04-11 05:09:45 +02:00
llama: add qwen2vl support
This commit is contained in:
parent
8c9fb8eb73
commit
80d41b579b
147
llama/ggml-metal-embed.metal
vendored
147
llama/ggml-metal-embed.metal
vendored
@ -2081,6 +2081,7 @@ typedef struct {
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
int32_t sections[4];
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
@ -4785,8 +4786,148 @@ kernel void kernel_rope_neox(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_multi(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3];
|
||||
int sec_w = args.sections[1] + args.sections[0];
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
if (sector >= args.sections[0] && sector < sec_w) {
|
||||
theta_base = (float) pos[i2 + args.ne2];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 3];
|
||||
}
|
||||
|
||||
float theta = theta_base*pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_vision(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
int sect_dims = args.sections[0] + args.sections[1];
|
||||
int sec_w = args.sections[1] + args.sections[0];
|
||||
int sec_e = args.sections[2] + sec_w;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
const int ic = i0/2;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
if (sector >= args.sections[0] && sector < sec_w) {
|
||||
theta_base = (float) pos[i2 + args.ne2];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 3];
|
||||
}
|
||||
|
||||
int p = sector;
|
||||
if (sector >= sec_w + args.sections[2]) {
|
||||
p = sector - (sec_w + args.sections[2]);
|
||||
} else if (sector >= sec_w) {
|
||||
p = sector - sec_w;
|
||||
} else if (sector >= args.sections[0]) {
|
||||
p = sector - args.sections[0];
|
||||
}
|
||||
|
||||
const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
||||
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
@ -4794,6 +4935,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
||||
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
||||
|
||||
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
||||
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
||||
|
||||
typedef void (im2col_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
|
1
llama/ggml-metal-impl.h
vendored
1
llama/ggml-metal-impl.h
vendored
@ -169,6 +169,7 @@ typedef struct {
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
int32_t sections[4];
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
|
146
llama/ggml-metal.metal
vendored
146
llama/ggml-metal.metal
vendored
@ -2594,8 +2594,148 @@ kernel void kernel_rope_neox(
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_multi(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3];
|
||||
int sec_w = args.sections[1] + args.sections[0];
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
if (i0 < args.n_dims) {
|
||||
const int ic = i0/2;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
if (sector >= args.sections[0] && sector < sec_w) {
|
||||
theta_base = (float) pos[i2 + args.ne2];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 3];
|
||||
}
|
||||
|
||||
float theta = theta_base*pow(args.freq_base, inv_ndims*i0);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims/2];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
} else {
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
|
||||
dst_data[0] = src[0];
|
||||
dst_data[1] = src[1];
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
template<typename T>
|
||||
kernel void kernel_rope_vision(
|
||||
constant ggml_metal_kargs_rope & args,
|
||||
device const char * src0,
|
||||
device const char * src1,
|
||||
device const char * src2,
|
||||
device char * dst,
|
||||
ushort tiitg[[thread_index_in_threadgroup]],
|
||||
ushort3 tptg [[threads_per_threadgroup]],
|
||||
uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
const int i3 = tgpig[2];
|
||||
const int i2 = tgpig[1];
|
||||
const int i1 = tgpig[0];
|
||||
|
||||
float corr_dims[2];
|
||||
rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
|
||||
device const int32_t * pos = (device const int32_t *) src1;
|
||||
|
||||
int sect_dims = args.sections[0] + args.sections[1];
|
||||
int sec_w = args.sections[1] + args.sections[0];
|
||||
int sec_e = args.sections[2] + sec_w;
|
||||
|
||||
const float inv_ndims = -1.f/args.n_dims;
|
||||
|
||||
float cos_theta;
|
||||
float sin_theta;
|
||||
|
||||
for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
const int ic = i0/2;
|
||||
const int sector = ic % sect_dims;
|
||||
|
||||
float theta_base = (float) pos[i2];
|
||||
if (sector >= args.sections[0] && sector < sec_w) {
|
||||
theta_base = (float) pos[i2 + args.ne2];
|
||||
}
|
||||
else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 2];
|
||||
}
|
||||
else if (sector >= sec_w + args.sections[2]) {
|
||||
theta_base = (float) pos[i2 + args.ne2 * 3];
|
||||
}
|
||||
|
||||
int p = sector;
|
||||
if (sector >= sec_w + args.sections[2]) {
|
||||
p = sector - (sec_w + args.sections[2]);
|
||||
} else if (sector >= sec_w) {
|
||||
p = sector - sec_w;
|
||||
} else if (sector >= args.sections[0]) {
|
||||
p = sector - args.sections[0];
|
||||
}
|
||||
|
||||
const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p);
|
||||
|
||||
const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
|
||||
rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
|
||||
device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
|
||||
const float x0 = src[0];
|
||||
const float x1 = src[args.n_dims];
|
||||
|
||||
dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
||||
typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
@ -2603,6 +2743,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
||||
template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
||||
|
||||
template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
||||
template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
||||
|
||||
typedef void (im2col_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
||||
|
54
llama/ggml-metal_darwin_arm64.m
vendored
54
llama/ggml-metal_darwin_arm64.m
vendored
@ -328,6 +328,10 @@ enum ggml_metal_kernel_type {
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
||||
@ -928,6 +932,10 @@ static struct ggml_backend_metal_context * ggml_metal_init(ggml_backend_dev_t de
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
||||
@ -1155,16 +1163,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_NORM:
|
||||
return true;
|
||||
case GGML_OP_ROPE:
|
||||
{
|
||||
const int mode = ((const int32_t *) op->op_params)[2];
|
||||
if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
return false;
|
||||
}
|
||||
if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
return false;
|
||||
}
|
||||
return true;
|
||||
}
|
||||
return true;
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
@ -3083,6 +3082,7 @@ static void ggml_metal_encode_node(
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
int32_t sections[4];
|
||||
|
||||
memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
|
||||
@ -3090,21 +3090,44 @@ static void ggml_metal_encode_node(
|
||||
memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
||||
memcpy(§ions, (const int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
|
||||
if (is_mrope) {
|
||||
GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
}
|
||||
|
||||
if (is_vision) {
|
||||
GGML_ASSERT(n_dims == ne00/2);
|
||||
}
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
if (!is_neox) {
|
||||
if (is_neox) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
} else if (is_mrope && !is_vision) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
} else if (is_vision) {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
} else {
|
||||
switch (src0->type) {
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
||||
case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
||||
case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
}
|
||||
@ -3135,6 +3158,7 @@ static void ggml_metal_encode_node(
|
||||
/*.attn_factor =*/ attn_factor,
|
||||
/*.beta_fast =*/ beta_fast,
|
||||
/*.beta_slow =*/ beta_slow,
|
||||
/*.sections =*/ {sections[0], sections[1], sections[2], sections[3]}
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
|
299
llama/patches/0014-qwen2vl-support.patch
Normal file
299
llama/patches/0014-qwen2vl-support.patch
Normal file
@ -0,0 +1,299 @@
|
||||
From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001
|
||||
From: jmorganca <jmorganca@gmail.com>
|
||||
Date: Sun, 15 Dec 2024 23:56:24 -0800
|
||||
Subject: [PATCH] qwen2vl support
|
||||
|
||||
---
|
||||
ggml/src/ggml-metal/ggml-metal-impl.h | 1 +
|
||||
ggml/src/ggml-metal/ggml-metal.m | 54 +++++++---
|
||||
ggml/src/ggml-metal/ggml-metal.metal | 146 ++++++++++++++++++++++++++
|
||||
3 files changed, 186 insertions(+), 15 deletions(-)
|
||||
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal-impl.h b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
index e3dc25f1..766a4999 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal-impl.h
|
||||
@@ -143,6 +143,7 @@ typedef struct {
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
+ int32_t sections[4];
|
||||
} ggml_metal_kargs_rope;
|
||||
|
||||
typedef struct {
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.m b/ggml/src/ggml-metal/ggml-metal.m
|
||||
index 787fc713..806c9fd3 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.m
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.m
|
||||
@@ -302,6 +302,10 @@ static void ggml_backend_metal_device_rel(struct ggml_backend_metal_device_conte
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32,
|
||||
GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16,
|
||||
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32,
|
||||
+ GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16,
|
||||
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32,
|
||||
+ GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F16,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_F32,
|
||||
GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16,
|
||||
@@ -902,6 +906,10 @@ @implementation GGMLMetalClass
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16, rope_norm_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32, rope_neox_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16, rope_neox_f16, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32, rope_multi_f32, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16, rope_multi_f16, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32, rope_vision_f32, true);
|
||||
+ GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16, rope_vision_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F16, im2col_f16, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_F32, im2col_f32, true);
|
||||
GGML_METAL_ADD_KERNEL(GGML_METAL_KERNEL_TYPE_IM2COL_EXT_F16, im2col_ext_f16, true);
|
||||
@@ -1129,16 +1137,7 @@ static bool ggml_metal_supports_op(const struct ggml_backend_metal_device_contex
|
||||
case GGML_OP_NORM:
|
||||
return true;
|
||||
case GGML_OP_ROPE:
|
||||
- {
|
||||
- const int mode = ((const int32_t *) op->op_params)[2];
|
||||
- if (mode & GGML_ROPE_TYPE_MROPE) {
|
||||
- return false;
|
||||
- }
|
||||
- if (mode & GGML_ROPE_TYPE_VISION) {
|
||||
- return false;
|
||||
- }
|
||||
- return true;
|
||||
- }
|
||||
+ return true;
|
||||
case GGML_OP_IM2COL:
|
||||
return op->src[0]->type == GGML_TYPE_F16;
|
||||
case GGML_OP_POOL_1D:
|
||||
@@ -3057,6 +3056,7 @@ static void ggml_metal_encode_node(
|
||||
float attn_factor;
|
||||
float beta_fast;
|
||||
float beta_slow;
|
||||
+ int32_t sections[4];
|
||||
|
||||
memcpy(&freq_base, (const int32_t *) dst->op_params + 5, sizeof(float));
|
||||
memcpy(&freq_scale, (const int32_t *) dst->op_params + 6, sizeof(float));
|
||||
@@ -3064,21 +3064,44 @@ static void ggml_metal_encode_node(
|
||||
memcpy(&attn_factor, (const int32_t *) dst->op_params + 8, sizeof(float));
|
||||
memcpy(&beta_fast, (const int32_t *) dst->op_params + 9, sizeof(float));
|
||||
memcpy(&beta_slow, (const int32_t *) dst->op_params + 10, sizeof(float));
|
||||
+ memcpy(§ions, (const int32_t *) dst->op_params + 11, sizeof(int32_t)*4);
|
||||
|
||||
const bool is_neox = mode & GGML_ROPE_TYPE_NEOX;
|
||||
+ const bool is_mrope = mode & GGML_ROPE_TYPE_MROPE;
|
||||
+ const bool is_vision = mode == GGML_ROPE_TYPE_VISION;
|
||||
+
|
||||
+ if (is_mrope) {
|
||||
+ GGML_ASSERT(sections[0] > 0 || sections[1] > 0 || sections[2] > 0);
|
||||
+ }
|
||||
+
|
||||
+ if (is_vision) {
|
||||
+ GGML_ASSERT(n_dims == ne00/2);
|
||||
+ }
|
||||
|
||||
id<MTLComputePipelineState> pipeline = nil;
|
||||
|
||||
- if (!is_neox) {
|
||||
+ if (is_neox) {
|
||||
switch (src0->type) {
|
||||
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
||||
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
||||
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
||||
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
||||
+ default: GGML_ABORT("fatal error");
|
||||
+ };
|
||||
+ } else if (is_mrope && !is_vision) {
|
||||
+ switch (src0->type) {
|
||||
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F32].pipeline; break;
|
||||
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_MULTI_F16].pipeline; break;
|
||||
+ default: GGML_ABORT("fatal error");
|
||||
+ };
|
||||
+ } else if (is_vision) {
|
||||
+ switch (src0->type) {
|
||||
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F32].pipeline; break;
|
||||
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_VISION_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
} else {
|
||||
switch (src0->type) {
|
||||
- case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F32].pipeline; break;
|
||||
- case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NEOX_F16].pipeline; break;
|
||||
+ case GGML_TYPE_F32: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F32].pipeline; break;
|
||||
+ case GGML_TYPE_F16: pipeline = ctx->kernels[GGML_METAL_KERNEL_TYPE_ROPE_NORM_F16].pipeline; break;
|
||||
default: GGML_ABORT("fatal error");
|
||||
};
|
||||
}
|
||||
@@ -3109,6 +3132,7 @@ static void ggml_metal_encode_node(
|
||||
/*.attn_factor =*/ attn_factor,
|
||||
/*.beta_fast =*/ beta_fast,
|
||||
/*.beta_slow =*/ beta_slow,
|
||||
+ /*.sections =*/ {sections[0], sections[1], sections[2], sections[3]}
|
||||
};
|
||||
|
||||
[encoder setComputePipelineState:pipeline];
|
||||
diff --git a/ggml/src/ggml-metal/ggml-metal.metal b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
index 204c93e6..67b3240f 100644
|
||||
--- a/ggml/src/ggml-metal/ggml-metal.metal
|
||||
+++ b/ggml/src/ggml-metal/ggml-metal.metal
|
||||
@@ -2568,8 +2568,148 @@ kernel void kernel_rope_neox(
|
||||
}
|
||||
}
|
||||
|
||||
+
|
||||
+template<typename T>
|
||||
+kernel void kernel_rope_multi(
|
||||
+ constant ggml_metal_kargs_rope & args,
|
||||
+ device const char * src0,
|
||||
+ device const char * src1,
|
||||
+ device const char * src2,
|
||||
+ device char * dst,
|
||||
+ ushort tiitg[[thread_index_in_threadgroup]],
|
||||
+ ushort3 tptg [[threads_per_threadgroup]],
|
||||
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
+ const int i3 = tgpig[2];
|
||||
+ const int i2 = tgpig[1];
|
||||
+ const int i1 = tgpig[0];
|
||||
+
|
||||
+ float corr_dims[2];
|
||||
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
+
|
||||
+ device const int32_t * pos = (device const int32_t *) src1;
|
||||
+
|
||||
+ int sect_dims = args.sections[0] + args.sections[1] + args.sections[2] + args.sections[3];
|
||||
+ int sec_w = args.sections[1] + args.sections[0];
|
||||
+
|
||||
+ const float inv_ndims = -1.f/args.n_dims;
|
||||
+
|
||||
+ float cos_theta;
|
||||
+ float sin_theta;
|
||||
+
|
||||
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
+ if (i0 < args.n_dims) {
|
||||
+ const int ic = i0/2;
|
||||
+ const int sector = ic % sect_dims;
|
||||
+
|
||||
+ float theta_base = (float) pos[i2];
|
||||
+ if (sector >= args.sections[0] && sector < sec_w) {
|
||||
+ theta_base = (float) pos[i2 + args.ne2];
|
||||
+ }
|
||||
+ else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
|
||||
+ theta_base = (float) pos[i2 + args.ne2 * 2];
|
||||
+ }
|
||||
+ else if (sector >= sec_w + args.sections[2]) {
|
||||
+ theta_base = (float) pos[i2 + args.ne2 * 3];
|
||||
+ }
|
||||
+
|
||||
+ float theta = theta_base*pow(args.freq_base, inv_ndims*i0);
|
||||
+
|
||||
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
+
|
||||
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
+
|
||||
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
+
|
||||
+ const float x0 = src[0];
|
||||
+ const float x1 = src[args.n_dims/2];
|
||||
+
|
||||
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
+ dst_data[args.n_dims/2] = x0*sin_theta + x1*cos_theta;
|
||||
+ } else {
|
||||
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + i0*args.nb00);
|
||||
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + i0*args.nb0);
|
||||
+
|
||||
+ dst_data[0] = src[0];
|
||||
+ dst_data[1] = src[1];
|
||||
+ }
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+template<typename T>
|
||||
+kernel void kernel_rope_vision(
|
||||
+ constant ggml_metal_kargs_rope & args,
|
||||
+ device const char * src0,
|
||||
+ device const char * src1,
|
||||
+ device const char * src2,
|
||||
+ device char * dst,
|
||||
+ ushort tiitg[[thread_index_in_threadgroup]],
|
||||
+ ushort3 tptg [[threads_per_threadgroup]],
|
||||
+ uint3 tgpig[[threadgroup_position_in_grid]]) {
|
||||
+ const int i3 = tgpig[2];
|
||||
+ const int i2 = tgpig[1];
|
||||
+ const int i1 = tgpig[0];
|
||||
+
|
||||
+ float corr_dims[2];
|
||||
+ rope_yarn_corr_dims(args.n_dims, args.n_ctx_orig, args.freq_base, args.beta_fast, args.beta_slow, corr_dims);
|
||||
+
|
||||
+ device const int32_t * pos = (device const int32_t *) src1;
|
||||
+
|
||||
+ int sect_dims = args.sections[0] + args.sections[1];
|
||||
+ int sec_w = args.sections[1] + args.sections[0];
|
||||
+ int sec_e = args.sections[2] + sec_w;
|
||||
+
|
||||
+ const float inv_ndims = -1.f/args.n_dims;
|
||||
+
|
||||
+ float cos_theta;
|
||||
+ float sin_theta;
|
||||
+
|
||||
+ for (int i0 = 2*tiitg; i0 < args.ne0; i0 += 2*tptg.x) {
|
||||
+ const int ic = i0/2;
|
||||
+ const int sector = ic % sect_dims;
|
||||
+
|
||||
+ float theta_base = (float) pos[i2];
|
||||
+ if (sector >= args.sections[0] && sector < sec_w) {
|
||||
+ theta_base = (float) pos[i2 + args.ne2];
|
||||
+ }
|
||||
+ else if (sector >= sec_w && sector < sec_w + args.sections[2]) {
|
||||
+ theta_base = (float) pos[i2 + args.ne2 * 2];
|
||||
+ }
|
||||
+ else if (sector >= sec_w + args.sections[2]) {
|
||||
+ theta_base = (float) pos[i2 + args.ne2 * 3];
|
||||
+ }
|
||||
+
|
||||
+ int p = sector;
|
||||
+ if (sector >= sec_w + args.sections[2]) {
|
||||
+ p = sector - (sec_w + args.sections[2]);
|
||||
+ } else if (sector >= sec_w) {
|
||||
+ p = sector - sec_w;
|
||||
+ } else if (sector >= args.sections[0]) {
|
||||
+ p = sector - args.sections[0];
|
||||
+ }
|
||||
+
|
||||
+ const float theta = theta_base*pow(args.freq_base, inv_ndims*2.0f*p);
|
||||
+
|
||||
+ const float freq_factor = src2 != src0 ? ((device const float *) src2)[ic] : 1.0f;
|
||||
+
|
||||
+ rope_yarn(theta/freq_factor, args.freq_scale, corr_dims, i0, args.ext_factor, args.attn_factor, &cos_theta, &sin_theta);
|
||||
+
|
||||
+ device const T * const src = (device T *)(src0 + i3*args.nb03 + i2*args.nb02 + i1*args.nb01 + ic*args.nb00);
|
||||
+ device T * dst_data = (device T *)( dst + i3*args.nb3 + i2*args.nb2 + i1*args.nb1 + ic*args.nb0);
|
||||
+
|
||||
+ const float x0 = src[0];
|
||||
+ const float x1 = src[args.n_dims];
|
||||
+
|
||||
+ dst_data[0] = x0*cos_theta - x1*sin_theta;
|
||||
+ dst_data[args.n_dims] = x0*sin_theta + x1*cos_theta;
|
||||
+ }
|
||||
+}
|
||||
+
|
||||
+
|
||||
typedef decltype(kernel_rope_norm<float>) kernel_rope_norm_t;
|
||||
typedef decltype(kernel_rope_neox<float>) kernel_rope_neox_t;
|
||||
+typedef decltype(kernel_rope_multi<float>) kernel_rope_multi_t;
|
||||
+typedef decltype(kernel_rope_vision<float>) kernel_rope_vision_t;
|
||||
|
||||
template [[host_name("kernel_rope_norm_f32")]] kernel kernel_rope_norm_t kernel_rope_norm<float>;
|
||||
template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_rope_norm<half>;
|
||||
@@ -2577,6 +2717,12 @@ template [[host_name("kernel_rope_norm_f16")]] kernel kernel_rope_norm_t kernel_
|
||||
template [[host_name("kernel_rope_neox_f32")]] kernel kernel_rope_neox_t kernel_rope_neox<float>;
|
||||
template [[host_name("kernel_rope_neox_f16")]] kernel kernel_rope_neox_t kernel_rope_neox<half>;
|
||||
|
||||
+template [[host_name("kernel_rope_multi_f32")]] kernel kernel_rope_multi_t kernel_rope_multi<float>;
|
||||
+template [[host_name("kernel_rope_multi_f16")]] kernel kernel_rope_multi_t kernel_rope_multi<half>;
|
||||
+
|
||||
+template [[host_name("kernel_rope_vision_f32")]] kernel kernel_rope_vision_t kernel_rope_vision<float>;
|
||||
+template [[host_name("kernel_rope_vision_f16")]] kernel kernel_rope_vision_t kernel_rope_vision<half>;
|
||||
+
|
||||
typedef void (im2col_t)(
|
||||
device const float * x,
|
||||
device char * dst,
|
Loading…
x
Reference in New Issue
Block a user