mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-26 20:08:38 +02:00
Slack CoT Scratchpad (#421)
This commit is contained in:
@@ -42,6 +42,7 @@ def handle_message(
|
|||||||
user=None,
|
user=None,
|
||||||
db_session=db_session,
|
db_session=db_session,
|
||||||
answer_generation_timeout=answer_generation_timeout,
|
answer_generation_timeout=answer_generation_timeout,
|
||||||
|
real_time_flow=False,
|
||||||
)
|
)
|
||||||
if not answer.error_msg:
|
if not answer.error_msg:
|
||||||
return answer
|
return answer
|
||||||
|
@@ -209,7 +209,7 @@ DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
|
|||||||
)
|
)
|
||||||
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
|
||||||
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
|
||||||
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "60")
|
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
|
||||||
)
|
)
|
||||||
DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
|
||||||
"DANSWER_BOT_DISPLAY_ERROR_MSGS", ""
|
"DANSWER_BOT_DISPLAY_ERROR_MSGS", ""
|
||||||
|
@@ -33,6 +33,7 @@ def answer_qa_query(
|
|||||||
db_session: Session,
|
db_session: Session,
|
||||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||||
answer_generation_timeout: int = QA_TIMEOUT,
|
answer_generation_timeout: int = QA_TIMEOUT,
|
||||||
|
real_time_flow: bool = True,
|
||||||
) -> QAResponse:
|
) -> QAResponse:
|
||||||
query = question.query
|
query = question.query
|
||||||
filters = question.filters
|
filters = question.filters
|
||||||
@@ -88,7 +89,9 @@ def answer_qa_query(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qa_model = get_default_qa_model(timeout=answer_generation_timeout)
|
qa_model = get_default_qa_model(
|
||||||
|
timeout=answer_generation_timeout, real_time_flow=real_time_flow
|
||||||
|
)
|
||||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=None,
|
answer=None,
|
||||||
|
@@ -22,6 +22,7 @@ from danswer.direct_qa.qa_block import QABlock
|
|||||||
from danswer.direct_qa.qa_block import QAHandler
|
from danswer.direct_qa.qa_block import QAHandler
|
||||||
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
from danswer.direct_qa.qa_block import SimpleChatQAHandler
|
||||||
from danswer.direct_qa.qa_block import SingleMessageQAHandler
|
from danswer.direct_qa.qa_block import SingleMessageQAHandler
|
||||||
|
from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
||||||
from danswer.direct_qa.request_model import RequestCompletionQA
|
from danswer.direct_qa.request_model import RequestCompletionQA
|
||||||
@@ -51,9 +52,13 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
|||||||
return False
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_default_qa_handler(model: str) -> QAHandler:
|
def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler:
|
||||||
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
if model == DanswerGenAIModel.OPENAI_CHAT.value:
|
||||||
return SingleMessageQAHandler()
|
return (
|
||||||
|
SingleMessageQAHandler()
|
||||||
|
if real_time_flow
|
||||||
|
else SingleMessageScratchpadHandler()
|
||||||
|
)
|
||||||
|
|
||||||
return SimpleChatQAHandler()
|
return SimpleChatQAHandler()
|
||||||
|
|
||||||
@@ -64,6 +69,7 @@ def get_default_qa_model(
|
|||||||
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
model_host_type: str | None = GEN_AI_HOST_TYPE,
|
||||||
api_key: str | None = GEN_AI_API_KEY,
|
api_key: str | None = GEN_AI_API_KEY,
|
||||||
timeout: int = QA_TIMEOUT,
|
timeout: int = QA_TIMEOUT,
|
||||||
|
real_time_flow: bool = True,
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> QAModel:
|
) -> QAModel:
|
||||||
if not api_key:
|
if not api_key:
|
||||||
@@ -76,7 +82,9 @@ def get_default_qa_model(
|
|||||||
# un-used arguments will be ignored by the underlying `LLM` class
|
# un-used arguments will be ignored by the underlying `LLM` class
|
||||||
# if any args are missing, a `TypeError` will be thrown
|
# if any args are missing, a `TypeError` will be thrown
|
||||||
llm = get_default_llm(timeout=timeout)
|
llm = get_default_llm(timeout=timeout)
|
||||||
qa_handler = get_default_qa_handler(model=internal_model)
|
qa_handler = get_default_qa_handler(
|
||||||
|
model=internal_model, real_time_flow=real_time_flow
|
||||||
|
)
|
||||||
|
|
||||||
return QABlock(
|
return QABlock(
|
||||||
llm=llm,
|
llm=llm,
|
||||||
|
@@ -13,21 +13,24 @@ from danswer.chunking.models import InferenceChunk
|
|||||||
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
from danswer.direct_qa.interfaces import AnswerQuestionReturn
|
||||||
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
|
||||||
from danswer.direct_qa.interfaces import DanswerQuotes
|
from danswer.direct_qa.interfaces import DanswerQuotes
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
|
||||||
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||||
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
from danswer.direct_qa.qa_prompts import QUESTION_PAT
|
||||||
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
|
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
|
||||||
|
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
|
from danswer.direct_qa.qa_utils import process_answer
|
||||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
from danswer.llm.llm import LLM
|
from danswer.llm.llm import LLM
|
||||||
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.llm.utils import str_prompt_to_langchain_prompt
|
from danswer.llm.utils import str_prompt_to_langchain_prompt
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.text_processing import escape_newlines
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@@ -43,11 +46,26 @@ class QAHandler(abc.ABC):
|
|||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@property
|
||||||
def process_response(
|
def is_json_output(self) -> bool:
|
||||||
|
"""Does the model expected to output a valid json"""
|
||||||
|
return True
|
||||||
|
|
||||||
|
def process_llm_output(
|
||||||
|
self, model_output: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||||
|
return process_answer(
|
||||||
|
model_output, context_chunks, is_json_prompt=self.is_json_output
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_llm_token_stream(
|
||||||
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||||
) -> AnswerQuestionStreamReturn:
|
) -> AnswerQuestionStreamReturn:
|
||||||
raise NotImplementedError
|
yield from process_model_tokens(
|
||||||
|
tokens=tokens,
|
||||||
|
context_docs=context_chunks,
|
||||||
|
is_json_prompt=self.is_json_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class JsonChatQAHandler(QAHandler):
|
class JsonChatQAHandler(QAHandler):
|
||||||
@@ -60,19 +78,12 @@ class JsonChatQAHandler(QAHandler):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self,
|
|
||||||
tokens: Iterator[str],
|
|
||||||
context_chunks: list[InferenceChunk],
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=tokens,
|
|
||||||
context_docs=context_chunks,
|
|
||||||
is_json_prompt=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SimpleChatQAHandler(QAHandler):
|
class SimpleChatQAHandler(QAHandler):
|
||||||
|
@property
|
||||||
|
def is_json_output(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self, query: str, context_chunks: list[InferenceChunk]
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
@@ -84,26 +95,11 @@ class SimpleChatQAHandler(QAHandler):
|
|||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self,
|
|
||||||
tokens: Iterator[str],
|
|
||||||
context_chunks: list[InferenceChunk],
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=tokens,
|
|
||||||
context_docs=context_chunks,
|
|
||||||
is_json_prompt=False,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
class SingleMessageQAHandler(QAHandler):
|
class SingleMessageQAHandler(QAHandler):
|
||||||
def build_prompt(
|
def build_prompt(
|
||||||
self, query: str, context_chunks: list[InferenceChunk]
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
) -> list[BaseMessage]:
|
) -> list[BaseMessage]:
|
||||||
complete_answer_not_found_response = (
|
|
||||||
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
|
||||||
)
|
|
||||||
|
|
||||||
context_docs_str = "\n".join(
|
context_docs_str = "\n".join(
|
||||||
f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks
|
f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks
|
||||||
)
|
)
|
||||||
@@ -115,27 +111,64 @@ class SingleMessageQAHandler(QAHandler):
|
|||||||
"to provide accurate and detailed answers to diverse queries.\n"
|
"to provide accurate and detailed answers to diverse queries.\n"
|
||||||
"You ALWAYS responds in a json containing an answer and quotes that support the answer.\n"
|
"You ALWAYS responds in a json containing an answer and quotes that support the answer.\n"
|
||||||
"Your responses are as INFORMATIVE and DETAILED as possible.\n"
|
"Your responses are as INFORMATIVE and DETAILED as possible.\n"
|
||||||
"If you don't know the answer, respond with "
|
|
||||||
f"{CODE_BLOCK_PAT.format(complete_answer_not_found_response)}"
|
|
||||||
"\nSample response:"
|
|
||||||
f"{CODE_BLOCK_PAT.format(json.dumps(SAMPLE_JSON_RESPONSE))}"
|
|
||||||
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
|
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
|
||||||
f"{GENERAL_SEP_PAT}{QUESTION_PAT} {query}"
|
f"{GENERAL_SEP_PAT}Sample response:"
|
||||||
|
f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n"
|
||||||
|
f"{QUESTION_PAT} {query}"
|
||||||
"\nHint: Make the answer as detailed as possible and use a JSON! "
|
"\nHint: Make the answer as detailed as possible and use a JSON! "
|
||||||
"Quotes MUST be EXACT substrings from provided documents!"
|
"Quotes MUST be EXACT substrings from provided documents!"
|
||||||
)
|
)
|
||||||
]
|
]
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self,
|
class SingleMessageScratchpadHandler(QAHandler):
|
||||||
tokens: Iterator[str],
|
def build_prompt(
|
||||||
context_chunks: list[InferenceChunk],
|
self, query: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> list[BaseMessage]:
|
||||||
|
cot_block = (
|
||||||
|
f"{THOUGHT_PAT} Let's think step by step. Use this section as a scratchpad.\n"
|
||||||
|
f"{json.dumps(EMPTY_SAMPLE_JSON)}"
|
||||||
|
)
|
||||||
|
|
||||||
|
context_docs_str = "\n".join(
|
||||||
|
f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks
|
||||||
|
)
|
||||||
|
|
||||||
|
prompt: list[BaseMessage] = [
|
||||||
|
HumanMessage(
|
||||||
|
content="You are a question answering system that is constantly learning and improving. "
|
||||||
|
"You can process and comprehend vast amounts of text and utilize this knowledge "
|
||||||
|
"to provide accurate and detailed answers to diverse queries.\n"
|
||||||
|
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}"
|
||||||
|
f"You MUST use the following format:\n"
|
||||||
|
f"{CODE_BLOCK_PAT.format(cot_block)}\n"
|
||||||
|
f"Begin!\n{QUESTION_PAT} {query}"
|
||||||
|
)
|
||||||
|
]
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
def process_llm_output(
|
||||||
|
self, model_output: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer_start = model_output.find('{"answer":')
|
||||||
|
# Only found thoughts, no final answer
|
||||||
|
if answer_start == -1:
|
||||||
|
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
|
||||||
|
|
||||||
|
final_json = escape_newlines(model_output[answer_start:])
|
||||||
|
|
||||||
|
return process_answer(
|
||||||
|
final_json, context_chunks, is_json_prompt=self.is_json_output
|
||||||
|
)
|
||||||
|
|
||||||
|
def process_llm_token_stream(
|
||||||
|
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
|
||||||
) -> AnswerQuestionStreamReturn:
|
) -> AnswerQuestionStreamReturn:
|
||||||
yield from process_model_tokens(
|
raise ValueError(
|
||||||
tokens=tokens,
|
"This Scratchpad approach is not suitable for real time uses like streaming"
|
||||||
context_docs=context_chunks,
|
|
||||||
is_json_prompt=True,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@@ -172,17 +205,6 @@ class JsonChatQAUnshackledHandler(QAHandler):
|
|||||||
|
|
||||||
return prompt
|
return prompt
|
||||||
|
|
||||||
def process_response(
|
|
||||||
self,
|
|
||||||
tokens: Iterator[str],
|
|
||||||
context_chunks: list[InferenceChunk],
|
|
||||||
) -> AnswerQuestionStreamReturn:
|
|
||||||
yield from process_model_tokens(
|
|
||||||
tokens=tokens,
|
|
||||||
context_docs=context_chunks,
|
|
||||||
is_json_prompt=True,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def _tiktoken_trim_chunks(
|
def _tiktoken_trim_chunks(
|
||||||
chunks: list[InferenceChunk], max_chunk_toks: int = 512
|
chunks: list[InferenceChunk], max_chunk_toks: int = 512
|
||||||
@@ -212,7 +234,7 @@ class QABlock(QAModel):
|
|||||||
def warm_up_model(self) -> None:
|
def warm_up_model(self) -> None:
|
||||||
"""This is called during server start up to load the models into memory
|
"""This is called during server start up to load the models into memory
|
||||||
in case the chosen LLM is not accessed via API"""
|
in case the chosen LLM is not accessed via API"""
|
||||||
self._llm.stream("Ignore this!")
|
self._llm.invoke("Ignore this!")
|
||||||
|
|
||||||
def answer_question(
|
def answer_question(
|
||||||
self,
|
self,
|
||||||
@@ -221,21 +243,9 @@ class QABlock(QAModel):
|
|||||||
) -> AnswerQuestionReturn:
|
) -> AnswerQuestionReturn:
|
||||||
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
||||||
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||||
tokens = self._llm.stream(prompt)
|
model_out = self._llm.invoke(prompt)
|
||||||
|
|
||||||
final_answer = ""
|
return self._qa_handler.process_llm_output(model_out, trimmed_context_docs)
|
||||||
quotes = DanswerQuotes([])
|
|
||||||
for output in self._qa_handler.process_response(tokens, trimmed_context_docs):
|
|
||||||
if output is None:
|
|
||||||
continue
|
|
||||||
|
|
||||||
if isinstance(output, DanswerAnswerPiece):
|
|
||||||
if output.answer_piece:
|
|
||||||
final_answer += output.answer_piece
|
|
||||||
elif isinstance(output, DanswerQuotes):
|
|
||||||
quotes = output
|
|
||||||
|
|
||||||
return DanswerAnswer(final_answer), quotes
|
|
||||||
|
|
||||||
def answer_question_stream(
|
def answer_question_stream(
|
||||||
self,
|
self,
|
||||||
@@ -245,4 +255,6 @@ class QABlock(QAModel):
|
|||||||
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
|
||||||
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
|
||||||
tokens = self._llm.stream(prompt)
|
tokens = self._llm.stream(prompt)
|
||||||
yield from self._qa_handler.process_response(tokens, trimmed_context_docs)
|
yield from self._qa_handler.process_llm_token_stream(
|
||||||
|
tokens, trimmed_context_docs
|
||||||
|
)
|
||||||
|
@@ -11,9 +11,12 @@ CODE_BLOCK_PAT = "\n```\n{}\n```\n"
|
|||||||
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
DOC_SEP_PAT = "---NEW DOCUMENT---"
|
||||||
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
|
||||||
QUESTION_PAT = "Query:"
|
QUESTION_PAT = "Query:"
|
||||||
|
THOUGHT_PAT = "Thought:"
|
||||||
ANSWER_PAT = "Answer:"
|
ANSWER_PAT = "Answer:"
|
||||||
|
FINAL_ANSWER_PAT = "Final Answer:"
|
||||||
UNCERTAINTY_PAT = "?"
|
UNCERTAINTY_PAT = "?"
|
||||||
QUOTE_PAT = "Quote:"
|
QUOTE_PAT = "Quote:"
|
||||||
|
QUOTES_PAT_PLURAL = "Quotes:"
|
||||||
|
|
||||||
BASE_PROMPT = (
|
BASE_PROMPT = (
|
||||||
"Answer the query based on provided documents and quote relevant sections. "
|
"Answer the query based on provided documents and quote relevant sections. "
|
||||||
@@ -31,6 +34,17 @@ SAMPLE_JSON_RESPONSE = {
|
|||||||
"located on the Champ de Mars in France.",
|
"located on the Champ de Mars in France.",
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
EMPTY_SAMPLE_JSON = {
|
||||||
|
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
|
||||||
|
"quotes": [
|
||||||
|
"each quote must be UNEDITED and EXACTLY as shown in the provided documents!",
|
||||||
|
"HINT the quotes are not shown to the user!",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
ANSWER_NOT_FOUND_JSON = '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||||
|
|
||||||
SAMPLE_RESPONSE_COT = (
|
SAMPLE_RESPONSE_COT = (
|
||||||
"Let's think step by step. The user is asking for the "
|
"Let's think step by step. The user is asking for the "
|
||||||
"location of the Eiffel Tower. The first document describes the Eiffel Tower "
|
"location of the Eiffel Tower. The first document describes the Eiffel Tower "
|
||||||
|
@@ -80,12 +80,17 @@ def extract_answer_quotes_json(
|
|||||||
|
|
||||||
|
|
||||||
def separate_answer_quotes(
|
def separate_answer_quotes(
|
||||||
answer_raw: str,
|
answer_raw: str, is_json_prompt: bool = False
|
||||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||||
try:
|
try:
|
||||||
model_raw_json = json.loads(answer_raw)
|
model_raw_json = json.loads(answer_raw)
|
||||||
return extract_answer_quotes_json(model_raw_json)
|
return extract_answer_quotes_json(model_raw_json)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
|
if is_json_prompt:
|
||||||
|
logger.error(
|
||||||
|
"Model did not output in json format as expected, "
|
||||||
|
"trying to parse it regardless"
|
||||||
|
)
|
||||||
return extract_answer_quotes_freeform(answer_raw)
|
return extract_answer_quotes_freeform(answer_raw)
|
||||||
|
|
||||||
|
|
||||||
@@ -149,9 +154,11 @@ def match_quotes_to_docs(
|
|||||||
|
|
||||||
|
|
||||||
def process_answer(
|
def process_answer(
|
||||||
answer_raw: str, chunks: list[InferenceChunk]
|
answer_raw: str,
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
is_json_prompt: bool = True,
|
||||||
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
) -> tuple[DanswerAnswer, DanswerQuotes]:
|
||||||
answer, quote_strings = separate_answer_quotes(answer_raw)
|
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
|
||||||
if answer == UNCERTAINTY_PAT or not answer:
|
if answer == UNCERTAINTY_PAT or not answer:
|
||||||
if answer == UNCERTAINTY_PAT:
|
if answer == UNCERTAINTY_PAT:
|
||||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||||
|
@@ -5,34 +5,35 @@ from dataclasses import asdict
|
|||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
|
||||||
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
|
||||||
from danswer.llm.build import get_default_llm
|
from danswer.llm.build import get_default_llm
|
||||||
from danswer.server.models import QueryValidationResponse
|
from danswer.server.models import QueryValidationResponse
|
||||||
from danswer.server.utils import get_json_line
|
from danswer.server.utils import get_json_line
|
||||||
|
|
||||||
QUERY_PAT = "QUERY: "
|
QUERY_PAT = "QUERY: "
|
||||||
REASONING_PAT = "REASONING: "
|
REASONING_PAT = "THOUGHT: "
|
||||||
ANSWERABLE_PAT = "ANSWERABLE: "
|
ANSWERABLE_PAT = "ANSWERABLE: "
|
||||||
|
|
||||||
|
|
||||||
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
||||||
ambiguous_example = (
|
ambiguous_example_question = f"{QUERY_PAT}What is this Slack channel about?"
|
||||||
f"{QUERY_PAT}What is this Slack channel about?\n"
|
ambiguous_example_answer = (
|
||||||
f"{REASONING_PAT}First the system must determine which Slack channel is "
|
f"{REASONING_PAT}First the system must determine which Slack channel is "
|
||||||
f"being referred to. By fetching 5 documents related to Slack channel contents, "
|
f"being referred to. By fetching 5 documents related to Slack channel contents, "
|
||||||
f"it is not possible to determine which Slack channel the user is referring to.\n"
|
f"it is not possible to determine which Slack channel the user is referring to.\n"
|
||||||
f"{ANSWERABLE_PAT}False"
|
f"{ANSWERABLE_PAT}False"
|
||||||
)
|
)
|
||||||
|
|
||||||
debug_example = (
|
debug_example_question = f"{QUERY_PAT}Danswer is unreachable."
|
||||||
f"{QUERY_PAT}Danswer is unreachable.\n"
|
debug_example_answer = (
|
||||||
f"{REASONING_PAT}The system searches documents related to Danswer being "
|
f"{REASONING_PAT}The system searches documents related to Danswer being "
|
||||||
f"unreachable. Assuming the documents from search contains situations where "
|
f"unreachable. Assuming the documents from search contains situations where "
|
||||||
f"Danswer is not reachable and contains a fix, the query is answerable.\n"
|
f"Danswer is not reachable and contains a fix, the query may be answerable.\n"
|
||||||
f"{ANSWERABLE_PAT}True"
|
f"{ANSWERABLE_PAT}True"
|
||||||
)
|
)
|
||||||
|
|
||||||
up_to_date_example = (
|
up_to_date_example_question = f"{QUERY_PAT}How many customers do we have"
|
||||||
f"{QUERY_PAT}How many customers do we have\n"
|
up_to_date_example_answer = (
|
||||||
f"{REASONING_PAT}Assuming the retrieved documents contain up to date customer "
|
f"{REASONING_PAT}Assuming the retrieved documents contain up to date customer "
|
||||||
f"acquisition information including a list of customers, the query can be answered. "
|
f"acquisition information including a list of customers, the query can be answered. "
|
||||||
f"It is important to note that if the information only exists in a database, "
|
f"It is important to note that if the information only exists in a database, "
|
||||||
@@ -44,18 +45,18 @@ def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
|
|||||||
{
|
{
|
||||||
"role": "user",
|
"role": "user",
|
||||||
"content": "You are a helper tool to determine if a query is answerable using retrieval augmented "
|
"content": "You are a helper tool to determine if a query is answerable using retrieval augmented "
|
||||||
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
|
f"generation.\nThe main system will try to answer the user query based on ONLY the top 5 most relevant "
|
||||||
f"documents found from search. Sources contain both up to date and proprietary information for "
|
f"documents found from search.\nSources contain both up to date and proprietary information for "
|
||||||
f"the specific team. For named or unknown entities, assume the search will always find "
|
f"the specific team.\nFor named or unknown entities, assume the search will find "
|
||||||
f"consistent knowledge about the entity.\n"
|
f"relevant and consistent knowledge about the entity.\n"
|
||||||
f"The system is not tuned for writing code nor for interfacing with structured data "
|
f"The system is not tuned for writing code.\n"
|
||||||
f"via query languages like SQL.\n"
|
f"The system is not tuned for interfacing with structured data via query languages like SQL.\n"
|
||||||
f"Determine if that system should attempt to answer. "
|
f"Determine if that system should attempt to answer.\n"
|
||||||
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n'
|
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n{GENERAL_SEP_PAT}\n'
|
||||||
f"{CODE_BLOCK_PAT.format(ambiguous_example)}\n"
|
f"{ambiguous_example_question}{CODE_BLOCK_PAT.format(ambiguous_example_answer)}\n"
|
||||||
f"{CODE_BLOCK_PAT.format(debug_example)}\n"
|
f"{debug_example_question}{CODE_BLOCK_PAT.format(debug_example_answer)}\n"
|
||||||
f"{CODE_BLOCK_PAT.format(up_to_date_example)}\n"
|
f"{up_to_date_example_question}{CODE_BLOCK_PAT.format(up_to_date_example_answer)}\n"
|
||||||
f"{CODE_BLOCK_PAT.format(QUERY_PAT + user_query)}\n",
|
f"{QUERY_PAT + user_query}",
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
|
|
||||||
@@ -103,7 +104,8 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
|
|||||||
|
|
||||||
if not reasoning_pat_found and REASONING_PAT in model_output:
|
if not reasoning_pat_found and REASONING_PAT in model_output:
|
||||||
reasoning_pat_found = True
|
reasoning_pat_found = True
|
||||||
remaining = model_output[len(REASONING_PAT) :]
|
reason_ind = model_output.find(REASONING_PAT)
|
||||||
|
remaining = model_output[reason_ind + len(REASONING_PAT) :]
|
||||||
if remaining:
|
if remaining:
|
||||||
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
|
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
|
||||||
continue
|
continue
|
||||||
|
@@ -4,6 +4,10 @@ import bs4
|
|||||||
from bs4 import BeautifulSoup
|
from bs4 import BeautifulSoup
|
||||||
|
|
||||||
|
|
||||||
|
def escape_newlines(s: str) -> str:
|
||||||
|
return re.sub(r"(?<!\\)\n", "\\\\n", s)
|
||||||
|
|
||||||
|
|
||||||
def clean_model_quote(quote: str, trim_length: int) -> str:
|
def clean_model_quote(quote: str, trim_length: int) -> str:
|
||||||
quote_clean = quote.strip()
|
quote_clean = quote.strip()
|
||||||
if quote_clean[0] == '"':
|
if quote_clean[0] == '"':
|
||||||
|
Reference in New Issue
Block a user