From d5a0d8d904baaf66a5326463a409fe4fa09b2dd2 Mon Sep 17 00:00:00 2001 From: Jesse Gross Date: Thu, 29 May 2025 12:21:48 -0700 Subject: [PATCH] llm: New memory management This changes the memory allocation strategy from upfront estimation to tracking actual allocations done by the engine and reacting to that. The goal is avoid issues caused by both under-estimation (crashing) and over-estimation (low performance due to under-utilized GPUs). It is currently opt-in and can be enabled for models running on the Ollama engine by setting OLLAMA_NEW_ESTIMATES=1. Behavior in other cases is unchanged and will continue to use the existing estimates. --- discover/amd_linux.go | 65 +- envconfig/config.go | 3 + fs/ggml/ggml.go | 2 + llama/llama.go | 16 + ... 0016-add-C-API-for-mtmd_input_text.patch} | 0 ...rary-prevent-rocm-cuda-mixed-loading.patch | 32 - ...no-power-throttling-win32-with-gnuc.patch} | 0 ...ch => 0018-BF16-macos-version-guard.patch} | 0 ...0019-Enable-CUDA-Graphs-for-gemma3n.patch} | 0 ...le-ggml-blas-on-macos-v13-and-older.patch} | 0 ...fix-mtmd-audio.cpp-build-on-windows.patch} | 0 ...de.patch => 0022-ggml-No-alloc-mode.patch} | 0 llm/memory.go | 94 +- llm/memory_test.go | 10 +- llm/server.go | 1049 ++++++++++++++--- llm/server_test.go | 169 +++ ml/backend.go | 162 ++- ml/backend/ggml/ggml.go | 226 ++-- ml/backend/ggml/ggml/src/ggml-backend-reg.cpp | 12 +- runner/llamarunner/runner.go | 146 ++- runner/ollamarunner/runner.go | 215 ++-- server/routes.go | 6 +- server/routes_generate_test.go | 6 +- server/routes_harmony_streaming_test.go | 9 +- server/sched.go | 369 ++---- server/sched_test.go | 169 +-- 26 files changed, 1860 insertions(+), 900 deletions(-) rename llama/patches/{0017-add-C-API-for-mtmd_input_text.patch => 0016-add-C-API-for-mtmd_input_text.patch} (100%) delete mode 100644 llama/patches/0016-temporary-prevent-rocm-cuda-mixed-loading.patch rename llama/patches/{0018-no-power-throttling-win32-with-gnuc.patch => 0017-no-power-throttling-win32-with-gnuc.patch} (100%) rename llama/patches/{0019-BF16-macos-version-guard.patch => 0018-BF16-macos-version-guard.patch} (100%) rename llama/patches/{0020-Enable-CUDA-Graphs-for-gemma3n.patch => 0019-Enable-CUDA-Graphs-for-gemma3n.patch} (100%) rename llama/patches/{0021-Disable-ggml-blas-on-macos-v13-and-older.patch => 0020-Disable-ggml-blas-on-macos-v13-and-older.patch} (100%) rename llama/patches/{0022-fix-mtmd-audio.cpp-build-on-windows.patch => 0021-fix-mtmd-audio.cpp-build-on-windows.patch} (100%) rename llama/patches/{0023-ggml-No-alloc-mode.patch => 0022-ggml-No-alloc-mode.patch} (100%) diff --git a/discover/amd_linux.go b/discover/amd_linux.go index dc9a4e185d..ebffbdf66c 100644 --- a/discover/amd_linux.go +++ b/discover/amd_linux.go @@ -97,6 +97,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { return a < b }) gpuCount := 0 + gpuOrdinalID := 0 for _, match := range matches { slog.Debug("evaluating amdgpu node " + match) fp, err := os.Open(match) @@ -187,10 +188,6 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { continue } - // Keep track of numeric IDs based on valid GPUs - gpuID := gpuCount - gpuCount += 1 - // Look up the memory for the current node totalMemory := uint64(0) usedMemory := uint64(0) @@ -269,7 +266,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { if uniqueID != 0 { ID = fmt.Sprintf("GPU-%016x", uniqueID) } else { - ID = strconv.Itoa(gpuID) + ID = strconv.Itoa(gpuOrdinalID) } gpuInfo := RocmGPUInfo{ @@ -287,13 +284,40 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { DriverMinor: driverMinor, }, usedFilepath: usedFile, - index: gpuID, + index: gpuCount, } + // Keep track of numeric IDs based on valid GPUs + gpuCount += 1 + + // If the user wants to filter to a subset of devices, filter out if we aren't a match + if len(visibleDevices) > 0 { + include := false + for _, visible := range visibleDevices { + if (uniqueID != 0 && visible == gpuInfo.ID) || visible == strconv.Itoa(gpuInfo.index) { + include = true + break + } + } + if !include { + reason := "filtering out device per user request" + slog.Info(reason, "id", gpuInfo.ID, "index", gpuInfo.index, "visible_devices", visibleDevices) + unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ + GpuInfo: gpuInfo.GpuInfo, + Reason: reason, + }) + + continue + } + } + + // Ordinal IDs are based on the visible GPUs + gpuOrdinalID += 1 + // iGPU detection, remove this check once we can support an iGPU variant of the rocm library if totalMemory < IGPUMemLimit { reason := "unsupported Radeon iGPU detected skipping" - slog.Info(reason, "id", gpuID, "total", format.HumanBytes2(totalMemory)) + slog.Info(reason, "id", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ GpuInfo: gpuInfo.GpuInfo, Reason: reason, @@ -306,7 +330,7 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { } if int(major) < minVer { reason := fmt.Sprintf("amdgpu too old gfx%d%x%x", major, minor, patch) - slog.Warn(reason, "gpu", gpuID) + slog.Warn(reason, "gpu", gpuInfo.ID) unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ GpuInfo: gpuInfo.GpuInfo, Reason: reason, @@ -315,29 +339,8 @@ func AMDGetGPUInfo() ([]RocmGPUInfo, error) { continue } - slog.Debug("amdgpu memory", "gpu", gpuID, "total", format.HumanBytes2(totalMemory)) - slog.Debug("amdgpu memory", "gpu", gpuID, "available", format.HumanBytes2(totalMemory-usedMemory)) - - // If the user wants to filter to a subset of devices, filter out if we aren't a match - if len(visibleDevices) > 0 { - include := false - for _, visible := range visibleDevices { - if visible == gpuInfo.ID || visible == strconv.Itoa(gpuInfo.index) { - include = true - break - } - } - if !include { - reason := "filtering out device per user request" - slog.Info(reason, "id", gpuInfo.ID, "visible_devices", visibleDevices) - unsupportedGPUs = append(unsupportedGPUs, UnsupportedGPUInfo{ - GpuInfo: gpuInfo.GpuInfo, - Reason: reason, - }) - - continue - } - } + slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "total", format.HumanBytes2(totalMemory)) + slog.Debug("amdgpu memory", "gpu", gpuInfo.ID, "available", format.HumanBytes2(totalMemory-usedMemory)) // Final validation is gfx compatibility - load the library if we haven't already loaded it // even if the user overrides, we still need to validate the library diff --git a/envconfig/config.go b/envconfig/config.go index 7fc0188703..868813ae85 100644 --- a/envconfig/config.go +++ b/envconfig/config.go @@ -185,6 +185,8 @@ var ( ContextLength = Uint("OLLAMA_CONTEXT_LENGTH", 4096) // Auth enables authentication between the Ollama client and server UseAuth = Bool("OLLAMA_AUTH") + // Enable the new memory estimation logic + NewMemoryEstimates = Bool("OLLAMA_NEW_ESTIMATES") ) func String(s string) func() string { @@ -270,6 +272,7 @@ func AsMap() map[string]EnvVar { "OLLAMA_MULTIUSER_CACHE": {"OLLAMA_MULTIUSER_CACHE", MultiUserCache(), "Optimize prompt caching for multi-user scenarios"}, "OLLAMA_CONTEXT_LENGTH": {"OLLAMA_CONTEXT_LENGTH", ContextLength(), "Context length to use unless otherwise specified (default: 4096)"}, "OLLAMA_NEW_ENGINE": {"OLLAMA_NEW_ENGINE", NewEngine(), "Enable the new Ollama engine"}, + "OLLAMA_NEW_ESTIMATES": {"OLLAMA_NEW_ESTIMATES", NewMemoryEstimates(), "Enable the new memory estimation logic"}, // Informational "HTTP_PROXY": {"HTTP_PROXY", String("HTTP_PROXY")(), "HTTP proxy"}, diff --git a/fs/ggml/ggml.go b/fs/ggml/ggml.go index 008d1e0fcd..1fef745249 100644 --- a/fs/ggml/ggml.go +++ b/fs/ggml/ggml.go @@ -480,6 +480,8 @@ func Decode(rs io.ReadSeeker, maxArraySize int) (*GGML, error) { } func (f GGML) GraphSize(context, batch uint64, numParallel int, kvCacheType string) (kv []uint64, partialOffload, fullOffload uint64) { + context *= uint64(numParallel) + embedding := f.KV().EmbeddingLength() heads := f.KV().HeadCountMax() headsKV := f.KV().HeadCountKVMax() diff --git a/llama/llama.go b/llama/llama.go index 4885949b73..ac2c112c29 100644 --- a/llama/llama.go +++ b/llama/llama.go @@ -62,6 +62,22 @@ func BackendInit() { C.llama_backend_init() } +func EnumerateGPUs() []string { + var ids []string + + for i := range C.ggml_backend_dev_count() { + device := C.ggml_backend_dev_get(i) + + if C.ggml_backend_dev_type(device) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + var props C.struct_ggml_backend_dev_props + C.ggml_backend_dev_get_props(device, &props) + ids = append(ids, C.GoString(props.id)) + } + } + + return ids +} + func GetModelArch(modelPath string) (string, error) { mp := C.CString(modelPath) defer C.free(unsafe.Pointer(mp)) diff --git a/llama/patches/0017-add-C-API-for-mtmd_input_text.patch b/llama/patches/0016-add-C-API-for-mtmd_input_text.patch similarity index 100% rename from llama/patches/0017-add-C-API-for-mtmd_input_text.patch rename to llama/patches/0016-add-C-API-for-mtmd_input_text.patch diff --git a/llama/patches/0016-temporary-prevent-rocm-cuda-mixed-loading.patch b/llama/patches/0016-temporary-prevent-rocm-cuda-mixed-loading.patch deleted file mode 100644 index f085e0c7ca..0000000000 --- a/llama/patches/0016-temporary-prevent-rocm-cuda-mixed-loading.patch +++ /dev/null @@ -1,32 +0,0 @@ -From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 -From: Daniel Hiltgen -Date: Sun, 22 Jun 2025 09:22:05 -0700 -Subject: [PATCH] temporary prevent rocm+cuda mixed loading - ---- - ggml/src/ggml-backend-reg.cpp | 12 ++++++++++-- - 1 file changed, 10 insertions(+), 2 deletions(-) - -diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp -index 3040b2aa..f1e9c180 100644 ---- a/ggml/src/ggml-backend-reg.cpp -+++ b/ggml/src/ggml-backend-reg.cpp -@@ -581,8 +581,16 @@ void ggml_backend_load_all_from_path(const char * dir_path) { - - ggml_backend_load_best("blas", silent, dir_path); - ggml_backend_load_best("cann", silent, dir_path); -- ggml_backend_load_best("cuda", silent, dir_path); -- ggml_backend_load_best("hip", silent, dir_path); -+ -+ // Avoid mixed hip+cuda configurations -+ const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES"); -+ const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES"); -+ if (!hip_devices && !rocr_devices) { -+ ggml_backend_load_best("cuda", silent, dir_path); -+ } else { -+ ggml_backend_load_best("hip", silent, dir_path); -+ } -+ - ggml_backend_load_best("metal", silent, dir_path); - ggml_backend_load_best("rpc", silent, dir_path); - ggml_backend_load_best("sycl", silent, dir_path); diff --git a/llama/patches/0018-no-power-throttling-win32-with-gnuc.patch b/llama/patches/0017-no-power-throttling-win32-with-gnuc.patch similarity index 100% rename from llama/patches/0018-no-power-throttling-win32-with-gnuc.patch rename to llama/patches/0017-no-power-throttling-win32-with-gnuc.patch diff --git a/llama/patches/0019-BF16-macos-version-guard.patch b/llama/patches/0018-BF16-macos-version-guard.patch similarity index 100% rename from llama/patches/0019-BF16-macos-version-guard.patch rename to llama/patches/0018-BF16-macos-version-guard.patch diff --git a/llama/patches/0020-Enable-CUDA-Graphs-for-gemma3n.patch b/llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch similarity index 100% rename from llama/patches/0020-Enable-CUDA-Graphs-for-gemma3n.patch rename to llama/patches/0019-Enable-CUDA-Graphs-for-gemma3n.patch diff --git a/llama/patches/0021-Disable-ggml-blas-on-macos-v13-and-older.patch b/llama/patches/0020-Disable-ggml-blas-on-macos-v13-and-older.patch similarity index 100% rename from llama/patches/0021-Disable-ggml-blas-on-macos-v13-and-older.patch rename to llama/patches/0020-Disable-ggml-blas-on-macos-v13-and-older.patch diff --git a/llama/patches/0022-fix-mtmd-audio.cpp-build-on-windows.patch b/llama/patches/0021-fix-mtmd-audio.cpp-build-on-windows.patch similarity index 100% rename from llama/patches/0022-fix-mtmd-audio.cpp-build-on-windows.patch rename to llama/patches/0021-fix-mtmd-audio.cpp-build-on-windows.patch diff --git a/llama/patches/0023-ggml-No-alloc-mode.patch b/llama/patches/0022-ggml-No-alloc-mode.patch similarity index 100% rename from llama/patches/0023-ggml-No-alloc-mode.patch rename to llama/patches/0022-ggml-No-alloc-mode.patch diff --git a/llm/memory.go b/llm/memory.go index b530000464..ee4be74196 100644 --- a/llm/memory.go +++ b/llm/memory.go @@ -4,7 +4,7 @@ import ( "fmt" "log/slog" "os" - "strconv" + "sort" "strings" "github.com/ollama/ollama/api" @@ -14,13 +14,79 @@ import ( "github.com/ollama/ollama/fs/ggml" ) +// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits +// The list of GPUs returned will always be the same brand (library) +// If the model can not be fit fully within the available GPU(s) nil is returned +func pickBestFullFitByLibrary(f *ggml.GGML, modelPath string, projectors []string, adapters []string, opts api.Options, gpus discover.GpuInfoList, numParallel int) discover.GpuInfoList { + for _, gl := range gpus.ByLibrary() { + sgl := append(make(discover.GpuInfoList, 0, len(gl)), gl...) + + // TODO - potentially sort by performance capability, existing models loaded, etc. + // TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them + // Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups + sort.Sort(sort.Reverse(discover.ByFreeMemory(sgl))) + + if !envconfig.SchedSpread() { + // Try to pack into as few GPUs as possible, starting from 1 GPU + for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ { + gpuSubset := sgl[:numGPUs] + ok, estimatedVRAM := PredictServerFit(gpuSubset, f, adapters, projectors, opts, numParallel) + + if ok { + slog.Info("new model will fit in available VRAM across minimum required GPUs, loading", + "model", modelPath, + "library", sgl[0].Library, + "parallel", numParallel, + "required", format.HumanBytes2(estimatedVRAM), + "gpus", numGPUs) + return gpuSubset + } + } + } else { + // TODO future refinements + // - if multiple Libraries, see if any single GPU in any Library will fit + // - try subsets of GPUs instead of just falling back to 1 or all in a family + + // Now try all the GPUS (OLLAMA_SCHED_SPREAD is set) + if ok, estimatedVRAM := PredictServerFit(sgl, f, adapters, projectors, opts, numParallel); ok { + slog.Info("new model will fit in available VRAM, loading", + "model", modelPath, + "library", sgl[0].Library, + "parallel", numParallel, + "required", format.HumanBytes2(estimatedVRAM), + "gpus", len(sgl)) + return sgl + } + } + } + return nil +} + +// If multiple Libraries are detected, pick the Library which loads the most layers for the model +func pickBestPartialFitByLibrary(f *ggml.GGML, projectors []string, adapters []string, opts api.Options, gpus discover.GpuInfoList, numParallel int) discover.GpuInfoList { + byLibrary := gpus.ByLibrary() + if len(byLibrary) <= 1 { + return gpus + } + var bestEstimate uint64 + var bestFit int + for i, gl := range byLibrary { + _, estimatedVRAM := PredictServerFit(gl, f, adapters, projectors, opts, numParallel) + if estimatedVRAM > bestEstimate { + bestEstimate = estimatedVRAM + bestFit = i + } + } + return byLibrary[bestFit] +} + // This algorithm looks for a complete fit to determine if we need to unload other models func PredictServerFit(allGpus discover.GpuInfoList, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (bool, uint64) { // Split up the GPUs by type and try them var estimatedVRAM uint64 for _, gpus := range allGpus.ByLibrary() { var layerCount int - estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel) + estimate := estimateGPULayers(gpus, f, projectors, opts, numParallel) layerCount, estimatedVRAM = estimate.Layers, estimate.VRAMSize if opts.NumGPU < 0 { if layerCount > 0 && layerCount >= int(f.KV().BlockCount()+1) { @@ -49,7 +115,7 @@ type MemoryEstimate struct { TotalSize uint64 // For multi-GPU scenarios, this provides the tensor split parameter - TensorSplit string + TensorSplit []int // For multi-GPU scenarios, this is the size in bytes per GPU GPUSizes []uint64 @@ -71,7 +137,7 @@ type MemoryEstimate struct { // Given a model and one or more GPU targets, predict how many layers and bytes we can load, and the total size // The GPUs provided must all be the same Library -func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate { +func estimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []string, opts api.Options, numParallel int) MemoryEstimate { // Graph size for a partial offload, applies to all GPUs var graphPartialOffload uint64 @@ -112,13 +178,9 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin for _, projector := range projectors { llamaEngineProjectorWeights += projectorMemoryRequirements(projector) - - // multimodal models require at least 2048 context - opts.NumCtx = max(opts.NumCtx, 2048) } if llamaEngineProjectorWeights == 0 { ollamaEngineProjectorWeights, ollamaEngineProjectorGraph = f.VisionGraphSize() - opts.NumCtx = max(opts.NumCtx, 2048) } layers := f.Tensors().GroupLayers() @@ -184,7 +246,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin // Reduce set of GPUs to only those that have sufficient space to fit overhead and at least one layer var layerCount int - layerCounts := make([]int, len(gpus)) + tensorSplit := make([]int, len(gpus)) gpuAllocations := make([]uint64, len(gpus)) type gs struct { i int @@ -248,7 +310,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload) if g.g.FreeMemory > overhead+used+layerSize { gpuAllocations[g.i] += layerSize - layerCounts[g.i]++ + tensorSplit[g.i]++ layerCount++ break } else { @@ -273,7 +335,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin used := gpuAllocations[g.i] + max(graphPartialOffload, graphFullOffload) if g.g.FreeMemory > overhead+used+memoryLastLayer { gpuAllocations[g.i] += memoryLastLayer - layerCounts[g.i]++ + tensorSplit[g.i]++ layerCount++ break } @@ -288,7 +350,7 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin // Add the applicable (full or partial) graph allocations for i := range gpus { - if layerCounts[i] <= 0 { + if tensorSplit[i] <= 0 { continue } if fullyLoaded { @@ -310,14 +372,6 @@ func EstimateGPULayers(gpus []discover.GpuInfo, f *ggml.GGML, projectors []strin } memoryRequiredTotal = memoryRequiredPartial + overflow - tensorSplit := "" - if len(gpus) > 1 { - splits := make([]string, len(gpus)) - for i, count := range layerCounts { - splits[i] = strconv.Itoa(count) - } - tensorSplit = strings.Join(splits, ",") - } allocationsList := []string{} for _, a := range gpuAllocations { allocationsList = append(allocationsList, format.HumanBytes2(a)) diff --git a/llm/memory_test.go b/llm/memory_test.go index 1d4f7a98cc..49851006c6 100644 --- a/llm/memory_test.go +++ b/llm/memory_test.go @@ -61,7 +61,7 @@ func TestEstimateGPULayers(t *testing.T) { projectors := []string{} opts := api.DefaultOptions() t.Run("cpu", func(t *testing.T) { - estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1) + estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1) assert.Equal(t, 0, estimate.Layers) assert.Equal(t, uint64(0), estimate.Graph) }) @@ -88,7 +88,7 @@ func TestEstimateGPULayers(t *testing.T) { // Nested array: GPU0 layer space, GPU1 layer space, expected gpu0, expected gpu1 for i, s := range []struct { layer0, layer1 uint64 - expect0, expect1 uint64 + expect0, expect1 int }{ {1, 1, 1, 1}, {2, 1, 2, 1}, @@ -112,9 +112,9 @@ func TestEstimateGPULayers(t *testing.T) { gpus[1].FreeMemory += gpuMinimumMemory + layerSize + s.layer1*layerSize + 1 gpus[0].FreeMemory += max(graphFullOffload, graphPartialOffload) gpus[1].FreeMemory += max(graphFullOffload, graphPartialOffload) - estimate := EstimateGPULayers(gpus, ggml, projectors, opts, 1) - assert.Equal(t, int(s.expect0+s.expect1), estimate.Layers, "scenario %d: %v", i, s) - assert.Equal(t, fmt.Sprintf("%d,%d", s.expect0, s.expect1), estimate.TensorSplit, "scenario %d: %v", i, s) + estimate := estimateGPULayers(gpus, ggml, projectors, opts, 1) + assert.Equal(t, s.expect0+s.expect1, estimate.Layers, "scenario %d: %v", i, s) + assert.Equal(t, []int{s.expect0, s.expect1}, estimate.TensorSplit, "scenario %d: %v", i, s) var layerSums uint64 for _, b := range estimate.GPUSizes { layerSums += b diff --git a/llm/server.go b/llm/server.go index 7d921f1443..01224a166f 100644 --- a/llm/server.go +++ b/llm/server.go @@ -18,6 +18,7 @@ import ( "path/filepath" "runtime" "slices" + "sort" "strconv" "strings" "sync" @@ -32,6 +33,7 @@ import ( "github.com/ollama/ollama/fs/ggml" "github.com/ollama/ollama/llama" "github.com/ollama/ollama/logutil" + "github.com/ollama/ollama/ml" "github.com/ollama/ollama/model" ) @@ -63,6 +65,8 @@ func (e filteredEnv) LogValue() slog.Value { } type LlamaServer interface { + ModelPath() string + Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error Ping(ctx context.Context) error WaitUntilRunning(ctx context.Context) error Completion(ctx context.Context, req CompletionRequest, fn func(CompletionResponse)) error @@ -70,13 +74,13 @@ type LlamaServer interface { Tokenize(ctx context.Context, content string) ([]int, error) Detokenize(ctx context.Context, tokens []int) (string, error) Close() error - EstimatedVRAM() uint64 // Total VRAM across all GPUs - EstimatedTotal() uint64 - EstimatedVRAMByGPU(gpuID string) uint64 + VRAMSize() uint64 // Total VRAM across all GPUs + TotalSize() uint64 + VRAMByGPU(gpuID string) uint64 Pid() int } -// llmServer is an instance of the llama.cpp server +// llmServer is an instance of a runner hosting a single model type llmServer struct { port int cmd *exec.Cmd @@ -86,25 +90,38 @@ type llmServer struct { numParallel int modelPath string + loadRequest LoadRequest // Parameters used to initialize the runner + // llamaModel is an instance of the cgo llama.cpp model definition // nil if this server is running the new engine llamaModel *llama.Model - llamaModelLock sync.Mutex + llamaModelLock *sync.Mutex // textProcessor handles text encoding/decoding for the model in the Ollama engine // nil if this server is running the llama.cpp based engine textProcessor model.TextProcessor - estimate MemoryEstimate - totalLayers uint64 - // gpuCount int - gpus discover.GpuInfoList // Recorded just before the model loaded, free space will be incorrect - loadDuration time.Duration // Record how long it took the model to load + totalLayers uint64 + loadStart time.Time // Record how long it took the model to load loadProgress float32 sem *semaphore.Weighted } +type llamaServer struct { + llmServer + + ggml *ggml.GGML + gpus discover.GpuInfoList // The set of GPUs covered by the memory estimate + estimate MemoryEstimate +} + +type ollamaServer struct { + llmServer + + mem *ml.BackendMemory +} + // LoadModel will load a model from disk. The model must be in the GGML format. // // It collects array values for arrays with a size less than or equal to @@ -126,81 +143,57 @@ func LoadModel(model string, maxArraySize int) (*ggml.GGML, error) { } // NewLlamaServer will run a server for the given GPUs -// The gpu list must be a single family. func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, adapters, projectors []string, opts api.Options, numParallel int) (LlamaServer, error) { - systemInfo := discover.GetSystemInfo() - systemTotalMemory := systemInfo.System.TotalMemory - systemFreeMemory := systemInfo.System.FreeMemory - systemSwapFreeMemory := systemInfo.System.FreeSwap - slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory)) + var llamaModel *llama.Model + var textProcessor model.TextProcessor + var err error + if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { + textProcessor, err = model.NewTextProcessor(modelPath) + if err != nil { + // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner + slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err) + } + } + if textProcessor == nil { + llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true}) + if err != nil { + return nil, err + } + } - // If the user wants zero GPU layers, reset the gpu list to be CPU/system ram info - if opts.NumGPU == 0 { - gpus = discover.GetCPUInfo() + newEstimates := textProcessor != nil && envconfig.NewMemoryEstimates() + if newEstimates { + slog.Info("enabling new memory estimates") } // Verify the requested context size is <= the model training size trainCtx := f.KV().ContextLength() - if opts.NumCtx/numParallel > int(trainCtx) && trainCtx > 0 { - slog.Warn("requested context size too large for model", "num_ctx", opts.NumCtx, "num_parallel", numParallel, "n_ctx_train", trainCtx) - opts.NumCtx = int(trainCtx) * numParallel + if opts.NumCtx > int(trainCtx) && trainCtx > 0 { + slog.Warn("requested context size too large for model", "num_ctx", opts.NumCtx, "n_ctx_train", trainCtx) + opts.NumCtx = int(trainCtx) } - estimate := EstimateGPULayers(gpus, f, projectors, opts, numParallel) - if len(gpus) > 1 || gpus[0].Library != "cpu" { - switch { - case gpus[0].Library == "metal" && estimate.VRAMSize > systemTotalMemory: - // disable partial offloading when model is greater than total system memory as this - // can lead to locking up the system - opts.NumGPU = 0 - case gpus[0].Library != "metal" && estimate.Layers == 0: - // Don't bother loading into the GPU if no layers can fit - gpus = discover.GetCPUInfo() - case opts.NumGPU < 0 && estimate.Layers > 0 && gpus[0].Library != "cpu": - opts.NumGPU = estimate.Layers - } + loadRequest := LoadRequest{LoraPath: adapters, KvSize: opts.NumCtx * numParallel, BatchSize: opts.NumBatch, Parallel: numParallel, MultiUserCache: envconfig.MultiUserCache()} + + defaultThreads := discover.GetSystemInfo().GetOptimalThreadCount() + if opts.NumThread > 0 { + loadRequest.NumThreads = opts.NumThread + } else if defaultThreads > 0 { + loadRequest.NumThreads = defaultThreads } - // On linux and windows, over-allocating CPU memory will almost always result in an error - // Darwin has fully dynamic swap so has no direct concept of free swap space - if runtime.GOOS != "darwin" { - systemMemoryRequired := estimate.TotalSize - estimate.VRAMSize - available := systemFreeMemory + systemSwapFreeMemory - if systemMemoryRequired > available { - slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", available, "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "swap", format.HumanBytes2(systemSwapFreeMemory)) - return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) - } - } - - slog.Info("offload", "", estimate) - - params := []string{ - "--model", modelPath, - "--ctx-size", strconv.Itoa(opts.NumCtx), - "--batch-size", strconv.Itoa(opts.NumBatch), - } - - if opts.NumGPU >= 0 { - params = append(params, "--n-gpu-layers", strconv.Itoa(opts.NumGPU)) - } + // TODO - NUMA support currently doesn't work properly if opts.MainGPU > 0 { - params = append(params, "--main-gpu", strconv.Itoa(opts.MainGPU)) + loadRequest.MainGPU = opts.MainGPU } - if len(adapters) > 0 { - for _, adapter := range adapters { - params = append(params, "--lora", adapter) - } - } - - defaultThreads := systemInfo.GetOptimalThreadCount() - if opts.NumThread > 0 { - params = append(params, "--threads", strconv.Itoa(opts.NumThread)) - } else if defaultThreads > 0 { - params = append(params, "--threads", strconv.Itoa(defaultThreads)) + if len(projectors) > 0 && llamaModel != nil { + loadRequest.ProjectorPath = projectors[0] } + // This will disable flash attention unless all GPUs on the system support it, even if we end up selecting a subset + // that can handle it. fa := envconfig.FlashAttention() if fa && !gpus.FlashAttentionSupported() { slog.Warn("flash attention enabled but not supported by gpu") @@ -216,12 +209,12 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a if fa { slog.Info("enabling flash attention") - params = append(params, "--flash-attn") + loadRequest.FlashAttention = true // Flash Attention also supports kv cache quantization // Enable if the requested and kv cache type is supported by the model if kvct != "" && f.SupportsKVCacheType(kvct) { - params = append(params, "--kv-cache-type", kvct) + loadRequest.KvCacheType = kvct } else { slog.Warn("kv cache type not supported by model", "type", kvct) } @@ -229,66 +222,45 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a slog.Warn("quantized kv cache requested but flash attention disabled", "type", kvct) } - // mmap has issues with partial offloading on metal - for _, g := range gpus { - if g.Library == "metal" && - uint64(opts.NumGPU) > 0 && - uint64(opts.NumGPU) < f.KV().BlockCount()+1 { - opts.UseMMap = new(bool) - *opts.UseMMap = false - } - } - - // Windows CUDA should not use mmap for best performance - // Linux with a model larger than free space, mmap leads to thrashing - // For CPU loads we want the memory to be allocated, not FS cache - if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && opts.UseMMap == nil) || - (runtime.GOOS == "linux" && systemFreeMemory < estimate.TotalSize && opts.UseMMap == nil) || - (gpus[0].Library == "cpu" && opts.UseMMap == nil) || - (opts.UseMMap != nil && !*opts.UseMMap) { - params = append(params, "--no-mmap") - } - - // TODO - NUMA support currently doesn't work properly - - params = append(params, "--parallel", strconv.Itoa(numParallel)) - - if estimate.TensorSplit != "" { - params = append(params, "--tensor-split", estimate.TensorSplit) - } - - if envconfig.MultiUserCache() { - params = append(params, "--multiuser-cache") - } - - libs := make(map[string]string) + availableLibs := make(map[string]string) if entries, err := os.ReadDir(discover.LibOllamaPath); err == nil { for _, entry := range entries { - libs[entry.Name()] = filepath.Join(discover.LibOllamaPath, entry.Name()) + availableLibs[entry.Name()] = filepath.Join(discover.LibOllamaPath, entry.Name()) } } - lib := gpus[0].RunnerName() + var gpuLibs []string + for _, gpu := range gpus { + gpuLibs = append(gpuLibs, gpu.RunnerName()) + } + requested := envconfig.LLMLibrary() - if libs[requested] != "" { + if availableLibs[requested] != "" { slog.Info("using requested gpu library", "requested", requested) - lib = requested + gpuLibs = []string{requested} } var compatible []string - for k := range libs { - // exact match first - if k == lib { - compatible = append([]string{k}, compatible...) - continue + for _, gpuLib := range gpuLibs { + var matchingLibs []string + for k := range availableLibs { + // exact match first + if k == gpuLib { + matchingLibs = append([]string{k}, matchingLibs...) + continue + } + + // then match the family (e.g. 'cuda') + if strings.Split(k, "_")[0] == strings.Split(gpuLib, "_")[0] { + matchingLibs = append(matchingLibs, k) + } } - // then match the family (e.g. 'cuda') - if strings.Split(k, "_")[0] == strings.Split(lib, "_")[0] { - compatible = append(compatible, k) + if len(matchingLibs) > 0 { + compatible = append(compatible, matchingLibs[0]) } } - slog.Debug("compatible gpu libraries", "compatible", compatible) + exe, err := os.Executable() if err != nil { return nil, fmt.Errorf("unable to lookup executable path: %w", err) @@ -298,26 +270,6 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a exe = eval } - var llamaModel *llama.Model - var textProcessor model.TextProcessor - if envconfig.NewEngine() || f.KV().OllamaEngineRequired() { - textProcessor, err = model.NewTextProcessor(modelPath) - if err != nil { - // To prepare for opt-out mode, instead of treating this as an error, we fallback to the old runner - slog.Debug("model not yet supported by Ollama engine, switching to compatibility mode", "model", modelPath, "error", err) - } - } - if textProcessor == nil { - llamaModel, err = llama.LoadModelFromFile(modelPath, llama.ModelParams{VocabOnly: true}) - if err != nil { - return nil, err - } - } - - if len(projectors) > 0 && llamaModel != nil { - params = append(params, "--mmproj", projectors[0]) - } - // iterate through compatible GPU libraries such as 'cuda_v12', 'rocm', etc. // adding each library's respective path to the LD_LIBRARY_PATH, until finally running // without any LD_LIBRARY_PATH flags @@ -334,14 +286,14 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a slog.Debug("ResolveTCPAddr failed, using random port") port = rand.Intn(65535-49152) + 49152 // get a random port in the ephemeral range } - finalParams := []string{"runner"} + params := []string{"runner"} if textProcessor != nil { // New engine // TODO - if we have failure to load scenarios, add logic to retry with the old runner - finalParams = append(finalParams, "--ollama-engine") + params = append(params, "--ollama-engine") } - finalParams = append(finalParams, params...) - finalParams = append(finalParams, "--port", strconv.Itoa(port)) + params = append(params, "--model", modelPath) + params = append(params, "--port", strconv.Itoa(port)) var pathEnv string switch runtime.GOOS { @@ -361,38 +313,39 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } ggmlPaths := []string{discover.LibOllamaPath} - if len(compatible) > 0 { - c := compatible[0] - if libpath, ok := libs[c]; ok { + for _, c := range compatible { + if libpath, ok := availableLibs[c]; ok { slog.Debug("adding gpu library", "path", libpath) libraryPaths = append([]string{libpath}, libraryPaths...) ggmlPaths = append(ggmlPaths, libpath) } } - if gpus[0].DependencyPath != nil { - slog.Debug("adding gpu dependency paths", "paths", gpus[0].DependencyPath) - // assume gpus from the same library have the same dependency path - libraryPaths = append(gpus[0].DependencyPath, libraryPaths...) + for _, gpu := range gpus { + if gpu.DependencyPath != nil { + slog.Debug("adding gpu dependency paths", "paths", gpu.DependencyPath) + libraryPaths = append(gpu.DependencyPath, libraryPaths...) + } } // finally, add the root library path libraryPaths = append(libraryPaths, discover.LibOllamaPath) - s := &llmServer{ - port: port, - cmd: exec.Command(exe, finalParams...), - status: NewStatusWriter(os.Stderr), - options: opts, - modelPath: modelPath, - llamaModel: llamaModel, - textProcessor: textProcessor, - estimate: estimate, - numParallel: numParallel, - sem: semaphore.NewWeighted(int64(numParallel)), - totalLayers: f.KV().BlockCount() + 1, - gpus: gpus, - done: make(chan error, 1), + s := llmServer{ + port: port, + cmd: exec.Command(exe, params...), + status: NewStatusWriter(os.Stderr), + options: opts, + modelPath: modelPath, + loadRequest: loadRequest, + llamaModel: llamaModel, + llamaModelLock: &sync.Mutex{}, + textProcessor: textProcessor, + numParallel: numParallel, + sem: semaphore.NewWeighted(int64(numParallel)), + totalLayers: f.KV().BlockCount() + 1, + loadStart: time.Now(), + done: make(chan error, 1), } s.cmd.Env = os.Environ() @@ -406,20 +359,15 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a for _, gpu := range gpus { envWorkarounds = append(envWorkarounds, gpu.EnvWorkarounds...) } - visibleDevicesEnv, visibleDevicesEnvVal := gpus.GetVisibleDevicesEnv() pathEnvVal := strings.Join(libraryPaths, string(filepath.ListSeparator)) - // Update or add the path and visible devices variable with our adjusted version + // Update or add the path variable with our adjusted version pathNeeded := true - devicesNeeded := visibleDevicesEnv != "" for i := range s.cmd.Env { cmp := strings.SplitN(s.cmd.Env[i], "=", 2) if strings.EqualFold(cmp[0], pathEnv) { s.cmd.Env[i] = pathEnv + "=" + pathEnvVal pathNeeded = false - } else if devicesNeeded && strings.EqualFold(cmp[0], visibleDevicesEnv) { - s.cmd.Env[i] = visibleDevicesEnv + "=" + visibleDevicesEnvVal - devicesNeeded = false } else if len(envWorkarounds) != 0 { for _, kv := range envWorkarounds { if strings.EqualFold(cmp[0], kv[0]) { @@ -431,11 +379,8 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a if pathNeeded { s.cmd.Env = append(s.cmd.Env, pathEnv+"="+pathEnvVal) } - if devicesNeeded { - s.cmd.Env = append(s.cmd.Env, visibleDevicesEnv+"="+visibleDevicesEnvVal) - } - slog.Info("starting llama server", "cmd", s.cmd) + slog.Info("starting runner", "cmd", s.cmd) slog.Debug("subprocess", "", filteredEnv(s.cmd.Env)) if err = s.cmd.Start(); err != nil { @@ -471,15 +416,703 @@ func NewLlamaServer(gpus discover.GpuInfoList, modelPath string, f *ggml.GGML, a } }() - return s, nil + if newEstimates { + return &ollamaServer{llmServer: s}, nil + } else { + return &llamaServer{llmServer: s, ggml: f}, nil + } } } +func (s *llmServer) ModelPath() string { + return s.modelPath +} + +type LoadOperation int + +// The order of these constants are significant because we iterate over the operations. They +// should be in order of increasingly loading the model. +const ( + LoadOperationFit LoadOperation = iota // Return memory requirements but do not allocate + LoadOperationAlloc // Allocate memory but do not load the weights + LoadOperationCommit // Load weights - further changes cannot be made after this + LoadOperationClose // Close model and free memory +) + +func (o LoadOperation) String() string { + switch o { + case LoadOperationFit: + return "fit" + case LoadOperationAlloc: + return "alloc" + case LoadOperationCommit: + return "commit" + case LoadOperationClose: + return "close" + default: + return "unknown" + } +} + +type LoadRequest struct { + Operation LoadOperation + + LoraPath []string + Parallel int + BatchSize int + FlashAttention bool + KvSize int + KvCacheType string + NumThreads int + GPULayers ml.GPULayersList + MultiUserCache bool + + // Legacy fields - not used with the Ollama engine + ProjectorPath string + MainGPU int + UseMmap bool +} + +type LoadResponse struct { + Success bool + Memory ml.BackendMemory +} + +var ErrLoadRequiredFull = errors.New("unable to load full model on GPU") + +func (s *llamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { + systemInfo := discover.GetSystemInfo() + systemTotalMemory := systemInfo.System.TotalMemory + systemFreeMemory := systemInfo.System.FreeMemory + systemSwapFreeMemory := systemInfo.System.FreeSwap + slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory)) + + g := pickBestFullFitByLibrary(s.ggml, s.modelPath, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel) + if g == nil { + if !requireFull { + g = pickBestPartialFitByLibrary(s.ggml, []string{s.loadRequest.ProjectorPath}, s.loadRequest.LoraPath, s.options, gpus, s.numParallel) + } else { + return ErrLoadRequiredFull + } + } + + gpus = g + s.estimate = estimateGPULayers(gpus, s.ggml, []string{s.loadRequest.ProjectorPath}, s.options, s.numParallel) + + if len(gpus) > 1 || gpus[0].Library != "cpu" { + switch { + case gpus[0].Library == "metal" && s.estimate.VRAMSize > systemInfo.System.TotalMemory: + // disable partial offloading when model is greater than total system memory as this + // can lead to locking up the system + s.options.NumGPU = 0 + case gpus[0].Library != "metal" && s.estimate.Layers == 0: + // Don't bother loading into the GPU if no layers can fit + gpus = discover.GetCPUInfo() + case s.options.NumGPU < 0 && s.estimate.Layers > 0 && gpus[0].Library != "cpu": + s.options.NumGPU = s.estimate.Layers + } + } + + // On linux and windows, over-allocating CPU memory will almost always result in an error + // Darwin has fully dynamic swap so has no direct concept of free swap space + if runtime.GOOS != "darwin" { + systemMemoryRequired := s.estimate.TotalSize - s.estimate.VRAMSize + available := systemInfo.System.FreeMemory + systemInfo.System.FreeSwap + if systemMemoryRequired > available { + slog.Warn("model request too large for system", "requested", format.HumanBytes2(systemMemoryRequired), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.System.TotalMemory), "free", format.HumanBytes2(systemInfo.System.FreeMemory), "swap", format.HumanBytes2(systemInfo.System.FreeSwap)) + return fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(systemMemoryRequired), format.HumanBytes2(available)) + } + } + + if requireFull && len(gpus) == 1 && gpus[0].Library == "cpu" && s.estimate.TotalSize > gpus[0].FreeMemory { + return ErrLoadRequiredFull + } + + slog.Info("offload", "", s.estimate) + + s.gpus = gpus + s.loadRequest.GPULayers = createGPULayers(s.estimate, s.ggml, gpus, s.options.NumGPU) + + // Mmap is only supported on the llama engine + if s.textProcessor == nil { + s.loadRequest.UseMmap = true + + // mmap has issues with partial offloading on metal + for _, g := range gpus { + if g.Library == "metal" && + uint64(s.options.NumGPU) > 0 && + uint64(s.options.NumGPU) < s.ggml.KV().BlockCount()+1 { + s.options.UseMMap = new(bool) + *s.options.UseMMap = false + } + } + + // Windows CUDA should not use mmap for best performance + // Linux with a model larger than free space, mmap leads to thrashing + // For CPU loads we want the memory to be allocated, not FS cache + if (runtime.GOOS == "windows" && gpus[0].Library == "cuda" && s.options.UseMMap == nil) || + (runtime.GOOS == "linux" && systemInfo.System.FreeMemory < s.estimate.TotalSize && s.options.UseMMap == nil) || + (gpus[0].Library == "cpu" && s.options.UseMMap == nil) || + (s.options.UseMMap != nil && !*s.options.UseMMap) { + s.loadRequest.UseMmap = false + } + } + + if err := s.waitUntilRunnerLaunched(ctx); err != nil { + return err + } + + resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit) + if err != nil { + return err + } + + // On the Ollama engine, we can print out a summary of the memory allocations. + // We don't have this for the llama engine but it does something similar itself. + if s.textProcessor != nil { + resp.Memory.Log(slog.LevelInfo) + } + + if !resp.Success { + slog.Warn("failed to allocate memory for model", "memory", resp.Memory) + return errors.New("failed to allocate memory for model") + } + + // The llama engine does its memory allocations together with model loading, so we + // need to wait until it is done to ensure that we have accurate memory data before + // loading the next model + if s.textProcessor == nil { + return s.WaitUntilRunning(ctx) + } else { + return nil + } +} + +// createGPULayers maps from the tensor splits assigned by the memory estimates to explicit assignment +// of particular layers onto GPUs +func createGPULayers(estimate MemoryEstimate, ggml *ggml.GGML, gpus discover.GpuInfoList, numGPU int) ml.GPULayersList { + if numGPU <= 0 { + return nil + } + + gpuLayers := make(ml.GPULayersList, len(gpus)) + for i := range gpuLayers { + gpuLayers[i].ID = gpus[i].ID + } + + var sum float32 + splits := make([]float32, len(estimate.TensorSplit)) + // cumulative sum of all splits + for i := range splits { + sum += float32(estimate.TensorSplit[i]) + splits[i] = sum + } + + if sum <= 0 { + return nil + } + + // normalize splits + for i := range splits { + splits[i] /= sum + } + + blocks := int(ggml.KV().BlockCount()) + gpuRangeStart := max(0, blocks-numGPU) + gpuRangeStop := min(gpuRangeStart+numGPU, blocks+1) + for i := range blocks + 1 { + if i < gpuRangeStart || i >= gpuRangeStop { + continue + } + + index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f }) + if index < 0 || index >= len(gpus) { + continue + } + + gpuLayers[index].Layers = append(gpuLayers[index].Layers, i) + } + + return gpuLayers +} + +// Load finds the optimal layout of layers to offload on GPUs based on no initial information about the size of the model +// It does this by: +// 1. Assigning the full model to the GPU with the largest available free memory +// 2. Attempting to allocate the layout and receiving the memory requirements in response +// 3. Creating a new layout based on the updated memory information +// 4. Going back to step 2 and looping until we either stabilize on a particular layout or discover that we have entered a cycle +// +// This process is repeated for higher levels of loading the model (fit, allocate, commit). The earlier levels are quicker, +// allowing for faster iteration, but may return less information. +func (s *ollamaServer) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { + var success bool + defer func() { + if !success { + s.initModel(ctx, LoadRequest{}, LoadOperationClose) + } + s.mem.Log(slog.LevelInfo) + }() + + slog.Info("loading model", "model layers", s.totalLayers, "requested", s.options.NumGPU) + + systemInfo := discover.GetSystemInfo() + systemTotalMemory := systemInfo.System.TotalMemory + systemFreeMemory := systemInfo.System.FreeMemory + systemSwapFreeMemory := systemInfo.System.FreeSwap + slog.Info("system memory", "total", format.HumanBytes2(systemTotalMemory), "free", format.HumanBytes2(systemFreeMemory), "free_swap", format.HumanBytes2(systemSwapFreeMemory)) + + if !(len(gpus) == 1 && gpus[0].Library == "cpu") { + for _, gpu := range gpus { + slog.Info("gpu memory", "id", gpu.ID, + "available", format.HumanBytes2(gpu.FreeMemory-envconfig.GpuOverhead()-gpu.MinimumMemory), + "free", format.HumanBytes2(gpu.FreeMemory), + "minimum", format.HumanBytes2(gpu.MinimumMemory), + "overhead", format.HumanBytes2(envconfig.GpuOverhead())) + } + } + + pastAllocations := make(map[uint64]struct{}) + var backoff float32 + + gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) + if err != nil { + return err + } + + if err := s.waitUntilRunnerLaunched(ctx); err != nil { + return err + } + +nextOperation: + for operation := LoadOperationFit; operation < LoadOperationCommit; operation++ { + nextLoad: + for { + s.loadRequest.GPULayers = gpuLayers + resp, err := s.initModel(ctx, s.loadRequest, operation) + if err != nil { + return err + } + + resp.Memory.Log(slog.LevelDebug) + slog.Debug("memory", "success", resp.Success, "required", resp.Memory) + + pastAllocations[gpuLayers.Hash()] = struct{}{} + s.mem = &resp.Memory + + for { + newGPULayers, err := s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) + if err != nil { + return err + } + + slog.Debug("new layout created", "layers", newGPULayers) + + // We get additional memory information over time, which will reduce the number of + // layers that can fit, so fewer layers is actually better. As long as we haven't seen + // this layout before and it doesn't have more layers than the last one, we can keep + // trying to see if we can do better. + if _, ok := pastAllocations[newGPULayers.Hash()]; !ok && newGPULayers.Sum() <= gpuLayers.Sum() { + gpuLayers = newGPULayers + continue nextLoad + } + + // If we are looping around a few different layouts due to graphs moving off and on + // GPUs, make sure that we try out the intermediate states. For example, if we are + // looping between offloading 39 and 41 layers, we should also check 40. + // + // This switches strategies to force an incremental number of layers to be offloaded + // and checking the memory layout. If the allocation succeeds and creating a new layout + // without forcing offload yields the same or greater number of layers offloaded, then + // the trial is successful. + // + // This alternate strategy does not introduce the possibility of loops with the overall + // state machine, as it exits this code block either with a successful result, moving + // to the next operation or the original number of layers offloaded. + if s.options.NumGPU < 0 && newGPULayers.Sum()-gpuLayers.Sum() > 1 { + for i := newGPULayers.Sum() - 1; i >= gpuLayers.Sum(); i-- { + slog.Debug("exploring intermediate layers", "layer", i) + + s.options.NumGPU = i + newGPULayers, err = s.createLayout(systemInfo, gpus, s.mem, requireFull, backoff) + s.options.NumGPU = -1 + if err != nil { + return err + } + + slog.Debug("new layout created", "layers", newGPULayers) + + s.loadRequest.GPULayers = newGPULayers + resp, err = s.initModel(ctx, s.loadRequest, operation) + if err != nil { + return err + } + + resp.Memory.Log(slog.LevelDebug) + slog.Debug("memory", "success", resp.Success, "required", resp.Memory) + + if resp.Success { + verifyGPULayers, err := s.createLayout(systemInfo, gpus, &resp.Memory, requireFull, backoff) + if err != nil { + return err + } + + slog.Debug("verifying layout", "layers", verifyGPULayers) + + if newGPULayers.Sum() <= verifyGPULayers.Sum() { + gpuLayers = newGPULayers + + // Since we are going backwards (increasing the number of layers), ensure that + // we can come back down if needed + clear(pastAllocations) + + continue nextOperation + } + } + } + } + + // If we generated a layout a second time or go backwards, then we've converged. Use the last + // layout before the repeat, which is already allocated. + if resp.Success { + continue nextOperation + } + + if s.options.NumGPU >= 0 { + return fmt.Errorf("memory layout cannot be allocated with num_gpu = %v", s.options.NumGPU) + } + + // Memory allocation failed even though we created a layout that we thought should + // fit in available memory. This could happen if either our free memory reports + // are incorrect or if available memory is changing between layout and allocation + // time. Apply an exponential backoff to try to find the real amount of available + // space. + if backoff > 1 { + slog.Warn("memory layout cannot be allocated", "memory", resp.Memory) + return errors.New("memory layout cannot be allocated") + } else if backoff == 0 { + backoff = 0.01 + } else { + backoff *= 2 + } + + slog.Info("model layout did not fit, applying backoff", "backoff", fmt.Sprintf("%.2f", backoff)) + } + } + } + + s.loadRequest.GPULayers = gpuLayers + resp, err := s.initModel(ctx, s.loadRequest, LoadOperationCommit) + if err != nil { + return err + } + + success = resp.Success + s.mem = &resp.Memory + + if !success { + slog.Warn("failed to commit memory for model", "memory", resp.Memory) + return errors.New("failed to commit memory for model") + } + + return nil +} + +// createLayout uses the current best view of memory requirements and creates a layout of model layers on GPUs. +// It does this by: +// - Calculating how much space each layer requires +// - Calculating how much space each GPU has available for layers, based on free memory and space occupied by the graph +// - Assigning layers +// - Ensuring that we don't exceed limits, such as requirements about partial offloading or system memory +func (s *ollamaServer) createLayout(systemInfo discover.SystemInfo, systemGPUs discover.GpuInfoList, memory *ml.BackendMemory, requireFull bool, backoff float32) (ml.GPULayersList, error) { + if s.totalLayers == 0 || s.options.NumGPU == 0 || len(systemGPUs) == 0 || (len(systemGPUs) == 1 && systemGPUs[0].Library == "cpu") { + return ml.GPULayersList{}, nil + } + + gpus := append(make(discover.GpuInfoList, 0, len(systemGPUs)), systemGPUs...) + sort.Sort(sort.Reverse(discover.ByFreeMemory(gpus))) + + if memory == nil { + memory = &ml.BackendMemory{CPU: ml.DeviceMemory{ + Weights: make([]ml.Memory, s.totalLayers), + Cache: make([]ml.Memory, s.totalLayers), + }} + } + + layers := make([]uint64, len(memory.CPU.Weights)) + for i := range layers { + for j := range memory.GPUs { + layers[i] += memory.GPUs[j].Weights[i].Size + layers[i] += memory.GPUs[j].Cache[i].Size + } + layers[i] += memory.CPU.Weights[i].Size + layers[i] += memory.CPU.Cache[i].Size + slog.Log(context.TODO(), logutil.LevelTrace, "layer to assign", "layer", i, "size", format.HumanBytes2(layers[i])) + } + + gpuLayers := ml.GPULayersList{} + for _, gl := range gpus.ByLibrary() { + // If a GPU already has a graph allocated on it, then we should continue to use it. + // Otherwise, we lose information that we got from previous allocations, which can + // cause cycling. Plus, we get more information about required allocation from each + // iteration, so it doesn't make sense that a later iteration would use fewer GPUs. + lastUsedGPU := 0 + for i := range gl { + found := false + for j := range memory.GPUs { + if gl[i].ID == memory.GPUs[j].ID { + if memory.GPUs[j].Graph.Size != 0 { + lastUsedGPU = i + } + + reserved := uint64(float32(gl[i].FreeMemory)*backoff) + gl[i].MinimumMemory + envconfig.GpuOverhead() + memory.GPUs[j].Graph.Size + if gl[i].FreeMemory > reserved { + gl[i].FreeMemory -= reserved + } else { + gl[i].FreeMemory = 0 + } + + slog.Debug("available gpu", "id", gl[i].ID, + "available layer vram", format.HumanBytes2(gl[i].FreeMemory), + "backoff", fmt.Sprintf("%.2f", backoff), "minimum", format.HumanBytes2(gl[i].MinimumMemory), + "overhead", format.HumanBytes2(envconfig.GpuOverhead()), + "graph", format.HumanBytes2(memory.GPUs[j].Graph.Size)) + + found = true + break + } + } + if !found { + // The runner doesn't report seeing this GPU + gl[i].FreeMemory = 0 + } + } + + libraryGpuLayers := assignLayers(layers, gl, s.options.NumGPU, lastUsedGPU) + if libraryGpuLayers.Sum() > gpuLayers.Sum() { + gpuLayers = libraryGpuLayers + } + } + + // These sizes will only increase as we go through additional iterations and get additional information. + cpuSize := memory.InputWeights.Size + memory.CPU.Graph.Size + var vramSize uint64 + for _, gl := range gpuLayers { + for _, gpu := range memory.GPUs { + if gl.ID == gpu.ID { + vramSize += gpu.Graph.Size + break + } + } + } + +nextLayer: + for i := range layers { + for _, g := range gpuLayers { + for _, gl := range g.Layers { + if i == gl { + vramSize += layers[i] + continue nextLayer + } + } + } + cpuSize += layers[i] + } + + if requireFull { + if gpuLayers.Sum() < len(layers) && (s.options.NumGPU < 0 || gpuLayers.Sum() < s.options.NumGPU) { + return nil, ErrLoadRequiredFull + } + + if cpuSize > systemInfo.System.FreeMemory { + return nil, ErrLoadRequiredFull + } + } + + // On linux and windows, over-allocating CPU memory will almost always result in an error + // Darwin has fully dynamic swap so has no direct concept of free swap space + if runtime.GOOS != "darwin" { + available := systemInfo.System.FreeMemory + systemInfo.System.FreeSwap + if cpuSize > available { + slog.Warn("model request too large for system", "requested", format.HumanBytes2(cpuSize), "available", format.HumanBytes2(available), "total", format.HumanBytes2(systemInfo.System.TotalMemory), "free", format.HumanBytes2(systemInfo.System.FreeMemory), "swap", format.HumanBytes2(systemInfo.System.FreeSwap)) + return nil, fmt.Errorf("model requires more system memory (%s) than is available (%s)", format.HumanBytes2(cpuSize), format.HumanBytes2(available)) + } + } else { + if vramSize > systemInfo.System.TotalMemory { + // disable partial offloading when model is greater than total system memory as this + // can lead to locking up the system + s.options.NumGPU = 0 + gpuLayers = ml.GPULayersList{} + } + } + + if gpuLayers.Sum() == 0 { + slog.Debug("insufficient VRAM to load any model layers") + } + + return gpuLayers, nil +} + +// assignLayers packs the maximum number of layers onto the smallest set of GPUs and comes up with a layer assignment +func assignLayers(layers []uint64, gpus discover.GpuInfoList, requestedLayers int, lastUsedGPU int) (gpuLayers ml.GPULayersList) { + // If we can't fit everything then prefer offloading layers other than the output layer + for range 2 { + // requestedLayers may be -1 if nothing was requested + requestedLayers = min(len(layers), requestedLayers) + + if !envconfig.SchedSpread() { + for i := lastUsedGPU; i < len(gpus); i++ { + // Try to pack things into as few GPUs as possible + forceRequest := i == len(gpus)-1 + gpuLayers = findBestFit(layers, gpus[:i+1], requestedLayers, forceRequest) + if gpuLayers.Sum() == len(layers) || gpuLayers.Sum() == requestedLayers { + break + } + } + } else { + gpuLayers = findBestFit(layers, gpus, requestedLayers, true) + } + + // We only stop if we've gotten all of the layers - even if we got requestedLayers, we still + // might want to try dropping the output layer. + if gpuLayers.Sum() == len(layers) { + return gpuLayers + } + + layers = layers[:len(layers)-1] + } + + return gpuLayers +} + +// findBestFit binary searches to find the smallest capacity factor that can fit +// the max number of layers. The capacity factor is multiplied by the free space on +// each GPU and a small one will force even balancing. +func findBestFit(layers []uint64, gpus discover.GpuInfoList, requestedLayers int, forceRequest bool) (gpuLayers ml.GPULayersList) { + var high float32 = 1 + var low float32 = 0 + + // If we need to fulfill the requested number of layers, pretend we have almost infinite VRAM + if requestedLayers >= 0 && forceRequest { + high = 1000 + } + + bestAssignments := greedyFit(layers, gpus, high, requestedLayers) + maxNumGPU := bestAssignments.Sum() + if maxNumGPU == 0 { + return bestAssignments + } + + for high-low > 1e-6 { + mid := (low + high) / 2 + assignments := greedyFit(layers, gpus, mid, requestedLayers) + if assignments.Sum() == maxNumGPU { + high = mid + bestAssignments = assignments + } else { + low = mid + } + } + + return bestAssignments +} + +// greedyFit assigns layers incrementally to GPUs, spilling over as each runs out of free space +func greedyFit(layers []uint64, gpus discover.GpuInfoList, capacity float32, requestedLayers int) (gpuLayers ml.GPULayersList) { + device := len(gpus) - 1 + gpuLayers = ml.GPULayersList{{ID: gpus[device].ID}} + freeSpace := uint64(float32(gpus[device].FreeMemory) * capacity) + for i := len(layers) - 1; i >= 0; i-- { + if requestedLayers >= 0 && len(layers)-1-i >= requestedLayers { + break + } + + for { + if layers[i] <= freeSpace { + gpuLayers[0].Layers = append([]int{i}, gpuLayers[0].Layers...) + freeSpace -= layers[i] + break + } + + device-- + if device < 0 { + return gpuLayers + } + gpuLayers = append(ml.GPULayersList{{ID: gpus[device].ID}}, gpuLayers...) + freeSpace = uint64(float32(gpus[device].FreeMemory) * capacity) + } + } + + return gpuLayers +} + +// waitUntilRunnerLaunched sleeps until the runner subprocess is alive enough +// to respond to status requests +func (s *llmServer) waitUntilRunnerLaunched(ctx context.Context) error { + for { + _, err := s.getServerStatus(ctx) + if err == nil { + break + } + + t := time.NewTimer(10 * time.Millisecond) + select { + case <-t.C: + continue + case <-ctx.Done(): + return ctx.Err() + } + } + + return nil +} + +// initModel sends a load request to the runner based on the request operation (fit, alloc, commit) +// and parameters +func (s *llmServer) initModel(ctx context.Context, req LoadRequest, operation LoadOperation) (*LoadResponse, error) { + req.Operation = operation + + data, err := json.Marshal(req) + if err != nil { + return nil, fmt.Errorf("error marshaling load data: %w", err) + } + + r, err := http.NewRequestWithContext(ctx, http.MethodPost, fmt.Sprintf("http://127.0.0.1:%d/load", s.port), bytes.NewBuffer(data)) + if err != nil { + return nil, fmt.Errorf("error creating load request: %w", err) + } + r.Header.Set("Content-Type", "application/json") + + resp, err := http.DefaultClient.Do(r) + if err != nil { + return nil, fmt.Errorf("do load request: %w", err) + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + return nil, fmt.Errorf("read load request: %w", err) + } + + if resp.StatusCode >= 400 { + log.Printf("llm load error: %s", body) + return nil, fmt.Errorf("%s", body) + } + + var llmResp LoadResponse + if err := json.Unmarshal(body, &llmResp); err != nil { + return nil, fmt.Errorf("load unmarshal encode response: %w", err) + } + + return &llmResp, nil +} + type ServerStatus int const ( // iota is reset to 0 ServerStatusReady ServerStatus = iota ServerStatusNoSlotsAvailable + ServerStatusLaunched ServerStatusLoadingModel ServerStatusNotResponding ServerStatusError @@ -491,6 +1124,8 @@ func (s ServerStatus) String() string { return "llm server ready" case ServerStatusNoSlotsAvailable: return "llm busy - no slots available" + case ServerStatusLaunched: + return "llm server launched" case ServerStatusLoadingModel: return "llm server loading model" case ServerStatusNotResponding: @@ -551,7 +1186,7 @@ func (s *llmServer) getServerStatus(ctx context.Context) (ServerStatus, error) { case ServerStatusLoadingModel: s.loadProgress = ssr.Progress return ssr.Status, nil - case ServerStatusReady, ServerStatusNoSlotsAvailable: + case ServerStatusLaunched, ServerStatusReady, ServerStatusNoSlotsAvailable: return ssr.Status, nil default: return ssr.Status, fmt.Errorf("server error: %+v", ssr) @@ -591,7 +1226,6 @@ func (s *llmServer) Ping(ctx context.Context) error { } func (s *llmServer) WaitUntilRunning(ctx context.Context) error { - start := time.Now() stallDuration := envconfig.LoadTimeout() // If no progress happens stallTimer := time.Now().Add(stallDuration) // give up if we stall @@ -633,8 +1267,7 @@ func (s *llmServer) WaitUntilRunning(ctx context.Context) error { } switch status { case ServerStatusReady: - s.loadDuration = time.Since(start) - slog.Info(fmt.Sprintf("llama runner started in %0.2f seconds", s.loadDuration.Seconds())) + slog.Info(fmt.Sprintf("llama runner started in %0.2f seconds", time.Since(s.loadStart).Seconds())) return nil default: lastStatus = status @@ -1044,15 +1677,15 @@ func (s *llmServer) Close() error { return nil } -func (s *llmServer) EstimatedVRAM() uint64 { +func (s *llamaServer) VRAMSize() uint64 { return s.estimate.VRAMSize } -func (s *llmServer) EstimatedTotal() uint64 { +func (s *llamaServer) TotalSize() uint64 { return s.estimate.TotalSize } -func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 { +func (s *llamaServer) VRAMByGPU(gpuID string) uint64 { for i, gpu := range s.gpus { if gpu.ID == gpuID { if i < len(s.estimate.GPUSizes) { @@ -1062,3 +1695,59 @@ func (s *llmServer) EstimatedVRAMByGPU(gpuID string) uint64 { } return 0 } + +func (s *ollamaServer) VRAMSize() uint64 { + if s.mem == nil { + return 0 + } + + var mem uint64 + + for _, g := range s.mem.GPUs { + mem += g.Allocated() + } + + // Some elements are always on CPU. However, if we have allocated all layers + // on the GPU then include the CPU components as well, to represent complete offloading. + noCPULayers := true + for i := range s.mem.CPU.Weights { + if s.mem.CPU.Weights[i].Size != 0 || s.mem.CPU.Cache[i].Size != 0 { + noCPULayers = false + break + } + } + if noCPULayers { + mem += s.mem.InputWeights.Size + mem += s.mem.CPU.Graph.Size + } + + return mem +} + +func (s *ollamaServer) TotalSize() uint64 { + if s.mem == nil { + return 0 + } + + mem := s.mem.InputWeights.Size + mem += s.mem.CPU.Allocated() + for _, g := range s.mem.GPUs { + mem += g.Allocated() + } + + return mem +} + +func (s *ollamaServer) VRAMByGPU(gpuID string) uint64 { + if s.mem == nil { + return 0 + } + + for _, g := range s.mem.GPUs { + if g.ID == gpuID { + return g.Allocated() + } + } + + return 0 +} diff --git a/llm/server_test.go b/llm/server_test.go index b6a8705e5a..4eed82bce3 100644 --- a/llm/server_test.go +++ b/llm/server_test.go @@ -8,9 +8,178 @@ import ( "testing" "github.com/ollama/ollama/api" + "github.com/ollama/ollama/discover" + "github.com/ollama/ollama/format" + "github.com/ollama/ollama/ml" "golang.org/x/sync/semaphore" ) +func TestLLMServerFitGPU(t *testing.T) { + type gpu struct { + library string + free int + } + + tests := []struct { + name string + gpus []gpu + layers []int + numGPU int + requireFull bool + expected ml.GPULayersList + expectedErr error + }{ + { + name: "No GPU", + layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: -1, + expected: ml.GPULayersList{}, + }, + { + name: "Full single GPU", + gpus: []gpu{{free: 256 * format.MebiByte}}, + layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: -1, + expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2}}}, + }, + { + name: "Partial single GPU", + gpus: []gpu{{free: 256 * format.MebiByte}}, + layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, + numGPU: -1, + expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1, 2}}}, + }, + { + name: "Single GPU with numGPU 1", + gpus: []gpu{{free: 256 * format.MebiByte}}, + layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: 1, + expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{1}}}, + }, + { + name: "Single GPU with numGPU 0", + gpus: []gpu{{free: 256 * format.MebiByte}}, + layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: 0, + expected: ml.GPULayersList{}, + }, + { + name: "Single GPU with numGPU 999", + gpus: []gpu{{free: 256 * format.MebiByte}}, + layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, + numGPU: 999, + expected: ml.GPULayersList{{ID: "gpu0", Layers: []int{0, 1, 2, 3}}}, + }, + { + name: "Multi GPU fits on one", + gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: -1, + expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1, 2}}}, + }, + { + name: "Multi GPU split", + gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: -1, + expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1, 2}}}, + }, + { + name: "Multi GPU partial", + gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte}, + numGPU: -1, + expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}}, + }, + { + name: "Multi GPU numGPU 1", + gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + layers: []int{50 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: 1, + expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{1}}}, + }, + { + name: "Multi GPU numGPU 2", + gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + layers: []int{256 * format.MebiByte, 50 * format.MebiByte, 50 * format.MebiByte}, + numGPU: 2, + expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0}}, {ID: "gpu0", Layers: []int{1}}}, + }, + { + name: "Multi GPU numGPU 999", + gpus: []gpu{{free: 128 * format.MebiByte}, {free: 256 * format.MebiByte}}, + layers: []int{256 * format.MebiByte, 256 * format.MebiByte, 50 * format.MebiByte}, + numGPU: 999, + expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}, {ID: "gpu0", Layers: []int{2}}}, + }, + { + name: "Multi GPU different libraries", + gpus: []gpu{{library: "cuda", free: 128 * format.MebiByte}, {library: "rocm", free: 256 * format.MebiByte}}, + layers: []int{128 * format.MebiByte, 128 * format.MebiByte, 50 * format.MebiByte}, + numGPU: -1, + expected: ml.GPULayersList{{ID: "gpu1", Layers: []int{0, 1}}}, + }, + { + name: "requireFull", + gpus: []gpu{{free: 256 * format.MebiByte}}, + layers: []int{100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte, 100 * format.MebiByte}, + numGPU: -1, + requireFull: true, + expectedErr: ErrLoadRequiredFull, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + var systemInfo discover.SystemInfo + systemInfo.System.TotalMemory = format.GibiByte + systemInfo.System.FreeMemory = 512 * format.MebiByte + systemInfo.System.FreeSwap = 256 * format.MebiByte + + gpus := make(discover.GpuInfoList, len(tt.gpus)) + for i := range tt.gpus { + gpus[i].ID = fmt.Sprintf("gpu%d", i) + gpus[i].Library = tt.gpus[i].library + gpus[i].FreeMemory = uint64(tt.gpus[i].free) + } + + s := &ollamaServer{ + llmServer: llmServer{ + totalLayers: uint64(len(tt.layers)), + options: api.Options{ + Runner: api.Runner{ + NumGPU: tt.numGPU, + }, + }, + }, + } + + s.mem = &ml.BackendMemory{CPU: ml.DeviceMemory{ + Weights: make([]ml.Memory, s.totalLayers), + Cache: make([]ml.Memory, s.totalLayers), + }, GPUs: make([]ml.DeviceMemory, len(gpus))} + + for i := range tt.layers { + s.mem.CPU.Weights[i].Size = uint64(tt.layers[i]) + } + + for i := range s.mem.GPUs { + s.mem.GPUs[i].ID = fmt.Sprintf("gpu%d", i) + s.mem.GPUs[i].Weights = make([]ml.Memory, s.totalLayers) + s.mem.GPUs[i].Cache = make([]ml.Memory, s.totalLayers) + } + + gpuLayers, err := s.createLayout(systemInfo, gpus, s.mem, tt.requireFull, 0) + if err != tt.expectedErr { + t.Fatalf("fitGPU returned error: %v", err) + } + if gpuLayers.Hash() != tt.expected.Hash() { + t.Errorf("fitGPU assigned %v, want %v", gpuLayers, tt.expected) + } + }) + } +} + func TestLLMServerCompletionFormat(t *testing.T) { // This test was written to fix an already deployed issue. It is a bit // of a mess, and but it's good enough, until we can refactoring the diff --git a/ml/backend.go b/ml/backend.go index 3eb84c726e..638a05d144 100644 --- a/ml/backend.go +++ b/ml/backend.go @@ -5,12 +5,14 @@ import ( "context" "encoding/binary" "fmt" + "hash/maphash" "log/slog" "math" "slices" "strconv" "strings" + "github.com/ollama/ollama/format" "github.com/ollama/ollama/fs" ) @@ -58,19 +60,89 @@ type CacheConfig struct { MaskBatchPadding int } +// GPULayers is a set of layers to be allocated on a single GPU +type GPULayers struct { + // ID is the identifier of the GPU, as reported in DeviceMemory + ID string + + // Layers is a set of layer indicies to load + Layers []int +} + +func (g GPULayers) String() string { + if len(g.Layers) == 0 { + return "" + } + + slices.Sort(g.Layers) + + contiguous := true + base := g.Layers[0] + for i := range g.Layers { + if g.Layers[i] != base+i { + contiguous = false + break + } + } + + if contiguous { + return fmt.Sprintf("ID:%v Layers:%v(%v..%v)", g.ID, len(g.Layers), g.Layers[0], g.Layers[len(g.Layers)-1]) + } else { + return fmt.Sprintf("ID:%v Layers:%v%v", g.ID, len(g.Layers), g.Layers) + } +} + +// GPULayersList is a set of layer allocations across multiple GPUs +type GPULayersList []GPULayers + +func (l GPULayersList) String() string { + if l.Sum() > 0 { + return fmt.Sprintf("%v%v", l.Sum(), []GPULayers(l)) + } else { + return fmt.Sprintf("%v", []GPULayers(l)) + } +} + +// Sum is the total number of layers assigned across all GPUs +func (l GPULayersList) Sum() int { + var sum int + + for _, g := range l { + sum += len(g.Layers) + } + + return sum +} + +var h maphash.Hash + +// Hash is an identifier of this layer assignment +func (l GPULayersList) Hash() uint64 { + h.Reset() + for _, g := range l { + if len(g.Layers) > 0 { + h.WriteString(g.ID) + for _, l := range g.Layers { + binary.Write(&h, binary.NativeEndian, int64(l)) + } + } + } + + return h.Sum64() +} + // BackendParams controls how the backend loads and executes models type BackendParams struct { + // AllocMemory causes the backend to allocate memory for the model. If + // false, this is only being used for discovering the required amount of + // memory and cannot load the model for running. + AllocMemory bool + // NumThreads sets the number of threads to use if running on the CPU NumThreads int - // MainGPU is the index of the primary GPU to use - MainGPU int - - // NumGPULayers is the number of layers to offload to GPUs - NumGPULayers int - - // TensorSplit is the fraction of the model to offload to each GPU - TensorSplit []float32 + // GPULayers is the set of layers to offload to GPUs + GPULayers GPULayersList // FlashAttention indicates that we should use a fused flash attention kernel FlashAttention bool @@ -141,6 +213,28 @@ type DeviceMemory struct { Graph Memory } +// Allocated returns the total size of the memory that has been successfully +// allocated on this device +func (m DeviceMemory) Allocated() uint64 { + var mem uint64 + + for _, w := range m.Weights { + if w.Status == Allocated { + mem += w.Size + } + } + for _, c := range m.Cache { + if c.Status == Allocated { + mem += c.Size + } + } + if m.Graph.Status == Allocated { + mem += m.Graph.Size + } + + return mem +} + func memoryPresent(mem []Memory) bool { return slices.ContainsFunc(mem, func(m Memory) bool { return m.Size != 0 }) } @@ -197,6 +291,58 @@ func (m BackendMemory) LogValue() slog.Value { return slog.GroupValue(attrs...) } +func sumMemory(mem []Memory) uint64 { + var sum uint64 + + for _, m := range mem { + sum += m.Size + } + + return sum +} + +// Log prints a high level summary of the memory (allocated or not) +func (m BackendMemory) Log(level slog.Level) { + var total uint64 + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.InputWeights.Size + sumMemory(m.CPU.Weights); sum > 0 { + slog.Log(context.TODO(), level, "model weights", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := sumMemory(gpu.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := sumMemory(m.CPU.Cache); sum > 0 { + slog.Log(context.TODO(), level, "kv cache", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + for _, gpu := range m.GPUs { + if sum := gpu.Graph.Size; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", gpu.Name, "size", format.HumanBytes2(sum)) + total += sum + } + } + if sum := m.CPU.Graph.Size; sum > 0 { + slog.Log(context.TODO(), level, "compute graph", "device", m.CPU.Name, "size", format.HumanBytes2(sum)) + total += sum + } + + if total > 0 { + slog.Log(context.TODO(), level, "total memory", "size", format.HumanBytes2(total)) + } +} + var backends = make(map[string]func(string, BackendParams) (Backend, error)) func RegisterBackend(name string, f func(string, BackendParams) (Backend, error)) { diff --git a/ml/backend/ggml/ggml.go b/ml/backend/ggml/ggml.go index d64e27fa74..f8121cc530 100644 --- a/ml/backend/ggml/ggml.go +++ b/ml/backend/ggml/ggml.go @@ -10,6 +10,7 @@ import "C" import ( "context" + "errors" "fmt" "io" "log/slog" @@ -62,12 +63,21 @@ var initDevices = sync.OnceFunc(func() { } }) +type layerDevice struct { + d C.ggml_backend_dev_t + bt C.ggml_backend_buffer_type_t +} + type Backend struct { // modelPath is the location of the model data modelPath string meta *fsggml.GGML + // allocMemory means that memory should be allocated for tensors and not + // just a dry run + allocMemory bool + // tensorLoadTargets maps from the name of the tensor in the file // to the name that is used by the model definition tensorLoadTargets map[string][]string @@ -78,11 +88,14 @@ type Backend struct { tensors map[string]*C.struct_ggml_tensor - // input is the backend used for inputs + // input is the backend buffer type used for inputs input C.ggml_backend_buffer_type_t + // output is the backend device used for outputs + output C.ggml_backend_dev_t + // layers is the backend used for repeating layers - layers map[int]C.ggml_backend_buffer_type_t + layers map[int]layerDevice // requiredMemory is the cumulative memory allocations needed by the backend requiredMemory *ml.BackendMemory @@ -99,6 +112,8 @@ type Backend struct { weightBuffers map[*C.struct_ggml_context]C.ggml_backend_buffer_t } +var once sync.Once + func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { r, err := os.Open(modelPath) if err != nil { @@ -111,15 +126,17 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { return nil, err } - slog.Info( - "", - "architecture", meta.KV().Architecture(), - "file_type", meta.KV().FileType(), - "name", meta.KV().String("general.name"), - "description", meta.KV().String("general.description"), - "num_tensors", len(meta.Tensors().Items()), - "num_key_values", len(meta.KV()), - ) + once.Do(func() { + slog.Info( + "", + "architecture", meta.KV().Architecture(), + "file_type", meta.KV().FileType(), + "name", meta.KV().String("general.name"), + "description", meta.KV().String("general.description"), + "num_tensors", len(meta.Tensors().Items()), + "num_key_values", len(meta.KV()), + ) + }) initDevices() @@ -139,7 +156,10 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { switch C.ggml_backend_dev_type(d) { case C.GGML_BACKEND_DEVICE_TYPE_CPU, C.GGML_BACKEND_DEVICE_TYPE_ACCEL: - cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, C.ggml_backend_dev_buffer_type(d)) + bt := C.ggml_backend_dev_buffer_type(d) + cpuDeviceBufferType.bts = append(cpuDeviceBufferType.bts, bt) + C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory)) + btDeviceMemory[C.ggml_backend_dev_buffer_type(d)] = &requiredMemory.CPU } } @@ -160,6 +180,8 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { d: d, bts: append([]C.ggml_backend_buffer_type_t{bt}, cpuDeviceBufferType.bts...), }) + C.ggml_backend_buft_set_alloc(bt, C.bool(params.AllocMemory)) + btDeviceMemory[bt] = &requiredMemory.GPUs[i] requiredMemory.GPUs[i].Name = C.GoString(C.ggml_backend_dev_name(d)) var props C.struct_ggml_backend_dev_props @@ -169,56 +191,25 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { requiredMemory.GPUs[i].Cache = make([]ml.Memory, blocks+1) } - useDefaultSplit := true - for _, s := range params.TensorSplit { - if s != 0 { - useDefaultSplit = false - break - } - } - - // calculate splits - splits := make([]float32, len(gpus)) - if useDefaultSplit { - // default: split on free memory - for i := range splits { - var free, total C.size_t - C.ggml_backend_dev_memory(gpus[i], &free, &total) - splits[i] = float32(free) - } - } else { - splits = params.TensorSplit - } - - var sum float32 - // cumulative sum of all splits - for i := range splits { - sum += splits[i] - splits[i] = sum - } - - // normalize splits - for i := range splits { - splits[i] /= sum - } - // inputs always use cpu input := cpuDeviceBufferType - // define a range of gpu layers. anything outside of this range is assigned to the cpu - gpuRangeStart := max(0, blocks-params.NumGPULayers) - gpuRangeStop := min(gpuRangeStart+params.NumGPULayers, blocks+1) - assignLayer := func(i int) deviceBufferType { - if i < gpuRangeStart || i >= gpuRangeStop { - return cpuDeviceBufferType + assignLayer := func(layer int) deviceBufferType { + for _, p := range params.GPULayers { + for _, l := range p.Layers { + if l == layer { + for i := range requiredMemory.GPUs { + if requiredMemory.GPUs[i].ID == p.ID { + return gpuDeviceBufferTypes[i] + } + } + + return cpuDeviceBufferType + } + } } - index := slices.IndexFunc(splits, func(f float32) bool { return float32(i-gpuRangeStart)/float32(gpuRangeStop-gpuRangeStart) < f }) - if index < 0 || index >= len(gpuDeviceBufferTypes) { - return cpuDeviceBufferType - } - - return gpuDeviceBufferTypes[index] + return cpuDeviceBufferType } // repeating layers are assigned based on their index in reverse order, e.g. i / (block_count + 1) @@ -284,7 +275,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { size := pad(C.ggml_backend_buft_get_alloc_size(bt, tt), C.ggml_backend_buft_get_alignment(bt)) if layer == -1 { // Assume that InputWeights can be allocated - they're always in system memory and can't be moved in any case - requiredMemory.InputWeights.Status = ml.Allocated + if params.AllocMemory { + requiredMemory.InputWeights.Status = ml.Allocated + } requiredMemory.InputWeights.Size += uint64(size) } else { btDeviceMemory[bt].Weights[layer].Size += uint64(size) @@ -355,12 +348,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { } b := C.ggml_backend_alloc_ctx_tensors_from_buft(c, bt) - for i := range btDeviceMemory[bt].Weights { - if btDeviceMemory[bt].Weights[i].Size != 0 { - if b != nil { - btDeviceMemory[bt].Weights[i].Status = ml.Allocated - } else { - btDeviceMemory[bt].Weights[i].Status = ml.Failed + if params.AllocMemory { + for i := range btDeviceMemory[bt].Weights { + if btDeviceMemory[bt].Weights[i].Size != 0 { + if b != nil { + btDeviceMemory[bt].Weights[i].Status = ml.Allocated + } else { + btDeviceMemory[bt].Weights[i].Status = ml.Failed + } } } } @@ -381,28 +376,9 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { bbs[c] = b } - // Mimic llama runner logs summarizing layers and memory - gpuLayers := 0 - for _, layer := range layers { - if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU { - gpuLayers++ - } - } - slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", gpuLayers)) - - switch C.ggml_backend_dev_type(output.d) { - case C.GGML_BACKEND_DEVICE_TYPE_CPU: - slog.Info("offloading output layer to CPU") - case C.GGML_BACKEND_DEVICE_TYPE_GPU: - slog.Info("offloading output layer to GPU") - gpuLayers++ - case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: - slog.Info("offloading output layer to ACCEL") - } - slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(layers)+1)) - for bs := range maps.Values(bbs) { - slog.Info("model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) + slog.Log(context.TODO(), logutil.LevelTrace, "model weights", "buffer", C.GoString(C.ggml_backend_buffer_name(bs)), + "size", format.HumanBytes2(uint64(C.ggml_backend_buffer_get_size(bs)))) } // map tensor names to tensors for easy lookup later @@ -423,6 +399,13 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { b := backends[d] bt := C.ggml_backend_get_default_buffer_type(b) + // Always include CPU as a fallback but otherwise, just use the devices where we assigned layers + if !slices.Contains(cpuDeviceBufferType.bts, bt) { + if c, ok := ctxs[bt]; !ok || C.ggml_get_first_tensor(c) == nil { + continue + } + } + deviceBufferTypes[d] = bt schedBackends = append(schedBackends, b) @@ -437,6 +420,7 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { maxGraphNodes := max(8192, len(meta.Tensors().Items())*5) return &Backend{ modelPath: modelPath, + allocMemory: params.AllocMemory, flashAttention: params.FlashAttention, meta: meta, tensorLoadTargets: targets, @@ -452,10 +436,14 @@ func New(modelPath string, params ml.BackendParams) (ml.Backend, error) { schedBackends: schedBackends, schedBufts: schedBufts, input: deviceBufferTypes[input.d], - layers: func() map[int]C.ggml_backend_buffer_type_t { - m := make(map[int]C.ggml_backend_buffer_type_t) + output: output.d, + layers: func() map[int]layerDevice { + m := make(map[int]layerDevice) for i, layer := range layers { - m[i] = deviceBufferTypes[layer.d] + m[i] = layerDevice{ + d: layer.d, + bt: deviceBufferTypes[layer.d], + } } return m }(), @@ -484,6 +472,30 @@ func (b *Backend) Close() { } func (b *Backend) Load(ctx context.Context, progress func(float32)) error { + if !b.allocMemory { + return errors.New("cannot load model without memory allocation") + } + + // Mimic llama runner logs summarizing layers and memory + gpuLayers := 0 + for layer := range maps.Values(b.layers) { + if C.ggml_backend_dev_type(layer.d) == C.GGML_BACKEND_DEVICE_TYPE_GPU { + gpuLayers++ + } + } + slog.Info(fmt.Sprintf("offloading %d repeating layers to GPU", gpuLayers)) + + switch C.ggml_backend_dev_type(b.output) { + case C.GGML_BACKEND_DEVICE_TYPE_CPU: + slog.Info("offloading output layer to CPU") + case C.GGML_BACKEND_DEVICE_TYPE_GPU: + slog.Info("offloading output layer to GPU") + gpuLayers++ + case C.GGML_BACKEND_DEVICE_TYPE_ACCEL: + slog.Info("offloading output layer to ACCEL") + } + slog.Info(fmt.Sprintf("offloaded %d/%d layers to GPU", gpuLayers, len(b.layers)+1)) + var doneBytes atomic.Uint64 totalBytes := uint64(b.meta.Length) - b.meta.Tensors().Offset @@ -730,11 +742,11 @@ func (c *Context) Input() ml.Context { } func (c *Context) Layer(i int) ml.Context { - if buft, ok := c.b.layers[i]; ok { + if layer, ok := c.b.layers[i]; ok { return &Context{ b: c.b, ctx: c.ctx, - buft: buft, + buft: layer.bt, allocatedBuffers: c.allocatedBuffers, maxGraphNodes: c.maxGraphNodes, layer: i, @@ -792,14 +804,16 @@ func (c *Context) Reserve() { graph := &c.b.btDeviceMemory[c.b.schedBufts[i]].Graph graph.Size += uint64(bufferStatus.size) - if bufferStatus.allocated && graph.Status != ml.Failed { - graph.Status = ml.Allocated - } else { - graph.Status = ml.Failed + if c.b.allocMemory { + if bufferStatus.allocated && graph.Status != ml.Failed { + graph.Status = ml.Allocated + } else { + graph.Status = ml.Failed + } } - slog.Info("compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), - "size", format.HumanBytes2(uint64(bufferStatus.size))) + slog.Log(context.TODO(), logutil.LevelTrace, "compute graph", "backend", C.GoString(C.ggml_backend_name(c.b.schedBackends[i])), + "buffer_type", C.GoString(C.ggml_backend_buft_name(c.b.schedBufts[i])), "size", format.HumanBytes2(uint64(bufferStatus.size))) } if !reserved { @@ -868,10 +882,12 @@ func (c *Context) newTensor(dtype ml.DType, shape []int) ml.Tensor { cache := &c.b.btDeviceMemory[c.buft].Cache[c.layer] cache.Size += uint64(size) - if b != nil { - cache.Status = ml.Allocated - } else { - cache.Status = ml.Failed + if c.b.allocMemory { + if b != nil { + cache.Status = ml.Allocated + } else { + cache.Status = ml.Failed + } } } @@ -890,7 +906,9 @@ func (c *Context) Empty(dtype ml.DType, shape ...int) ml.Tensor { func (c *Context) Zeros(dtype ml.DType, shape ...int) ml.Tensor { t := c.newTensor(dtype, shape) - C.ggml_set_zero(t.(*Tensor).t) + if c.b.allocMemory { + C.ggml_set_zero(t.(*Tensor).t) + } return t } @@ -915,7 +933,7 @@ func (c *Context) FromFloatSlice(s []float32, shape ...int) ml.Tensor { t := c.newTensor(ml.DTypeF32, shape) - if len(s) > 0 { + if c.b.allocMemory && len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } @@ -927,7 +945,7 @@ func (c *Context) FromIntSlice(s []int32, shape ...int) ml.Tensor { t := c.newTensor(ml.DTypeI32, shape) - if len(s) > 0 { + if c.b.allocMemory && len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } @@ -1550,7 +1568,7 @@ func (t *Tensor) Clamp(ctx ml.Context, min, max float32) ml.Tensor { func (c Context) FromBytes(dtype ml.DType, s []uint8, shape ...int) ml.Tensor { // Unchecked to handle quantized types t := c.newTensor(dtype, shape) - if len(s) > 0 { + if c.b.allocMemory && len(s) > 0 { C.ggml_backend_tensor_set(t.(*Tensor).t, unsafe.Pointer(&s[0]), 0, C.ggml_nbytes(t.(*Tensor).t)) } diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index f1e9c18014..3040b2aa1e 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -581,16 +581,8 @@ void ggml_backend_load_all_from_path(const char * dir_path) { ggml_backend_load_best("blas", silent, dir_path); ggml_backend_load_best("cann", silent, dir_path); - - // Avoid mixed hip+cuda configurations - const char * hip_devices = std::getenv("HIP_VISIBLE_DEVICES"); - const char * rocr_devices = std::getenv("ROCR_VISIBLE_DEVICES"); - if (!hip_devices && !rocr_devices) { - ggml_backend_load_best("cuda", silent, dir_path); - } else { - ggml_backend_load_best("hip", silent, dir_path); - } - + ggml_backend_load_best("cuda", silent, dir_path); + ggml_backend_load_best("hip", silent, dir_path); ggml_backend_load_best("metal", silent, dir_path); ggml_backend_load_best("rpc", silent, dir_path); ggml_backend_load_best("sycl", silent, dir_path); diff --git a/runner/llamarunner/runner.go b/runner/llamarunner/runner.go index 7aa9b96a28..791492bbbd 100644 --- a/runner/llamarunner/runner.go +++ b/runner/llamarunner/runner.go @@ -12,7 +12,6 @@ import ( "net/http" "os" "regexp" - "runtime" "strconv" "strings" "sync" @@ -216,6 +215,12 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input, error) } type Server struct { + // modelPath is the location of the model to be loaded + modelPath string + + // loadMu prevents more than one load attempt from occurring at a time + loadMu sync.Mutex + // is the server ready to process requests? // protects access to model and image ready sync.WaitGroup @@ -723,21 +728,12 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) { } } -type multiLPath []string - -func (m *multiLPath) Set(value string) error { - *m = append(*m, value) - return nil -} - -func (m *multiLPath) String() string { - return strings.Join(*m, ", ") -} - +// loadModel allocates memory based on the given parameters and loads the weights. The +// memory allocated is worst case for text models but not for vision. func (s *Server) loadModel( params llama.ModelParams, mpath string, - lpath multiLPath, + lpath []string, ppath string, kvSize int, kvCacheType string, @@ -757,12 +753,10 @@ func (s *Server) loadModel( panic(err) } - if lpath.String() != "" { - for _, path := range lpath { - err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads) - if err != nil { - panic(err) - } + for _, path := range lpath { + err := s.model.ApplyLoraFromFile(s.lc, path, 1.0, threads) + if err != nil { + panic(err) } } @@ -783,26 +777,81 @@ func (s *Server) loadModel( s.ready.Done() } +// load is the handler called by the Ollama server to process different +// load operations +func (s *Server) load(w http.ResponseWriter, r *http.Request) { + s.loadMu.Lock() + defer s.loadMu.Unlock() + + w.Header().Set("Content-Type", "application/json") + + if s.status != llm.ServerStatusLaunched { + http.Error(w, "model already loaded", http.StatusInternalServerError) + return + } + + var req llm.LoadRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + slog.Info("load", "request", req) + + switch req.Operation { + // LoadOperationFit and LoadOperationAlloc have no meaning here - just return a successful response + + case llm.LoadOperationCommit: + s.batchSize = req.BatchSize + s.parallel = req.Parallel + s.seqs = make([]*Sequence, s.parallel) + s.seqsSem = semaphore.NewWeighted(int64(s.parallel)) + + gpuIDs := llama.EnumerateGPUs() + tensorSplit := make([]float32, len(gpuIDs)) + numGPU := 0 + for i := range gpuIDs { + for _, layers := range req.GPULayers { + if gpuIDs[i] == layers.ID { + tensorSplit[i] = float32(len(layers.Layers)) + numGPU += len(layers.Layers) + } + } + } + + params := llama.ModelParams{ + NumGpuLayers: numGPU, + MainGpu: req.MainGPU, + UseMmap: req.UseMmap && len(req.LoraPath) == 0, + TensorSplit: tensorSplit, + Progress: func(progress float32) { + s.progress = progress + }, + } + + s.status = llm.ServerStatusLoadingModel + go s.loadModel(params, s.modelPath, req.LoraPath, req.ProjectorPath, req.KvSize, req.KvCacheType, req.FlashAttention, req.NumThreads, req.MultiUserCache) + + case llm.LoadOperationClose: + // No-op for us + if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } + return + } + + resp := llm.LoadResponse{Success: true} + if err := json.NewEncoder(w).Encode(&resp); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + return + } +} + func Execute(args []string) error { fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") - ppath := fs.String("mmproj", "", "Path to projector binary file") - parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously") - batchSize := fs.Int("batch-size", 512, "Batch size") - nGpuLayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") - mainGpu := fs.Int("main-gpu", 0, "Main GPU") - flashAttention := fs.Bool("flash-attn", false, "Enable flash attention") - kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") - kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") port := fs.Int("port", 8080, "Port to expose the server on") - threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") _ = fs.Bool("verbose", false, "verbose output (default: disabled)") - noMmap := fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") - tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") - multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") - - var lpaths multiLPath - fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)") fs.Usage = func() { fmt.Fprintf(fs.Output(), "Runner usage\n") @@ -817,35 +866,11 @@ func Execute(args []string) error { llama.BackendInit() server := &Server{ - batchSize: *batchSize, - parallel: *parallel, - seqs: make([]*Sequence, *parallel), - seqsSem: semaphore.NewWeighted(int64(*parallel)), - status: llm.ServerStatusLoadingModel, - } - - var tensorSplitFloats []float32 - if *tensorSplit != "" { - splits := strings.Split(*tensorSplit, ",") - tensorSplitFloats = make([]float32, len(splits)) - for i, s := range splits { - f, _ := strconv.ParseFloat(s, 32) - tensorSplitFloats[i] = float32(f) - } - } - - params := llama.ModelParams{ - NumGpuLayers: *nGpuLayers, - MainGpu: *mainGpu, - UseMmap: !*noMmap && lpaths.String() == "", - TensorSplit: tensorSplitFloats, - Progress: func(progress float32) { - server.progress = progress - }, + modelPath: *mpath, + status: llm.ServerStatusLaunched, } server.ready.Add(1) - go server.loadModel(params, *mpath, lpaths, *ppath, *kvSize, *kvCacheType, *flashAttention, *threads, *multiUserCache) server.cond = sync.NewCond(&server.mu) @@ -863,6 +888,7 @@ func Execute(args []string) error { defer listener.Close() mux := http.NewServeMux() + mux.HandleFunc("POST /load", server.load) mux.HandleFunc("/embedding", server.embeddings) mux.HandleFunc("/completion", server.completion) mux.HandleFunc("/health", server.health) diff --git a/runner/ollamarunner/runner.go b/runner/ollamarunner/runner.go index cebe30deff..2f41f68f22 100644 --- a/runner/ollamarunner/runner.go +++ b/runner/ollamarunner/runner.go @@ -14,6 +14,7 @@ import ( "net" "net/http" "os" + "reflect" "regexp" "runtime" "strconv" @@ -259,6 +260,16 @@ func (s *Server) inputs(prompt string, images []llm.ImageData) ([]input.Input, [ } type Server struct { + // modelPath is the location of the model to be loaded + modelPath string + + // loadMu prevents more than one load attempt from occurring at a time + loadMu sync.Mutex + + // lastLoad is the load request from the previous load attempt. Used to + // detect if we can reuse an existing memory allocation. + lastLoad llm.LoadRequest + // is the server ready to process requests? // protects access to model and image ready sync.WaitGroup @@ -720,17 +731,6 @@ func (s *Server) health(w http.ResponseWriter, r *http.Request) { } } -type multiLPath []string - -func (m *multiLPath) Set(value string) error { - *m = append(*m, value) - return nil -} - -func (m *multiLPath) String() string { - return strings.Join(*m, ", ") -} - func (s *Server) reserveWorstCaseGraph() error { ctx := s.model.Backend().NewContext() defer ctx.Close() @@ -828,15 +828,28 @@ func (s *Server) reserveWorstCaseGraph() error { return nil } -func (s *Server) initModel( +// allocModel pre-allocates the maximum needed memory for a model +// based on the given parameters +func (s *Server) allocModel( mpath string, params ml.BackendParams, - lpath multiLPath, + loraPath []string, parallel int, kvCacheType string, kvSize int, multiUserCache bool, -) error { +) (panicErr error) { + // Convert memory allocation panics to errors + defer func() { + if r := recover(); r != nil { + if err, ok := r.(error); ok { + panicErr = err + } else { + panic(r) + } + } + }() + var err error s.model, err = model.New(mpath, params) if err != nil { @@ -844,7 +857,7 @@ func (s *Server) initModel( } // TODO(jessegross): LoRA loading - if lpath.String() != "" { + if len(loraPath) > 0 { return errors.New("loras are not yet implemented") } @@ -865,63 +878,122 @@ func (s *Server) initModel( return s.reserveWorstCaseGraph() } -func (s *Server) load( - ctx context.Context, - mpath string, - params ml.BackendParams, - lpath multiLPath, - parallel int, - kvCacheType string, - kvSize int, - multiUserCache bool, -) { - err := s.initModel(mpath, params, lpath, parallel, kvCacheType, kvSize, multiUserCache) - if err != nil { - var noMem ml.ErrNoMem - if errors.As(err, &noMem) { - // We can't yet handle this but in the future we will - s.cache.Close() - if s.model != nil { - s.model.Backend().Close() - } - } - - panic(err) +// closeModel frees all memory associated with a model +func (s *Server) closeModel() { + s.cache.Close() + s.cache = nil + if s.model != nil { + s.model.Backend().Close() + s.model = nil } +} - slog.Debug("memory", "allocated", s.model.Backend().BackendMemory()) - - err = s.model.Backend().Load(ctx, +// loadModel loads the weights for a model. The memory must already +// have been allocated with allocModel +func (s *Server) loadModel() { + err := s.model.Backend().Load(context.TODO(), func(progress float32) { s.progress = progress }) if err != nil { - panic(err) + panic(fmt.Errorf("failed to load model: %v", err)) } s.status = llm.ServerStatusReady s.ready.Done() } +// load is the handler called by the Ollama server to process different +// load operations +func (s *Server) load(w http.ResponseWriter, r *http.Request) { + s.loadMu.Lock() + defer s.loadMu.Unlock() + + w.Header().Set("Content-Type", "application/json") + + if s.status != llm.ServerStatusLaunched { + http.Error(w, "model already loaded", http.StatusInternalServerError) + return + } + + var req llm.LoadRequest + if err := json.NewDecoder(r.Body).Decode(&req); err != nil { + http.Error(w, "bad request", http.StatusBadRequest) + return + } + + slog.Info("load", "request", req) + + if req.Operation == llm.LoadOperationClose { + s.closeModel() + if err := json.NewEncoder(w).Encode(&llm.LoadResponse{}); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } + return + } + + s.lastLoad.Operation = req.Operation + loadModel := s.model == nil || !reflect.DeepEqual(req, s.lastLoad) + + s.lastLoad = req + + if loadModel { + s.closeModel() + + params := ml.BackendParams{ + AllocMemory: req.Operation != llm.LoadOperationFit, + NumThreads: req.NumThreads, + GPULayers: req.GPULayers, + FlashAttention: req.FlashAttention, + } + + s.batchSize = req.BatchSize + + err := s.allocModel(s.modelPath, params, req.LoraPath, req.Parallel, req.KvCacheType, req.KvSize, req.MultiUserCache) + if err != nil { + s.closeModel() + + var noMem ml.ErrNoMem + if errors.As(err, &noMem) { + resp := llm.LoadResponse{Success: false, Memory: noMem.BackendMemory} + if err := json.NewEncoder(w).Encode(&resp); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + } + + return + } + + http.Error(w, fmt.Sprintf("failed to initialize model: %v", err), http.StatusInternalServerError) + return + } + } + + mem := s.model.Backend().BackendMemory() + + switch req.Operation { + case llm.LoadOperationFit: + // LoadOperationFit can't be used for anything else, so just close it + s.closeModel() + + // LoadOperationAlloc should stay open for future operations + + case llm.LoadOperationCommit: + s.status = llm.ServerStatusLoadingModel + go s.loadModel() + } + + resp := llm.LoadResponse{Success: true, Memory: mem} + if err := json.NewEncoder(w).Encode(&resp); err != nil { + http.Error(w, fmt.Sprintf("failed to encode response: %v", err), http.StatusInternalServerError) + return + } +} + func Execute(args []string) error { fs := flag.NewFlagSet("runner", flag.ExitOnError) mpath := fs.String("model", "", "Path to model binary file") - parallel := fs.Int("parallel", 1, "Number of sequences to handle simultaneously") - batchSize := fs.Int("batch-size", 512, "Batch size") - numGPULayers := fs.Int("n-gpu-layers", 0, "Number of layers to offload to GPU") - mainGPU := fs.Int("main-gpu", 0, "Main GPU") - flashAttention := fs.Bool("flash-attn", false, "Enable flash attention") - kvSize := fs.Int("ctx-size", 2048, "Context (or KV cache) size") - kvCacheType := fs.String("kv-cache-type", "", "quantization type for KV cache (default: f16)") port := fs.Int("port", 8080, "Port to expose the server on") - threads := fs.Int("threads", runtime.NumCPU(), "Number of threads to use during generation") _ = fs.Bool("verbose", false, "verbose output (default: disabled)") - _ = fs.Bool("no-mmap", false, "do not memory-map model (slower load but may reduce pageouts if not using mlock)") - tensorSplit := fs.String("tensor-split", "", "fraction of the model to offload to each GPU, comma-separated list of proportions") - multiUserCache := fs.Bool("multiuser-cache", false, "optimize input cache algorithm for multiple users") - - var lpaths multiLPath - fs.Var(&lpaths, "lora", "Path to lora layer file (can be specified multiple times)") fs.Usage = func() { fmt.Fprintf(fs.Output(), "Runner usage\n") @@ -933,39 +1005,17 @@ func Execute(args []string) error { slog.SetDefault(logutil.NewLogger(os.Stderr, envconfig.LogLevel())) slog.Info("starting ollama engine") + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + server := &Server{ - batchSize: *batchSize, - status: llm.ServerStatusLoadingModel, + modelPath: *mpath, + status: llm.ServerStatusLaunched, } server.cond = sync.NewCond(&server.mu) server.ready.Add(1) - ctx, cancel := context.WithCancel(context.Background()) - defer cancel() - - // TODO(jessegross): Parameters that need to be implemented: - // no-mmap - - var tensorSplitFloats []float32 - if *tensorSplit != "" { - splits := strings.Split(*tensorSplit, ",") - tensorSplitFloats = make([]float32, len(splits)) - for i, s := range splits { - f, _ := strconv.ParseFloat(s, 32) - tensorSplitFloats[i] = float32(f) - } - } - - params := ml.BackendParams{ - NumThreads: *threads, - NumGPULayers: *numGPULayers, - MainGPU: *mainGPU, - TensorSplit: tensorSplitFloats, - FlashAttention: *flashAttention, - } - - go server.load(ctx, *mpath, params, lpaths, *parallel, *kvCacheType, *kvSize, *multiUserCache) go server.run(ctx) addr := "127.0.0.1:" + strconv.Itoa(*port) @@ -978,6 +1028,7 @@ func Execute(args []string) error { mux := http.NewServeMux() // TODO: support embeddings + mux.HandleFunc("POST /load", server.load) mux.HandleFunc("POST /embedding", func(w http.ResponseWriter, r *http.Request) { http.Error(w, "this model does not support embeddings", http.StatusNotImplemented) }) diff --git a/server/routes.go b/server/routes.go index 5d37c83a90..99b1b300ae 100644 --- a/server/routes.go +++ b/server/routes.go @@ -1477,14 +1477,14 @@ func (s *Server) PsHandler(c *gin.Context) { mr := api.ProcessModelResponse{ Model: model.ShortName, Name: model.ShortName, - Size: int64(v.estimatedTotal), - SizeVRAM: int64(v.estimatedVRAM), + Size: int64(v.totalSize), + SizeVRAM: int64(v.vramSize), Digest: model.Digest, Details: modelDetails, ExpiresAt: v.expiresAt, } if v.Options != nil { - mr.ContextLength = v.Options.NumCtx / v.numParallel + mr.ContextLength = v.Options.NumCtx } // The scheduler waits to set expiresAt, so if a model is loading it's // possible that it will be set to the unix epoch. For those cases, just diff --git a/server/routes_generate_test.go b/server/routes_generate_test.go index 506071edfa..a57975f16b 100644 --- a/server/routes_generate_test.go +++ b/server/routes_generate_test.go @@ -77,12 +77,13 @@ func TestGenerateChat(t *testing.T) { getGpuFn: discover.GetGPUInfo, getCpuFn: discover.GetCPUInfo, reschedDelay: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, } + return false }, }, } @@ -620,12 +621,13 @@ func TestGenerate(t *testing.T) { getGpuFn: discover.GetGPUInfo, getCpuFn: discover.GetCPUInfo, reschedDelay: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { // add small delay to simulate loading time.Sleep(time.Millisecond) req.successCh <- &runnerRef{ llama: &mock, } + return false }, }, } diff --git a/server/routes_harmony_streaming_test.go b/server/routes_harmony_streaming_test.go index 1b86f84c12..b1ede4e39e 100644 --- a/server/routes_harmony_streaming_test.go +++ b/server/routes_harmony_streaming_test.go @@ -277,10 +277,11 @@ func TestChatHarmonyParserStreamingRealtime(t *testing.T) { getGpuFn: discover.GetGPUInfo, getCpuFn: discover.GetCPUInfo, reschedDelay: 100 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ llama: &mock, } + return false }, }, } @@ -427,10 +428,11 @@ func TestChatHarmonyParserStreamingSimple(t *testing.T) { getGpuFn: discover.GetGPUInfo, getCpuFn: discover.GetCPUInfo, reschedDelay: 100 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ llama: &mock, } + return false }, }, } @@ -608,10 +610,11 @@ func TestChatHarmonyParserStreaming(t *testing.T) { getGpuFn: discover.GetGPUInfo, getCpuFn: discover.GetCPUInfo, reschedDelay: 250 * time.Millisecond, - loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ int) { + loadFn: func(req *LlmRequest, _ *ggml.GGML, _ discover.GpuInfoList, _ bool) bool { req.successCh <- &runnerRef{ llama: &mock, } + return false }, }, } diff --git a/server/sched.go b/server/sched.go index 40e6e5f722..c501c0e85d 100644 --- a/server/sched.go +++ b/server/sched.go @@ -28,7 +28,6 @@ type LlmRequest struct { ctx context.Context //nolint:containedctx model *Model opts api.Options - origNumCtx int // Track the initial ctx request sessionDuration *api.Duration successCh chan *runnerRef errCh chan error @@ -41,10 +40,17 @@ type Scheduler struct { expiredCh chan *runnerRef unloadedCh chan any - loaded map[string]*runnerRef + // loadedMu protects loaded and activeLoading loadedMu sync.Mutex - loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int) + // activeLoading is the model that we are currently working on loading, + // including by evicting one or more other models. We can only load + // one model at a time but new requests to models that already loaded can + // happen in parallel + activeLoading llm.LlamaServer + loaded map[string]*runnerRef + + loadFn func(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool newServerFn func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) getGpuFn func() discover.GpuInfoList getCpuFn func() discover.GpuInfoList @@ -56,9 +62,6 @@ type Scheduler struct { // on a large GPU can cause stalling var defaultModelsPerGPU = 3 -// Default automatic value for parallel setting -var defaultParallel = 1 - var ErrMaxQueue = errors.New("server busy, please try again. maximum pending requests exceeded") func InitScheduler(ctx context.Context) *Scheduler { @@ -79,24 +82,36 @@ func InitScheduler(ctx context.Context) *Scheduler { } // context must be canceled to decrement ref count and release the runner -func (s *Scheduler) GetRunner(c context.Context, model *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) { +func (s *Scheduler) GetRunner(c context.Context, m *Model, opts api.Options, sessionDuration *api.Duration) (chan *runnerRef, chan error) { if opts.NumCtx < 4 { opts.NumCtx = 4 } + if m.CheckCapabilities(model.CapabilityVision) == nil { + // multimodal models require at least 2048 context + opts.NumCtx = max(opts.NumCtx, 2048) + } + req := &LlmRequest{ ctx: c, - model: model, + model: m, opts: opts, sessionDuration: sessionDuration, - successCh: make(chan *runnerRef), + successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), } - select { - case s.pendingReqCh <- req: - default: - req.errCh <- ErrMaxQueue + s.loadedMu.Lock() + runner := s.loaded[req.model.ModelPath] + s.loadedMu.Unlock() + if runner != nil && !runner.needsReload(c, req) { + req.useLoadedRunner(runner, s.finishedReqCh) + } else { + select { + case s.pendingReqCh <- req: + default: + req.errCh <- ErrMaxQueue + } } return req.successCh, req.errCh } @@ -122,21 +137,11 @@ func (s *Scheduler) processPending(ctx context.Context) { case pending := <-s.pendingReqCh: // Block other requests until we get this pending request running pending.schedAttempts++ - if pending.origNumCtx == 0 { - pending.origNumCtx = pending.opts.NumCtx - } if pending.ctx.Err() != nil { slog.Debug("pending request cancelled or timed out, skipping scheduling") continue } - numParallel := int(envconfig.NumParallel()) - // `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1 - // ref: https://github.com/ollama/ollama/issues/4165 - if slices.Contains(pending.model.Config.ModelFamilies, "mllama") && numParallel != 1 { - numParallel = 1 - slog.Warn("mllama does not currently support parallel requests") - } for { var runnerToExpire *runnerRef @@ -195,84 +200,26 @@ func (s *Scheduler) processPending(ctx context.Context) { break } - // Embedding models should always be loaded with parallel=1 - if pending.model.CheckCapabilities(model.CapabilityCompletion) != nil { - numParallel = 1 - } + // Update free memory from currently loaded models + s.updateFreeSpace(gpus) - // Evaluate if the model will fit in the available system memory, or if we should unload a model first - if len(gpus) == 1 && gpus[0].Library == "cpu" { - // simplifying assumption of defaultParallel when in CPU mode - if numParallel <= 0 { - numParallel = defaultParallel - } - - pending.opts.NumCtx = pending.origNumCtx * numParallel - - if loadedCount == 0 { - slog.Debug("cpu mode with first model, loading") - s.loadFn(pending, ggml, gpus, numParallel) - break - } - runnerToExpire = s.maybeFindCPURunnerToUnload(pending, ggml, gpus) - if runnerToExpire == nil { - slog.Debug("cpu mode with available system memory or first model, loading") - s.loadFn(pending, ggml, gpus, numParallel) - break - } - // else we need to expire a runner - } else if loadedCount == 0 { + if loadedCount == 0 { // No models loaded. Load the model but prefer the best fit. slog.Debug("loading first model", "model", pending.model.ModelPath) - g := pickBestFullFitByLibrary(pending, ggml, gpus, &numParallel) - if g != nil { - gpus = g - } else { - // Only allow partial loads when this is the first model - gpus = pickBestPartialFitByLibrary(pending, ggml, gpus, &numParallel) - } - s.loadFn(pending, ggml, gpus, numParallel) + s.loadFn(pending, ggml, gpus, false) break } - if runnerToExpire == nil { - // More than one loaded model, so we have to see if the - // new one fits - // - // We want to avoid loading on any GPUs that have other - // models still loading on them to avoid potential races - // with VRAM consumption ramping up during load - availGpus := s.filterGPUsWithoutLoadingModels(gpus) + // More than one loaded model, so we have to see if the + // new one fits - // Update free memory from currently loaded models - s.updateFreeSpace(availGpus) - fitGpus := pickBestFullFitByLibrary(pending, ggml, availGpus, &numParallel) - if fitGpus != nil { - slog.Debug("new model fits with existing models, loading") - s.loadFn(pending, ggml, fitGpus, numParallel) - break - } - - // We couldn't find a set of GPUs to fully load the new - // model. If no other models are loading (both GPU lists - // are the same) then we need to unload another model to - // make room - if len(availGpus) < len(gpus) { - // There are other requests pending, and this one - // needs more time, so put it on the back of the - // queue so that we might satisfy other pending - // requests that aren't blocked - go func() { - // Process in a go routine to avoid deadlocking - // the scheduler if our queue is full - slog.Debug("delaying scheduling while other models finish loading", "attempts", pending.schedAttempts, "model", pending.model.ModelPath) - time.Sleep(s.reschedDelay) - s.pendingReqCh <- pending - }() - break - } - runnerToExpire = s.findRunnerToUnload() + needEvict := s.loadFn(pending, ggml, gpus, true) + if !needEvict { + slog.Debug("new model fits with existing models, loading") + break } + + runnerToExpire = s.findRunnerToUnload() } if runnerToExpire == nil { @@ -293,8 +240,6 @@ func (s *Scheduler) processPending(ctx context.Context) { } runnerToExpire.refMu.Unlock() // Wait for the unload to happen - // Note: at this point we're queueing up all incoming requests, even if they were for - // a different model that's loaded and not scheduled to be removed. slog.Debug("waiting for pending requests to complete and unload to occur", "runner", runnerToExpire) select { case <-ctx.Done(): @@ -434,26 +379,72 @@ func (pending *LlmRequest) useLoadedRunner(runner *runnerRef, finished chan *Llm }() } -func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel int) { +// load creates a new model based on req and loads it. If requireFull is true then the model must be loaded fully onto GPUs +// (if any). Returns whether the scheduler needs to evict a model to make this one fit. +func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, requireFull bool) bool { + numParallel := int(envconfig.NumParallel()) if numParallel < 1 { numParallel = 1 } + + // Embedding models should always be loaded with parallel=1 + if req.model.CheckCapabilities(model.CapabilityCompletion) != nil { + numParallel = 1 + } + + // `mllama` is a snowflake and uses an encoder cache which cannot be used with num_parallel > 1 + // ref: https://github.com/ollama/ollama/issues/4165 + if slices.Contains(req.model.Config.ModelFamilies, "mllama") && numParallel != 1 { + numParallel = 1 + slog.Warn("mllama does not currently support parallel requests") + } + sessionDuration := envconfig.KeepAlive() if req.sessionDuration != nil { sessionDuration = req.sessionDuration.Duration } - llama, err := s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) - if err != nil { - // some older models are not compatible with newer versions of llama.cpp - // show a generalized compatibility error until there is a better way to - // check for model compatibility - if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") { - err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) + + s.loadedMu.Lock() + llama := s.activeLoading + + if llama == nil { + var err error + llama, err = s.newServerFn(gpus, req.model.ModelPath, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, numParallel) + if err != nil { + // some older models are not compatible with newer versions of llama.cpp + // show a generalized compatibility error until there is a better way to + // check for model compatibility + if errors.Is(err, ggml.ErrUnsupportedFormat) || strings.Contains(err.Error(), "failed to load model") { + err = fmt.Errorf("%v: this model may be incompatible with your version of Ollama. If you previously pulled this model, try updating it by running `ollama pull %s`", err, req.model.ShortName) + } + slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) + req.errCh <- err + s.loadedMu.Unlock() + return false + } + + s.activeLoading = llama + } else { + if s.activeLoading.ModelPath() != req.model.ModelPath { + panic(fmt.Errorf("attempting to load different model after eviction (original %v new %v)", s.activeLoading.ModelPath(), req.model.ModelPath)) } - slog.Info("NewLlamaServer failed", "model", req.model.ModelPath, "error", err) - req.errCh <- err - return } + + s.loadedMu.Unlock() + + err := llama.Load(req.ctx, gpus, requireFull) + if err != nil { + if errors.Is(err, llm.ErrLoadRequiredFull) { + return true + } + + slog.Info("Load failed", "model", req.model.ModelPath, "error", err) + s.activeLoading.Close() + s.activeLoading = nil + req.errCh <- err + return false + } + runner := &runnerRef{ model: req.model, modelPath: req.model.ModelPath, @@ -461,8 +452,8 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis Options: &req.opts, sessionDuration: sessionDuration, gpus: gpus, - estimatedVRAM: llama.EstimatedVRAM(), - estimatedTotal: llama.EstimatedTotal(), + vramSize: llama.VRAMSize(), + totalSize: llama.TotalSize(), loading: true, pid: llama.Pid(), } @@ -477,6 +468,7 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis oldRunner.unload() oldRunner.refMu.Unlock() } + s.activeLoading = nil s.loaded[req.model.ModelPath] = runner slog.Info("loaded runners", "count", len(s.loaded)) s.loadedMu.Unlock() @@ -503,6 +495,8 @@ func (s *Scheduler) load(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoLis }() req.successCh <- runner }() + + return false } func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) { @@ -521,7 +515,7 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) { r.refMu.Lock() if r.llama != nil { for _, gpu := range allGpus { - predMap[predKey{gpu.Library, gpu.ID}] += r.llama.EstimatedVRAMByGPU(gpu.ID) + predMap[predKey{gpu.Library, gpu.ID}] += r.llama.VRAMByGPU(gpu.ID) } } else { slog.Warn("unexpected nil runner reference, memory prediction may be incorrect") @@ -548,41 +542,17 @@ func (s *Scheduler) updateFreeSpace(allGpus discover.GpuInfoList) { } } -// While models are loading the VRAM consumption numbers will be indeterminate, so we have -// to avoid scheduling another model on the same GPU(s) that haven't stabilized. -// This routine returns the set of GPUs that do not have an active loading model. -// If all GPUs have loading models, an empty list will be returned (not a single CPU entry) -func (s *Scheduler) filterGPUsWithoutLoadingModels(allGpus discover.GpuInfoList) discover.GpuInfoList { - ret := append(discover.GpuInfoList{}, allGpus...) - s.loadedMu.Lock() - defer s.loadedMu.Unlock() - for _, runner := range s.loaded { - if runner.loading { - slog.Debug("overlapping loads detected", "gpus", runner.gpus, "model", runner.modelPath) - for _, busyGPU := range runner.gpus { - for i := range ret { - if ret[i].ID == busyGPU.ID { - ret = append(ret[:i], ret[i+1:]...) - break - } - } - } - } - } - return ret -} - // TODO consolidate sched_types.go type runnerRef struct { refMu sync.Mutex refCount uint // prevent unloading if > 0 - llama llm.LlamaServer - pid int - loading bool // True only during initial load, then false forever - gpus discover.GpuInfoList // Recorded at time of provisioning - estimatedVRAM uint64 - estimatedTotal uint64 + llama llm.LlamaServer + pid int + loading bool // True only during initial load, then false forever + gpus discover.GpuInfoList // Recorded at time of provisioning + vramSize uint64 + totalSize uint64 sessionDuration time.Duration expireTimer *time.Timer @@ -631,9 +601,6 @@ func (runner *runnerRef) needsReload(ctx context.Context, req *LlmRequest) bool optsNew.NumGPU = -1 } - // Normalize the NumCtx for parallelism - optsExisting.NumCtx = optsExisting.NumCtx / runner.numParallel - ctx, cancel := context.WithTimeout(ctx, timeout) defer cancel() if !reflect.DeepEqual(runner.model.AdapterPaths, req.model.AdapterPaths) || // have the adapters changed? @@ -694,7 +661,7 @@ func (runner *runnerRef) waitForVRAMRecovery() chan any { freeMemoryNow += gpu.FreeMemory } // If we're within ~80% of the estimated memory usage recovered, bail out - if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.estimatedVRAM)*0.8 { + if float32(freeMemoryNow-freeMemoryBefore) > float32(runner.vramSize)*0.8 { slog.Debug(fmt.Sprintf("gpu VRAM free memory converged after %0.2f seconds", time.Since(start).Seconds()), "runner", runner) finished <- struct{}{} return @@ -719,8 +686,8 @@ func (runner *runnerRef) LogValue() slog.Value { ) } attrs = append(attrs, - slog.String("size", format.HumanBytes2(runner.estimatedTotal)), - slog.String("vram", format.HumanBytes2(runner.estimatedVRAM)), + slog.String("size", format.HumanBytes2(runner.totalSize)), + slog.String("vram", format.HumanBytes2(runner.vramSize)), slog.Int("parallel", runner.numParallel), slog.Int("pid", runner.pid), slog.String("model", runner.modelPath), @@ -750,95 +717,7 @@ func (a ByDurationAndName) Less(i, j int) bool { // type BySize []*runnerRef // func (a BySize) Len() int { return len(a) } // func (a BySize) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -// func (a BySize) Less(i, j int) bool { return a[i].estimatedVRAM < a[j].estimatedVRAM } - -// pickBestFullFitByLibrary will try to find the optimal placement of the model in the available GPUs where the model fully fits -// The list of GPUs returned will always be the same brand (library) -// If the model can not be fit fully within the available GPU(s) nil is returned -// If numParallel is <= 0, this will attempt try to optimize parallelism based on available VRAM, and adjust -// opts.NumCtx accordingly -func pickBestFullFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList { - var numParallelToTry []int - if *numParallel <= 0 { - // If no specific parallel setting was provided, try larger then smaller, always end with 1 - numParallelToTry = append(numParallelToTry, defaultParallel, 1) - } else { - numParallelToTry = []int{*numParallel} - } - - for _, gl := range gpus.ByLibrary() { - sgl := append(make(discover.GpuInfoList, 0, len(gl)), gl...) - - // TODO - potentially sort by performance capability, existing models loaded, etc. - // TODO - Eliminate any GPUs that already have envconfig.MaxRunners loaded on them - // Note: at present, this will favor most current available VRAM descending and ignoring faster GPU speed in mixed setups - sort.Sort(sort.Reverse(discover.ByFreeMemory(sgl))) - - if !envconfig.SchedSpread() { - for _, p := range numParallelToTry { - req.opts.NumCtx = req.origNumCtx * p - // Try to pack into as few GPUs as possible, starting from 1 GPU - for numGPUs := 1; numGPUs <= len(sgl); numGPUs++ { - gpuSubset := sgl[:numGPUs] - ok, estimatedVRAM := llm.PredictServerFit(gpuSubset, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p) - - if ok { - slog.Info("new model will fit in available VRAM across minimum required GPUs, loading", - "model", req.model.ModelPath, - "library", sgl[0].Library, - "parallel", p, - "required", format.HumanBytes2(estimatedVRAM), - "gpus", numGPUs) - *numParallel = p - return gpuSubset - } - } - } - } else { - // TODO future refinements - // - if multiple Libraries, see if any single GPU in any Library will fit - // - try subsets of GPUs instead of just falling back to 1 or all in a family - - // Now try all the GPUS (OLLAMA_SCHED_SPREAD is set) - for _, p := range numParallelToTry { - req.opts.NumCtx = req.origNumCtx * p - if ok, estimatedVRAM := llm.PredictServerFit(sgl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, p); ok { - slog.Info("new model will fit in available VRAM, loading", - "model", req.model.ModelPath, - "library", sgl[0].Library, - "parallel", p, - "required", format.HumanBytes2(estimatedVRAM), - "gpus", len(sgl)) - *numParallel = p - return sgl - } - } - } - } - return nil -} - -// If multiple Libraries are detected, pick the Library which loads the most layers for the model -func pickBestPartialFitByLibrary(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList, numParallel *int) discover.GpuInfoList { - if *numParallel <= 0 { - *numParallel = 1 - req.opts.NumCtx = req.origNumCtx - } - byLibrary := gpus.ByLibrary() - if len(byLibrary) <= 1 { - return gpus - } - var bestEstimate uint64 - var bestFit int - for i, gl := range byLibrary { - _, estimatedVRAM := llm.PredictServerFit(gl, f, req.model.AdapterPaths, req.model.ProjectorPaths, req.opts, *numParallel) - if estimatedVRAM > bestEstimate { - bestEstimate = estimatedVRAM - bestFit = i - } - } - return byLibrary[bestFit] -} +// func (a BySize) Less(i, j int) bool { return a[i].vramSize < a[j].vramSize } // findRunnerToUnload finds a runner to unload to make room for a new model func (s *Scheduler) findRunnerToUnload() *runnerRef { @@ -875,6 +754,13 @@ func (s *Scheduler) findRunnerToUnload() *runnerRef { func (s *Scheduler) unloadAllRunners() { s.loadedMu.Lock() defer s.loadedMu.Unlock() + + if s.activeLoading != nil { + slog.Debug("shutting down currently loading runner") + s.activeLoading.Close() + s.activeLoading = nil + } + for model, runner := range s.loaded { if runner.llama != nil { slog.Debug("shutting down runner", "model", model) @@ -901,18 +787,3 @@ func (s *Scheduler) expireRunner(model *Model) { runner.refMu.Unlock() } } - -// If other runners are loaded, make sure the pending request will fit in system memory -// If not, pick a runner to unload, else return nil and the request can be loaded -func (s *Scheduler) maybeFindCPURunnerToUnload(req *LlmRequest, f *ggml.GGML, gpus discover.GpuInfoList) *runnerRef { - slog.Debug("evaluating if CPU model load will fit in available system memory") - estimate := llm.EstimateGPULayers(gpus, f, req.model.ProjectorPaths, req.opts, req.opts.NumCtx/req.origNumCtx) - if estimate.TotalSize <= gpus[0].FreeMemory { - slog.Debug("cpu inference mode, model fits in available system memory", "model", format.HumanBytes2(estimate.TotalSize), "available", format.HumanBytes2(gpus[0].FreeMemory)) - return nil - } - - // TODO - optimization: try to find CPU only runners first, or partial offloads with enough in system memory to make room - - return s.findRunnerToUnload() -} diff --git a/server/sched_test.go b/server/sched_test.go index 3892fbbab5..0acd591189 100644 --- a/server/sched_test.go +++ b/server/sched_test.go @@ -52,7 +52,7 @@ func TestLoad(t *testing.T) { return nil, errors.New("something failed to load model blah") } gpus := discover.GpuInfoList{} - s.load(req, f, gpus, 0) + s.load(req, f, gpus, false) require.Empty(t, req.successCh) require.Len(t, req.errCh, 1) s.loadedMu.Lock() @@ -61,16 +61,17 @@ func TestLoad(t *testing.T) { err := <-req.errCh require.Contains(t, err.Error(), "this model may be incompatible") - server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}} + server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}} s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + server.modelPath = model return server, nil } - s.load(req, f, gpus, 0) + s.load(req, f, gpus, false) select { case err := <-req.errCh: require.NoError(t, err) case resp := <-req.successCh: - require.Equal(t, uint64(10), resp.estimatedVRAM) + require.Equal(t, uint64(10), resp.vramSize) require.Equal(t, uint(1), resp.refCount) s.loadedMu.Lock() require.Len(t, s.loaded, 1) @@ -79,7 +80,7 @@ func TestLoad(t *testing.T) { req.model.ModelPath = "dummy_model_path" server.waitResp = errors.New("wait failure") - s.load(req, f, gpus, 0) + s.load(req, f, gpus, false) select { case err := <-req.errCh: require.Contains(t, err.Error(), "wait failure") @@ -104,10 +105,11 @@ type reqBundle struct { } func (scenario *reqBundle) newServer(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + scenario.srv.modelPath = model return scenario.srv, nil } -func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, estimatedVRAM uint64, duration *api.Duration) *reqBundle { +func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, vramSize uint64, duration *api.Duration) *reqBundle { b := &reqBundle{} b.ctx, b.ctxDone = context.WithCancel(ctx) t.Helper() @@ -144,7 +146,7 @@ func newScenarioRequest(t *testing.T, ctx context.Context, modelName string, est successCh: make(chan *runnerRef, 1), errCh: make(chan error, 1), } - b.srv = &mockLlm{estimatedVRAM: estimatedVRAM, estimatedVRAMByGPU: map[string]uint64{"": estimatedVRAM}} + b.srv = &mockLlm{vramSize: vramSize, vramByGPU: map[string]uint64{"": vramSize}} return b } @@ -262,10 +264,10 @@ func TestRequestsMultipleLoadedModels(t *testing.T) { // Multiple loaded models a := newScenarioRequest(t, ctx, "ollama-model-3a", 1*format.GigaByte, nil) - b := newScenarioRequest(t, ctx, "ollama-model-3b", 24*format.GigaByte, nil) - c := newScenarioRequest(t, ctx, "ollama-model-4a", 30, nil) - c.req.opts.NumGPU = 0 // CPU load, will be allowed - d := newScenarioRequest(t, ctx, "ollama-model-3c", 30, nil) // Needs prior unloaded + b := newScenarioRequest(t, ctx, "ollama-model-3b", 10*format.GigaByte, nil) + c := newScenarioRequest(t, ctx, "ollama-model-4a", 10*format.GigaByte, nil) + c.req.opts.NumGPU = 0 // CPU load, will be allowed + d := newScenarioRequest(t, ctx, "ollama-model-3c", 10*format.GigaByte, nil) // Needs prior unloaded t.Setenv("OLLAMA_MAX_LOADED_MODELS", "1") s.newServerFn = a.newServer @@ -418,11 +420,12 @@ func TestExpireRunner(t *testing.T) { var f *ggml.GGML gpus := discover.GpuInfoList{} - server := &mockLlm{estimatedVRAM: 10, estimatedVRAMByGPU: map[string]uint64{}} + server := &mockLlm{vramSize: 10, vramByGPU: map[string]uint64{}} s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { + server.modelPath = model return server, nil } - s.load(req, f, gpus, 0) + s.load(req, f, gpus, false) select { case err := <-req.errCh: @@ -506,7 +509,7 @@ func TestUseLoadedRunner(t *testing.T) { sessionDuration: &api.Duration{Duration: 2}, } finished := make(chan *LlmRequest) - llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} + llm1 := &mockLlm{vramByGPU: map[string]uint64{}} r1 := &runnerRef{llama: llm1, sessionDuration: 1, numParallel: 1} req.useLoadedRunner(r1, finished) require.Equal(t, uint(1), r1.refCount) @@ -541,8 +544,8 @@ func TestUpdateFreeSpace(t *testing.T) { gpus[0].FreeMemory = 900 gpus[1].TotalMemory = 2000 gpus[1].FreeMemory = 1900 - llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 50, "2": 50}} - llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{"1": 125, "2": 75}} + llm1 := &mockLlm{vramByGPU: map[string]uint64{"1": 50, "2": 50}} + llm2 := &mockLlm{vramByGPU: map[string]uint64{"1": 125, "2": 75}} r1 := &runnerRef{llama: llm1, gpus: gpus, numParallel: 1} r2 := &runnerRef{llama: llm2, gpus: gpus, numParallel: 1} @@ -557,40 +560,6 @@ func TestUpdateFreeSpace(t *testing.T) { require.Equal(t, uint64(2000-50-75), gpus[1].FreeMemory) } -func TestFilterGPUsWithoutLoadingModels(t *testing.T) { - ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) - defer done() - gpus := discover.GpuInfoList{ - { - Library: "cuda", - ID: "0", - }, - { - Library: "cuda", - ID: "1", - }, - } - r1 := &runnerRef{gpus: discover.GpuInfoList{gpus[0]}, loading: true} - - s := InitScheduler(ctx) - s.loadedMu.Lock() - s.loaded["a"] = r1 - s.loadedMu.Unlock() - - tmp := s.filterGPUsWithoutLoadingModels(gpus) - require.Len(t, tmp, 1) - require.Equal(t, "1", tmp[0].ID) - - r1.gpus = discover.GpuInfoList{gpus[1]} - tmp = s.filterGPUsWithoutLoadingModels(gpus) - require.Len(t, tmp, 1) - require.Equal(t, "0", tmp[0].ID) - - r1.gpus = discover.GpuInfoList{} - tmp = s.filterGPUsWithoutLoadingModels(gpus) - require.Len(t, tmp, 2) -} - func TestFindRunnerToUnload(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) defer done() @@ -615,7 +584,7 @@ func TestNeedsReload(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) defer done() - llm := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} + llm := &mockLlm{vramByGPU: map[string]uint64{}} do := api.DefaultOptions() runner := &runnerRef{ model: &Model{ @@ -662,8 +631,8 @@ func TestUnloadAllRunners(t *testing.T) { ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) defer done() - llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} - llm2 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} + llm1 := &mockLlm{vramByGPU: map[string]uint64{}} + llm2 := &mockLlm{vramByGPU: map[string]uint64{}} s := InitScheduler(ctx) s.unloadAllRunners() @@ -681,7 +650,7 @@ func TestUnloadAllRunners(t *testing.T) { } func TestUnload(t *testing.T) { - llm1 := &mockLlm{estimatedVRAMByGPU: map[string]uint64{}} + llm1 := &mockLlm{vramByGPU: map[string]uint64{}} r1 := &runnerRef{llama: llm1, numParallel: 1} r2 := &runnerRef{model: &Model{AdapterPaths: []string{"A"}}, numParallel: 1} r1.unload() @@ -707,62 +676,40 @@ func TestAlreadyCanceled(t *testing.T) { require.Empty(t, scenario1a.req.successCh) } -func TestHomogeneousGPUs(t *testing.T) { - ctx, done := context.WithTimeout(t.Context(), 100*time.Millisecond) - defer done() - s := InitScheduler(ctx) - - s.getGpuFn = func() discover.GpuInfoList { - // Set memory values to require the model to be spread - gpus := []discover.GpuInfo{ - {Library: "cuda"}, - {Library: "rocm"}, - } - gpus[0].TotalMemory = 1 * format.GibiByte - gpus[0].FreeMemory = 256 * format.MebiByte - gpus[1].TotalMemory = 1 * format.GibiByte - gpus[1].FreeMemory = 256 * format.MebiByte - return gpus - } - s.getCpuFn = getCpuFn - a := newScenarioRequest(t, ctx, "ollama-model-1", 10, &api.Duration{Duration: 5 * time.Millisecond}) - s.newServerFn = func(gpus discover.GpuInfoList, model string, f *ggml.GGML, adapters []string, projectors []string, opts api.Options, numParallel int) (llm.LlamaServer, error) { - require.Len(t, gpus, 1) - return a.newServer(gpus, model, f, adapters, projectors, opts, numParallel) - } - slog.Info("a") - s.pendingReqCh <- a.req - require.Len(t, s.pendingReqCh, 1) - s.Run(ctx) - select { - case resp := <-a.req.successCh: - require.Equal(t, resp.llama, a.srv) - require.Empty(t, s.pendingReqCh) - require.Empty(t, a.req.errCh) - case err := <-a.req.errCh: - t.Fatal(err.Error()) - case <-ctx.Done(): - t.Fatal("timeout") - } -} - type mockLlm struct { - pingResp error - waitResp error - completionResp error - embeddingResp []float32 - embeddingRespErr error - tokenizeResp []int - tokenizeRespErr error - detokenizeResp string - detonekizeRespErr error - closeResp error - closeCalled bool - estimatedVRAM uint64 - estimatedTotal uint64 - estimatedVRAMByGPU map[string]uint64 + modelPath string + pingResp error + waitResp error + completionResp error + embeddingResp []float32 + embeddingRespErr error + tokenizeResp []int + tokenizeRespErr error + detokenizeResp string + detonekizeRespErr error + closeResp error + closeCalled bool + vramSize uint64 + totalSize uint64 + vramByGPU map[string]uint64 } +func (s *mockLlm) ModelPath() string { + return s.modelPath +} + +func (s *mockLlm) Load(ctx context.Context, gpus discover.GpuInfoList, requireFull bool) error { + if requireFull { + for _, g := range gpus { + if g.FreeMemory >= s.vramSize { + return nil + } + } + + return llm.ErrLoadRequiredFull + } + return nil +} func (s *mockLlm) Ping(ctx context.Context) error { return s.pingResp } func (s *mockLlm) WaitUntilRunning(ctx context.Context) error { return s.waitResp } func (s *mockLlm) Completion(ctx context.Context, req llm.CompletionRequest, fn func(llm.CompletionResponse)) error { @@ -785,7 +732,7 @@ func (s *mockLlm) Close() error { s.closeCalled = true return s.closeResp } -func (s *mockLlm) EstimatedVRAM() uint64 { return s.estimatedVRAM } -func (s *mockLlm) EstimatedTotal() uint64 { return s.estimatedTotal } -func (s *mockLlm) EstimatedVRAMByGPU(gpuid string) uint64 { return s.estimatedVRAMByGPU[gpuid] } -func (s *mockLlm) Pid() int { return -1 } +func (s *mockLlm) VRAMSize() uint64 { return s.vramSize } +func (s *mockLlm) TotalSize() uint64 { return s.totalSize } +func (s *mockLlm) VRAMByGPU(gpuid string) uint64 { return s.vramByGPU[gpuid] } +func (s *mockLlm) Pid() int { return -1 }