More Cleanup and Deduplication (#675)

This commit is contained in:
Yuhong Sun
2023-11-01 16:03:48 -07:00
committed by GitHub
parent 9cd0c197e7
commit 73b653d324
9 changed files with 224 additions and 247 deletions

View File

@@ -1,4 +1,5 @@
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Iterator
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -6,7 +7,10 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA from danswer.configs.constants import IGNORE_FOR_QA
from danswer.db.feedback import create_query_event
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks from danswer.direct_qa.qa_utils import get_usable_chunks
@@ -21,8 +25,11 @@ from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
from danswer.server.models import QAResponse from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
logger = setup_logger() logger = setup_logger()
@@ -169,3 +176,114 @@ def answer_qa_query(
favor_recent=favor_recent, favor_recent=favor_recent,
error_msg=error_msg, error_msg=error_msg,
) )
@log_generator_function_time()
def answer_qa_query_stream(
question: QuestionRequest,
user: User | None,
db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Iterator[str]:
logger.debug(
f"Received QA query ({question.search_type.value} search): {question.query}"
)
logger.debug(f"Query filters: {question.filters}")
answer_so_far: str = ""
query = question.query
offset_count = question.offset if question.offset is not None else 0
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
question=question,
user=user,
db_session=db_session,
document_index=get_default_document_index(),
)
# TODO retire this
predicted_search, predicted_flow = query_intent(query)
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
initial_response = RerankedRetrievalDocs(
top_documents=top_docs,
unranked_top_documents=unranked_top_docs,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search=predicted_search,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
).dict()
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
yield get_json_line(initial_response)
if not ranked_chunks:
logger.debug("No Documents Found")
return
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_qa_model()
except Exception as e:
logger.exception("Unable to get QA model")
error = StreamingError(error=str(e))
yield get_json_line(error.dict())
return
# remove chunks marked as not applicable for QA (e.g. Google Drive file
# types which can't be parsed). These chunks are useful to show in the
# search results, but not for QA.
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
try:
for response_packet in qa_model.answer_question_stream(query, usable_chunks):
if response_packet is None:
continue
if (
isinstance(response_packet, DanswerAnswerPiece)
and response_packet.answer_piece
):
answer_so_far = answer_so_far + response_packet.answer_piece
logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(response_packet.dict())
except Exception:
# exception is logged in the answer_question method, no need to re-log
logger.exception("Failed to run QA")
error = StreamingError(error="The LLM failed to produce a useable response")
yield get_json_line(error.dict())
query_event_id = create_query_event(
query=query,
search_type=question.search_type,
llm_answer=answer_so_far,
retrieved_document_ids=[doc.document_id for doc in top_docs],
user_id=None if user is None else user.id,
db_session=db_session,
)
yield get_json_line({"query_event_id": query_event_id})

View File

@@ -1,5 +1,3 @@
from openai.error import AuthenticationError
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_block import QABlock from danswer.direct_qa.qa_block import QABlock
@@ -12,25 +10,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger() logger = setup_logger()
def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
llm = get_default_llm(api_key=model_api_key, timeout=10)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
llm.invoke("Do not respond")
return True
except AuthenticationError:
return False
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
# TODO introduce the prompt choice parameter # TODO introduce the prompt choice parameter
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler: def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
return ( return (

View File

@@ -4,7 +4,6 @@ import re
from collections.abc import Generator from collections.abc import Generator
from collections.abc import Iterator from collections.abc import Iterator
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from typing import cast
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@@ -12,8 +11,6 @@ import regex
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
@@ -21,8 +18,6 @@ from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.qa_prompts import ANSWER_PAT from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import InferenceChunk from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens from danswer.llm.utils import check_number_of_tokens
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -33,17 +28,6 @@ from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger() logger = setup_logger()
def get_gen_ai_api_key() -> str | None:
# first check if the key has been provided by the UI
try:
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
except ConfigNotFoundError:
pass
# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY
def extract_answer_quotes_freeform( def extract_answer_quotes_freeform(
answer_raw: str, answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]: ) -> Tuple[Optional[str], Optional[list[str]]]:

View File

@@ -1,22 +1,77 @@
import abc
from collections.abc import Iterator
import litellm # type:ignore import litellm # type:ignore
from langchain.chat_models import ChatLiteLLM from langchain.chat_models import ChatLiteLLM
from langchain.chat_models.base import BaseChatModel
from langchain.schema.language_model import LanguageModelInput
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_VERSION from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LangChainChatLLM from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_generator_to_string_generator
from danswer.llm.utils import should_be_verbose from danswer.llm.utils import should_be_verbose
from danswer.utils.logger import setup_logger
logger = setup_logger()
# If a user configures a different model and it doesn't support all the same # If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them # parameters like frequency and presence, just ignore them
litellm.drop_params = True litellm.drop_params = True
litellm.telemetry = False litellm.telemetry = False
class LangChainChatLLM(LLM, abc.ABC):
@property
@abc.abstractmethod
def llm(self) -> BaseChatModel:
raise NotImplementedError
def _log_model_config(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
@staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
model_raw = self.llm.invoke(prompt).content
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{model_raw}")
return model_raw
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
output_tokens = []
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
output_tokens.append(token)
yield token
full_output = "".join(output_tokens)
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{full_output}")
def _get_model_str( def _get_model_str(
model_provider: str | None, model_provider: str | None,
model_version: str | None, model_version: str | None,

View File

@@ -1,10 +1,10 @@
from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.direct_qa.qa_utils import get_gen_ai_api_key from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.gpt_4_all import DanswerGPT4All from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM from danswer.llm.interfaces import LLM
from danswer.llm.multi_llm import DefaultMultiLLM from danswer.llm.utils import get_gen_ai_api_key
def get_default_llm( def get_default_llm(

View File

@@ -1,11 +1,8 @@
import abc import abc
from collections.abc import Iterator from collections.abc import Iterator
from langchain.chat_models.base import BaseChatModel
from langchain.schema.language_model import LanguageModelInput from langchain.schema.language_model import LanguageModelInput
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.llm.utils import message_generator_to_string_generator
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
@@ -28,48 +25,3 @@ class LLM(abc.ABC):
@abc.abstractmethod @abc.abstractmethod
def stream(self, prompt: LanguageModelInput) -> Iterator[str]: def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
raise NotImplementedError raise NotImplementedError
class LangChainChatLLM(LLM, abc.ABC):
@property
@abc.abstractmethod
def llm(self) -> BaseChatModel:
raise NotImplementedError
def _log_model_config(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
@staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
model_raw = self.llm.invoke(prompt).content
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{model_raw}")
return model_raw
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
output_tokens = []
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
output_tokens.append(token)
yield token
full_output = "".join(output_tokens)
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{full_output}")

View File

@@ -1,6 +1,7 @@
from collections.abc import Callable from collections.abc import Callable
from collections.abc import Iterator from collections.abc import Iterator
from typing import Any from typing import Any
from typing import cast
import tiktoken import tiktoken
from langchain.prompts.base import StringPromptValue from langchain.prompts.base import StringPromptValue
@@ -14,9 +15,16 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage from langchain.schema.messages import SystemMessage
from danswer.configs.app_configs import LOG_LEVEL from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.interfaces import LLM
from danswer.utils.logger import setup_logger
logger = setup_logger()
_LLM_TOKENIZER: Callable[[str], Any] | None = None _LLM_TOKENIZER: Callable[[str], Any] | None = None
@@ -107,7 +115,7 @@ def should_be_verbose() -> bool:
def check_number_of_tokens( def check_number_of_tokens(
text: str, encode_fn: Callable[[str], list] | None = None text: str, encode_fn: Callable[[str], list] | None = None
) -> int: ) -> int:
"""Get's the number of tokens in the provided text, using the provided encoding """Gets the number of tokens in the provided text, using the provided encoding
function. If none is provided, default to the tiktoken encoder used by GPT-3.5 function. If none is provided, default to the tiktoken encoder used by GPT-3.5
and GPT-4. and GPT-4.
""" """
@@ -116,3 +124,26 @@ def check_number_of_tokens(
encode_fn = tiktoken.get_encoding("cl100k_base").encode encode_fn = tiktoken.get_encoding("cl100k_base").encode
return len(encode_fn(text)) return len(encode_fn(text))
def get_gen_ai_api_key() -> str | None:
# first check if the key has been provided by the UI
try:
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
except ConfigNotFoundError:
pass
# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY
def test_llm(llm: LLM) -> bool:
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
llm.invoke("Do not respond")
return True
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False

View File

@@ -23,12 +23,13 @@ from danswer.db.feedback import fetch_docs_ranked_by_boost
from danswer.db.feedback import update_document_boost from danswer.db.feedback import update_document_boost
from danswer.db.feedback import update_document_hidden from danswer.db.feedback import update_document_hidden
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_gen_ai_api_key
from danswer.llm.utils import test_llm
from danswer.server.models import ApiKey from danswer.server.models import ApiKey
from danswer.server.models import BoostDoc from danswer.server.models import BoostDoc
from danswer.server.models import BoostUpdateRequest from danswer.server.models import BoostUpdateRequest
@@ -132,7 +133,8 @@ def validate_existing_genai_api_key(
raise HTTPException(status_code=404, detail="Key not found") raise HTTPException(status_code=404, detail="Key not found")
try: try:
is_valid = check_model_api_key_is_valid(genai_api_key) llm = get_default_llm(api_key=genai_api_key, timeout=10)
is_valid = test_llm(llm)
except ValueError: except ValueError:
# this is the case where they aren't using an OpenAI-based model # this is the case where they aren't using an OpenAI-based model
is_valid = True is_valid = True
@@ -168,9 +170,15 @@ def store_genai_api_key(
_: User = Depends(current_admin_user), _: User = Depends(current_admin_user),
) -> None: ) -> None:
try: try:
is_valid = check_model_api_key_is_valid(request.api_key) if not request.api_key:
raise HTTPException(400, "No API key provided")
llm = get_default_llm(api_key=request.api_key, timeout=10)
is_valid = test_llm(llm)
if not is_valid: if not is_valid:
raise HTTPException(400, "Invalid API key provided") raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key) get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except RuntimeError as e: except RuntimeError as e:
raise HTTPException(400, str(e)) raise HTTPException(400, str(e))

View File

@@ -1,5 +1,3 @@
from collections.abc import Generator
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
from fastapi import HTTPException from fastapi import HTTPException
@@ -9,30 +7,19 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.feedback import create_doc_retrieval_feedback from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_feedback from danswer.db.feedback import update_query_event_feedback
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.answer_question import answer_qa_query_stream
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.document_index.factory import get_default_document_index from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex from danswer.document_index.vespa.index import VespaIndex
from danswer.search.access_filters import build_access_filters_for_user from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.models import IndexFilters from danswer.search.models import IndexFilters
from danswer.search.models import QueryFlow
from danswer.search.models import SearchQuery
from danswer.search.search_runner import chunks_to_search_docs from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import danswer_search from danswer.search.search_runner import danswer_search
from danswer.search.search_runner import search_chunks
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
from danswer.secondary_llm_flows.query_validation import get_query_answerability from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability from danswer.secondary_llm_flows.query_validation import stream_query_answerability
@@ -41,13 +28,10 @@ from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.models import SearchDoc from danswer.server.models import SearchDoc
from danswer.server.models import SearchFeedbackRequest from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse from danswer.server.models import SearchResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger() logger = setup_logger()
@@ -190,144 +174,10 @@ def stream_direct_qa(
user: User | None = Depends(current_user), user: User | None = Depends(current_user),
db_session: Session = Depends(get_session), db_session: Session = Depends(get_session),
) -> StreamingResponse: ) -> StreamingResponse:
send_packet_debug_msg = "Sending Packet: {}" packets = answer_qa_query_stream(
top_documents_key = "top_documents" question=question, user=user, db_session=db_session
unranked_top_docs_key = "unranked_top_documents"
predicted_flow_key = "predicted_flow"
predicted_search_key = "predicted_search"
query_event_id_key = "query_event_id"
logger.debug(
f"Received QA query ({question.search_type.value} search): {question.query}"
) )
logger.debug(f"Query filters: {question.filters}") return StreamingResponse(packets, media_type="application/json")
@log_generator_function_time()
def stream_qa_portions(
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Generator[str, None, None]:
answer_so_far: str = ""
query = question.query
offset_count = question.offset if question.offset is not None else 0
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff # not used but just in case
filters = question.filters
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
source_type=filters.source_type,
document_set=filters.document_set,
time_cutoff=time_cutoff,
access_control_list=user_acl_filters,
)
search_query = SearchQuery(
query=query,
search_type=question.search_type,
filters=final_filters,
favor_recent=favor_recent,
)
ranked_chunks, unranked_chunks = search_chunks(
query=search_query, document_index=get_default_document_index()
)
# TODO retire this
predicted_search, predicted_flow = query_intent(query)
if not ranked_chunks:
logger.debug("No Documents Found")
empty_docs_result = {
top_documents_key: None,
unranked_top_docs_key: None,
predicted_flow_key: predicted_flow,
predicted_search_key: predicted_search,
}
logger.debug(send_packet_debug_msg.format(empty_docs_result))
yield get_json_line(empty_docs_result)
return
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
initial_response = RerankedRetrievalDocs(
top_documents=top_docs,
unranked_top_documents=unranked_top_docs,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search=predicted_search,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
).dict()
logger.debug(send_packet_debug_msg.format(initial_response))
yield get_json_line(initial_response)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_qa_model()
except Exception as e:
logger.exception("Unable to get QA model")
error = StreamingError(error=str(e))
yield get_json_line(error.dict())
return
# remove chunks marked as not applicable for QA (e.g. Google Drive file
# types which can't be parsed). These chunks are useful to show in the
# search results, but not for QA.
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
try:
for response_packet in qa_model.answer_question_stream(
query, usable_chunks
):
if response_packet is None:
continue
if (
isinstance(response_packet, DanswerAnswerPiece)
and response_packet.answer_piece
):
answer_so_far = answer_so_far + response_packet.answer_piece
logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(response_packet.dict())
except Exception:
# exception is logged in the answer_question method, no need to re-log
logger.exception("Failed to run QA")
yield get_json_line(
{"error": "The LLM failed to produce a useable response"}
)
query_event_id = create_query_event(
query=query,
search_type=question.search_type,
llm_answer=answer_so_far,
retrieved_document_ids=[doc.document_id for doc in top_docs],
user_id=None if user is None else user.id,
db_session=db_session,
)
yield get_json_line({query_event_id_key: query_event_id})
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")
@router.post("/query-feedback") @router.post("/query-feedback")