mirror of
https://github.com/ollama/ollama.git
synced 2025-04-02 09:00:28 +02:00
Add MLX Backend POC
The cache still has some bugs.
This commit is contained in:
parent
7b3c3135de
commit
c8f346dc46
@ -130,3 +130,8 @@ if(CMAKE_HIP_COMPILER)
|
||||
endforeach()
|
||||
endif()
|
||||
endif()
|
||||
|
||||
if(CMAKE_SYSTEM_NAME MATCHES "Darwin")
|
||||
message(STATUS "Setting up MLX (this takes a while...)")
|
||||
add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/mlx)
|
||||
endif()
|
@ -257,6 +257,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
if c.config.MaskDType != ml.DTypeF32 {
|
||||
// TODO - MLX not covered here...
|
||||
out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...)
|
||||
ctx.Forward(maskTensor.Copy(ctx, out))
|
||||
maskTensor = out
|
||||
@ -266,6 +267,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) {
|
||||
}
|
||||
|
||||
func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) {
|
||||
// TODO this wont work on MLX as is - needs to be adjusted for SliceUpdate
|
||||
for i, key := range c.keys {
|
||||
if key == nil {
|
||||
continue
|
||||
@ -431,41 +433,48 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) {
|
||||
value := c.values[c.curLayer]
|
||||
|
||||
kHeadDim := key.Dim(2)
|
||||
vHeadDim := value.Dim(2)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(0)
|
||||
cachedSize := c.curMask.Dim(1)
|
||||
// slog.Info("Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "rowSize", rowSize, "cachedSize", cachedSize)
|
||||
|
||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||
[]int{cachedSize, numKVHeads, kHeadDim},
|
||||
[]int{key.Stride(0), key.Stride(1)},
|
||||
)
|
||||
// slog.Info("Get", "key", key)
|
||||
// panic("XXX")
|
||||
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(2)
|
||||
|
||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||
[]int{numKVHeads, vHeadDim, cachedSize},
|
||||
[]int{value.Stride(0), value.Stride(1)},
|
||||
)
|
||||
// Potential abstraction to work around differences in cache tensor handling.
|
||||
if su, ok := ctx.(ml.SliceUpdate); ok {
|
||||
start := []int{c.curCellRange.min, 0, 0}
|
||||
kStop := []int{c.curCellRange.min + cachedSize, numKVHeads, kHeadDim}
|
||||
vStop := []int{c.curCellRange.min + cachedSize, numKVHeads, vHeadDim}
|
||||
strides := []int{1, 1, 1}
|
||||
key = su.Slice(key, start, kStop, strides)
|
||||
value = su.Slice(value, start, vStop, strides)
|
||||
} else {
|
||||
vHeadDim := value.Dim(2)
|
||||
rowSize := value.Stride(0)
|
||||
|
||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||
[]int{cachedSize, numKVHeads, vHeadDim},
|
||||
[]int{value.Stride(0), value.Stride(1)},
|
||||
key = key.View(ctx, rowSize*c.curCellRange.min,
|
||||
[]int{cachedSize, numKVHeads, kHeadDim},
|
||||
[]int{key.Stride(0), key.Stride(1)},
|
||||
)
|
||||
}
|
||||
|
||||
// TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||
// the 1 becomes trailing and messes up later operations
|
||||
// This isn't the right solution, but works around it...
|
||||
if c.curMask.Dim(1) == 1 {
|
||||
return key, value, c.curMask.Permute(ctx, 1, 0, 2, 3)
|
||||
if c.config.PermutedV {
|
||||
vHeadDim := value.Dim(1)
|
||||
elemSize := value.Stride(2)
|
||||
|
||||
value = value.View(ctx, elemSize*c.curCellRange.min,
|
||||
[]int{numKVHeads, vHeadDim, cachedSize},
|
||||
[]int{value.Stride(0), value.Stride(1)},
|
||||
)
|
||||
} else {
|
||||
vHeadDim := value.Dim(2)
|
||||
rowSize := value.Stride(0)
|
||||
|
||||
value = value.View(ctx, rowSize*c.curCellRange.min,
|
||||
[]int{cachedSize, numKVHeads, vHeadDim},
|
||||
[]int{value.Stride(0), value.Stride(1)},
|
||||
)
|
||||
}
|
||||
// TODO The mask changes from X,X to 1,X, and with the Row-order change
|
||||
// the 1 becomes trailing and messes up later operations
|
||||
// This isn't the right solution, but works around it...
|
||||
if c.curMask.Dim(1) == 1 {
|
||||
return key, value, c.curMask.Permute(ctx, 1, 0, 2, 3)
|
||||
}
|
||||
}
|
||||
|
||||
return key, value, c.curMask
|
||||
@ -495,20 +504,35 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) {
|
||||
} else {
|
||||
c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, vHeadDim)
|
||||
}
|
||||
// slog.Info("Cache Put", "c.keys[c.curLayer]", c.keys[c.curLayer])
|
||||
// slog.Info("Cache Put", "c.values[c.curLayer]", c.values[c.curLayer])
|
||||
}
|
||||
|
||||
rowSize := c.keys[c.curLayer].Stride(0)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, []int{kHeadDim * numKVHeads * batchSize}, nil)))
|
||||
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, []int{vHeadDim * numKVHeads, batchSize}, []int{int(c.Capacity) * elemSize})))
|
||||
// Potential abstraction to work around differences in cache tensor handling.
|
||||
if su, ok := ctx.(ml.SliceUpdate); ok {
|
||||
start := []int{c.curLoc, 0, 0}
|
||||
kStop := []int{c.curLoc + batchSize, numKVHeads, kHeadDim}
|
||||
vStop := []int{c.curLoc + batchSize, numKVHeads, vHeadDim}
|
||||
strides := []int{1, 1, 1}
|
||||
su.SliceUpdate(c.keys[c.curLayer], key, start, kStop, strides)
|
||||
su.SliceUpdate(c.values[c.curLayer], value, start, vStop, strides)
|
||||
ctx.Forward(c.keys[c.curLayer])
|
||||
ctx.Forward(c.values[c.curLayer])
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(0)
|
||||
// GGML pattern
|
||||
rowSize := c.keys[c.curLayer].Stride(0)
|
||||
ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, []int{kHeadDim * numKVHeads * batchSize}, nil)))
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, []int{vHeadDim * numKVHeads * batchSize}, nil)))
|
||||
if c.config.PermutedV {
|
||||
elemSize := c.values[c.curLayer].Stride(2)
|
||||
|
||||
value = value.Permute(ctx, 1, 2, 0, 3)
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, []int{vHeadDim * numKVHeads, batchSize}, []int{int(c.Capacity) * elemSize})))
|
||||
} else {
|
||||
rowSize := c.values[c.curLayer].Stride(0)
|
||||
|
||||
ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, []int{vHeadDim * numKVHeads * batchSize}, nil)))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@ -565,6 +589,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error {
|
||||
continue
|
||||
}
|
||||
|
||||
// TODO - this also needs adjusting to support MLX with SliceUpdate
|
||||
kHeadDim := key.Dim(2)
|
||||
numKVHeads := key.Dim(1)
|
||||
rowSize := key.Stride(0)
|
||||
|
@ -457,7 +457,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors, freqs ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor {
|
||||
panic("not implemented")
|
||||
}
|
||||
|
||||
|
@ -28,6 +28,7 @@ var errorPrefixes = []string{
|
||||
"error loading model",
|
||||
"GGML_ASSERT",
|
||||
"Deepseek2 does not support K-shift",
|
||||
"panic:",
|
||||
}
|
||||
|
||||
func (w *StatusWriter) Write(b []byte) (int, error) {
|
||||
|
@ -4,6 +4,7 @@ import (
|
||||
"bytes"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"os"
|
||||
"slices"
|
||||
"strconv"
|
||||
@ -87,7 +88,13 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro
|
||||
}
|
||||
|
||||
func NewBackend(f *os.File, params BackendParams) (Backend, error) {
|
||||
if backend, ok := backends["ggml"]; ok {
|
||||
be := os.Getenv("OLLAMA_BACKEND")
|
||||
if be == "" {
|
||||
be = "ggml"
|
||||
slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override")
|
||||
}
|
||||
slog.Info("Loading new engine", "backend", be)
|
||||
if backend, ok := backends[be]; ok {
|
||||
return backend(f, params)
|
||||
}
|
||||
|
||||
@ -122,6 +129,18 @@ type Context interface {
|
||||
Abort(Tensor) // Evaluate the graph up to this point, retrieve the data from the tensor and dump it to a json file for comparison
|
||||
}
|
||||
|
||||
// Usage:
|
||||
//
|
||||
// if su, ok := ctx.(ml.SliceUpdate); ok {
|
||||
// su.SliceUpdate(...)
|
||||
// } else {
|
||||
// // view + copy operations
|
||||
// }
|
||||
type SliceUpdate interface {
|
||||
SliceUpdate(target, source Tensor, start, stop, strides []int)
|
||||
Slice(source Tensor, start, stop, strides []int) Tensor
|
||||
}
|
||||
|
||||
type Tensor interface {
|
||||
Dim(n int) int
|
||||
Stride(n int) int
|
||||
@ -145,7 +164,7 @@ type Tensor interface {
|
||||
AvgPool2D(ctx Context, k, s int, p float32) Tensor
|
||||
Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor
|
||||
|
||||
RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
RoPE(ctx Context, positionIDs, ropeFactors, freqs Tensor, dim, ropeType uint32, base, scale float32) Tensor
|
||||
|
||||
Tanh(ctx Context) Tensor
|
||||
GELU(ctx Context) Tensor
|
||||
@ -318,3 +337,15 @@ func (dt DType) String() string {
|
||||
return "unknown"
|
||||
}
|
||||
}
|
||||
|
||||
func (dt DType) Sizeof() int64 {
|
||||
// TODO call underlying API?
|
||||
switch dt {
|
||||
case DTypeF32:
|
||||
return 4
|
||||
case DTypeI32:
|
||||
return 4
|
||||
default:
|
||||
panic("unrecognized type")
|
||||
}
|
||||
}
|
||||
|
@ -2,4 +2,5 @@ package backend
|
||||
|
||||
import (
|
||||
_ "github.com/ollama/ollama/ml/backend/ggml"
|
||||
_ "github.com/ollama/ollama/ml/backend/mlx"
|
||||
)
|
||||
|
@ -1053,7 +1053,15 @@ const (
|
||||
ropeTypeVision C.int = 24
|
||||
)
|
||||
|
||||
func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor {
|
||||
func (t *Tensor) RoPE(
|
||||
ctx ml.Context,
|
||||
positionIDs ml.Tensor,
|
||||
ropeFactors ml.Tensor,
|
||||
freqs ml.Tensor, // Unused on GGML
|
||||
ropeDim, ropeType uint32,
|
||||
ropeBase,
|
||||
ropeScale float32,
|
||||
) ml.Tensor {
|
||||
if ropeFactors == nil {
|
||||
ropeFactors = &Tensor{b: t.b, nDims: 0}
|
||||
}
|
||||
|
36
ml/backend/mlx/CMakeLists.txt
vendored
Normal file
36
ml/backend/mlx/CMakeLists.txt
vendored
Normal file
@ -0,0 +1,36 @@
|
||||
include(FetchContent)
|
||||
|
||||
set(MLX_C_BUILD_EXAMPLES OFF)
|
||||
|
||||
set(MLX_BUILD_GGUF OFF)
|
||||
set(MLX_BUILD_SAFETENSORS OFF)
|
||||
|
||||
function(set_target_output_directory _target)
|
||||
if(TARGET ${_target})
|
||||
set_target_properties(${_target} PROPERTIES
|
||||
RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib
|
||||
LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib
|
||||
ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib
|
||||
)
|
||||
endif()
|
||||
endfunction()
|
||||
|
||||
execute_process(
|
||||
COMMAND
|
||||
zsh "-c"
|
||||
"echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'"
|
||||
OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY)
|
||||
|
||||
if(NOT MLX_METAL_VERSION)
|
||||
message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF")
|
||||
set(MLX_BUILD_METAL OFF)
|
||||
endif()
|
||||
|
||||
FetchContent_Declare(
|
||||
mlx-c
|
||||
GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git"
|
||||
GIT_TAG v0.1.0)
|
||||
FetchContent_MakeAvailable(mlx-c)
|
||||
|
||||
set_target_output_directory(mlx)
|
||||
set_target_output_directory(mlxc)
|
1075
ml/backend/mlx/mlx.go
Normal file
1075
ml/backend/mlx/mlx.go
Normal file
File diff suppressed because it is too large
Load Diff
328
ml/backend/mlx/quant.go
Normal file
328
ml/backend/mlx/quant.go
Normal file
@ -0,0 +1,328 @@
|
||||
package mlx
|
||||
|
||||
/*
|
||||
#include <stdio.h>
|
||||
#include <string.h>
|
||||
|
||||
#include "mlx/c/array.h"
|
||||
#include "mlx/c/ops.h"
|
||||
|
||||
// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp
|
||||
|
||||
void unpack_32_4(uint8_t* data, int8_t* dst) {
|
||||
memset(dst, 0, 16);
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes.
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[j / 2] += x;
|
||||
}
|
||||
// Last 16 weights are in the higher bits
|
||||
for (int j = 0; j < 16; ++j) {
|
||||
uint8_t x = (data[j + 2] >> 4);
|
||||
if (j % 2 != 0) {
|
||||
x <<= 4;
|
||||
}
|
||||
dst[8 + j / 2] += x;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 4bit weights|.
|
||||
void extract_q4_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = -8 * scales[i];
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q4_1 tensors.
|
||||
// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|.
|
||||
void extract_q4_1_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
scales[i] = *((float16_t*)data);
|
||||
biases[i] = *((float16_t*)(data) + 1);
|
||||
unpack_32_4(data, weights);
|
||||
weights += 16;
|
||||
data += bytes_per_block;
|
||||
}
|
||||
}
|
||||
|
||||
// Extracts (weight, scales, biases) from Q8_0 tensors.
|
||||
// Data layout is: |16 bit scale|32 x 8bit weights|.
|
||||
void extract_q8_0_data(
|
||||
uint8_t* data,
|
||||
mlx_array* weights_arr,
|
||||
mlx_array* scales_arr,
|
||||
mlx_array* biases_arr) {
|
||||
const uint64_t weights_per_block = 32;
|
||||
const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights
|
||||
uint8_t* weights = mlx_array_data_uint8(*weights_arr);
|
||||
float16_t* scales = mlx_array_data_float16(*scales_arr);
|
||||
float16_t* biases = mlx_array_data_float16(*biases_arr);
|
||||
for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) {
|
||||
uint8_t* block_data = data + i * bytes_per_block;
|
||||
scales[i] = *((float16_t*)block_data);
|
||||
biases[i] = -128 * scales[i];
|
||||
for (int64_t j = 0; j < weights_per_block; ++j) {
|
||||
uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes.
|
||||
// Original data is in int8_t, so we add a bias of -128 and invert the
|
||||
// first bit.
|
||||
x ^= 1 << 7;
|
||||
weights[i * weights_per_block + j] = x;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Drived from ggml-quants.c
|
||||
|
||||
#define QK_K 256
|
||||
|
||||
// 6-bit quantization
|
||||
// weight is represented as x = a * q
|
||||
// 16 blocks of 16 elements each
|
||||
// Effectively 6.5625 bits per weight
|
||||
typedef struct {
|
||||
uint8_t ql[QK_K/2]; // quants, lower 4 bits
|
||||
uint8_t qh[QK_K/4]; // quants, upper 2 bits
|
||||
int8_t scales[QK_K/16]; // scales, quantized with 8 bits
|
||||
uint16_t d; // super-block scale
|
||||
} block_q6_K;
|
||||
|
||||
void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
const int64_t nb = k / QK_K;
|
||||
block_q6_K *x = (block_q6_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
|
||||
const uint8_t * restrict ql = x[i].ql;
|
||||
const uint8_t * restrict qh = x[i].qh;
|
||||
const int8_t * restrict sc = x[i].scales;
|
||||
|
||||
for (int n = 0; n < QK_K; n += 128) {
|
||||
for (int l = 0; l < 32; ++l) {
|
||||
int is = l/16;
|
||||
const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32;
|
||||
const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32;
|
||||
const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32;
|
||||
const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32;
|
||||
y[l + 0] = d * sc[is + 0] * q1;
|
||||
y[l + 32] = d * sc[is + 2] * q2;
|
||||
y[l + 64] = d * sc[is + 4] * q3;
|
||||
y[l + 96] = d * sc[is + 6] * q4;
|
||||
}
|
||||
y += 128;
|
||||
ql += 64;
|
||||
qh += 32;
|
||||
sc += 8;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#define K_SCALE_SIZE 12
|
||||
#define GGML_COMMON_AGGR_U
|
||||
#define GGML_COMMON_AGGR_S
|
||||
|
||||
// 4-bit quantization
|
||||
// 8 blocks of 32 elements each
|
||||
// weight is represented as x = a * q + b
|
||||
// Effectively 4.5 bits per weight
|
||||
typedef struct {
|
||||
union {
|
||||
struct {
|
||||
uint16_t d; // super-block scale for quantized scales
|
||||
uint16_t dmin; // super-block scale for quantized mins
|
||||
} GGML_COMMON_AGGR_S;
|
||||
uint16_t dm;
|
||||
} GGML_COMMON_AGGR_U;
|
||||
uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits
|
||||
uint8_t qs[QK_K/2]; // 4--bit quants
|
||||
} block_q4_K;
|
||||
|
||||
static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) {
|
||||
if (j < 4) {
|
||||
*d = q[j] & 63; *m = q[j + 4] & 63;
|
||||
} else {
|
||||
*d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4);
|
||||
*m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4);
|
||||
}
|
||||
}
|
||||
|
||||
void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) {
|
||||
block_q4_K *x = (block_q4_K *)vx;
|
||||
float16_t* y = (float16_t *)vy;
|
||||
const int nb = k / QK_K;
|
||||
|
||||
for (int i = 0; i < nb; i++) {
|
||||
const uint8_t * q = x[i].qs;
|
||||
float16_t d = 0.0;
|
||||
memcpy(&d, &x[i].d, sizeof(d));
|
||||
float16_t min = 0.0;
|
||||
memcpy(&min, &x[i].dmin, sizeof(d));
|
||||
|
||||
int is = 0;
|
||||
uint8_t sc, m;
|
||||
for (int j = 0; j < QK_K; j += 64) {
|
||||
get_scale_min_k4(is + 0, x[i].scales, &sc, &m);
|
||||
const float16_t d1 = d * sc; const float16_t m1 = min * m;
|
||||
get_scale_min_k4(is + 1, x[i].scales, &sc, &m);
|
||||
const float16_t d2 = d * sc; const float16_t m2 = min * m;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1;
|
||||
for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2;
|
||||
q += 32; is += 2;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
|
||||
*/
|
||||
import "C"
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"unsafe"
|
||||
|
||||
"github.com/x448/float16"
|
||||
)
|
||||
|
||||
func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
shape := append([]C.int{}, final_shape...)
|
||||
var weights_per_byte C.int
|
||||
if dtype == 2 || dtype == 3 {
|
||||
weights_per_byte = 2
|
||||
} else if dtype == 8 {
|
||||
weights_per_byte = 1
|
||||
} else {
|
||||
return r, fmt.Errorf("unsupported tensor type %d", dtype)
|
||||
}
|
||||
|
||||
weights_per_block := C.int(32)
|
||||
if shape[len(shape)-1]%weights_per_block != 0 {
|
||||
return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1])
|
||||
}
|
||||
|
||||
weights_shape := append([]C.int{}, shape...)
|
||||
weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4)
|
||||
w_nbytes := C.int(unsafe.Sizeof(uint32(0)))
|
||||
for i := range weights_shape {
|
||||
w_nbytes *= weights_shape[i]
|
||||
}
|
||||
w_data := make([]byte, w_nbytes)
|
||||
cbytes := C.CBytes(w_data)
|
||||
defer C.free(cbytes)
|
||||
weights := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&weights_shape[0],
|
||||
C.int(len(weights_shape)),
|
||||
C.MLX_UINT32,
|
||||
)
|
||||
|
||||
// For scales and bias
|
||||
shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block
|
||||
sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0)))
|
||||
for i := range shape {
|
||||
sb_nbytes *= shape[i]
|
||||
}
|
||||
|
||||
s_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(s_data)
|
||||
defer C.free(cbytes)
|
||||
scales := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
b_data := make([]byte, sb_nbytes)
|
||||
cbytes = C.CBytes(b_data)
|
||||
defer C.free(cbytes)
|
||||
biases := C.mlx_array_new_data(
|
||||
cbytes,
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
var bits C.int
|
||||
switch dtype {
|
||||
case 2:
|
||||
C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 3:
|
||||
C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 4
|
||||
case 8:
|
||||
C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases)
|
||||
bits = 8
|
||||
}
|
||||
C.mlx_dequantize(
|
||||
&r,
|
||||
weights,
|
||||
scales,
|
||||
biases,
|
||||
32, // group size
|
||||
bits,
|
||||
stream,
|
||||
)
|
||||
C.mlx_array_free(weights)
|
||||
C.mlx_array_free(scales)
|
||||
C.mlx_array_free(biases)
|
||||
|
||||
return r, nil
|
||||
}
|
||||
|
||||
func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) {
|
||||
size := 1
|
||||
for _, d := range shape {
|
||||
size *= int(d)
|
||||
}
|
||||
fdata := make([]float16.Float16, size)
|
||||
switch dtype {
|
||||
case 14:
|
||||
C.dequant_row_q6_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
|
||||
case 12:
|
||||
C.dequant_row_q4_K(
|
||||
data,
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
C.int(size),
|
||||
)
|
||||
default:
|
||||
return r, fmt.Errorf("unsupported K quant")
|
||||
}
|
||||
|
||||
r = C.mlx_array_new_data(
|
||||
unsafe.Pointer(&fdata[0]),
|
||||
&shape[0],
|
||||
C.int(len(shape)),
|
||||
C.MLX_FLOAT16,
|
||||
)
|
||||
return r, nil
|
||||
}
|
@ -1,6 +1,8 @@
|
||||
package nn
|
||||
|
||||
import "github.com/ollama/ollama/ml"
|
||||
import (
|
||||
"github.com/ollama/ollama/ml"
|
||||
)
|
||||
|
||||
type Linear struct {
|
||||
Weight ml.Tensor `gguf:"weight"`
|
||||
|
@ -82,7 +82,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@ -92,7 +92,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen)
|
||||
@ -122,7 +122,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
@ -96,7 +96,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen)
|
||||
q = sa.QueryNorm.Forward(ctx, q, opts.eps)
|
||||
q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
q = q.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
|
||||
if opts.largeModelScaling {
|
||||
q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads)))
|
||||
@ -107,7 +107,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen)
|
||||
k = sa.KeyNorm.Forward(ctx, k, opts.eps)
|
||||
k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
k = k.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen)
|
||||
@ -125,7 +125,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T
|
||||
ropeBase = m.TextOptions.ropeGlobalBase
|
||||
}
|
||||
|
||||
return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
||||
return key.RoPE(ctx, shift, nil, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil
|
||||
}
|
||||
|
||||
type TextMLP struct {
|
||||
|
@ -76,15 +76,14 @@ type SelfAttention struct {
|
||||
func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor {
|
||||
batchSize := hiddenState.Dim(0) // TODO Consider renaming "L" as this is the sequence length, not batch size
|
||||
headDim := opts.hiddenSize / opts.numHeads
|
||||
ropeType := uint32(0)
|
||||
|
||||
q := sa.Query.Forward(ctx, hiddenState)
|
||||
q = q.Reshape(ctx, batchSize, opts.numHeads, -1)
|
||||
q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
q = LlamaRoPE(ctx, q, positionIDs, sa.RopeFactors, opts)
|
||||
|
||||
k := sa.Key.Forward(ctx, hiddenState)
|
||||
k = k.Reshape(ctx, batchSize, opts.numKVHeads, -1)
|
||||
k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
k = LlamaRoPE(ctx, k, positionIDs, sa.RopeFactors, opts)
|
||||
|
||||
v := sa.Value.Forward(ctx, hiddenState)
|
||||
v = v.Reshape(ctx, batchSize, opts.numKVHeads, -1)
|
||||
@ -97,7 +96,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten
|
||||
}
|
||||
|
||||
func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil
|
||||
return LlamaRoPE(ctx, key, shift, m.Layers[layer].SelfAttention.RopeFactors, m.Options), nil
|
||||
}
|
||||
|
||||
type MLP struct {
|
||||
|
82
model/models/llama/utils.go
Normal file
82
model/models/llama/utils.go
Normal file
@ -0,0 +1,82 @@
|
||||
package llama
|
||||
|
||||
import (
|
||||
"math"
|
||||
"sync"
|
||||
|
||||
"github.com/ollama/ollama/ml"
|
||||
"github.com/ollama/ollama/model"
|
||||
)
|
||||
|
||||
func LlamaRoPE(ctx ml.Context, x, positionIDs, ropeFactors ml.Tensor, opts *Options) ml.Tensor {
|
||||
var once sync.Once
|
||||
var _freqs ml.Tensor
|
||||
dims := opts.ropeDim
|
||||
onceBody := func() {
|
||||
// Reference: https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/rope_utils.py#L9
|
||||
|
||||
base := opts.ropeBase // aka rope_scale
|
||||
if base == 0 {
|
||||
base = 10000.0
|
||||
}
|
||||
low_freq_factor := opts.ropeScale // ???
|
||||
high_freq_factor := float32(4.0) // TODO should attempt to get from metadata
|
||||
factor := float32(8.0) // metadata?
|
||||
old_context_len := float32(8192) // metadata? (aka original_max_position_embeddings)
|
||||
|
||||
// Calcs...
|
||||
low_freq_wavelen := float32(old_context_len) / low_freq_factor
|
||||
high_freq_wavelen := float32(old_context_len) / high_freq_factor
|
||||
|
||||
// freqs = base ** (mx.model.ArangeF32(0, dims, 2) / dims)
|
||||
freqs := model.ArangeF32(0, float32(dims), 2)
|
||||
for i := range freqs {
|
||||
freqs[i] = (float32)(math.Pow(float64(base), float64(freqs[i])/float64(dims)))
|
||||
}
|
||||
// wavelens = 2 * mx.pi * freqs
|
||||
wavelens := make([]float32, len(freqs))
|
||||
for i := range wavelens {
|
||||
wavelens[i] = freqs[i] * 2 * float32(math.Pi)
|
||||
}
|
||||
// freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs)
|
||||
for i := range freqs {
|
||||
if wavelens[i] > low_freq_wavelen {
|
||||
freqs[i] = freqs[i] * factor
|
||||
}
|
||||
}
|
||||
// is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen)
|
||||
is_medium_freq := make([]bool, len(freqs))
|
||||
for i := range freqs {
|
||||
is_medium_freq[i] = (wavelens[i] > high_freq_wavelen) && (wavelens[i] < low_freq_wavelen)
|
||||
}
|
||||
// smooth_factors = (old_context_len / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor)
|
||||
smooth_factors := make([]float32, len(freqs))
|
||||
for i := range freqs {
|
||||
smooth_factors[i] = ((old_context_len)/wavelens[i] - (low_freq_factor)) / ((high_freq_factor) - (low_freq_factor))
|
||||
}
|
||||
// smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors)
|
||||
smooth_freqs := make([]float32, len(freqs))
|
||||
for i := range freqs {
|
||||
smooth_freqs[i] = freqs[i] / ((1-smooth_factors[i])/factor + (smooth_factors[i]))
|
||||
}
|
||||
// _freqs = mx.where(is_medium_freq, smooth_freqs, freqs)
|
||||
for i := range freqs {
|
||||
if is_medium_freq[i] {
|
||||
freqs[i] = float32(smooth_freqs[i])
|
||||
}
|
||||
}
|
||||
_freqs, _ = ctx.Input().FromFloatSlice(freqs, len(freqs))
|
||||
}
|
||||
once.Do(onceBody)
|
||||
|
||||
return x.RoPE(
|
||||
ctx,
|
||||
positionIDs,
|
||||
ropeFactors,
|
||||
_freqs,
|
||||
dims,
|
||||
0, // type
|
||||
500000, // base
|
||||
1.0, // scale
|
||||
)
|
||||
}
|
@ -24,11 +24,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
|
||||
query := sa.Query.Forward(ctx, hiddenState)
|
||||
query = query.Reshape(ctx, batchSize, opts.numHeads, headDim)
|
||||
query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
query = query.RoPE(ctx, positions, nil /* TODO freqs */, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
key := sa.Key.Forward(ctx, hiddenState)
|
||||
key = key.Reshape(ctx, batchSize, opts.numKVHeads, headDim)
|
||||
key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
key = key.RoPE(ctx, positions, nil /* TODO freqs */, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale)
|
||||
|
||||
value := sa.Value.Forward(ctx, hiddenState)
|
||||
value = value.Reshape(ctx, batchSize, opts.numKVHeads, headDim)
|
||||
@ -42,8 +42,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m
|
||||
|
||||
func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) {
|
||||
// This will only get called for layers in the cache, which are just the self attention layers
|
||||
if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
||||
if _, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok {
|
||||
// return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil
|
||||
panic("NOT YET IMPLEMENTED")
|
||||
}
|
||||
|
||||
return key, nil
|
||||
|
Loading…
x
Reference in New Issue
Block a user