From 0cc3d65839dc2cb8eaf717529821896beec54e3d Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Sun, 19 Nov 2023 17:48:42 -0800 Subject: [PATCH] Add option to run a faster/cheaper LLM for secondary flows (#742) --- backend/danswer/configs/model_configs.py | 5 +++++ backend/danswer/llm/factory.py | 15 +++++++++++---- backend/danswer/main.py | 5 +++++ .../secondary_llm_flows/chunk_usefulness.py | 4 +++- .../danswer/secondary_llm_flows/source_filter.py | 2 +- .../danswer/secondary_llm_flows/time_filter.py | 2 +- deployment/docker_compose/docker-compose.dev.yml | 2 ++ 7 files changed, 28 insertions(+), 7 deletions(-) diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index 353853ac2..a91a20882 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -70,6 +70,11 @@ INTENT_MODEL_VERSION = "danswer/intent-model" GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai" # If using Azure, it's the engine name, for example: Danswer GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo" +# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need +# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper +FAST_GEN_AI_MODEL_VERSION = ( + os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION +) # If the Generative AI model requires an API key for access, otherwise can leave blank GEN_AI_API_KEY = ( diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index 670def9eb..b4b06ac77 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -1,5 +1,7 @@ from danswer.configs.app_configs import QA_TIMEOUT +from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.custom_llm import CustomModelServer from danswer.llm.gpt_4_all import DanswerGPT4All @@ -8,18 +10,23 @@ from danswer.llm.utils import get_gen_ai_api_key def get_default_llm( + gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER, api_key: str | None = None, timeout: int = QA_TIMEOUT, + use_fast_llm: bool = False, ) -> LLM: """A single place to fetch the configured LLM for Danswer Also allows overriding certain LLM defaults""" + model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION if api_key is None: api_key = get_gen_ai_api_key() - if GEN_AI_MODEL_PROVIDER.lower() == "custom": + if gen_ai_model_provider.lower() == "custom": return CustomModelServer(api_key=api_key, timeout=timeout) - if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all": - return DanswerGPT4All(timeout=timeout) + if gen_ai_model_provider.lower() == "gpt4all": + return DanswerGPT4All(model_version=model_version, timeout=timeout) - return DefaultMultiLLM(api_key=api_key, timeout=timeout) + return DefaultMultiLLM( + model_version=model_version, api_key=api_key, timeout=timeout + ) diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 222a6eed2..cf51b5c28 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -30,6 +30,7 @@ from danswer.configs.constants import AuthType from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL +from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_VERSION @@ -173,6 +174,10 @@ def get_application() -> FastAPI: else: logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}") logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}") + if GEN_AI_MODEL_VERSION != FAST_GEN_AI_MODEL_VERSION: + logger.info( + f"Using Fast LLM Model Version: {FAST_GEN_AI_MODEL_VERSION}" + ) if GEN_AI_API_ENDPOINT: logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") diff --git a/backend/danswer/secondary_llm_flows/chunk_usefulness.py b/backend/danswer/secondary_llm_flows/chunk_usefulness.py index 057aa7f63..636401912 100644 --- a/backend/danswer/secondary_llm_flows/chunk_usefulness.py +++ b/backend/danswer/secondary_llm_flows/chunk_usefulness.py @@ -35,7 +35,9 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool: # When running in a batch, it takes as long as the longest thread # And when running a large batch, one may fail and take the whole timeout # instead cap it to 5 seconds - model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt) + model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke( + filled_llm_prompt + ) logger.debug(model_output) return _extract_usefulness(model_output) diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index ed4bdbdf6..cd8f484b2 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -155,7 +155,7 @@ def extract_source_filter( messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) + model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt) logger.debug(model_output) return _extract_source_filters_from_llm_out(model_output) diff --git a/backend/danswer/secondary_llm_flows/time_filter.py b/backend/danswer/secondary_llm_flows/time_filter.py index be06d23cc..9bda77963 100644 --- a/backend/danswer/secondary_llm_flows/time_filter.py +++ b/backend/danswer/secondary_llm_flows/time_filter.py @@ -147,7 +147,7 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]: messages = _get_time_filter_messages(query) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = get_default_llm().invoke(filled_llm_prompt) + model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt) logger.debug(model_output) return _extract_time_filter_from_llm_out(model_output) diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index ac809ce84..5d56f1443 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -18,6 +18,7 @@ services: environment: - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} + - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} @@ -76,6 +77,7 @@ services: environment: - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} + - FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}