mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-01 00:18:18 +02:00
Add option to run a faster/cheaper LLM for secondary flows (#742)
This commit is contained in:
parent
df37387146
commit
0cc3d65839
@ -70,6 +70,11 @@ INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
GEN_AI_MODEL_PROVIDER = os.environ.get("GEN_AI_MODEL_PROVIDER") or "openai"
|
||||
# If using Azure, it's the engine name, for example: Danswer
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION") or "gpt-3.5-turbo"
|
||||
# For secondary flows like extracting filters or deciding if a chunk is useful, we don't need
|
||||
# as powerful of a model as say GPT-4 so we can use an alternative that is faster and cheaper
|
||||
FAST_GEN_AI_MODEL_VERSION = (
|
||||
os.environ.get("FAST_GEN_AI_MODEL_VERSION") or GEN_AI_MODEL_VERSION
|
||||
)
|
||||
|
||||
# If the Generative AI model requires an API key for access, otherwise can leave blank
|
||||
GEN_AI_API_KEY = (
|
||||
|
@ -1,5 +1,7 @@
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||
from danswer.llm.custom_llm import CustomModelServer
|
||||
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||
@ -8,18 +10,23 @@ from danswer.llm.utils import get_gen_ai_api_key
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
gen_ai_model_provider: str = GEN_AI_MODEL_PROVIDER,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
use_fast_llm: bool = False,
|
||||
) -> LLM:
|
||||
"""A single place to fetch the configured LLM for Danswer
|
||||
Also allows overriding certain LLM defaults"""
|
||||
model_version = FAST_GEN_AI_MODEL_VERSION if use_fast_llm else GEN_AI_MODEL_VERSION
|
||||
if api_key is None:
|
||||
api_key = get_gen_ai_api_key()
|
||||
|
||||
if GEN_AI_MODEL_PROVIDER.lower() == "custom":
|
||||
if gen_ai_model_provider.lower() == "custom":
|
||||
return CustomModelServer(api_key=api_key, timeout=timeout)
|
||||
|
||||
if GEN_AI_MODEL_PROVIDER.lower() == "gpt4all":
|
||||
return DanswerGPT4All(timeout=timeout)
|
||||
if gen_ai_model_provider.lower() == "gpt4all":
|
||||
return DanswerGPT4All(model_version=model_version, timeout=timeout)
|
||||
|
||||
return DefaultMultiLLM(api_key=api_key, timeout=timeout)
|
||||
return DefaultMultiLLM(
|
||||
model_version=model_version, api_key=api_key, timeout=timeout
|
||||
)
|
||||
|
@ -30,6 +30,7 @@ from danswer.configs.constants import AuthType
|
||||
from danswer.configs.model_configs import ASYM_PASSAGE_PREFIX
|
||||
from danswer.configs.model_configs import ASYM_QUERY_PREFIX
|
||||
from danswer.configs.model_configs import DOCUMENT_ENCODER_MODEL
|
||||
from danswer.configs.model_configs import FAST_GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
@ -173,6 +174,10 @@ def get_application() -> FastAPI:
|
||||
else:
|
||||
logger.info(f"Using LLM Provider: {GEN_AI_MODEL_PROVIDER}")
|
||||
logger.info(f"Using LLM Model Version: {GEN_AI_MODEL_VERSION}")
|
||||
if GEN_AI_MODEL_VERSION != FAST_GEN_AI_MODEL_VERSION:
|
||||
logger.info(
|
||||
f"Using Fast LLM Model Version: {FAST_GEN_AI_MODEL_VERSION}"
|
||||
)
|
||||
if GEN_AI_API_ENDPOINT:
|
||||
logger.info(f"Using LLM Endpoint: {GEN_AI_API_ENDPOINT}")
|
||||
|
||||
|
@ -35,7 +35,9 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
||||
# When running in a batch, it takes as long as the longest thread
|
||||
# And when running a large batch, one may fail and take the whole timeout
|
||||
# instead cap it to 5 seconds
|
||||
model_output = get_default_llm(timeout=5).invoke(filled_llm_prompt)
|
||||
model_output = get_default_llm(use_fast_llm=True, timeout=5).invoke(
|
||||
filled_llm_prompt
|
||||
)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_usefulness(model_output)
|
||||
|
@ -155,7 +155,7 @@ def extract_source_filter(
|
||||
|
||||
messages = _get_source_filter_messages(query=query, valid_sources=valid_sources)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_source_filters_from_llm_out(model_output)
|
||||
|
@ -147,7 +147,7 @@ def extract_time_filter(query: str) -> tuple[datetime | None, bool]:
|
||||
|
||||
messages = _get_time_filter_messages(query)
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = get_default_llm().invoke(filled_llm_prompt)
|
||||
model_output = get_default_llm(use_fast_llm=True).invoke(filled_llm_prompt)
|
||||
logger.debug(model_output)
|
||||
|
||||
return _extract_time_filter_from_llm_out(model_output)
|
||||
|
@ -18,6 +18,7 @@ services:
|
||||
environment:
|
||||
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
|
||||
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
|
||||
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
|
||||
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
|
||||
@ -76,6 +77,7 @@ services:
|
||||
environment:
|
||||
- GEN_AI_MODEL_PROVIDER=${GEN_AI_MODEL_PROVIDER:-openai}
|
||||
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
|
||||
- FAST_GEN_AI_MODEL_VERSION=${FAST_GEN_AI_MODEL_VERSION:-}
|
||||
- GEN_AI_API_KEY=${GEN_AI_API_KEY:-}
|
||||
- GEN_AI_API_ENDPOINT=${GEN_AI_API_ENDPOINT:-}
|
||||
- GEN_AI_API_VERSION=${GEN_AI_API_VERSION:-}
|
||||
|
Loading…
x
Reference in New Issue
Block a user