Improve LLM answer parsing

This commit is contained in:
Weves
2023-11-23 14:31:23 -08:00
committed by Chris Weaver
parent 13001ede98
commit 26c6651a03

View File

@@ -1,4 +1,3 @@
import json
import math import math
import re import re
from collections.abc import Generator from collections.abc import Generator
@@ -24,12 +23,13 @@ from danswer.prompts.constants import UNCERTAINTY_PAT
from danswer.utils.logger import setup_logger from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import clean_up_code_blocks from danswer.utils.text_processing import clean_up_code_blocks
from danswer.utils.text_processing import extract_embedded_json
from danswer.utils.text_processing import shared_precompare_cleanup from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger() logger = setup_logger()
def extract_answer_quotes_freeform( def _extract_answer_quotes_freeform(
answer_raw: str, answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]: ) -> Tuple[Optional[str], Optional[list[str]]]:
"""Splits the model output into an Answer and 0 or more Quote sections. """Splits the model output into an Answer and 0 or more Quote sections.
@@ -61,7 +61,7 @@ def extract_answer_quotes_freeform(
return answer, sections_clean[1:] return answer, sections_clean[1:]
def extract_answer_quotes_json( def _extract_answer_quotes_json(
answer_dict: dict[str, str | list[str]] answer_dict: dict[str, str | list[str]]
) -> Tuple[Optional[str], Optional[list[str]]]: ) -> Tuple[Optional[str], Optional[list[str]]]:
answer_dict = {k.lower(): v for k, v in answer_dict.items()} answer_dict = {k.lower(): v for k, v in answer_dict.items()}
@@ -72,24 +72,30 @@ def extract_answer_quotes_json(
return answer, quotes return answer, quotes
def separate_answer_quotes( def _extract_answer_json(raw_model_output: str) -> dict:
answer_raw: str, is_json_prompt: bool = False
) -> Tuple[Optional[str], Optional[list[str]]]:
try: try:
model_raw_json = json.loads(answer_raw, strict=False) answer_json = extract_embedded_json(raw_model_output)
return extract_answer_quotes_json(model_raw_json) except (ValueError, JSONDecodeError):
except JSONDecodeError:
# LLMs get confused when handling the list in the json. Sometimes it doesn't attend # LLMs get confused when handling the list in the json. Sometimes it doesn't attend
# enough to the previous { token so it just ends the list of quotes and stops there # enough to the previous { token so it just ends the list of quotes and stops there
# here, we add logic to try to fix this LLM error. # here, we add logic to try to fix this LLM error.
try: answer_json = extract_embedded_json(raw_model_output + "}")
model_raw_json = json.loads(answer_raw + "}", strict=False)
return extract_answer_quotes_json(model_raw_json) if "answer" not in answer_json:
except JSONDecodeError: raise ValueError("Model did not output an answer as expected.")
if is_json_prompt:
logger.error("Model did not output in json format as expected.") return answer_json
raise
return extract_answer_quotes_freeform(answer_raw)
def separate_answer_quotes(
answer_raw: str, is_json_prompt: bool = False
) -> Tuple[Optional[str], Optional[list[str]]]:
"""Takes in a raw model output and pulls out the answer and the quotes sections."""
if is_json_prompt:
model_raw_json = _extract_answer_json(answer_raw)
return _extract_answer_quotes_json(model_raw_json)
return _extract_answer_quotes_freeform(clean_up_code_blocks(answer_raw))
def match_quotes_to_docs( def match_quotes_to_docs(
@@ -156,9 +162,10 @@ def process_answer(
chunks: list[InferenceChunk], chunks: list[InferenceChunk],
is_json_prompt: bool = True, is_json_prompt: bool = True,
) -> tuple[DanswerAnswer, DanswerQuotes]: ) -> tuple[DanswerAnswer, DanswerQuotes]:
answer_clean = clean_up_code_blocks(answer_raw) """Used (1) in the non-streaming case to process the model output
into an Answer and Quotes AND (2) after the complete streaming response
answer, quote_strings = separate_answer_quotes(answer_clean, is_json_prompt) has been received to process the model output into an Answer and Quotes."""
answer, quote_strings = separate_answer_quotes(answer_raw, is_json_prompt)
if answer == UNCERTAINTY_PAT or not answer: if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT: if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT") logger.debug("Answer matched UNCERTAINTY_PAT")
@@ -177,7 +184,7 @@ def process_answer(
return DanswerAnswer(answer=answer), quotes return DanswerAnswer(answer=answer), quotes
def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: def _stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "") next_token = next_token.replace('\\"', "")
# If the previous character is an escape token, don't consider the first character of next_token # If the previous character is an escape token, don't consider the first character of next_token
# This does not work if it's an escaped escape sign before the " but this is rare, not worth handling # This does not work if it's an escaped escape sign before the " but this is rare, not worth handling
@@ -188,7 +195,7 @@ def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
return False return False
def extract_quotes_from_completed_token_stream( def _extract_quotes_from_completed_token_stream(
model_output: str, context_chunks: list[InferenceChunk] model_output: str, context_chunks: list[InferenceChunk]
) -> DanswerQuotes: ) -> DanswerQuotes:
answer, quotes = process_answer(model_output, context_chunks) answer, quotes = process_answer(model_output, context_chunks)
@@ -205,7 +212,10 @@ def process_model_tokens(
context_docs: list[InferenceChunk], context_docs: list[InferenceChunk],
is_json_prompt: bool = True, is_json_prompt: bool = True,
) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]: ) -> Generator[DanswerAnswerPiece | DanswerQuotes, None, None]:
"""Yields Answer tokens back out in a dict for streaming to frontend """Used in the streaming case to process the model output
into an Answer and Quotes
Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output""" Collects all the tokens at the end to form the complete model output"""
quote_pat = f"\n{QUOTE_PAT}" quote_pat = f"\n{QUOTE_PAT}"
@@ -228,14 +238,14 @@ def process_model_tokens(
found_answer_start = True found_answer_start = True
# Prevent heavy cases of hallucinations where model is not even providing a json until later # Prevent heavy cases of hallucinations where model is not even providing a json until later
if is_json_prompt and len(model_output) > 20: if is_json_prompt and len(model_output) > 40:
logger.warning("LLM did not produce json as prompted") logger.warning("LLM did not produce json as prompted")
found_answer_end = True found_answer_end = True
continue continue
if found_answer_start and not found_answer_end: if found_answer_start and not found_answer_end:
if is_json_prompt and stream_json_answer_end(model_previous, token): if is_json_prompt and _stream_json_answer_end(model_previous, token):
found_answer_end = True found_answer_end = True
yield DanswerAnswerPiece(answer_piece=None) yield DanswerAnswerPiece(answer_piece=None)
continue continue
@@ -252,20 +262,7 @@ def process_model_tokens(
logger.debug(f"Raw Model QnA Output: {model_output}") logger.debug(f"Raw Model QnA Output: {model_output}")
# for a JSON prompt, make sure that we're only passing through the "JSON part" yield _extract_quotes_from_completed_token_stream(model_output, context_docs)
# since that is what `extract_quotes_from_completed_token_stream` expects
if is_json_prompt:
try:
json_answer_ind = model_output.index('{"answer":')
if json_answer_ind != 0:
model_output = model_output[json_answer_ind:]
end = model_output.rfind("}")
if end != -1:
model_output = model_output[: end + 1]
except ValueError:
logger.exception("Did not find answer pattern in response for JSON prompt")
yield extract_quotes_from_completed_token_stream(model_output, context_docs)
def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: def simulate_streaming_response(model_out: str) -> Generator[str, None, None]: