From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Thu, 1 May 2025 13:45:12 -0700 Subject: [PATCH] add argsort and cuda copy for i32 --- ggml/src/ggml-cpu/ops.cpp | 43 +++++++++++++ ggml/src/ggml-cuda/argsort.cu | 102 ++++++++++++++++++++++++++++++- ggml/src/ggml-cuda/cpy-utils.cuh | 6 ++ ggml/src/ggml-cuda/cpy.cu | 43 +++++++++++++ 4 files changed, 192 insertions(+), 2 deletions(-) diff --git a/ggml/src/ggml-cpu/ops.cpp b/ggml/src/ggml-cpu/ops.cpp index 854f1c2b..a2924757 100644 --- a/ggml/src/ggml-cpu/ops.cpp +++ b/ggml/src/ggml-cpu/ops.cpp @@ -8146,6 +8146,45 @@ static void ggml_compute_forward_argsort_f32( } } +static void ggml_compute_forward_argsort_i32( + const ggml_compute_params * params, + ggml_tensor * dst) { + + const ggml_tensor * src0 = dst->src[0]; + + GGML_TENSOR_UNARY_OP_LOCALS + + GGML_ASSERT(nb0 == sizeof(int32_t)); + + const int ith = params->ith; + const int nth = params->nth; + + const int64_t nr = ggml_nrows(src0); + + ggml_sort_order order = (ggml_sort_order) ggml_get_op_params_i32(dst, 0); + + for (int64_t i = ith; i < nr; i += nth) { + int32_t * dst_data = (int32_t *)((char *) dst->data + i*nb1); + const int32_t * src_data = (int32_t *)((char *) src0->data + i*nb01); + + for (int64_t j = 0; j < ne0; j++) { + dst_data[j] = j; + } + + // C doesn't have a functional sort, so we do a bubble sort instead + for (int64_t j = 0; j < ne0; j++) { + for (int64_t k = j + 1; k < ne0; k++) { + if ((order == GGML_SORT_ORDER_ASC && src_data[dst_data[j]] > src_data[dst_data[k]]) || + (order == GGML_SORT_ORDER_DESC && src_data[dst_data[j]] < src_data[dst_data[k]])) { + int32_t tmp = dst_data[j]; + dst_data[j] = dst_data[k]; + dst_data[k] = tmp; + } + } + } + } +} + void ggml_compute_forward_argsort( const ggml_compute_params * params, ggml_tensor * dst) { @@ -8157,6 +8196,10 @@ void ggml_compute_forward_argsort( { ggml_compute_forward_argsort_f32(params, dst); } break; + case GGML_TYPE_I32: + { + ggml_compute_forward_argsort_i32(params, dst); + } break; default: { GGML_ABORT("fatal error"); diff --git a/ggml/src/ggml-cuda/argsort.cu b/ggml/src/ggml-cuda/argsort.cu index 607ded85..53b02634 100644 --- a/ggml/src/ggml-cuda/argsort.cu +++ b/ggml/src/ggml-cuda/argsort.cu @@ -85,13 +85,107 @@ static void argsort_f32_i32_cuda(const float * x, int * dst, const int ncols, co } } + +template +static __global__ void k_argsort_i32_i32(const int32_t * x, int * dst, const int ncols, const int ncols_pad) { + extern __shared__ int shared_mem[]; + int * indices = shared_mem; + + const int tid = threadIdx.x; + const int row = blockIdx.y; + + // Initialize all indices, handling the case where threads < ncols_pad + for (int i = tid; i < ncols_pad; i += blockDim.x) { + indices[i] = i < ncols ? i : 0; // Use 0 for padding indices + } + __syncthreads(); + + // Bitonic sort + for (int k = 2; k <= ncols_pad; k *= 2) { + for (int j = k/2; j > 0; j /= 2) { + for (int i = tid; i < ncols_pad; i += blockDim.x) { + const int ij = i ^ j; + if (ij > i) { + // Only compare values within the actual data range + if (i < ncols && ij < ncols) { + if ((i & k) == 0) { + if (order == GGML_SORT_ORDER_ASC) { + if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } else { + if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } + } else { + if (order == GGML_SORT_ORDER_ASC) { + if (x[row * ncols + indices[i]] < x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } else { + if (x[row * ncols + indices[i]] > x[row * ncols + indices[ij]]) { + int tmp = indices[i]; + indices[i] = indices[ij]; + indices[ij] = tmp; + } + } + } + } + } + } + __syncthreads(); + } + } + + // Write sorted indices to output, only threads handling valid data + for (int i = tid; i < ncols; i += blockDim.x) { + dst[row * ncols + i] = indices[i]; + } +} + +static void argsort_i32_i32_cuda(const int32_t * x, int * dst, const int ncols, const int nrows, ggml_sort_order order, cudaStream_t stream) { + // Bitonic sort requires ncols to be power of 2 + const int ncols_pad = next_power_of_2(ncols); + + // Ensure thread count doesn't exceed maximum (typically 1024) + const int max_threads = 1024; // This is the typical max for most GPUs + const int threads_per_block = ncols_pad > max_threads ? max_threads : ncols_pad; + + const dim3 block_dims(threads_per_block, 1, 1); + const dim3 block_nums(1, nrows, 1); + const size_t shared_mem = ncols_pad * sizeof(int); + + // Check if shared memory size is within limits + const size_t max_shared_mem = ggml_cuda_info().devices[ggml_cuda_get_device()].smpb; + + // Instead of logging an error, use GGML_ASSERT with a descriptive message + GGML_ASSERT(shared_mem <= max_shared_mem && "argsort: required shared memory exceeds device limit"); + + // Launch kernels with the updated thread configuration + if (order == GGML_SORT_ORDER_ASC) { + k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); + } else if (order == GGML_SORT_ORDER_DESC) { + k_argsort_i32_i32<<>>(x, dst, ncols, ncols_pad); + } else { + GGML_ABORT("fatal error"); + } +} + + void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { const ggml_tensor * src0 = dst->src[0]; const float * src0_d = (const float *)src0->data; float * dst_d = (float *)dst->data; cudaStream_t stream = ctx.stream(); - GGML_ASSERT(src0->type == GGML_TYPE_F32); + GGML_ASSERT(src0->type == GGML_TYPE_F32 || src0->type == GGML_TYPE_I32); GGML_ASSERT( dst->type == GGML_TYPE_I32); GGML_ASSERT(ggml_is_contiguous(src0)); @@ -100,5 +194,9 @@ void ggml_cuda_op_argsort(ggml_backend_cuda_context & ctx, ggml_tensor * dst) { enum ggml_sort_order order = (enum ggml_sort_order) dst->op_params[0]; - argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + if (src0->type == GGML_TYPE_I32) { + argsort_i32_i32_cuda((const int32_t *)src0_d, (int *)dst_d, ncols, nrows, order, stream); + } else { + argsort_f32_i32_cuda(src0_d, (int *)dst_d, ncols, nrows, order, stream); + } } diff --git a/ggml/src/ggml-cuda/cpy-utils.cuh b/ggml/src/ggml-cuda/cpy-utils.cuh index 410c12b7..b8e9e107 100644 --- a/ggml/src/ggml-cuda/cpy-utils.cuh +++ b/ggml/src/ggml-cuda/cpy-utils.cuh @@ -223,3 +223,9 @@ template static __device__ void cpy_1_flt(const char * cxi, char * cdsti) { convert_flt((const src_t *)cxi, (dst_t *)cdsti); } + +static __device__ void cpy_1_i32_i32(const char * cxi, char * cdsti) { + const int32_t * src = (const int32_t *)cxi; + int32_t * dst = (int32_t *)cdsti; + *dst = *src; +} diff --git a/ggml/src/ggml-cuda/cpy.cu b/ggml/src/ggml-cuda/cpy.cu index f9bb0256..9c3774e5 100644 --- a/ggml/src/ggml-cuda/cpy.cu +++ b/ggml/src/ggml-cuda/cpy.cu @@ -278,6 +278,47 @@ static void ggml_cpy_f32_iq4_nl_cuda( (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, cdst_indirect, graph_cpynode_index++); } +template +static __global__ void cpy_i32_i32( + const char *cx, char *cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int & graph_cpynode_index) { + + const int64_t i = blockDim.x * blockIdx.x + threadIdx.x; + + if (i >= ne) { + return; + } + + const int64_t i03 = i / (ne00 * ne01 * ne02); + const int64_t i02 = (i - i03 * ne00 * ne01 * ne02) / (ne00 * ne01); + const int64_t i01 = (i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00) / ne00; + const int64_t i00 = i - i03 * ne00 * ne01 * ne02 - i02 * ne01 * ne00 - i01 * ne00; + const int64_t x_offset = i00 * nb00 + i01 * nb01 + i02 * nb02 + i03 * nb03; + + const int64_t i13 = i / (ne10 * ne11 * ne12); + const int64_t i12 = (i - i13 * ne10 * ne11 * ne12) / (ne10 * ne11); + const int64_t i11 = (i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11) / ne10; + const int64_t i10 = i - i13 * ne10 * ne11 * ne12 - i12 * ne10 * ne11 - i11 * ne10; + const int64_t dst_offset = i10 * nb10 + i11 * nb11 + i12 * nb12 + i13 * nb13; + + char * cdst_ptr = (cdst_indirect != nullptr) ? cdst_indirect[graph_cpynode_index] : cdst; + cpy_1(cx + x_offset, cdst_ptr + dst_offset); +} + + +static void ggml_cpy_i32_i32_cuda( + const char * cx, char * cdst, const int ne, + const int ne00, const int ne01, const int ne02, const int nb00, const int nb01, const int nb02, const int nb03, + const int ne10, const int ne11, const int ne12, const int nb10, const int nb11, const int nb12, const int nb13, + cudaStream_t stream, char ** cdst_indirect, int graph_cpynode_index) { + + const int num_blocks = (ne + CUDA_CPY_BLOCK_SIZE - 1) / CUDA_CPY_BLOCK_SIZE; + cpy_i32_i32<<>> + (cx, cdst, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, stream, cdst_indirect, graph_cpynode_index); +} + void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, ggml_tensor * src1, bool disable_indirection_for_this_node) { const int64_t ne = ggml_nelements(src0); GGML_ASSERT(ne == ggml_nelements(src1)); @@ -369,6 +410,8 @@ void ggml_cuda_cpy(ggml_backend_cuda_context & ctx, const ggml_tensor * src0, gg ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_F16 && src1->type == GGML_TYPE_F32) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); + } else if (src0->type == GGML_TYPE_I32 && src1->type == GGML_TYPE_I32) { + ggml_cpy_i32_i32_cuda(src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_BF16) { ggml_cpy_flt_cuda (src0_ddc, src1_ddc, ne, ne00, ne01, ne02, nb00, nb01, nb02, nb03, ne10, ne11, ne12, nb10, nb11, nb12, nb13, main_stream, dest_ptrs_d, graph_cpynode_index); } else if (src0->type == GGML_TYPE_BF16 && src1->type == GGML_TYPE_F16) {