Fix fast models

This commit is contained in:
Weves 2024-06-28 17:13:23 -07:00 committed by Chris Weaver
parent ed550986a6
commit 415960564d
21 changed files with 148 additions and 94 deletions

View File

@ -47,7 +47,8 @@ from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.answering.models import PromptConfig
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.search.enums import OptionalSearchSetting
from danswer.search.retrieval.search_runner import inference_documents_from_ids
@ -244,7 +245,7 @@ def stream_chat_message_objects(
)
try:
llm = get_llm_for_persona(
llm, fast_llm = get_llms_for_persona(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
@ -425,6 +426,7 @@ def stream_chat_message_objects(
retrieval_options=retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
selected_docs=selected_llm_docs,
chunks_above=new_msg_req.chunks_above,
@ -498,10 +500,14 @@ def stream_chat_message_objects(
prompt_config=prompt_config,
llm=(
llm
or get_llm_for_persona(
persona=persona,
llm_override=new_msg_req.llm_override or chat_session.llm_override,
additional_headers=litellm_additional_headers,
or get_main_llm_from_tuple(
get_llms_for_persona(
persona=persona,
llm_override=(
new_msg_req.llm_override or chat_session.llm_override
),
additional_headers=litellm_additional_headers,
)
)
),
message_history=[

View File

@ -50,7 +50,7 @@ from danswer.db.persona import fetch_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.utils import check_number_of_tokens
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
@ -324,7 +324,7 @@ def handle_message(
Persona,
fetch_persona_by_id(db_session, new_message_request.persona_id),
)
llm = get_llm_for_persona(persona)
llm, _ = get_llms_for_persona(persona)
# In cases of threads, split the available tokens between docs and thread context
input_tokens = get_max_input_tokens(

View File

@ -30,7 +30,7 @@ from danswer.danswerbot.slack.tokens import fetch_tokens
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.users import get_user_by_email
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.one_shot_answer.models import ThreadMessage
@ -58,7 +58,7 @@ def rephrase_slack_message(msg: str) -> str:
return messages
try:
llm = get_default_llm(use_fast_llm=False, timeout=5)
llm, _ = get_default_llms(timeout=5)
except GenAIDisabledException:
logger.warning("Unable to rephrase Slack user message, Gen AI disabled")
return msg

View File

@ -8,7 +8,8 @@ from danswer.db.models import Persona
from danswer.db.persona import get_default_prompt__read_only
from danswer.file_store.utils import InMemoryChatFile
from danswer.llm.answering.models import PromptConfig
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.interfaces import LLMConfig
from danswer.llm.utils import build_content_with_imgs
from danswer.llm.utils import check_number_of_tokens
@ -99,7 +100,7 @@ def compute_max_document_tokens_for_persona(
prompt = persona.prompts[0] if persona.prompts else get_default_prompt__read_only()
return compute_max_document_tokens(
prompt_config=PromptConfig.from_model(prompt),
llm_config=get_llm_for_persona(persona).config,
llm_config=get_main_llm_from_tuple(get_llms_for_persona(persona)).config,
actual_user_input=actual_user_input,
max_llm_token_override=max_llm_token_override,
)

View File

@ -12,65 +12,92 @@ from danswer.llm.interfaces import LLM
from danswer.llm.override_models import LLMOverride
def get_llm_for_persona(
def get_main_llm_from_tuple(
llms: tuple[LLM, LLM],
) -> LLM:
return llms[0]
def get_llms_for_persona(
persona: Persona,
llm_override: LLMOverride | None = None,
additional_headers: dict[str, str] | None = None,
) -> LLM:
) -> tuple[LLM, LLM]:
model_provider_override = llm_override.model_provider if llm_override else None
model_version_override = llm_override.model_version if llm_override else None
temperature_override = llm_override.temperature if llm_override else None
return get_default_llm(
model_provider_name=(
model_provider_override or persona.llm_model_provider_override
),
model_version=(model_version_override or persona.llm_model_version_override),
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
)
provider_name = model_provider_override or persona.llm_model_provider_override
if not provider_name:
return get_default_llms(
temperature=temperature_override or GEN_AI_TEMPERATURE,
additional_headers=additional_headers,
)
with get_session_context_manager() as db_session:
llm_provider = fetch_provider(db_session, provider_name)
if not llm_provider:
raise ValueError("No LLM provider found")
model = model_version_override or persona.llm_model_version_override
fast_model = llm_provider.fast_default_model_name or llm_provider.default_model_name
if not model:
raise ValueError("No model name found")
if not fast_model:
raise ValueError("No fast model name found")
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
additional_headers=additional_headers,
)
return _create_llm(model), _create_llm(fast_model)
def get_default_llm(
def get_default_llms(
timeout: int = QA_TIMEOUT,
temperature: float = GEN_AI_TEMPERATURE,
use_fast_llm: bool = False,
model_provider_name: str | None = None,
model_version: str | None = None,
additional_headers: dict[str, str] | None = None,
) -> LLM:
) -> tuple[LLM, LLM]:
if DISABLE_GENERATIVE_AI:
raise GenAIDisabledException()
# TODO: pass this in
with get_session_context_manager() as session:
if model_provider_name is None:
llm_provider = fetch_default_provider(session)
else:
llm_provider = fetch_provider(session, model_provider_name)
with get_session_context_manager() as db_session:
llm_provider = fetch_default_provider(db_session)
if not llm_provider:
raise ValueError("No default LLM provider found")
model_name = model_version or (
(llm_provider.fast_default_model_name or llm_provider.default_model_name)
if use_fast_llm
else llm_provider.default_model_name
model_name = llm_provider.default_model_name
fast_model_name = (
llm_provider.fast_default_model_name or llm_provider.default_model_name
)
if not model_name:
raise ValueError("No default model name found")
if not fast_model_name:
raise ValueError("No fast default model name found")
return get_llm(
provider=llm_provider.provider,
model=model_name,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
)
def _create_llm(model: str) -> LLM:
return get_llm(
provider=llm_provider.provider,
model=model,
api_key=llm_provider.api_key,
api_base=llm_provider.api_base,
api_version=llm_provider.api_version,
custom_config=llm_provider.custom_config,
timeout=timeout,
temperature=temperature,
additional_headers=additional_headers,
)
return _create_llm(model_name), _create_llm(fast_model_name)
def get_llm(

View File

@ -30,7 +30,8 @@ from danswer.llm.answering.models import CitationConfig
from danswer.llm.answering.models import DocumentPruningConfig
from danswer.llm.answering.models import PromptConfig
from danswer.llm.answering.models import QuotesConfig
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import get_default_llm_token_encode
from danswer.one_shot_answer.models import DirectQARequest
from danswer.one_shot_answer.models import OneShotQAResponse
@ -156,7 +157,7 @@ def stream_answer_objects(
commit=True,
)
llm = get_llm_for_persona(persona=chat_session.persona)
llm, fast_llm = get_llms_for_persona(persona=chat_session.persona)
prompt_config = PromptConfig.from_model(prompt)
document_pruning_config = DocumentPruningConfig(
max_chunks=int(
@ -174,6 +175,7 @@ def stream_answer_objects(
retrieval_options=query_req.retrieval_options,
prompt_config=prompt_config,
llm=llm,
fast_llm=fast_llm,
pruning_config=document_pruning_config,
bypass_acl=bypass_acl,
)
@ -187,7 +189,7 @@ def stream_answer_objects(
question=query_msg.message,
answer_style_config=answer_config,
prompt_config=PromptConfig.from_model(prompt),
llm=get_llm_for_persona(persona=chat_session.persona),
llm=get_main_llm_from_tuple(get_llms_for_persona(persona=chat_session.persona)),
single_message_history=history_str,
tools=[search_tool],
force_use_tool=ForceUseTool(

View File

@ -56,6 +56,7 @@ class SearchPipeline:
search_request: SearchRequest,
user: User | None,
llm: LLM,
fast_llm: LLM,
db_session: Session,
bypass_acl: bool = False, # NOTE: VERY DANGEROUS, USE WITH CAUTION
retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None]
@ -65,6 +66,7 @@ class SearchPipeline:
self.search_request = search_request
self.user = user
self.llm = llm
self.fast_llm = fast_llm
self.db_session = db_session
self.bypass_acl = bypass_acl
self.retrieval_metrics_callback = retrieval_metrics_callback
@ -298,6 +300,7 @@ class SearchPipeline:
self._postprocessing_generator = search_postprocessing(
search_query=self.search_query,
retrieved_chunks=self.retrieved_chunks,
llm=self.fast_llm, # use fast_llm for relevance, since it is a relatively easier task
rerank_metrics_callback=self.rerank_metrics_callback,
)
self._reranked_chunks = cast(

View File

@ -9,6 +9,7 @@ from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN
from danswer.document_index.document_index_utils import (
translate_boost_count_to_multiplier,
)
from danswer.llm.interfaces import LLM
from danswer.search.models import ChunkMetric
from danswer.search.models import InferenceChunk
from danswer.search.models import MAX_METRICS_CONTENT
@ -134,6 +135,7 @@ def rerank_chunks(
def filter_chunks(
query: SearchQuery,
chunks_to_filter: list[InferenceChunk],
llm: LLM,
) -> list[str]:
"""Filters chunks based on whether the LLM thought they were relevant to the query.
@ -142,6 +144,7 @@ def filter_chunks(
llm_chunk_selection = llm_batch_eval_chunks(
query=query.query,
chunk_contents=[chunk.content for chunk in chunks_to_filter],
llm=llm,
)
return [
chunk.unique_id
@ -153,6 +156,7 @@ def filter_chunks(
def search_postprocessing(
search_query: SearchQuery,
retrieved_chunks: list[InferenceChunk],
llm: LLM,
rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None,
) -> Generator[list[InferenceChunk] | list[str], None, None]:
post_processing_tasks: list[FunctionCall] = []
@ -184,7 +188,11 @@ def search_postprocessing(
post_processing_tasks.append(
FunctionCall(
filter_chunks,
(search_query, retrieved_chunks[: search_query.max_llm_filter_chunks]),
(
search_query,
retrieved_chunks[: search_query.max_llm_filter_chunks],
llm,
),
)
)
llm_filter_task_id = post_processing_tasks[-1].result_id

View File

@ -1,5 +1,5 @@
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.answer_validation import ANSWER_VALIDITY_PROMPT
@ -44,7 +44,7 @@ def get_answer_validity(
return True # If something is wrong, let's not toss away the answer
try:
llm = get_default_llm()
llm, _ = get_default_llms()
except GenAIDisabledException:
return True

View File

@ -1,7 +1,6 @@
from collections.abc import Callable
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
from danswer.prompts.llm_chunk_filter import CHUNK_FILTER_PROMPT
@ -12,7 +11,7 @@ from danswer.utils.threadpool_concurrency import run_functions_tuples_in_paralle
logger = setup_logger()
def llm_eval_chunk(query: str, chunk_content: str) -> bool:
def llm_eval_chunk(query: str, chunk_content: str, llm: LLM) -> bool:
def _get_usefulness_messages() -> list[dict[str, str]]:
messages = [
{
@ -32,14 +31,6 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
return False
return True
# If Gen AI is disabled, none of the messages are more "useful" than any other
# All are marked not useful (False) so that the icon for Gen AI likes this answer
# is not shown for any result
try:
llm = get_default_llm(use_fast_llm=True, timeout=5)
except GenAIDisabledException:
return False
messages = _get_usefulness_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
# When running in a batch, it takes as long as the longest thread
@ -52,11 +43,12 @@ def llm_eval_chunk(query: str, chunk_content: str) -> bool:
def llm_batch_eval_chunks(
query: str, chunk_contents: list[str], use_threads: bool = True
query: str, chunk_contents: list[str], llm: LLM, use_threads: bool = True
) -> list[bool]:
if use_threads:
functions_with_args: list[tuple[Callable, tuple]] = [
(llm_eval_chunk, (query, chunk_content)) for chunk_content in chunk_contents
(llm_eval_chunk, (query, chunk_content, llm))
for chunk_content in chunk_contents
]
logger.debug(
@ -71,5 +63,6 @@ def llm_batch_eval_chunks(
else:
return [
llm_eval_chunk(query, chunk_content) for chunk_content in chunk_contents
llm_eval_chunk(query, chunk_content, llm)
for chunk_content in chunk_contents
]

View File

@ -6,7 +6,7 @@ from danswer.configs.model_configs import GEN_AI_HISTORY_CUTOFF
from danswer.db.models import ChatMessage
from danswer.llm.answering.models import PreviousMessage
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.interfaces import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_to_string
@ -33,7 +33,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
return messages
try:
llm = get_default_llm(use_fast_llm=True, timeout=5)
_, fast_llm = get_default_llms(timeout=5)
except GenAIDisabledException:
logger.warning(
"Unable to perform multilingual query expansion, Gen AI disabled"
@ -42,7 +42,7 @@ def llm_multilingual_query_expansion(query: str, language: str) -> str:
messages = _get_rephrase_messages()
filled_llm_prompt = dict_based_prompt_to_langchain_prompt(messages)
model_output = message_to_string(llm.invoke(filled_llm_prompt))
model_output = message_to_string(fast_llm.invoke(filled_llm_prompt))
logger.debug(model_output)
return model_output
@ -148,7 +148,7 @@ def thread_based_query_rephrase(
if llm is None:
try:
llm = get_default_llm()
llm, _ = get_default_llms()
except GenAIDisabledException:
# If Generative AI is turned off, just return the original query
return user_query

View File

@ -5,7 +5,7 @@ from danswer.chat.models import DanswerAnswerPiece
from danswer.chat.models import StreamingError
from danswer.configs.chat_configs import DISABLE_LLM_QUERY_ANSWERABILITY
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import message_generator_to_string_generator
from danswer.llm.utils import message_to_string
@ -52,7 +52,7 @@ def get_query_answerability(
return "Query Answerability Evaluation feature is turned off", True
try:
llm = get_default_llm()
llm, _ = get_default_llms()
except GenAIDisabledException:
return "Generative AI is turned off - skipping check", True
@ -79,7 +79,7 @@ def stream_query_answerability(
return
try:
llm = get_default_llm()
llm, _ = get_default_llms()
except GenAIDisabledException:
yield get_json_line(
QueryValidationResponse(

View File

@ -159,11 +159,13 @@ def extract_source_filter(
if __name__ == "__main__":
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple
# Just for testing purposes
with Session(get_sqlalchemy_engine()) as db_session:
while True:
user_input = input("Query to Extract Sources: ")
sources = extract_source_filter(user_input, get_default_llm(), db_session)
sources = extract_source_filter(
user_input, get_main_llm_from_tuple(get_default_llms()), db_session
)
print(sources)

View File

@ -156,10 +156,12 @@ def extract_time_filter(query: str, llm: LLM) -> tuple[datetime | None, bool]:
if __name__ == "__main__":
# Just for testing purposes, too tedious to unit test as it relies on an LLM
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms, get_main_llm_from_tuple
while True:
user_input = input("Query to Extract Time: ")
cutoff, recency_bias = extract_time_filter(user_input, get_default_llm())
cutoff, recency_bias = extract_time_filter(
user_input, get_main_llm_from_tuple(get_default_llms())
)
print(f"Time Cutoff: {cutoff}")
print(f"Favor Recent: {recency_bias}")

View File

@ -7,7 +7,7 @@ from pydantic import BaseModel
from sqlalchemy.orm import Session
from danswer.db.engine import get_session
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.search.models import SearchRequest
from danswer.search.pipeline import SearchPipeline
from danswer.server.danswer_api.ingestion import api_key_dep
@ -67,12 +67,14 @@ def gpt_search(
_: str | None = Depends(api_key_dep),
db_session: Session = Depends(get_session),
) -> GptSearchResponse:
llm, fast_llm = get_default_llms()
top_chunks = SearchPipeline(
search_request=SearchRequest(
query=search_request.query,
),
user=None,
llm=get_default_llm(),
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
).reranked_chunks

View File

@ -24,7 +24,7 @@ from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs.factory import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.file_store.file_store import get_default_file_store
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.utils import test_llm
from danswer.server.documents.models import ConnectorCredentialPairIdentifier
from danswer.server.manage.models import BoostDoc
@ -126,7 +126,7 @@ def validate_existing_genai_api_key(
pass
try:
llm = get_default_llm(timeout=10)
llm, __ = get_default_llms(timeout=10)
except ValueError:
raise HTTPException(status_code=404, detail="LLM not setup")

View File

@ -13,7 +13,7 @@ from danswer.db.llm import remove_llm_provider
from danswer.db.llm import update_default_provider
from danswer.db.llm import upsert_llm_provider
from danswer.db.models import User
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.factory import get_llm
from danswer.llm.llm_provider_options import fetch_available_well_known_llms
from danswer.llm.llm_provider_options import WellKnownLLMProviderDescriptor
@ -85,8 +85,7 @@ def test_default_provider(
_: User | None = Depends(current_admin_user),
) -> None:
try:
llm = get_default_llm()
fast_llm = get_default_llm(use_fast_llm=True)
llm, fast_llm = get_default_llms()
except ValueError:
logger.exception("Failed to fetch default LLM Provider")
raise HTTPException(status_code=400, detail="No LLM Provider setup")

View File

@ -43,7 +43,7 @@ from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.exceptions import GenAIDisabledException
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.llm.headers import get_litellm_additional_request_headers
from danswer.llm.utils import get_default_llm_tokenizer
from danswer.secondary_llm_flows.chat_session_naming import (
@ -224,7 +224,7 @@ def rename_chat_session(
full_history = history_msgs + [final_msg]
try:
llm = get_default_llm(
llm, _ = get_default_llms(
additional_headers=get_litellm_additional_request_headers(request.headers)
)
except GenAIDisabledException:

View File

@ -69,6 +69,7 @@ class SearchTool(Tool):
retrieval_options: RetrievalDetails | None,
prompt_config: PromptConfig,
llm: LLM,
fast_llm: LLM,
pruning_config: DocumentPruningConfig,
# if specified, will not actually run a search and will instead return these
# sections. Used when the user selects specific docs to talk to
@ -83,6 +84,7 @@ class SearchTool(Tool):
self.retrieval_options = retrieval_options
self.prompt_config = prompt_config
self.llm = llm
self.fast_llm = fast_llm
self.pruning_config = pruning_config
self.selected_docs = selected_docs
@ -212,6 +214,7 @@ class SearchTool(Tool):
),
user=self.user,
llm=self.llm,
fast_llm=self.fast_llm,
bypass_acl=self.bypass_acl,
db_session=self.db_session,
)

View File

@ -10,8 +10,9 @@ from danswer.db.persona import get_persona_by_id
from danswer.llm.answering.prompts.citations_prompt import (
compute_max_document_tokens_for_persona,
)
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_llm_for_persona
from danswer.llm.factory import get_default_llms
from danswer.llm.factory import get_llms_for_persona
from danswer.llm.factory import get_main_llm_from_tuple
from danswer.llm.utils import get_max_input_tokens
from danswer.one_shot_answer.answer_question import get_search_answer
from danswer.one_shot_answer.models import DirectQARequest
@ -41,7 +42,7 @@ def handle_search_request(
query = search_request.message
logger.info(f"Received document search query: {query}")
llm = get_default_llm()
llm, fast_llm = get_default_llms()
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
@ -59,6 +60,7 @@ def handle_search_request(
),
user=user,
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
bypass_acl=False,
)
@ -104,7 +106,9 @@ def get_answer_with_quote(
is_for_edit=False,
)
llm = get_default_llm() if not persona else get_llm_for_persona(persona)
llm = get_main_llm_from_tuple(
get_default_llms() if not persona else get_llms_for_persona(persona)
)
input_tokens = get_max_input_tokens(
model_name=llm.config.model_name, model_provider=llm.config.model_provider
)

View File

@ -9,7 +9,7 @@ from sqlalchemy.orm import Session
from danswer.db.engine import get_sqlalchemy_engine
from danswer.llm.answering.doc_pruning import reorder_docs
from danswer.llm.factory import get_default_llm
from danswer.llm.factory import get_default_llms
from danswer.search.models import InferenceChunk
from danswer.search.models import RerankMetricsContainer
from danswer.search.models import RetrievalMetricsContainer
@ -83,12 +83,14 @@ def get_search_results(
rerank_metrics = MetricsHander[RerankMetricsContainer]()
with Session(get_sqlalchemy_engine()) as db_session:
llm, fast_llm = get_default_llms()
search_pipeline = SearchPipeline(
search_request=SearchRequest(
query=query,
),
user=None,
llm=get_default_llm(),
llm=llm,
fast_llm=fast_llm,
db_session=db_session,
retrieval_metrics_callback=retrieval_metrics.record_metric,
rerank_metrics_callback=rerank_metrics.record_metric,