diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 6ec8dc44b..62699e415 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -15,6 +15,7 @@ METADATA = "metadata" GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key" HTML_SEPARATOR = "\n" PUBLIC_DOC_PAT = "PUBLIC" +QUOTE = "quote" class DocumentSource(str, Enum): diff --git a/backend/danswer/direct_qa/__init__.py b/backend/danswer/direct_qa/__init__.py index 92df801fa..e69de29bb 100644 --- a/backend/danswer/direct_qa/__init__.py +++ b/backend/danswer/direct_qa/__init__.py @@ -1,111 +0,0 @@ -from typing import Any - -import pkg_resources -from openai.error import AuthenticationError - -from danswer.configs.app_configs import QA_TIMEOUT -from danswer.configs.constants import DanswerGenAIModel -from danswer.configs.constants import ModelHostType -from danswer.configs.model_configs import GEN_AI_API_KEY -from danswer.configs.model_configs import GEN_AI_ENDPOINT -from danswer.configs.model_configs import GEN_AI_HOST_TYPE -from danswer.configs.model_configs import INTERNAL_MODEL_VERSION -from danswer.direct_qa.exceptions import UnknownModelError -from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA -from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA -from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA -from danswer.direct_qa.huggingface import HuggingFaceCompletionQA -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.local_transformers import TransformerQA -from danswer.direct_qa.open_ai import OpenAIChatCompletionQA -from danswer.direct_qa.open_ai import OpenAICompletionQA -from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor -from danswer.direct_qa.qa_utils import get_gen_ai_api_key -from danswer.direct_qa.request_model import RequestCompletionQA -from danswer.dynamic_configs.interface import ConfigNotFoundError -from danswer.utils.logger import setup_logger - -logger = setup_logger() - - -def check_model_api_key_is_valid(model_api_key: str) -> bool: - if not model_api_key: - return False - - qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5) - - # try for up to 2 timeouts (e.g. 10 seconds in total) - for _ in range(2): - try: - qa_model.answer_question("Do not respond", []) - return True - except AuthenticationError: - return False - except Exception as e: - logger.warning(f"GenAI API key failed for the following reason: {e}") - - return False - - -def get_default_backend_qa_model( - internal_model: str = INTERNAL_MODEL_VERSION, - endpoint: str | None = GEN_AI_ENDPOINT, - model_host_type: str | None = GEN_AI_HOST_TYPE, - api_key: str | None = GEN_AI_API_KEY, - timeout: int = QA_TIMEOUT, - **kwargs: Any, -) -> QAModel: - if not api_key: - try: - api_key = get_gen_ai_api_key() - except ConfigNotFoundError: - pass - - if internal_model in [ - DanswerGenAIModel.GPT4ALL.value, - DanswerGenAIModel.GPT4ALL_CHAT.value, - ]: - # gpt4all is not compatible M1 Mac hardware as of Aug 2023 - pkg_resources.get_distribution("gpt4all") - - if internal_model == DanswerGenAIModel.OPENAI.value: - return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value: - return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.GPT4ALL.value: - return GPT4AllCompletionQA(**kwargs) - elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value: - return GPT4AllChatCompletionQA(**kwargs) - elif internal_model == DanswerGenAIModel.HUGGINGFACE.value: - return HuggingFaceCompletionQA(api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value: - return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs) - elif internal_model == DanswerGenAIModel.TRANSFORMERS: - return TransformerQA() - elif internal_model == DanswerGenAIModel.REQUEST.value: - if endpoint is None or model_host_type is None: - raise ValueError( - "Request based GenAI model requires an endpoint and host type" - ) - if ( - model_host_type == ModelHostType.HUGGINGFACE.value - or model_host_type == ModelHostType.COLAB_DEMO.value - ): - # Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits - # With the 7B Llama2 Chat model, there is a max limit of 1512 tokens - # This is the sum of input and output tokens, so cannot take in full Danswer context - return RequestCompletionQA( - endpoint=endpoint, - model_host_type=model_host_type, - api_key=api_key, - prompt_processor=WeakModelFreeformProcessor(), - timeout=timeout, - ) - return RequestCompletionQA( - endpoint=endpoint, - model_host_type=model_host_type, - api_key=api_key, - timeout=timeout, - ) - else: - raise UnknownModelError(internal_model) diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 96d33bf49..3494bdac4 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -5,10 +5,9 @@ from danswer.configs.app_configs import QA_TIMEOUT from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.typesense.store import TypesenseIndex from danswer.db.models import User -from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError -from danswer.direct_qa.qa_utils import structure_quotes_for_response +from danswer.direct_qa.llm_utils import get_default_llm from danswer.search.danswer_helper import query_intent from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.models import QueryFlow @@ -75,7 +74,7 @@ def answer_question( ) try: - qa_model = get_default_backend_qa_model(timeout=answer_generation_timeout) + qa_model = get_default_llm(timeout=answer_generation_timeout) except (UnknownModelError, OpenAIKeyMissing) as e: return QAResponse( answer=None, @@ -104,7 +103,7 @@ def answer_question( return QAResponse( answer=answer.answer if answer else None, - quotes=structure_quotes_for_response(quotes), + quotes=quotes.quotes if quotes else None, top_ranked_docs=chunks_to_search_docs(ranked_chunks), lower_ranked_docs=chunks_to_search_docs(unranked_chunks), predicted_flow=predicted_flow, diff --git a/backend/danswer/direct_qa/gpt_4_all.py b/backend/danswer/direct_qa/gpt_4_all.py index 9402b234a..7b90c8290 100644 --- a/backend/danswer/direct_qa/gpt_4_all.py +++ b/backend/danswer/direct_qa/gpt_4_all.py @@ -4,8 +4,12 @@ from typing import Any from danswer.chunking.models import InferenceChunk from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.interfaces import AnswerQuestionReturn +from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_prompts import ChatPromptProcessor from danswer.direct_qa.qa_prompts import NonChatPromptProcessor @@ -85,7 +89,7 @@ class GPT4AllCompletionQA(QAModel): @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -101,12 +105,12 @@ class GPT4AllCompletionQA(QAModel): logger.debug(model_output) - answer, quotes_dict = process_answer(model_output, context_docs) - return answer, quotes_dict + answer, quotes = process_answer(model_output, context_docs) + return answer, quotes def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -150,7 +154,7 @@ class GPT4AllChatCompletionQA(QAModel): @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -177,7 +181,7 @@ class GPT4AllChatCompletionQA(QAModel): def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) diff --git a/backend/danswer/direct_qa/huggingface.py b/backend/danswer/direct_qa/huggingface.py index ea8310f20..709e4c28b 100644 --- a/backend/danswer/direct_qa/huggingface.py +++ b/backend/danswer/direct_qa/huggingface.py @@ -7,8 +7,12 @@ from huggingface_hub.utils import HfHubHTTPError # type:ignore from danswer.chunking.models import InferenceChunk from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.interfaces import AnswerQuestionReturn +from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_prompts import ChatPromptProcessor from danswer.direct_qa.qa_prompts import FreeformProcessor @@ -51,7 +55,7 @@ class HuggingFaceCompletionQA(QAModel): @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -68,7 +72,7 @@ class HuggingFaceCompletionQA(QAModel): def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: filled_prompt = self.prompt_processor.fill_prompt( query, context_docs, self.include_metadata ) @@ -165,7 +169,7 @@ class HuggingFaceChatCompletionQA(QAModel): @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: model_output = self._get_hf_model_output(query, context_docs) answer, quotes_dict = process_answer(model_output, context_docs) @@ -174,7 +178,7 @@ class HuggingFaceChatCompletionQA(QAModel): def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: """As of Aug 2023, HF conversational (chat) endpoints do not support streaming So here it is faked by streaming characters within Danswer from the model output """ diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py index 9a4fa8229..117a47d2a 100644 --- a/backend/danswer/direct_qa/interfaces.py +++ b/backend/danswer/direct_qa/interfaces.py @@ -1,7 +1,6 @@ import abc from collections.abc import Generator from dataclasses import dataclass -from typing import Any from danswer.chunking.models import InferenceChunk @@ -11,6 +10,13 @@ class DanswerAnswer: answer: str | None +@dataclass +class DanswerAnswerPiece: + """A small piece of a complete answer. Used for streaming back answers.""" + + answer_piece: str | None # if None, specifies the end of an Answer + + @dataclass class DanswerQuote: # This is during inference so everything is a string by this point @@ -22,6 +28,21 @@ class DanswerQuote: blurb: str +@dataclass +class DanswerQuotes: + """A little clunky, but making this into a separate class so that the result from + `answer_question_stream` is always a subclass of `dataclass` and can thus use `asdict()` + """ + + quotes: list[DanswerQuote] + + +AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes] +AnswerQuestionStreamReturn = Generator[ + DanswerAnswerPiece | DanswerQuotes | None, None, None +] + + class QAModel: @property def requires_api_key(self) -> bool: @@ -39,7 +60,7 @@ class QAModel: self, query: str, context_docs: list[InferenceChunk], - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: raise NotImplementedError @abc.abstractmethod @@ -47,5 +68,5 @@ class QAModel: self, query: str, context_docs: list[InferenceChunk], - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: raise NotImplementedError diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py new file mode 100644 index 000000000..a49052c9f --- /dev/null +++ b/backend/danswer/direct_qa/llm_utils.py @@ -0,0 +1,111 @@ +from typing import Any + +import pkg_resources +from openai.error import AuthenticationError + +from danswer.configs.app_configs import QA_TIMEOUT +from danswer.configs.constants import DanswerGenAIModel +from danswer.configs.constants import ModelHostType +from danswer.configs.model_configs import GEN_AI_API_KEY +from danswer.configs.model_configs import GEN_AI_ENDPOINT +from danswer.configs.model_configs import GEN_AI_HOST_TYPE +from danswer.configs.model_configs import INTERNAL_MODEL_VERSION +from danswer.direct_qa.exceptions import UnknownModelError +from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA +from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA +from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA +from danswer.direct_qa.huggingface import HuggingFaceCompletionQA +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.local_transformers import TransformerQA +from danswer.direct_qa.open_ai import OpenAIChatCompletionQA +from danswer.direct_qa.open_ai import OpenAICompletionQA +from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor +from danswer.direct_qa.qa_utils import get_gen_ai_api_key +from danswer.direct_qa.request_model import RequestCompletionQA +from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.utils.logger import setup_logger + +logger = setup_logger() + + +def check_model_api_key_is_valid(model_api_key: str) -> bool: + if not model_api_key: + return False + + qa_model = get_default_llm(api_key=model_api_key, timeout=5) + + # try for up to 2 timeouts (e.g. 10 seconds in total) + for _ in range(2): + try: + qa_model.answer_question("Do not respond", []) + return True + except AuthenticationError: + return False + except Exception as e: + logger.warning(f"GenAI API key failed for the following reason: {e}") + + return False + + +def get_default_llm( + internal_model: str = INTERNAL_MODEL_VERSION, + endpoint: str | None = GEN_AI_ENDPOINT, + model_host_type: str | None = GEN_AI_HOST_TYPE, + api_key: str | None = GEN_AI_API_KEY, + timeout: int = QA_TIMEOUT, + **kwargs: Any, +) -> QAModel: + if not api_key: + try: + api_key = get_gen_ai_api_key() + except ConfigNotFoundError: + pass + + if internal_model in [ + DanswerGenAIModel.GPT4ALL.value, + DanswerGenAIModel.GPT4ALL_CHAT.value, + ]: + # gpt4all is not compatible M1 Mac hardware as of Aug 2023 + pkg_resources.get_distribution("gpt4all") + + if internal_model == DanswerGenAIModel.OPENAI.value: + return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) + elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value: + return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs) + elif internal_model == DanswerGenAIModel.GPT4ALL.value: + return GPT4AllCompletionQA(**kwargs) + elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value: + return GPT4AllChatCompletionQA(**kwargs) + elif internal_model == DanswerGenAIModel.HUGGINGFACE.value: + return HuggingFaceCompletionQA(api_key=api_key, **kwargs) + elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value: + return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs) + elif internal_model == DanswerGenAIModel.TRANSFORMERS: + return TransformerQA() + elif internal_model == DanswerGenAIModel.REQUEST.value: + if endpoint is None or model_host_type is None: + raise ValueError( + "Request based GenAI model requires an endpoint and host type" + ) + if ( + model_host_type == ModelHostType.HUGGINGFACE.value + or model_host_type == ModelHostType.COLAB_DEMO.value + ): + # Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits + # With the 7B Llama2 Chat model, there is a max limit of 1512 tokens + # This is the sum of input and output tokens, so cannot take in full Danswer context + return RequestCompletionQA( + endpoint=endpoint, + model_host_type=model_host_type, + api_key=api_key, + prompt_processor=WeakModelFreeformProcessor(), + timeout=timeout, + ) + return RequestCompletionQA( + endpoint=endpoint, + model_host_type=model_host_type, + api_key=api_key, + timeout=timeout, + ) + else: + raise UnknownModelError(internal_model) diff --git a/backend/danswer/direct_qa/local_transformers.py b/backend/danswer/direct_qa/local_transformers.py index 17e986a84..92e160cd5 100644 --- a/backend/danswer/direct_qa/local_transformers.py +++ b/backend/danswer/direct_qa/local_transformers.py @@ -7,10 +7,13 @@ from transformers import QuestionAnsweringPipeline # type:ignore from danswer.chunking.models import InferenceChunk from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.interfaces import AnswerQuestionReturn +from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.qa_utils import structure_quotes_for_response from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time @@ -104,7 +107,7 @@ class TransformerQA(QAModel): @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: danswer_quotes: list[DanswerQuote] = [] d_answers: list[str] = [] for chunk in context_docs: @@ -118,11 +121,13 @@ class TransformerQA(QAModel): for ind, answer in enumerate(d_answers, start=1) ] combined_answer = "\n".join(answers_list) - return DanswerAnswer(answer=combined_answer), danswer_quotes + return DanswerAnswer(answer=combined_answer), DanswerQuotes( + quotes=danswer_quotes + ) def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: quotes: list[DanswerQuote] = [] answers: list[str] = [] for chunk in context_docs: @@ -135,13 +140,14 @@ class TransformerQA(QAModel): answer_count = 1 for answer in answers: if answer_count == 1: - yield {"answer_data": "Source 1: "} + yield DanswerAnswerPiece(answer_piece="Source 1: ") else: - yield {"answer_data": f"\nSource {answer_count}: "} + yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ") answer_count += 1 for char in answer.strip(): - yield {"answer_data": char} + yield DanswerAnswerPiece(answer_piece=char) - yield {"answer_finished": True} + # signal end of answer + yield DanswerAnswerPiece(answer_piece=None) - yield structure_quotes_for_response(quotes) + yield DanswerQuotes(quotes=quotes) diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py index 54c938b03..f8694d9ae 100644 --- a/backend/danswer/direct_qa/open_ai.py +++ b/backend/danswer/direct_qa/open_ai.py @@ -22,8 +22,12 @@ from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.direct_qa.exceptions import OpenAIKeyMissing +from danswer.direct_qa.interfaces import AnswerQuestionReturn +from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_prompts import ChatPromptProcessor from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg @@ -147,7 +151,7 @@ class OpenAICompletionQA(OpenAIQAModel): @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) filled_prompt = self.prompt_processor.fill_prompt( @@ -177,7 +181,7 @@ class OpenAICompletionQA(OpenAIQAModel): def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) filled_prompt = self.prompt_processor.fill_prompt( @@ -243,7 +247,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel): self, query: str, context_docs: list[InferenceChunk], - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) messages = self.prompt_processor.fill_prompt( @@ -276,12 +280,12 @@ class OpenAIChatCompletionQA(OpenAIQAModel): logger.debug(model_output) - answer, quotes_dict = process_answer(model_output, context_docs) - return answer, quotes_dict + answer, quotes = process_answer(model_output, context_docs) + return answer, quotes def answer_question_stream( self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) messages = self.prompt_processor.fill_prompt( diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index 8b79f576a..eaa95ee88 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -2,7 +2,6 @@ import json import math import re from collections.abc import Generator -from typing import Any from typing import cast from typing import Optional from typing import Tuple @@ -11,15 +10,12 @@ import regex from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT -from danswer.configs.constants import BLURB -from danswer.configs.constants import DOCUMENT_ID from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY -from danswer.configs.constants import SEMANTIC_IDENTIFIER -from danswer.configs.constants import SOURCE_LINK -from danswer.configs.constants import SOURCE_TYPE from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import DanswerQuotes from danswer.direct_qa.qa_prompts import ANSWER_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT @@ -37,24 +33,6 @@ def get_gen_ai_api_key() -> str: ) -def structure_quotes_for_response( - quotes: list[DanswerQuote] | None, -) -> dict[str, dict[str, str | None]]: - if quotes is None: - return {} - - response_quotes = {} - for quote in quotes: - response_quotes[quote.quote] = { - DOCUMENT_ID: quote.document_id, - SOURCE_LINK: quote.link, - SOURCE_TYPE: quote.source_type, - SEMANTIC_IDENTIFIER: quote.semantic_identifier, - BLURB: quote.blurb, - } - return response_quotes - - def extract_answer_quotes_freeform( answer_raw: str, ) -> Tuple[Optional[str], Optional[list[str]]]: @@ -114,8 +92,8 @@ def match_quotes_to_docs( max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, fuzzy_search: bool = False, prefix_only_length: int = 100, -) -> list[DanswerQuote]: - danswer_quotes = [] +) -> DanswerQuotes: + danswer_quotes: list[DanswerQuote] = [] for quote in quotes: max_edits = math.ceil(float(len(quote)) * max_error_percent) @@ -145,23 +123,13 @@ def match_quotes_to_docs( # Extracting the link from the offset curr_link = None for link_offset, link in chunk.source_links.items(): - # Should always find one because offset is at least 0 and there must be a 0 link_offset + # Should always find one because offset is at least 0 and there + # must be a 0 link_offset if int(link_offset) <= offset: curr_link = link else: - danswer_quotes.append( - DanswerQuote( - quote=quote, - document_id=chunk.document_id, - link=curr_link, - source_type=chunk.source_type, - semantic_identifier=chunk.semantic_identifier, - blurb=chunk.blurb, - ) - ) break - # If the offset is larger than the start of the last quote, it must be the last one danswer_quotes.append( DanswerQuote( quote=quote, @@ -174,24 +142,24 @@ def match_quotes_to_docs( ) break - return danswer_quotes + return DanswerQuotes(quotes=danswer_quotes) def process_answer( answer_raw: str, chunks: list[InferenceChunk] -) -> tuple[DanswerAnswer, list[DanswerQuote]]: +) -> tuple[DanswerAnswer, DanswerQuotes]: answer, quote_strings = separate_answer_quotes(answer_raw) if answer == UNCERTAINTY_PAT or not answer: if answer == UNCERTAINTY_PAT: logger.debug("Answer matched UNCERTAINTY_PAT") else: logger.debug("No answer extracted from raw output") - return DanswerAnswer(answer=None), [] + return DanswerAnswer(answer=None), DanswerQuotes(quotes=[]) logger.info(f"Answer: {answer}") if not quote_strings: logger.debug("No quotes extracted from raw output") - return DanswerAnswer(answer=answer), [] + return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[]) logger.info(f"All quotes (including unmatched): {quote_strings}") quotes = match_quotes_to_docs(quote_strings, chunks) logger.info(f"Final quotes: {quotes}") @@ -212,7 +180,7 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: def extract_quotes_from_completed_token_stream( model_output: str, context_chunks: list[InferenceChunk] -) -> list[DanswerQuote]: +) -> DanswerQuotes: logger.debug(model_output) answer, quotes = process_answer(model_output, context_chunks) if answer: @@ -227,7 +195,7 @@ def process_model_tokens( tokens: Generator[str, None, None], context_docs: list[InferenceChunk], is_json_prompt: bool = True, -) -> Generator[dict[str, Any], None, None]: +) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: """Yields Answer tokens back out in a dict for streaming to frontend When Answer section ends, yields dict with answer_finished key Collects all the tokens at the end to form the complete model output""" @@ -255,21 +223,20 @@ def process_model_tokens( if found_answer_start and not found_answer_end: if is_json_prompt and stream_json_answer_end(model_previous, token): found_answer_end = True - yield {"answer_finished": True} + yield DanswerAnswerPiece(answer_piece=None) continue elif not is_json_prompt: if quote_pat in hold_quote + token or quote_loose in hold_quote + token: found_answer_end = True - yield {"answer_finished": True} + yield DanswerAnswerPiece(answer_piece=None) continue if hold_quote + token in quote_pat_full: hold_quote += token continue - yield {"answer_data": hold_quote + token} + yield DanswerAnswerPiece(answer_piece=token) hold_quote = "" - quotes = extract_quotes_from_completed_token_stream(model_output, context_docs) - yield structure_quotes_for_response(quotes) + yield extract_quotes_from_completed_token_stream(model_output, context_docs) def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: diff --git a/backend/danswer/direct_qa/request_model.py b/backend/danswer/direct_qa/request_model.py index 600ad7058..26f5d98b2 100644 --- a/backend/danswer/direct_qa/request_model.py +++ b/backend/danswer/direct_qa/request_model.py @@ -13,8 +13,8 @@ from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_ENDPOINT from danswer.configs.model_configs import GEN_AI_HOST_TYPE from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS -from danswer.direct_qa.interfaces import DanswerAnswer -from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import AnswerQuestionReturn +from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.qa_prompts import JsonProcessor from danswer.direct_qa.qa_prompts import NonChatPromptProcessor @@ -236,7 +236,7 @@ class RequestCompletionQA(QAModel): @log_function_time() def answer_question( self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + ) -> AnswerQuestionReturn: model_api_response = self._get_request_response( query, context_docs, stream=False ) @@ -253,7 +253,7 @@ class RequestCompletionQA(QAModel): self, query: str, context_docs: list[InferenceChunk], - ) -> Generator[dict[str, Any] | None, None, None]: + ) -> AnswerQuestionStreamReturn: model_api_response = self._get_request_response( query, context_docs, stream=False ) diff --git a/backend/danswer/listeners/slack_listener.py b/backend/danswer/listeners/slack_listener.py index 0c3d89e79..efc654cb2 100644 --- a/backend/danswer/listeners/slack_listener.py +++ b/backend/danswer/listeners/slack_listener.py @@ -13,6 +13,7 @@ from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.constants import DocumentSource from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.direct_qa.answer_question import answer_question +from danswer.direct_qa.interfaces import DanswerQuote from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest from danswer.server.models import SearchDoc @@ -59,24 +60,22 @@ def _build_custom_semantic_identifier( return semantic_identifier -def _process_quotes( - quotes: dict[str, dict[str, str | None]] | None -) -> tuple[str | None, list[str]]: +def _process_quotes(quotes: list[DanswerQuote] | None) -> tuple[str | None, list[str]]: if not quotes: return None, [] quote_lines: list[str] = [] doc_identifiers: list[str] = [] - for quote_dict in quotes.values(): - doc_id = str(quote_dict.get("document_id", "")) - doc_link = quote_dict.get("link") - doc_name = str(quote_dict.get("semantic_identifier", "")) + for quote in quotes: + doc_id = quote.document_id + doc_link = quote.link + doc_name = quote.semantic_identifier if doc_link and doc_name and doc_id and doc_id not in doc_identifiers: doc_identifiers.append(doc_id) custom_semantic_identifier = _build_custom_semantic_identifier( semantic_identifier=doc_name, - blurb=str(quote_dict.get("blurb", "")), - source=str(quote_dict.get("source_type", "")), + blurb=quote.blurb, + source=quote.source_type, ) quote_lines.append(f"- <{doc_link}|{custom_semantic_identifier}>") diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 410981c66..b17f6cb32 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -33,7 +33,7 @@ from danswer.datastores.qdrant.indexing import list_qdrant_collections from danswer.datastores.typesense.store import check_typesense_collection_exist from danswer.datastores.typesense.store import create_typesense_collection from danswer.db.credentials import create_initial_public_credential -from danswer.direct_qa import get_default_backend_qa_model +from danswer.direct_qa.llm_utils import get_default_llm from danswer.server.event_loading import router as event_processing_router from danswer.server.health import router as health_router from danswer.server.manage import router as admin_router @@ -179,7 +179,7 @@ def get_application() -> FastAPI: logger.info("Warming up local NLP models.") warm_up_models() - qa_model = get_default_backend_qa_model() + qa_model = get_default_llm() qa_model.warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 9953780dd..ab9175a59 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -9,7 +9,6 @@ from fastapi import Request from fastapi import Response from fastapi import UploadFile from fastapi_users.db import SQLAlchemyUserDatabase -from sqlalchemy.exc import IntegrityError from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.orm import Session @@ -53,14 +52,10 @@ from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_latest_index_attempts from danswer.db.models import DeletionAttempt -from danswer.db.models import DeletionStatus -from danswer.db.models import IndexAttempt -from danswer.db.models import IndexingStatus from danswer.db.models import User -from danswer.direct_qa import check_model_api_key_is_valid -from danswer.direct_qa import get_default_backend_qa_model +from danswer.direct_qa.llm_utils import check_model_api_key_is_valid +from danswer.direct_qa.llm_utils import get_default_llm from danswer.direct_qa.open_ai import get_gen_ai_api_key -from danswer.direct_qa.open_ai import OpenAIQAModel from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.server.models import ApiKey @@ -361,7 +356,7 @@ def validate_existing_genai_api_key( ) -> None: # OpenAI key is only used for generative QA, so no need to validate this # if it's turned off or if a non-OpenAI model is being used - if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key: + if DISABLE_GENERATIVE_AI or not get_default_llm().requires_api_key: return # Only validate every so often diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 87e908f34..f98dd1630 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -19,6 +19,7 @@ from danswer.db.models import DeletionAttempt from danswer.db.models import DeletionStatus from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus +from danswer.direct_qa.interfaces import DanswerQuote from danswer.search.models import QueryFlow from danswer.search.models import SearchType from danswer.server.utils import mask_credential_dict @@ -110,7 +111,7 @@ class SearchResponse(BaseModel): class QAResponse(SearchResponse): answer: str | None # DanswerAnswer - quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote + quotes: list[DanswerQuote] | None predicted_flow: QueryFlow predicted_search: SearchType error_msg: str | None = None diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 2e0e79e89..f1cad9d20 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,5 +1,6 @@ import json from collections.abc import Generator +from dataclasses import asdict from fastapi import APIRouter from fastapi import Depends @@ -12,10 +13,10 @@ from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.typesense.store import TypesenseIndex from danswer.db.models import User -from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.answer_question import answer_question from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError +from danswer.direct_qa.llm_utils import get_default_llm from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import recommend_search_flow from danswer.search.keyword_search import retrieve_keyword_documents @@ -166,7 +167,7 @@ def stream_direct_qa( return try: - qa_model = get_default_backend_qa_model() + qa_model = get_default_llm() except (UnknownModelError, OpenAIKeyMissing) as e: logger.exception("Unable to get QA model") yield get_json_line({"error": str(e)}) @@ -178,16 +179,16 @@ def stream_direct_qa( "Chunks offset too large, should not retry this many times" ) try: - for response_dict in qa_model.answer_question_stream( + for response_packet in qa_model.answer_question_stream( query, ranked_chunks[ chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS ], ): - if response_dict is None: + if response_packet is None: continue - logger.debug(f"Sending packet: {response_dict}") - yield get_json_line(response_dict) + logger.debug(f"Sending packet: {response_packet}") + yield get_json_line(asdict(response_packet)) except Exception as e: # exception is logged in the answer_question method, no need to re-log yield get_json_line({"error": str(e)}) diff --git a/web/src/components/search/SearchResultsDisplay.tsx b/web/src/components/search/SearchResultsDisplay.tsx index 1771c8116..854c1571b 100644 --- a/web/src/components/search/SearchResultsDisplay.tsx +++ b/web/src/components/search/SearchResultsDisplay.tsx @@ -56,7 +56,7 @@ export const SearchResultsDisplay: React.FC = ({ const dedupedQuotes: Quote[] = []; const seen = new Set(); if (quotes) { - Object.values(quotes).forEach((quote) => { + quotes.forEach((quote) => { if (!seen.has(quote.document_id)) { dedupedQuotes.push(quote); seen.add(quote.document_id); @@ -109,7 +109,7 @@ export const SearchResultsDisplay: React.FC = ({ diff --git a/web/src/components/search/SearchSection.tsx b/web/src/components/search/SearchSection.tsx index 408f5d279..32e0d4b34 100644 --- a/web/src/components/search/SearchSection.tsx +++ b/web/src/components/search/SearchSection.tsx @@ -71,7 +71,7 @@ export const SearchSection: React.FC = ({ ...(prevState || initialSearchResponse), answer, })); - const updateQuotes = (quotes: Record) => + const updateQuotes = (quotes: Quote[]) => setSearchResponse((prevState) => ({ ...(prevState || initialSearchResponse), quotes, diff --git a/web/src/lib/search/interfaces.ts b/web/src/lib/search/interfaces.ts index 5621af5f4..02bc8b004 100644 --- a/web/src/lib/search/interfaces.ts +++ b/web/src/lib/search/interfaces.ts @@ -13,11 +13,12 @@ export const SearchType = { export type SearchType = (typeof SearchType)[keyof typeof SearchType]; export interface Quote { + quote: string; document_id: string; - link: string; + link: string | null; source_type: ValidSources; blurb: string; - semantic_identifier: string | null; + semantic_identifier: string; } export interface DanswerDocument { @@ -32,7 +33,7 @@ export interface SearchResponse { suggestedSearchType: SearchType | null; suggestedFlowType: FlowType | null; answer: string | null; - quotes: Record | null; + quotes: Quote[] | null; documents: DanswerDocument[] | null; error: string | null; } @@ -51,7 +52,7 @@ export interface SearchRequestArgs { query: string; sources: Source[]; updateCurrentAnswer: (val: string) => void; - updateQuotes: (quotes: Record) => void; + updateQuotes: (quotes: Quote[]) => void; updateDocs: (documents: DanswerDocument[]) => void; updateSuggestedSearchType: (searchType: SearchType) => void; updateSuggestedFlowType: (flowType: FlowType) => void; diff --git a/web/src/lib/search/qa.ts b/web/src/lib/search/qa.ts index 40333fbd5..d7258ca05 100644 --- a/web/src/lib/search/qa.ts +++ b/web/src/lib/search/qa.ts @@ -24,7 +24,7 @@ export const searchRequest = async ({ } let answer = ""; - let quotes: Record | null = null; + let quotes: Quote[] | null = null; let relevantDocuments: DanswerDocument[] | null = null; try { const response = await fetch("/api/direct-qa", { @@ -54,7 +54,7 @@ export const searchRequest = async ({ const data = (await response.json()) as { answer: string; - quotes: Record; + quotes: Quote[]; top_ranked_docs: DanswerDocument[]; lower_ranked_docs: DanswerDocument[]; predicted_flow: FlowType; diff --git a/web/src/lib/search/streamingQa.ts b/web/src/lib/search/streamingQa.ts index 13ce07115..50c0fe0d1 100644 --- a/web/src/lib/search/streamingQa.ts +++ b/web/src/lib/search/streamingQa.ts @@ -69,7 +69,7 @@ export const searchRequestStreamed = async ({ } let answer = ""; - let quotes: Record | null = null; + let quotes: Quote[] | null = null; let relevantDocuments: DanswerDocument[] | null = null; try { const response = await fetch("/api/stream-direct-qa", { @@ -118,18 +118,17 @@ export const searchRequestStreamed = async ({ previousPartialChunk = partialChunk; completedChunks.forEach((chunk) => { // TODO: clean up response / this logic - const answerChunk = chunk.answer_data; + const answerChunk = chunk.answer_piece; if (answerChunk) { answer += answerChunk; updateCurrentAnswer(answer); return; } - const answerFinished = chunk.answer_finished; - if (answerFinished) { + if (answerChunk === null) { // set quotes as non-null to signify that the answer is finished and // we're now looking for quotes - updateQuotes({}); + updateQuotes([]); if ( answer && !answer.endsWith(".") && @@ -168,9 +167,15 @@ export const searchRequestStreamed = async ({ return; } - // if it doesn't match any of the above, assume it is a quote - quotes = chunk as Record; - updateQuotes(quotes); + // Check for quote section + if (chunk.quotes) { + quotes = chunk.quotes as Quote[]; + updateQuotes(quotes); + return; + } + + // should never reach this + console.log("Unknown chunk:", chunk); }); } } catch (err) {