From 0000000000000000000000000000000000000000 Mon Sep 17 00:00:00 2001 From: jmorganca Date: Tue, 8 Apr 2025 20:31:38 -0700 Subject: [PATCH] sort devices by score in the ggml backend loading code, devices are now sorted by score, ensuring the device with the fastest acceleration is loaded --- ggml/src/ggml-backend-reg.cpp | 21 +++++++++++++-------- 1 file changed, 13 insertions(+), 8 deletions(-) diff --git a/ggml/src/ggml-backend-reg.cpp b/ggml/src/ggml-backend-reg.cpp index 136afec7..f794d9cf 100644 --- a/ggml/src/ggml-backend-reg.cpp +++ b/ggml/src/ggml-backend-reg.cpp @@ -175,7 +175,7 @@ struct ggml_backend_reg_entry { struct ggml_backend_registry { std::vector backends; - std::vector devices; + std::vector> devices; ggml_backend_registry() { #ifdef GGML_USE_CUDA @@ -223,7 +223,7 @@ struct ggml_backend_registry { } } - void register_backend(ggml_backend_reg_t reg, dl_handle_ptr handle = nullptr) { + void register_backend(ggml_backend_reg_t reg, int score = -1, dl_handle_ptr handle = nullptr) { if (!reg) { return; } @@ -234,15 +234,20 @@ struct ggml_backend_registry { #endif backends.push_back({ reg, std::move(handle) }); for (size_t i = 0; i < ggml_backend_reg_dev_count(reg); i++) { - register_device(ggml_backend_reg_dev_get(reg, i)); + register_device(ggml_backend_reg_dev_get(reg, i), score); } } - void register_device(ggml_backend_dev_t device) { + void register_device(ggml_backend_dev_t device, int score = -1) { #ifndef NDEBUG GGML_LOG_DEBUG("%s: registered device %s (%s)\n", __func__, ggml_backend_dev_name(device), ggml_backend_dev_description(device)); #endif - devices.push_back(device); + devices.push_back({device, score}); + std::stable_sort(devices.begin(), devices.end(), + [](const auto & a, const auto & b) { + return a.second > b.second; + } + ); } ggml_backend_reg_t load_backend(const fs::path & path, bool silent) { @@ -286,7 +291,7 @@ struct ggml_backend_registry { GGML_LOG_INFO("%s: loaded %s backend from %s\n", __func__, ggml_backend_reg_name(reg), path_str(path).c_str()); - register_backend(reg, std::move(handle)); + register_backend(reg, score_fn ? score_fn() : -1, std::move(handle)); return reg; } @@ -309,7 +314,7 @@ struct ggml_backend_registry { // remove devices devices.erase( std::remove_if(devices.begin(), devices.end(), - [reg](ggml_backend_dev_t dev) { return ggml_backend_dev_backend_reg(dev) == reg; }), + [reg](std::pair dev) { return ggml_backend_dev_backend_reg(dev.first) == reg; }), devices.end()); // remove backend @@ -367,7 +372,7 @@ size_t ggml_backend_dev_count() { ggml_backend_dev_t ggml_backend_dev_get(size_t index) { GGML_ASSERT(index < ggml_backend_dev_count()); - return get_reg().devices[index]; + return get_reg().devices[index].first; } ggml_backend_dev_t ggml_backend_dev_by_name(const char * name) {