From 49df03da9af6b0050ebbf50676f7db569a2b54d9 Mon Sep 17 00:00:00 2001 From: Michael Yang Date: Tue, 11 Feb 2025 23:36:53 +0000 Subject: [PATCH] fix: harden backend loading (#9024) * wrap ggml_backend_load_best in try/catch * ignore non-ollama paths --- .../patches/0017-try-catch-backend-load.patch | 69 +++++++++++++++++++ ml/backend/ggml/ggml/src/ggml-backend-reg.cpp | 45 ++++++------ ml/backend/ggml/ggml/src/ggml.go | 5 ++ 3 files changed, 97 insertions(+), 22 deletions(-) create mode 100644 llama/patches/0017-try-catch-backend-load.patch diff --git a/llama/patches/0017-try-catch-backend-load.patch b/llama/patches/0017-try-catch-backend-load.patch new file mode 100644 index 000000000..b48b75889 --- /dev/null +++ b/llama/patches/0017-try-catch-backend-load.patch @@ -0,0 +1,69 @@ +From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 +From: Michael Yang +Date: Tue, 11 Feb 2025 14:06:36 -0800 +Subject: [PATCH] try/catch backend load + +--- + ggml/src/ggml-backend-reg.cpp | 45 ++++++++++++++++++----------------- + 1 file changed, 23 insertions(+), 22 deletions(-) + +diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp +index ac5cda07..374c3b21 100644 +--- a/ggml/src/ggml-backend-reg.cpp ++++ b/ggml/src/ggml-backend-reg.cpp +@@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, + } + fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); + for (const auto & entry : dir_it) { +- if (entry.is_regular_file()) { +- std::wstring filename = entry.path().filename().wstring(); +- std::wstring ext = entry.path().extension().wstring(); +- if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { +- dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; +- if (!handle && !silent) { +- GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); +- } +- if (handle) { ++ try { ++ if (entry.is_regular_file()) { ++ std::wstring filename = entry.path().filename().wstring(); ++ std::wstring ext = entry.path().extension().wstring(); ++ if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { ++ dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; ++ if (!handle) { ++ GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); ++ continue; ++ } ++ + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); +- if (score_fn) { +- int s = score_fn(); +-#ifndef NDEBUG +- GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); +-#endif +- if (s > best_score) { +- best_score = s; +- best_path = entry.path().wstring(); +- } +- } else { +- if (!silent) { +- GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); +- } ++ if (!score_fn) { ++ GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); ++ continue; ++ } ++ ++ int s = score_fn(); ++ GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); ++ if (s > best_score) { ++ best_score = s; ++ best_path = entry.path().wstring(); + } + } + } ++ } catch (const std::exception & e) { ++ GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); + } + } + } diff --git a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp index ac5cda072..374c3b219 100644 --- a/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp +++ b/ml/backend/ggml/ggml/src/ggml-backend-reg.cpp @@ -512,32 +512,33 @@ static ggml_backend_reg_t ggml_backend_load_best(const char * name, bool silent, } fs::directory_iterator dir_it(search_path, fs::directory_options::skip_permission_denied); for (const auto & entry : dir_it) { - if (entry.is_regular_file()) { - std::wstring filename = entry.path().filename().wstring(); - std::wstring ext = entry.path().extension().wstring(); - if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { - dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; - if (!handle && !silent) { - GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); - } - if (handle) { + try { + if (entry.is_regular_file()) { + std::wstring filename = entry.path().filename().wstring(); + std::wstring ext = entry.path().extension().wstring(); + if (filename.find(file_prefix) == 0 && ext == backend_filename_suffix()) { + dl_handle_ptr handle { dl_load_library(entry.path().wstring()) }; + if (!handle) { + GGML_LOG_ERROR("%s: failed to load %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); + continue; + } + auto score_fn = (ggml_backend_score_t) dl_get_sym(handle.get(), "ggml_backend_score"); - if (score_fn) { - int s = score_fn(); -#ifndef NDEBUG - GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); -#endif - if (s > best_score) { - best_score = s; - best_path = entry.path().wstring(); - } - } else { - if (!silent) { - GGML_LOG_INFO("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); - } + if (!score_fn) { + GGML_LOG_DEBUG("%s: failed to find ggml_backend_score in %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str()); + continue; + } + + int s = score_fn(); + GGML_LOG_DEBUG("%s: %s score: %d\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), s); + if (s > best_score) { + best_score = s; + best_path = entry.path().wstring(); } } } + } catch (const std::exception & e) { + GGML_LOG_ERROR("%s: failed to load %s: %s\n", __func__, utf16_to_utf8(entry.path().wstring()).c_str(), e.what()); } } } diff --git a/ml/backend/ggml/ggml/src/ggml.go b/ml/backend/ggml/ggml/src/ggml.go index 94b0d1853..3920e37dc 100644 --- a/ml/backend/ggml/ggml/src/ggml.go +++ b/ml/backend/ggml/ggml/src/ggml.go @@ -79,6 +79,11 @@ var OnceLoad = sync.OnceFunc(func() { continue } + if abspath != filepath.Dir(exe) && !strings.Contains(abspath, filepath.FromSlash("lib/ollama")) { + slog.Debug("skipping path which is not part of ollama", "path", abspath) + continue + } + if _, ok := visited[abspath]; !ok { func() { slog.Debug("ggml backend load all from path", "path", abspath)