Stop using untyped dicts to represent quotes (#310)

This commit is contained in:
Chris Weaver
2023-08-17 14:53:55 -07:00
committed by GitHub
parent 81a4934bb8
commit f37ac76d3c
21 changed files with 247 additions and 239 deletions

View File

@@ -15,6 +15,7 @@ METADATA = "metadata"
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key" GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
HTML_SEPARATOR = "\n" HTML_SEPARATOR = "\n"
PUBLIC_DOC_PAT = "PUBLIC" PUBLIC_DOC_PAT = "PUBLIC"
QUOTE = "quote"
class DocumentSource(str, Enum): class DocumentSource(str, Enum):

View File

@@ -1,111 +0,0 @@
from typing import Any
import pkg_resources
from openai.error import AuthenticationError
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.local_transformers import TransformerQA
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
def get_default_backend_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
**kwargs: Any,
) -> QAModel:
if not api_key:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
pass
if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
]:
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
pkg_resources.get_distribution("gpt4all")
if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
return GPT4AllCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
return GPT4AllChatCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
return TransformerQA()
elif internal_model == DanswerGenAIModel.REQUEST.value:
if endpoint is None or model_host_type is None:
raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if (
model_host_type == ModelHostType.HUGGINGFACE.value
or model_host_type == ModelHostType.COLAB_DEMO.value
):
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
prompt_processor=WeakModelFreeformProcessor(),
timeout=timeout,
)
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
timeout=timeout,
)
else:
raise UnknownModelError(internal_model)

View File

@@ -5,10 +5,9 @@ from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.qa_utils import structure_quotes_for_response from danswer.direct_qa.llm_utils import get_default_llm
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow from danswer.search.models import QueryFlow
@@ -75,7 +74,7 @@ def answer_question(
) )
try: try:
qa_model = get_default_backend_qa_model(timeout=answer_generation_timeout) qa_model = get_default_llm(timeout=answer_generation_timeout)
except (UnknownModelError, OpenAIKeyMissing) as e: except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse( return QAResponse(
answer=None, answer=None,
@@ -104,7 +103,7 @@ def answer_question(
return QAResponse( return QAResponse(
answer=answer.answer if answer else None, answer=answer.answer if answer else None,
quotes=structure_quotes_for_response(quotes), quotes=quotes.quotes if quotes else None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks), top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks), lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow, predicted_flow=predicted_flow,

View File

@@ -4,8 +4,12 @@ from typing import Any
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
@@ -85,7 +89,7 @@ class GPT4AllCompletionQA(QAModel):
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata query, context_docs, self.include_metadata
) )
@@ -101,12 +105,12 @@ class GPT4AllCompletionQA(QAModel):
logger.debug(model_output) logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs) answer, quotes = process_answer(model_output, context_docs)
return answer, quotes_dict return answer, quotes
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata query, context_docs, self.include_metadata
) )
@@ -150,7 +154,7 @@ class GPT4AllChatCompletionQA(QAModel):
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata query, context_docs, self.include_metadata
) )
@@ -177,7 +181,7 @@ class GPT4AllChatCompletionQA(QAModel):
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata query, context_docs, self.include_metadata
) )

View File

@@ -7,8 +7,12 @@ from huggingface_hub.utils import HfHubHTTPError # type:ignore
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import FreeformProcessor from danswer.direct_qa.qa_prompts import FreeformProcessor
@@ -51,7 +55,7 @@ class HuggingFaceCompletionQA(QAModel):
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata query, context_docs, self.include_metadata
) )
@@ -68,7 +72,7 @@ class HuggingFaceCompletionQA(QAModel):
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata query, context_docs, self.include_metadata
) )
@@ -165,7 +169,7 @@ class HuggingFaceChatCompletionQA(QAModel):
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
model_output = self._get_hf_model_output(query, context_docs) model_output = self._get_hf_model_output(query, context_docs)
answer, quotes_dict = process_answer(model_output, context_docs) answer, quotes_dict = process_answer(model_output, context_docs)
@@ -174,7 +178,7 @@ class HuggingFaceChatCompletionQA(QAModel):
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming """As of Aug 2023, HF conversational (chat) endpoints do not support streaming
So here it is faked by streaming characters within Danswer from the model output So here it is faked by streaming characters within Danswer from the model output
""" """

View File

@@ -1,7 +1,6 @@
import abc import abc
from collections.abc import Generator from collections.abc import Generator
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
@@ -11,6 +10,13 @@ class DanswerAnswer:
answer: str | None answer: str | None
@dataclass
class DanswerAnswerPiece:
"""A small piece of a complete answer. Used for streaming back answers."""
answer_piece: str | None # if None, specifies the end of an Answer
@dataclass @dataclass
class DanswerQuote: class DanswerQuote:
# This is during inference so everything is a string by this point # This is during inference so everything is a string by this point
@@ -22,6 +28,21 @@ class DanswerQuote:
blurb: str blurb: str
@dataclass
class DanswerQuotes:
"""A little clunky, but making this into a separate class so that the result from
`answer_question_stream` is always a subclass of `dataclass` and can thus use `asdict()`
"""
quotes: list[DanswerQuote]
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
AnswerQuestionStreamReturn = Generator[
DanswerAnswerPiece | DanswerQuotes | None, None, None
]
class QAModel: class QAModel:
@property @property
def requires_api_key(self) -> bool: def requires_api_key(self) -> bool:
@@ -39,7 +60,7 @@ class QAModel:
self, self,
query: str, query: str,
context_docs: list[InferenceChunk], context_docs: list[InferenceChunk],
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
raise NotImplementedError raise NotImplementedError
@abc.abstractmethod @abc.abstractmethod
@@ -47,5 +68,5 @@ class QAModel:
self, self,
query: str, query: str,
context_docs: list[InferenceChunk], context_docs: list[InferenceChunk],
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
raise NotImplementedError raise NotImplementedError

View File

@@ -0,0 +1,111 @@
from typing import Any
import pkg_resources
from openai.error import AuthenticationError
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import DanswerGenAIModel
from danswer.configs.constants import ModelHostType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.local_transformers import TransformerQA
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
qa_model = get_default_llm(api_key=model_api_key, timeout=5)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
def get_default_llm(
internal_model: str = INTERNAL_MODEL_VERSION,
endpoint: str | None = GEN_AI_ENDPOINT,
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
**kwargs: Any,
) -> QAModel:
if not api_key:
try:
api_key = get_gen_ai_api_key()
except ConfigNotFoundError:
pass
if internal_model in [
DanswerGenAIModel.GPT4ALL.value,
DanswerGenAIModel.GPT4ALL_CHAT.value,
]:
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
pkg_resources.get_distribution("gpt4all")
if internal_model == DanswerGenAIModel.OPENAI.value:
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
return GPT4AllCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
return GPT4AllChatCompletionQA(**kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
return TransformerQA()
elif internal_model == DanswerGenAIModel.REQUEST.value:
if endpoint is None or model_host_type is None:
raise ValueError(
"Request based GenAI model requires an endpoint and host type"
)
if (
model_host_type == ModelHostType.HUGGINGFACE.value
or model_host_type == ModelHostType.COLAB_DEMO.value
):
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
# This is the sum of input and output tokens, so cannot take in full Danswer context
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
prompt_processor=WeakModelFreeformProcessor(),
timeout=timeout,
)
return RequestCompletionQA(
endpoint=endpoint,
model_host_type=model_host_type,
api_key=api_key,
timeout=timeout,
)
else:
raise UnknownModelError(internal_model)

View File

@@ -7,10 +7,13 @@ from transformers import QuestionAnsweringPipeline # type:ignore
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_utils import structure_quotes_for_response
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time from danswer.utils.timing import log_function_time
@@ -104,7 +107,7 @@ class TransformerQA(QAModel):
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
danswer_quotes: list[DanswerQuote] = [] danswer_quotes: list[DanswerQuote] = []
d_answers: list[str] = [] d_answers: list[str] = []
for chunk in context_docs: for chunk in context_docs:
@@ -118,11 +121,13 @@ class TransformerQA(QAModel):
for ind, answer in enumerate(d_answers, start=1) for ind, answer in enumerate(d_answers, start=1)
] ]
combined_answer = "\n".join(answers_list) combined_answer = "\n".join(answers_list)
return DanswerAnswer(answer=combined_answer), danswer_quotes return DanswerAnswer(answer=combined_answer), DanswerQuotes(
quotes=danswer_quotes
)
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
quotes: list[DanswerQuote] = [] quotes: list[DanswerQuote] = []
answers: list[str] = [] answers: list[str] = []
for chunk in context_docs: for chunk in context_docs:
@@ -135,13 +140,14 @@ class TransformerQA(QAModel):
answer_count = 1 answer_count = 1
for answer in answers: for answer in answers:
if answer_count == 1: if answer_count == 1:
yield {"answer_data": "Source 1: "} yield DanswerAnswerPiece(answer_piece="Source 1: ")
else: else:
yield {"answer_data": f"\nSource {answer_count}: "} yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ")
answer_count += 1 answer_count += 1
for char in answer.strip(): for char in answer.strip():
yield {"answer_data": char} yield DanswerAnswerPiece(answer_piece=char)
yield {"answer_finished": True} # signal end of answer
yield DanswerAnswerPiece(answer_piece=None)
yield structure_quotes_for_response(quotes) yield DanswerQuotes(quotes=quotes)

View File

@@ -22,8 +22,12 @@ from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
@@ -147,7 +151,7 @@ class OpenAICompletionQA(OpenAIQAModel):
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
@@ -177,7 +181,7 @@ class OpenAICompletionQA(OpenAIQAModel):
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
filled_prompt = self.prompt_processor.fill_prompt( filled_prompt = self.prompt_processor.fill_prompt(
@@ -243,7 +247,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
self, self,
query: str, query: str,
context_docs: list[InferenceChunk], context_docs: list[InferenceChunk],
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
messages = self.prompt_processor.fill_prompt( messages = self.prompt_processor.fill_prompt(
@@ -276,12 +280,12 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
logger.debug(model_output) logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs) answer, quotes = process_answer(model_output, context_docs)
return answer, quotes_dict return answer, quotes
def answer_question_stream( def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version) context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
messages = self.prompt_processor.fill_prompt( messages = self.prompt_processor.fill_prompt(

View File

@@ -2,7 +2,6 @@ import json
import math import math
import re import re
from collections.abc import Generator from collections.abc import Generator
from typing import Any
from typing import cast from typing import cast
from typing import Optional from typing import Optional
from typing import Tuple from typing import Tuple
@@ -11,15 +10,12 @@ import regex
from danswer.chunking.models import InferenceChunk from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import GEN_AI_API_KEY from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.qa_prompts import ANSWER_PAT from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
@@ -37,24 +33,6 @@ def get_gen_ai_api_key() -> str:
) )
def structure_quotes_for_response(
quotes: list[DanswerQuote] | None,
) -> dict[str, dict[str, str | None]]:
if quotes is None:
return {}
response_quotes = {}
for quote in quotes:
response_quotes[quote.quote] = {
DOCUMENT_ID: quote.document_id,
SOURCE_LINK: quote.link,
SOURCE_TYPE: quote.source_type,
SEMANTIC_IDENTIFIER: quote.semantic_identifier,
BLURB: quote.blurb,
}
return response_quotes
def extract_answer_quotes_freeform( def extract_answer_quotes_freeform(
answer_raw: str, answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]: ) -> Tuple[Optional[str], Optional[list[str]]]:
@@ -114,8 +92,8 @@ def match_quotes_to_docs(
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False, fuzzy_search: bool = False,
prefix_only_length: int = 100, prefix_only_length: int = 100,
) -> list[DanswerQuote]: ) -> DanswerQuotes:
danswer_quotes = [] danswer_quotes: list[DanswerQuote] = []
for quote in quotes: for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent) max_edits = math.ceil(float(len(quote)) * max_error_percent)
@@ -145,23 +123,13 @@ def match_quotes_to_docs(
# Extracting the link from the offset # Extracting the link from the offset
curr_link = None curr_link = None
for link_offset, link in chunk.source_links.items(): for link_offset, link in chunk.source_links.items():
# Should always find one because offset is at least 0 and there must be a 0 link_offset # Should always find one because offset is at least 0 and there
# must be a 0 link_offset
if int(link_offset) <= offset: if int(link_offset) <= offset:
curr_link = link curr_link = link
else: else:
danswer_quotes.append(
DanswerQuote(
quote=quote,
document_id=chunk.document_id,
link=curr_link,
source_type=chunk.source_type,
semantic_identifier=chunk.semantic_identifier,
blurb=chunk.blurb,
)
)
break break
# If the offset is larger than the start of the last quote, it must be the last one
danswer_quotes.append( danswer_quotes.append(
DanswerQuote( DanswerQuote(
quote=quote, quote=quote,
@@ -174,24 +142,24 @@ def match_quotes_to_docs(
) )
break break
return danswer_quotes return DanswerQuotes(quotes=danswer_quotes)
def process_answer( def process_answer(
answer_raw: str, chunks: list[InferenceChunk] answer_raw: str, chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> tuple[DanswerAnswer, DanswerQuotes]:
answer, quote_strings = separate_answer_quotes(answer_raw) answer, quote_strings = separate_answer_quotes(answer_raw)
if answer == UNCERTAINTY_PAT or not answer: if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT: if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT") logger.debug("Answer matched UNCERTAINTY_PAT")
else: else:
logger.debug("No answer extracted from raw output") logger.debug("No answer extracted from raw output")
return DanswerAnswer(answer=None), [] return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
logger.info(f"Answer: {answer}") logger.info(f"Answer: {answer}")
if not quote_strings: if not quote_strings:
logger.debug("No quotes extracted from raw output") logger.debug("No quotes extracted from raw output")
return DanswerAnswer(answer=answer), [] return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
logger.info(f"All quotes (including unmatched): {quote_strings}") logger.info(f"All quotes (including unmatched): {quote_strings}")
quotes = match_quotes_to_docs(quote_strings, chunks) quotes = match_quotes_to_docs(quote_strings, chunks)
logger.info(f"Final quotes: {quotes}") logger.info(f"Final quotes: {quotes}")
@@ -212,7 +180,7 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
def extract_quotes_from_completed_token_stream( def extract_quotes_from_completed_token_stream(
model_output: str, context_chunks: list[InferenceChunk] model_output: str, context_chunks: list[InferenceChunk]
) -> list[DanswerQuote]: ) -> DanswerQuotes:
logger.debug(model_output) logger.debug(model_output)
answer, quotes = process_answer(model_output, context_chunks) answer, quotes = process_answer(model_output, context_chunks)
if answer: if answer:
@@ -227,7 +195,7 @@ def process_model_tokens(
tokens: Generator[str, None, None], tokens: Generator[str, None, None],
context_docs: list[InferenceChunk], context_docs: list[InferenceChunk],
is_json_prompt: bool = True, is_json_prompt: bool = True,
) -> Generator[dict[str, Any], None, None]: ) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Yields Answer tokens back out in a dict for streaming to frontend """Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output""" Collects all the tokens at the end to form the complete model output"""
@@ -255,21 +223,20 @@ def process_model_tokens(
if found_answer_start and not found_answer_end: if found_answer_start and not found_answer_end:
if is_json_prompt and stream_json_answer_end(model_previous, token): if is_json_prompt and stream_json_answer_end(model_previous, token):
found_answer_end = True found_answer_end = True
yield {"answer_finished": True} yield DanswerAnswerPiece(answer_piece=None)
continue continue
elif not is_json_prompt: elif not is_json_prompt:
if quote_pat in hold_quote + token or quote_loose in hold_quote + token: if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
found_answer_end = True found_answer_end = True
yield {"answer_finished": True} yield DanswerAnswerPiece(answer_piece=None)
continue continue
if hold_quote + token in quote_pat_full: if hold_quote + token in quote_pat_full:
hold_quote += token hold_quote += token
continue continue
yield {"answer_data": hold_quote + token} yield DanswerAnswerPiece(answer_piece=token)
hold_quote = "" hold_quote = ""
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs) yield extract_quotes_from_completed_token_stream(model_output, context_docs)
yield structure_quotes_for_response(quotes)
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:

View File

@@ -13,8 +13,8 @@ from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.configs.model_configs import GEN_AI_ENDPOINT from danswer.configs.model_configs import GEN_AI_ENDPOINT
from danswer.configs.model_configs import GEN_AI_HOST_TYPE from danswer.configs.model_configs import GEN_AI_HOST_TYPE
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.direct_qa.interfaces import DanswerAnswer from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import DanswerQuote from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import QAModel from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import JsonProcessor from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
@@ -236,7 +236,7 @@ class RequestCompletionQA(QAModel):
@log_function_time() @log_function_time()
def answer_question( def answer_question(
self, query: str, context_docs: list[InferenceChunk] self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]: ) -> AnswerQuestionReturn:
model_api_response = self._get_request_response( model_api_response = self._get_request_response(
query, context_docs, stream=False query, context_docs, stream=False
) )
@@ -253,7 +253,7 @@ class RequestCompletionQA(QAModel):
self, self,
query: str, query: str,
context_docs: list[InferenceChunk], context_docs: list[InferenceChunk],
) -> Generator[dict[str, Any] | None, None, None]: ) -> AnswerQuestionStreamReturn:
model_api_response = self._get_request_response( model_api_response = self._get_request_response(
query, context_docs, stream=False query, context_docs, stream=False
) )

View File

@@ -13,6 +13,7 @@ from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.constants import DocumentSource from danswer.configs.constants import DocumentSource
from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.direct_qa.answer_question import answer_question from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.server.models import QAResponse from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest from danswer.server.models import QuestionRequest
from danswer.server.models import SearchDoc from danswer.server.models import SearchDoc
@@ -59,24 +60,22 @@ def _build_custom_semantic_identifier(
return semantic_identifier return semantic_identifier
def _process_quotes( def _process_quotes(quotes: list[DanswerQuote] | None) -> tuple[str | None, list[str]]:
quotes: dict[str, dict[str, str | None]] | None
) -> tuple[str | None, list[str]]:
if not quotes: if not quotes:
return None, [] return None, []
quote_lines: list[str] = [] quote_lines: list[str] = []
doc_identifiers: list[str] = [] doc_identifiers: list[str] = []
for quote_dict in quotes.values(): for quote in quotes:
doc_id = str(quote_dict.get("document_id", "")) doc_id = quote.document_id
doc_link = quote_dict.get("link") doc_link = quote.link
doc_name = str(quote_dict.get("semantic_identifier", "")) doc_name = quote.semantic_identifier
if doc_link and doc_name and doc_id and doc_id not in doc_identifiers: if doc_link and doc_name and doc_id and doc_id not in doc_identifiers:
doc_identifiers.append(doc_id) doc_identifiers.append(doc_id)
custom_semantic_identifier = _build_custom_semantic_identifier( custom_semantic_identifier = _build_custom_semantic_identifier(
semantic_identifier=doc_name, semantic_identifier=doc_name,
blurb=str(quote_dict.get("blurb", "")), blurb=quote.blurb,
source=str(quote_dict.get("source_type", "")), source=quote.source_type,
) )
quote_lines.append(f"- <{doc_link}|{custom_semantic_identifier}>") quote_lines.append(f"- <{doc_link}|{custom_semantic_identifier}>")

View File

@@ -33,7 +33,7 @@ from danswer.datastores.qdrant.indexing import list_qdrant_collections
from danswer.datastores.typesense.store import check_typesense_collection_exist from danswer.datastores.typesense.store import check_typesense_collection_exist
from danswer.datastores.typesense.store import create_typesense_collection from danswer.datastores.typesense.store import create_typesense_collection
from danswer.db.credentials import create_initial_public_credential from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.llm_utils import get_default_llm
from danswer.server.event_loading import router as event_processing_router from danswer.server.event_loading import router as event_processing_router
from danswer.server.health import router as health_router from danswer.server.health import router as health_router
from danswer.server.manage import router as admin_router from danswer.server.manage import router as admin_router
@@ -179,7 +179,7 @@ def get_application() -> FastAPI:
logger.info("Warming up local NLP models.") logger.info("Warming up local NLP models.")
warm_up_models() warm_up_models()
qa_model = get_default_backend_qa_model() qa_model = get_default_llm()
qa_model.warm_up_model() qa_model.warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded") logger.info("Verifying query preprocessing (NLTK) data is downloaded")

View File

@@ -9,7 +9,6 @@ from fastapi import Request
from fastapi import Response from fastapi import Response
from fastapi import UploadFile from fastapi import UploadFile
from fastapi_users.db import SQLAlchemyUserDatabase from fastapi_users.db import SQLAlchemyUserDatabase
from sqlalchemy.exc import IntegrityError
from sqlalchemy.ext.asyncio import AsyncSession from sqlalchemy.ext.asyncio import AsyncSession
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
@@ -53,14 +52,10 @@ from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_latest_index_attempts from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.models import DeletionAttempt from danswer.db.models import DeletionAttempt
from danswer.db.models import DeletionStatus
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import check_model_api_key_is_valid from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.llm_utils import get_default_llm
from danswer.direct_qa.open_ai import get_gen_ai_api_key from danswer.direct_qa.open_ai import get_gen_ai_api_key
from danswer.direct_qa.open_ai import OpenAIQAModel
from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey from danswer.server.models import ApiKey
@@ -361,7 +356,7 @@ def validate_existing_genai_api_key(
) -> None: ) -> None:
# OpenAI key is only used for generative QA, so no need to validate this # OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off or if a non-OpenAI model is being used # if it's turned off or if a non-OpenAI model is being used
if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key: if DISABLE_GENERATIVE_AI or not get_default_llm().requires_api_key:
return return
# Only validate every so often # Only validate every so often

View File

@@ -19,6 +19,7 @@ from danswer.db.models import DeletionAttempt
from danswer.db.models import DeletionStatus from danswer.db.models import DeletionStatus
from danswer.db.models import IndexAttempt from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus from danswer.db.models import IndexingStatus
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.search.models import QueryFlow from danswer.search.models import QueryFlow
from danswer.search.models import SearchType from danswer.search.models import SearchType
from danswer.server.utils import mask_credential_dict from danswer.server.utils import mask_credential_dict
@@ -110,7 +111,7 @@ class SearchResponse(BaseModel):
class QAResponse(SearchResponse): class QAResponse(SearchResponse):
answer: str | None # DanswerAnswer answer: str | None # DanswerAnswer
quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote quotes: list[DanswerQuote] | None
predicted_flow: QueryFlow predicted_flow: QueryFlow
predicted_search: SearchType predicted_search: SearchType
error_msg: str | None = None error_msg: str | None = None

View File

@@ -1,5 +1,6 @@
import json import json
from collections.abc import Generator from collections.abc import Generator
from dataclasses import asdict
from fastapi import APIRouter from fastapi import APIRouter
from fastapi import Depends from fastapi import Depends
@@ -12,10 +13,10 @@ from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.answer_question import answer_question from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm_utils import get_default_llm
from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.keyword_search import retrieve_keyword_documents
@@ -166,7 +167,7 @@ def stream_direct_qa(
return return
try: try:
qa_model = get_default_backend_qa_model() qa_model = get_default_llm()
except (UnknownModelError, OpenAIKeyMissing) as e: except (UnknownModelError, OpenAIKeyMissing) as e:
logger.exception("Unable to get QA model") logger.exception("Unable to get QA model")
yield get_json_line({"error": str(e)}) yield get_json_line({"error": str(e)})
@@ -178,16 +179,16 @@ def stream_direct_qa(
"Chunks offset too large, should not retry this many times" "Chunks offset too large, should not retry this many times"
) )
try: try:
for response_dict in qa_model.answer_question_stream( for response_packet in qa_model.answer_question_stream(
query, query,
ranked_chunks[ ranked_chunks[
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
], ],
): ):
if response_dict is None: if response_packet is None:
continue continue
logger.debug(f"Sending packet: {response_dict}") logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(response_dict) yield get_json_line(asdict(response_packet))
except Exception as e: except Exception as e:
# exception is logged in the answer_question method, no need to re-log # exception is logged in the answer_question method, no need to re-log
yield get_json_line({"error": str(e)}) yield get_json_line({"error": str(e)})

View File

@@ -56,7 +56,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
const dedupedQuotes: Quote[] = []; const dedupedQuotes: Quote[] = [];
const seen = new Set<string>(); const seen = new Set<string>();
if (quotes) { if (quotes) {
Object.values(quotes).forEach((quote) => { quotes.forEach((quote) => {
if (!seen.has(quote.document_id)) { if (!seen.has(quote.document_id)) {
dedupedQuotes.push(quote); dedupedQuotes.push(quote);
seen.add(quote.document_id); seen.add(quote.document_id);
@@ -109,7 +109,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
<a <a
key={quoteInfo.document_id} key={quoteInfo.document_id}
className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800" className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800"
href={quoteInfo.link} href={quoteInfo.link || undefined}
target="_blank" target="_blank"
rel="noopener noreferrer" rel="noopener noreferrer"
> >

View File

@@ -71,7 +71,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
...(prevState || initialSearchResponse), ...(prevState || initialSearchResponse),
answer, answer,
})); }));
const updateQuotes = (quotes: Record<string, Quote>) => const updateQuotes = (quotes: Quote[]) =>
setSearchResponse((prevState) => ({ setSearchResponse((prevState) => ({
...(prevState || initialSearchResponse), ...(prevState || initialSearchResponse),
quotes, quotes,

View File

@@ -13,11 +13,12 @@ export const SearchType = {
export type SearchType = (typeof SearchType)[keyof typeof SearchType]; export type SearchType = (typeof SearchType)[keyof typeof SearchType];
export interface Quote { export interface Quote {
quote: string;
document_id: string; document_id: string;
link: string; link: string | null;
source_type: ValidSources; source_type: ValidSources;
blurb: string; blurb: string;
semantic_identifier: string | null; semantic_identifier: string;
} }
export interface DanswerDocument { export interface DanswerDocument {
@@ -32,7 +33,7 @@ export interface SearchResponse {
suggestedSearchType: SearchType | null; suggestedSearchType: SearchType | null;
suggestedFlowType: FlowType | null; suggestedFlowType: FlowType | null;
answer: string | null; answer: string | null;
quotes: Record<string, Quote> | null; quotes: Quote[] | null;
documents: DanswerDocument[] | null; documents: DanswerDocument[] | null;
error: string | null; error: string | null;
} }
@@ -51,7 +52,7 @@ export interface SearchRequestArgs {
query: string; query: string;
sources: Source[]; sources: Source[];
updateCurrentAnswer: (val: string) => void; updateCurrentAnswer: (val: string) => void;
updateQuotes: (quotes: Record<string, Quote>) => void; updateQuotes: (quotes: Quote[]) => void;
updateDocs: (documents: DanswerDocument[]) => void; updateDocs: (documents: DanswerDocument[]) => void;
updateSuggestedSearchType: (searchType: SearchType) => void; updateSuggestedSearchType: (searchType: SearchType) => void;
updateSuggestedFlowType: (flowType: FlowType) => void; updateSuggestedFlowType: (flowType: FlowType) => void;

View File

@@ -24,7 +24,7 @@ export const searchRequest = async ({
} }
let answer = ""; let answer = "";
let quotes: Record<string, Quote> | null = null; let quotes: Quote[] | null = null;
let relevantDocuments: DanswerDocument[] | null = null; let relevantDocuments: DanswerDocument[] | null = null;
try { try {
const response = await fetch("/api/direct-qa", { const response = await fetch("/api/direct-qa", {
@@ -54,7 +54,7 @@ export const searchRequest = async ({
const data = (await response.json()) as { const data = (await response.json()) as {
answer: string; answer: string;
quotes: Record<string, Quote>; quotes: Quote[];
top_ranked_docs: DanswerDocument[]; top_ranked_docs: DanswerDocument[];
lower_ranked_docs: DanswerDocument[]; lower_ranked_docs: DanswerDocument[];
predicted_flow: FlowType; predicted_flow: FlowType;

View File

@@ -69,7 +69,7 @@ export const searchRequestStreamed = async ({
} }
let answer = ""; let answer = "";
let quotes: Record<string, Quote> | null = null; let quotes: Quote[] | null = null;
let relevantDocuments: DanswerDocument[] | null = null; let relevantDocuments: DanswerDocument[] | null = null;
try { try {
const response = await fetch("/api/stream-direct-qa", { const response = await fetch("/api/stream-direct-qa", {
@@ -118,18 +118,17 @@ export const searchRequestStreamed = async ({
previousPartialChunk = partialChunk; previousPartialChunk = partialChunk;
completedChunks.forEach((chunk) => { completedChunks.forEach((chunk) => {
// TODO: clean up response / this logic // TODO: clean up response / this logic
const answerChunk = chunk.answer_data; const answerChunk = chunk.answer_piece;
if (answerChunk) { if (answerChunk) {
answer += answerChunk; answer += answerChunk;
updateCurrentAnswer(answer); updateCurrentAnswer(answer);
return; return;
} }
const answerFinished = chunk.answer_finished; if (answerChunk === null) {
if (answerFinished) {
// set quotes as non-null to signify that the answer is finished and // set quotes as non-null to signify that the answer is finished and
// we're now looking for quotes // we're now looking for quotes
updateQuotes({}); updateQuotes([]);
if ( if (
answer && answer &&
!answer.endsWith(".") && !answer.endsWith(".") &&
@@ -168,9 +167,15 @@ export const searchRequestStreamed = async ({
return; return;
} }
// if it doesn't match any of the above, assume it is a quote // Check for quote section
quotes = chunk as Record<string, Quote>; if (chunk.quotes) {
updateQuotes(quotes); quotes = chunk.quotes as Quote[];
updateQuotes(quotes);
return;
}
// should never reach this
console.log("Unknown chunk:", chunk);
}); });
} }
} catch (err) { } catch (err) {