From 415960564d92159a54079631c79ec655ad506717 Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 28 Jun 2024 17:13:23 -0700 Subject: [PATCH] Fix fast models --- backend/danswer/chat/process_message.py | 18 ++-- .../slack/handlers/handle_message.py | 4 +- backend/danswer/danswerbot/slack/utils.py | 4 +- .../llm/answering/prompts/citations_prompt.py | 5 +- backend/danswer/llm/factory.py | 99 ++++++++++++------- .../one_shot_answer/answer_question.py | 8 +- backend/danswer/search/pipeline.py | 3 + .../search/postprocessing/postprocessing.py | 10 +- .../secondary_llm_flows/answer_validation.py | 4 +- .../secondary_llm_flows/chunk_usefulness.py | 21 ++-- .../secondary_llm_flows/query_expansion.py | 8 +- .../secondary_llm_flows/query_validation.py | 6 +- .../secondary_llm_flows/source_filter.py | 6 +- .../secondary_llm_flows/time_filter.py | 6 +- backend/danswer/server/gpts/api.py | 6 +- .../danswer/server/manage/administrative.py | 4 +- backend/danswer/server/manage/llm/api.py | 5 +- .../server/query_and_chat/chat_backend.py | 4 +- backend/danswer/tools/search/search_tool.py | 3 + .../server/query_and_chat/query_backend.py | 12 ++- .../regression/search_quality/eval_search.py | 6 +- 21 files changed, 148 insertions(+), 94 deletions(-) diff --git a/backend/danswer/chat/process_message.py b/backend/danswer/chat/process_message.py index 58a575745..9dab81de4 100644 --- a/backend/danswer/chat/process_message.py +++ b/backend/danswer/chat/process_message.py @@ -47,7 +47,8 @@ from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PreviousMessage from danswer.llm.answering.models import PromptConfig from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.utils import get_default_llm_tokenizer from danswer.search.enums import OptionalSearchSetting from danswer.search.retrieval.search_runner import inference_documents_from_ids @@ -244,7 +245,7 @@ def stream_chat_message_objects( ) try: - llm = get_llm_for_persona( + llm, fast_llm = get_llms_for_persona( persona=persona, llm_override=new_msg_req.llm_override or chat_session.llm_override, additional_headers=litellm_additional_headers, @@ -425,6 +426,7 @@ def stream_chat_message_objects( retrieval_options=retrieval_options, prompt_config=prompt_config, llm=llm, + fast_llm=fast_llm, pruning_config=document_pruning_config, selected_docs=selected_llm_docs, chunks_above=new_msg_req.chunks_above, @@ -498,10 +500,14 @@ def stream_chat_message_objects( prompt_config=prompt_config, llm=( llm - or get_llm_for_persona( - persona=persona, - llm_override=new_msg_req.llm_override or chat_session.llm_override, - additional_headers=litellm_additional_headers, + or get_main_llm_from_tuple( + get_llms_for_persona( + persona=persona, + llm_override=( + new_msg_req.llm_override or chat_session.llm_override + ), + additional_headers=litellm_additional_headers, + ) ) ), message_history=[ diff --git a/backend/danswer/danswerbot/slack/handlers/handle_message.py b/backend/danswer/danswerbot/slack/handlers/handle_message.py index abaf094ce..4573e0bd9 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_message.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_message.py @@ -50,7 +50,7 @@ from danswer.db.persona import fetch_persona_by_id from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_llms_for_persona from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import get_max_input_tokens from danswer.one_shot_answer.answer_question import get_search_answer @@ -324,7 +324,7 @@ def handle_message( Persona, fetch_persona_by_id(db_session, new_message_request.persona_id), ) - llm = get_llm_for_persona(persona) + llm, _ = get_llms_for_persona(persona) # In cases of threads, split the available tokens between docs and thread context input_tokens = get_max_input_tokens( diff --git a/backend/danswer/danswerbot/slack/utils.py b/backend/danswer/danswerbot/slack/utils.py index 12f7b4662..ab4396ced 100644 --- a/backend/danswer/danswerbot/slack/utils.py +++ b/backend/danswer/danswerbot/slack/utils.py @@ -30,7 +30,7 @@ from danswer.danswerbot.slack.tokens import fetch_tokens from danswer.db.engine import get_sqlalchemy_engine from danswer.db.users import get_user_by_email from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.one_shot_answer.models import ThreadMessage @@ -58,7 +58,7 @@ def rephrase_slack_message(msg: str) -> str: return messages try: - llm = get_default_llm(use_fast_llm=False, timeout=5) + llm, _ = get_default_llms(timeout=5) except GenAIDisabledException: logger.warning("Unable to rephrase Slack user message, Gen AI disabled") return msg diff --git a/backend/danswer/llm/answering/prompts/citations_prompt.py b/backend/danswer/llm/answering/prompts/citations_prompt.py index 3235ddc17..69f727318 100644 --- a/backend/danswer/llm/answering/prompts/citations_prompt.py +++ b/backend/danswer/llm/answering/prompts/citations_prompt.py @@ -8,7 +8,8 @@ from danswer.db.models import Persona from danswer.db.persona import get_default_prompt__read_only from danswer.file_store.utils import InMemoryChatFile from danswer.llm.answering.models import PromptConfig -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.interfaces import LLMConfig from danswer.llm.utils import build_content_with_imgs from danswer.llm.utils import check_number_of_tokens @@ -99,7 +100,7 @@ def compute_max_document_tokens_for_persona( prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only() return compute_max_document_tokens( prompt_config=PromptConfig.from_model(prompt), - llm_config=get_llm_for_persona(persona).config, + llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config, actual_user_input=actual_user_input, max_llm_token_override=max_llm_token_override, ) diff --git a/backend/danswer/llm/factory.py b/backend/danswer/llm/factory.py index edad6a295..f57bfb524 100644 --- a/backend/danswer/llm/factory.py +++ b/backend/danswer/llm/factory.py @@ -12,65 +12,92 @@ from danswer.llm.interfaces import LLM from danswer.llm.override_models import LLMOverride -def get_llm_for_persona( +def get_main_llm_from_tuple( + llms: tuple[LLM, LLM], +) -> LLM: + return llms[0] + + +def get_llms_for_persona( persona: Persona, llm_override: LLMOverride | None = None, additional_headers: dict[str, str] | None = None, -) -> LLM: +) -> tuple[LLM, LLM]: model_provider_override = llm_override.model_provider if llm_override else None model_version_override = llm_override.model_version if llm_override else None temperature_override = llm_override.temperature if llm_override else None - return get_default_llm( - model_provider_name=( - model_provider_override or persona.llm_model_provider_override - ), - model_version=(model_version_override or persona.llm_model_version_override), - temperature=temperature_override or GEN_AI_TEMPERATURE, - additional_headers=additional_headers, - ) + provider_name = model_provider_override or persona.llm_model_provider_override + if not provider_name: + return get_default_llms( + temperature=temperature_override or GEN_AI_TEMPERATURE, + additional_headers=additional_headers, + ) + + with get_session_context_manager() as db_session: + llm_provider = fetch_provider(db_session, provider_name) + + if not llm_provider: + raise ValueError("No LLM provider found") + + model = model_version_override or persona.llm_model_version_override + fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name + if not model: + raise ValueError("No model name found") + if not fast_model: + raise ValueError("No fast model name found") + + def _create_llm(model: str) -> LLM: + return get_llm( + provider=llm_provider.provider, + model=model, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, + additional_headers=additional_headers, + ) + + return _create_llm(model), _create_llm(fast_model) -def get_default_llm( +def get_default_llms( timeout: int = QA_TIMEOUT, temperature: float = GEN_AI_TEMPERATURE, - use_fast_llm: bool = False, - model_provider_name: str | None = None, - model_version: str | None = None, additional_headers: dict[str, str] | None = None, -) -> LLM: +) -> tuple[LLM, LLM]: if DISABLE_GENERATIVE_AI: raise GenAIDisabledException() - # TODO: pass this in - with get_session_context_manager() as session: - if model_provider_name is None: - llm_provider = fetch_default_provider(session) - else: - llm_provider = fetch_provider(session, model_provider_name) + with get_session_context_manager() as db_session: + llm_provider = fetch_default_provider(db_session) if not llm_provider: raise ValueError("No default LLM provider found") - model_name = model_version or ( - (llm_provider.fast_default_model_name or llm_provider.default_model_name) - if use_fast_llm - else llm_provider.default_model_name + model_name = llm_provider.default_model_name + fast_model_name = ( + llm_provider.fast_default_model_name or llm_provider.default_model_name ) if not model_name: raise ValueError("No default model name found") + if not fast_model_name: + raise ValueError("No fast default model name found") - return get_llm( - provider=llm_provider.provider, - model=model_name, - api_key=llm_provider.api_key, - api_base=llm_provider.api_base, - api_version=llm_provider.api_version, - custom_config=llm_provider.custom_config, - timeout=timeout, - temperature=temperature, - additional_headers=additional_headers, - ) + def _create_llm(model: str) -> LLM: + return get_llm( + provider=llm_provider.provider, + model=model, + api_key=llm_provider.api_key, + api_base=llm_provider.api_base, + api_version=llm_provider.api_version, + custom_config=llm_provider.custom_config, + timeout=timeout, + temperature=temperature, + additional_headers=additional_headers, + ) + + return _create_llm(model_name), _create_llm(fast_model_name) def get_llm( diff --git a/backend/danswer/one_shot_answer/answer_question.py b/backend/danswer/one_shot_answer/answer_question.py index 5199d1e3c..b315c3662 100644 --- a/backend/danswer/one_shot_answer/answer_question.py +++ b/backend/danswer/one_shot_answer/answer_question.py @@ -30,7 +30,8 @@ from danswer.llm.answering.models import CitationConfig from danswer.llm.answering.models import DocumentPruningConfig from danswer.llm.answering.models import PromptConfig from danswer.llm.answering.models import QuotesConfig -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.utils import get_default_llm_token_encode from danswer.one_shot_answer.models import DirectQARequest from danswer.one_shot_answer.models import OneShotQAResponse @@ -156,7 +157,7 @@ def stream_answer_objects( commit=True, ) - llm = get_llm_for_persona(persona=chat_session.persona) + llm, fast_llm = get_llms_for_persona(persona=chat_session.persona) prompt_config = PromptConfig.from_model(prompt) document_pruning_config = DocumentPruningConfig( max_chunks=int( @@ -174,6 +175,7 @@ def stream_answer_objects( retrieval_options=query_req.retrieval_options, prompt_config=prompt_config, llm=llm, + fast_llm=fast_llm, pruning_config=document_pruning_config, bypass_acl=bypass_acl, ) @@ -187,7 +189,7 @@ def stream_answer_objects( question=query_msg.message, answer_style_config=answer_config, prompt_config=PromptConfig.from_model(prompt), - llm=get_llm_for_persona(persona=chat_session.persona), + llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)), single_message_history=history_str, tools=[search_tool], force_use_tool=ForceUseTool( diff --git a/backend/danswer/search/pipeline.py b/backend/danswer/search/pipeline.py index 7f68178bf..98b1a8716 100644 --- a/backend/danswer/search/pipeline.py +++ b/backend/danswer/search/pipeline.py @@ -56,6 +56,7 @@ class SearchPipeline: search_request: SearchRequest, user: User | None, llm: LLM, + fast_llm: LLM, db_session: Session, bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] @@ -65,6 +66,7 @@ class SearchPipeline: self.search_request = search_request self.user = user self.llm = llm + self.fast_llm = fast_llm self.db_session = db_session self.bypass_acl = bypass_acl self.retrieval_metrics_callback = retrieval_metrics_callback @@ -298,6 +300,7 @@ class SearchPipeline: self._postprocessing_generator = search_postprocessing( search_query=self.search_query, retrieved_chunks=self.retrieved_chunks, + llm=self.fast_llm, # use fast_llm for relevance, since it is a relatively easier task rerank_metrics_callback=self.rerank_metrics_callback, ) self._reranked_chunks = cast( diff --git a/backend/danswer/search/postprocessing/postprocessing.py b/backend/danswer/search/postprocessing/postprocessing.py index f7c750eaf..3b36bcff3 100644 --- a/backend/danswer/search/postprocessing/postprocessing.py +++ b/backend/danswer/search/postprocessing/postprocessing.py @@ -9,6 +9,7 @@ from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN from danswer.document_index.document_index_utils import ( translate_boost_count_to_multiplier, ) +from danswer.llm.interfaces import LLM from danswer.search.models import ChunkMetric from danswer.search.models import InferenceChunk from danswer.search.models import MAX_METRICS_CONTENT @@ -134,6 +135,7 @@ def rerank_chunks( def filter_chunks( query: SearchQuery, chunks_to_filter: list[InferenceChunk], + llm: LLM, ) -> list[str]: """Filters chunks based on whether the LLM thought they were relevant to the query. @@ -142,6 +144,7 @@ def filter_chunks( llm_chunk_selection = llm_batch_eval_chunks( query=query.query, chunk_contents=[chunk.content for chunk in chunks_to_filter], + llm=llm, ) return [ chunk.unique_id @@ -153,6 +156,7 @@ def filter_chunks( def search_postprocessing( search_query: SearchQuery, retrieved_chunks: list[InferenceChunk], + llm: LLM, rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, ) -> Generator[list[InferenceChunk] | list[str], None, None]: post_processing_tasks: list[FunctionCall] = [] @@ -184,7 +188,11 @@ def search_postprocessing( post_processing_tasks.append( FunctionCall( filter_chunks, - (search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]), + ( + search_query, + retrieved_chunks[: search_query.max_llm_filter_chunks], + llm, + ), ) ) llm_filter_task_id = post_processing_tasks[-1].result_id diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py index 2ef3787c1..685871095 100644 --- a/backend/danswer/secondary_llm_flows/answer_validation.py +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -1,5 +1,5 @@ from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT @@ -44,7 +44,7 @@ def get_answer_validity( return True # If something is wrong, let's not toss away the answer try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: return True diff --git a/backend/danswer/secondary_llm_flows/chunk_usefulness.py b/backend/danswer/secondary_llm_flows/chunk_usefulness.py index d37feb0c0..8148a3713 100644 --- a/backend/danswer/secondary_llm_flows/chunk_usefulness.py +++ b/backend/danswer/secondary_llm_flows/chunk_usefulness.py @@ -1,7 +1,6 @@ from collections.abc import Callable -from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT @@ -12,7 +11,7 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle logger = setup_logger() -def llm_eval_chunk(query: str, chunk_content: str) -> bool: +def llm_eval_chunk(query: str, chunk_content: str, llm: LLM) -> bool: def _get_usefulness_messages() -> list[dict[str, str]]: messages = [ { @@ -32,14 +31,6 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool: return False return True - # If Gen AI is disabled, none of the messages are more "useful" than any other - # All are marked not useful (False) so that the icon for Gen AI likes this answer - # is not shown for any result - try: - llm = get_default_llm(use_fast_llm=True, timeout=5) - except GenAIDisabledException: - return False - messages = _get_usefulness_messages() filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) # When running in a batch, it takes as long as the longest thread @@ -52,11 +43,12 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool: def llm_batch_eval_chunks( - query: str, chunk_contents: list[str], use_threads: bool = True + query: str, chunk_contents: list[str], llm: LLM, use_threads: bool = True ) -> list[bool]: if use_threads: functions_with_args: list[tuple[Callable, tuple]] = [ - (llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents + (llm_eval_chunk, (query, chunk_content, llm)) + for chunk_content in chunk_contents ] logger.debug( @@ -71,5 +63,6 @@ def llm_batch_eval_chunks( else: return [ - llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents + llm_eval_chunk(query, chunk_content, llm) + for chunk_content in chunk_contents ] diff --git a/backend/danswer/secondary_llm_flows/query_expansion.py b/backend/danswer/secondary_llm_flows/query_expansion.py index 2f221bfa9..2aa17855d 100644 --- a/backend/danswer/secondary_llm_flows/query_expansion.py +++ b/backend/danswer/secondary_llm_flows/query_expansion.py @@ -6,7 +6,7 @@ from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF from danswer.db.models import ChatMessage from danswer.llm.answering.models import PreviousMessage from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.interfaces import LLM from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_to_string @@ -33,7 +33,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str: return messages try: - llm = get_default_llm(use_fast_llm=True, timeout=5) + _, fast_llm = get_default_llms(timeout=5) except GenAIDisabledException: logger.warning( "Unable to perform multilingual query expansion, Gen AI disabled" @@ -42,7 +42,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str: messages = _get_rephrase_messages() filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages) - model_output = message_to_string(llm.invoke(filled_llm_prompt)) + model_output = message_to_string(fast_llm.invoke(filled_llm_prompt)) logger.debug(model_output) return model_output @@ -148,7 +148,7 @@ def thread_based_query_rephrase( if llm is None: try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: # If Generative AI is turned off, just return the original query return user_query diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 4130b7ee3..bbc1ef412 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -5,7 +5,7 @@ from danswer.chat.models import DanswerAnswerPiece from danswer.chat.models import StreamingError from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.llm.utils import message_generator_to_string_generator from danswer.llm.utils import message_to_string @@ -52,7 +52,7 @@ def get_query_answerability( return "Query Answerability Evaluation feature is turned off", True try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: return "Generative AI is turned off - skipping check", True @@ -79,7 +79,7 @@ def stream_query_answerability( return try: - llm = get_default_llm() + llm, _ = get_default_llms() except GenAIDisabledException: yield get_json_line( QueryValidationResponse( diff --git a/backend/danswer/secondary_llm_flows/source_filter.py b/backend/danswer/secondary_llm_flows/source_filter.py index 78dc504ab..802a14f42 100644 --- a/backend/danswer/secondary_llm_flows/source_filter.py +++ b/backend/danswer/secondary_llm_flows/source_filter.py @@ -159,11 +159,13 @@ def extract_source_filter( if __name__ == "__main__": - from danswer.llm.factory import get_default_llm + from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple # Just for testing purposes with Session(get_sqlalchemy_engine()) as db_session: while True: user_input = input("Query to Extract Sources: ") - sources = extract_source_filter(user_input, get_default_llm(), db_session) + sources = extract_source_filter( + user_input, get_main_llm_from_tuple(get_default_llms()), db_session + ) print(sources) diff --git a/backend/danswer/secondary_llm_flows/time_filter.py b/backend/danswer/secondary_llm_flows/time_filter.py index 7de6efb3d..aef32d7bd 100644 --- a/backend/danswer/secondary_llm_flows/time_filter.py +++ b/backend/danswer/secondary_llm_flows/time_filter.py @@ -156,10 +156,12 @@ def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]: if __name__ == "__main__": # Just for testing purposes, too tedious to unit test as it relies on an LLM - from danswer.llm.factory import get_default_llm + from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple while True: user_input = input("Query to Extract Time: ") - cutoff, recency_bias = extract_time_filter(user_input, get_default_llm()) + cutoff, recency_bias = extract_time_filter( + user_input, get_main_llm_from_tuple(get_default_llms()) + ) print(f"Time Cutoff: {cutoff}") print(f"Favor Recent: {recency_bias}") diff --git a/backend/danswer/server/gpts/api.py b/backend/danswer/server/gpts/api.py index cefb92341..84b0078ee 100644 --- a/backend/danswer/server/gpts/api.py +++ b/backend/danswer/server/gpts/api.py @@ -7,7 +7,7 @@ from pydantic import BaseModel from sqlalchemy.orm import Session from danswer.db.engine import get_session -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.search.models import SearchRequest from danswer.search.pipeline import SearchPipeline from danswer.server.danswer_api.ingestion import api_key_dep @@ -67,12 +67,14 @@ def gpt_search( _: str | None = Depends(api_key_dep), db_session: Session = Depends(get_session), ) -> GptSearchResponse: + llm, fast_llm = get_default_llms() top_chunks = SearchPipeline( search_request=SearchRequest( query=search_request.query, ), user=None, - llm=get_default_llm(), + llm=llm, + fast_llm=fast_llm, db_session=db_session, ).reranked_chunks diff --git a/backend/danswer/server/manage/administrative.py b/backend/danswer/server/manage/administrative.py index 0535ff815..d6a52917f 100644 --- a/backend/danswer/server/manage/administrative.py +++ b/backend/danswer/server/manage/administrative.py @@ -24,7 +24,7 @@ from danswer.document_index.factory import get_default_document_index from danswer.dynamic_configs.factory import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.file_store.file_store import get_default_file_store -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.utils import test_llm from danswer.server.documents.models import ConnectorCredentialPairIdentifier from danswer.server.manage.models import BoostDoc @@ -126,7 +126,7 @@ def validate_existing_genai_api_key( pass try: - llm = get_default_llm(timeout=10) + llm, __ = get_default_llms(timeout=10) except ValueError: raise HTTPException(status_code=404, detail="LLM not setup") diff --git a/backend/danswer/server/manage/llm/api.py b/backend/danswer/server/manage/llm/api.py index 6e21a29b1..4df00b529 100644 --- a/backend/danswer/server/manage/llm/api.py +++ b/backend/danswer/server/manage/llm/api.py @@ -13,7 +13,7 @@ from danswer.db.llm import remove_llm_provider from danswer.db.llm import update_default_provider from danswer.db.llm import upsert_llm_provider from danswer.db.models import User -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.factory import get_llm from danswer.llm.llm_provider_options import fetch_available_well_known_llms from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor @@ -85,8 +85,7 @@ def test_default_provider( _: User | None = Depends(current_admin_user), ) -> None: try: - llm = get_default_llm() - fast_llm = get_default_llm(use_fast_llm=True) + llm, fast_llm = get_default_llms() except ValueError: logger.exception("Failed to fetch default LLM Provider") raise HTTPException(status_code=400, detail="No LLM Provider setup") diff --git a/backend/danswer/server/query_and_chat/chat_backend.py b/backend/danswer/server/query_and_chat/chat_backend.py index 57c63b14d..646660c9f 100644 --- a/backend/danswer/server/query_and_chat/chat_backend.py +++ b/backend/danswer/server/query_and_chat/chat_backend.py @@ -43,7 +43,7 @@ from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) from danswer.llm.exceptions import GenAIDisabledException -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.llm.headers import get_litellm_additional_request_headers from danswer.llm.utils import get_default_llm_tokenizer from danswer.secondary_llm_flows.chat_session_naming import ( @@ -224,7 +224,7 @@ def rename_chat_session( full_history = history_msgs + [final_msg] try: - llm = get_default_llm( + llm, _ = get_default_llms( additional_headers=get_litellm_additional_request_headers(request.headers) ) except GenAIDisabledException: diff --git a/backend/danswer/tools/search/search_tool.py b/backend/danswer/tools/search/search_tool.py index 5975d085b..75770e69f 100644 --- a/backend/danswer/tools/search/search_tool.py +++ b/backend/danswer/tools/search/search_tool.py @@ -69,6 +69,7 @@ class SearchTool(Tool): retrieval_options: RetrievalDetails | None, prompt_config: PromptConfig, llm: LLM, + fast_llm: LLM, pruning_config: DocumentPruningConfig, # if specified, will not actually run a search and will instead return these # sections. Used when the user selects specific docs to talk to @@ -83,6 +84,7 @@ class SearchTool(Tool): self.retrieval_options = retrieval_options self.prompt_config = prompt_config self.llm = llm + self.fast_llm = fast_llm self.pruning_config = pruning_config self.selected_docs = selected_docs @@ -212,6 +214,7 @@ class SearchTool(Tool): ), user=self.user, llm=self.llm, + fast_llm=self.fast_llm, bypass_acl=self.bypass_acl, db_session=self.db_session, ) diff --git a/backend/ee/danswer/server/query_and_chat/query_backend.py b/backend/ee/danswer/server/query_and_chat/query_backend.py index 1e02ac0aa..f61772c6a 100644 --- a/backend/ee/danswer/server/query_and_chat/query_backend.py +++ b/backend/ee/danswer/server/query_and_chat/query_backend.py @@ -10,8 +10,9 @@ from danswer.db.persona import get_persona_by_id from danswer.llm.answering.prompts.citations_prompt import ( compute_max_document_tokens_for_persona, ) -from danswer.llm.factory import get_default_llm -from danswer.llm.factory import get_llm_for_persona +from danswer.llm.factory import get_default_llms +from danswer.llm.factory import get_llms_for_persona +from danswer.llm.factory import get_main_llm_from_tuple from danswer.llm.utils import get_max_input_tokens from danswer.one_shot_answer.answer_question import get_search_answer from danswer.one_shot_answer.models import DirectQARequest @@ -41,7 +42,7 @@ def handle_search_request( query = search_request.message logger.info(f"Received document search query: {query}") - llm = get_default_llm() + llm, fast_llm = get_default_llms() search_pipeline = SearchPipeline( search_request=SearchRequest( query=query, @@ -59,6 +60,7 @@ def handle_search_request( ), user=user, llm=llm, + fast_llm=fast_llm, db_session=db_session, bypass_acl=False, ) @@ -104,7 +106,9 @@ def get_answer_with_quote( is_for_edit=False, ) - llm = get_default_llm() if not persona else get_llm_for_persona(persona) + llm = get_main_llm_from_tuple( + get_default_llms() if not persona else get_llms_for_persona(persona) + ) input_tokens = get_max_input_tokens( model_name=llm.config.model_name, model_provider=llm.config.model_provider ) diff --git a/backend/tests/regression/search_quality/eval_search.py b/backend/tests/regression/search_quality/eval_search.py index e67fe7c5b..9567a6e50 100644 --- a/backend/tests/regression/search_quality/eval_search.py +++ b/backend/tests/regression/search_quality/eval_search.py @@ -9,7 +9,7 @@ from sqlalchemy.orm import Session from danswer.db.engine import get_sqlalchemy_engine from danswer.llm.answering.doc_pruning import reorder_docs -from danswer.llm.factory import get_default_llm +from danswer.llm.factory import get_default_llms from danswer.search.models import InferenceChunk from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer @@ -83,12 +83,14 @@ def get_search_results( rerank_metrics = MetricsHander[RerankMetricsContainer]() with Session(get_sqlalchemy_engine()) as db_session: + llm, fast_llm = get_default_llms() search_pipeline = SearchPipeline( search_request=SearchRequest( query=query, ), user=None, - llm=get_default_llm(), + llm=llm, + fast_llm=fast_llm, db_session=db_session, retrieval_metrics_callback=retrieval_metrics.record_metric, rerank_metrics_callback=rerank_metrics.record_metric,