Files
danswer/backend/danswer/direct_qa/question_answer.py
Chris Weaver 8f5b9c0bcd Danswer assistant (#96)
Add helper!
2023-06-11 17:54:41 -07:00

465 lines
16 KiB
Python

import json
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from functools import wraps
from typing import Any
from typing import cast
from typing import Dict
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import openai
import regex
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import OPENAI_MODEL_VERSION
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
from danswer.direct_qa.qa_prompts import json_chat_processor
from danswer.direct_qa.qa_prompts import json_processor
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.utils.logging import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
from danswer.utils.timing import log_function_time
from openai.error import AuthenticationError
from openai.error import Timeout
logger = setup_logger()
def get_openai_api_key() -> str:
return OPENAI_API_KEY or cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
)
def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"
def extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
null_answer_check = (
answer_raw.replace(ANSWER_PAT, "").replace(QUOTE_PAT, "").strip()
)
# If model just gives back the uncertainty pattern to signify answer isn't found or nothing at all
# if null_answer_check == UNCERTAINTY_PAT or not null_answer_check:
# return None, None
# If no answer section, don't care about the quote
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
return None, None
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
answer_raw = answer_raw[len(ANSWER_PAT) :]
# Accept quote sections starting with the lower case version
answer_raw = answer_raw.replace(
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
) # Just in case model unreliable
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
sections_clean = [
str(section).strip() for section in sections if str(section).strip()
]
if not sections_clean:
return None, None
answer = str(sections_clean[0])
if len(sections) == 1:
return answer, None
return answer, sections_clean[1:]
def extract_answer_quotes_json(
answer_dict: dict[str, str | list[str]]
) -> Tuple[Optional[str], Optional[list[str]]]:
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
answer = str(answer_dict.get("answer"))
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
if isinstance(quotes, str):
quotes = [quotes]
return answer, quotes
def separate_answer_quotes(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
try:
model_raw_json = json.loads(answer_raw)
return extract_answer_quotes_json(model_raw_json)
except ValueError:
return extract_answer_quotes_freeform(answer_raw)
def match_quotes_to_docs(
quotes: list[str],
chunks: list[InferenceChunk],
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False,
prefix_only_length: int = 100,
) -> Dict[str, Dict[str, Union[str, int, None]]]:
quotes_dict: dict[str, dict[str, Union[str, int, None]]] = {}
for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent)
for chunk in chunks:
if not chunk.source_links:
continue
quote_clean = shared_precompare_cleanup(
clean_model_quote(quote, trim_length=prefix_only_length)
)
chunk_clean = shared_precompare_cleanup(chunk.content)
# Finding the offset of the quote in the plain text
if fuzzy_search:
re_search_str = (
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
)
found = regex.search(re_search_str, chunk_clean)
if not found:
continue
offset = found.span()[0]
else:
if quote_clean not in chunk_clean:
continue
offset = chunk_clean.index(quote_clean)
# Extracting the link from the offset
curr_link = None
for link_offset, link in chunk.source_links.items():
# Should always find one because offset is at least 0 and there must be a 0 link_offset
if int(link_offset) <= offset:
curr_link = link
else:
quotes_dict[quote] = {
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
BLURB: chunk.blurb,
}
break
quotes_dict[quote] = {
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
BLURB: chunk.blurb,
}
break
return quotes_dict
def process_answer(
answer_raw: str, chunks: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
answer, quote_strings = separate_answer_quotes(answer_raw)
if not answer or not quote_strings:
return None, None
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
return answer, quotes_dict
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
# If the previous character is an escape token, don't consider the first character of next_token
if answer_so_far and answer_so_far[-1] == "\\":
next_token = next_token[1:]
if '"' in next_token:
return True
return False
F = TypeVar("F", bound=Callable)
Answer = str
Quotes = dict[str, dict[str, str | int | None]]
ModelType = Literal["ChatCompletion", "Completion"]
PromptProcessor = Callable[[str, list[str]], str]
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"temperature": 0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
**kwargs,
}
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
@wraps(openai_call)
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
try:
# if streamed, the call returns a generator
if kwargs.get("stream"):
def _generator() -> Generator[Any, None, None]:
yield from openai_call(*args, **kwargs)
return _generator()
return openai_call(*args, **kwargs)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Timeout:
logger.exception("OpenAI API timed out for query: %s", query)
raise
except Exception as e:
logger.exception("Unexpected error with OpenAI API for query: %s", query)
raise
return cast(F, wrapped_call)
# used to check if the QAModel is an OpenAI model
class OpenAIQAModel(QAModel):
pass
class OpenAICompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[[str, list[str]], str] = json_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
api_key: str | None = None,
timeout: int | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.api_key = api_key or get_openai_api_key()
self.timeout = timeout
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
filled_prompt = self.prompt_processor(query, top_contents)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(str, response["choices"][0]["text"]).strip()
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
filled_prompt = self.prompt_processor(query, top_contents)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)
model_output: str = ""
found_answer_start = False
found_answer_end = False
# iterate through the stream of events
for event in response:
event_text = cast(str, event["choices"][0]["text"])
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
if answer:
logger.info(answer)
else:
logger.warning(
"Answer extraction from model output failed, most likely no quotes provided"
)
if quotes_dict is None:
yield {}
else:
yield quotes_dict
class OpenAIChatCompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[
[str, list[str]], list[dict[str, str]]
] = json_chat_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
reflexion_try_count: int = 0,
api_key: str | None = None,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.reflexion_try_count = reflexion_try_count
self.api_key = api_key or get_openai_api_key()
self.timeout = timeout
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
messages = self.prompt_processor(query, top_contents)
logger.debug(messages)
model_output = ""
for _ in range(self.reflexion_try_count + 1):
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(
str, response["choices"][0]["message"]["content"]
).strip()
assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_chat_reflexion_msg()])
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
messages = self.prompt_processor(query, top_contents)
logger.debug(messages)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)
model_output: str = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
if (
"content" not in event_dict
): # could be a role message or empty termination
continue
event_text = event_dict["content"]
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
if answer:
logger.info(answer)
else:
logger.warning(
"Answer extraction from model output failed, most likely no quotes provided"
)
if quotes_dict is None:
yield {}
else:
yield quotes_dict