Add option to run a faster/cheaper LLM for secondary flows (#742)

This commit is contained in:
Yuhong Sun
2023-11-19 17:48:42 -08:00
committed by GitHub
parent df37387146
commit 0cc3d65839
7 changed files with 28 additions and 7 deletions

View File

@@ -70,6 +70,11 @@ INTENT_MODEL_VERSION = "danswer/intent-model"
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai" GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
# If using Azure, it's the engine name, for example: Danswer # If using Azure, it's the engine name, for example: Danswer
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo" GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
FAST_GEN_AI_MODEL_VERSION = (
os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION
)
# If the Generative AI model requires an API key for access, otherwise can leave blank # If the Generative AI model requires an API key for access, otherwise can leave blank
GEN_AI_API_KEY = ( GEN_AI_API_KEY = (

View File

@@ -1,5 +1,7 @@
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.llm.chat_llm import DefaultMultiLLM from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.gpt_4_all import DanswerGPT4All from danswer.llm.gpt_4_all import DanswerGPT4All
@@ -8,18 +10,23 @@ from danswer.llm.utils import get_gen_ai_api_key
def get_default_llm( def get_default_llm(
gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER,
api_key: str | None = None, api_key: str | None = None,
timeout: int = QA_TIMEOUT, timeout: int = QA_TIMEOUT,
use_fast_llm: bool = False,
) -> LLM: ) -> LLM:
"""A single place to fetch the configured LLM for Danswer """A single place to fetch the configured LLM for Danswer
Also allows overriding certain LLM defaults""" Also allows overriding certain LLM defaults"""
model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
if api_key is None: if api_key is None:
api_key = get_gen_ai_api_key() api_key = get_gen_ai_api_key()
if GEN_AI_MODEL_PROVIDER.lower() == "custom": if gen_ai_model_provider.lower() == "custom":
return CustomModelServer(api_key=api_key, timeout=timeout) return CustomModelServer(api_key=api_key, timeout=timeout)
if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all": if gen_ai_model_provider.lower() == "gpt4all":
return DanswerGPT4All(timeout=timeout) return DanswerGPT4All(model_version=model_version, timeout=timeout)
return DefaultMultiLLM(api_key=api_key, timeout=timeout) return DefaultMultiLLM(
model_version=model_version, api_key=api_key, timeout=timeout
)

View File

@@ -30,6 +30,7 @@ from danswer.configs.constants import AuthType
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import ASYM_QUERY_PREFIX
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
@@ -173,6 +174,10 @@ def get_application() -> FastAPI:
else: else:
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}") logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}") logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}")
if GEN_AI_MODEL_VERSION != FAST_GEN_AI_MODEL_VERSION:
logger.info(
f"Using Fast LLM Model Version: {FAST_GEN_AI_MODEL_VERSION}"
)
if GEN_AI_API_ENDPOINT: if GEN_AI_API_ENDPOINT:
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}") logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")

View File

@@ -35,7 +35,9 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
# When running in a batch, it takes as long as the longest thread # When running in a batch, it takes as long as the longest thread
# And when running a large batch, one may fail and take the whole timeout # And when running a large batch, one may fail and take the whole timeout
# instead cap it to 5 seconds # instead cap it to 5 seconds
model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt) model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke(
filled_llm_prompt
)
logger.debug(model_output) logger.debug(model_output)
return _extract_usefulness(model_output) return _extract_usefulness(model_output)

View File

@@ -155,7 +155,7 @@ def extract_source_filter(
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources) messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt) model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt)
logger.debug(model_output) logger.debug(model_output)
return _extract_source_filters_from_llm_out(model_output) return _extract_source_filters_from_llm_out(model_output)

View File

@@ -147,7 +147,7 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
messages = _get_time_filter_messages(query) messages = _get_time_filter_messages(query)
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = get_default_llm().invoke(filled_llm_prompt) model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt)
logger.debug(model_output) logger.debug(model_output)
return _extract_time_filter_from_llm_out(model_output) return _extract_time_filter_from_llm_out(model_output)

View File

@@ -18,6 +18,7 @@ services:
environment: environment:
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
@@ -76,6 +77,7 @@ services:
environment: environment:
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai} - GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-}
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-} - GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-} - GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-} - GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}