mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-04-12 05:49:36 +02:00
Stop using untyped dicts to represent quotes (#310)
This commit is contained in:
parent
81a4934bb8
commit
f37ac76d3c
@ -15,6 +15,7 @@ METADATA = "metadata"
|
||||
GEN_AI_API_KEY_STORAGE_KEY = "genai_api_key"
|
||||
HTML_SEPARATOR = "\n"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
QUOTE = "quote"
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
|
@ -1,111 +0,0 @@
|
||||
from typing import Any
|
||||
|
||||
import pkg_resources
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.local_transformers import TransformerQA
|
||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
if not model_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5)
|
||||
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_default_backend_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
**kwargs: Any,
|
||||
) -> QAModel:
|
||||
if not api_key:
|
||||
try:
|
||||
api_key = get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
if internal_model in [
|
||||
DanswerGenAIModel.GPT4ALL.value,
|
||||
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
||||
]:
|
||||
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
|
||||
pkg_resources.get_distribution("gpt4all")
|
||||
|
||||
if internal_model == DanswerGenAIModel.OPENAI.value:
|
||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
||||
return GPT4AllCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
||||
return GPT4AllChatCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
|
||||
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
|
||||
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
|
||||
return TransformerQA()
|
||||
elif internal_model == DanswerGenAIModel.REQUEST.value:
|
||||
if endpoint is None or model_host_type is None:
|
||||
raise ValueError(
|
||||
"Request based GenAI model requires an endpoint and host type"
|
||||
)
|
||||
if (
|
||||
model_host_type == ModelHostType.HUGGINGFACE.value
|
||||
or model_host_type == ModelHostType.COLAB_DEMO.value
|
||||
):
|
||||
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
||||
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
||||
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
prompt_processor=WeakModelFreeformProcessor(),
|
||||
timeout=timeout,
|
||||
)
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise UnknownModelError(internal_model)
|
@ -5,10 +5,9 @@ from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.qa_utils import structure_quotes_for_response
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
from danswer.search.models import QueryFlow
|
||||
@ -75,7 +74,7 @@ def answer_question(
|
||||
)
|
||||
|
||||
try:
|
||||
qa_model = get_default_backend_qa_model(timeout=answer_generation_timeout)
|
||||
qa_model = get_default_llm(timeout=answer_generation_timeout)
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
@ -104,7 +103,7 @@ def answer_question(
|
||||
|
||||
return QAResponse(
|
||||
answer=answer.answer if answer else None,
|
||||
quotes=structure_quotes_for_response(quotes),
|
||||
quotes=quotes.quotes if quotes else None,
|
||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||
predicted_flow=predicted_flow,
|
||||
|
@ -4,8 +4,12 @@ from typing import Any
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
@ -85,7 +89,7 @@ class GPT4AllCompletionQA(QAModel):
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -101,12 +105,12 @@ class GPT4AllCompletionQA(QAModel):
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
answer, quotes = process_answer(model_output, context_docs)
|
||||
return answer, quotes
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -150,7 +154,7 @@ class GPT4AllChatCompletionQA(QAModel):
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -177,7 +181,7 @@ class GPT4AllChatCompletionQA(QAModel):
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
|
@ -7,8 +7,12 @@ from huggingface_hub.utils import HfHubHTTPError # type:ignore
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import FreeformProcessor
|
||||
@ -51,7 +55,7 @@ class HuggingFaceCompletionQA(QAModel):
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -68,7 +72,7 @@ class HuggingFaceCompletionQA(QAModel):
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
@ -165,7 +169,7 @@ class HuggingFaceChatCompletionQA(QAModel):
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
model_output = self._get_hf_model_output(query, context_docs)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
@ -174,7 +178,7 @@ class HuggingFaceChatCompletionQA(QAModel):
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
"""As of Aug 2023, HF conversational (chat) endpoints do not support streaming
|
||||
So here it is faked by streaming characters within Danswer from the model output
|
||||
"""
|
||||
|
@ -1,7 +1,6 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
|
||||
@ -11,6 +10,13 @@ class DanswerAnswer:
|
||||
answer: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerAnswerPiece:
|
||||
"""A small piece of a complete answer. Used for streaming back answers."""
|
||||
|
||||
answer_piece: str | None # if None, specifies the end of an Answer
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerQuote:
|
||||
# This is during inference so everything is a string by this point
|
||||
@ -22,6 +28,21 @@ class DanswerQuote:
|
||||
blurb: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerQuotes:
|
||||
"""A little clunky, but making this into a separate class so that the result from
|
||||
`answer_question_stream` is always a subclass of `dataclass` and can thus use `asdict()`
|
||||
"""
|
||||
|
||||
quotes: list[DanswerQuote]
|
||||
|
||||
|
||||
AnswerQuestionReturn = tuple[DanswerAnswer, DanswerQuotes]
|
||||
AnswerQuestionStreamReturn = Generator[
|
||||
DanswerAnswerPiece | DanswerQuotes | None, None, None
|
||||
]
|
||||
|
||||
|
||||
class QAModel:
|
||||
@property
|
||||
def requires_api_key(self) -> bool:
|
||||
@ -39,7 +60,7 @@ class QAModel:
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
@ -47,5 +68,5 @@ class QAModel:
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
raise NotImplementedError
|
||||
|
111
backend/danswer/direct_qa/llm_utils.py
Normal file
111
backend/danswer/direct_qa/llm_utils.py
Normal file
@ -0,0 +1,111 @@
|
||||
from typing import Any
|
||||
|
||||
import pkg_resources
|
||||
from openai.error import AuthenticationError
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.constants import DanswerGenAIModel
|
||||
from danswer.configs.constants import ModelHostType
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceChatCompletionQA
|
||||
from danswer.direct_qa.huggingface import HuggingFaceCompletionQA
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.local_transformers import TransformerQA
|
||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
if not model_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_llm(api_key=model_api_key, timeout=5)
|
||||
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
except Exception as e:
|
||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_default_llm(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
endpoint: str | None = GEN_AI_ENDPOINT,
|
||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||
api_key: str | None = GEN_AI_API_KEY,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
**kwargs: Any,
|
||||
) -> QAModel:
|
||||
if not api_key:
|
||||
try:
|
||||
api_key = get_gen_ai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
pass
|
||||
|
||||
if internal_model in [
|
||||
DanswerGenAIModel.GPT4ALL.value,
|
||||
DanswerGenAIModel.GPT4ALL_CHAT.value,
|
||||
]:
|
||||
# gpt4all is not compatible M1 Mac hardware as of Aug 2023
|
||||
pkg_resources.get_distribution("gpt4all")
|
||||
|
||||
if internal_model == DanswerGenAIModel.OPENAI.value:
|
||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL.value:
|
||||
return GPT4AllCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.GPT4ALL_CHAT.value:
|
||||
return GPT4AllChatCompletionQA(**kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE.value:
|
||||
return HuggingFaceCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.HUGGINGFACE_CHAT.value:
|
||||
return HuggingFaceChatCompletionQA(api_key=api_key, **kwargs)
|
||||
elif internal_model == DanswerGenAIModel.TRANSFORMERS:
|
||||
return TransformerQA()
|
||||
elif internal_model == DanswerGenAIModel.REQUEST.value:
|
||||
if endpoint is None or model_host_type is None:
|
||||
raise ValueError(
|
||||
"Request based GenAI model requires an endpoint and host type"
|
||||
)
|
||||
if (
|
||||
model_host_type == ModelHostType.HUGGINGFACE.value
|
||||
or model_host_type == ModelHostType.COLAB_DEMO.value
|
||||
):
|
||||
# Assuming user is hosting the smallest size LLMs with weaker capabilities and token limits
|
||||
# With the 7B Llama2 Chat model, there is a max limit of 1512 tokens
|
||||
# This is the sum of input and output tokens, so cannot take in full Danswer context
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
prompt_processor=WeakModelFreeformProcessor(),
|
||||
timeout=timeout,
|
||||
)
|
||||
return RequestCompletionQA(
|
||||
endpoint=endpoint,
|
||||
model_host_type=model_host_type,
|
||||
api_key=api_key,
|
||||
timeout=timeout,
|
||||
)
|
||||
else:
|
||||
raise UnknownModelError(internal_model)
|
@ -7,10 +7,13 @@ from transformers import QuestionAnsweringPipeline # type:ignore
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_utils import structure_quotes_for_response
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
|
||||
@ -104,7 +107,7 @@ class TransformerQA(QAModel):
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
danswer_quotes: list[DanswerQuote] = []
|
||||
d_answers: list[str] = []
|
||||
for chunk in context_docs:
|
||||
@ -118,11 +121,13 @@ class TransformerQA(QAModel):
|
||||
for ind, answer in enumerate(d_answers, start=1)
|
||||
]
|
||||
combined_answer = "\n".join(answers_list)
|
||||
return DanswerAnswer(answer=combined_answer), danswer_quotes
|
||||
return DanswerAnswer(answer=combined_answer), DanswerQuotes(
|
||||
quotes=danswer_quotes
|
||||
)
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
quotes: list[DanswerQuote] = []
|
||||
answers: list[str] = []
|
||||
for chunk in context_docs:
|
||||
@ -135,13 +140,14 @@ class TransformerQA(QAModel):
|
||||
answer_count = 1
|
||||
for answer in answers:
|
||||
if answer_count == 1:
|
||||
yield {"answer_data": "Source 1: "}
|
||||
yield DanswerAnswerPiece(answer_piece="Source 1: ")
|
||||
else:
|
||||
yield {"answer_data": f"\nSource {answer_count}: "}
|
||||
yield DanswerAnswerPiece(answer_piece=f"\nSource {answer_count}: ")
|
||||
answer_count += 1
|
||||
for char in answer.strip():
|
||||
yield {"answer_data": char}
|
||||
yield DanswerAnswerPiece(answer_piece=char)
|
||||
|
||||
yield {"answer_finished": True}
|
||||
# signal end of answer
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
|
||||
yield structure_quotes_for_response(quotes)
|
||||
yield DanswerQuotes(quotes=quotes)
|
||||
|
@ -22,8 +22,12 @@ from danswer.configs.model_configs import AZURE_DEPLOYMENT_ID
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
|
||||
@ -147,7 +151,7 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
@ -177,7 +181,7 @@ class OpenAICompletionQA(OpenAIQAModel):
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
@ -243,7 +247,7 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
@ -276,12 +280,12 @@ class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
answer, quotes = process_answer(model_output, context_docs)
|
||||
return answer, quotes
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
context_docs = _tiktoken_trim_chunks(context_docs, self.model_version)
|
||||
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
|
@ -2,7 +2,6 @@ import json
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
@ -11,15 +10,12 @@ import regex
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
@ -37,24 +33,6 @@ def get_gen_ai_api_key() -> str:
|
||||
)
|
||||
|
||||
|
||||
def structure_quotes_for_response(
|
||||
quotes: list[DanswerQuote] | None,
|
||||
) -> dict[str, dict[str, str | None]]:
|
||||
if quotes is None:
|
||||
return {}
|
||||
|
||||
response_quotes = {}
|
||||
for quote in quotes:
|
||||
response_quotes[quote.quote] = {
|
||||
DOCUMENT_ID: quote.document_id,
|
||||
SOURCE_LINK: quote.link,
|
||||
SOURCE_TYPE: quote.source_type,
|
||||
SEMANTIC_IDENTIFIER: quote.semantic_identifier,
|
||||
BLURB: quote.blurb,
|
||||
}
|
||||
return response_quotes
|
||||
|
||||
|
||||
def extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
@ -114,8 +92,8 @@ def match_quotes_to_docs(
|
||||
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
||||
fuzzy_search: bool = False,
|
||||
prefix_only_length: int = 100,
|
||||
) -> list[DanswerQuote]:
|
||||
danswer_quotes = []
|
||||
) -> DanswerQuotes:
|
||||
danswer_quotes: list[DanswerQuote] = []
|
||||
for quote in quotes:
|
||||
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
||||
|
||||
@ -145,23 +123,13 @@ def match_quotes_to_docs(
|
||||
# Extracting the link from the offset
|
||||
curr_link = None
|
||||
for link_offset, link in chunk.source_links.items():
|
||||
# Should always find one because offset is at least 0 and there must be a 0 link_offset
|
||||
# Should always find one because offset is at least 0 and there
|
||||
# must be a 0 link_offset
|
||||
if int(link_offset) <= offset:
|
||||
curr_link = link
|
||||
else:
|
||||
danswer_quotes.append(
|
||||
DanswerQuote(
|
||||
quote=quote,
|
||||
document_id=chunk.document_id,
|
||||
link=curr_link,
|
||||
source_type=chunk.source_type,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
blurb=chunk.blurb,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
# If the offset is larger than the start of the last quote, it must be the last one
|
||||
danswer_quotes.append(
|
||||
DanswerQuote(
|
||||
quote=quote,
|
||||
@ -174,24 +142,24 @@ def match_quotes_to_docs(
|
||||
)
|
||||
break
|
||||
|
||||
return danswer_quotes
|
||||
return DanswerQuotes(quotes=danswer_quotes)
|
||||
|
||||
|
||||
def process_answer(
|
||||
answer_raw: str, chunks: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw)
|
||||
if answer == UNCERTAINTY_PAT or not answer:
|
||||
if answer == UNCERTAINTY_PAT:
|
||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||
else:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
return DanswerAnswer(answer=None), []
|
||||
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
|
||||
|
||||
logger.info(f"Answer: {answer}")
|
||||
if not quote_strings:
|
||||
logger.debug("No quotes extracted from raw output")
|
||||
return DanswerAnswer(answer=answer), []
|
||||
return DanswerAnswer(answer=answer), DanswerQuotes(quotes=[])
|
||||
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
||||
quotes = match_quotes_to_docs(quote_strings, chunks)
|
||||
logger.info(f"Final quotes: {quotes}")
|
||||
@ -212,7 +180,7 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
|
||||
def extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[DanswerQuote]:
|
||||
) -> DanswerQuotes:
|
||||
logger.debug(model_output)
|
||||
answer, quotes = process_answer(model_output, context_chunks)
|
||||
if answer:
|
||||
@ -227,7 +195,7 @@ def process_model_tokens(
|
||||
tokens: Generator[str, None, None],
|
||||
context_docs: list[InferenceChunk],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
|
||||
"""Yields Answer tokens back out in a dict for streaming to frontend
|
||||
When Answer section ends, yields dict with answer_finished key
|
||||
Collects all the tokens at the end to form the complete model output"""
|
||||
@ -255,21 +223,20 @@ def process_model_tokens(
|
||||
if found_answer_start and not found_answer_end:
|
||||
if is_json_prompt and stream_json_answer_end(model_previous, token):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
elif not is_json_prompt:
|
||||
if quote_pat in hold_quote + token or quote_loose in hold_quote + token:
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
yield DanswerAnswerPiece(answer_piece=None)
|
||||
continue
|
||||
if hold_quote + token in quote_pat_full:
|
||||
hold_quote += token
|
||||
continue
|
||||
yield {"answer_data": hold_quote + token}
|
||||
yield DanswerAnswerPiece(answer_piece=token)
|
||||
hold_quote = ""
|
||||
|
||||
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
|
||||
yield structure_quotes_for_response(quotes)
|
||||
yield extract_quotes_from_completed_token_stream(model_output, context_docs)
|
||||
|
||||
|
||||
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]:
|
||||
|
@ -13,8 +13,8 @@ from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_ENDPOINT
|
||||
from danswer.configs.model_configs import GEN_AI_HOST_TYPE
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
@ -236,7 +236,7 @@ class RequestCompletionQA(QAModel):
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
) -> AnswerQuestionReturn:
|
||||
model_api_response = self._get_request_response(
|
||||
query, context_docs, stream=False
|
||||
)
|
||||
@ -253,7 +253,7 @@ class RequestCompletionQA(QAModel):
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
) -> AnswerQuestionStreamReturn:
|
||||
model_api_response = self._get_request_response(
|
||||
query, context_docs, stream=False
|
||||
)
|
||||
|
@ -13,6 +13,7 @@ from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.direct_qa.answer_question import answer_question
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchDoc
|
||||
@ -59,24 +60,22 @@ def _build_custom_semantic_identifier(
|
||||
return semantic_identifier
|
||||
|
||||
|
||||
def _process_quotes(
|
||||
quotes: dict[str, dict[str, str | None]] | None
|
||||
) -> tuple[str | None, list[str]]:
|
||||
def _process_quotes(quotes: list[DanswerQuote] | None) -> tuple[str | None, list[str]]:
|
||||
if not quotes:
|
||||
return None, []
|
||||
|
||||
quote_lines: list[str] = []
|
||||
doc_identifiers: list[str] = []
|
||||
for quote_dict in quotes.values():
|
||||
doc_id = str(quote_dict.get("document_id", ""))
|
||||
doc_link = quote_dict.get("link")
|
||||
doc_name = str(quote_dict.get("semantic_identifier", ""))
|
||||
for quote in quotes:
|
||||
doc_id = quote.document_id
|
||||
doc_link = quote.link
|
||||
doc_name = quote.semantic_identifier
|
||||
if doc_link and doc_name and doc_id and doc_id not in doc_identifiers:
|
||||
doc_identifiers.append(doc_id)
|
||||
custom_semantic_identifier = _build_custom_semantic_identifier(
|
||||
semantic_identifier=doc_name,
|
||||
blurb=str(quote_dict.get("blurb", "")),
|
||||
source=str(quote_dict.get("source_type", "")),
|
||||
blurb=quote.blurb,
|
||||
source=quote.source_type,
|
||||
)
|
||||
quote_lines.append(f"- <{doc_link}|{custom_semantic_identifier}>")
|
||||
|
||||
|
@ -33,7 +33,7 @@ from danswer.datastores.qdrant.indexing import list_qdrant_collections
|
||||
from danswer.datastores.typesense.store import check_typesense_collection_exist
|
||||
from danswer.datastores.typesense.store import create_typesense_collection
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.server.event_loading import router as event_processing_router
|
||||
from danswer.server.health import router as health_router
|
||||
from danswer.server.manage import router as admin_router
|
||||
@ -179,7 +179,7 @@ def get_application() -> FastAPI:
|
||||
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
qa_model = get_default_backend_qa_model()
|
||||
qa_model = get_default_llm()
|
||||
qa_model.warm_up_model()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
|
@ -9,7 +9,6 @@ from fastapi import Request
|
||||
from fastapi import Response
|
||||
from fastapi import UploadFile
|
||||
from fastapi_users.db import SQLAlchemyUserDatabase
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
from sqlalchemy.ext.asyncio import AsyncSession
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
@ -53,14 +52,10 @@ from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_latest_index_attempts
|
||||
from danswer.db.models import DeletionAttempt
|
||||
from danswer.db.models import DeletionStatus
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa import check_model_api_key_is_valid
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.direct_qa.open_ai import get_gen_ai_api_key
|
||||
from danswer.direct_qa.open_ai import OpenAIQAModel
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.models import ApiKey
|
||||
@ -361,7 +356,7 @@ def validate_existing_genai_api_key(
|
||||
) -> None:
|
||||
# OpenAI key is only used for generative QA, so no need to validate this
|
||||
# if it's turned off or if a non-OpenAI model is being used
|
||||
if DISABLE_GENERATIVE_AI or not get_default_backend_qa_model().requires_api_key:
|
||||
if DISABLE_GENERATIVE_AI or not get_default_llm().requires_api_key:
|
||||
return
|
||||
|
||||
# Only validate every so often
|
||||
|
@ -19,6 +19,7 @@ from danswer.db.models import DeletionAttempt
|
||||
from danswer.db.models import DeletionStatus
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.search.models import QueryFlow
|
||||
from danswer.search.models import SearchType
|
||||
from danswer.server.utils import mask_credential_dict
|
||||
@ -110,7 +111,7 @@ class SearchResponse(BaseModel):
|
||||
|
||||
class QAResponse(SearchResponse):
|
||||
answer: str | None # DanswerAnswer
|
||||
quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote
|
||||
quotes: list[DanswerQuote] | None
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
error_msg: str | None = None
|
||||
|
@ -1,5 +1,6 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
from dataclasses import asdict
|
||||
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
@ -12,10 +13,10 @@ from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.answer_question import answer_question
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.llm_utils import get_default_llm
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
@ -166,7 +167,7 @@ def stream_direct_qa(
|
||||
return
|
||||
|
||||
try:
|
||||
qa_model = get_default_backend_qa_model()
|
||||
qa_model = get_default_llm()
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
yield get_json_line({"error": str(e)})
|
||||
@ -178,16 +179,16 @@ def stream_direct_qa(
|
||||
"Chunks offset too large, should not retry this many times"
|
||||
)
|
||||
try:
|
||||
for response_dict in qa_model.answer_question_stream(
|
||||
for response_packet in qa_model.answer_question_stream(
|
||||
query,
|
||||
ranked_chunks[
|
||||
chunk_offset : chunk_offset + NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
],
|
||||
):
|
||||
if response_dict is None:
|
||||
if response_packet is None:
|
||||
continue
|
||||
logger.debug(f"Sending packet: {response_dict}")
|
||||
yield get_json_line(response_dict)
|
||||
logger.debug(f"Sending packet: {response_packet}")
|
||||
yield get_json_line(asdict(response_packet))
|
||||
except Exception as e:
|
||||
# exception is logged in the answer_question method, no need to re-log
|
||||
yield get_json_line({"error": str(e)})
|
||||
|
@ -56,7 +56,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
|
||||
const dedupedQuotes: Quote[] = [];
|
||||
const seen = new Set<string>();
|
||||
if (quotes) {
|
||||
Object.values(quotes).forEach((quote) => {
|
||||
quotes.forEach((quote) => {
|
||||
if (!seen.has(quote.document_id)) {
|
||||
dedupedQuotes.push(quote);
|
||||
seen.add(quote.document_id);
|
||||
@ -109,7 +109,7 @@ export const SearchResultsDisplay: React.FC<SearchResultsDisplayProps> = ({
|
||||
<a
|
||||
key={quoteInfo.document_id}
|
||||
className="p-2 ml-1 border border-gray-800 rounded-lg text-sm flex max-w-[280px] hover:bg-gray-800"
|
||||
href={quoteInfo.link}
|
||||
href={quoteInfo.link || undefined}
|
||||
target="_blank"
|
||||
rel="noopener noreferrer"
|
||||
>
|
||||
|
@ -71,7 +71,7 @@ export const SearchSection: React.FC<SearchSectionProps> = ({
|
||||
...(prevState || initialSearchResponse),
|
||||
answer,
|
||||
}));
|
||||
const updateQuotes = (quotes: Record<string, Quote>) =>
|
||||
const updateQuotes = (quotes: Quote[]) =>
|
||||
setSearchResponse((prevState) => ({
|
||||
...(prevState || initialSearchResponse),
|
||||
quotes,
|
||||
|
@ -13,11 +13,12 @@ export const SearchType = {
|
||||
export type SearchType = (typeof SearchType)[keyof typeof SearchType];
|
||||
|
||||
export interface Quote {
|
||||
quote: string;
|
||||
document_id: string;
|
||||
link: string;
|
||||
link: string | null;
|
||||
source_type: ValidSources;
|
||||
blurb: string;
|
||||
semantic_identifier: string | null;
|
||||
semantic_identifier: string;
|
||||
}
|
||||
|
||||
export interface DanswerDocument {
|
||||
@ -32,7 +33,7 @@ export interface SearchResponse {
|
||||
suggestedSearchType: SearchType | null;
|
||||
suggestedFlowType: FlowType | null;
|
||||
answer: string | null;
|
||||
quotes: Record<string, Quote> | null;
|
||||
quotes: Quote[] | null;
|
||||
documents: DanswerDocument[] | null;
|
||||
error: string | null;
|
||||
}
|
||||
@ -51,7 +52,7 @@ export interface SearchRequestArgs {
|
||||
query: string;
|
||||
sources: Source[];
|
||||
updateCurrentAnswer: (val: string) => void;
|
||||
updateQuotes: (quotes: Record<string, Quote>) => void;
|
||||
updateQuotes: (quotes: Quote[]) => void;
|
||||
updateDocs: (documents: DanswerDocument[]) => void;
|
||||
updateSuggestedSearchType: (searchType: SearchType) => void;
|
||||
updateSuggestedFlowType: (flowType: FlowType) => void;
|
||||
|
@ -24,7 +24,7 @@ export const searchRequest = async ({
|
||||
}
|
||||
|
||||
let answer = "";
|
||||
let quotes: Record<string, Quote> | null = null;
|
||||
let quotes: Quote[] | null = null;
|
||||
let relevantDocuments: DanswerDocument[] | null = null;
|
||||
try {
|
||||
const response = await fetch("/api/direct-qa", {
|
||||
@ -54,7 +54,7 @@ export const searchRequest = async ({
|
||||
|
||||
const data = (await response.json()) as {
|
||||
answer: string;
|
||||
quotes: Record<string, Quote>;
|
||||
quotes: Quote[];
|
||||
top_ranked_docs: DanswerDocument[];
|
||||
lower_ranked_docs: DanswerDocument[];
|
||||
predicted_flow: FlowType;
|
||||
|
@ -69,7 +69,7 @@ export const searchRequestStreamed = async ({
|
||||
}
|
||||
|
||||
let answer = "";
|
||||
let quotes: Record<string, Quote> | null = null;
|
||||
let quotes: Quote[] | null = null;
|
||||
let relevantDocuments: DanswerDocument[] | null = null;
|
||||
try {
|
||||
const response = await fetch("/api/stream-direct-qa", {
|
||||
@ -118,18 +118,17 @@ export const searchRequestStreamed = async ({
|
||||
previousPartialChunk = partialChunk;
|
||||
completedChunks.forEach((chunk) => {
|
||||
// TODO: clean up response / this logic
|
||||
const answerChunk = chunk.answer_data;
|
||||
const answerChunk = chunk.answer_piece;
|
||||
if (answerChunk) {
|
||||
answer += answerChunk;
|
||||
updateCurrentAnswer(answer);
|
||||
return;
|
||||
}
|
||||
|
||||
const answerFinished = chunk.answer_finished;
|
||||
if (answerFinished) {
|
||||
if (answerChunk === null) {
|
||||
// set quotes as non-null to signify that the answer is finished and
|
||||
// we're now looking for quotes
|
||||
updateQuotes({});
|
||||
updateQuotes([]);
|
||||
if (
|
||||
answer &&
|
||||
!answer.endsWith(".") &&
|
||||
@ -168,9 +167,15 @@ export const searchRequestStreamed = async ({
|
||||
return;
|
||||
}
|
||||
|
||||
// if it doesn't match any of the above, assume it is a quote
|
||||
quotes = chunk as Record<string, Quote>;
|
||||
updateQuotes(quotes);
|
||||
// Check for quote section
|
||||
if (chunk.quotes) {
|
||||
quotes = chunk.quotes as Quote[];
|
||||
updateQuotes(quotes);
|
||||
return;
|
||||
}
|
||||
|
||||
// should never reach this
|
||||
console.log("Unknown chunk:", chunk);
|
||||
});
|
||||
}
|
||||
} catch (err) {
|
||||
|
Loading…
x
Reference in New Issue
Block a user