mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-19 12:03:54 +02:00
Support GPT4All in memory (#230)
This commit is contained in:
@@ -4,6 +4,7 @@ from danswer.connectors.factory import instantiate_connector
|
||||
from danswer.connectors.interfaces import LoadConnector
|
||||
from danswer.connectors.interfaces import PollConnector
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.datastores.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.db.connector import disable_connector
|
||||
from danswer.db.connector import fetch_connectors
|
||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||
@@ -21,7 +22,6 @@ from danswer.db.index_attempt import mark_attempt_succeeded
|
||||
from danswer.db.models import Connector
|
||||
from danswer.db.models import IndexAttempt
|
||||
from danswer.db.models import IndexingStatus
|
||||
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import setup_logger
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
|
@@ -31,10 +31,16 @@ CROSS_EMBED_CONTEXT_SIZE = 512
|
||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||
|
||||
# QA Model API Configs
|
||||
# https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||
# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models
|
||||
# Valid list:
|
||||
# - openai-completion
|
||||
# - openai-chat-completion
|
||||
# - gpt4all-completion
|
||||
# - gpt4all-chat-completion
|
||||
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
|
||||
OPENAI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-3.5-turbo")
|
||||
OPENAI_MAX_OUTPUT_TOKENS = 512
|
||||
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
|
||||
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
|
||||
GEN_AI_MAX_OUTPUT_TOKENS = 512
|
||||
|
||||
# Danswer custom Deep Learning Models
|
||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||
|
@@ -1,18 +1,49 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.llm import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.llm import OpenAICompletionQA
|
||||
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
|
||||
def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||
if not model_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5)
|
||||
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
except Timeout:
|
||||
pass
|
||||
|
||||
return False
|
||||
|
||||
|
||||
def get_default_backend_qa_model(
|
||||
internal_model: str = INTERNAL_MODEL_VERSION, **kwargs: Any
|
||||
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||
api_key: str | None = None,
|
||||
timeout: int = QA_TIMEOUT,
|
||||
**kwargs: Any
|
||||
) -> QAModel:
|
||||
if internal_model == "openai-completion":
|
||||
return OpenAICompletionQA(**kwargs)
|
||||
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == "openai-chat-completion":
|
||||
return OpenAIChatCompletionQA(**kwargs)
|
||||
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||
elif internal_model == "gpt4all-completion":
|
||||
return GPT4AllCompletionQA(**kwargs)
|
||||
elif internal_model == "gpt4all-chat-completion":
|
||||
return GPT4AllChatCompletionQA(**kwargs)
|
||||
else:
|
||||
raise UnknownModelError(internal_model)
|
||||
|
@@ -1,13 +1,13 @@
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.qa_utils import structure_quotes_for_response
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
from danswer.search.models import QueryFlow
|
||||
@@ -26,7 +26,6 @@ logger = setup_logger()
|
||||
def answer_question(
|
||||
question: QuestionRequest,
|
||||
user: User | None,
|
||||
qa_model_timeout: int = QA_TIMEOUT,
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
) -> QAResponse:
|
||||
query = question.query
|
||||
@@ -74,7 +73,7 @@ def answer_question(
|
||||
)
|
||||
|
||||
try:
|
||||
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
|
||||
qa_model = get_default_backend_qa_model()
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
return QAResponse(
|
||||
answer=None,
|
||||
@@ -102,8 +101,8 @@ def answer_question(
|
||||
error_msg = f"Error occurred in call to LLM - {e}"
|
||||
|
||||
return QAResponse(
|
||||
answer=answer,
|
||||
quotes=quotes,
|
||||
answer=answer.answer if answer else None,
|
||||
quotes=structure_quotes_for_response(quotes),
|
||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||
predicted_flow=predicted_flow,
|
||||
|
@@ -1,5 +1,10 @@
|
||||
class OpenAIKeyMissing(Exception):
|
||||
def __init__(self, msg: str = "Unable to find an OpenAI Key") -> None:
|
||||
default_msg = (
|
||||
"Unable to find existing OpenAI Key. "
|
||||
'A new key can be added from "Keys" section of the Admin Panel'
|
||||
)
|
||||
|
||||
def __init__(self, msg: str = default_msg) -> None:
|
||||
super().__init__(msg)
|
||||
|
||||
|
||||
|
173
backend/danswer/direct_qa/gpt_4_all.py
Normal file
173
backend/danswer/direct_qa/gpt_4_all.py
Normal file
@@ -0,0 +1,173 @@
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from gpt4all import GPT4All # type:ignore
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
GPT4ALL_MODEL: GPT4All | None = None
|
||||
|
||||
|
||||
def get_gpt_4_all_model(
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
) -> GPT4All:
|
||||
global GPT4ALL_MODEL
|
||||
if GPT4ALL_MODEL is None:
|
||||
GPT4ALL_MODEL = GPT4All(model_version)
|
||||
return GPT4ALL_MODEL
|
||||
|
||||
|
||||
def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"temp": 0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
class GPT4AllCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False, # gpt4all models can't handle this atm
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
|
||||
def warm_up_model(self) -> None:
|
||||
get_gpt_4_all_model(self.model_version)
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
model_output = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=filled_prompt, max_tokens=self.max_output_tokens
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
model_stream = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True
|
||||
),
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
||||
|
||||
class GPT4AllChatCompletionQA(QAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
include_metadata: bool = False, # gpt4all models can't handle this atm
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.include_metadata = include_metadata
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
with gen_ai_model.chat_session():
|
||||
context_msgs = filled_prompt[:-1]
|
||||
user_query = filled_prompt[-1].get("content")
|
||||
for message in context_msgs:
|
||||
gen_ai_model.current_chat_session.append(message)
|
||||
|
||||
model_output = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=user_query, max_tokens=self.max_output_tokens
|
||||
),
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||
|
||||
with gen_ai_model.chat_session():
|
||||
context_msgs = filled_prompt[:-1]
|
||||
user_query = filled_prompt[-1].get("content")
|
||||
for message in context_msgs:
|
||||
gen_ai_model.current_chat_session.append(message)
|
||||
|
||||
model_stream = gen_ai_model.generate(
|
||||
**_build_gpt4all_settings(
|
||||
prompt=user_query, max_tokens=self.max_output_tokens
|
||||
),
|
||||
)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=model_stream,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -1,17 +1,39 @@
|
||||
import abc
|
||||
from collections.abc import Generator
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerAnswer:
|
||||
answer: str | None
|
||||
|
||||
|
||||
@dataclass
|
||||
class DanswerQuote:
|
||||
# This is during inference so everything is a string by this point
|
||||
quote: str
|
||||
document_id: str
|
||||
link: str | None
|
||||
source_type: str
|
||||
semantic_identifier: str
|
||||
blurb: str
|
||||
|
||||
|
||||
class QAModel:
|
||||
def warm_up_model(self) -> None:
|
||||
"""This is called during server start up to load the models into memory
|
||||
pass if model is accessed via API"""
|
||||
pass
|
||||
|
||||
@abc.abstractmethod
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
raise NotImplementedError
|
||||
|
||||
@abc.abstractmethod
|
||||
|
@@ -1,25 +0,0 @@
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.llm import OpenAIQAModel
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
|
||||
def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
|
||||
if not openai_api_key:
|
||||
return False
|
||||
|
||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5)
|
||||
if not isinstance(qa_model, OpenAIQAModel):
|
||||
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
|
||||
|
||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||
for _ in range(2):
|
||||
try:
|
||||
qa_model.answer_question("Do not respond", [])
|
||||
return True
|
||||
except AuthenticationError:
|
||||
return False
|
||||
except Timeout:
|
||||
pass
|
||||
|
||||
return False
|
@@ -1,488 +0,0 @@
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import Dict
|
||||
from typing import Literal
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
from typing import TypeVar
|
||||
from typing import Union
|
||||
|
||||
import openai
|
||||
import regex
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import INCLUDE_METADATA
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import OPENAI_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
|
||||
from danswer.direct_qa.qa_prompts import json_chat_processor
|
||||
from danswer.direct_qa.qa_prompts import json_processor
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
from danswer.utils.timing import log_function_time
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def get_openai_api_key() -> str:
|
||||
return OPENAI_API_KEY or cast(
|
||||
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
||||
)
|
||||
|
||||
|
||||
def get_json_line(json_dict: dict) -> str:
|
||||
return json.dumps(json_dict) + "\n"
|
||||
|
||||
|
||||
def extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
null_answer_check = (
|
||||
answer_raw.replace(ANSWER_PAT, "").replace(QUOTE_PAT, "").strip()
|
||||
)
|
||||
|
||||
# If model just gives back the uncertainty pattern to signify answer isn't found or nothing at all
|
||||
# if null_answer_check == UNCERTAINTY_PAT or not null_answer_check:
|
||||
# return None, None
|
||||
|
||||
# If no answer section, don't care about the quote
|
||||
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
||||
return None, None
|
||||
|
||||
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
||||
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
||||
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
||||
|
||||
# Accept quote sections starting with the lower case version
|
||||
answer_raw = answer_raw.replace(
|
||||
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
||||
) # Just in case model unreliable
|
||||
|
||||
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
||||
sections_clean = [
|
||||
str(section).strip() for section in sections if str(section).strip()
|
||||
]
|
||||
if not sections_clean:
|
||||
return None, None
|
||||
|
||||
answer = str(sections_clean[0])
|
||||
if len(sections) == 1:
|
||||
return answer, None
|
||||
return answer, sections_clean[1:]
|
||||
|
||||
|
||||
def extract_answer_quotes_json(
|
||||
answer_dict: dict[str, str | list[str]]
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
||||
answer = str(answer_dict.get("answer"))
|
||||
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
||||
if isinstance(quotes, str):
|
||||
quotes = [quotes]
|
||||
return answer, quotes
|
||||
|
||||
|
||||
def separate_answer_quotes(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
try:
|
||||
model_raw_json = json.loads(answer_raw)
|
||||
return extract_answer_quotes_json(model_raw_json)
|
||||
except ValueError:
|
||||
return extract_answer_quotes_freeform(answer_raw)
|
||||
|
||||
|
||||
def match_quotes_to_docs(
|
||||
quotes: list[str],
|
||||
chunks: list[InferenceChunk],
|
||||
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
||||
fuzzy_search: bool = False,
|
||||
prefix_only_length: int = 100,
|
||||
) -> Dict[str, Dict[str, Union[str, int, None]]]:
|
||||
quotes_dict: dict[str, dict[str, Union[str, int, None]]] = {}
|
||||
for quote in quotes:
|
||||
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.source_links:
|
||||
continue
|
||||
|
||||
quote_clean = shared_precompare_cleanup(
|
||||
clean_model_quote(quote, trim_length=prefix_only_length)
|
||||
)
|
||||
chunk_clean = shared_precompare_cleanup(chunk.content)
|
||||
|
||||
# Finding the offset of the quote in the plain text
|
||||
if fuzzy_search:
|
||||
re_search_str = (
|
||||
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
||||
)
|
||||
found = regex.search(re_search_str, chunk_clean)
|
||||
if not found:
|
||||
continue
|
||||
offset = found.span()[0]
|
||||
else:
|
||||
if quote_clean not in chunk_clean:
|
||||
continue
|
||||
offset = chunk_clean.index(quote_clean)
|
||||
|
||||
# Extracting the link from the offset
|
||||
curr_link = None
|
||||
for link_offset, link in chunk.source_links.items():
|
||||
# Should always find one because offset is at least 0 and there must be a 0 link_offset
|
||||
if int(link_offset) <= offset:
|
||||
curr_link = link
|
||||
else:
|
||||
quotes_dict[quote] = {
|
||||
DOCUMENT_ID: chunk.document_id,
|
||||
SOURCE_LINK: curr_link,
|
||||
SOURCE_TYPE: chunk.source_type,
|
||||
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
||||
BLURB: chunk.blurb,
|
||||
}
|
||||
break
|
||||
quotes_dict[quote] = {
|
||||
DOCUMENT_ID: chunk.document_id,
|
||||
SOURCE_LINK: curr_link,
|
||||
SOURCE_TYPE: chunk.source_type,
|
||||
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
||||
BLURB: chunk.blurb,
|
||||
}
|
||||
break
|
||||
return quotes_dict
|
||||
|
||||
|
||||
def process_answer(
|
||||
answer_raw: str, chunks: list[InferenceChunk]
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw)
|
||||
if answer == UNCERTAINTY_PAT or not answer:
|
||||
if answer == UNCERTAINTY_PAT:
|
||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||
else:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
return None, None
|
||||
|
||||
logger.info(f"Answer: {answer}")
|
||||
if not quote_strings:
|
||||
logger.debug("No quotes extracted from raw output")
|
||||
return answer, None
|
||||
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
||||
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
|
||||
logger.info(f"Final quotes dict: {quotes_dict}")
|
||||
|
||||
return answer, quotes_dict
|
||||
|
||||
|
||||
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
next_token = next_token.replace('\\"', "")
|
||||
# If the previous character is an escape token, don't consider the first character of next_token
|
||||
if answer_so_far and answer_so_far[-1] == "\\":
|
||||
next_token = next_token[1:]
|
||||
if '"' in next_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
Answer = str
|
||||
Quotes = dict[str, dict[str, str | int | None]]
|
||||
ModelType = Literal["ChatCompletion", "Completion"]
|
||||
PromptProcessor = Callable[[str, list[str]], str]
|
||||
|
||||
|
||||
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
||||
@wraps(openai_call)
|
||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
try:
|
||||
# if streamed, the call returns a generator
|
||||
if kwargs.get("stream"):
|
||||
|
||||
def _generator() -> Generator[Any, None, None]:
|
||||
yield from openai_call(*args, **kwargs)
|
||||
|
||||
return _generator()
|
||||
return openai_call(*args, **kwargs)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Timeout:
|
||||
logger.exception("OpenAI API timed out for query: %s", query)
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
# used to check if the QAModel is an OpenAI model
|
||||
class OpenAIQAModel(QAModel):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAICompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: Callable[
|
||||
[str, list[InferenceChunk], bool], str
|
||||
] = json_processor,
|
||||
model_version: str = OPENAI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
api_key: str | None = None,
|
||||
timeout: int | None = None,
|
||||
include_metadata: bool = INCLUDE_METADATA,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.timeout = timeout
|
||||
self.include_metadata = include_metadata
|
||||
try:
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
filled_prompt = self.prompt_processor(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
prompt=filled_prompt,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = cast(str, response["choices"][0]["text"]).strip()
|
||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
filled_prompt = self.prompt_processor(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
prompt=filled_prompt,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
# iterate through the stream of events
|
||||
for event in response:
|
||||
event_text = cast(str, event["choices"][0]["text"])
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
|
||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||
" ", ""
|
||||
).replace("\n", ""):
|
||||
found_answer_start = True
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
if answer:
|
||||
logger.info(answer)
|
||||
else:
|
||||
logger.warning(
|
||||
"Answer extraction from model output failed, most likely no quotes provided"
|
||||
)
|
||||
|
||||
if quotes_dict is None:
|
||||
yield {}
|
||||
else:
|
||||
yield quotes_dict
|
||||
|
||||
|
||||
class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: Callable[
|
||||
[str, list[InferenceChunk], bool], list[dict[str, str]]
|
||||
] = json_chat_processor,
|
||||
model_version: str = OPENAI_MODEL_VERSION,
|
||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
||||
timeout: int | None = None,
|
||||
reflexion_try_count: int = 0,
|
||||
api_key: str | None = None,
|
||||
include_metadata: bool = INCLUDE_METADATA,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.reflexion_try_count = reflexion_try_count
|
||||
self.timeout = timeout
|
||||
self.include_metadata = include_metadata
|
||||
try:
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
||||
messages = self.prompt_processor(query, context_docs, self.include_metadata)
|
||||
logger.debug(json.dumps(messages, indent=4))
|
||||
model_output = ""
|
||||
for _ in range(self.reflexion_try_count + 1):
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
messages=messages,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = cast(
|
||||
str, response["choices"][0]["message"]["content"]
|
||||
).strip()
|
||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
messages = self.prompt_processor(query, context_docs, self.include_metadata)
|
||||
logger.debug(json.dumps(messages, indent=4))
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=self.api_key,
|
||||
messages=messages,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
model_output: str = ""
|
||||
found_answer_start = False
|
||||
found_answer_end = False
|
||||
for event in response:
|
||||
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
continue
|
||||
event_text = event_dict["content"]
|
||||
model_previous = model_output
|
||||
model_output += event_text
|
||||
logger.debug(f"GPT returned token: {event_text}")
|
||||
|
||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
||||
" ", ""
|
||||
).replace("\n", ""):
|
||||
# Note, if the token that completes the pattern has additional text, for example if the token is "?
|
||||
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
|
||||
# event that the model outputs the UNCERTAINTY_PAT
|
||||
found_answer_start = True
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if stream_answer_end(model_previous, event_text):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": event_text}
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
_, quotes_dict = process_answer(model_output, context_docs)
|
||||
|
||||
yield {} if quotes_dict is None else quotes_dict
|
280
backend/danswer/direct_qa/open_ai.py
Normal file
280
backend/danswer/direct_qa/open_ai.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import json
|
||||
from abc import ABC
|
||||
from collections.abc import Callable
|
||||
from collections.abc import Generator
|
||||
from functools import wraps
|
||||
from typing import Any
|
||||
from typing import cast
|
||||
from typing import TypeVar
|
||||
|
||||
import openai
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import INCLUDE_METADATA
|
||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.interfaces import QAModel
|
||||
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
|
||||
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||
from danswer.direct_qa.qa_utils import process_answer
|
||||
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_function_time
|
||||
from openai.error import AuthenticationError
|
||||
from openai.error import Timeout
|
||||
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
F = TypeVar("F", bound=Callable)
|
||||
|
||||
|
||||
def get_openai_api_key() -> str:
|
||||
return OPENAI_API_KEY or cast(
|
||||
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
||||
)
|
||||
|
||||
|
||||
def _ensure_openai_api_key(api_key: str | None) -> str:
|
||||
try:
|
||||
return api_key or get_openai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
|
||||
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
||||
"""
|
||||
Utility to add in some common default values so they don't have to be set every time.
|
||||
"""
|
||||
return {
|
||||
"temperature": 0,
|
||||
"top_p": 1,
|
||||
"frequency_penalty": 0,
|
||||
"presence_penalty": 0,
|
||||
**kwargs,
|
||||
}
|
||||
|
||||
|
||||
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
||||
@wraps(openai_call)
|
||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||
try:
|
||||
# if streamed, the call returns a generator
|
||||
if kwargs.get("stream"):
|
||||
|
||||
def _generator() -> Generator[Any, None, None]:
|
||||
yield from openai_call(*args, **kwargs)
|
||||
|
||||
return _generator()
|
||||
return openai_call(*args, **kwargs)
|
||||
except AuthenticationError:
|
||||
logger.exception("Failed to authenticate with OpenAI API")
|
||||
raise
|
||||
except Timeout:
|
||||
logger.exception("OpenAI API timed out for query: %s", query)
|
||||
raise
|
||||
except Exception:
|
||||
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
||||
raise
|
||||
|
||||
return cast(F, wrapped_call)
|
||||
|
||||
|
||||
# used to check if the QAModel is an OpenAI model
|
||||
class OpenAIQAModel(QAModel, ABC):
|
||||
pass
|
||||
|
||||
|
||||
class OpenAICompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
api_key: str | None = None,
|
||||
timeout: int | None = None,
|
||||
include_metadata: bool = INCLUDE_METADATA,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.timeout = timeout
|
||||
self.include_metadata = include_metadata
|
||||
try:
|
||||
self.api_key = api_key or get_openai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise OpenAIKeyMissing()
|
||||
|
||||
@staticmethod
|
||||
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
||||
for event in response:
|
||||
yield event["choices"][0]["text"]
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
prompt=filled_prompt,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = cast(str, response["choices"][0]["text"]).strip()
|
||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
filled_prompt = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(filled_prompt)
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.Completion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
prompt=filled_prompt,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
|
||||
tokens = self._generate_tokens_from_response(response)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
||||
|
||||
|
||||
class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||
def __init__(
|
||||
self,
|
||||
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
||||
model_version: str = GEN_AI_MODEL_VERSION,
|
||||
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||
timeout: int | None = None,
|
||||
reflexion_try_count: int = 0,
|
||||
api_key: str | None = None,
|
||||
include_metadata: bool = INCLUDE_METADATA,
|
||||
) -> None:
|
||||
self.prompt_processor = prompt_processor
|
||||
self.model_version = model_version
|
||||
self.max_output_tokens = max_output_tokens
|
||||
self.reflexion_try_count = reflexion_try_count
|
||||
self.timeout = timeout
|
||||
self.include_metadata = include_metadata
|
||||
self.api_key = api_key
|
||||
|
||||
@staticmethod
|
||||
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
||||
for event in response:
|
||||
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
||||
if (
|
||||
"content" not in event_dict
|
||||
): # could be a role message or empty termination
|
||||
continue
|
||||
yield event_dict["content"]
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
self,
|
||||
query: str,
|
||||
context_docs: list[InferenceChunk],
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(json.dumps(messages, indent=4))
|
||||
model_output = ""
|
||||
for _ in range(self.reflexion_try_count + 1):
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
messages=messages,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
),
|
||||
)
|
||||
model_output = cast(
|
||||
str, response["choices"][0]["message"]["content"]
|
||||
).strip()
|
||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||
messages.extend([assistant_msg, get_json_chat_reflexion_msg()])
|
||||
logger.info(
|
||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||
)
|
||||
|
||||
logger.debug(model_output)
|
||||
|
||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||
return answer, quotes_dict
|
||||
|
||||
def answer_question_stream(
|
||||
self, query: str, context_docs: list[InferenceChunk]
|
||||
) -> Generator[dict[str, Any] | None, None, None]:
|
||||
messages = self.prompt_processor.fill_prompt(
|
||||
query, context_docs, self.include_metadata
|
||||
)
|
||||
logger.debug(json.dumps(messages, indent=4))
|
||||
|
||||
openai_call = _handle_openai_exceptions_wrapper(
|
||||
openai_call=openai.ChatCompletion.create,
|
||||
query=query,
|
||||
)
|
||||
response = openai_call(
|
||||
**_build_openai_settings(
|
||||
api_key=_ensure_openai_api_key(self.api_key),
|
||||
messages=messages,
|
||||
model=self.model_version,
|
||||
max_tokens=self.max_output_tokens,
|
||||
request_timeout=self.timeout,
|
||||
stream=True,
|
||||
),
|
||||
)
|
||||
|
||||
tokens = self._generate_tokens_from_response(response)
|
||||
|
||||
yield from process_model_tokens(
|
||||
tokens=tokens,
|
||||
context_docs=context_docs,
|
||||
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||
)
|
@@ -1,3 +1,4 @@
|
||||
import abc
|
||||
import json
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
@@ -16,10 +17,10 @@ QUOTE_PAT = "Quote:"
|
||||
BASE_PROMPT = (
|
||||
f"Answer the query based on provided documents and quote relevant sections. "
|
||||
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
||||
f'Respond with "?" for the answer if the query cannot be answered based on the documents. '
|
||||
f"The quotes must be EXACT substrings from the documents."
|
||||
)
|
||||
|
||||
|
||||
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
|
||||
|
||||
SAMPLE_JSON_RESPONSE = {
|
||||
@@ -31,7 +32,23 @@ SAMPLE_JSON_RESPONSE = {
|
||||
}
|
||||
|
||||
|
||||
def add_metadata_section(
|
||||
def _append_acknowledge_doc_messages(
|
||||
current_messages: list[dict[str, str]], new_chunk_content: str
|
||||
) -> list[dict[str, str]]:
|
||||
updated_messages = current_messages.copy()
|
||||
updated_messages.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": new_chunk_content,
|
||||
},
|
||||
{"role": "assistant", "content": "Acknowledged"},
|
||||
]
|
||||
)
|
||||
return updated_messages
|
||||
|
||||
|
||||
def _add_metadata_section(
|
||||
prompt_current: str,
|
||||
chunk: InferenceChunk,
|
||||
prepend_tab: bool = False,
|
||||
@@ -67,192 +84,313 @@ def add_metadata_section(
|
||||
return prompt_current
|
||||
|
||||
|
||||
def json_processor(
|
||||
question: str,
|
||||
chunks: list[InferenceChunk],
|
||||
include_metadata: bool = False,
|
||||
include_sep: bool = True,
|
||||
) -> str:
|
||||
prompt = (
|
||||
BASE_PROMPT + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
class PromptProcessor(abc.ABC):
|
||||
"""Take the most relevant chunks and fills out a LLM prompt using the chunk contents
|
||||
and optionally metadata about the chunk"""
|
||||
|
||||
for chunk in chunks:
|
||||
prompt += f"\n\n{DOC_SEP_PAT}\n"
|
||||
if include_metadata:
|
||||
prompt = add_metadata_section(
|
||||
prompt, chunk, prepend_tab=False, include_sep=include_sep
|
||||
)
|
||||
@property
|
||||
@abc.abstractmethod
|
||||
def specifies_json_output(self) -> bool:
|
||||
raise NotImplementedError
|
||||
|
||||
prompt += chunk.content
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
return prompt
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str | list[dict[str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def json_chat_processor(
|
||||
question: str,
|
||||
chunks: list[InferenceChunk],
|
||||
include_metadata: bool = False,
|
||||
include_sep: bool = False,
|
||||
) -> list[dict[str, str]]:
|
||||
metadata_prompt_section = "with metadata and contents " if include_metadata else ""
|
||||
intro_msg = (
|
||||
f"You are a Question Answering assistant that answers queries based on the provided most relevant documents.\n"
|
||||
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
|
||||
)
|
||||
|
||||
complete_answer_not_found_response = (
|
||||
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||
)
|
||||
task_msg = (
|
||||
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
||||
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
||||
"All quotes MUST be EXACT substrings from provided documents.\n"
|
||||
"Your responses should be informative and concise.\n"
|
||||
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
||||
"If the query cannot be answered based on the documents, respond with "
|
||||
f"{complete_answer_not_found_response}\n"
|
||||
"If the query requires aggregating the number of documents, respond with "
|
||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||
)
|
||||
messages = [{"role": "system", "content": intro_msg}]
|
||||
|
||||
for chunk in chunks:
|
||||
full_context = ""
|
||||
if include_metadata:
|
||||
full_context = add_metadata_section(
|
||||
full_context, chunk, prepend_tab=False, include_sep=include_sep
|
||||
)
|
||||
full_context += chunk.content
|
||||
messages.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": full_context,
|
||||
},
|
||||
{"role": "assistant", "content": "Acknowledged"},
|
||||
]
|
||||
)
|
||||
messages.append({"role": "system", "content": task_msg})
|
||||
|
||||
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
||||
|
||||
return messages
|
||||
class NonChatPromptProcessor(PromptProcessor):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may use again in future
|
||||
class ChatPromptProcessor(PromptProcessor):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> list[dict[str, str]]:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
# Chain of Thought approach works however has higher token cost (more expensive) and is slower.
|
||||
# Should use this one if users ask questions that require logical reasoning.
|
||||
def json_cot_variant_processor(question: str, documents: list[str]) -> str:
|
||||
prompt = (
|
||||
f"Answer the query based on provided documents and quote relevant sections. "
|
||||
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
|
||||
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
|
||||
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
class JsonProcessor(NonChatPromptProcessor):
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
for document in documents:
|
||||
prompt += f"\n{DOC_SEP_PAT}\n{document}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += "Reasoning:\n"
|
||||
return prompt
|
||||
|
||||
|
||||
# This one seems largely useless with a single example
|
||||
# Model seems to take the one example of answering Yes and just does that too.
|
||||
def json_reflexion_processor(question: str, documents: list[str]) -> str:
|
||||
reflexion_str = "Does this fully answer the user query?"
|
||||
prompt = (
|
||||
BASE_PROMPT
|
||||
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
|
||||
f"If No, generate a better json response to the query.\n"
|
||||
f"Sample question and response:\n"
|
||||
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
|
||||
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
|
||||
f"{reflexion_str} Yes\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{document}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
# Initial design, works pretty well but not optimal
|
||||
def freeform_processor(question: str, documents: list[str]) -> str:
|
||||
prompt = (
|
||||
f"Answer the query based on the documents below and quote the documents segments containing the answer. "
|
||||
f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. '
|
||||
f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. '
|
||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||
f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
for document in documents:
|
||||
prompt += f"\n{DOC_SEP_PAT}\n{document}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += f"{ANSWER_PAT}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def freeform_chat_processor(
|
||||
question: str, documents: list[str]
|
||||
) -> list[dict[str, str]]:
|
||||
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
|
||||
role_msg = (
|
||||
f"You are a Question Answering assistant that answers queries based on provided documents. "
|
||||
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
|
||||
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
|
||||
f"Answer the question directly and concisely. "
|
||||
f"Each quote should be a single continuous segment from a document. "
|
||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||
f"An example of quote sections may look like:\n{sample_quote}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": role_msg},
|
||||
]
|
||||
for document in documents:
|
||||
messages.extend(
|
||||
[
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Acknowledge the following document:\n{document}",
|
||||
},
|
||||
{"role": "assistant", "content": "Acknowledged"},
|
||||
]
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
prompt = (
|
||||
BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Please now answer the following query based on the previously provided "
|
||||
f"documents and quote the relevant sections of the documents\n{question}",
|
||||
},
|
||||
)
|
||||
for chunk in chunks:
|
||||
prompt += f"\n\n{DOC_SEP_PAT}\n"
|
||||
if include_metadata:
|
||||
prompt = _add_metadata_section(
|
||||
prompt, chunk, prepend_tab=False, include_sep=True
|
||||
)
|
||||
|
||||
return messages
|
||||
prompt += chunk.content
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
# Not very useful, have not seen it improve an answer based on this
|
||||
# Sometimes gpt-3.5-turbo will just answer something worse like:
|
||||
# 'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
|
||||
# document. No revision is needed.'
|
||||
def get_chat_reflexion_msg() -> dict[str, str]:
|
||||
class JsonChatProcessor(ChatPromptProcessor):
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> list[dict[str, str]]:
|
||||
metadata_prompt_section = (
|
||||
"with metadata and contents " if include_metadata else ""
|
||||
)
|
||||
intro_msg = (
|
||||
f"You are a Question Answering assistant that answers queries "
|
||||
f"based on the provided most relevant documents.\n"
|
||||
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
|
||||
)
|
||||
|
||||
complete_answer_not_found_response = (
|
||||
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||
)
|
||||
task_msg = (
|
||||
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
||||
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
||||
"All quotes MUST be EXACT substrings from provided documents.\n"
|
||||
"Your responses should be informative and concise.\n"
|
||||
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
||||
"If the query cannot be answered based on the documents, respond with "
|
||||
f"{complete_answer_not_found_response}\n"
|
||||
"If the query requires aggregating the number of documents, respond with "
|
||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||
)
|
||||
messages = [{"role": "system", "content": intro_msg}]
|
||||
|
||||
for chunk in chunks:
|
||||
full_context = ""
|
||||
if include_metadata:
|
||||
full_context = _add_metadata_section(
|
||||
full_context, chunk, prepend_tab=False, include_sep=False
|
||||
)
|
||||
full_context += chunk.content
|
||||
messages = _append_acknowledge_doc_messages(messages, full_context)
|
||||
messages.append({"role": "system", "content": task_msg})
|
||||
|
||||
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class WeakModelFreeformProcessor(NonChatPromptProcessor):
|
||||
"""Avoid using this one if the model is capable of using another prompt
|
||||
Intended for models that can't follow complex instructions or have short context windows
|
||||
This prompt only uses 1 reference document chunk
|
||||
"""
|
||||
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
first_chunk_content = chunks[0].content if chunks else "No Document Provided"
|
||||
|
||||
prompt = (
|
||||
f"Reference Document:\n{first_chunk_content}\n{GENERAL_SEP_PAT}"
|
||||
f"Answer the user query below based on the reference document above. "
|
||||
f'Respond with an "{ANSWER_PAT}" section and '
|
||||
f'as many "{QUOTE_PAT}" sections as needed to support the answer.'
|
||||
f"\n{GENERAL_SEP_PAT}"
|
||||
f"{QUESTION_PAT} {question}\n"
|
||||
f"{ANSWER_PAT}"
|
||||
)
|
||||
|
||||
return prompt
|
||||
|
||||
|
||||
class WeakChatModelFreeformProcessor(ChatPromptProcessor):
|
||||
"""Avoid using this one if the model is capable of using another prompt
|
||||
Intended for models that can't follow complex instructions or have short context windows
|
||||
This prompt only uses 1 reference document chunk
|
||||
"""
|
||||
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> list[dict[str, str]]:
|
||||
first_chunk_content = chunks[0].content if chunks else "No Document Provided"
|
||||
intro_msg = (
|
||||
f"You are a question answering assistant. "
|
||||
f'Respond to the query with an "{ANSWER_PAT}" section and '
|
||||
f'as many "{QUOTE_PAT}" sections as needed to support the answer. '
|
||||
f"Answer the user query based on the following document:\n\n{first_chunk_content}"
|
||||
)
|
||||
|
||||
messages = [{"role": "system", "content": intro_msg}]
|
||||
|
||||
user_query = f"{QUESTION_PAT} {question}"
|
||||
messages.append({"role": "user", "content": user_query})
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may revisit in future
|
||||
|
||||
|
||||
class FreeformProcessor(NonChatPromptProcessor):
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
prompt = (
|
||||
f"Answer the query based on the documents below and quote the documents segments containing the answer. "
|
||||
f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. '
|
||||
f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. '
|
||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||
f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += f"{ANSWER_PAT}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
class FreeformChatProcessor(ChatPromptProcessor):
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> list[dict[str, str]]:
|
||||
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
|
||||
role_msg = (
|
||||
f"You are a Question Answering assistant that answers queries based on provided documents. "
|
||||
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
|
||||
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
|
||||
f"Answer the question directly and concisely. "
|
||||
f"Each quote should be a single continuous segment from a document. "
|
||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||
f"An example of quote sections may look like:\n{sample_quote}"
|
||||
)
|
||||
|
||||
messages = [
|
||||
{"role": "system", "content": role_msg},
|
||||
]
|
||||
for chunk in chunks:
|
||||
messages = _append_acknowledge_doc_messages(messages, chunk.content)
|
||||
|
||||
messages.append(
|
||||
{
|
||||
"role": "user",
|
||||
"content": f"Please now answer the following query based on the previously provided "
|
||||
f"documents and quote the relevant sections of the documents\n{question}",
|
||||
},
|
||||
)
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
class JsonCOTProcessor(NonChatPromptProcessor):
|
||||
"""Chain of Thought allows model to explain out its reasoning to handle harder tests.
|
||||
This prompt type works however has higher token cost (more expensive) and is slower.
|
||||
Consider this one if users ask questions that require logical reasoning."""
|
||||
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
prompt = (
|
||||
f"Answer the query based on provided documents and quote relevant sections. "
|
||||
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
|
||||
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
|
||||
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
prompt += "Reasoning:\n"
|
||||
return prompt
|
||||
|
||||
|
||||
class JsonReflexionProcessor(NonChatPromptProcessor):
|
||||
"""Reflexion prompting to attempt to have model evaluate its own answer.
|
||||
This one seems largely useless when only given a single example
|
||||
Model seems to take the one example of answering "Yes" and just does that too."""
|
||||
|
||||
@property
|
||||
def specifies_json_output(self) -> bool:
|
||||
return True
|
||||
|
||||
@staticmethod
|
||||
def fill_prompt(
|
||||
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||
) -> str:
|
||||
reflexion_str = "Does this fully answer the user query?"
|
||||
prompt = (
|
||||
BASE_PROMPT
|
||||
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
|
||||
f"If No, generate a better json response to the query.\n"
|
||||
f"Sample question and response:\n"
|
||||
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
|
||||
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
|
||||
f"{reflexion_str} Yes\n\n"
|
||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||
)
|
||||
|
||||
for chunk in chunks:
|
||||
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}"
|
||||
|
||||
prompt += "\n\n---\n\n"
|
||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||
return prompt
|
||||
|
||||
|
||||
def get_json_chat_reflexion_msg() -> dict[str, str]:
|
||||
"""With the models tried (curent as of Jul 2023), this has not been very useful.
|
||||
Have not seen any answers improved based on this.
|
||||
For models like gpt-3.5-turbo, it will often answer something like:
|
||||
'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
|
||||
document. No revision is needed.'"""
|
||||
reflexion_content = (
|
||||
"Is the assistant response a valid json that fully answer the user query? "
|
||||
"If the response needs to be fixed or if an improvement is possible, provide a revised json. "
|
||||
|
247
backend/danswer/direct_qa/qa_utils.py
Normal file
247
backend/danswer/direct_qa/qa_utils.py
Normal file
@@ -0,0 +1,247 @@
|
||||
import json
|
||||
import math
|
||||
import re
|
||||
from collections.abc import Generator
|
||||
from typing import Any
|
||||
from typing import Optional
|
||||
from typing import Tuple
|
||||
|
||||
import regex
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||
from danswer.configs.constants import BLURB
|
||||
from danswer.configs.constants import DOCUMENT_ID
|
||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||
from danswer.configs.constants import SOURCE_LINK
|
||||
from danswer.configs.constants import SOURCE_TYPE
|
||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.text_processing import clean_model_quote
|
||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||
|
||||
logger = setup_logger()
|
||||
|
||||
|
||||
def structure_quotes_for_response(
|
||||
quotes: list[DanswerQuote] | None,
|
||||
) -> dict[str, dict[str, str | None]]:
|
||||
if quotes is None:
|
||||
return {}
|
||||
|
||||
response_quotes = {}
|
||||
for quote in quotes:
|
||||
response_quotes[quote.quote] = {
|
||||
DOCUMENT_ID: quote.document_id,
|
||||
SOURCE_LINK: quote.link,
|
||||
SOURCE_TYPE: quote.source_type,
|
||||
SEMANTIC_IDENTIFIER: quote.semantic_identifier,
|
||||
BLURB: quote.blurb,
|
||||
}
|
||||
return response_quotes
|
||||
|
||||
|
||||
def extract_answer_quotes_freeform(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
"""Splits the model output into an Answer and 0 or more Quote sections.
|
||||
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
|
||||
"""
|
||||
# If no answer section, don't care about the quote
|
||||
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
||||
return None, None
|
||||
|
||||
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
||||
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
||||
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
||||
|
||||
# Accept quote sections starting with the lower case version
|
||||
answer_raw = answer_raw.replace(
|
||||
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
||||
) # Just in case model unreliable
|
||||
|
||||
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
||||
sections_clean = [
|
||||
str(section).strip() for section in sections if str(section).strip()
|
||||
]
|
||||
if not sections_clean:
|
||||
return None, None
|
||||
|
||||
answer = str(sections_clean[0])
|
||||
if len(sections) == 1:
|
||||
return answer, None
|
||||
return answer, sections_clean[1:]
|
||||
|
||||
|
||||
def extract_answer_quotes_json(
|
||||
answer_dict: dict[str, str | list[str]]
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
||||
answer = str(answer_dict.get("answer"))
|
||||
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
||||
if isinstance(quotes, str):
|
||||
quotes = [quotes]
|
||||
return answer, quotes
|
||||
|
||||
|
||||
def separate_answer_quotes(
|
||||
answer_raw: str,
|
||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||
try:
|
||||
model_raw_json = json.loads(answer_raw)
|
||||
return extract_answer_quotes_json(model_raw_json)
|
||||
except ValueError:
|
||||
return extract_answer_quotes_freeform(answer_raw)
|
||||
|
||||
|
||||
def match_quotes_to_docs(
|
||||
quotes: list[str],
|
||||
chunks: list[InferenceChunk],
|
||||
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
||||
fuzzy_search: bool = False,
|
||||
prefix_only_length: int = 100,
|
||||
) -> list[DanswerQuote]:
|
||||
danswer_quotes = []
|
||||
for quote in quotes:
|
||||
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
||||
|
||||
for chunk in chunks:
|
||||
if not chunk.source_links:
|
||||
continue
|
||||
|
||||
quote_clean = shared_precompare_cleanup(
|
||||
clean_model_quote(quote, trim_length=prefix_only_length)
|
||||
)
|
||||
chunk_clean = shared_precompare_cleanup(chunk.content)
|
||||
|
||||
# Finding the offset of the quote in the plain text
|
||||
if fuzzy_search:
|
||||
re_search_str = (
|
||||
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
||||
)
|
||||
found = regex.search(re_search_str, chunk_clean)
|
||||
if not found:
|
||||
continue
|
||||
offset = found.span()[0]
|
||||
else:
|
||||
if quote_clean not in chunk_clean:
|
||||
continue
|
||||
offset = chunk_clean.index(quote_clean)
|
||||
|
||||
# Extracting the link from the offset
|
||||
curr_link = None
|
||||
for link_offset, link in chunk.source_links.items():
|
||||
# Should always find one because offset is at least 0 and there must be a 0 link_offset
|
||||
if int(link_offset) <= offset:
|
||||
curr_link = link
|
||||
else:
|
||||
danswer_quotes.append(
|
||||
DanswerQuote(
|
||||
quote=quote,
|
||||
document_id=chunk.document_id,
|
||||
link=curr_link,
|
||||
source_type=chunk.source_type,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
blurb=chunk.blurb,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
# If the offset is larger than the start of the last quote, it must be the last one
|
||||
danswer_quotes.append(
|
||||
DanswerQuote(
|
||||
quote=quote,
|
||||
document_id=chunk.document_id,
|
||||
link=curr_link,
|
||||
source_type=chunk.source_type,
|
||||
semantic_identifier=chunk.semantic_identifier,
|
||||
blurb=chunk.blurb,
|
||||
)
|
||||
)
|
||||
break
|
||||
|
||||
return danswer_quotes
|
||||
|
||||
|
||||
def process_answer(
|
||||
answer_raw: str, chunks: list[InferenceChunk]
|
||||
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||
answer, quote_strings = separate_answer_quotes(answer_raw)
|
||||
if answer == UNCERTAINTY_PAT or not answer:
|
||||
if answer == UNCERTAINTY_PAT:
|
||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||
else:
|
||||
logger.debug("No answer extracted from raw output")
|
||||
return DanswerAnswer(answer=None), []
|
||||
|
||||
logger.info(f"Answer: {answer}")
|
||||
if not quote_strings:
|
||||
logger.debug("No quotes extracted from raw output")
|
||||
return DanswerAnswer(answer=answer), []
|
||||
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
||||
quotes = match_quotes_to_docs(quote_strings, chunks)
|
||||
logger.info(f"Final quotes: {quotes}")
|
||||
|
||||
return DanswerAnswer(answer=answer), quotes
|
||||
|
||||
|
||||
def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||
next_token = next_token.replace('\\"', "")
|
||||
# If the previous character is an escape token, don't consider the first character of next_token
|
||||
if answer_so_far and answer_so_far[-1] == "\\":
|
||||
next_token = next_token[1:]
|
||||
if '"' in next_token:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
def extract_quotes_from_completed_token_stream(
|
||||
model_output: str, context_chunks: list[InferenceChunk]
|
||||
) -> list[DanswerQuote]:
|
||||
logger.debug(model_output)
|
||||
answer, quotes = process_answer(model_output, context_chunks)
|
||||
if answer:
|
||||
logger.info(answer)
|
||||
elif model_output:
|
||||
logger.warning("Answer extraction from model output failed.")
|
||||
|
||||
return quotes
|
||||
|
||||
|
||||
def process_model_tokens(
|
||||
tokens: Generator[str, None, None],
|
||||
context_docs: list[InferenceChunk],
|
||||
is_json_prompt: bool = True,
|
||||
) -> Generator[dict[str, Any], None, None]:
|
||||
"""Yields Answer tokens back out in a dict for streaming to frontend
|
||||
When Answer section ends, yields dict with answer_finished key
|
||||
Collects all the tokens at the end to form the complete model output"""
|
||||
model_output: str = ""
|
||||
found_answer_start = False if is_json_prompt else True
|
||||
found_answer_end = False
|
||||
for token in tokens:
|
||||
model_previous = model_output
|
||||
model_output += token
|
||||
|
||||
trimmed_combine = model_output.replace(" ", "").replace("\n", "")
|
||||
if not found_answer_start and '{"answer":"' in trimmed_combine:
|
||||
# Note, if the token that completes the pattern has additional text, for example if the token is "?
|
||||
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
|
||||
# event that the model outputs the UNCERTAINTY_PAT
|
||||
found_answer_start = True
|
||||
continue
|
||||
|
||||
if found_answer_start and not found_answer_end:
|
||||
if (is_json_prompt and stream_json_answer_end(model_previous, token)) or (
|
||||
not is_json_prompt and f"\n{QUOTE_PAT}" in model_output
|
||||
):
|
||||
found_answer_end = True
|
||||
yield {"answer_finished": True}
|
||||
continue
|
||||
yield {"answer_data": token}
|
||||
|
||||
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
|
||||
yield structure_quotes_for_response(quotes)
|
@@ -58,7 +58,7 @@ def _build_custom_semantic_identifier(
|
||||
|
||||
|
||||
def _process_quotes(
|
||||
quotes: dict[str, dict[str, str | int | None]] | None
|
||||
quotes: dict[str, dict[str, str | None]] | None
|
||||
) -> tuple[str | None, list[str]]:
|
||||
if not quotes:
|
||||
return None, []
|
||||
|
@@ -9,19 +9,21 @@ from danswer.auth.users import google_oauth_client
|
||||
from danswer.configs.app_configs import APP_HOST
|
||||
from danswer.configs.app_configs import APP_PORT
|
||||
from danswer.configs.app_configs import DISABLE_AUTH
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
|
||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import SECRET
|
||||
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||
from danswer.configs.app_configs import WEB_DOMAIN
|
||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||
from danswer.datastores.qdrant.indexing import list_qdrant_collections
|
||||
from danswer.datastores.typesense.store import check_typesense_collection_exist
|
||||
from danswer.datastores.typesense.store import create_typesense_collection
|
||||
from danswer.db.credentials import create_initial_public_credential
|
||||
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
|
||||
from danswer.direct_qa.llm import get_openai_api_key
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.server.event_loading import router as event_processing_router
|
||||
from danswer.server.health import router as health_router
|
||||
from danswer.server.manage import router as admin_router
|
||||
@@ -121,21 +123,29 @@ def get_application() -> FastAPI:
|
||||
warm_up_models,
|
||||
)
|
||||
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
logger.info("Generative AI Q&A disabled")
|
||||
else:
|
||||
logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
|
||||
logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
|
||||
|
||||
auth_status = "off" if DISABLE_AUTH else "on"
|
||||
logger.info(f"User auth is turned {auth_status}")
|
||||
logger.info(f"User Authentication is turned {auth_status}")
|
||||
|
||||
if not ENABLE_OAUTH:
|
||||
logger.debug("OAuth is turned off")
|
||||
else:
|
||||
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
|
||||
logger.warning("OAuth is turned on but incorrectly configured")
|
||||
if not DISABLE_AUTH:
|
||||
if not ENABLE_OAUTH:
|
||||
logger.warning("OAuth is turned off")
|
||||
else:
|
||||
logger.debug("OAuth is turned on")
|
||||
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
|
||||
logger.warning("OAuth is turned on but incorrectly configured")
|
||||
else:
|
||||
logger.debug("OAuth is turned on")
|
||||
|
||||
logger.info("Warming up local NLP models.")
|
||||
warm_up_models()
|
||||
qa_model = get_default_backend_qa_model()
|
||||
qa_model.warm_up_model()
|
||||
|
||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||
nltk.download("stopwords")
|
||||
|
@@ -1,9 +1,5 @@
|
||||
from typing import Any
|
||||
|
||||
from danswer.connectors.slack.connector import get_channel_info
|
||||
from danswer.connectors.slack.connector import get_thread
|
||||
from danswer.connectors.slack.connector import thread_to_doc
|
||||
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
||||
from danswer.utils.logger import setup_logger
|
||||
from fastapi import APIRouter
|
||||
from pydantic import BaseModel
|
||||
|
@@ -38,8 +38,10 @@ from danswer.db.engine import get_session
|
||||
from danswer.db.engine import get_sqlalchemy_async_engine
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
|
||||
from danswer.direct_qa.llm import get_openai_api_key
|
||||
from danswer.direct_qa import check_model_api_key_is_valid
|
||||
from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.open_ai import get_openai_api_key
|
||||
from danswer.direct_qa.open_ai import OpenAIQAModel
|
||||
from danswer.dynamic_configs import get_dynamic_config_store
|
||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.models import ApiKey
|
||||
@@ -102,7 +104,7 @@ def check_google_app_credentials_exist(
|
||||
) -> dict[str, str]:
|
||||
try:
|
||||
return {"client_id": get_google_app_cred().web.client_id}
|
||||
except ConfigNotFoundError as e:
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Google App Credentials not found")
|
||||
|
||||
|
||||
@@ -295,19 +297,13 @@ def validate_existing_openai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
# OpenAI key is only used for generative QA, so no need to validate this
|
||||
# if it's turned off
|
||||
if DISABLE_GENERATIVE_AI:
|
||||
# if it's turned off or if a non-OpenAI model is being used
|
||||
if DISABLE_GENERATIVE_AI or not isinstance(
|
||||
get_default_backend_qa_model(), OpenAIQAModel
|
||||
):
|
||||
return
|
||||
|
||||
# always check if key exists
|
||||
try:
|
||||
openai_api_key = get_openai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
# don't call OpenAI every single time, only validate every so often
|
||||
# Only validate every so often
|
||||
check_key_time = "openai_api_key_last_check_time"
|
||||
kv_store = get_dynamic_config_store()
|
||||
curr_time = datetime.now()
|
||||
@@ -320,10 +316,17 @@ def validate_existing_openai_api_key(
|
||||
# First time checking the key, nothing unusual
|
||||
pass
|
||||
|
||||
try:
|
||||
openai_api_key = get_openai_api_key()
|
||||
except ConfigNotFoundError:
|
||||
raise HTTPException(status_code=404, detail="Key not found")
|
||||
except ValueError as e:
|
||||
raise HTTPException(status_code=400, detail=str(e))
|
||||
|
||||
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
||||
|
||||
try:
|
||||
is_valid = check_openai_api_key_is_valid(openai_api_key)
|
||||
is_valid = check_model_api_key_is_valid(openai_api_key)
|
||||
except ValueError:
|
||||
# this is the case where they aren't using an OpenAI-based model
|
||||
is_valid = True
|
||||
@@ -356,7 +359,7 @@ def store_openai_api_key(
|
||||
_: User = Depends(current_admin_user),
|
||||
) -> None:
|
||||
try:
|
||||
is_valid = check_openai_api_key_is_valid(request.api_key)
|
||||
is_valid = check_model_api_key_is_valid(request.api_key)
|
||||
if not is_valid:
|
||||
raise HTTPException(400, "Invalid API key provided")
|
||||
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)
|
||||
|
@@ -102,8 +102,8 @@ class SearchResponse(BaseModel):
|
||||
|
||||
|
||||
class QAResponse(SearchResponse):
|
||||
answer: str | None
|
||||
quotes: dict[str, dict[str, str | int | None]] | None
|
||||
answer: str | None # DanswerAnswer
|
||||
quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote
|
||||
predicted_flow: QueryFlow
|
||||
predicted_search: SearchType
|
||||
error_msg: str | None = None
|
||||
|
@@ -1,10 +1,10 @@
|
||||
import json
|
||||
from collections.abc import Generator
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.datastores.qdrant.store import QdrantIndex
|
||||
from danswer.datastores.typesense.store import TypesenseIndex
|
||||
from danswer.db.models import User
|
||||
@@ -12,7 +12,6 @@ from danswer.direct_qa import get_default_backend_qa_model
|
||||
from danswer.direct_qa.answer_question import answer_question
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.llm import get_json_line
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||
@@ -35,6 +34,10 @@ logger = setup_logger()
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def get_json_line(json_dict: dict) -> str:
|
||||
return json.dumps(json_dict) + "\n"
|
||||
|
||||
|
||||
@router.get("/search-intent")
|
||||
def get_search_type(
|
||||
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||
@@ -162,7 +165,7 @@ def stream_direct_qa(
|
||||
return
|
||||
|
||||
try:
|
||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
||||
qa_model = get_default_backend_qa_model()
|
||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||
logger.exception("Unable to get QA model")
|
||||
yield get_json_line({"error": str(e)})
|
||||
|
@@ -10,6 +10,7 @@ filelock==3.12.0
|
||||
google-api-python-client==2.86.0
|
||||
google-auth-httplib2==0.1.0
|
||||
google-auth-oauthlib==1.0.0
|
||||
gpt4all==1.0.5
|
||||
httpcore==0.16.3
|
||||
httpx==0.23.3
|
||||
httpx-oauth==0.11.2
|
||||
|
@@ -2,8 +2,8 @@ import textwrap
|
||||
import unittest
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.direct_qa.llm import match_quotes_to_docs
|
||||
from danswer.direct_qa.llm import separate_answer_quotes
|
||||
from danswer.direct_qa.qa_utils import match_quotes_to_docs
|
||||
from danswer.direct_qa.qa_utils import separate_answer_quotes
|
||||
|
||||
|
||||
class TestQAPostprocessing(unittest.TestCase):
|
||||
|
Reference in New Issue
Block a user