From 3269535a4c19e5b1f3178645a136e753df8ed9ba Mon Sep 17 00:00:00 2001 From: Daniel Hiltgen Date: Fri, 15 Dec 2023 14:27:27 -0800 Subject: [PATCH] Refine handling of shim presence This allows the CPU only builds to work on systems with Radeon cards --- llm/llm.go | 10 ++++++---- llm/shim_ext_server.go | 5 ++--- 2 files changed, 8 insertions(+), 7 deletions(-) diff --git a/llm/llm.go b/llm/llm.go index 86dd3346d..69ea705f2 100644 --- a/llm/llm.go +++ b/llm/llm.go @@ -22,6 +22,9 @@ type LLM interface { Close() } +// Set to false on linux/windows if we are able to load the shim +var ShimPresent = false + func New(workDir, model string, adapters, projectors []string, opts api.Options) (LLM, error) { if _, err := os.Stat(model); err != nil { return nil, err @@ -79,11 +82,10 @@ func New(workDir, model string, adapters, projectors []string, opts api.Options) opts.RopeFrequencyBase = 0.0 opts.RopeFrequencyScale = 0.0 gpuInfo := gpu.GetGPUInfo() - switch gpuInfo.Driver { - case "ROCM": + if gpuInfo.Driver == "ROCM" && ShimPresent { return newRocmShimExtServer(model, adapters, projectors, ggml.NumLayers(), opts) - default: - // Rely on the built-in CUDA based server which will fall back to CPU + } else { + // Rely on the built-in CUDA/Metal based server which will fall back to CPU return newLlamaExtServer(model, adapters, projectors, ggml.NumLayers(), opts) } } diff --git a/llm/shim_ext_server.go b/llm/shim_ext_server.go index 0e7bcfae4..7505adaa5 100644 --- a/llm/shim_ext_server.go +++ b/llm/shim_ext_server.go @@ -30,7 +30,6 @@ import ( var libEmbed embed.FS var RocmShimMissing = fmt.Errorf("ROCm shim library not included in this build of ollama. Radeon GPUs are not supported") -var NoShim = true type shimExtServer struct { s C.struct_rocm_llama_server @@ -78,7 +77,7 @@ func (llm *shimExtServer) llama_server_release_json_resp(json_resp **C.char) { } func newRocmShimExtServer(model string, adapters, projectors []string, numLayers int64, opts api.Options) (extServer, error) { - if NoShim { + if !ShimPresent { return nil, RocmShimMissing } log.Printf("Loading ROCM llm server") @@ -207,6 +206,6 @@ func extractLib(workDir string) error { case err != nil: return fmt.Errorf("stat ROCm shim %s: %v", files[0], err) } - NoShim = false + ShimPresent = true return nil }