Slack CoT Scratchpad (#421)

This commit is contained in:
Yuhong Sun
2023-09-10 16:56:44 -07:00
committed by GitHub
parent 1d847bfd23
commit 6c795dfa6c
9 changed files with 150 additions and 99 deletions

View File

@@ -42,6 +42,7 @@ def handle_message(
user=None,
db_session=db_session,
answer_generation_timeout=answer_generation_timeout,
real_time_flow=False,
)
if not answer.error_msg:
return answer

View File

@@ -209,7 +209,7 @@ DANSWER_BOT_NUM_DOCS_TO_DISPLAY = int(
)
DANSWER_BOT_NUM_RETRIES = int(os.environ.get("DANSWER_BOT_NUM_RETRIES", "5"))
DANSWER_BOT_ANSWER_GENERATION_TIMEOUT = int(
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "60")
os.environ.get("DANSWER_BOT_ANSWER_GENERATION_TIMEOUT", "90")
)
DANSWER_BOT_DISPLAY_ERROR_MSGS = os.environ.get(
"DANSWER_BOT_DISPLAY_ERROR_MSGS", ""

View File

@@ -33,6 +33,7 @@ def answer_qa_query(
db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
answer_generation_timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
) -> QAResponse:
query = question.query
filters = question.filters
@@ -88,7 +89,9 @@ def answer_qa_query(
)
try:
qa_model = get_default_qa_model(timeout=answer_generation_timeout)
qa_model = get_default_qa_model(
timeout=answer_generation_timeout, real_time_flow=real_time_flow
)
except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse(
answer=None,

View File

@@ -22,6 +22,7 @@ from danswer.direct_qa.qa_block import QABlock
from danswer.direct_qa.qa_block import QAHandler
from danswer.direct_qa.qa_block import SimpleChatQAHandler
from danswer.direct_qa.qa_block import SingleMessageQAHandler
from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.direct_qa.request_model import RequestCompletionQA
@@ -51,9 +52,13 @@ def check_model_api_key_is_valid(model_api_key: str) -> bool:
return False
def get_default_qa_handler(model: str) -> QAHandler:
def get_default_qa_handler(model: str, real_time_flow: bool = True) -> QAHandler:
if model == DanswerGenAIModel.OPENAI_CHAT.value:
return SingleMessageQAHandler()
return (
SingleMessageQAHandler()
if real_time_flow
else SingleMessageScratchpadHandler()
)
return SimpleChatQAHandler()
@@ -64,6 +69,7 @@ def get_default_qa_model(
model_host_type: str | None = GEN_AI_HOST_TYPE,
api_key: str | None = GEN_AI_API_KEY,
timeout: int = QA_TIMEOUT,
real_time_flow: bool = True,
**kwargs: Any,
) -> QAModel:
if not api_key:
@@ -76,7 +82,9 @@ def get_default_qa_model(
# un-used arguments will be ignored by the underlying `LLM` class
# if any args are missing, a `TypeError` will be thrown
llm = get_default_llm(timeout=timeout)
qa_handler = get_default_qa_handler(model=internal_model)
qa_handler = get_default_qa_handler(
model=internal_model, real_time_flow=real_time_flow
)
return QABlock(
llm=llm,

View File

@@ -13,21 +13,24 @@ from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.interfaces import AnswerQuestionReturn
from danswer.direct_qa.interfaces import AnswerQuestionStreamReturn
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.direct_qa.qa_prompts import EMPTY_SAMPLE_JSON
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import QUESTION_PAT
from danswer.direct_qa.qa_prompts import SAMPLE_JSON_RESPONSE
from danswer.direct_qa.qa_prompts import THOUGHT_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.llm.llm import LLM
from danswer.llm.utils import dict_based_prompt_to_langchain_prompt
from danswer.llm.utils import str_prompt_to_langchain_prompt
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import escape_newlines
logger = setup_logger()
@@ -43,11 +46,26 @@ class QAHandler(abc.ABC):
) -> list[BaseMessage]:
raise NotImplementedError
@abc.abstractmethod
def process_response(
@property
def is_json_output(self) -> bool:
"""Does the model expected to output a valid json"""
return True
def process_llm_output(
self, model_output: str, context_chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, DanswerQuotes]:
return process_answer(
model_output, context_chunks, is_json_prompt=self.is_json_output
)
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
raise NotImplementedError
yield from process_model_tokens(
tokens=tokens,
context_docs=context_chunks,
is_json_prompt=self.is_json_output,
)
class JsonChatQAHandler(QAHandler):
@@ -60,19 +78,12 @@ class JsonChatQAHandler(QAHandler):
)
)
def process_response(
self,
tokens: Iterator[str],
context_chunks: list[InferenceChunk],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_chunks,
is_json_prompt=True,
)
class SimpleChatQAHandler(QAHandler):
@property
def is_json_output(self) -> bool:
return False
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
@@ -84,26 +95,11 @@ class SimpleChatQAHandler(QAHandler):
)
)
def process_response(
self,
tokens: Iterator[str],
context_chunks: list[InferenceChunk],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_chunks,
is_json_prompt=False,
)
class SingleMessageQAHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)
context_docs_str = "\n".join(
f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks
)
@@ -115,27 +111,64 @@ class SingleMessageQAHandler(QAHandler):
"to provide accurate and detailed answers to diverse queries.\n"
"You ALWAYS responds in a json containing an answer and quotes that support the answer.\n"
"Your responses are as INFORMATIVE and DETAILED as possible.\n"
"If you don't know the answer, respond with "
f"{CODE_BLOCK_PAT.format(complete_answer_not_found_response)}"
"\nSample response:"
f"{CODE_BLOCK_PAT.format(json.dumps(SAMPLE_JSON_RESPONSE))}"
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}"
f"{GENERAL_SEP_PAT}{QUESTION_PAT} {query}"
f"{GENERAL_SEP_PAT}Sample response:"
f"{CODE_BLOCK_PAT.format(json.dumps(EMPTY_SAMPLE_JSON))}\n"
f"{QUESTION_PAT} {query}"
"\nHint: Make the answer as detailed as possible and use a JSON! "
"Quotes MUST be EXACT substrings from provided documents!"
)
]
return prompt
def process_response(
self,
tokens: Iterator[str],
context_chunks: list[InferenceChunk],
class SingleMessageScratchpadHandler(QAHandler):
def build_prompt(
self, query: str, context_chunks: list[InferenceChunk]
) -> list[BaseMessage]:
cot_block = (
f"{THOUGHT_PAT} Let's think step by step. Use this section as a scratchpad.\n"
f"{json.dumps(EMPTY_SAMPLE_JSON)}"
)
context_docs_str = "\n".join(
f"{CODE_BLOCK_PAT.format(c.content)}" for c in context_chunks
)
prompt: list[BaseMessage] = [
HumanMessage(
content="You are a question answering system that is constantly learning and improving. "
"You can process and comprehend vast amounts of text and utilize this knowledge "
"to provide accurate and detailed answers to diverse queries.\n"
f"{GENERAL_SEP_PAT}CONTEXT:\n\n{context_docs_str}{GENERAL_SEP_PAT}"
f"You MUST use the following format:\n"
f"{CODE_BLOCK_PAT.format(cot_block)}\n"
f"Begin!\n{QUESTION_PAT} {query}"
)
]
return prompt
def process_llm_output(
self, model_output: str, context_chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, DanswerQuotes]:
logger.debug(model_output)
answer_start = model_output.find('{"answer":')
# Only found thoughts, no final answer
if answer_start == -1:
return DanswerAnswer(answer=None), DanswerQuotes(quotes=[])
final_json = escape_newlines(model_output[answer_start:])
return process_answer(
final_json, context_chunks, is_json_prompt=self.is_json_output
)
def process_llm_token_stream(
self, tokens: Iterator[str], context_chunks: list[InferenceChunk]
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_chunks,
is_json_prompt=True,
raise ValueError(
"This Scratchpad approach is not suitable for real time uses like streaming"
)
@@ -172,17 +205,6 @@ class JsonChatQAUnshackledHandler(QAHandler):
return prompt
def process_response(
self,
tokens: Iterator[str],
context_chunks: list[InferenceChunk],
) -> AnswerQuestionStreamReturn:
yield from process_model_tokens(
tokens=tokens,
context_docs=context_chunks,
is_json_prompt=True,
)
def _tiktoken_trim_chunks(
chunks: list[InferenceChunk], max_chunk_toks: int = 512
@@ -212,7 +234,7 @@ class QABlock(QAModel):
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
in case the chosen LLM is not accessed via API"""
self._llm.stream("Ignore this!")
self._llm.invoke("Ignore this!")
def answer_question(
self,
@@ -221,21 +243,9 @@ class QABlock(QAModel):
) -> AnswerQuestionReturn:
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
tokens = self._llm.stream(prompt)
model_out = self._llm.invoke(prompt)
final_answer = ""
quotes = DanswerQuotes([])
for output in self._qa_handler.process_response(tokens, trimmed_context_docs):
if output is None:
continue
if isinstance(output, DanswerAnswerPiece):
if output.answer_piece:
final_answer += output.answer_piece
elif isinstance(output, DanswerQuotes):
quotes = output
return DanswerAnswer(final_answer), quotes
return self._qa_handler.process_llm_output(model_out, trimmed_context_docs)
def answer_question_stream(
self,
@@ -245,4 +255,6 @@ class QABlock(QAModel):
trimmed_context_docs = _tiktoken_trim_chunks(context_docs)
prompt = self._qa_handler.build_prompt(query, trimmed_context_docs)
tokens = self._llm.stream(prompt)
yield from self._qa_handler.process_response(tokens, trimmed_context_docs)
yield from self._qa_handler.process_llm_token_stream(
tokens, trimmed_context_docs
)

View File

@@ -11,9 +11,12 @@ CODE_BLOCK_PAT = "\n```\n{}\n```\n"
DOC_SEP_PAT = "---NEW DOCUMENT---"
DOC_CONTENT_START_PAT = "DOCUMENT CONTENTS:\n"
QUESTION_PAT = "Query:"
THOUGHT_PAT = "Thought:"
ANSWER_PAT = "Answer:"
FINAL_ANSWER_PAT = "Final Answer:"
UNCERTAINTY_PAT = "?"
QUOTE_PAT = "Quote:"
QUOTES_PAT_PLURAL = "Quotes:"
BASE_PROMPT = (
"Answer the query based on provided documents and quote relevant sections. "
@@ -31,6 +34,17 @@ SAMPLE_JSON_RESPONSE = {
"located on the Champ de Mars in France.",
],
}
EMPTY_SAMPLE_JSON = {
"answer": "Place your final answer here. It should be as DETAILED and INFORMATIVE as possible.",
"quotes": [
"each quote must be UNEDITED and EXACTLY as shown in the provided documents!",
"HINT the quotes are not shown to the user!",
],
}
ANSWER_NOT_FOUND_JSON = '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
SAMPLE_RESPONSE_COT = (
"Let's think step by step. The user is asking for the "
"location of the Eiffel Tower. The first document describes the Eiffel Tower "

View File

@@ -80,12 +80,17 @@ def extract_answer_quotes_json(
def separate_answer_quotes(
answer_raw: str,
answer_raw: str, is_json_prompt: bool = False
) -> Tuple[Optional[str], Optional[list[str]]]:
try:
model_raw_json = json.loads(answer_raw)
return extract_answer_quotes_json(model_raw_json)
except ValueError:
if is_json_prompt:
logger.error(
"Model did not output in json format as expected, "
"trying to parse it regardless"
)
return extract_answer_quotes_freeform(answer_raw)
@@ -149,9 +154,11 @@ def match_quotes_to_docs(
def process_answer(
answer_raw: str, chunks: list[InferenceChunk]
answer_raw: str,
chunks: list[InferenceChunk],
is_json_prompt: bool = True,
) -> tuple[DanswerAnswer, DanswerQuotes]:
answer, quote_strings = separate_answer_quotes(answer_raw)
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")

View File

@@ -5,34 +5,35 @@ from dataclasses import asdict
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt
from danswer.direct_qa.qa_prompts import CODE_BLOCK_PAT
from danswer.direct_qa.qa_prompts import GENERAL_SEP_PAT
from danswer.llm.build import get_default_llm
from danswer.server.models import QueryValidationResponse
from danswer.server.utils import get_json_line
QUERY_PAT = "QUERY: "
REASONING_PAT = "REASONING: "
REASONING_PAT = "THOUGHT: "
ANSWERABLE_PAT = "ANSWERABLE: "
def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
ambiguous_example = (
f"{QUERY_PAT}What is this Slack channel about?\n"
ambiguous_example_question = f"{QUERY_PAT}What is this Slack channel about?"
ambiguous_example_answer = (
f"{REASONING_PAT}First the system must determine which Slack channel is "
f"being referred to. By fetching 5 documents related to Slack channel contents, "
f"it is not possible to determine which Slack channel the user is referring to.\n"
f"{ANSWERABLE_PAT}False"
)
debug_example = (
f"{QUERY_PAT}Danswer is unreachable.\n"
debug_example_question = f"{QUERY_PAT}Danswer is unreachable."
debug_example_answer = (
f"{REASONING_PAT}The system searches documents related to Danswer being "
f"unreachable. Assuming the documents from search contains situations where "
f"Danswer is not reachable and contains a fix, the query is answerable.\n"
f"Danswer is not reachable and contains a fix, the query may be answerable.\n"
f"{ANSWERABLE_PAT}True"
)
up_to_date_example = (
f"{QUERY_PAT}How many customers do we have\n"
up_to_date_example_question = f"{QUERY_PAT}How many customers do we have"
up_to_date_example_answer = (
f"{REASONING_PAT}Assuming the retrieved documents contain up to date customer "
f"acquisition information including a list of customers, the query can be answered. "
f"It is important to note that if the information only exists in a database, "
@@ -44,18 +45,18 @@ def get_query_validation_messages(user_query: str) -> list[dict[str, str]]:
{
"role": "user",
"content": "You are a helper tool to determine if a query is answerable using retrieval augmented "
f"generation. A system will try to answer the user query based on ONLY the top 5 most relevant "
f"documents found from search. Sources contain both up to date and proprietary information for "
f"the specific team. For named or unknown entities, assume the search will always find "
f"consistent knowledge about the entity.\n"
f"The system is not tuned for writing code nor for interfacing with structured data "
f"via query languages like SQL.\n"
f"Determine if that system should attempt to answer. "
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n'
f"{CODE_BLOCK_PAT.format(ambiguous_example)}\n"
f"{CODE_BLOCK_PAT.format(debug_example)}\n"
f"{CODE_BLOCK_PAT.format(up_to_date_example)}\n"
f"{CODE_BLOCK_PAT.format(QUERY_PAT + user_query)}\n",
f"generation.\nThe main system will try to answer the user query based on ONLY the top 5 most relevant "
f"documents found from search.\nSources contain both up to date and proprietary information for "
f"the specific team.\nFor named or unknown entities, assume the search will find "
f"relevant and consistent knowledge about the entity.\n"
f"The system is not tuned for writing code.\n"
f"The system is not tuned for interfacing with structured data via query languages like SQL.\n"
f"Determine if that system should attempt to answer.\n"
f'"{ANSWERABLE_PAT}" must be exactly "True" or "False"\n{GENERAL_SEP_PAT}\n'
f"{ambiguous_example_question}{CODE_BLOCK_PAT.format(ambiguous_example_answer)}\n"
f"{debug_example_question}{CODE_BLOCK_PAT.format(debug_example_answer)}\n"
f"{up_to_date_example_question}{CODE_BLOCK_PAT.format(up_to_date_example_answer)}\n"
f"{QUERY_PAT + user_query}",
},
]
@@ -103,7 +104,8 @@ def stream_query_answerability(user_query: str) -> Iterator[str]:
if not reasoning_pat_found and REASONING_PAT in model_output:
reasoning_pat_found = True
remaining = model_output[len(REASONING_PAT) :]
reason_ind = model_output.find(REASONING_PAT)
remaining = model_output[reason_ind + len(REASONING_PAT) :]
if remaining:
yield get_json_line(asdict(DanswerAnswerPiece(answer_piece=remaining)))
continue

View File

@@ -4,6 +4,10 @@ import bs4
from bs4 import BeautifulSoup
def escape_newlines(s: str) -> str:
return re.sub(r"(?<!\\)\n", "\\\\n", s)
def clean_model_quote(quote: str, trim_length: int) -> str:
quote_clean = quote.strip()
if quote_clean[0] == '"':