mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-10 05:05:34 +02:00
465 lines
16 KiB
Python
465 lines
16 KiB
Python
import json
|
|
import math
|
|
import re
|
|
from collections.abc import Callable
|
|
from collections.abc import Generator
|
|
from functools import wraps
|
|
from typing import Any
|
|
from typing import cast
|
|
from typing import Dict
|
|
from typing import Literal
|
|
from typing import Optional
|
|
from typing import Tuple
|
|
from typing import TypeVar
|
|
from typing import Union
|
|
|
|
import openai
|
|
import regex
|
|
from danswer.chunking.models import InferenceChunk
|
|
from danswer.configs.app_configs import OPENAI_API_KEY
|
|
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
|
from danswer.configs.constants import BLURB
|
|
from danswer.configs.constants import DOCUMENT_ID
|
|
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
|
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
|
from danswer.configs.constants import SOURCE_LINK
|
|
from danswer.configs.constants import SOURCE_TYPE
|
|
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
|
|
from danswer.configs.model_configs import OPENAI_MODEL_VERSION
|
|
from danswer.direct_qa.interfaces import QAModel
|
|
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
|
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
|
|
from danswer.direct_qa.qa_prompts import json_chat_processor
|
|
from danswer.direct_qa.qa_prompts import json_processor
|
|
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
|
from danswer.dynamic_configs import get_dynamic_config_store
|
|
from danswer.utils.logging import setup_logger
|
|
from danswer.utils.text_processing import clean_model_quote
|
|
from danswer.utils.text_processing import shared_precompare_cleanup
|
|
from danswer.utils.timing import log_function_time
|
|
from openai.error import AuthenticationError
|
|
from openai.error import Timeout
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
def get_openai_api_key() -> str:
|
|
return OPENAI_API_KEY or cast(
|
|
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
|
)
|
|
|
|
|
|
def get_json_line(json_dict: dict) -> str:
|
|
return json.dumps(json_dict) + "\n"
|
|
|
|
|
|
def extract_answer_quotes_freeform(
|
|
answer_raw: str,
|
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
|
null_answer_check = (
|
|
answer_raw.replace(ANSWER_PAT, "").replace(QUOTE_PAT, "").strip()
|
|
)
|
|
|
|
# If model just gives back the uncertainty pattern to signify answer isn't found or nothing at all
|
|
# if null_answer_check == UNCERTAINTY_PAT or not null_answer_check:
|
|
# return None, None
|
|
|
|
# If no answer section, don't care about the quote
|
|
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
|
return None, None
|
|
|
|
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
|
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
|
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
|
|
|
# Accept quote sections starting with the lower case version
|
|
answer_raw = answer_raw.replace(
|
|
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
|
) # Just in case model unreliable
|
|
|
|
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
|
sections_clean = [
|
|
str(section).strip() for section in sections if str(section).strip()
|
|
]
|
|
if not sections_clean:
|
|
return None, None
|
|
|
|
answer = str(sections_clean[0])
|
|
if len(sections) == 1:
|
|
return answer, None
|
|
return answer, sections_clean[1:]
|
|
|
|
|
|
def extract_answer_quotes_json(
|
|
answer_dict: dict[str, str | list[str]]
|
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
|
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
|
answer = str(answer_dict.get("answer"))
|
|
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
|
if isinstance(quotes, str):
|
|
quotes = [quotes]
|
|
return answer, quotes
|
|
|
|
|
|
def separate_answer_quotes(
|
|
answer_raw: str,
|
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
|
try:
|
|
model_raw_json = json.loads(answer_raw)
|
|
return extract_answer_quotes_json(model_raw_json)
|
|
except ValueError:
|
|
return extract_answer_quotes_freeform(answer_raw)
|
|
|
|
|
|
def match_quotes_to_docs(
|
|
quotes: list[str],
|
|
chunks: list[InferenceChunk],
|
|
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
|
fuzzy_search: bool = False,
|
|
prefix_only_length: int = 100,
|
|
) -> Dict[str, Dict[str, Union[str, int, None]]]:
|
|
quotes_dict: dict[str, dict[str, Union[str, int, None]]] = {}
|
|
for quote in quotes:
|
|
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
|
|
|
for chunk in chunks:
|
|
if not chunk.source_links:
|
|
continue
|
|
|
|
quote_clean = shared_precompare_cleanup(
|
|
clean_model_quote(quote, trim_length=prefix_only_length)
|
|
)
|
|
chunk_clean = shared_precompare_cleanup(chunk.content)
|
|
|
|
# Finding the offset of the quote in the plain text
|
|
if fuzzy_search:
|
|
re_search_str = (
|
|
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
|
)
|
|
found = regex.search(re_search_str, chunk_clean)
|
|
if not found:
|
|
continue
|
|
offset = found.span()[0]
|
|
else:
|
|
if quote_clean not in chunk_clean:
|
|
continue
|
|
offset = chunk_clean.index(quote_clean)
|
|
|
|
# Extracting the link from the offset
|
|
curr_link = None
|
|
for link_offset, link in chunk.source_links.items():
|
|
# Should always find one because offset is at least 0 and there must be a 0 link_offset
|
|
if int(link_offset) <= offset:
|
|
curr_link = link
|
|
else:
|
|
quotes_dict[quote] = {
|
|
DOCUMENT_ID: chunk.document_id,
|
|
SOURCE_LINK: curr_link,
|
|
SOURCE_TYPE: chunk.source_type,
|
|
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
|
BLURB: chunk.blurb,
|
|
}
|
|
break
|
|
quotes_dict[quote] = {
|
|
DOCUMENT_ID: chunk.document_id,
|
|
SOURCE_LINK: curr_link,
|
|
SOURCE_TYPE: chunk.source_type,
|
|
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
|
BLURB: chunk.blurb,
|
|
}
|
|
break
|
|
return quotes_dict
|
|
|
|
|
|
def process_answer(
|
|
answer_raw: str, chunks: list[InferenceChunk]
|
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
|
answer, quote_strings = separate_answer_quotes(answer_raw)
|
|
if not answer or not quote_strings:
|
|
return None, None
|
|
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
|
|
return answer, quotes_dict
|
|
|
|
|
|
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
|
next_token = next_token.replace('\\"', "")
|
|
# If the previous character is an escape token, don't consider the first character of next_token
|
|
if answer_so_far and answer_so_far[-1] == "\\":
|
|
next_token = next_token[1:]
|
|
if '"' in next_token:
|
|
return True
|
|
return False
|
|
|
|
|
|
F = TypeVar("F", bound=Callable)
|
|
|
|
Answer = str
|
|
Quotes = dict[str, dict[str, str | int | None]]
|
|
ModelType = Literal["ChatCompletion", "Completion"]
|
|
PromptProcessor = Callable[[str, list[str]], str]
|
|
|
|
|
|
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
|
"""
|
|
Utility to add in some common default values so they don't have to be set every time.
|
|
"""
|
|
return {
|
|
"temperature": 0,
|
|
"top_p": 1,
|
|
"frequency_penalty": 0,
|
|
"presence_penalty": 0,
|
|
**kwargs,
|
|
}
|
|
|
|
|
|
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
|
@wraps(openai_call)
|
|
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
try:
|
|
# if streamed, the call returns a generator
|
|
if kwargs.get("stream"):
|
|
|
|
def _generator() -> Generator[Any, None, None]:
|
|
yield from openai_call(*args, **kwargs)
|
|
|
|
return _generator()
|
|
return openai_call(*args, **kwargs)
|
|
except AuthenticationError:
|
|
logger.exception("Failed to authenticate with OpenAI API")
|
|
raise
|
|
except Timeout:
|
|
logger.exception("OpenAI API timed out for query: %s", query)
|
|
raise
|
|
except Exception as e:
|
|
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
|
raise
|
|
|
|
return cast(F, wrapped_call)
|
|
|
|
|
|
# used to check if the QAModel is an OpenAI model
|
|
class OpenAIQAModel(QAModel):
|
|
pass
|
|
|
|
|
|
class OpenAICompletionQA(OpenAIQAModel):
|
|
def __init__(
|
|
self,
|
|
prompt_processor: Callable[[str, list[str]], str] = json_processor,
|
|
model_version: str = OPENAI_MODEL_VERSION,
|
|
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
|
api_key: str | None = None,
|
|
timeout: int | None = None,
|
|
) -> None:
|
|
self.prompt_processor = prompt_processor
|
|
self.model_version = model_version
|
|
self.max_output_tokens = max_output_tokens
|
|
self.api_key = api_key or get_openai_api_key()
|
|
self.timeout = timeout
|
|
|
|
@log_function_time()
|
|
def answer_question(
|
|
self, query: str, context_docs: list[InferenceChunk]
|
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
|
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
|
filled_prompt = self.prompt_processor(query, top_contents)
|
|
logger.debug(filled_prompt)
|
|
|
|
openai_call = _handle_openai_exceptions_wrapper(
|
|
openai_call=openai.Completion.create,
|
|
query=query,
|
|
)
|
|
response = openai_call(
|
|
**_build_openai_settings(
|
|
api_key=self.api_key,
|
|
prompt=filled_prompt,
|
|
model=self.model_version,
|
|
max_tokens=self.max_output_tokens,
|
|
request_timeout=self.timeout,
|
|
),
|
|
)
|
|
model_output = cast(str, response["choices"][0]["text"]).strip()
|
|
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
|
logger.debug(model_output)
|
|
|
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
return answer, quotes_dict
|
|
|
|
def answer_question_stream(
|
|
self, query: str, context_docs: list[InferenceChunk]
|
|
) -> Generator[dict[str, Any] | None, None, None]:
|
|
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
|
filled_prompt = self.prompt_processor(query, top_contents)
|
|
logger.debug(filled_prompt)
|
|
|
|
openai_call = _handle_openai_exceptions_wrapper(
|
|
openai_call=openai.Completion.create,
|
|
query=query,
|
|
)
|
|
response = openai_call(
|
|
**_build_openai_settings(
|
|
api_key=self.api_key,
|
|
prompt=filled_prompt,
|
|
model=self.model_version,
|
|
max_tokens=self.max_output_tokens,
|
|
request_timeout=self.timeout,
|
|
stream=True,
|
|
),
|
|
)
|
|
model_output: str = ""
|
|
found_answer_start = False
|
|
found_answer_end = False
|
|
# iterate through the stream of events
|
|
for event in response:
|
|
event_text = cast(str, event["choices"][0]["text"])
|
|
model_previous = model_output
|
|
model_output += event_text
|
|
|
|
if not found_answer_start and '{"answer":"' in model_output.replace(
|
|
" ", ""
|
|
).replace("\n", ""):
|
|
found_answer_start = True
|
|
continue
|
|
|
|
if found_answer_start and not found_answer_end:
|
|
if stream_answer_end(model_previous, event_text):
|
|
found_answer_end = True
|
|
yield {"answer_finished": True}
|
|
continue
|
|
yield {"answer_data": event_text}
|
|
|
|
logger.debug(model_output)
|
|
|
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
if answer:
|
|
logger.info(answer)
|
|
else:
|
|
logger.warning(
|
|
"Answer extraction from model output failed, most likely no quotes provided"
|
|
)
|
|
|
|
if quotes_dict is None:
|
|
yield {}
|
|
else:
|
|
yield quotes_dict
|
|
|
|
|
|
class OpenAIChatCompletionQA(OpenAIQAModel):
|
|
def __init__(
|
|
self,
|
|
prompt_processor: Callable[
|
|
[str, list[str]], list[dict[str, str]]
|
|
] = json_chat_processor,
|
|
model_version: str = OPENAI_MODEL_VERSION,
|
|
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
|
timeout: int | None = None,
|
|
reflexion_try_count: int = 0,
|
|
api_key: str | None = None,
|
|
) -> None:
|
|
self.prompt_processor = prompt_processor
|
|
self.model_version = model_version
|
|
self.max_output_tokens = max_output_tokens
|
|
self.reflexion_try_count = reflexion_try_count
|
|
self.api_key = api_key or get_openai_api_key()
|
|
self.timeout = timeout
|
|
|
|
@log_function_time()
|
|
def answer_question(
|
|
self, query: str, context_docs: list[InferenceChunk]
|
|
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
|
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
|
messages = self.prompt_processor(query, top_contents)
|
|
logger.debug(messages)
|
|
model_output = ""
|
|
for _ in range(self.reflexion_try_count + 1):
|
|
openai_call = _handle_openai_exceptions_wrapper(
|
|
openai_call=openai.ChatCompletion.create,
|
|
query=query,
|
|
)
|
|
response = openai_call(
|
|
**_build_openai_settings(
|
|
api_key=self.api_key,
|
|
messages=messages,
|
|
model=self.model_version,
|
|
max_tokens=self.max_output_tokens,
|
|
request_timeout=self.timeout,
|
|
),
|
|
)
|
|
model_output = cast(
|
|
str, response["choices"][0]["message"]["content"]
|
|
).strip()
|
|
assistant_msg = {"content": model_output, "role": "assistant"}
|
|
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
|
logger.info(
|
|
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
|
)
|
|
|
|
logger.debug(model_output)
|
|
|
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
return answer, quotes_dict
|
|
|
|
def answer_question_stream(
|
|
self, query: str, context_docs: list[InferenceChunk]
|
|
) -> Generator[dict[str, Any] | None, None, None]:
|
|
top_contents = [ranked_chunk.content for ranked_chunk in context_docs]
|
|
messages = self.prompt_processor(query, top_contents)
|
|
logger.debug(messages)
|
|
|
|
openai_call = _handle_openai_exceptions_wrapper(
|
|
openai_call=openai.ChatCompletion.create,
|
|
query=query,
|
|
)
|
|
response = openai_call(
|
|
**_build_openai_settings(
|
|
api_key=self.api_key,
|
|
messages=messages,
|
|
model=self.model_version,
|
|
max_tokens=self.max_output_tokens,
|
|
request_timeout=self.timeout,
|
|
stream=True,
|
|
),
|
|
)
|
|
model_output: str = ""
|
|
found_answer_start = False
|
|
found_answer_end = False
|
|
for event in response:
|
|
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
|
if (
|
|
"content" not in event_dict
|
|
): # could be a role message or empty termination
|
|
continue
|
|
event_text = event_dict["content"]
|
|
model_previous = model_output
|
|
model_output += event_text
|
|
|
|
if not found_answer_start and '{"answer":"' in model_output.replace(
|
|
" ", ""
|
|
).replace("\n", ""):
|
|
found_answer_start = True
|
|
continue
|
|
|
|
if found_answer_start and not found_answer_end:
|
|
if stream_answer_end(model_previous, event_text):
|
|
found_answer_end = True
|
|
yield {"answer_finished": True}
|
|
continue
|
|
yield {"answer_data": event_text}
|
|
|
|
logger.debug(model_output)
|
|
|
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
if answer:
|
|
logger.info(answer)
|
|
else:
|
|
logger.warning(
|
|
"Answer extraction from model output failed, most likely no quotes provided"
|
|
)
|
|
|
|
if quotes_dict is None:
|
|
yield {}
|
|
else:
|
|
yield quotes_dict
|