mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-09 20:39:29 +02:00
More Cleanup and Deduplication (#675)
This commit is contained in:
parent
9cd0c197e7
commit
73b653d324
@ -1,4 +1,5 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -6,7 +7,10 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.models import LLMMetricsContainer
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
@ -21,8 +25,11 @@ from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import RerankedRetrievalDocs
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -169,3 +176,114 @@ def answer_qa_query(
|
||||
favor_recent=favor_recent,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
|
||||
|
||||
@log_generator_function_time()
|
||||
def answer_qa_query_stream(
|
||||
question: QuestionRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
) -> Iterator[str]:
|
||||
logger.debug(
|
||||
f"Received QA query ({question.search_type.value} search): {question.query}"
|
||||
)
|
||||
logger.debug(f"Query filters: {question.filters}")
|
||||
|
||||
answer_so_far: str = ""
|
||||
query = question.query
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
question.filters.time_cutoff = time_cutoff
|
||||
question.favor_recent = favor_recent
|
||||
|
||||
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||
question=question,
|
||||
user=user,
|
||||
db_session=db_session,
|
||||
document_index=get_default_document_index(),
|
||||
)
|
||||
|
||||
# TODO retire this
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
|
||||
initial_response = RerankedRetrievalDocs(
|
||||
top_documents=top_docs,
|
||||
unranked_top_documents=unranked_top_docs,
|
||||
# if generative AI is disabled, set flow as search so frontend
|
||||
# doesn't ask the user if they want to run QA over more documents
|
||||
predicted_flow=QueryFlow.SEARCH
|
||||
if disable_generative_answer
|
||||
else predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
).dict()
|
||||
|
||||
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
if not ranked_chunks:
|
||||
logger.debug("No Documents Found")
|
||||
return
|
||||
|
||||
if disable_generative_answer:
|
||||
logger.debug("Skipping QA because generative AI is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
qa_model = get_default_qa_model()
|
||||
except Exception as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
error = StreamingError(error=str(e))
|
||||
yield get_json_line(error.dict())
|
||||
return
|
||||
|
||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
||||
# types which can't be parsed). These chunks are useful to show in the
|
||||
# search results, but not for QA.
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
offset=offset_count,
|
||||
)
|
||||
logger.debug(
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
||||
)
|
||||
|
||||
try:
|
||||
for response_packet in qa_model.answer_question_stream(query, usable_chunks):
|
||||
if response_packet is None:
|
||||
continue
|
||||
if (
|
||||
isinstance(response_packet, DanswerAnswerPiece)
|
||||
and response_packet.answer_piece
|
||||
):
|
||||
answer_so_far = answer_so_far + response_packet.answer_piece
|
||||
logger.debug(f"Sending packet: {response_packet}")
|
||||
yield get_json_line(response_packet.dict())
|
||||
except Exception:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
logger.exception("Failed to run QA")
|
||||
error = StreamingError(error="The LLM failed to produce a useable response")
|
||||
yield get_json_line(error.dict())
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
search_type=question.search_type,
|
||||
llm_answer=answer_so_far,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||
user_id=None if user is None else user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
yield get_json_line({"query_event_id": query_event_id})
|
||||
|
@ -1,5 +1,3 @@
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_block import QABlock
|
||||
@ -12,25 +10,6 @@ from danswer.utils.logger import setup_logger
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
if not model_api_key:
|
||||
return False
|
||||
|
||||
llm = get_default_llm(api_key=model_api_key, timeout=10)
|
||||
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke("Do not respond")
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
# TODO introduce the prompt choice parameter
|
||||
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
|
||||
return (
|
||||
|
@ -4,7 +4,6 @@ import re
|
||||
from collections.abc import Generator
|
||||
from collections.abc import Iterator
|
||||
from json.decoder import JSONDecodeError
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
@ -12,8 +11,6 @@ import regex
|
||||
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
@ -21,8 +18,6 @@ from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.indexing.models import InferenceChunk
|
||||
from danswer.llm.utils import check_number_of_tokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -33,17 +28,6 @@ from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_gen_ai_api_key() -> str | None:
|
||||
# first check if the key has been provided by the UI
|
||||
try:
|
||||
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
# if not provided by the UI, fallback to the env variable
|
||||
return GEN_AI_API_KEY
|
||||
|
||||
|
||||
def extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
|
@ -1,22 +1,77 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
|
||||
import litellm # type:ignore
|
||||
from langchain.chat_models import ChatLiteLLM
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||
from danswer.llm.interfaces import LangChainChatLLM
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.llm.utils import should_be_verbose
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
# If a user configures a different model and it doesn't support all the same
|
||||
# parameters like frequency and presence, just ignore them
|
||||
litellm.drop_params = True
|
||||
litellm.telemetry = False
|
||||
|
||||
|
||||
class LangChainChatLLM(LLM, abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def llm(self) -> BaseChatModel:
|
||||
raise NotImplementedError
|
||||
|
||||
def _log_model_config(self) -> None:
|
||||
logger.debug(
|
||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
self._log_model_config()
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
self._log_prompt(prompt)
|
||||
|
||||
model_raw = self.llm.invoke(prompt).content
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
logger.debug(f"Raw Model Output:\n{model_raw}")
|
||||
|
||||
return model_raw
|
||||
|
||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||
self._log_model_config()
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
self._log_prompt(prompt)
|
||||
|
||||
output_tokens = []
|
||||
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
|
||||
output_tokens.append(token)
|
||||
yield token
|
||||
|
||||
full_output = "".join(output_tokens)
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
logger.debug(f"Raw Model Output:\n{full_output}")
|
||||
|
||||
|
||||
def _get_model_str(
|
||||
model_provider: str | None,
|
||||
model_version: str | None,
|
@ -1,10 +1,10 @@
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||
from danswer.llm.custom_llm import CustomModelServer
|
||||
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.llm.multi_llm import DefaultMultiLLM
|
||||
from danswer.llm.utils import get_gen_ai_api_key
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
|
@ -1,11 +1,8 @@
|
||||
import abc
|
||||
from collections.abc import Iterator
|
||||
|
||||
from langchain.chat_models.base import BaseChatModel
|
||||
from langchain.schema.language_model import LanguageModelInput
|
||||
|
||||
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
||||
from danswer.llm.utils import message_generator_to_string_generator
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
@ -28,48 +25,3 @@ class LLM(abc.ABC):
|
||||
@abc.abstractmethod
|
||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class LangChainChatLLM(LLM, abc.ABC):
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def llm(self) -> BaseChatModel:
|
||||
raise NotImplementedError
|
||||
|
||||
def _log_model_config(self) -> None:
|
||||
logger.debug(
|
||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||
)
|
||||
|
||||
@staticmethod
|
||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
||||
if isinstance(prompt, list):
|
||||
for ind, msg in enumerate(prompt):
|
||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||
if isinstance(prompt, str):
|
||||
logger.debug(f"Prompt:\n{prompt}")
|
||||
|
||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||
self._log_model_config()
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
self._log_prompt(prompt)
|
||||
|
||||
model_raw = self.llm.invoke(prompt).content
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
logger.debug(f"Raw Model Output:\n{model_raw}")
|
||||
|
||||
return model_raw
|
||||
|
||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||
self._log_model_config()
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
self._log_prompt(prompt)
|
||||
|
||||
output_tokens = []
|
||||
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
|
||||
output_tokens.append(token)
|
||||
yield token
|
||||
|
||||
full_output = "".join(output_tokens)
|
||||
if LOG_ALL_MODEL_INTERACTIONS:
|
||||
logger.debug(f"Raw Model Output:\n{full_output}")
|
||||
|
@ -1,6 +1,7 @@
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Iterator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
|
||||
import tiktoken
|
||||
from langchain.prompts.base import StringPromptValue
|
||||
@ -14,9 +15,16 @@ from langchain.schema.messages import HumanMessage
|
||||
from langchain.schema.messages import SystemMessage
|
||||
|
||||
from danswer.configs.app_configs import LOG_LEVEL
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import MessageType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.db.models import ChatMessage
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.interfaces import LLM
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
_LLM_TOKENIZER: Callable[[str], Any] | None = None
|
||||
|
||||
@ -107,7 +115,7 @@ def should_be_verbose() -> bool:
|
||||
def check_number_of_tokens(
|
||||
text: str, encode_fn: Callable[[str], list] | None = None
|
||||
) -> int:
|
||||
"""Get's the number of tokens in the provided text, using the provided encoding
|
||||
"""Gets the number of tokens in the provided text, using the provided encoding
|
||||
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
|
||||
and GPT-4.
|
||||
"""
|
||||
@ -116,3 +124,26 @@ def check_number_of_tokens(
|
||||
encode_fn = tiktoken.get_encoding("cl100k_base").encode
|
||||
|
||||
return len(encode_fn(text))
|
||||
|
||||
|
||||
def get_gen_ai_api_key() -> str | None:
|
||||
# first check if the key has been provided by the UI
|
||||
try:
|
||||
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
# if not provided by the UI, fallback to the env variable
|
||||
return GEN_AI_API_KEY
|
||||
|
||||
|
||||
def test_llm(llm: LLM) -> bool:
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
try:
|
||||
llm.invoke("Do not respond")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||
|
||||
return False
|
||||
|
@ -23,12 +23,13 @@ from danswer.db.feedback import fetch_docs_ranked_by_boost
|
||||
from danswer.db.feedback import update_document_boost
|
||||
from danswer.db.feedback import update_document_hidden
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.llm.factory import get_default_llm
|
||||
from danswer.llm.utils import get_gen_ai_api_key
|
||||
from danswer.llm.utils import test_llm
|
||||
from danswer.server.models import ApiKey
|
||||
from danswer.server.models import BoostDoc
|
||||
from danswer.server.models import BoostUpdateRequest
|
||||
@ -132,7 +133,8 @@ def validate_existing_genai_api_key(
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
|
||||
try:
|
||||
is_valid = check_model_api_key_is_valid(genai_api_key)
|
||||
llm = get_default_llm(api_key=genai_api_key, timeout=10)
|
||||
is_valid = test_llm(llm)
|
||||
except ValueError:
|
||||
# this is the case where they aren't using an OpenAI-based model
|
||||
is_valid = True
|
||||
@ -168,9 +170,15 @@ def store_genai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
is_valid = check_model_api_key_is_valid(request.api_key)
|
||||
if not request.api_key:
|
||||
raise HTTPException(400, "No API key provided")
|
||||
|
||||
llm = get_default_llm(api_key=request.api_key, timeout=10)
|
||||
is_valid = test_llm(llm)
|
||||
|
||||
if not is_valid:
|
||||
raise HTTPException(400, "Invalid API key provided")
|
||||
|
||||
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
except RuntimeError as e:
|
||||
raise HTTPException(400, str(e))
|
||||
|
@ -1,5 +1,3 @@
|
||||
from collections.abc import Generator
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi import HTTPException
|
||||
@ -9,30 +7,19 @@ from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_admin_user
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||
from danswer.configs.constants import IGNORE_FOR_QA
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.feedback import create_doc_retrieval_feedback
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.feedback import update_query_event_feedback
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import StreamingError
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||
from danswer.direct_qa.answer_question import answer_qa_query_stream
|
||||
from danswer.document_index.factory import get_default_document_index
|
||||
from danswer.document_index.vespa.index import VespaIndex
|
||||
from danswer.search.access_filters import build_access_filters_for_user
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.models import IndexFilters
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchQuery
|
||||
from danswer.search.search_runner import chunks_to_search_docs
|
||||
from danswer.search.search_runner import danswer_search
|
||||
from danswer.search.search_runner import search_chunks
|
||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||
@ -41,13 +28,10 @@ from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QueryValidationResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import RerankedRetrievalDocs
|
||||
from danswer.server.models import SearchDoc
|
||||
from danswer.server.models import SearchFeedbackRequest
|
||||
from danswer.server.models import SearchResponse
|
||||
from danswer.server.utils import get_json_line
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
@ -190,144 +174,10 @@ def stream_direct_qa(
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
send_packet_debug_msg = "Sending Packet: {}"
|
||||
top_documents_key = "top_documents"
|
||||
unranked_top_docs_key = "unranked_top_documents"
|
||||
predicted_flow_key = "predicted_flow"
|
||||
predicted_search_key = "predicted_search"
|
||||
query_event_id_key = "query_event_id"
|
||||
|
||||
logger.debug(
|
||||
f"Received QA query ({question.search_type.value} search): {question.query}"
|
||||
packets = answer_qa_query_stream(
|
||||
question=question, user=user, db_session=db_session
|
||||
)
|
||||
logger.debug(f"Query filters: {question.filters}")
|
||||
|
||||
@log_generator_function_time()
|
||||
def stream_qa_portions(
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
) -> Generator[str, None, None]:
|
||||
answer_so_far: str = ""
|
||||
query = question.query
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
|
||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||
question.filters.time_cutoff = time_cutoff # not used but just in case
|
||||
filters = question.filters
|
||||
|
||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
||||
final_filters = IndexFilters(
|
||||
source_type=filters.source_type,
|
||||
document_set=filters.document_set,
|
||||
time_cutoff=time_cutoff,
|
||||
access_control_list=user_acl_filters,
|
||||
)
|
||||
|
||||
search_query = SearchQuery(
|
||||
query=query,
|
||||
search_type=question.search_type,
|
||||
filters=final_filters,
|
||||
favor_recent=favor_recent,
|
||||
)
|
||||
|
||||
ranked_chunks, unranked_chunks = search_chunks(
|
||||
query=search_query, document_index=get_default_document_index()
|
||||
)
|
||||
|
||||
# TODO retire this
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
|
||||
if not ranked_chunks:
|
||||
logger.debug("No Documents Found")
|
||||
empty_docs_result = {
|
||||
top_documents_key: None,
|
||||
unranked_top_docs_key: None,
|
||||
predicted_flow_key: predicted_flow,
|
||||
predicted_search_key: predicted_search,
|
||||
}
|
||||
logger.debug(send_packet_debug_msg.format(empty_docs_result))
|
||||
yield get_json_line(empty_docs_result)
|
||||
return
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
initial_response = RerankedRetrievalDocs(
|
||||
top_documents=top_docs,
|
||||
unranked_top_documents=unranked_top_docs,
|
||||
# if generative AI is disabled, set flow as search so frontend
|
||||
# doesn't ask the user if they want to run QA over more documents
|
||||
predicted_flow=QueryFlow.SEARCH
|
||||
if disable_generative_answer
|
||||
else predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
time_cutoff=time_cutoff,
|
||||
favor_recent=favor_recent,
|
||||
).dict()
|
||||
|
||||
logger.debug(send_packet_debug_msg.format(initial_response))
|
||||
yield get_json_line(initial_response)
|
||||
|
||||
if disable_generative_answer:
|
||||
logger.debug("Skipping QA because generative AI is disabled")
|
||||
return
|
||||
|
||||
try:
|
||||
qa_model = get_default_qa_model()
|
||||
except Exception as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
error = StreamingError(error=str(e))
|
||||
yield get_json_line(error.dict())
|
||||
return
|
||||
|
||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
||||
# types which can't be parsed). These chunks are useful to show in the
|
||||
# search results, but not for QA.
|
||||
filtered_ranked_chunks = [
|
||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||
]
|
||||
|
||||
# get all chunks that fit into the token limit
|
||||
usable_chunks = get_usable_chunks(
|
||||
chunks=filtered_ranked_chunks,
|
||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||
offset=offset_count,
|
||||
)
|
||||
logger.debug(
|
||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
||||
)
|
||||
|
||||
try:
|
||||
for response_packet in qa_model.answer_question_stream(
|
||||
query, usable_chunks
|
||||
):
|
||||
if response_packet is None:
|
||||
continue
|
||||
if (
|
||||
isinstance(response_packet, DanswerAnswerPiece)
|
||||
and response_packet.answer_piece
|
||||
):
|
||||
answer_so_far = answer_so_far + response_packet.answer_piece
|
||||
logger.debug(f"Sending packet: {response_packet}")
|
||||
yield get_json_line(response_packet.dict())
|
||||
except Exception:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
logger.exception("Failed to run QA")
|
||||
yield get_json_line(
|
||||
{"error": "The LLM failed to produce a useable response"}
|
||||
)
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
search_type=question.search_type,
|
||||
llm_answer=answer_so_far,
|
||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||
user_id=None if user is None else user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
yield get_json_line({query_event_id_key: query_event_id})
|
||||
return
|
||||
|
||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||
return StreamingResponse(packets, media_type="application/json")
|
||||
|
||||
|
||||
@router.post("/query-feedback")
|
||||
|
Loading…
x
Reference in New Issue
Block a user