Add option to run a faster/cheaper LLM for secondary flows (#742)

This commit is contained in:
Yuhong Sun 2023-11-19 17:48:42 -08:00 committed by GitHub
parent df37387146
commit 0cc3d65839
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 28 additions and 7 deletions

View File

@ -70,6 +70,11 @@ INTENT_MODEL_VERSION = "danswer/intent-model"
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = (
os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION
)
# If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = (

View File

@ -1,5 +1,7 @@
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.gpt_4_all import DanswerGPT4All
@ -8,18 +10,23 @@ from danswer.llm.utils import get_gen_ai_api_key
def get_default_llm(
gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
use_fast_llm: bool = False,
) -> LLM:
"""A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults"""
model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
if api_key is None:
api_key = get_gen_ai_api_key()
if GEN_AI_MODEL_PROVIDER.lower() == "custom":
if gen_ai_model_provider.lower() == "custom":
return CustomModelServer(api_key=api_key, timeout=timeout)
if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all":
return DanswerGPT4All(timeout=timeout)
if gen_ai_model_provider.lower() == "gpt4all":
return DanswerGPT4All(model_version=model_version, timeout=timeout)
return DefaultMultiLLM(api_key=api_key, timeout=timeout)
return DefaultMultiLLM(
model_version=model_version, api_key=api_key, timeout=timeout
)

View File

@ -30,6 +30,7 @@ from danswer.configs.constants import AuthType
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
@ -173,6 +174,10 @@ def get_application() -> FastAPI:
else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}")
if GEN_AI_MODEL_VERSION != FAST_GEN_AI_MODEL_VERSION:
logger.info(
f"Using Fast LLM Model Version: {FAST_GEN_AI_MODEL_VERSION}"
)
if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")

View File

@ -35,7 +35,9 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
# When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds
model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt)
model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke(
filled_llm_prompt
)
logger.debug(model_output)
return _extract_usefulness(model_output)

View File

@ -155,7 +155,7 @@ def extract_source_filter(
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt)
logger.debug(model_output)
return _extract_source_filters_from_llm_out(model_output)

View File

@ -147,7 +147,7 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
messages = _get_time_filter_messages(query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt)
model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt)
logger.debug(model_output)
return _extract_time_filter_from_llm_out(model_output)

View File

@ -18,6 +18,7 @@ services:
environment:
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
@ -76,6 +77,7 @@ services:
environment:
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}