mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-07 19:38:19 +02:00
Fix fast models
This commit is contained in:
parent
ed550986a6
commit
415960564d
@ -47,7 +47,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.search.enums import OptionalSearchSetting
|
||||
from danswer.search.retrieval.search_runner import inference_documents_from_ids
|
||||
@ -244,7 +245,7 @@ def stream_chat_message_objects(
|
||||
)
|
||||
|
||||
try:
|
||||
llm = get_llm_for_persona(
|
||||
llm, fast_llm = get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
@ -425,6 +426,7 @@ def stream_chat_message_objects(
|
||||
retrieval_options=retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
selected_docs=selected_llm_docs,
|
||||
chunks_above=new_msg_req.chunks_above,
|
||||
@ -498,10 +500,14 @@ def stream_chat_message_objects(
|
||||
prompt_config=prompt_config,
|
||||
llm=(
|
||||
llm
|
||||
or get_llm_for_persona(
|
||||
persona=persona,
|
||||
llm_override=new_msg_req.llm_override or chat_session.llm_override,
|
||||
additional_headers=litellm_additional_headers,
|
||||
or get_main_llm_from_tuple(
|
||||
get_llms_for_persona(
|
||||
persona=persona,
|
||||
llm_override=(
|
||||
new_msg_req.llm_override or chat_session.llm_override
|
||||
),
|
||||
additional_headers=litellm_additional_headers,
|
||||
)
|
||||
)
|
||||
),
|
||||
message_history=[
|
||||
|
@ -50,7 +50,7 @@ from danswer.db.persona import fetch_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
@ -324,7 +324,7 @@ def handle_message(
|
||||
Persona,
|
||||
fetch_persona_by_id(db_session, new_message_request.persona_id),
|
||||
)
|
||||
llm = get_llm_for_persona(persona)
|
||||
llm, _ = get_llms_for_persona(persona)
|
||||
|
||||
# In cases of threads, split the available tokens between docs and thread context
|
||||
input_tokens = get_max_input_tokens(
|
||||
|
@ -30,7 +30,7 @@ from danswer.danswerbot.slack.tokens import fetch_tokens
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.users import get_user_by_email
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.one_shot_answer.models import ThreadMessage
|
||||
@ -58,7 +58,7 @@ def rephrase_slack_message(msg: str) -> str:
|
||||
return messages
|
||||
|
||||
try:
|
||||
llm = get_default_llm(use_fast_llm=False, timeout=5)
|
||||
llm, _ = get_default_llms(timeout=5)
|
||||
except GenAIDisabledException:
|
||||
logger.warning("Unable to rephrase Slack user message, Gen AI disabled")
|
||||
return msg
|
||||
|
@ -8,7 +8,8 @@ from danswer.db.models import Persona
|
||||
from danswer.db.persona import get_default_prompt__read_only
|
||||
from danswer.file_store.utils import InMemoryChatFile
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.interfaces import LLMConfig
|
||||
from danswer.llm.utils import build_content_with_imgs
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
@ -99,7 +100,7 @@ def compute_max_document_tokens_for_persona(
|
||||
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
|
||||
return compute_max_document_tokens(
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm_config=get_llm_for_persona(persona).config,
|
||||
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
|
||||
actual_user_input=actual_user_input,
|
||||
max_llm_token_override=max_llm_token_override,
|
||||
)
|
||||
|
@ -12,65 +12,92 @@ from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.override_models import LLMOverride
|
||||
|
||||
|
||||
def get_llm_for_persona(
|
||||
def get_main_llm_from_tuple(
|
||||
llms: tuple[LLM, LLM],
|
||||
) -> LLM:
|
||||
return llms[0]
|
||||
|
||||
|
||||
def get_llms_for_persona(
|
||||
persona: Persona,
|
||||
llm_override: LLMOverride | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> LLM:
|
||||
) -> tuple[LLM, LLM]:
|
||||
model_provider_override = llm_override.model_provider if llm_override else None
|
||||
model_version_override = llm_override.model_version if llm_override else None
|
||||
temperature_override = llm_override.temperature if llm_override else None
|
||||
|
||||
return get_default_llm(
|
||||
model_provider_name=(
|
||||
model_provider_override or persona.llm_model_provider_override
|
||||
),
|
||||
model_version=(model_version_override or persona.llm_model_version_override),
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
provider_name = model_provider_override or persona.llm_model_provider_override
|
||||
if not provider_name:
|
||||
return get_default_llms(
|
||||
temperature=temperature_override or GEN_AI_TEMPERATURE,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_provider(db_session, provider_name)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No LLM provider found")
|
||||
|
||||
model = model_version_override or persona.llm_model_version_override
|
||||
fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name
|
||||
if not model:
|
||||
raise ValueError("No model name found")
|
||||
if not fast_model:
|
||||
raise ValueError("No fast model name found")
|
||||
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
return _create_llm(model), _create_llm(fast_model)
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
def get_default_llms(
|
||||
timeout: int = QA_TIMEOUT,
|
||||
temperature: float = GEN_AI_TEMPERATURE,
|
||||
use_fast_llm: bool = False,
|
||||
model_provider_name: str | None = None,
|
||||
model_version: str | None = None,
|
||||
additional_headers: dict[str, str] | None = None,
|
||||
) -> LLM:
|
||||
) -> tuple[LLM, LLM]:
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
raise GenAIDisabledException()
|
||||
|
||||
# TODO: pass this in
|
||||
with get_session_context_manager() as session:
|
||||
if model_provider_name is None:
|
||||
llm_provider = fetch_default_provider(session)
|
||||
else:
|
||||
llm_provider = fetch_provider(session, model_provider_name)
|
||||
with get_session_context_manager() as db_session:
|
||||
llm_provider = fetch_default_provider(db_session)
|
||||
|
||||
if not llm_provider:
|
||||
raise ValueError("No default LLM provider found")
|
||||
|
||||
model_name = model_version or (
|
||||
(llm_provider.fast_default_model_name or llm_provider.default_model_name)
|
||||
if use_fast_llm
|
||||
else llm_provider.default_model_name
|
||||
model_name = llm_provider.default_model_name
|
||||
fast_model_name = (
|
||||
llm_provider.fast_default_model_name or llm_provider.default_model_name
|
||||
)
|
||||
if not model_name:
|
||||
raise ValueError("No default model name found")
|
||||
if not fast_model_name:
|
||||
raise ValueError("No fast default model name found")
|
||||
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model_name,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
def _create_llm(model: str) -> LLM:
|
||||
return get_llm(
|
||||
provider=llm_provider.provider,
|
||||
model=model,
|
||||
api_key=llm_provider.api_key,
|
||||
api_base=llm_provider.api_base,
|
||||
api_version=llm_provider.api_version,
|
||||
custom_config=llm_provider.custom_config,
|
||||
timeout=timeout,
|
||||
temperature=temperature,
|
||||
additional_headers=additional_headers,
|
||||
)
|
||||
|
||||
return _create_llm(model_name), _create_llm(fast_model_name)
|
||||
|
||||
|
||||
def get_llm(
|
||||
|
@ -30,7 +30,8 @@ from danswer.llm.answering.models import CitationConfig
|
||||
from danswer.llm.answering.models import DocumentPruningConfig
|
||||
from danswer.llm.answering.models import PromptConfig
|
||||
from danswer.llm.answering.models import QuotesConfig
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_default_llm_token_encode
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
from danswer.one_shot_answer.models import OneShotQAResponse
|
||||
@ -156,7 +157,7 @@ def stream_answer_objects(
|
||||
commit=True,
|
||||
)
|
||||
|
||||
llm = get_llm_for_persona(persona=chat_session.persona)
|
||||
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
|
||||
prompt_config = PromptConfig.from_model(prompt)
|
||||
document_pruning_config = DocumentPruningConfig(
|
||||
max_chunks=int(
|
||||
@ -174,6 +175,7 @@ def stream_answer_objects(
|
||||
retrieval_options=query_req.retrieval_options,
|
||||
prompt_config=prompt_config,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
pruning_config=document_pruning_config,
|
||||
bypass_acl=bypass_acl,
|
||||
)
|
||||
@ -187,7 +189,7 @@ def stream_answer_objects(
|
||||
question=query_msg.message,
|
||||
answer_style_config=answer_config,
|
||||
prompt_config=PromptConfig.from_model(prompt),
|
||||
llm=get_llm_for_persona(persona=chat_session.persona),
|
||||
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
|
||||
single_message_history=history_str,
|
||||
tools=[search_tool],
|
||||
force_use_tool=ForceUseTool(
|
||||
|
@ -56,6 +56,7 @@ class SearchPipeline:
|
||||
search_request: SearchRequest,
|
||||
user: User | None,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
db_session: Session,
|
||||
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
|
||||
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
|
||||
@ -65,6 +66,7 @@ class SearchPipeline:
|
||||
self.search_request = search_request
|
||||
self.user = user
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.db_session = db_session
|
||||
self.bypass_acl = bypass_acl
|
||||
self.retrieval_metrics_callback = retrieval_metrics_callback
|
||||
@ -298,6 +300,7 @@ class SearchPipeline:
|
||||
self._postprocessing_generator = search_postprocessing(
|
||||
search_query=self.search_query,
|
||||
retrieved_chunks=self.retrieved_chunks,
|
||||
llm=self.fast_llm, # use fast_llm for relevance, since it is a relatively easier task
|
||||
rerank_metrics_callback=self.rerank_metrics_callback,
|
||||
)
|
||||
self._reranked_chunks = cast(
|
||||
|
@ -9,6 +9,7 @@ from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
|
||||
from danswer.document_index.document_index_utils import (
|
||||
translate_boost_count_to_multiplier,
|
||||
)
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.search.models import ChunkMetric
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import MAX_METRICS_CONTENT
|
||||
@ -134,6 +135,7 @@ def rerank_chunks(
|
||||
def filter_chunks(
|
||||
query: SearchQuery,
|
||||
chunks_to_filter: list[InferenceChunk],
|
||||
llm: LLM,
|
||||
) -> list[str]:
|
||||
"""Filters chunks based on whether the LLM thought they were relevant to the query.
|
||||
|
||||
@ -142,6 +144,7 @@ def filter_chunks(
|
||||
llm_chunk_selection = llm_batch_eval_chunks(
|
||||
query=query.query,
|
||||
chunk_contents=[chunk.content for chunk in chunks_to_filter],
|
||||
llm=llm,
|
||||
)
|
||||
return [
|
||||
chunk.unique_id
|
||||
@ -153,6 +156,7 @@ def filter_chunks(
|
||||
def search_postprocessing(
|
||||
search_query: SearchQuery,
|
||||
retrieved_chunks: list[InferenceChunk],
|
||||
llm: LLM,
|
||||
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
|
||||
) -> Generator[list[InferenceChunk] | list[str], None, None]:
|
||||
post_processing_tasks: list[FunctionCall] = []
|
||||
@ -184,7 +188,11 @@ def search_postprocessing(
|
||||
post_processing_tasks.append(
|
||||
FunctionCall(
|
||||
filter_chunks,
|
||||
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
|
||||
(
|
||||
search_query,
|
||||
retrieved_chunks[: search_query.max_llm_filter_chunks],
|
||||
llm,
|
||||
),
|
||||
)
|
||||
)
|
||||
llm_filter_task_id = post_processing_tasks[-1].result_id
|
||||
|
@ -1,5 +1,5 @@
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
|
||||
@ -44,7 +44,7 @@ def get_answer_validity(
|
||||
return True # If something is wrong, let's not toss away the answer
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
llm, _ = get_default_llms()
|
||||
except GenAIDisabledException:
|
||||
return True
|
||||
|
||||
|
@ -1,7 +1,6 @@
|
||||
from collections.abc import Callable
|
||||
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
|
||||
@ -12,7 +11,7 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
||||
def llm_eval_chunk(query: str, chunk_content: str, llm: LLM) -> bool:
|
||||
def _get_usefulness_messages() -> list[dict[str, str]]:
|
||||
messages = [
|
||||
{
|
||||
@ -32,14 +31,6 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
||||
return False
|
||||
return True
|
||||
|
||||
# If Gen AI is disabled, none of the messages are more "useful" than any other
|
||||
# All are marked not useful (False) so that the icon for Gen AI likes this answer
|
||||
# is not shown for any result
|
||||
try:
|
||||
llm = get_default_llm(use_fast_llm=True, timeout=5)
|
||||
except GenAIDisabledException:
|
||||
return False
|
||||
|
||||
messages = _get_usefulness_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
# When running in a batch, it takes as long as the longest thread
|
||||
@ -52,11 +43,12 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
|
||||
|
||||
|
||||
def llm_batch_eval_chunks(
|
||||
query: str, chunk_contents: list[str], use_threads: bool = True
|
||||
query: str, chunk_contents: list[str], llm: LLM, use_threads: bool = True
|
||||
) -> list[bool]:
|
||||
if use_threads:
|
||||
functions_with_args: list[tuple[Callable, tuple]] = [
|
||||
(llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents
|
||||
(llm_eval_chunk, (query, chunk_content, llm))
|
||||
for chunk_content in chunk_contents
|
||||
]
|
||||
|
||||
logger.debug(
|
||||
@ -71,5 +63,6 @@ def llm_batch_eval_chunks(
|
||||
|
||||
else:
|
||||
return [
|
||||
llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents
|
||||
llm_eval_chunk(query, chunk_content, llm)
|
||||
for chunk_content in chunk_contents
|
||||
]
|
||||
|
@ -6,7 +6,7 @@ from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.llm.answering.models import PreviousMessage
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_to_string
|
||||
@ -33,7 +33,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
return messages
|
||||
|
||||
try:
|
||||
llm = get_default_llm(use_fast_llm=True, timeout=5)
|
||||
_, fast_llm = get_default_llms(timeout=5)
|
||||
except GenAIDisabledException:
|
||||
logger.warning(
|
||||
"Unable to perform multilingual query expansion, Gen AI disabled"
|
||||
@ -42,7 +42,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
|
||||
|
||||
messages = _get_rephrase_messages()
|
||||
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
|
||||
model_output = message_to_string(llm.invoke(filled_llm_prompt))
|
||||
model_output = message_to_string(fast_llm.invoke(filled_llm_prompt))
|
||||
logger.debug(model_output)
|
||||
|
||||
return model_output
|
||||
@ -148,7 +148,7 @@ def thread_based_query_rephrase(
|
||||
|
||||
if llm is None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
llm, _ = get_default_llms()
|
||||
except GenAIDisabledException:
|
||||
# If Generative AI is turned off, just return the original query
|
||||
return user_query
|
||||
|
@ -5,7 +5,7 @@ from danswer.chat.models import DanswerAnswerPiece
|
||||
from danswer.chat.models import StreamingError
|
||||
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.llm.utils import message_to_string
|
||||
@ -52,7 +52,7 @@ def get_query_answerability(
|
||||
return "Query Answerability Evaluation feature is turned off", True
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
llm, _ = get_default_llms()
|
||||
except GenAIDisabledException:
|
||||
return "Generative AI is turned off - skipping check", True
|
||||
|
||||
@ -79,7 +79,7 @@ def stream_query_answerability(
|
||||
return
|
||||
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
llm, _ = get_default_llms()
|
||||
except GenAIDisabledException:
|
||||
yield get_json_line(
|
||||
QueryValidationResponse(
|
||||
|
@ -159,11 +159,13 @@ def extract_source_filter(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple
|
||||
|
||||
# Just for testing purposes
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
while True:
|
||||
user_input = input("Query to Extract Sources: ")
|
||||
sources = extract_source_filter(user_input, get_default_llm(), db_session)
|
||||
sources = extract_source_filter(
|
||||
user_input, get_main_llm_from_tuple(get_default_llms()), db_session
|
||||
)
|
||||
print(sources)
|
||||
|
@ -156,10 +156,12 @@ def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]:
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Just for testing purposes, too tedious to unit test as it relies on an LLM
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple
|
||||
|
||||
while True:
|
||||
user_input = input("Query to Extract Time: ")
|
||||
cutoff, recency_bias = extract_time_filter(user_input, get_default_llm())
|
||||
cutoff, recency_bias = extract_time_filter(
|
||||
user_input, get_main_llm_from_tuple(get_default_llms())
|
||||
)
|
||||
print(f"Time Cutoff: {cutoff}")
|
||||
print(f"Favor Recent: {recency_bias}")
|
||||
|
@ -7,7 +7,7 @@ from pydantic import BaseModel
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.search.models import SearchRequest
|
||||
from danswer.search.pipeline import SearchPipeline
|
||||
from danswer.server.danswer_api.ingestion import api_key_dep
|
||||
@ -67,12 +67,14 @@ def gpt_search(
|
||||
_: str | None = Depends(api_key_dep),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> GptSearchResponse:
|
||||
llm, fast_llm = get_default_llms()
|
||||
top_chunks = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=search_request.query,
|
||||
),
|
||||
user=None,
|
||||
llm=get_default_llm(),
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
).reranked_chunks
|
||||
|
||||
|
@ -24,7 +24,7 @@ from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.dynamic_configs.factory import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.file_store.file_store import get_default_file_store
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.utils import test_llm
|
||||
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.manage.models import BoostDoc
|
||||
@ -126,7 +126,7 @@ def validate_existing_genai_api_key(
|
||||
pass
|
||||
|
||||
try:
|
||||
llm = get_default_llm(timeout=10)
|
||||
llm, __ = get_default_llms(timeout=10)
|
||||
except ValueError:
|
||||
raise HTTPException(status_code=404, detail="LLM not setup")
|
||||
|
||||
|
@ -13,7 +13,7 @@ from danswer.db.llm import remove_llm_provider
|
||||
from danswer.db.llm import update_default_provider
|
||||
from danswer.db.llm import upsert_llm_provider
|
||||
from danswer.db.models import User
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llm
|
||||
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
|
||||
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
|
||||
@ -85,8 +85,7 @@ def test_default_provider(
|
||||
_: User | None = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
llm = get_default_llm()
|
||||
fast_llm = get_default_llm(use_fast_llm=True)
|
||||
llm, fast_llm = get_default_llms()
|
||||
except ValueError:
|
||||
logger.exception("Failed to fetch default LLM Provider")
|
||||
raise HTTPException(status_code=400, detail="No LLM Provider setup")
|
||||
|
@ -43,7 +43,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.exceptions import GenAIDisabledException
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.headers import get_litellm_additional_request_headers
|
||||
from danswer.llm.utils import get_default_llm_tokenizer
|
||||
from danswer.secondary_llm_flows.chat_session_naming import (
|
||||
@ -224,7 +224,7 @@ def rename_chat_session(
|
||||
full_history = history_msgs + [final_msg]
|
||||
|
||||
try:
|
||||
llm = get_default_llm(
|
||||
llm, _ = get_default_llms(
|
||||
additional_headers=get_litellm_additional_request_headers(request.headers)
|
||||
)
|
||||
except GenAIDisabledException:
|
||||
|
@ -69,6 +69,7 @@ class SearchTool(Tool):
|
||||
retrieval_options: RetrievalDetails | None,
|
||||
prompt_config: PromptConfig,
|
||||
llm: LLM,
|
||||
fast_llm: LLM,
|
||||
pruning_config: DocumentPruningConfig,
|
||||
# if specified, will not actually run a search and will instead return these
|
||||
# sections. Used when the user selects specific docs to talk to
|
||||
@ -83,6 +84,7 @@ class SearchTool(Tool):
|
||||
self.retrieval_options = retrieval_options
|
||||
self.prompt_config = prompt_config
|
||||
self.llm = llm
|
||||
self.fast_llm = fast_llm
|
||||
self.pruning_config = pruning_config
|
||||
|
||||
self.selected_docs = selected_docs
|
||||
@ -212,6 +214,7 @@ class SearchTool(Tool):
|
||||
),
|
||||
user=self.user,
|
||||
llm=self.llm,
|
||||
fast_llm=self.fast_llm,
|
||||
bypass_acl=self.bypass_acl,
|
||||
db_session=self.db_session,
|
||||
)
|
||||
|
@ -10,8 +10,9 @@ from danswer.db.persona import get_persona_by_id
|
||||
from danswer.llm.answering.prompts.citations_prompt import (
|
||||
compute_max_document_tokens_for_persona,
|
||||
)
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_llm_for_persona
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.llm.factory import get_llms_for_persona
|
||||
from danswer.llm.factory import get_main_llm_from_tuple
|
||||
from danswer.llm.utils import get_max_input_tokens
|
||||
from danswer.one_shot_answer.answer_question import get_search_answer
|
||||
from danswer.one_shot_answer.models import DirectQARequest
|
||||
@ -41,7 +42,7 @@ def handle_search_request(
|
||||
query = search_request.message
|
||||
logger.info(f"Received document search query: {query}")
|
||||
|
||||
llm = get_default_llm()
|
||||
llm, fast_llm = get_default_llms()
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
@ -59,6 +60,7 @@ def handle_search_request(
|
||||
),
|
||||
user=user,
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
bypass_acl=False,
|
||||
)
|
||||
@ -104,7 +106,9 @@ def get_answer_with_quote(
|
||||
is_for_edit=False,
|
||||
)
|
||||
|
||||
llm = get_default_llm() if not persona else get_llm_for_persona(persona)
|
||||
llm = get_main_llm_from_tuple(
|
||||
get_default_llms() if not persona else get_llms_for_persona(persona)
|
||||
)
|
||||
input_tokens = get_max_input_tokens(
|
||||
model_name=llm.config.model_name, model_provider=llm.config.model_provider
|
||||
)
|
||||
|
@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.llm.answering.doc_pruning import reorder_docs
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.factory import get_default_llms
|
||||
from danswer.search.models import InferenceChunk
|
||||
from danswer.search.models import RerankMetricsContainer
|
||||
from danswer.search.models import RetrievalMetricsContainer
|
||||
@ -83,12 +83,14 @@ def get_search_results(
|
||||
rerank_metrics = MetricsHander[RerankMetricsContainer]()
|
||||
|
||||
with Session(get_sqlalchemy_engine()) as db_session:
|
||||
llm, fast_llm = get_default_llms()
|
||||
search_pipeline = SearchPipeline(
|
||||
search_request=SearchRequest(
|
||||
query=query,
|
||||
),
|
||||
user=None,
|
||||
llm=get_default_llm(),
|
||||
llm=llm,
|
||||
fast_llm=fast_llm,
|
||||
db_session=db_session,
|
||||
retrieval_metrics_callback=retrieval_metrics.record_metric,
|
||||
rerank_metrics_callback=rerank_metrics.record_metric,
|
||||
|
Loading…
x
Reference in New Issue
Block a user