Stop using untyped dicts to represent quotes (#310)

This commit is contained in:
Chris Weaver 2023-08-17 14:53:55 -07:00 committed by GitHub
parent 81a4934bb8
commit f37ac76d3c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
21 changed files with 247 additions and 239 deletions

View File

@ -15,6 +15,7 @@ METADATA = "metadata"
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
HTML_SEPARATOR = "\n"
PUBLIC_DOC_PAT = "PUBLIC"
QUOTE = "quote"
class DocumentSource(str, Enum):

View File

@ -1,111 +0,0 @@
from typing import Any
import pkg_resources
from openai.error import AuthenticationError
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.local_transformers import TransformerQA
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
def get_default_backend_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
**kwargs: Any,
) -> QAModel:
if not api_key:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
pass
if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
]:
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
pkg_resources.get_distribution("gpt4all")
if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
return GPT4AllCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
return GPT4AllChatCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
return TransformerQA()
elif internal_model == DanswerGenAIModel.REQUEST.value:
if endpoint is None or model_host_type is None:
raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if (
model_host_type == ModelHostType.HUGGINGFACE.value
or model_host_type == ModelHostType.COLAB_DEMO.value
):
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
prompt_processor=WeakModelFreeformProcessor(),
timeout=timeout,
)
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
timeout=timeout,
)
else:
raise UnknownModelError(internal_model)

View File

@ -5,10 +5,9 @@ from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.qa_utils import structure_quotes_for_response
from danswer.direct_qa.llm_utils import get_default_llm
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
@ -75,7 +74,7 @@ def answer_question(
)
try:
qa_model = get_default_backend_qa_model(timeout=answer_generation_timeout)
qa_model = get_default_llm(timeout=answer_generation_timeout)
except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse(
answer=None,
@ -104,7 +103,7 @@ def answer_question(
return QAResponse(
answer=answer.answer if answer else None,
quotes=structure_quotes_for_response(quotes),
quotes=quotes.quotes if quotes else None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,

View File

@ -4,8 +4,12 @@ from typing import Any
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
@ -85,7 +89,7 @@ class GPT4AllCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -101,12 +105,12 @@ class GPT4AllCompletionQA(QAModel):
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
answer, quotes = process_answer(model_output, context_docs)
return answer, quotes
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -150,7 +154,7 @@ class GPT4AllChatCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -177,7 +181,7 @@ class GPT4AllChatCompletionQA(QAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)

View File

@ -7,8 +7,12 @@ from huggingface_hub.utils import HfHubHTTPError # type:ignore
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import FreeformProcessor
@ -51,7 +55,7 @@ class HuggingFaceCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -68,7 +72,7 @@ class HuggingFaceCompletionQA(QAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
@ -165,7 +169,7 @@ class HuggingFaceChatCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
model_output = self._get_hf_model_output(query, context_docs)
answer, quotes_dict = process_answer(model_output, context_docs)
@ -174,7 +178,7 @@ class HuggingFaceChatCompletionQA(QAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
So here it is faked by streaming characters within Danswer from the model output
"""

View File

@ -1,7 +1,6 @@
import abc
from collections.abc import Generator
from dataclasses import dataclass
from typing import Any
from danswer.chunking.models import InferenceChunk
@ -11,6 +10,13 @@ class DanswerAnswer:
answer: str | None
@dataclass
class DanswerAnswerPiece:
"""A small piece of a complete answer. Used for streaming back answers."""
answer_piece: str | None # if None, specifies the end of an Answer
@dataclass
class DanswerQuote:
# This is during inference so everything is a string by this point
@ -22,6 +28,21 @@ class DanswerQuote:
blurb: str
@dataclass
class DanswerQuotes:
"""A little clunky, but making this into a separate class so that the result from
`answer_question_stream` is always a subclass of `dataclass` and can thus use `asdict()`
"""
quotes: list[DanswerQuote]
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
AnswerQuestionStreamReturn = Generator[
DanswerAnswerPiece | DanswerQuotes | None, None, None
]
class QAModel:
@property
def requires_api_key(self) -> bool:
@ -39,7 +60,7 @@ class QAModel:
self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
raise NotImplementedError
@abc.abstractmethod
@ -47,5 +68,5 @@ class QAModel:
self,
query: str,
context_docs: list[InferenceChunk],
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
raise NotImplementedError

View File

@ -0,0 +1,111 @@
from typing import Any
import pkg_resources
from openai.error import AuthenticationError
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.local_transformers import TransformerQA
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
qa_model = get_default_llm(api_key=model_api_key, timeout=5)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
def get_default_llm(
internal_model: str = INTERNAL_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
**kwargs: Any,
) -> QAModel:
if not api_key:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
pass
if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
]:
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
pkg_resources.get_distribution("gpt4all")
if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
return GPT4AllCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
return GPT4AllChatCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
return TransformerQA()
elif internal_model == DanswerGenAIModel.REQUEST.value:
if endpoint is None or model_host_type is None:
raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if (
model_host_type == ModelHostType.HUGGINGFACE.value
or model_host_type == ModelHostType.COLAB_DEMO.value
):
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
prompt_processor=WeakModelFreeformProcessor(),
timeout=timeout,
)
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
timeout=timeout,
)
else:
raise UnknownModelError(internal_model)

View File

@ -7,10 +7,13 @@ from transformers import QuestionAnsweringPipeline # type:ignore
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_utils import structure_quotes_for_response
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
@ -104,7 +107,7 @@ class TransformerQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
danswer_quotes: list[DanswerQuote] = []
d_answers: list[str] = []
for chunk in context_docs:
@ -118,11 +121,13 @@ class TransformerQA(QAModel):
for ind, answer in enumerate(d_answers, start=1)
]
combined_answer = "\n".join(answers_list)
return DanswerAnswer(answer=combined_answer), danswer_quotes
return DanswerAnswer(answer=combined_answer), DanswerQuotes(
quotes=danswer_quotes
)
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
quotes: list[DanswerQuote] = []
answers: list[str] = []
for chunk in context_docs:
@ -135,13 +140,14 @@ class TransformerQA(QAModel):
answer_count = 1
for answer in answers:
if answer_count == 1:
yield {"answer_data": "Source 1: "}
yield DanswerAnswerPiece(answer_piece="Source 1: ")
else:
yield {"answer_data": f"\nSource {answer_count}: "}
yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ")
answer_count += 1
for char in answer.strip():
yield {"answer_data": char}
yield DanswerAnswerPiece(answer_piece=char)
yield {"answer_finished": True}
# signal end of answer
yield DanswerAnswerPiece(answer_piece=None)
yield structure_quotes_for_response(quotes)
yield DanswerQuotes(quotes=quotes)

View File

@ -22,8 +22,12 @@ from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
@ -147,7 +151,7 @@ class OpenAICompletionQA(OpenAIQAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt(
@ -177,7 +181,7 @@ class OpenAICompletionQA(OpenAIQAModel):
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt(
@ -243,7 +247,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
messages = self.prompt_processor.fill_prompt(
@ -276,12 +280,12 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
answer, quotes = process_answer(model_output, context_docs)
return answer, quotes
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
messages = self.prompt_processor.fill_prompt(

View File

@ -2,7 +2,6 @@ import json
import math
import re
from collections.abc import Generator
from typing import Any
from typing import cast
from typing import Optional
from typing import Tuple
@ -11,15 +10,12 @@ import regex
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
@ -37,24 +33,6 @@ def get_gen_ai_api_key() -> str:
)
def structure_quotes_for_response(
quotes: list[DanswerQuote] | None,
) -> dict[str, dict[str, str | None]]:
if quotes is None:
return {}
response_quotes = {}
for quote in quotes:
response_quotes[quote.quote] = {
DOCUMENT_ID: quote.document_id,
SOURCE_LINK: quote.link,
SOURCE_TYPE: quote.source_type,
SEMANTIC_IDENTIFIER: quote.semantic_identifier,
BLURB: quote.blurb,
}
return response_quotes
def extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
@ -114,8 +92,8 @@ def match_quotes_to_docs(
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False,
prefix_only_length: int = 100,
) -> list[DanswerQuote]:
danswer_quotes = []
) -> DanswerQuotes:
danswer_quotes: list[DanswerQuote] = []
for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent)
@ -145,23 +123,13 @@ def match_quotes_to_docs(
# Extracting the link from the offset
curr_link = None
for link_offset, link in chunk.source_links.items():
# Should always find one because offset is at least 0 and there must be a 0 link_offset
# Should always find one because offset is at least 0 and there
# must be a 0 link_offset
if int(link_offset) <= offset:
curr_link = link
else:
danswer_quotes.append(
DanswerQuote(
quote=quote,
document_id=chunk.document_id,
link=curr_link,
source_type=chunk.source_type,
semantic_identifier=chunk.semantic_identifier,
blurb=chunk.blurb,
)
)
break
# If the offset is larger than the start of the last quote, it must be the last one
danswer_quotes.append(
DanswerQuote(
quote=quote,
@ -174,24 +142,24 @@ def match_quotes_to_docs(
)
break
return danswer_quotes
return DanswerQuotes(quotes=danswer_quotes)
def process_answer(
answer_raw: str, chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> tuple[DanswerAnswer, DanswerQuotes]:
answer, quote_strings = separate_answer_quotes(answer_raw)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")
else:
logger.debug("No answer extracted from raw output")
return DanswerAnswer(answer=None), []
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
logger.info(f"Answer: {answer}")
if not quote_strings:
logger.debug("No quotes extracted from raw output")
return DanswerAnswer(answer=answer), []
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
logger.info(f"All quotes (including unmatched): {quote_strings}")
quotes = match_quotes_to_docs(quote_strings, chunks)
logger.info(f"Final quotes: {quotes}")
@ -212,7 +180,7 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
def extract_quotes_from_completed_token_stream(
model_output: str, context_chunks: list[InferenceChunk]
) -> list[DanswerQuote]:
) -> DanswerQuotes:
logger.debug(model_output)
answer, quotes = process_answer(model_output, context_chunks)
if answer:
@ -227,7 +195,7 @@ def process_model_tokens(
tokens: Generator[str, None, None],
context_docs: list[InferenceChunk],
is_json_prompt: bool = True,
) -> Generator[dict[str, Any], None, None]:
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
@ -255,21 +223,20 @@ def process_model_tokens(
if found_answer_start and not found_answer_end:
if is_json_prompt and stream_json_answer_end(model_previous, token):
found_answer_end = True
yield {"answer_finished": True}
yield DanswerAnswerPiece(answer_piece=None)
continue
elif not is_json_prompt:
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
found_answer_end = True
yield {"answer_finished": True}
yield DanswerAnswerPiece(answer_piece=None)
continue
if hold_quote + token in quote_pat_full:
hold_quote += token
continue
yield {"answer_data": hold_quote + token}
yield DanswerAnswerPiece(answer_piece=token)
hold_quote = ""
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
yield structure_quotes_for_response(quotes)
yield extract_quotes_from_completed_token_stream(model_output, context_docs)
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:

View File

@ -13,8 +13,8 @@ from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
@ -236,7 +236,7 @@ class RequestCompletionQA(QAModel):
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
) -> AnswerQuestionReturn:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)
@ -253,7 +253,7 @@ class RequestCompletionQA(QAModel):
self,
query: str,
context_docs: list[InferenceChunk],
) -> Generator[dict[str, Any] | None, None, None]:
) -> AnswerQuestionStreamReturn:
model_api_response = self._get_request_response(
query, context_docs, stream=False
)

View File

@ -13,6 +13,7 @@ from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.constants import DocumentSource
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc
@ -59,24 +60,22 @@ def _build_custom_semantic_identifier(
return semantic_identifier
def _process_quotes(
quotes: dict[str, dict[str, str | None]] | None
) -> tuple[str | None, list[str]]:
def _process_quotes(quotes: list[DanswerQuote] | None) -> tuple[str | None, list[str]]:
if not quotes:
return None, []
quote_lines: list[str] = []
doc_identifiers: list[str] = []
for quote_dict in quotes.values():
doc_id = str(quote_dict.get("document_id", ""))
doc_link = quote_dict.get("link")
doc_name = str(quote_dict.get("semantic_identifier", ""))
for quote in quotes:
doc_id = quote.document_id
doc_link = quote.link
doc_name = quote.semantic_identifier
if doc_link and doc_name and doc_id and doc_id not in doc_identifiers:
doc_identifiers.append(doc_id)
custom_semantic_identifier = _build_custom_semantic_identifier(
semantic_identifier=doc_name,
blurb=str(quote_dict.get("blurb", "")),
source=str(quote_dict.get("source_type", "")),
blurb=quote.blurb,
source=quote.source_type,
)
quote_lines.append(f"- <{doc_link}|{custom_semantic_identifier}>")

View File

@ -33,7 +33,7 @@ from danswer.datastores.qdrant.indexing import list_qdrant_collections
from danswer.datastores.typesense.store import check_typesense_collection_exist
from danswer.datastores.typesense.store import create_typesense_collection
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.llm_utils import get_default_llm
from danswer.server.event_loading import router as event_processing_router
from danswer.server.health import router as health_router
from danswer.server.manage import router as admin_router
@ -179,7 +179,7 @@ def get_application() -> FastAPI:
logger.info("Warming up local NLP models.")
warm_up_models()
qa_model = get_default_backend_qa_model()
qa_model = get_default_llm()
qa_model.warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")

View File

@ -9,7 +9,6 @@ from fastapi import Request
from fastapi import Response
from fastapi import UploadFile
from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session
@ -53,14 +52,10 @@ from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.models import DeletionAttempt
from danswer.db.models import DeletionStatus
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import User
from danswer.direct_qa import check_model_api_key_is_valid
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
from danswer.direct_qa.llm_utils import get_default_llm
from danswer.direct_qa.open_ai import get_gen_ai_api_key
from danswer.direct_qa.open_ai import OpenAIQAModel
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey
@ -361,7 +356,7 @@ def validate_existing_genai_api_key(
) -> None:
# OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off or if a non-OpenAI model is being used
if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key:
if DISABLE_GENERATIVE_AI or not get_default_llm().requires_api_key:
return
# Only validate every so often

View File

@ -19,6 +19,7 @@ from danswer.db.models import DeletionAttempt
from danswer.db.models import DeletionStatus
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.search.models import QueryFlow
from danswer.search.models import SearchType
from danswer.server.utils import mask_credential_dict
@ -110,7 +111,7 @@ class SearchResponse(BaseModel):
class QAResponse(SearchResponse):
answer: str | None # DanswerAnswer
quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote
quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow
predicted_search: SearchType
error_msg: str | None = None

View File

@ -1,5 +1,6 @@
import json
from collections.abc import Generator
from dataclasses import asdict
from fastapi import APIRouter
from fastapi import Depends
@ -12,10 +13,10 @@ from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_llm
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
@ -166,7 +167,7 @@ def stream_direct_qa(
return
try:
qa_model = get_default_backend_qa_model()
qa_model = get_default_llm()
except (UnknownModelError, OpenAIKeyMissing) as e:
logger.exception("Unable to get QA model")
yield get_json_line({"error": str(e)})
@ -178,16 +179,16 @@ def stream_direct_qa(
"Chunks offset too large, should not retry this many times"
)
try:
for response_dict in qa_model.answer_question_stream(
for response_packet in qa_model.answer_question_stream(
query,
ranked_chunks[
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
],
):
if response_dict is None:
if response_packet is None:
continue
logger.debug(f"Sending packet: {response_dict}")
yield get_json_line(response_dict)
logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(asdict(response_packet))
except Exception as e:
# exception is logged in the answer_question method, no need to re-log
yield get_json_line({"error": str(e)})

View File

@ -56,7 +56,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
const dedupedQuotes: Quote[] = [];
const seen = new Set<string>();
if (quotes) {
Object.values(quotes).forEach((quote) => {
quotes.forEach((quote) => {
if (!seen.has(quote.document_id)) {
dedupedQuotes.push(quote);
seen.add(quote.document_id);
@ -109,7 +109,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
<a
key={quoteInfo.document_id}
className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800"
href={quoteInfo.link}
href={quoteInfo.link || undefined}
target="_blank"
rel="noopener noreferrer"
>

View File

@ -71,7 +71,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
...(prevState || initialSearchResponse),
answer,
}));
const updateQuotes = (quotes: Record<string, Quote>) =>
const updateQuotes = (quotes: Quote[]) =>
setSearchResponse((prevState) => ({
...(prevState || initialSearchResponse),
quotes,

View File

@ -13,11 +13,12 @@ export const SearchType = {
export type SearchType = (typeof SearchType)[keyof typeof SearchType];
export interface Quote {
quote: string;
document_id: string;
link: string;
link: string | null;
source_type: ValidSources;
blurb: string;
semantic_identifier: string | null;
semantic_identifier: string;
}
export interface DanswerDocument {
@ -32,7 +33,7 @@ export interface SearchResponse {
suggestedSearchType: SearchType | null;
suggestedFlowType: FlowType | null;
answer: string | null;
quotes: Record<string, Quote> | null;
quotes: Quote[] | null;
documents: DanswerDocument[] | null;
error: string | null;
}
@ -51,7 +52,7 @@ export interface SearchRequestArgs {
query: string;
sources: Source[];
updateCurrentAnswer: (val: string) => void;
updateQuotes: (quotes: Record<string, Quote>) => void;
updateQuotes: (quotes: Quote[]) => void;
updateDocs: (documents: DanswerDocument[]) => void;
updateSuggestedSearchType: (searchType: SearchType) => void;
updateSuggestedFlowType: (flowType: FlowType) => void;

View File

@ -24,7 +24,7 @@ export const searchRequest = async ({
}
let answer = "";
let quotes: Record<string, Quote> | null = null;
let quotes: Quote[] | null = null;
let relevantDocuments: DanswerDocument[] | null = null;
try {
const response = await fetch("/api/direct-qa", {
@ -54,7 +54,7 @@ export const searchRequest = async ({
const data = (await response.json()) as {
answer: string;
quotes: Record<string, Quote>;
quotes: Quote[];
top_ranked_docs: DanswerDocument[];
lower_ranked_docs: DanswerDocument[];
predicted_flow: FlowType;

View File

@ -69,7 +69,7 @@ export const searchRequestStreamed = async ({
}
let answer = "";
let quotes: Record<string, Quote> | null = null;
let quotes: Quote[] | null = null;
let relevantDocuments: DanswerDocument[] | null = null;
try {
const response = await fetch("/api/stream-direct-qa", {
@ -118,18 +118,17 @@ export const searchRequestStreamed = async ({
previousPartialChunk = partialChunk;
completedChunks.forEach((chunk) => {
// TODO: clean up response / this logic
const answerChunk = chunk.answer_data;
const answerChunk = chunk.answer_piece;
if (answerChunk) {
answer += answerChunk;
updateCurrentAnswer(answer);
return;
}
const answerFinished = chunk.answer_finished;
if (answerFinished) {
if (answerChunk === null) {
// set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes
updateQuotes({});
updateQuotes([]);
if (
answer &&
!answer.endsWith(".") &&
@ -168,9 +167,15 @@ export const searchRequestStreamed = async ({
return;
}
// if it doesn't match any of the above, assume it is a quote
quotes = chunk as Record<string, Quote>;
updateQuotes(quotes);
// Check for quote section
if (chunk.quotes) {
quotes = chunk.quotes as Quote[];
updateQuotes(quotes);
return;
}
// should never reach this
console.log("Unknown chunk:", chunk);
});
}
} catch (err) {