Improve LLM answer parsing

This commit is contained in:
Weves 2023-11-23 14:31:23 -08:00 committed by Chris Weaver
parent 13001ede98
commit 26c6651a03

View File

@ -1,4 +1,3 @@
import json
import math
import re
from collections.abc import Generator
@ -24,12 +23,13 @@ from danswer.prompts.constants import UNCERTAINTY_PAT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
def extract_answer_quotes_freeform(
def _extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
"""Splits the model output into an Answer and 0 or more Quote sections.
@ -61,7 +61,7 @@ def extract_answer_quotes_freeform(
return answer, sections_clean[1:]
def extract_answer_quotes_json(
def _extract_answer_quotes_json(
answer_dict: dict[str, str | list[str]]
) -> Tuple[Optional[str], Optional[list[str]]]:
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
@ -72,24 +72,30 @@ def extract_answer_quotes_json(
return answer, quotes
def separate_answer_quotes(
answer_raw: str, is_json_prompt: bool = False
) -> Tuple[Optional[str], Optional[list[str]]]:
def _extract_answer_json(raw_model_output: str) -> dict:
try:
model_raw_json = json.loads(answer_raw, strict=False)
return extract_answer_quotes_json(model_raw_json)
except JSONDecodeError:
answer_json = extract_embedded_json(raw_model_output)
except (ValueError, JSONDecodeError):
# LLMs get confused when handling the list in the json. Sometimes it doesn't attend
# enough to the previous { token so it just ends the list of quotes and stops there
# here, we add logic to try to fix this LLM error.
try:
model_raw_json = json.loads(answer_raw + "}", strict=False)
return extract_answer_quotes_json(model_raw_json)
except JSONDecodeError:
if is_json_prompt:
logger.error("Model did not output in json format as expected.")
raise
return extract_answer_quotes_freeform(answer_raw)
answer_json = extract_embedded_json(raw_model_output + "}")
if "answer" not in answer_json:
raise ValueError("Model did not output an answer as expected.")
return answer_json
def separate_answer_quotes(
answer_raw: str, is_json_prompt: bool = False
) -> Tuple[Optional[str], Optional[list[str]]]:
"""Takes in a raw model output and pulls out the answer and the quotes sections."""
if is_json_prompt:
model_raw_json = _extract_answer_json(answer_raw)
return _extract_answer_quotes_json(model_raw_json)
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
def match_quotes_to_docs(
@ -156,9 +162,10 @@ def process_answer(
chunks: list[InferenceChunk],
is_json_prompt: bool = True,
) -> tuple[DanswerAnswer, DanswerQuotes]:
answer_clean = clean_up_code_blocks(answer_raw)
answer, quote_strings = separate_answer_quotes(answer_clean, is_json_prompt)
"""Used (1) in the non-streaming case to process the model output
into an Answer and Quotes AND (2) after the complete streaming response
has been received to process the model output into an Answer and Quotes."""
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")
@ -177,7 +184,7 @@ def process_answer(
return DanswerAnswer(answer=answer), quotes
def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
# If the previous character is an escape token, don't consider the first character of next_token
# This does not work if it's an escaped escape sign before the " but this is rare, not worth handling
@ -188,7 +195,7 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
return False
def extract_quotes_from_completed_token_stream(
def _extract_quotes_from_completed_token_stream(
model_output: str, context_chunks: list[InferenceChunk]
) -> DanswerQuotes:
answer, quotes = process_answer(model_output, context_chunks)
@ -205,7 +212,10 @@ def process_model_tokens(
context_docs: list[InferenceChunk],
is_json_prompt: bool = True,
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Yields Answer tokens back out in a dict for streaming to frontend
"""Used in the streaming case to process the model output
into an Answer and Quotes
Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
quote_pat = f"\n{QUOTE_PAT}"
@ -228,14 +238,14 @@ def process_model_tokens(
found_answer_start = True
# Prevent heavy cases of hallucinations where model is not even providing a json until later
if is_json_prompt and len(model_output) > 20:
if is_json_prompt and len(model_output) > 40:
logger.warning("LLM did not produce json as prompted")
found_answer_end = True
continue
if found_answer_start and not found_answer_end:
if is_json_prompt and stream_json_answer_end(model_previous, token):
if is_json_prompt and _stream_json_answer_end(model_previous, token):
found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None)
continue
@ -252,20 +262,7 @@ def process_model_tokens(
logger.debug(f"Raw Model QnA Output: {model_output}")
# for a JSON prompt, make sure that we're only passing through the "JSON part"
# since that is what `extract_quotes_from_completed_token_stream` expects
if is_json_prompt:
try:
json_answer_ind = model_output.index('{"answer":')
if json_answer_ind != 0:
model_output = model_output[json_answer_ind:]
end = model_output.rfind("}")
if end != -1:
model_output = model_output[: end + 1]
except ValueError:
logger.exception("Did not find answer pattern in response for JSON prompt")
yield extract_quotes_from_completed_token_stream(model_output, context_docs)
yield _extract_quotes_from_completed_token_stream(model_output, context_docs)
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: