From c8f346dc46f6d2269e2dc57ae037ac9b03e80ea2 Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Wed, 12 Mar 2025 11:15:18 -0700 Subject: [PATCH] Add MLX Backend POC The cache still has some bugs. --- CMakeLists.txt | 5 + kvcache/causal.go | 101 ++- kvcache/causal_test.go | 2 +- llm/status.go | 1 + ml/backend.go | 35 +- ml/backend/backend.go | 1 + ml/backend/ggml/ggml.go | 10 +- ml/backend/mlx/CMakeLists.txt | 36 + ml/backend/mlx/mlx.go | 1075 +++++++++++++++++++++++++++++ ml/backend/mlx/quant.go | 328 +++++++++ ml/nn/linear.go | 4 +- model/models/gemma2/model.go | 6 +- model/models/gemma3/model_text.go | 6 +- model/models/llama/model.go | 7 +- model/models/llama/utils.go | 82 +++ model/models/mllama/model_text.go | 9 +- 16 files changed, 1651 insertions(+), 57 deletions(-) create mode 100644 ml/backend/mlx/CMakeLists.txt create mode 100644 ml/backend/mlx/mlx.go create mode 100644 ml/backend/mlx/quant.go create mode 100644 model/models/llama/utils.go diff --git a/CMakeLists.txt b/CMakeLists.txt index 034fc7d79..959ce2e38 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -130,3 +130,8 @@ if(CMAKE_HIP_COMPILER) endforeach() endif() endif() + +if(CMAKE_SYSTEM_NAME MATCHES "Darwin") + message(STATUS "Setting up MLX (this takes a while...)") + add_subdirectory(${CMAKE_CURRENT_SOURCE_DIR}/ml/backend/mlx) +endif() \ No newline at end of file diff --git a/kvcache/causal.go b/kvcache/causal.go index 17c4a83a1..560c744e5 100644 --- a/kvcache/causal.go +++ b/kvcache/causal.go @@ -257,6 +257,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { } if c.config.MaskDType != ml.DTypeF32 { + // TODO - MLX not covered here... out := ctx.Input().Empty(c.config.MaskDType, maskTensor.Shape()...) ctx.Forward(maskTensor.Copy(ctx, out)) maskTensor = out @@ -266,6 +267,7 @@ func (c *Causal) buildMask(ctx ml.Context) (ml.Tensor, error) { } func (c *Causal) moveCells(ctx ml.Context, src, dst, len int) { + // TODO this wont work on MLX as is - needs to be adjusted for SliceUpdate for i, key := range c.keys { if key == nil { continue @@ -431,41 +433,48 @@ func (c *Causal) Get(ctx ml.Context) (ml.Tensor, ml.Tensor, ml.Tensor) { value := c.values[c.curLayer] kHeadDim := key.Dim(2) + vHeadDim := value.Dim(2) numKVHeads := key.Dim(1) rowSize := key.Stride(0) cachedSize := c.curMask.Dim(1) - // slog.Info("Get", "kHeadDim", kHeadDim, "numKVHeads", numKVHeads, "rowSize", rowSize, "cachedSize", cachedSize) - key = key.View(ctx, rowSize*c.curCellRange.min, - []int{cachedSize, numKVHeads, kHeadDim}, - []int{key.Stride(0), key.Stride(1)}, - ) - // slog.Info("Get", "key", key) - // panic("XXX") - - if c.config.PermutedV { - vHeadDim := value.Dim(1) - elemSize := value.Stride(2) - - value = value.View(ctx, elemSize*c.curCellRange.min, - []int{numKVHeads, vHeadDim, cachedSize}, - []int{value.Stride(0), value.Stride(1)}, - ) + // Potential abstraction to work around differences in cache tensor handling. + if su, ok := ctx.(ml.SliceUpdate); ok { + start := []int{c.curCellRange.min, 0, 0} + kStop := []int{c.curCellRange.min + cachedSize, numKVHeads, kHeadDim} + vStop := []int{c.curCellRange.min + cachedSize, numKVHeads, vHeadDim} + strides := []int{1, 1, 1} + key = su.Slice(key, start, kStop, strides) + value = su.Slice(value, start, vStop, strides) } else { - vHeadDim := value.Dim(2) - rowSize := value.Stride(0) - - value = value.View(ctx, rowSize*c.curCellRange.min, - []int{cachedSize, numKVHeads, vHeadDim}, - []int{value.Stride(0), value.Stride(1)}, + key = key.View(ctx, rowSize*c.curCellRange.min, + []int{cachedSize, numKVHeads, kHeadDim}, + []int{key.Stride(0), key.Stride(1)}, ) - } - // TODO The mask changes from X,X to 1,X, and with the Row-order change - // the 1 becomes trailing and messes up later operations - // This isn't the right solution, but works around it... - if c.curMask.Dim(1) == 1 { - return key, value, c.curMask.Permute(ctx, 1, 0, 2, 3) + if c.config.PermutedV { + vHeadDim := value.Dim(1) + elemSize := value.Stride(2) + + value = value.View(ctx, elemSize*c.curCellRange.min, + []int{numKVHeads, vHeadDim, cachedSize}, + []int{value.Stride(0), value.Stride(1)}, + ) + } else { + vHeadDim := value.Dim(2) + rowSize := value.Stride(0) + + value = value.View(ctx, rowSize*c.curCellRange.min, + []int{cachedSize, numKVHeads, vHeadDim}, + []int{value.Stride(0), value.Stride(1)}, + ) + } + // TODO The mask changes from X,X to 1,X, and with the Row-order change + // the 1 becomes trailing and messes up later operations + // This isn't the right solution, but works around it... + if c.curMask.Dim(1) == 1 { + return key, value, c.curMask.Permute(ctx, 1, 0, 2, 3) + } } return key, value, c.curMask @@ -495,20 +504,35 @@ func (c *Causal) Put(ctx ml.Context, key, value ml.Tensor) { } else { c.values[c.curLayer] = c.ctxs[c.curLayer].Zeros(c.DType, int(c.Capacity), numKVHeads, vHeadDim) } + // slog.Info("Cache Put", "c.keys[c.curLayer]", c.keys[c.curLayer]) + // slog.Info("Cache Put", "c.values[c.curLayer]", c.values[c.curLayer]) } - rowSize := c.keys[c.curLayer].Stride(0) - ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, []int{kHeadDim * numKVHeads * batchSize}, nil))) - - if c.config.PermutedV { - elemSize := c.values[c.curLayer].Stride(2) - - value = value.Permute(ctx, 1, 2, 0, 3) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, []int{vHeadDim * numKVHeads, batchSize}, []int{int(c.Capacity) * elemSize}))) + // Potential abstraction to work around differences in cache tensor handling. + if su, ok := ctx.(ml.SliceUpdate); ok { + start := []int{c.curLoc, 0, 0} + kStop := []int{c.curLoc + batchSize, numKVHeads, kHeadDim} + vStop := []int{c.curLoc + batchSize, numKVHeads, vHeadDim} + strides := []int{1, 1, 1} + su.SliceUpdate(c.keys[c.curLayer], key, start, kStop, strides) + su.SliceUpdate(c.values[c.curLayer], value, start, vStop, strides) + ctx.Forward(c.keys[c.curLayer]) + ctx.Forward(c.values[c.curLayer]) } else { - rowSize := c.values[c.curLayer].Stride(0) + // GGML pattern + rowSize := c.keys[c.curLayer].Stride(0) + ctx.Forward(key.Copy(ctx, c.keys[c.curLayer].View(ctx, rowSize*c.curLoc, []int{kHeadDim * numKVHeads * batchSize}, nil))) - ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, []int{vHeadDim * numKVHeads * batchSize}, nil))) + if c.config.PermutedV { + elemSize := c.values[c.curLayer].Stride(2) + + value = value.Permute(ctx, 1, 2, 0, 3) + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, elemSize*c.curLoc, []int{vHeadDim * numKVHeads, batchSize}, []int{int(c.Capacity) * elemSize}))) + } else { + rowSize := c.values[c.curLayer].Stride(0) + + ctx.Forward(value.Copy(ctx, c.values[c.curLayer].View(ctx, rowSize*c.curLoc, []int{vHeadDim * numKVHeads * batchSize}, nil))) + } } } @@ -565,6 +589,7 @@ func (c *Causal) shift(seq int, beginIndex, offset int32) error { continue } + // TODO - this also needs adjusting to support MLX with SliceUpdate kHeadDim := key.Dim(2) numKVHeads := key.Dim(1) rowSize := key.Stride(0) diff --git a/kvcache/causal_test.go b/kvcache/causal_test.go index 87cc60558..d06450f00 100644 --- a/kvcache/causal_test.go +++ b/kvcache/causal_test.go @@ -457,7 +457,7 @@ func (t *testTensor) Conv2D(ctx ml.Context, weight ml.Tensor, s0, s1, p0, p1, d0 panic("not implemented") } -func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor { +func (t *testTensor) RoPE(ctx ml.Context, positionIDs, ropeFactors, freqs ml.Tensor, dim, ropeType uint32, base, scale float32) ml.Tensor { panic("not implemented") } diff --git a/llm/status.go b/llm/status.go index 80f44e654..81090848e 100644 --- a/llm/status.go +++ b/llm/status.go @@ -28,6 +28,7 @@ var errorPrefixes = []string{ "error loading model", "GGML_ASSERT", "Deepseek2 does not support K-shift", + "panic:", } func (w *StatusWriter) Write(b []byte) (int, error) { diff --git a/ml/backend.go b/ml/backend.go index 134317750..78817859e 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -4,6 +4,7 @@ import ( "bytes" "encoding/binary" "fmt" + "log/slog" "os" "slices" "strconv" @@ -87,7 +88,13 @@ func RegisterBackend(name string, f func(*os.File, BackendParams) (Backend, erro } func NewBackend(f *os.File, params BackendParams) (Backend, error) { - if backend, ok := backends["ggml"]; ok { + be := os.Getenv("OLLAMA_BACKEND") + if be == "" { + be = "ggml" + slog.Info("Defaulting to " + be + ". Set OLLAMA_BACKEND to override") + } + slog.Info("Loading new engine", "backend", be) + if backend, ok := backends[be]; ok { return backend(f, params) } @@ -122,6 +129,18 @@ type Context interface { Abort(Tensor) // Evaluate the graph up to this point, retrieve the data from the tensor and dump it to a json file for comparison } +// Usage: +// +// if su, ok := ctx.(ml.SliceUpdate); ok { +// su.SliceUpdate(...) +// } else { +// // view + copy operations +// } +type SliceUpdate interface { + SliceUpdate(target, source Tensor, start, stop, strides []int) + Slice(source Tensor, start, stop, strides []int) Tensor +} + type Tensor interface { Dim(n int) int Stride(n int) int @@ -145,7 +164,7 @@ type Tensor interface { AvgPool2D(ctx Context, k, s int, p float32) Tensor Conv2D(ctx Context, weight Tensor, s0, s1, p0, p1, d0, d1 int) Tensor - RoPE(ctx Context, positionIDs, ropeFactors Tensor, dim, ropeType uint32, base, scale float32) Tensor + RoPE(ctx Context, positionIDs, ropeFactors, freqs Tensor, dim, ropeType uint32, base, scale float32) Tensor Tanh(ctx Context) Tensor GELU(ctx Context) Tensor @@ -318,3 +337,15 @@ func (dt DType) String() string { return "unknown" } } + +func (dt DType) Sizeof() int64 { + // TODO call underlying API? + switch dt { + case DTypeF32: + return 4 + case DTypeI32: + return 4 + default: + panic("unrecognized type") + } +} diff --git a/ml/backend/backend.go b/ml/backend/backend.go index 55063fb3b..3fdeec745 100644 --- a/ml/backend/backend.go +++ b/ml/backend/backend.go @@ -2,4 +2,5 @@ package backend import ( _ "github.com/ollama/ollama/ml/backend/ggml" + _ "github.com/ollama/ollama/ml/backend/mlx" ) diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index df2679a36..a2ba78985 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -1053,7 +1053,15 @@ const ( ropeTypeVision C.int = 24 ) -func (t *Tensor) RoPE(ctx ml.Context, positionIDs, ropeFactors ml.Tensor, ropeDim, ropeType uint32, ropeBase, ropeScale float32) ml.Tensor { +func (t *Tensor) RoPE( + ctx ml.Context, + positionIDs ml.Tensor, + ropeFactors ml.Tensor, + freqs ml.Tensor, // Unused on GGML + ropeDim, ropeType uint32, + ropeBase, + ropeScale float32, +) ml.Tensor { if ropeFactors == nil { ropeFactors = &Tensor{b: t.b, nDims: 0} } diff --git a/ml/backend/mlx/CMakeLists.txt b/ml/backend/mlx/CMakeLists.txt new file mode 100644 index 000000000..bcc21195b --- /dev/null +++ b/ml/backend/mlx/CMakeLists.txt @@ -0,0 +1,36 @@ +include(FetchContent) + +set(MLX_C_BUILD_EXAMPLES OFF) + +set(MLX_BUILD_GGUF OFF) +set(MLX_BUILD_SAFETENSORS OFF) + +function(set_target_output_directory _target) + if(TARGET ${_target}) + set_target_properties(${_target} PROPERTIES + RUNTIME_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + LIBRARY_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + ARCHIVE_OUTPUT_DIRECTORY ${CMAKE_BINARY_DIR}/lib + ) + endif() +endfunction() + +execute_process( + COMMAND + zsh "-c" + "echo \"__METAL_VERSION__\" | xcrun -sdk macosx metal ${XCRUN_FLAGS} -E -x metal -P - | tail -1 | tr -d '\n'" + OUTPUT_VARIABLE MLX_METAL_VERSION COMMAND_ERROR_IS_FATAL ANY) + +if(NOT MLX_METAL_VERSION) + message(STATUS "`xcrun metal` error. Setting MLX_BUILD_METAL=OFF") + set(MLX_BUILD_METAL OFF) +endif() + +FetchContent_Declare( + mlx-c + GIT_REPOSITORY "https://github.com/ml-explore/mlx-c.git" + GIT_TAG v0.1.0) +FetchContent_MakeAvailable(mlx-c) + +set_target_output_directory(mlx) +set_target_output_directory(mlxc) diff --git a/ml/backend/mlx/mlx.go b/ml/backend/mlx/mlx.go new file mode 100644 index 000000000..38f364603 --- /dev/null +++ b/ml/backend/mlx/mlx.go @@ -0,0 +1,1075 @@ +package mlx + +/* +#cgo CPPFLAGS: -I${SRCDIR}/../../../build/_deps/mlx-c-src +#cgo LDFLAGS: -L${SRCDIR}/../../../build/lib/ollama/ -lmlxc -lmlx +#cgo LDFLAGS: -framework Accelerate +#cgo LDFLAGS: -Wl,-rpath,${SRCDIR}/../../../build/lib/ollama/ +#include +#include "mlx/c/array.h" +#include "mlx/c/fast.h" +#include "mlx/c/ops.h" +#include "mlx/c/stream.h" +#include "mlx/c/transforms.h" +#include "mlx/c/error.h" +static inline size_t stride(const mlx_array a, int i) {return mlx_array_strides(a)[i];} + +extern void goStackTrace(); +static void error_handler(const char *msg, void* data) { + fprintf(stderr, "MLX error: %s\n", msg); + goStackTrace(); + exit(-1); +} +static void set_error_handler() {mlx_set_error_handler(&error_handler, NULL, NULL);} +static void* mlx_array_data_float16_asvoid(const mlx_array a) {return (void*)mlx_array_data_float16(a);} +*/ +import "C" + +import ( + "bytes" + "encoding/json" + "fmt" + "io" + "log/slog" + "math" + "os" + "runtime" + "runtime/debug" + "strings" + "sync" + "unsafe" + + fs "github.com/ollama/ollama/fs/ggml" + "github.com/ollama/ollama/ml" + "github.com/x448/float16" + "golang.org/x/sync/errgroup" +) + +func init() { + ml.RegisterBackend("mlx", New) + C.set_error_handler() +} + +//export goStackTrace +func goStackTrace() { + debug.PrintStack() +} + +func New(r *os.File, params ml.BackendParams) (ml.Backend, error) { + meta, n, err := fs.Decode(r, -1) + if err != nil { + return nil, err + } + + // TODO all this loading logic will be replaced by the new model loading abstraction, including any necessary transformations + // As currently structured, this likely causes a significant performance impact + + tensors := make(map[string]*Array, len(meta.Tensors().Items())) + sr := io.NewSectionReader(r, int64(meta.Tensors().Offset), n-int64(meta.Tensors().Offset)) + + slog.Info("initializing MLX GPU backend") + stream := C.mlx_default_gpu_stream_new() + + var g errgroup.Group + var mu sync.Mutex + vec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vec) + unmutate := func(name string, shape []C.int, r C.mlx_array) error { + // TODO - is this code memory access safe, or does the delayed processing cause potential memory access after Go frees the stack? + + // TODO performance: Since these operations are ~static yet cause a lot of additional nodes in the graph + // Ideally these should be applied "on the fly" at load time, so the tensor has the data ready to go. + defer C.mlx_array_free(r) + + var n_head uint64 + if strings.Contains(name, "attn_q") { + n_head = meta.KV().HeadCount() // Q + } else { + n_head = meta.KV().HeadCountKV() // K + } + tmpShape := []C.int{C.int(n_head), C.int(math.Floor(math.Floor(float64(shape[0]) / float64(n_head) / float64(2)))), 2, shape[1]} + var shaped C.mlx_array + C.mlx_reshape(&shaped, r, &tmpShape[0], C.size_t(len(tmpShape)), stream) + defer C.mlx_array_free(shaped) + var swapped C.mlx_array + C.mlx_swapaxes( + &swapped, + shaped, + 1, + 2, + stream, + ) + defer C.mlx_array_free(swapped) + + var reshaped C.mlx_array + C.mlx_reshape( + &reshaped, + swapped, + &shape[0], + C.size_t(len(shape)), + stream, + ) + defer C.mlx_array_free(reshaped) + var a C.mlx_array + + C.mlx_transpose_all( + &a, + reshaped, + stream, + ) + mu.Lock() + defer mu.Unlock() + C.mlx_vector_array_append_value(vec, a) + tmp := &Array{a: a, name: name} + tensors[name] = tmp + return nil + } + for _, t := range meta.Tensors().Items() { + g.Go(func() error { + var b bytes.Buffer + n, err := io.Copy(&b, io.NewSectionReader(sr, int64(t.Offset), int64(t.Size()))) + if err != nil { + return err + } + + if n != int64(t.Size()) { + return fmt.Errorf("expected %d bytes, got %d", t.Size(), n) + } + + cbytes := C.CBytes(b.Bytes()) + defer C.free(cbytes) + + // Inverted + shape := make([]C.int, len(t.Shape)) + i := len(t.Shape) - 1 + for _, dim := range t.Shape { + shape[i] = C.int(dim) + i-- + } + var r C.mlx_array + + switch t.Kind { + case 0: // GGML_TYPE_F32 + a := C.mlx_array_new_data( + cbytes, + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT32, + ) + // MLX fp32 ops are significantly slower than fp16 + C.mlx_astype( + &r, + a, + C.MLX_FLOAT16, + stream, + ) + defer C.mlx_array_free(a) + case 1: // GGML_TYPE_F16 + r = C.mlx_array_new_data( + cbytes, + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT16, + ) + case 30: // GGML_TYPE_BF16 + r = C.mlx_array_new_data( + cbytes, + &shape[0], + C.int(len(shape)), + C.MLX_BFLOAT16, + ) + case 2, 8: // GGML_TYPE_Q4_0, GGML_TYPE_Q8_0 + // Note: theoretically GGML_TYPE_Q4_1 (3) should work, but spits out garbage so omitting for now + r, err = gguf_load_quantized(cbytes, t.Name, shape, t.Kind, stream) + if err != nil { + panic(err.Error()) + } + case 12, 14: // GGML_TYPE_Q4_K, GGML_TYPE_Q6_K + // TODO any special cases? + r, err = load_k_quantized(cbytes, t.Name, shape, t.Kind, stream) + if err != nil { + panic(err) + } + default: + return fmt.Errorf("unsupported dtype %v", t) + } + + var a C.mlx_array + + // Q/K are are mutated and we need to reverse that mutation + // TODO - this is only for llama based models and shouldn't be applied universally + // but only applies to some backends at the moment... maybe? + if strings.HasSuffix(t.Name, "attn_q.weight") || strings.HasSuffix(t.Name, "attn_q.bias") || strings.HasSuffix(t.Name, "attn_k.weight") || strings.HasSuffix(t.Name, "attn_k.bias") { + return unmutate(t.Name, shape, r) + } else if strings.Contains(t.Name, "token_embd.weight") { + // TODO bug in model code? Why is this one special compared to all the rest? + a = r + } else { + // TODO performance: this should be done to the data as it's loaded, not add additional operations in the graph + C.mlx_transpose_all( + &a, + r, + stream, + ) + defer C.mlx_array_free(r) + } + mu.Lock() + defer mu.Unlock() + C.mlx_vector_array_append_value(vec, a) + tmp := &Array{a: a, name: t.Name} + tmp.name = t.Name + tensors[t.Name] = tmp + return nil + }) + } + + if err := g.Wait(); err != nil { + return nil, err + } + C.mlx_async_eval(vec) + + return &Backend{ + meta: meta, + tensors: tensors, + }, nil +} + +type Backend struct { + meta *fs.GGML + tensors map[string]*Array +} + +// Config implements ml.Backend. +func (b *Backend) Config() ml.Config { + return b.meta.KV() +} + +// Get implements ml.Backend. +func (b *Backend) Get(name string) ml.Tensor { + if a, ok := b.tensors[name]; ok { + return a + } + + return nil +} + +func (b *Backend) NewContext() ml.Context { + return &Context{ + stream: C.mlx_default_gpu_stream_new(), + } +} + +func (b *Backend) NewContextSize(_ int) ml.Context { + return b.NewContext() +} + +func (b *Backend) SystemInfo() string { + // TODO implement this, maybe from metal.h calls... + return "" +} + +type Context struct { + stream C.mlx_stream + + mu sync.Mutex + arrays []C.mlx_array // TODO should we do some bookkeeping to ensure none of these Arrays are still lingering? +} + +// Close implements ml.Context. +func (c *Context) Close() { + // C.mlx_synchronize(c.stream) // ??? + C.mlx_stream_free(c.stream) + + c.mu.Lock() + defer c.mu.Unlock() + for _, a := range c.arrays { + C.mlx_array_free(a) + } +} + +// Compute implements ml.Context. +func (c *Context) Compute(tensors ...ml.Tensor) { + // TODO - for the zero tensor case this feels like it might not be correct... + needSync := true + sync := func() { + if needSync { + C.mlx_synchronize(c.stream) + needSync = false + } + } + + vec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vec) + for _, t := range tensors { + C.mlx_vector_array_append_value(vec, t.(*Array).a) + t.(*Array).sync = sync + } + C.mlx_async_eval(vec) +} + +// Forward implements ml.Context. +func (c *Context) Forward(tensors ...ml.Tensor) ml.Context { + vec := C.mlx_vector_array_new() + defer C.mlx_vector_array_free(vec) + needSync := true + sync := func() { + if needSync { + C.mlx_synchronize(c.stream) + needSync = false + } + } + + for _, t := range tensors { + t.(*Array).sync = sync + C.mlx_vector_array_append_value(vec, t.(*Array).a) + } + C.mlx_async_eval(vec) + return c +} + +// FromFloatSlice implements ml.Context. +func (c *Context) FromFloatSlice(s []float32, shape ...int) (ml.Tensor, error) { + u16s := make([]float16.Float16, len(s)) + for i := range u16s { + u16s[i] = float16.Fromfloat32(s[i]) + } + cshape := make([]C.int, len(shape)) + for i, dim := range shape { + cshape[i] = C.int(dim) + } + return newArray(c, + C.mlx_array_new_data( + unsafe.Pointer(&u16s[0]), + &cshape[0], + C.int(len(cshape)), + C.MLX_FLOAT16, + ), + ), nil +} + +// FromIntSlice implements ml.Context. +func (c *Context) FromIntSlice(s []int32, shape ...int) (ml.Tensor, error) { + cshape := make([]C.int, len(shape)) + for i, dim := range shape { + cshape[i] = C.int(dim) + } + return newArray(c, + C.mlx_array_new_data( + unsafe.Pointer(&s[0]), + &cshape[0], + C.int(len(cshape)), + C.MLX_INT32, + ), + ), nil +} + +func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { + // TODO more efficient impl? + return c.Zeros(dtype, shape...) +} + +// Zeros implements ml.Context. +func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { + if len(shape) < 1 || len(shape) > 4 { + panic("unsupported number of dimensions") + } + for _, dim := range shape { + if dim < 1 { + panic("invalid shape") + } + } + var dt C.mlx_dtype + switch dtype { + case ml.DTypeF32: + // TODO should we just force this to fp16? + dt = C.MLX_FLOAT32 + case ml.DTypeF16: + dt = C.MLX_FLOAT16 + case ml.DTypeI32: + dt = C.MLX_INT32 + default: + panic(fmt.Sprintf("unsupported dtype %d", dtype)) + } + sh := make([]C.int, len(shape)) + for i, s := range shape { + sh[i] = (C.int)(s) + } + + var r C.mlx_array + C.mlx_zeros( + &r, + &sh[0], + (C.size_t)(len(sh)), + dt, + c.stream, + ) + return newArray(c, r) +} + +func (c *Context) MaxGraphNodes() int { + // TODO actually wire up correctly + return 9999 +} + +func (c *Context) Input() ml.Context { + return c +} + +func (c *Context) Output() ml.Context { + return c +} + +func (c *Context) Layer(_ int) ml.Context { + return c +} + +type Array struct { + name string + a C.mlx_array + + sync func() +} + +func newArray(ctx *Context, a C.mlx_array) *Array { + // TODO measure impact and if this slows things down, make it conditional on some debugging flag at load time + var name string + _, f, l, ok := runtime.Caller(2) + if ok { + name = fmt.Sprintf("%s:%d", f, l) + } + + t := &Array{ + name: name, + a: a, + } + ctx.mu.Lock() + defer ctx.mu.Unlock() + ctx.arrays = append(ctx.arrays, a) + return t +} + +func (a *Array) LogValue() slog.Value { + // TODO this forces eval on every log message - find a pattern to make this configurable to aid in debugging + // str := C.mlx_string_new() + // C.mlx_array_tostring(&str, a.a) + // s := C.mlx_string_data(str) + // defer C.mlx_string_free(str) + // fmt.Println(C.GoString(s)) + dims := int(C.mlx_array_ndim(a.a)) + strides := make([]int, dims) + for i := range strides { + strides[i] = int(C.stride(a.a, (C.int)(i))) + } + + return slog.GroupValue( + slog.String("name", a.name), + slog.String("type", a.TypeString()), + slog.Any("shape", a.Shape()), + slog.Any("strides", strides), + // slog.String("values", C.GoString(s)), + ) +} + +// Add implements ml.Tensor. +func (a *Array) Add(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + var r C.mlx_array + C.mlx_add( + &r, + a.a, + a2.(*Array).a, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Bytes implements ml.Tensor. +func (a *Array) Bytes() []byte { + if a.sync != nil { + a.sync() + } + + l := (int)(C.mlx_array_nbytes(a.a)) + data := C.mlx_array_data_uint8(a.a) + if data == nil { + return nil + } + return unsafe.Slice((*byte)(data), l) +} + +// Concat implements ml.Tensor. +func (a *Array) Concat(ctx ml.Context, a2 ml.Tensor, dim int) ml.Tensor { + panic("unimplemented") +} + +// Contiguous implements ml.Tensor. +func (a *Array) Contiguous(ctx ml.Context) ml.Tensor { + var r C.mlx_array + C.mlx_contiguous( + &r, + a.a, + true, // TODO ??? + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Conv2D implements ml.Tensor. +func (a *Array) Conv2D(ctx ml.Context, weight ml.Tensor, s0 int, s1 int, p0 int, p1 int, d0 int, d1 int) ml.Tensor { + var r C.mlx_array + C.mlx_conv2d( + &r, + a.a, + weight.(*Array).a, + C.int(s0), + C.int(s1), + C.int(p0), + C.int(p1), + C.int(d0), + C.int(d1), + 1, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Copy implements ml.Tensor. +func (a *Array) Copy(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + C.mlx_copy( + &a2.(*Array).a, + a.a, + ctx.(*Context).stream, + ) + // TODO - view? + return newArray(ctx.(*Context), a2.(*Array).a) +} + +// DType implements ml.Tensor. +func (a *Array) DType() ml.DType { + switch C.mlx_array_dtype(a.a) { + // case C.MLX_BOOL: + // case C.MLX_UINT8: + // case C.MLX_UINT16: + // case C.MLX_UINT32: + // case C.MLX_UINT64: + // case C.MLX_INT8: + // case C.MLX_INT16: + case C.MLX_INT32: + return ml.DTypeI32 + // case C.MLX_INT64: + case C.MLX_FLOAT16: + return ml.DTypeF16 + case C.MLX_FLOAT32: + return ml.DTypeF32 + default: + panic("unsupported dtype") + } +} + +// Dim implements ml.Tensor. +func (a *Array) Dim(n int) int { + return int(C.mlx_array_dim(a.a, C.int(n))) +} + +// Floats implements ml.Tensor. +func (a *Array) Floats() []float32 { + if a.sync != nil { + a.sync() + } + l := (int)(C.mlx_array_size(a.a)) + + switch C.mlx_array_dtype(a.a) { + case C.MLX_BFLOAT16: + panic("bfloat16 not yet implemented") + case C.MLX_FLOAT16: + data := C.mlx_array_data_float16_asvoid(a.a) + if data == nil { + panic("nil data, wasn't eval'd") + } + u16s := unsafe.Slice((*uint16)(data), l) + f32s := make([]float32, len(u16s)) + for i := range u16s { + f32s[i] = float16.Frombits(u16s[i]).Float32() + } + return f32s + case C.MLX_FLOAT32: + data := C.mlx_array_data_float32(a.a) + if data == nil { + panic("nil data, wasn't eval'd") + } + f32s := unsafe.Slice((*float32)(data), l) + return f32s + default: + panic(fmt.Sprintf("unsupported dtype for Floats: %d", C.mlx_array_dtype(a.a))) + } +} + +// GELU implements ml.Tensor. +func (a *Array) GELU(ctx ml.Context) ml.Tensor { + panic("unimplemented") +} + +// Mul implements ml.Tensor. +func (a *Array) Mul(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + var r C.mlx_array + C.mlx_multiply( + &r, + a.a, + a2.(*Array).a, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Mulmat implements ml.Tensor. +func (a *Array) Mulmat(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + var r C.mlx_array + s := a.Shape() + strides := make([]int, len(s)) + for i := range s { + strides[i] = a.Stride(i) + } + sb := a2.Shape() + stridesb := make([]int, len(sb)) + for i := range sb { + stridesb[i] = a2.Stride(i) + } + C.mlx_matmul(&r, + a2.(*Array).a, + a.a, + ctx.(*Context).stream) + return newArray(ctx.(*Context), r) +} + +func (a *Array) MulmatFullPrec(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + return a.Mulmat(ctx, a2) +} + +// LayerNorm implements ml.Tensor. +func (a *Array) LayerNorm(ctx ml.Context, w, b ml.Tensor, eps float32) ml.Tensor { + var r C.mlx_array + C.mlx_fast_layer_norm( + &r, + a.a, + w.(*Array).a, + b.(*Array).a, + C.float(eps), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Pad implements ml.Tensor. +func (a *Array) Pad(ctx ml.Context, shape ...int) ml.Tensor { + panic("unimplemented") +} + +// Permute implements ml.Tensor. +func (a *Array) Permute(ctx ml.Context, shape ...int) ml.Tensor { + ndim := min(C.mlx_array_ndim(a.a), C.size_t(len(shape))) + var r C.mlx_array + sh := make([]C.int, ndim) + for i := range ndim { + sh[i] = (C.int)(shape[i]) + if int(sh[i]) >= int(ndim) { + slog.Error("Permute error", "tensor", a, "shape", shape) + panic("invalid pemute call") + } + } + C.mlx_transpose( + &r, + a.a, + &sh[0], + ndim, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// RMSNorm implements ml.Tensor. +func (a *Array) RMSNorm(ctx ml.Context, w ml.Tensor, eps float32) ml.Tensor { + var r C.mlx_array + C.mlx_fast_rms_norm( + &r, + a.a, + w.(*Array).a, + C.float(eps), + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Reshape implements ml.Tensor. +func (a *Array) Reshape(ctx ml.Context, shape ...int) ml.Tensor { + cshape := make([]C.int, len(shape)) + for i, dim := range shape { + cshape[i] = C.int(dim) + } + var r C.mlx_array + C.mlx_reshape(&r, a.a, &cshape[0], C.size_t(len(cshape)), ctx.(*Context).stream) + return newArray(ctx.(*Context), r) +} + +/* MLX breadcrumb for Fast RoPE +a (array) – Input array. +dims (int) – The feature dimensions to be rotated. If the input feature is larger than dims then the rest is left unchanged. +traditional (bool) – If set to True choose the traditional implementation which rotates consecutive dimensions. +base (float, optional) – The base used to compute angular frequency for each dimension in the positional encodings. Exactly one of base and freqs must be None. +scale (float) – The scale used to scale the positions. +offset (int or array) – The position offset to start at. +freqs (array, optional) – Optional frequencies to use with RoPE. If set, the base parameter must be None. Default: None. +*/ + +// Rope implements ml.Tensor. +func (a *Array) RoPE( + ctx ml.Context, + positionIDs ml.Tensor, // Unused in MLX + ropeFactors ml.Tensor, // Unused in MLX + freqs ml.Tensor, + dim uint32, + ropeType uint32, + base float32, + scale float32, +) ml.Tensor { + a = a.Reshape(ctx, append([]int{1}, a.Shape()...)...).Permute(ctx, 0, 2, 1, 3).(*Array) + // TODO figure out how to get offset wired up + offset := 0 + var r C.mlx_array + var b C.mlx_optional_float + var _freqs C.mlx_array + if base == 0 { + base = 10000 + } + if freqs == nil || len(freqs.Shape()) == 0 { + b.value = C.float(base) + b.has_value = true + } else { + _freqs = freqs.(*Array).a + } + + C.mlx_fast_rope( + &r, + a.a, + C.int(dim), + false, // traditional=false + b, + C.float(scale), + C.int(offset), + _freqs, + ctx.(*Context).stream, + ) + + res := newArray(ctx.(*Context), r).Permute(ctx, 0, 2, 1, 3) + return res.Reshape(ctx, res.Shape()[1:]...) +} + +// Rows implements ml.Tensor. +func (a *Array) Rows(ctx ml.Context, a2 ml.Tensor) ml.Tensor { + var r C.mlx_array + + // HACK! + // If the indicies is greater than 2 dimensions, assume axis 1 + var axis C.int + if C.mlx_array_ndim(a2.(*Array).a) > 1 { + axis = 1 + } else { + axis = 0 + } + C.mlx_take(&r, a.a, a2.(*Array).a, axis, ctx.(*Context).stream) + return newArray(ctx.(*Context), r) +} + +// SILU implements ml.Tensor. +func (a *Array) SILU(ctx ml.Context) ml.Tensor { + var sig C.mlx_array + C.mlx_sigmoid( + &sig, + a.a, + ctx.(*Context).stream, + ) + var r C.mlx_array + C.mlx_multiply( + &r, + a.a, + sig, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Scale implements ml.Tensor. +func (a *Array) Scale(ctx ml.Context, s float64) ml.Tensor { + scale := C.mlx_array_new_float(C.float(s)) + var r C.mlx_array + C.mlx_multiply( + &r, + a.a, + scale, + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Shape implements ml.Tensor. +func (a *Array) Shape() []int { + shape := make([]int, C.mlx_array_ndim(a.a)) + for i := range shape { + shape[i] = int(C.mlx_array_dim(a.a, C.int(i))) + } + + return shape +} + +// Softmax implements ml.Tensor. +func (a *Array) Softmax(ctx ml.Context) ml.Tensor { + var r C.mlx_array + axes := []C.int{-1} + C.mlx_softmax( + &r, + a.a, + &axes[0], + C.size_t(len(axes)), + false, // TODO - precise? + ctx.(*Context).stream, + ) + return newArray(ctx.(*Context), r) +} + +// Stack implements ml.Tensor. +func (a *Array) Stack(ctx ml.Context, dim int, s ...ml.Tensor) ml.Tensor { + panic("unimplemented") +} + +// Stride implements ml.Tensor. +func (a *Array) Stride(n int) int { + return (int)(C.stride(a.a, (C.int)(n))) +} + +// Tanh implements ml.Tensor. +func (a *Array) Tanh(ctx ml.Context) ml.Tensor { + panic("unimplemented") +} + +// Unpad implements ml.Tensor. +func (a *Array) Unpad(ctx ml.Context, shape ...int) ml.Tensor { + panic("unimplemented") +} + +// View implements ml.Tensor. +func (a *Array) View(ctx ml.Context, offset int, shape []int, stride []int) ml.Tensor { + if len(stride)+1 != len(shape) { + panic(fmt.Sprintf("malformed view request: shape=%v stride=%v", shape, stride)) + } + + var r C.mlx_array + var sh []C.int + var st []C.size_t + var stp *C.size_t + switch len(shape) { + case 1: + sh = []C.int{ + C.int(shape[0]), + } + case 2: + sh = []C.int{ + C.int(shape[0]), + C.int(shape[1]), + } + // st = []C.size_t{ + // C.size_t(stride[0]), + // } + case 3: + sh = []C.int{ + C.int(shape[0]), + C.int(shape[1]), + C.int(shape[2]), + } + // st = []C.size_t{ + // C.size_t(stride[0]), + // C.size_t(stride[1]), + // } + case 4: + sh = []C.int{ + C.int(shape[0]), + C.int(shape[1]), + C.int(shape[2]), + C.int(shape[3]), + } + // st = []C.size_t{ + // C.size_t(stride[0]), + // C.size_t(stride[1]), + // C.size_t(stride[2]), + // } + default: + panic("unsupported number of dimensions") + } + if len(st) > 0 { + stp = (*C.size_t)(unsafe.Pointer(&st[0])) + } + C.mlx_as_strided( + &r, + a.a, + (*C.int)(unsafe.Pointer(&sh[0])), + C.size_t(len(sh)), + stp, + C.size_t(len(st)), + C.size_t(offset), + ctx.(*Context).stream, + ) + + return newArray(ctx.(*Context), r) +} + +func (t *Array) ScaledDotProductAttention(ctx ml.Context, keys, values, mask ml.Tensor, scale float64) ml.Tensor { + var r C.mlx_array + var m C.mlx_array + if mask != nil { + m = mask.(*Array).a + } + + queries := t.Reshape(ctx, append([]int{1}, t.Shape()...)...).Permute(ctx, 0, 2, 1, 3) + keys = keys.Reshape(ctx, append([]int{1}, keys.Shape()...)...).Permute(ctx, 0, 2, 1, 3) + values = values.Reshape(ctx, append([]int{1}, values.Shape()...)...).Permute(ctx, 0, 2, 1, 3) + + C.mlx_fast_scaled_dot_product_attention( + &r, + queries.(*Array).a, + keys.(*Array).a, + values.(*Array).a, + C.float(scale), + m, + C.mlx_optional_int{}, + ctx.(*Context).stream, + ) + res := newArray(ctx.(*Context), r) + return res.Reshape(ctx, append([]int{}, res.Shape()[1:]...)...).Permute(ctx, 1, 0, 2, 3) +} + +func (t Array) AvgPool2D(ctx ml.Context, k, s int, p float32) ml.Tensor { + panic("NOT YET IMPLEMENTED") +} + +func (t Array) Set(ctx ml.Context, t2 ml.Tensor, offset int, strides ...int) ml.Tensor { + panic("NOT YET IMPLEMENTED") +} + +func (ctx *Context) SliceUpdate(target, source ml.Tensor, start, stop, strides []int) { + t := target.(*Array) + cStart := make([]C.int, len(start)) + for i := range start { + cStart[i] = C.int(start[i]) + } + cStop := make([]C.int, len(stop)) + for i := range stop { + cStop[i] = C.int(stop[i]) + } + cStrides := make([]C.int, len(strides)) + for i := range strides { + cStrides[i] = C.int(strides[i]) + } + var r C.mlx_array + C.mlx_slice_update( + &r, + t.a, + source.(*Array).a, + (*C.int)(unsafe.Pointer(&cStart[0])), + C.size_t(len(cStart)), + (*C.int)(unsafe.Pointer(&cStop[0])), + C.size_t(len(cStop)), + (*C.int)(unsafe.Pointer(&cStrides[0])), + C.size_t(len(cStrides)), + ctx.stream, + ) + // Release the old array and replace with the new one to ensure the same underlying buffer is used + C.mlx_array_free(t.a) + t.a = r +} + +func (ctx *Context) Slice(source ml.Tensor, start, stop, strides []int) ml.Tensor { + cStart := make([]C.int, len(start)) + for i := range start { + cStart[i] = C.int(start[i]) + } + cStop := make([]C.int, len(stop)) + for i := range stop { + cStop[i] = C.int(stop[i]) + } + cStrides := make([]C.int, len(strides)) + for i := range strides { + cStrides[i] = C.int(strides[i]) + } + var r C.mlx_array + C.mlx_slice( + &r, + source.(*Array).a, + (*C.int)(unsafe.Pointer(&cStart[0])), + C.size_t(len(cStart)), + (*C.int)(unsafe.Pointer(&cStop[0])), + C.size_t(len(cStop)), + (*C.int)(unsafe.Pointer(&cStrides[0])), + C.size_t(len(cStrides)), + ctx.stream, + ) + return newArray(ctx, r) +} + +// TODO remove this before merging - temporary debugging aid +func (c *Context) Abort(t ml.Tensor) { + // str := C.mlx_string_new() + // C.mlx_array_tostring(&str, t.(*Array).a) + // s := C.mlx_string_data(str) + // defer C.mlx_string_free(str) + debug.PrintStack() + // fmt.Printf("shape%v\n", t.Shape()) + // fmt.Println(C.GoString(s)) + + c.Compute(t) + f32 := t.Floats() + + filename := os.Getenv("OLLAMA_BACKEND") + ".json" + slog.Info("Writing tensors to", "filename", filename) + f, err := os.Create(filename) + if err != nil { + panic(err) + } + defer f.Close() + encoder := json.NewEncoder(f) + err = encoder.Encode(f32) + if err != nil { + panic(err) + } + + os.Exit(1) +} + +func (a *Array) TypeString() string { + switch C.mlx_array_dtype(a.a) { + case C.MLX_BOOL: + return "bool" + case C.MLX_UINT8: + return "uint8" + case C.MLX_UINT16: + return "uint16" + case C.MLX_UINT32: + return "uint32" + case C.MLX_UINT64: + return "uint64" + case C.MLX_INT8: + return "int8" + case C.MLX_INT16: + return "int16" + case C.MLX_INT32: + return "int32" + case C.MLX_INT64: + return "int64" + case C.MLX_FLOAT16: + return "float16" + case C.MLX_FLOAT32: + return "float32" + case C.MLX_BFLOAT16: + return "bfloat16" + case C.MLX_COMPLEX64: + return "complex64" + default: + return "unknown" + } +} diff --git a/ml/backend/mlx/quant.go b/ml/backend/mlx/quant.go new file mode 100644 index 000000000..e447eacd8 --- /dev/null +++ b/ml/backend/mlx/quant.go @@ -0,0 +1,328 @@ +package mlx + +/* +#include +#include + +#include "mlx/c/array.h" +#include "mlx/c/ops.h" + +// Derived from https://github.com/ml-explore/mlx/blob/main/mlx/io/gguf_quants.cpp + +void unpack_32_4(uint8_t* data, int8_t* dst) { + memset(dst, 0, 16); + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j + 2] & 0x0F); // j+2 to skip scale bytes. + if (j % 2 != 0) { + x <<= 4; + } + dst[j / 2] += x; + } + // Last 16 weights are in the higher bits + for (int j = 0; j < 16; ++j) { + uint8_t x = (data[j + 2] >> 4); + if (j % 2 != 0) { + x <<= 4; + } + dst[8 + j / 2] += x; + } +} + +// Extracts (weight, scales, biases) from Q4_0 tensors. +// Data layout is: |16 bit scale|32 x 4bit weights|. +void extract_q4_0_data( + uint8_t* data, + mlx_array* weights_arr, + mlx_array* scales_arr, + mlx_array* biases_arr) { + const uint64_t bytes_per_block = 18; // 2 bytes scale, 32x0.5 byte weights + uint8_t* weights = mlx_array_data_uint8(*weights_arr); + float16_t* scales = mlx_array_data_float16(*scales_arr); + float16_t* biases = mlx_array_data_float16(*biases_arr); + for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { + scales[i] = *((float16_t*)data); + biases[i] = -8 * scales[i]; + unpack_32_4(data, weights); + weights += 16; + data += bytes_per_block; + } +} + +// Extracts (weight, scales, biases) from Q4_1 tensors. +// Data layout is: |16 bit scale|16 bit bias|32 x 4bit weights|. +void extract_q4_1_data( + uint8_t* data, + mlx_array* weights_arr, + mlx_array* scales_arr, + mlx_array* biases_arr) { + const uint64_t bytes_per_block = 20; // 2 bytes scale, 2 bytes bias, 32x0.5 byte weights + uint8_t* weights = mlx_array_data_uint8(*weights_arr); + float16_t* scales = mlx_array_data_float16(*scales_arr); + float16_t* biases = mlx_array_data_float16(*biases_arr); + for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { + scales[i] = *((float16_t*)data); + biases[i] = *((float16_t*)(data) + 1); + unpack_32_4(data, weights); + weights += 16; + data += bytes_per_block; + } +} + +// Extracts (weight, scales, biases) from Q8_0 tensors. +// Data layout is: |16 bit scale|32 x 8bit weights|. +void extract_q8_0_data( + uint8_t* data, + mlx_array* weights_arr, + mlx_array* scales_arr, + mlx_array* biases_arr) { + const uint64_t weights_per_block = 32; + const uint64_t bytes_per_block = 34; // 2 bytes scale, 32x1 byte weights + uint8_t* weights = mlx_array_data_uint8(*weights_arr); + float16_t* scales = mlx_array_data_float16(*scales_arr); + float16_t* biases = mlx_array_data_float16(*biases_arr); + for (int64_t i = 0; i < mlx_array_size(*scales_arr); i++) { + uint8_t* block_data = data + i * bytes_per_block; + scales[i] = *((float16_t*)block_data); + biases[i] = -128 * scales[i]; + for (int64_t j = 0; j < weights_per_block; ++j) { + uint8_t x = block_data[j + 2]; // j+2 to skip the scale bytes. + // Original data is in int8_t, so we add a bias of -128 and invert the + // first bit. + x ^= 1 << 7; + weights[i * weights_per_block + j] = x; + } + } +} + +// Drived from ggml-quants.c + +#define QK_K 256 + +// 6-bit quantization +// weight is represented as x = a * q +// 16 blocks of 16 elements each +// Effectively 6.5625 bits per weight +typedef struct { + uint8_t ql[QK_K/2]; // quants, lower 4 bits + uint8_t qh[QK_K/4]; // quants, upper 2 bits + int8_t scales[QK_K/16]; // scales, quantized with 8 bits + uint16_t d; // super-block scale +} block_q6_K; + +void dequant_row_q6_K(const void * restrict vx, void * restrict vy, int k) { + const int64_t nb = k / QK_K; + block_q6_K *x = (block_q6_K *)vx; + float16_t* y = (float16_t *)vy; + + for (int i = 0; i < nb; i++) { + float16_t d = 0.0; + memcpy(&d, &x[i].d, sizeof(d)); + + const uint8_t * restrict ql = x[i].ql; + const uint8_t * restrict qh = x[i].qh; + const int8_t * restrict sc = x[i].scales; + + for (int n = 0; n < QK_K; n += 128) { + for (int l = 0; l < 32; ++l) { + int is = l/16; + const int8_t q1 = (int8_t)((ql[l + 0] & 0xF) | (((qh[l] >> 0) & 3) << 4)) - 32; + const int8_t q2 = (int8_t)((ql[l + 32] & 0xF) | (((qh[l] >> 2) & 3) << 4)) - 32; + const int8_t q3 = (int8_t)((ql[l + 0] >> 4) | (((qh[l] >> 4) & 3) << 4)) - 32; + const int8_t q4 = (int8_t)((ql[l + 32] >> 4) | (((qh[l] >> 6) & 3) << 4)) - 32; + y[l + 0] = d * sc[is + 0] * q1; + y[l + 32] = d * sc[is + 2] * q2; + y[l + 64] = d * sc[is + 4] * q3; + y[l + 96] = d * sc[is + 6] * q4; + } + y += 128; + ql += 64; + qh += 32; + sc += 8; + } + } +} + +#define K_SCALE_SIZE 12 +#define GGML_COMMON_AGGR_U +#define GGML_COMMON_AGGR_S + +// 4-bit quantization +// 8 blocks of 32 elements each +// weight is represented as x = a * q + b +// Effectively 4.5 bits per weight +typedef struct { + union { + struct { + uint16_t d; // super-block scale for quantized scales + uint16_t dmin; // super-block scale for quantized mins + } GGML_COMMON_AGGR_S; + uint16_t dm; + } GGML_COMMON_AGGR_U; + uint8_t scales[K_SCALE_SIZE]; // scales and mins, quantized with 6 bits + uint8_t qs[QK_K/2]; // 4--bit quants +} block_q4_K; + +static inline void get_scale_min_k4(int j, const uint8_t * restrict q, uint8_t * restrict d, uint8_t * restrict m) { + if (j < 4) { + *d = q[j] & 63; *m = q[j + 4] & 63; + } else { + *d = (q[j+4] & 0xF) | ((q[j-4] >> 6) << 4); + *m = (q[j+4] >> 4) | ((q[j-0] >> 6) << 4); + } +} + +void dequant_row_q4_K(const void * restrict vx, void * restrict vy, int k) { + block_q4_K *x = (block_q4_K *)vx; + float16_t* y = (float16_t *)vy; + const int nb = k / QK_K; + + for (int i = 0; i < nb; i++) { + const uint8_t * q = x[i].qs; + float16_t d = 0.0; + memcpy(&d, &x[i].d, sizeof(d)); + float16_t min = 0.0; + memcpy(&min, &x[i].dmin, sizeof(d)); + + int is = 0; + uint8_t sc, m; + for (int j = 0; j < QK_K; j += 64) { + get_scale_min_k4(is + 0, x[i].scales, &sc, &m); + const float16_t d1 = d * sc; const float16_t m1 = min * m; + get_scale_min_k4(is + 1, x[i].scales, &sc, &m); + const float16_t d2 = d * sc; const float16_t m2 = min * m; + for (int l = 0; l < 32; ++l) *y++ = d1 * (q[l] & 0xF) - m1; + for (int l = 0; l < 32; ++l) *y++ = d2 * (q[l] >> 4) - m2; + q += 32; is += 2; + } + } +} + + + +*/ +import "C" + +import ( + "fmt" + "unsafe" + + "github.com/x448/float16" +) + +func gguf_load_quantized(data unsafe.Pointer, name string, final_shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) { + shape := append([]C.int{}, final_shape...) + var weights_per_byte C.int + if dtype == 2 || dtype == 3 { + weights_per_byte = 2 + } else if dtype == 8 { + weights_per_byte = 1 + } else { + return r, fmt.Errorf("unsupported tensor type %d", dtype) + } + + weights_per_block := C.int(32) + if shape[len(shape)-1]%weights_per_block != 0 { + return r, fmt.Errorf("[load_gguf] tensor has incompatible last dim shape: %d", shape[len(shape)-1]) + } + + weights_shape := append([]C.int{}, shape...) + weights_shape[len(weights_shape)-1] /= (weights_per_byte * 4) + w_nbytes := C.int(unsafe.Sizeof(uint32(0))) + for i := range weights_shape { + w_nbytes *= weights_shape[i] + } + w_data := make([]byte, w_nbytes) + cbytes := C.CBytes(w_data) + defer C.free(cbytes) + weights := C.mlx_array_new_data( + cbytes, + &weights_shape[0], + C.int(len(weights_shape)), + C.MLX_UINT32, + ) + + // For scales and bias + shape[len(shape)-1] = shape[len(shape)-1] / weights_per_block + sb_nbytes := C.int(unsafe.Sizeof(float16.Float16(0))) + for i := range shape { + sb_nbytes *= shape[i] + } + + s_data := make([]byte, sb_nbytes) + cbytes = C.CBytes(s_data) + defer C.free(cbytes) + scales := C.mlx_array_new_data( + cbytes, + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT16, + ) + b_data := make([]byte, sb_nbytes) + cbytes = C.CBytes(b_data) + defer C.free(cbytes) + biases := C.mlx_array_new_data( + cbytes, + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT16, + ) + var bits C.int + switch dtype { + case 2: + C.extract_q4_0_data((*C.uint8_t)(data), &weights, &scales, &biases) + bits = 4 + case 3: + C.extract_q4_1_data((*C.uint8_t)(data), &weights, &scales, &biases) + bits = 4 + case 8: + C.extract_q8_0_data((*C.uint8_t)(data), &weights, &scales, &biases) + bits = 8 + } + C.mlx_dequantize( + &r, + weights, + scales, + biases, + 32, // group size + bits, + stream, + ) + C.mlx_array_free(weights) + C.mlx_array_free(scales) + C.mlx_array_free(biases) + + return r, nil +} + +func load_k_quantized(data unsafe.Pointer, name string, shape []C.int, dtype uint32, stream C.mlx_stream) (r C.mlx_array, err error) { + size := 1 + for _, d := range shape { + size *= int(d) + } + fdata := make([]float16.Float16, size) + switch dtype { + case 14: + C.dequant_row_q6_K( + data, + unsafe.Pointer(&fdata[0]), + C.int(size), + ) + + case 12: + C.dequant_row_q4_K( + data, + unsafe.Pointer(&fdata[0]), + C.int(size), + ) + default: + return r, fmt.Errorf("unsupported K quant") + } + + r = C.mlx_array_new_data( + unsafe.Pointer(&fdata[0]), + &shape[0], + C.int(len(shape)), + C.MLX_FLOAT16, + ) + return r, nil +} diff --git a/ml/nn/linear.go b/ml/nn/linear.go index 3985dd6c8..aa711ccc1 100644 --- a/ml/nn/linear.go +++ b/ml/nn/linear.go @@ -1,6 +1,8 @@ package nn -import "github.com/ollama/ollama/ml" +import ( + "github.com/ollama/ollama/ml" +) type Linear struct { Weight ml.Tensor `gguf:"weight"` diff --git a/model/models/gemma2/model.go b/model/models/gemma2/model.go index 6087f8d79..9bd483be6 100644 --- a/model/models/gemma2/model.go +++ b/model/models/gemma2/model.go @@ -82,7 +82,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -92,7 +92,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, opts.ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen) @@ -122,7 +122,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil + return key.RoPE(ctx, shift, nil, nil, uint32(m.Options.attnKeyLen), uint32(2), m.Options.ropeBase, m.Options.ropeScale), nil } type MLP struct { diff --git a/model/models/gemma3/model_text.go b/model/models/gemma3/model_text.go index 6a57fc907..05042e740 100644 --- a/model/models/gemma3/model_text.go +++ b/model/models/gemma3/model_text.go @@ -96,7 +96,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, batchSize, opts.numHeads, opts.attnKeyLen) q = sa.QueryNorm.Forward(ctx, q, opts.eps) - q = q.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + q = q.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) if opts.largeModelScaling { q = q.Scale(ctx, 1.0/math.Sqrt(float64(opts.hiddenSize/opts.numHeads))) @@ -107,7 +107,7 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, layer int, hiddenState, pos k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnKeyLen) k = sa.KeyNorm.Forward(ctx, k, opts.eps) - k = k.RoPE(ctx, positionIDs, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) + k = k.RoPE(ctx, positionIDs, nil, nil, uint32(opts.attnKeyLen), ropeType, ropeBase, opts.ropeScale) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, batchSize, opts.numKVHeads, opts.attnValLen) @@ -125,7 +125,7 @@ func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.T ropeBase = m.TextOptions.ropeGlobalBase } - return key.RoPE(ctx, shift, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil + return key.RoPE(ctx, shift, nil, nil, uint32(m.TextOptions.attnKeyLen), uint32(2), ropeBase, m.TextOptions.ropeScale), nil } type TextMLP struct { diff --git a/model/models/llama/model.go b/model/models/llama/model.go index 452edadd9..bad60574a 100644 --- a/model/models/llama/model.go +++ b/model/models/llama/model.go @@ -76,15 +76,14 @@ type SelfAttention struct { func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Tensor, cache kvcache.Cache, opts *Options) ml.Tensor { batchSize := hiddenState.Dim(0) // TODO Consider renaming "L" as this is the sequence length, not batch size headDim := opts.hiddenSize / opts.numHeads - ropeType := uint32(0) q := sa.Query.Forward(ctx, hiddenState) q = q.Reshape(ctx, batchSize, opts.numHeads, -1) - q = q.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + q = LlamaRoPE(ctx, q, positionIDs, sa.RopeFactors, opts) k := sa.Key.Forward(ctx, hiddenState) k = k.Reshape(ctx, batchSize, opts.numKVHeads, -1) - k = k.RoPE(ctx, positionIDs, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + k = LlamaRoPE(ctx, k, positionIDs, sa.RopeFactors, opts) v := sa.Value.Forward(ctx, hiddenState) v = v.Reshape(ctx, batchSize, opts.numKVHeads, -1) @@ -97,7 +96,7 @@ func (sa *SelfAttention) Forward(ctx ml.Context, hiddenState, positionIDs ml.Ten } func (m *Model) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { - return key.RoPE(ctx, shift, m.Layers[layer].SelfAttention.RopeFactors, uint32(0), m.ropeDim, m.ropeBase, m.ropeScale), nil + return LlamaRoPE(ctx, key, shift, m.Layers[layer].SelfAttention.RopeFactors, m.Options), nil } type MLP struct { diff --git a/model/models/llama/utils.go b/model/models/llama/utils.go new file mode 100644 index 000000000..6826ddce4 --- /dev/null +++ b/model/models/llama/utils.go @@ -0,0 +1,82 @@ +package llama + +import ( + "math" + "sync" + + "github.com/ollama/ollama/ml" + "github.com/ollama/ollama/model" +) + +func LlamaRoPE(ctx ml.Context, x, positionIDs, ropeFactors ml.Tensor, opts *Options) ml.Tensor { + var once sync.Once + var _freqs ml.Tensor + dims := opts.ropeDim + onceBody := func() { + // Reference: https://github.com/ml-explore/mlx-examples/blob/main/llms/mlx_lm/models/rope_utils.py#L9 + + base := opts.ropeBase // aka rope_scale + if base == 0 { + base = 10000.0 + } + low_freq_factor := opts.ropeScale // ??? + high_freq_factor := float32(4.0) // TODO should attempt to get from metadata + factor := float32(8.0) // metadata? + old_context_len := float32(8192) // metadata? (aka original_max_position_embeddings) + + // Calcs... + low_freq_wavelen := float32(old_context_len) / low_freq_factor + high_freq_wavelen := float32(old_context_len) / high_freq_factor + + // freqs = base ** (mx.model.ArangeF32(0, dims, 2) / dims) + freqs := model.ArangeF32(0, float32(dims), 2) + for i := range freqs { + freqs[i] = (float32)(math.Pow(float64(base), float64(freqs[i])/float64(dims))) + } + // wavelens = 2 * mx.pi * freqs + wavelens := make([]float32, len(freqs)) + for i := range wavelens { + wavelens[i] = freqs[i] * 2 * float32(math.Pi) + } + // freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + for i := range freqs { + if wavelens[i] > low_freq_wavelen { + freqs[i] = freqs[i] * factor + } + } + // is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + is_medium_freq := make([]bool, len(freqs)) + for i := range freqs { + is_medium_freq[i] = (wavelens[i] > high_freq_wavelen) && (wavelens[i] < low_freq_wavelen) + } + // smooth_factors = (old_context_len / wavelens - low_freq_factor) / (high_freq_factor - low_freq_factor) + smooth_factors := make([]float32, len(freqs)) + for i := range freqs { + smooth_factors[i] = ((old_context_len)/wavelens[i] - (low_freq_factor)) / ((high_freq_factor) - (low_freq_factor)) + } + // smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + smooth_freqs := make([]float32, len(freqs)) + for i := range freqs { + smooth_freqs[i] = freqs[i] / ((1-smooth_factors[i])/factor + (smooth_factors[i])) + } + // _freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + for i := range freqs { + if is_medium_freq[i] { + freqs[i] = float32(smooth_freqs[i]) + } + } + _freqs, _ = ctx.Input().FromFloatSlice(freqs, len(freqs)) + } + once.Do(onceBody) + + return x.RoPE( + ctx, + positionIDs, + ropeFactors, + _freqs, + dims, + 0, // type + 500000, // base + 1.0, // scale + ) +} diff --git a/model/models/mllama/model_text.go b/model/models/mllama/model_text.go index cc2d0b26e..20b9337f1 100644 --- a/model/models/mllama/model_text.go +++ b/model/models/mllama/model_text.go @@ -24,11 +24,11 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m query := sa.Query.Forward(ctx, hiddenState) query = query.Reshape(ctx, batchSize, opts.numHeads, headDim) - query = query.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + query = query.RoPE(ctx, positions, nil /* TODO freqs */, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) key := sa.Key.Forward(ctx, hiddenState) key = key.Reshape(ctx, batchSize, opts.numKVHeads, headDim) - key = key.RoPE(ctx, positions, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) + key = key.RoPE(ctx, positions, nil /* TODO freqs */, sa.RopeFactors, opts.ropeDim, ropeType, opts.ropeBase, opts.ropeScale) value := sa.Value.Forward(ctx, hiddenState) value = value.Reshape(ctx, batchSize, opts.numKVHeads, headDim) @@ -42,8 +42,9 @@ func (sa *TextSelfAttention) Forward(ctx ml.Context, hiddenState, positions, _ m func (m *TextModel) Shift(ctx ml.Context, layer int, key, shift ml.Tensor) (ml.Tensor, error) { // This will only get called for layers in the cache, which are just the self attention layers - if sa, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { - return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + if _, ok := m.Transformer.Layers[layer].(*TextSelfAttentionDecoderLayer); ok { + // return key.RoPE(ctx, shift, sa.SelfAttention.RopeFactors, m.ropeDim, uint32(0), m.ropeBase, m.ropeScale), nil + panic("NOT YET IMPLEMENTED") } return key, nil