Support GPT4All in memory (#230)

This commit is contained in:
Yuhong Sun
2023-07-23 12:26:14 -07:00
committed by GitHub
parent 6684f1e5d5
commit d6ca865034
23 changed files with 1146 additions and 743 deletions

View File

@@ -4,6 +4,7 @@ from danswer.connectors.factory import instantiate_connector
from danswer.connectors.interfaces import LoadConnector
from danswer.connectors.interfaces import PollConnector
from danswer.connectors.models import InputType
from danswer.datastores.indexing_pipeline import build_indexing_pipeline
from danswer.db.connector import disable_connector
from danswer.db.connector import fetch_connectors
from danswer.db.connector_credential_pair import update_connector_credential_pair
@@ -21,7 +22,6 @@ from danswer.db.index_attempt import mark_attempt_succeeded
from danswer.db.models import Connector
from danswer.db.models import IndexAttempt
from danswer.db.models import IndexingStatus
from danswer.utils.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from sqlalchemy.orm import Session

View File

@@ -31,10 +31,16 @@ CROSS_EMBED_CONTEXT_SIZE = 512
BATCH_SIZE_ENCODE_CHUNKS = 8
# QA Model API Configs
# https://platform.openai.com/docs/models/model-endpoint-compatibility
# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models
# Valid list:
# - openai-completion
# - openai-chat-completion
# - gpt4all-completion
# - gpt4all-chat-completion
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
OPENAI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-3.5-turbo")
OPENAI_MAX_OUTPUT_TOKENS = 512
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
GEN_AI_MAX_OUTPUT_TOKENS = 512
# Danswer custom Deep Learning Models
INTENT_MODEL_VERSION = "danswer/intent-model"

View File

@@ -1,18 +1,49 @@
from typing import Any
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.llm import OpenAIChatCompletionQA
from danswer.direct_qa.llm import OpenAICompletionQA
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
from danswer.direct_qa.open_ai import OpenAICompletionQA
from openai.error import AuthenticationError
from openai.error import Timeout
def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
except Timeout:
pass
return False
def get_default_backend_qa_model(
internal_model: str = INTERNAL_MODEL_VERSION, **kwargs: Any
internal_model: str = INTERNAL_MODEL_VERSION,
api_key: str | None = None,
timeout: int = QA_TIMEOUT,
**kwargs: Any
) -> QAModel:
if internal_model == "openai-completion":
return OpenAICompletionQA(**kwargs)
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "openai-chat-completion":
return OpenAIChatCompletionQA(**kwargs)
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
elif internal_model == "gpt4all-completion":
return GPT4AllCompletionQA(**kwargs)
elif internal_model == "gpt4all-chat-completion":
return GPT4AllChatCompletionQA(**kwargs)
else:
raise UnknownModelError(internal_model)

View File

@@ -1,13 +1,13 @@
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.qa_utils import structure_quotes_for_response
from danswer.search.danswer_helper import query_intent
from danswer.search.keyword_search import retrieve_keyword_documents
from danswer.search.models import QueryFlow
@@ -26,7 +26,6 @@ logger = setup_logger()
def answer_question(
question: QuestionRequest,
user: User | None,
qa_model_timeout: int = QA_TIMEOUT,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> QAResponse:
query = question.query
@@ -74,7 +73,7 @@ def answer_question(
)
try:
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
qa_model = get_default_backend_qa_model()
except (UnknownModelError, OpenAIKeyMissing) as e:
return QAResponse(
answer=None,
@@ -102,8 +101,8 @@ def answer_question(
error_msg = f"Error occurred in call to LLM - {e}"
return QAResponse(
answer=answer,
quotes=quotes,
answer=answer.answer if answer else None,
quotes=structure_quotes_for_response(quotes),
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
predicted_flow=predicted_flow,

View File

@@ -1,5 +1,10 @@
class OpenAIKeyMissing(Exception):
def __init__(self, msg: str = "Unable to find an OpenAI Key") -> None:
default_msg = (
"Unable to find existing OpenAI Key. "
'A new key can be added from "Keys" section of the Admin Panel'
)
def __init__(self, msg: str = default_msg) -> None:
super().__init__(msg)

View File

@@ -0,0 +1,173 @@
from collections.abc import Generator
from typing import Any
from danswer.chunking.models import InferenceChunk
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from gpt4all import GPT4All # type:ignore
logger = setup_logger()
GPT4ALL_MODEL: GPT4All | None = None
def get_gpt_4_all_model(
model_version: str = GEN_AI_MODEL_VERSION,
) -> GPT4All:
global GPT4ALL_MODEL
if GPT4ALL_MODEL is None:
GPT4ALL_MODEL = GPT4All(model_version)
return GPT4ALL_MODEL
def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"temp": 0,
**kwargs,
}
class GPT4AllCompletionQA(QAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False, # gpt4all models can't handle this atm
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
def warm_up_model(self) -> None:
get_gpt_4_all_model(self.model_version)
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
model_output = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=filled_prompt, max_tokens=self.max_output_tokens
),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
model_stream = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True
),
)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class GPT4AllChatCompletionQA(QAModel):
def __init__(
self,
prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
include_metadata: bool = False, # gpt4all models can't handle this atm
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.include_metadata = include_metadata
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
with gen_ai_model.chat_session():
context_msgs = filled_prompt[:-1]
user_query = filled_prompt[-1].get("content")
for message in context_msgs:
gen_ai_model.current_chat_session.append(message)
model_output = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=user_query, max_tokens=self.max_output_tokens
),
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
gen_ai_model = get_gpt_4_all_model(self.model_version)
with gen_ai_model.chat_session():
context_msgs = filled_prompt[:-1]
user_query = filled_prompt[-1].get("content")
for message in context_msgs:
gen_ai_model.current_chat_session.append(message)
model_stream = gen_ai_model.generate(
**_build_gpt4all_settings(
prompt=user_query, max_tokens=self.max_output_tokens
),
)
yield from process_model_tokens(
tokens=model_stream,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -1,17 +1,39 @@
import abc
from collections.abc import Generator
from dataclasses import dataclass
from typing import Any
from danswer.chunking.models import InferenceChunk
@dataclass
class DanswerAnswer:
answer: str | None
@dataclass
class DanswerQuote:
# This is during inference so everything is a string by this point
quote: str
document_id: str
link: str | None
source_type: str
semantic_identifier: str
blurb: str
class QAModel:
def warm_up_model(self) -> None:
"""This is called during server start up to load the models into memory
pass if model is accessed via API"""
pass
@abc.abstractmethod
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
raise NotImplementedError
@abc.abstractmethod

View File

@@ -1,25 +0,0 @@
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.llm import OpenAIQAModel
from openai.error import AuthenticationError
from openai.error import Timeout
def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
if not openai_api_key:
return False
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5)
if not isinstance(qa_model, OpenAIQAModel):
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
qa_model.answer_question("Do not respond", [])
return True
except AuthenticationError:
return False
except Timeout:
pass
return False

View File

@@ -1,488 +0,0 @@
import json
import math
import re
from collections.abc import Callable
from collections.abc import Generator
from functools import wraps
from typing import Any
from typing import cast
from typing import Dict
from typing import Literal
from typing import Optional
from typing import Tuple
from typing import TypeVar
from typing import Union
import openai
import regex
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import INCLUDE_METADATA
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import OPENAI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
from danswer.direct_qa.qa_prompts import json_chat_processor
from danswer.direct_qa.qa_prompts import json_processor
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
from danswer.utils.timing import log_function_time
from openai.error import AuthenticationError
from openai.error import Timeout
logger = setup_logger()
def get_openai_api_key() -> str:
return OPENAI_API_KEY or cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
)
def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"
def extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
null_answer_check = (
answer_raw.replace(ANSWER_PAT, "").replace(QUOTE_PAT, "").strip()
)
# If model just gives back the uncertainty pattern to signify answer isn't found or nothing at all
# if null_answer_check == UNCERTAINTY_PAT or not null_answer_check:
# return None, None
# If no answer section, don't care about the quote
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
return None, None
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
answer_raw = answer_raw[len(ANSWER_PAT) :]
# Accept quote sections starting with the lower case version
answer_raw = answer_raw.replace(
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
) # Just in case model unreliable
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
sections_clean = [
str(section).strip() for section in sections if str(section).strip()
]
if not sections_clean:
return None, None
answer = str(sections_clean[0])
if len(sections) == 1:
return answer, None
return answer, sections_clean[1:]
def extract_answer_quotes_json(
answer_dict: dict[str, str | list[str]]
) -> Tuple[Optional[str], Optional[list[str]]]:
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
answer = str(answer_dict.get("answer"))
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
if isinstance(quotes, str):
quotes = [quotes]
return answer, quotes
def separate_answer_quotes(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
try:
model_raw_json = json.loads(answer_raw)
return extract_answer_quotes_json(model_raw_json)
except ValueError:
return extract_answer_quotes_freeform(answer_raw)
def match_quotes_to_docs(
quotes: list[str],
chunks: list[InferenceChunk],
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False,
prefix_only_length: int = 100,
) -> Dict[str, Dict[str, Union[str, int, None]]]:
quotes_dict: dict[str, dict[str, Union[str, int, None]]] = {}
for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent)
for chunk in chunks:
if not chunk.source_links:
continue
quote_clean = shared_precompare_cleanup(
clean_model_quote(quote, trim_length=prefix_only_length)
)
chunk_clean = shared_precompare_cleanup(chunk.content)
# Finding the offset of the quote in the plain text
if fuzzy_search:
re_search_str = (
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
)
found = regex.search(re_search_str, chunk_clean)
if not found:
continue
offset = found.span()[0]
else:
if quote_clean not in chunk_clean:
continue
offset = chunk_clean.index(quote_clean)
# Extracting the link from the offset
curr_link = None
for link_offset, link in chunk.source_links.items():
# Should always find one because offset is at least 0 and there must be a 0 link_offset
if int(link_offset) <= offset:
curr_link = link
else:
quotes_dict[quote] = {
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
BLURB: chunk.blurb,
}
break
quotes_dict[quote] = {
DOCUMENT_ID: chunk.document_id,
SOURCE_LINK: curr_link,
SOURCE_TYPE: chunk.source_type,
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
BLURB: chunk.blurb,
}
break
return quotes_dict
def process_answer(
answer_raw: str, chunks: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
answer, quote_strings = separate_answer_quotes(answer_raw)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")
else:
logger.debug("No answer extracted from raw output")
return None, None
logger.info(f"Answer: {answer}")
if not quote_strings:
logger.debug("No quotes extracted from raw output")
return answer, None
logger.info(f"All quotes (including unmatched): {quote_strings}")
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
logger.info(f"Final quotes dict: {quotes_dict}")
return answer, quotes_dict
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
# If the previous character is an escape token, don't consider the first character of next_token
if answer_so_far and answer_so_far[-1] == "\\":
next_token = next_token[1:]
if '"' in next_token:
return True
return False
F = TypeVar("F", bound=Callable)
Answer = str
Quotes = dict[str, dict[str, str | int | None]]
ModelType = Literal["ChatCompletion", "Completion"]
PromptProcessor = Callable[[str, list[str]], str]
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"temperature": 0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
**kwargs,
}
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
@wraps(openai_call)
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
try:
# if streamed, the call returns a generator
if kwargs.get("stream"):
def _generator() -> Generator[Any, None, None]:
yield from openai_call(*args, **kwargs)
return _generator()
return openai_call(*args, **kwargs)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Timeout:
logger.exception("OpenAI API timed out for query: %s", query)
raise
except Exception as e:
logger.exception("Unexpected error with OpenAI API for query: %s", query)
raise
return cast(F, wrapped_call)
# used to check if the QAModel is an OpenAI model
class OpenAIQAModel(QAModel):
pass
class OpenAICompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[
[str, list[InferenceChunk], bool], str
] = json_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
api_key: str | None = None,
timeout: int | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.timeout = timeout
self.include_metadata = include_metadata
try:
self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError:
raise OpenAIKeyMissing()
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
filled_prompt = self.prompt_processor(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(str, response["choices"][0]["text"]).strip()
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)
model_output: str = ""
found_answer_start = False
found_answer_end = False
# iterate through the stream of events
for event in response:
event_text = cast(str, event["choices"][0]["text"])
model_previous = model_output
model_output += event_text
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
if answer:
logger.info(answer)
else:
logger.warning(
"Answer extraction from model output failed, most likely no quotes provided"
)
if quotes_dict is None:
yield {}
else:
yield quotes_dict
class OpenAIChatCompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: Callable[
[str, list[InferenceChunk], bool], list[dict[str, str]]
] = json_chat_processor,
model_version: str = OPENAI_MODEL_VERSION,
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
reflexion_try_count: int = 0,
api_key: str | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.reflexion_try_count = reflexion_try_count
self.timeout = timeout
self.include_metadata = include_metadata
try:
self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError:
raise OpenAIKeyMissing()
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
messages = self.prompt_processor(query, context_docs, self.include_metadata)
logger.debug(json.dumps(messages, indent=4))
model_output = ""
for _ in range(self.reflexion_try_count + 1):
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(
str, response["choices"][0]["message"]["content"]
).strip()
assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_chat_reflexion_msg()])
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
messages = self.prompt_processor(query, context_docs, self.include_metadata)
logger.debug(json.dumps(messages, indent=4))
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=self.api_key,
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)
model_output: str = ""
found_answer_start = False
found_answer_end = False
for event in response:
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
if (
"content" not in event_dict
): # could be a role message or empty termination
continue
event_text = event_dict["content"]
model_previous = model_output
model_output += event_text
logger.debug(f"GPT returned token: {event_text}")
if not found_answer_start and '{"answer":"' in model_output.replace(
" ", ""
).replace("\n", ""):
# Note, if the token that completes the pattern has additional text, for example if the token is "?
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
# event that the model outputs the UNCERTAINTY_PAT
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if stream_answer_end(model_previous, event_text):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": event_text}
logger.debug(model_output)
_, quotes_dict = process_answer(model_output, context_docs)
yield {} if quotes_dict is None else quotes_dict

View File

@@ -0,0 +1,280 @@
import json
from abc import ABC
from collections.abc import Callable
from collections.abc import Generator
from functools import wraps
from typing import Any
from typing import cast
from typing import TypeVar
import openai
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import INCLUDE_METADATA
from danswer.configs.app_configs import OPENAI_API_KEY
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
from danswer.direct_qa.qa_prompts import JsonChatProcessor
from danswer.direct_qa.qa_prompts import JsonProcessor
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
from danswer.direct_qa.qa_utils import process_answer
from danswer.direct_qa.qa_utils import process_model_tokens
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from openai.error import AuthenticationError
from openai.error import Timeout
logger = setup_logger()
F = TypeVar("F", bound=Callable)
def get_openai_api_key() -> str:
return OPENAI_API_KEY or cast(
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
)
def _ensure_openai_api_key(api_key: str | None) -> str:
try:
return api_key or get_openai_api_key()
except ConfigNotFoundError:
raise OpenAIKeyMissing()
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
"""
Utility to add in some common default values so they don't have to be set every time.
"""
return {
"temperature": 0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
**kwargs,
}
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
@wraps(openai_call)
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
try:
# if streamed, the call returns a generator
if kwargs.get("stream"):
def _generator() -> Generator[Any, None, None]:
yield from openai_call(*args, **kwargs)
return _generator()
return openai_call(*args, **kwargs)
except AuthenticationError:
logger.exception("Failed to authenticate with OpenAI API")
raise
except Timeout:
logger.exception("OpenAI API timed out for query: %s", query)
raise
except Exception:
logger.exception("Unexpected error with OpenAI API for query: %s", query)
raise
return cast(F, wrapped_call)
# used to check if the QAModel is an OpenAI model
class OpenAIQAModel(QAModel, ABC):
pass
class OpenAICompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
api_key: str | None = None,
timeout: int | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.timeout = timeout
self.include_metadata = include_metadata
try:
self.api_key = api_key or get_openai_api_key()
except ConfigNotFoundError:
raise OpenAIKeyMissing()
@staticmethod
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
for event in response:
yield event["choices"][0]["text"]
@log_function_time()
def answer_question(
self, query: str, context_docs: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(str, response["choices"][0]["text"]).strip()
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
filled_prompt = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(filled_prompt)
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.Completion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
prompt=filled_prompt,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)
tokens = self._generate_tokens_from_response(response)
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)
class OpenAIChatCompletionQA(OpenAIQAModel):
def __init__(
self,
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
model_version: str = GEN_AI_MODEL_VERSION,
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
timeout: int | None = None,
reflexion_try_count: int = 0,
api_key: str | None = None,
include_metadata: bool = INCLUDE_METADATA,
) -> None:
self.prompt_processor = prompt_processor
self.model_version = model_version
self.max_output_tokens = max_output_tokens
self.reflexion_try_count = reflexion_try_count
self.timeout = timeout
self.include_metadata = include_metadata
self.api_key = api_key
@staticmethod
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
for event in response:
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
if (
"content" not in event_dict
): # could be a role message or empty termination
continue
yield event_dict["content"]
@log_function_time()
def answer_question(
self,
query: str,
context_docs: list[InferenceChunk],
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
messages = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(json.dumps(messages, indent=4))
model_output = ""
for _ in range(self.reflexion_try_count + 1):
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
),
)
model_output = cast(
str, response["choices"][0]["message"]["content"]
).strip()
assistant_msg = {"content": model_output, "role": "assistant"}
messages.extend([assistant_msg, get_json_chat_reflexion_msg()])
logger.info(
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
)
logger.debug(model_output)
answer, quotes_dict = process_answer(model_output, context_docs)
return answer, quotes_dict
def answer_question_stream(
self, query: str, context_docs: list[InferenceChunk]
) -> Generator[dict[str, Any] | None, None, None]:
messages = self.prompt_processor.fill_prompt(
query, context_docs, self.include_metadata
)
logger.debug(json.dumps(messages, indent=4))
openai_call = _handle_openai_exceptions_wrapper(
openai_call=openai.ChatCompletion.create,
query=query,
)
response = openai_call(
**_build_openai_settings(
api_key=_ensure_openai_api_key(self.api_key),
messages=messages,
model=self.model_version,
max_tokens=self.max_output_tokens,
request_timeout=self.timeout,
stream=True,
),
)
tokens = self._generate_tokens_from_response(response)
yield from process_model_tokens(
tokens=tokens,
context_docs=context_docs,
is_json_prompt=self.prompt_processor.specifies_json_output,
)

View File

@@ -1,3 +1,4 @@
import abc
import json
from danswer.chunking.models import InferenceChunk
@@ -16,10 +17,10 @@ QUOTE_PAT = "Quote:"
BASE_PROMPT = (
f"Answer the query based on provided documents and quote relevant sections. "
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
f'Respond with "?" for the answer if the query cannot be answered based on the documents. '
f"The quotes must be EXACT substrings from the documents."
)
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
SAMPLE_JSON_RESPONSE = {
@@ -31,7 +32,23 @@ SAMPLE_JSON_RESPONSE = {
}
def add_metadata_section(
def _append_acknowledge_doc_messages(
current_messages: list[dict[str, str]], new_chunk_content: str
) -> list[dict[str, str]]:
updated_messages = current_messages.copy()
updated_messages.extend(
[
{
"role": "user",
"content": new_chunk_content,
},
{"role": "assistant", "content": "Acknowledged"},
]
)
return updated_messages
def _add_metadata_section(
prompt_current: str,
chunk: InferenceChunk,
prepend_tab: bool = False,
@@ -67,192 +84,313 @@ def add_metadata_section(
return prompt_current
def json_processor(
question: str,
chunks: list[InferenceChunk],
include_metadata: bool = False,
include_sep: bool = True,
) -> str:
prompt = (
BASE_PROMPT + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
class PromptProcessor(abc.ABC):
"""Take the most relevant chunks and fills out a LLM prompt using the chunk contents
and optionally metadata about the chunk"""
for chunk in chunks:
prompt += f"\n\n{DOC_SEP_PAT}\n"
if include_metadata:
prompt = add_metadata_section(
prompt, chunk, prepend_tab=False, include_sep=include_sep
)
@property
@abc.abstractmethod
def specifies_json_output(self) -> bool:
raise NotImplementedError
prompt += chunk.content
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
return prompt
@staticmethod
@abc.abstractmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str | list[dict[str, str]]:
raise NotImplementedError
def json_chat_processor(
question: str,
chunks: list[InferenceChunk],
include_metadata: bool = False,
include_sep: bool = False,
) -> list[dict[str, str]]:
metadata_prompt_section = "with metadata and contents " if include_metadata else ""
intro_msg = (
f"You are a Question Answering assistant that answers queries based on the provided most relevant documents.\n"
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
)
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)
task_msg = (
"Now answer the next user query based on documents above and quote relevant sections.\n"
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative and concise.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
"If the query cannot be answered based on the documents, respond with "
f"{complete_answer_not_found_response}\n"
"If the query requires aggregating the number of documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]
for chunk in chunks:
full_context = ""
if include_metadata:
full_context = add_metadata_section(
full_context, chunk, prepend_tab=False, include_sep=include_sep
)
full_context += chunk.content
messages.extend(
[
{
"role": "user",
"content": full_context,
},
{"role": "assistant", "content": "Acknowledged"},
]
)
messages.append({"role": "system", "content": task_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
return messages
class NonChatPromptProcessor(PromptProcessor):
@staticmethod
@abc.abstractmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
raise NotImplementedError
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may use again in future
class ChatPromptProcessor(PromptProcessor):
@staticmethod
@abc.abstractmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> list[dict[str, str]]:
raise NotImplementedError
# Chain of Thought approach works however has higher token cost (more expensive) and is slower.
# Should use this one if users ask questions that require logical reasoning.
def json_cot_variant_processor(question: str, documents: list[str]) -> str:
prompt = (
f"Answer the query based on provided documents and quote relevant sections. "
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
class JsonProcessor(NonChatPromptProcessor):
@property
def specifies_json_output(self) -> bool:
return True
for document in documents:
prompt += f"\n{DOC_SEP_PAT}\n{document}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += "Reasoning:\n"
return prompt
# This one seems largely useless with a single example
# Model seems to take the one example of answering Yes and just does that too.
def json_reflexion_processor(question: str, documents: list[str]) -> str:
reflexion_str = "Does this fully answer the user query?"
prompt = (
BASE_PROMPT
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
f"If No, generate a better json response to the query.\n"
f"Sample question and response:\n"
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
f"{reflexion_str} Yes\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for document in documents:
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{document}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
return prompt
# Initial design, works pretty well but not optimal
def freeform_processor(question: str, documents: list[str]) -> str:
prompt = (
f"Answer the query based on the documents below and quote the documents segments containing the answer. "
f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. '
f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. '
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for document in documents:
prompt += f"\n{DOC_SEP_PAT}\n{document}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += f"{ANSWER_PAT}\n"
return prompt
def freeform_chat_processor(
question: str, documents: list[str]
) -> list[dict[str, str]]:
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
role_msg = (
f"You are a Question Answering assistant that answers queries based on provided documents. "
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
f"Answer the question directly and concisely. "
f"Each quote should be a single continuous segment from a document. "
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
f"An example of quote sections may look like:\n{sample_quote}"
)
messages = [
{"role": "system", "content": role_msg},
]
for document in documents:
messages.extend(
[
{
"role": "user",
"content": f"Acknowledge the following document:\n{document}",
},
{"role": "assistant", "content": "Acknowledged"},
]
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
prompt = (
BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
messages.append(
{
"role": "user",
"content": f"Please now answer the following query based on the previously provided "
f"documents and quote the relevant sections of the documents\n{question}",
},
)
for chunk in chunks:
prompt += f"\n\n{DOC_SEP_PAT}\n"
if include_metadata:
prompt = _add_metadata_section(
prompt, chunk, prepend_tab=False, include_sep=True
)
return messages
prompt += chunk.content
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
return prompt
# Not very useful, have not seen it improve an answer based on this
# Sometimes gpt-3.5-turbo will just answer something worse like:
# 'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
# document. No revision is needed.'
def get_chat_reflexion_msg() -> dict[str, str]:
class JsonChatProcessor(ChatPromptProcessor):
@property
def specifies_json_output(self) -> bool:
return True
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> list[dict[str, str]]:
metadata_prompt_section = (
"with metadata and contents " if include_metadata else ""
)
intro_msg = (
f"You are a Question Answering assistant that answers queries "
f"based on the provided most relevant documents.\n"
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
)
complete_answer_not_found_response = (
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
)
task_msg = (
"Now answer the next user query based on documents above and quote relevant sections.\n"
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
"All quotes MUST be EXACT substrings from provided documents.\n"
"Your responses should be informative and concise.\n"
"You MUST prioritize information from provided documents over internal knowledge.\n"
"If the query cannot be answered based on the documents, respond with "
f"{complete_answer_not_found_response}\n"
"If the query requires aggregating the number of documents, respond with "
'{"answer": "Aggregations not supported", "quotes": []}\n'
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
)
messages = [{"role": "system", "content": intro_msg}]
for chunk in chunks:
full_context = ""
if include_metadata:
full_context = _add_metadata_section(
full_context, chunk, prepend_tab=False, include_sep=False
)
full_context += chunk.content
messages = _append_acknowledge_doc_messages(messages, full_context)
messages.append({"role": "system", "content": task_msg})
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
return messages
class WeakModelFreeformProcessor(NonChatPromptProcessor):
"""Avoid using this one if the model is capable of using another prompt
Intended for models that can't follow complex instructions or have short context windows
This prompt only uses 1 reference document chunk
"""
@property
def specifies_json_output(self) -> bool:
return False
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
first_chunk_content = chunks[0].content if chunks else "No Document Provided"
prompt = (
f"Reference Document:\n{first_chunk_content}\n{GENERAL_SEP_PAT}"
f"Answer the user query below based on the reference document above. "
f'Respond with an "{ANSWER_PAT}" section and '
f'as many "{QUOTE_PAT}" sections as needed to support the answer.'
f"\n{GENERAL_SEP_PAT}"
f"{QUESTION_PAT} {question}\n"
f"{ANSWER_PAT}"
)
return prompt
class WeakChatModelFreeformProcessor(ChatPromptProcessor):
"""Avoid using this one if the model is capable of using another prompt
Intended for models that can't follow complex instructions or have short context windows
This prompt only uses 1 reference document chunk
"""
@property
def specifies_json_output(self) -> bool:
return False
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> list[dict[str, str]]:
first_chunk_content = chunks[0].content if chunks else "No Document Provided"
intro_msg = (
f"You are a question answering assistant. "
f'Respond to the query with an "{ANSWER_PAT}" section and '
f'as many "{QUOTE_PAT}" sections as needed to support the answer. '
f"Answer the user query based on the following document:\n\n{first_chunk_content}"
)
messages = [{"role": "system", "content": intro_msg}]
user_query = f"{QUESTION_PAT} {question}"
messages.append({"role": "user", "content": user_query})
return messages
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may revisit in future
class FreeformProcessor(NonChatPromptProcessor):
@property
def specifies_json_output(self) -> bool:
return False
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
prompt = (
f"Answer the query based on the documents below and quote the documents segments containing the answer. "
f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. '
f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. '
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for chunk in chunks:
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += f"{ANSWER_PAT}\n"
return prompt
class FreeformChatProcessor(ChatPromptProcessor):
@property
def specifies_json_output(self) -> bool:
return False
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> list[dict[str, str]]:
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
role_msg = (
f"You are a Question Answering assistant that answers queries based on provided documents. "
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
f"Answer the question directly and concisely. "
f"Each quote should be a single continuous segment from a document. "
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
f"An example of quote sections may look like:\n{sample_quote}"
)
messages = [
{"role": "system", "content": role_msg},
]
for chunk in chunks:
messages = _append_acknowledge_doc_messages(messages, chunk.content)
messages.append(
{
"role": "user",
"content": f"Please now answer the following query based on the previously provided "
f"documents and quote the relevant sections of the documents\n{question}",
},
)
return messages
class JsonCOTProcessor(NonChatPromptProcessor):
"""Chain of Thought allows model to explain out its reasoning to handle harder tests.
This prompt type works however has higher token cost (more expensive) and is slower.
Consider this one if users ask questions that require logical reasoning."""
@property
def specifies_json_output(self) -> bool:
return True
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
prompt = (
f"Answer the query based on provided documents and quote relevant sections. "
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for chunk in chunks:
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
prompt += "Reasoning:\n"
return prompt
class JsonReflexionProcessor(NonChatPromptProcessor):
"""Reflexion prompting to attempt to have model evaluate its own answer.
This one seems largely useless when only given a single example
Model seems to take the one example of answering "Yes" and just does that too."""
@property
def specifies_json_output(self) -> bool:
return True
@staticmethod
def fill_prompt(
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
) -> str:
reflexion_str = "Does this fully answer the user query?"
prompt = (
BASE_PROMPT
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
f"If No, generate a better json response to the query.\n"
f"Sample question and response:\n"
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
f"{reflexion_str} Yes\n\n"
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
)
for chunk in chunks:
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}"
prompt += "\n\n---\n\n"
prompt += f"{QUESTION_PAT}\n{question}\n"
return prompt
def get_json_chat_reflexion_msg() -> dict[str, str]:
"""With the models tried (curent as of Jul 2023), this has not been very useful.
Have not seen any answers improved based on this.
For models like gpt-3.5-turbo, it will often answer something like:
'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
document. No revision is needed.'"""
reflexion_content = (
"Is the assistant response a valid json that fully answer the user query? "
"If the response needs to be fixed or if an improvement is possible, provide a revised json. "

View File

@@ -0,0 +1,247 @@
import json
import math
import re
from collections.abc import Generator
from typing import Any
from typing import Optional
from typing import Tuple
import regex
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import BLURB
from danswer.configs.constants import DOCUMENT_ID
from danswer.configs.constants import SEMANTIC_IDENTIFIER
from danswer.configs.constants import SOURCE_LINK
from danswer.configs.constants import SOURCE_TYPE
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.utils.logger import setup_logger
from danswer.utils.text_processing import clean_model_quote
from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
def structure_quotes_for_response(
quotes: list[DanswerQuote] | None,
) -> dict[str, dict[str, str | None]]:
if quotes is None:
return {}
response_quotes = {}
for quote in quotes:
response_quotes[quote.quote] = {
DOCUMENT_ID: quote.document_id,
SOURCE_LINK: quote.link,
SOURCE_TYPE: quote.source_type,
SEMANTIC_IDENTIFIER: quote.semantic_identifier,
BLURB: quote.blurb,
}
return response_quotes
def extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
"""Splits the model output into an Answer and 0 or more Quote sections.
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
"""
# If no answer section, don't care about the quote
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
return None, None
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
answer_raw = answer_raw[len(ANSWER_PAT) :]
# Accept quote sections starting with the lower case version
answer_raw = answer_raw.replace(
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
) # Just in case model unreliable
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
sections_clean = [
str(section).strip() for section in sections if str(section).strip()
]
if not sections_clean:
return None, None
answer = str(sections_clean[0])
if len(sections) == 1:
return answer, None
return answer, sections_clean[1:]
def extract_answer_quotes_json(
answer_dict: dict[str, str | list[str]]
) -> Tuple[Optional[str], Optional[list[str]]]:
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
answer = str(answer_dict.get("answer"))
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
if isinstance(quotes, str):
quotes = [quotes]
return answer, quotes
def separate_answer_quotes(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:
try:
model_raw_json = json.loads(answer_raw)
return extract_answer_quotes_json(model_raw_json)
except ValueError:
return extract_answer_quotes_freeform(answer_raw)
def match_quotes_to_docs(
quotes: list[str],
chunks: list[InferenceChunk],
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
fuzzy_search: bool = False,
prefix_only_length: int = 100,
) -> list[DanswerQuote]:
danswer_quotes = []
for quote in quotes:
max_edits = math.ceil(float(len(quote)) * max_error_percent)
for chunk in chunks:
if not chunk.source_links:
continue
quote_clean = shared_precompare_cleanup(
clean_model_quote(quote, trim_length=prefix_only_length)
)
chunk_clean = shared_precompare_cleanup(chunk.content)
# Finding the offset of the quote in the plain text
if fuzzy_search:
re_search_str = (
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
)
found = regex.search(re_search_str, chunk_clean)
if not found:
continue
offset = found.span()[0]
else:
if quote_clean not in chunk_clean:
continue
offset = chunk_clean.index(quote_clean)
# Extracting the link from the offset
curr_link = None
for link_offset, link in chunk.source_links.items():
# Should always find one because offset is at least 0 and there must be a 0 link_offset
if int(link_offset) <= offset:
curr_link = link
else:
danswer_quotes.append(
DanswerQuote(
quote=quote,
document_id=chunk.document_id,
link=curr_link,
source_type=chunk.source_type,
semantic_identifier=chunk.semantic_identifier,
blurb=chunk.blurb,
)
)
break
# If the offset is larger than the start of the last quote, it must be the last one
danswer_quotes.append(
DanswerQuote(
quote=quote,
document_id=chunk.document_id,
link=curr_link,
source_type=chunk.source_type,
semantic_identifier=chunk.semantic_identifier,
blurb=chunk.blurb,
)
)
break
return danswer_quotes
def process_answer(
answer_raw: str, chunks: list[InferenceChunk]
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
answer, quote_strings = separate_answer_quotes(answer_raw)
if answer == UNCERTAINTY_PAT or not answer:
if answer == UNCERTAINTY_PAT:
logger.debug("Answer matched UNCERTAINTY_PAT")
else:
logger.debug("No answer extracted from raw output")
return DanswerAnswer(answer=None), []
logger.info(f"Answer: {answer}")
if not quote_strings:
logger.debug("No quotes extracted from raw output")
return DanswerAnswer(answer=answer), []
logger.info(f"All quotes (including unmatched): {quote_strings}")
quotes = match_quotes_to_docs(quote_strings, chunks)
logger.info(f"Final quotes: {quotes}")
return DanswerAnswer(answer=answer), quotes
def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
next_token = next_token.replace('\\"', "")
# If the previous character is an escape token, don't consider the first character of next_token
if answer_so_far and answer_so_far[-1] == "\\":
next_token = next_token[1:]
if '"' in next_token:
return True
return False
def extract_quotes_from_completed_token_stream(
model_output: str, context_chunks: list[InferenceChunk]
) -> list[DanswerQuote]:
logger.debug(model_output)
answer, quotes = process_answer(model_output, context_chunks)
if answer:
logger.info(answer)
elif model_output:
logger.warning("Answer extraction from model output failed.")
return quotes
def process_model_tokens(
tokens: Generator[str, None, None],
context_docs: list[InferenceChunk],
is_json_prompt: bool = True,
) -> Generator[dict[str, Any], None, None]:
"""Yields Answer tokens back out in a dict for streaming to frontend
When Answer section ends, yields dict with answer_finished key
Collects all the tokens at the end to form the complete model output"""
model_output: str = ""
found_answer_start = False if is_json_prompt else True
found_answer_end = False
for token in tokens:
model_previous = model_output
model_output += token
trimmed_combine = model_output.replace(" ", "").replace("\n", "")
if not found_answer_start and '{"answer":"' in trimmed_combine:
# Note, if the token that completes the pattern has additional text, for example if the token is "?
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
# event that the model outputs the UNCERTAINTY_PAT
found_answer_start = True
continue
if found_answer_start and not found_answer_end:
if (is_json_prompt and stream_json_answer_end(model_previous, token)) or (
not is_json_prompt and f"\n{QUOTE_PAT}" in model_output
):
found_answer_end = True
yield {"answer_finished": True}
continue
yield {"answer_data": token}
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
yield structure_quotes_for_response(quotes)

View File

@@ -58,7 +58,7 @@ def _build_custom_semantic_identifier(
def _process_quotes(
quotes: dict[str, dict[str, str | int | None]] | None
quotes: dict[str, dict[str, str | None]] | None
) -> tuple[str | None, list[str]]:
if not quotes:
return None, []

View File

@@ -9,19 +9,21 @@ from danswer.auth.users import google_oauth_client
from danswer.configs.app_configs import APP_HOST
from danswer.configs.app_configs import APP_PORT
from danswer.configs.app_configs import DISABLE_AUTH
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import ENABLE_OAUTH
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
from danswer.configs.app_configs import SECRET
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
from danswer.configs.app_configs import WEB_DOMAIN
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
from danswer.datastores.qdrant.indexing import list_qdrant_collections
from danswer.datastores.typesense.store import check_typesense_collection_exist
from danswer.datastores.typesense.store import create_typesense_collection
from danswer.db.credentials import create_initial_public_credential
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
from danswer.direct_qa.llm import get_openai_api_key
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.direct_qa import get_default_backend_qa_model
from danswer.server.event_loading import router as event_processing_router
from danswer.server.health import router as health_router
from danswer.server.manage import router as admin_router
@@ -121,21 +123,29 @@ def get_application() -> FastAPI:
warm_up_models,
)
from danswer.datastores.qdrant.indexing import create_qdrant_collection
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
if DISABLE_GENERATIVE_AI:
logger.info("Generative AI Q&A disabled")
else:
logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
auth_status = "off" if DISABLE_AUTH else "on"
logger.info(f"User auth is turned {auth_status}")
logger.info(f"User Authentication is turned {auth_status}")
if not ENABLE_OAUTH:
logger.debug("OAuth is turned off")
else:
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
logger.warning("OAuth is turned on but incorrectly configured")
if not DISABLE_AUTH:
if not ENABLE_OAUTH:
logger.warning("OAuth is turned off")
else:
logger.debug("OAuth is turned on")
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
logger.warning("OAuth is turned on but incorrectly configured")
else:
logger.debug("OAuth is turned on")
logger.info("Warming up local NLP models.")
warm_up_models()
qa_model = get_default_backend_qa_model()
qa_model.warm_up_model()
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
nltk.download("stopwords")

View File

@@ -1,9 +1,5 @@
from typing import Any
from danswer.connectors.slack.connector import get_channel_info
from danswer.connectors.slack.connector import get_thread
from danswer.connectors.slack.connector import thread_to_doc
from danswer.utils.indexing_pipeline import build_indexing_pipeline
from danswer.utils.logger import setup_logger
from fastapi import APIRouter
from pydantic import BaseModel

View File

@@ -38,8 +38,10 @@ from danswer.db.engine import get_session
from danswer.db.engine import get_sqlalchemy_async_engine
from danswer.db.index_attempt import create_index_attempt
from danswer.db.models import User
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
from danswer.direct_qa.llm import get_openai_api_key
from danswer.direct_qa import check_model_api_key_is_valid
from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.open_ai import get_openai_api_key
from danswer.direct_qa.open_ai import OpenAIQAModel
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey
@@ -102,7 +104,7 @@ def check_google_app_credentials_exist(
) -> dict[str, str]:
try:
return {"client_id": get_google_app_cred().web.client_id}
except ConfigNotFoundError as e:
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Google App Credentials not found")
@@ -295,19 +297,13 @@ def validate_existing_openai_api_key(
_: User = Depends(current_admin_user),
) -> None:
# OpenAI key is only used for generative QA, so no need to validate this
# if it's turned off
if DISABLE_GENERATIVE_AI:
# if it's turned off or if a non-OpenAI model is being used
if DISABLE_GENERATIVE_AI or not isinstance(
get_default_backend_qa_model(), OpenAIQAModel
):
return
# always check if key exists
try:
openai_api_key = get_openai_api_key()
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
# don't call OpenAI every single time, only validate every so often
# Only validate every so often
check_key_time = "openai_api_key_last_check_time"
kv_store = get_dynamic_config_store()
curr_time = datetime.now()
@@ -320,10 +316,17 @@ def validate_existing_openai_api_key(
# First time checking the key, nothing unusual
pass
try:
openai_api_key = get_openai_api_key()
except ConfigNotFoundError:
raise HTTPException(status_code=404, detail="Key not found")
except ValueError as e:
raise HTTPException(status_code=400, detail=str(e))
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
try:
is_valid = check_openai_api_key_is_valid(openai_api_key)
is_valid = check_model_api_key_is_valid(openai_api_key)
except ValueError:
# this is the case where they aren't using an OpenAI-based model
is_valid = True
@@ -356,7 +359,7 @@ def store_openai_api_key(
_: User = Depends(current_admin_user),
) -> None:
try:
is_valid = check_openai_api_key_is_valid(request.api_key)
is_valid = check_model_api_key_is_valid(request.api_key)
if not is_valid:
raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)

View File

@@ -102,8 +102,8 @@ class SearchResponse(BaseModel):
class QAResponse(SearchResponse):
answer: str | None
quotes: dict[str, dict[str, str | int | None]] | None
answer: str | None # DanswerAnswer
quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote
predicted_flow: QueryFlow
predicted_search: SearchType
error_msg: str | None = None

View File

@@ -1,10 +1,10 @@
import json
from collections.abc import Generator
from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.qdrant.store import QdrantIndex
from danswer.datastores.typesense.store import TypesenseIndex
from danswer.db.models import User
@@ -12,7 +12,6 @@ from danswer.direct_qa import get_default_backend_qa_model
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.llm import get_json_line
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.keyword_search import retrieve_keyword_documents
@@ -35,6 +34,10 @@ logger = setup_logger()
router = APIRouter()
def get_json_line(json_dict: dict) -> str:
return json.dumps(json_dict) + "\n"
@router.get("/search-intent")
def get_search_type(
question: QuestionRequest = Depends(), _: User = Depends(current_user)
@@ -162,7 +165,7 @@ def stream_direct_qa(
return
try:
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
qa_model = get_default_backend_qa_model()
except (UnknownModelError, OpenAIKeyMissing) as e:
logger.exception("Unable to get QA model")
yield get_json_line({"error": str(e)})

View File

@@ -10,6 +10,7 @@ filelock==3.12.0
google-api-python-client==2.86.0
google-auth-httplib2==0.1.0
google-auth-oauthlib==1.0.0
gpt4all==1.0.5
httpcore==0.16.3
httpx==0.23.3
httpx-oauth==0.11.2

View File

@@ -2,8 +2,8 @@ import textwrap
import unittest
from danswer.chunking.models import InferenceChunk
from danswer.direct_qa.llm import match_quotes_to_docs
from danswer.direct_qa.llm import separate_answer_quotes
from danswer.direct_qa.qa_utils import match_quotes_to_docs
from danswer.direct_qa.qa_utils import separate_answer_quotes
class TestQAPostprocessing(unittest.TestCase):