diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py index 4439decc9..640bd4e0b 100644 --- a/backend/danswer/direct_qa/qa_utils.py +++ b/backend/danswer/direct_qa/qa_utils.py @@ -1,4 +1,3 @@ -import json import math import re from collections.abc import Generator @@ -24,12 +23,13 @@ from danswer.prompts.constants import UNCERTAINTY_PAT from danswer.utils.logger import setup_logger from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import clean_up_code_blocks +from danswer.utils.text_processing import extract_embedded_json from danswer.utils.text_processing import shared_precompare_cleanup logger = setup_logger() -def extract_answer_quotes_freeform( +def _extract_answer_quotes_freeform( answer_raw: str, ) -> Tuple[Optional[str], Optional[list[str]]]: """Splits the model output into an Answer and 0 or more Quote sections. @@ -61,7 +61,7 @@ def extract_answer_quotes_freeform( return answer, sections_clean[1:] -def extract_answer_quotes_json( +def _extract_answer_quotes_json( answer_dict: dict[str, str | list[str]] ) -> Tuple[Optional[str], Optional[list[str]]]: answer_dict = {k.lower(): v for k, v in answer_dict.items()} @@ -72,24 +72,30 @@ def extract_answer_quotes_json( return answer, quotes -def separate_answer_quotes( - answer_raw: str, is_json_prompt: bool = False -) -> Tuple[Optional[str], Optional[list[str]]]: +def _extract_answer_json(raw_model_output: str) -> dict: try: - model_raw_json = json.loads(answer_raw, strict=False) - return extract_answer_quotes_json(model_raw_json) - except JSONDecodeError: + answer_json = extract_embedded_json(raw_model_output) + except (ValueError, JSONDecodeError): # LLMs get confused when handling the list in the json. Sometimes it doesn't attend # enough to the previous { token so it just ends the list of quotes and stops there # here, we add logic to try to fix this LLM error. - try: - model_raw_json = json.loads(answer_raw + "}", strict=False) - return extract_answer_quotes_json(model_raw_json) - except JSONDecodeError: - if is_json_prompt: - logger.error("Model did not output in json format as expected.") - raise - return extract_answer_quotes_freeform(answer_raw) + answer_json = extract_embedded_json(raw_model_output + "}") + + if "answer" not in answer_json: + raise ValueError("Model did not output an answer as expected.") + + return answer_json + + +def separate_answer_quotes( + answer_raw: str, is_json_prompt: bool = False +) -> Tuple[Optional[str], Optional[list[str]]]: + """Takes in a raw model output and pulls out the answer and the quotes sections.""" + if is_json_prompt: + model_raw_json = _extract_answer_json(answer_raw) + return _extract_answer_quotes_json(model_raw_json) + + return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw)) def match_quotes_to_docs( @@ -156,9 +162,10 @@ def process_answer( chunks: list[InferenceChunk], is_json_prompt: bool = True, ) -> tuple[DanswerAnswer, DanswerQuotes]: - answer_clean = clean_up_code_blocks(answer_raw) - - answer, quote_strings = separate_answer_quotes(answer_clean, is_json_prompt) + """Used (1) in the non-streaming case to process the model output + into an Answer and Quotes AND (2) after the complete streaming response + has been received to process the model output into an Answer and Quotes.""" + answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt) if answer == UNCERTAINTY_PAT or not answer: if answer == UNCERTAINTY_PAT: logger.debug("Answer matched UNCERTAINTY_PAT") @@ -177,7 +184,7 @@ def process_answer( return DanswerAnswer(answer=answer), quotes -def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: +def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: next_token = next_token.replace('\\"', "") # If the previous character is an escape token, don't consider the first character of next_token # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling @@ -188,7 +195,7 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: return False -def extract_quotes_from_completed_token_stream( +def _extract_quotes_from_completed_token_stream( model_output: str, context_chunks: list[InferenceChunk] ) -> DanswerQuotes: answer, quotes = process_answer(model_output, context_chunks) @@ -205,7 +212,10 @@ def process_model_tokens( context_docs: list[InferenceChunk], is_json_prompt: bool = True, ) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: - """Yields Answer tokens back out in a dict for streaming to frontend + """Used in the streaming case to process the model output + into an Answer and Quotes + + Yields Answer tokens back out in a dict for streaming to frontend When Answer section ends, yields dict with answer_finished key Collects all the tokens at the end to form the complete model output""" quote_pat = f"\n{QUOTE_PAT}" @@ -228,14 +238,14 @@ def process_model_tokens( found_answer_start = True # Prevent heavy cases of hallucinations where model is not even providing a json until later - if is_json_prompt and len(model_output) > 20: + if is_json_prompt and len(model_output) > 40: logger.warning("LLM did not produce json as prompted") found_answer_end = True continue if found_answer_start and not found_answer_end: - if is_json_prompt and stream_json_answer_end(model_previous, token): + if is_json_prompt and _stream_json_answer_end(model_previous, token): found_answer_end = True yield DanswerAnswerPiece(answer_piece=None) continue @@ -252,20 +262,7 @@ def process_model_tokens( logger.debug(f"Raw Model QnA Output: {model_output}") - # for a JSON prompt, make sure that we're only passing through the "JSON part" - # since that is what `extract_quotes_from_completed_token_stream` expects - if is_json_prompt: - try: - json_answer_ind = model_output.index('{"answer":') - if json_answer_ind != 0: - model_output = model_output[json_answer_ind:] - end = model_output.rfind("}") - if end != -1: - model_output = model_output[: end + 1] - except ValueError: - logger.exception("Did not find answer pattern in response for JSON prompt") - - yield extract_quotes_from_completed_token_stream(model_output, context_docs) + yield _extract_quotes_from_completed_token_stream(model_output, context_docs) def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: