mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-06-15 02:20:52 +02:00
112 lines
4.3 KiB
Python
112 lines
4.3 KiB
Python
from danswer.chunking.models import InferenceChunk
|
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
|
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
|
from danswer.configs.app_configs import QA_TIMEOUT
|
|
from danswer.datastores.document_index import get_default_document_index
|
|
from danswer.db.models import User
|
|
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
|
from danswer.direct_qa.exceptions import UnknownModelError
|
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
|
from danswer.search.danswer_helper import query_intent
|
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
|
from danswer.search.models import QueryFlow
|
|
from danswer.search.models import SearchType
|
|
from danswer.search.semantic_search import chunks_to_search_docs
|
|
from danswer.search.semantic_search import retrieve_ranked_documents
|
|
from danswer.server.models import QAResponse
|
|
from danswer.server.models import QuestionRequest
|
|
from danswer.utils.logger import setup_logger
|
|
from danswer.utils.timing import log_function_time
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
@log_function_time()
|
|
def answer_question(
|
|
question: QuestionRequest,
|
|
user: User | None,
|
|
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
|
answer_generation_timeout: int = QA_TIMEOUT,
|
|
) -> QAResponse:
|
|
query = question.query
|
|
collection = question.collection
|
|
filters = question.filters
|
|
use_keyword = question.use_keyword
|
|
offset_count = question.offset if question.offset is not None else 0
|
|
logger.info(f"Received QA query: {query}")
|
|
|
|
predicted_search, predicted_flow = query_intent(query)
|
|
if use_keyword is None:
|
|
use_keyword = predicted_search == SearchType.KEYWORD
|
|
|
|
user_id = None if user is None else user.id
|
|
if use_keyword:
|
|
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
|
query, user_id, filters, get_default_document_index(collection=collection)
|
|
)
|
|
unranked_chunks: list[InferenceChunk] | None = []
|
|
else:
|
|
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
|
query, user_id, filters, get_default_document_index(collection=collection)
|
|
)
|
|
if not ranked_chunks:
|
|
return QAResponse(
|
|
answer=None,
|
|
quotes=None,
|
|
top_ranked_docs=None,
|
|
lower_ranked_docs=None,
|
|
predicted_flow=predicted_flow,
|
|
predicted_search=predicted_search,
|
|
)
|
|
|
|
if disable_generative_answer:
|
|
logger.debug("Skipping QA because generative AI is disabled")
|
|
return QAResponse(
|
|
answer=None,
|
|
quotes=None,
|
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
|
# set flow as search so frontend doesn't ask the user if they want
|
|
# to run QA over more documents
|
|
predicted_flow=QueryFlow.SEARCH,
|
|
predicted_search=predicted_search,
|
|
)
|
|
|
|
try:
|
|
qa_model = get_default_qa_model(timeout=answer_generation_timeout)
|
|
except (UnknownModelError, OpenAIKeyMissing) as e:
|
|
return QAResponse(
|
|
answer=None,
|
|
quotes=None,
|
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
|
predicted_flow=predicted_flow,
|
|
predicted_search=predicted_search,
|
|
error_msg=str(e),
|
|
)
|
|
|
|
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
|
if chunk_offset >= len(ranked_chunks):
|
|
raise ValueError("Chunks offset too large, should not retry this many times")
|
|
|
|
error_msg = None
|
|
try:
|
|
answer, quotes = qa_model.answer_question(
|
|
query,
|
|
ranked_chunks[chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS],
|
|
)
|
|
except Exception as e:
|
|
# exception is logged in the answer_question method, no need to re-log
|
|
answer, quotes = None, None
|
|
error_msg = f"Error occurred in call to LLM - {e}"
|
|
|
|
return QAResponse(
|
|
answer=answer.answer if answer else None,
|
|
quotes=quotes.quotes if quotes else None,
|
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
|
predicted_flow=predicted_flow,
|
|
predicted_search=predicted_search,
|
|
error_msg=error_msg,
|
|
)
|