diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index bc2fd10d049a..7ed879299296 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -4,6 +4,7 @@ from danswer.connectors.factory import instantiate_connector from danswer.connectors.interfaces import LoadConnector from danswer.connectors.interfaces import PollConnector from danswer.connectors.models import InputType +from danswer.datastores.indexing_pipeline import build_indexing_pipeline from danswer.db.connector import disable_connector from danswer.db.connector import fetch_connectors from danswer.db.connector_credential_pair import update_connector_credential_pair @@ -21,7 +22,6 @@ from danswer.db.index_attempt import mark_attempt_succeeded from danswer.db.models import Connector from danswer.db.models import IndexAttempt from danswer.db.models import IndexingStatus -from danswer.utils.indexing_pipeline import build_indexing_pipeline from danswer.utils.logger import setup_logger from sqlalchemy.orm import Session diff --git a/backend/danswer/configs/model_configs.py b/backend/danswer/configs/model_configs.py index d10ab00f8775..67028b439e4d 100644 --- a/backend/danswer/configs/model_configs.py +++ b/backend/danswer/configs/model_configs.py @@ -31,10 +31,16 @@ CROSS_EMBED_CONTEXT_SIZE = 512 BATCH_SIZE_ENCODE_CHUNKS = 8 # QA Model API Configs -# https://platform.openai.com/docs/models/model-endpoint-compatibility +# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models +# Valid list: +# - openai-completion +# - openai-chat-completion +# - gpt4all-completion +# - gpt4all-chat-completion INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion") -OPENAI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-3.5-turbo") -OPENAI_MAX_OUTPUT_TOKENS = 512 +# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model +GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo") +GEN_AI_MAX_OUTPUT_TOKENS = 512 # Danswer custom Deep Learning Models INTENT_MODEL_VERSION = "danswer/intent-model" diff --git a/backend/danswer/utils/indexing_pipeline.py b/backend/danswer/datastores/indexing_pipeline.py similarity index 100% rename from backend/danswer/utils/indexing_pipeline.py rename to backend/danswer/datastores/indexing_pipeline.py diff --git a/backend/danswer/direct_qa/__init__.py b/backend/danswer/direct_qa/__init__.py index a7267d03a4c6..1f28a1384acc 100644 --- a/backend/danswer/direct_qa/__init__.py +++ b/backend/danswer/direct_qa/__init__.py @@ -1,18 +1,49 @@ from typing import Any +from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.direct_qa.exceptions import UnknownModelError +from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA +from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.llm import OpenAIChatCompletionQA -from danswer.direct_qa.llm import OpenAICompletionQA +from danswer.direct_qa.open_ai import OpenAIChatCompletionQA +from danswer.direct_qa.open_ai import OpenAICompletionQA +from openai.error import AuthenticationError +from openai.error import Timeout + + +def check_model_api_key_is_valid(model_api_key: str) -> bool: + if not model_api_key: + return False + + qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5) + + # try for up to 2 timeouts (e.g. 10 seconds in total) + for _ in range(2): + try: + qa_model.answer_question("Do not respond", []) + return True + except AuthenticationError: + return False + except Timeout: + pass + + return False def get_default_backend_qa_model( - internal_model: str = INTERNAL_MODEL_VERSION, **kwargs: Any + internal_model: str = INTERNAL_MODEL_VERSION, + api_key: str | None = None, + timeout: int = QA_TIMEOUT, + **kwargs: Any ) -> QAModel: if internal_model == "openai-completion": - return OpenAICompletionQA(**kwargs) + return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs) elif internal_model == "openai-chat-completion": - return OpenAIChatCompletionQA(**kwargs) + return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs) + elif internal_model == "gpt4all-completion": + return GPT4AllCompletionQA(**kwargs) + elif internal_model == "gpt4all-chat-completion": + return GPT4AllChatCompletionQA(**kwargs) else: raise UnknownModelError(internal_model) diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index e688a3f67ddc..29b4ba1a26fa 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -1,13 +1,13 @@ from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS -from danswer.configs.app_configs import QA_TIMEOUT from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.typesense.store import TypesenseIndex from danswer.db.models import User from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError +from danswer.direct_qa.qa_utils import structure_quotes_for_response from danswer.search.danswer_helper import query_intent from danswer.search.keyword_search import retrieve_keyword_documents from danswer.search.models import QueryFlow @@ -26,7 +26,6 @@ logger = setup_logger() def answer_question( question: QuestionRequest, user: User | None, - qa_model_timeout: int = QA_TIMEOUT, disable_generative_answer: bool = DISABLE_GENERATIVE_AI, ) -> QAResponse: query = question.query @@ -74,7 +73,7 @@ def answer_question( ) try: - qa_model = get_default_backend_qa_model(timeout=qa_model_timeout) + qa_model = get_default_backend_qa_model() except (UnknownModelError, OpenAIKeyMissing) as e: return QAResponse( answer=None, @@ -102,8 +101,8 @@ def answer_question( error_msg = f"Error occurred in call to LLM - {e}" return QAResponse( - answer=answer, - quotes=quotes, + answer=answer.answer if answer else None, + quotes=structure_quotes_for_response(quotes), top_ranked_docs=chunks_to_search_docs(ranked_chunks), lower_ranked_docs=chunks_to_search_docs(unranked_chunks), predicted_flow=predicted_flow, diff --git a/backend/danswer/direct_qa/exceptions.py b/backend/danswer/direct_qa/exceptions.py index ca599c7239f4..eb0434a7b1aa 100644 --- a/backend/danswer/direct_qa/exceptions.py +++ b/backend/danswer/direct_qa/exceptions.py @@ -1,5 +1,10 @@ class OpenAIKeyMissing(Exception): - def __init__(self, msg: str = "Unable to find an OpenAI Key") -> None: + default_msg = ( + "Unable to find existing OpenAI Key. " + 'A new key can be added from "Keys" section of the Admin Panel' + ) + + def __init__(self, msg: str = default_msg) -> None: super().__init__(msg) diff --git a/backend/danswer/direct_qa/gpt_4_all.py b/backend/danswer/direct_qa/gpt_4_all.py new file mode 100644 index 000000000000..16dc6d21bee2 --- /dev/null +++ b/backend/danswer/direct_qa/gpt_4_all.py @@ -0,0 +1,173 @@ +from collections.abc import Generator +from typing import Any + +from danswer.chunking.models import InferenceChunk +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.qa_prompts import ChatPromptProcessor +from danswer.direct_qa.qa_prompts import NonChatPromptProcessor +from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor +from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor +from danswer.direct_qa.qa_utils import process_answer +from danswer.direct_qa.qa_utils import process_model_tokens +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time +from gpt4all import GPT4All # type:ignore + + +logger = setup_logger() + +GPT4ALL_MODEL: GPT4All | None = None + + +def get_gpt_4_all_model( + model_version: str = GEN_AI_MODEL_VERSION, +) -> GPT4All: + global GPT4ALL_MODEL + if GPT4ALL_MODEL is None: + GPT4ALL_MODEL = GPT4All(model_version) + return GPT4ALL_MODEL + + +def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]: + """ + Utility to add in some common default values so they don't have to be set every time. + """ + return { + "temp": 0, + **kwargs, + } + + +class GPT4AllCompletionQA(QAModel): + def __init__( + self, + prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + include_metadata: bool = False, # gpt4all models can't handle this atm + ) -> None: + self.prompt_processor = prompt_processor + self.model_version = model_version + self.max_output_tokens = max_output_tokens + self.include_metadata = include_metadata + + def warm_up_model(self) -> None: + get_gpt_4_all_model(self.model_version) + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + gen_ai_model = get_gpt_4_all_model(self.model_version) + + model_output = gen_ai_model.generate( + **_build_gpt4all_settings( + prompt=filled_prompt, max_tokens=self.max_output_tokens + ), + ) + + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + gen_ai_model = get_gpt_4_all_model(self.model_version) + + model_stream = gen_ai_model.generate( + **_build_gpt4all_settings( + prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True + ), + ) + + yield from process_model_tokens( + tokens=model_stream, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) + + +class GPT4AllChatCompletionQA(QAModel): + def __init__( + self, + prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + include_metadata: bool = False, # gpt4all models can't handle this atm + ) -> None: + self.prompt_processor = prompt_processor + self.model_version = model_version + self.max_output_tokens = max_output_tokens + self.include_metadata = include_metadata + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + gen_ai_model = get_gpt_4_all_model(self.model_version) + + with gen_ai_model.chat_session(): + context_msgs = filled_prompt[:-1] + user_query = filled_prompt[-1].get("content") + for message in context_msgs: + gen_ai_model.current_chat_session.append(message) + + model_output = gen_ai_model.generate( + **_build_gpt4all_settings( + prompt=user_query, max_tokens=self.max_output_tokens + ), + ) + + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + gen_ai_model = get_gpt_4_all_model(self.model_version) + + with gen_ai_model.chat_session(): + context_msgs = filled_prompt[:-1] + user_query = filled_prompt[-1].get("content") + for message in context_msgs: + gen_ai_model.current_chat_session.append(message) + + model_stream = gen_ai_model.generate( + **_build_gpt4all_settings( + prompt=user_query, max_tokens=self.max_output_tokens + ), + ) + + yield from process_model_tokens( + tokens=model_stream, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) diff --git a/backend/danswer/direct_qa/interfaces.py b/backend/danswer/direct_qa/interfaces.py index 5958a1530cad..a40a7af4247c 100644 --- a/backend/danswer/direct_qa/interfaces.py +++ b/backend/danswer/direct_qa/interfaces.py @@ -1,17 +1,39 @@ import abc from collections.abc import Generator +from dataclasses import dataclass from typing import Any from danswer.chunking.models import InferenceChunk +@dataclass +class DanswerAnswer: + answer: str | None + + +@dataclass +class DanswerQuote: + # This is during inference so everything is a string by this point + quote: str + document_id: str + link: str | None + source_type: str + semantic_identifier: str + blurb: str + + class QAModel: + def warm_up_model(self) -> None: + """This is called during server start up to load the models into memory + pass if model is accessed via API""" + pass + @abc.abstractmethod def answer_question( self, query: str, context_docs: list[InferenceChunk], - ) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: raise NotImplementedError @abc.abstractmethod diff --git a/backend/danswer/direct_qa/key_validation.py b/backend/danswer/direct_qa/key_validation.py deleted file mode 100644 index c1e6375d9108..000000000000 --- a/backend/danswer/direct_qa/key_validation.py +++ /dev/null @@ -1,25 +0,0 @@ -from danswer.direct_qa import get_default_backend_qa_model -from danswer.direct_qa.llm import OpenAIQAModel -from openai.error import AuthenticationError -from openai.error import Timeout - - -def check_openai_api_key_is_valid(openai_api_key: str) -> bool: - if not openai_api_key: - return False - - qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5) - if not isinstance(qa_model, OpenAIQAModel): - raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model") - - # try for up to 2 timeouts (e.g. 10 seconds in total) - for _ in range(2): - try: - qa_model.answer_question("Do not respond", []) - return True - except AuthenticationError: - return False - except Timeout: - pass - - return False diff --git a/backend/danswer/direct_qa/llm.py b/backend/danswer/direct_qa/llm.py deleted file mode 100644 index fa5b85f5fefb..000000000000 --- a/backend/danswer/direct_qa/llm.py +++ /dev/null @@ -1,488 +0,0 @@ -import json -import math -import re -from collections.abc import Callable -from collections.abc import Generator -from functools import wraps -from typing import Any -from typing import cast -from typing import Dict -from typing import Literal -from typing import Optional -from typing import Tuple -from typing import TypeVar -from typing import Union - -import openai -import regex -from danswer.chunking.models import InferenceChunk -from danswer.configs.app_configs import INCLUDE_METADATA -from danswer.configs.app_configs import OPENAI_API_KEY -from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT -from danswer.configs.constants import BLURB -from danswer.configs.constants import DOCUMENT_ID -from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY -from danswer.configs.constants import SEMANTIC_IDENTIFIER -from danswer.configs.constants import SOURCE_LINK -from danswer.configs.constants import SOURCE_TYPE -from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS -from danswer.configs.model_configs import OPENAI_MODEL_VERSION -from danswer.direct_qa.exceptions import OpenAIKeyMissing -from danswer.direct_qa.interfaces import QAModel -from danswer.direct_qa.qa_prompts import ANSWER_PAT -from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg -from danswer.direct_qa.qa_prompts import json_chat_processor -from danswer.direct_qa.qa_prompts import json_processor -from danswer.direct_qa.qa_prompts import QUOTE_PAT -from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT -from danswer.dynamic_configs import get_dynamic_config_store -from danswer.dynamic_configs.interface import ConfigNotFoundError -from danswer.utils.logger import setup_logger -from danswer.utils.text_processing import clean_model_quote -from danswer.utils.text_processing import shared_precompare_cleanup -from danswer.utils.timing import log_function_time -from openai.error import AuthenticationError -from openai.error import Timeout - - -logger = setup_logger() - - -def get_openai_api_key() -> str: - return OPENAI_API_KEY or cast( - str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY) - ) - - -def get_json_line(json_dict: dict) -> str: - return json.dumps(json_dict) + "\n" - - -def extract_answer_quotes_freeform( - answer_raw: str, -) -> Tuple[Optional[str], Optional[list[str]]]: - null_answer_check = ( - answer_raw.replace(ANSWER_PAT, "").replace(QUOTE_PAT, "").strip() - ) - - # If model just gives back the uncertainty pattern to signify answer isn't found or nothing at all - # if null_answer_check == UNCERTAINTY_PAT or not null_answer_check: - # return None, None - - # If no answer section, don't care about the quote - if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): - return None, None - - # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt - if answer_raw.lower().startswith(ANSWER_PAT.lower()): - answer_raw = answer_raw[len(ANSWER_PAT) :] - - # Accept quote sections starting with the lower case version - answer_raw = answer_raw.replace( - f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" - ) # Just in case model unreliable - - sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) - sections_clean = [ - str(section).strip() for section in sections if str(section).strip() - ] - if not sections_clean: - return None, None - - answer = str(sections_clean[0]) - if len(sections) == 1: - return answer, None - return answer, sections_clean[1:] - - -def extract_answer_quotes_json( - answer_dict: dict[str, str | list[str]] -) -> Tuple[Optional[str], Optional[list[str]]]: - answer_dict = {k.lower(): v for k, v in answer_dict.items()} - answer = str(answer_dict.get("answer")) - quotes = answer_dict.get("quotes") or answer_dict.get("quote") - if isinstance(quotes, str): - quotes = [quotes] - return answer, quotes - - -def separate_answer_quotes( - answer_raw: str, -) -> Tuple[Optional[str], Optional[list[str]]]: - try: - model_raw_json = json.loads(answer_raw) - return extract_answer_quotes_json(model_raw_json) - except ValueError: - return extract_answer_quotes_freeform(answer_raw) - - -def match_quotes_to_docs( - quotes: list[str], - chunks: list[InferenceChunk], - max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, - fuzzy_search: bool = False, - prefix_only_length: int = 100, -) -> Dict[str, Dict[str, Union[str, int, None]]]: - quotes_dict: dict[str, dict[str, Union[str, int, None]]] = {} - for quote in quotes: - max_edits = math.ceil(float(len(quote)) * max_error_percent) - - for chunk in chunks: - if not chunk.source_links: - continue - - quote_clean = shared_precompare_cleanup( - clean_model_quote(quote, trim_length=prefix_only_length) - ) - chunk_clean = shared_precompare_cleanup(chunk.content) - - # Finding the offset of the quote in the plain text - if fuzzy_search: - re_search_str = ( - r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" - ) - found = regex.search(re_search_str, chunk_clean) - if not found: - continue - offset = found.span()[0] - else: - if quote_clean not in chunk_clean: - continue - offset = chunk_clean.index(quote_clean) - - # Extracting the link from the offset - curr_link = None - for link_offset, link in chunk.source_links.items(): - # Should always find one because offset is at least 0 and there must be a 0 link_offset - if int(link_offset) <= offset: - curr_link = link - else: - quotes_dict[quote] = { - DOCUMENT_ID: chunk.document_id, - SOURCE_LINK: curr_link, - SOURCE_TYPE: chunk.source_type, - SEMANTIC_IDENTIFIER: chunk.semantic_identifier, - BLURB: chunk.blurb, - } - break - quotes_dict[quote] = { - DOCUMENT_ID: chunk.document_id, - SOURCE_LINK: curr_link, - SOURCE_TYPE: chunk.source_type, - SEMANTIC_IDENTIFIER: chunk.semantic_identifier, - BLURB: chunk.blurb, - } - break - return quotes_dict - - -def process_answer( - answer_raw: str, chunks: list[InferenceChunk] -) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: - answer, quote_strings = separate_answer_quotes(answer_raw) - if answer == UNCERTAINTY_PAT or not answer: - if answer == UNCERTAINTY_PAT: - logger.debug("Answer matched UNCERTAINTY_PAT") - else: - logger.debug("No answer extracted from raw output") - return None, None - - logger.info(f"Answer: {answer}") - if not quote_strings: - logger.debug("No quotes extracted from raw output") - return answer, None - logger.info(f"All quotes (including unmatched): {quote_strings}") - quotes_dict = match_quotes_to_docs(quote_strings, chunks) - logger.info(f"Final quotes dict: {quotes_dict}") - - return answer, quotes_dict - - -def stream_answer_end(answer_so_far: str, next_token: str) -> bool: - next_token = next_token.replace('\\"', "") - # If the previous character is an escape token, don't consider the first character of next_token - if answer_so_far and answer_so_far[-1] == "\\": - next_token = next_token[1:] - if '"' in next_token: - return True - return False - - -F = TypeVar("F", bound=Callable) - -Answer = str -Quotes = dict[str, dict[str, str | int | None]] -ModelType = Literal["ChatCompletion", "Completion"] -PromptProcessor = Callable[[str, list[str]], str] - - -def _build_openai_settings(**kwargs: Any) -> dict[str, Any]: - """ - Utility to add in some common default values so they don't have to be set every time. - """ - return { - "temperature": 0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - **kwargs, - } - - -def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F: - @wraps(openai_call) - def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any: - try: - # if streamed, the call returns a generator - if kwargs.get("stream"): - - def _generator() -> Generator[Any, None, None]: - yield from openai_call(*args, **kwargs) - - return _generator() - return openai_call(*args, **kwargs) - except AuthenticationError: - logger.exception("Failed to authenticate with OpenAI API") - raise - except Timeout: - logger.exception("OpenAI API timed out for query: %s", query) - raise - except Exception as e: - logger.exception("Unexpected error with OpenAI API for query: %s", query) - raise - - return cast(F, wrapped_call) - - -# used to check if the QAModel is an OpenAI model -class OpenAIQAModel(QAModel): - pass - - -class OpenAICompletionQA(OpenAIQAModel): - def __init__( - self, - prompt_processor: Callable[ - [str, list[InferenceChunk], bool], str - ] = json_processor, - model_version: str = OPENAI_MODEL_VERSION, - max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS, - api_key: str | None = None, - timeout: int | None = None, - include_metadata: bool = INCLUDE_METADATA, - ) -> None: - self.prompt_processor = prompt_processor - self.model_version = model_version - self.max_output_tokens = max_output_tokens - self.timeout = timeout - self.include_metadata = include_metadata - try: - self.api_key = api_key or get_openai_api_key() - except ConfigNotFoundError: - raise OpenAIKeyMissing() - - @log_function_time() - def answer_question( - self, query: str, context_docs: list[InferenceChunk] - ) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: - filled_prompt = self.prompt_processor( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.Completion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=self.api_key, - prompt=filled_prompt, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - ), - ) - model_output = cast(str, response["choices"][0]["text"]).strip() - logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")) - logger.debug(model_output) - - answer, quotes_dict = process_answer(model_output, context_docs) - return answer, quotes_dict - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: - filled_prompt = self.prompt_processor( - query, context_docs, self.include_metadata - ) - logger.debug(filled_prompt) - - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.Completion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=self.api_key, - prompt=filled_prompt, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - stream=True, - ), - ) - model_output: str = "" - found_answer_start = False - found_answer_end = False - # iterate through the stream of events - for event in response: - event_text = cast(str, event["choices"][0]["text"]) - model_previous = model_output - model_output += event_text - - if not found_answer_start and '{"answer":"' in model_output.replace( - " ", "" - ).replace("\n", ""): - found_answer_start = True - continue - - if found_answer_start and not found_answer_end: - if stream_answer_end(model_previous, event_text): - found_answer_end = True - yield {"answer_finished": True} - continue - yield {"answer_data": event_text} - - logger.debug(model_output) - - answer, quotes_dict = process_answer(model_output, context_docs) - if answer: - logger.info(answer) - else: - logger.warning( - "Answer extraction from model output failed, most likely no quotes provided" - ) - - if quotes_dict is None: - yield {} - else: - yield quotes_dict - - -class OpenAIChatCompletionQA(OpenAIQAModel): - def __init__( - self, - prompt_processor: Callable[ - [str, list[InferenceChunk], bool], list[dict[str, str]] - ] = json_chat_processor, - model_version: str = OPENAI_MODEL_VERSION, - max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS, - timeout: int | None = None, - reflexion_try_count: int = 0, - api_key: str | None = None, - include_metadata: bool = INCLUDE_METADATA, - ) -> None: - self.prompt_processor = prompt_processor - self.model_version = model_version - self.max_output_tokens = max_output_tokens - self.reflexion_try_count = reflexion_try_count - self.timeout = timeout - self.include_metadata = include_metadata - try: - self.api_key = api_key or get_openai_api_key() - except ConfigNotFoundError: - raise OpenAIKeyMissing() - - @log_function_time() - def answer_question( - self, - query: str, - context_docs: list[InferenceChunk], - ) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]: - messages = self.prompt_processor(query, context_docs, self.include_metadata) - logger.debug(json.dumps(messages, indent=4)) - model_output = "" - for _ in range(self.reflexion_try_count + 1): - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.ChatCompletion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=self.api_key, - messages=messages, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - ), - ) - model_output = cast( - str, response["choices"][0]["message"]["content"] - ).strip() - assistant_msg = {"content": model_output, "role": "assistant"} - messages.extend([assistant_msg, get_chat_reflexion_msg()]) - logger.info( - "OpenAI Token Usage: " + str(response["usage"]).replace("\n", "") - ) - - logger.debug(model_output) - - answer, quotes_dict = process_answer(model_output, context_docs) - return answer, quotes_dict - - def answer_question_stream( - self, query: str, context_docs: list[InferenceChunk] - ) -> Generator[dict[str, Any] | None, None, None]: - messages = self.prompt_processor(query, context_docs, self.include_metadata) - logger.debug(json.dumps(messages, indent=4)) - - openai_call = _handle_openai_exceptions_wrapper( - openai_call=openai.ChatCompletion.create, - query=query, - ) - response = openai_call( - **_build_openai_settings( - api_key=self.api_key, - messages=messages, - model=self.model_version, - max_tokens=self.max_output_tokens, - request_timeout=self.timeout, - stream=True, - ), - ) - model_output: str = "" - found_answer_start = False - found_answer_end = False - for event in response: - event_dict = cast(dict[str, Any], event["choices"][0]["delta"]) - if ( - "content" not in event_dict - ): # could be a role message or empty termination - continue - event_text = event_dict["content"] - model_previous = model_output - model_output += event_text - logger.debug(f"GPT returned token: {event_text}") - - if not found_answer_start and '{"answer":"' in model_output.replace( - " ", "" - ).replace("\n", ""): - # Note, if the token that completes the pattern has additional text, for example if the token is "? - # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the - # event that the model outputs the UNCERTAINTY_PAT - found_answer_start = True - continue - - if found_answer_start and not found_answer_end: - if stream_answer_end(model_previous, event_text): - found_answer_end = True - yield {"answer_finished": True} - continue - yield {"answer_data": event_text} - - logger.debug(model_output) - - _, quotes_dict = process_answer(model_output, context_docs) - - yield {} if quotes_dict is None else quotes_dict diff --git a/backend/danswer/direct_qa/open_ai.py b/backend/danswer/direct_qa/open_ai.py new file mode 100644 index 000000000000..87aac726c624 --- /dev/null +++ b/backend/danswer/direct_qa/open_ai.py @@ -0,0 +1,280 @@ +import json +from abc import ABC +from collections.abc import Callable +from collections.abc import Generator +from functools import wraps +from typing import Any +from typing import cast +from typing import TypeVar + +import openai +from danswer.chunking.models import InferenceChunk +from danswer.configs.app_configs import INCLUDE_METADATA +from danswer.configs.app_configs import OPENAI_API_KEY +from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY +from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.direct_qa.exceptions import OpenAIKeyMissing +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.interfaces import QAModel +from danswer.direct_qa.qa_prompts import ChatPromptProcessor +from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg +from danswer.direct_qa.qa_prompts import JsonChatProcessor +from danswer.direct_qa.qa_prompts import JsonProcessor +from danswer.direct_qa.qa_prompts import NonChatPromptProcessor +from danswer.direct_qa.qa_utils import process_answer +from danswer.direct_qa.qa_utils import process_model_tokens +from danswer.dynamic_configs import get_dynamic_config_store +from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.utils.logger import setup_logger +from danswer.utils.timing import log_function_time +from openai.error import AuthenticationError +from openai.error import Timeout + + +logger = setup_logger() + +F = TypeVar("F", bound=Callable) + + +def get_openai_api_key() -> str: + return OPENAI_API_KEY or cast( + str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY) + ) + + +def _ensure_openai_api_key(api_key: str | None) -> str: + try: + return api_key or get_openai_api_key() + except ConfigNotFoundError: + raise OpenAIKeyMissing() + + +def _build_openai_settings(**kwargs: Any) -> dict[str, Any]: + """ + Utility to add in some common default values so they don't have to be set every time. + """ + return { + "temperature": 0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + **kwargs, + } + + +def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F: + @wraps(openai_call) + def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any: + try: + # if streamed, the call returns a generator + if kwargs.get("stream"): + + def _generator() -> Generator[Any, None, None]: + yield from openai_call(*args, **kwargs) + + return _generator() + return openai_call(*args, **kwargs) + except AuthenticationError: + logger.exception("Failed to authenticate with OpenAI API") + raise + except Timeout: + logger.exception("OpenAI API timed out for query: %s", query) + raise + except Exception: + logger.exception("Unexpected error with OpenAI API for query: %s", query) + raise + + return cast(F, wrapped_call) + + +# used to check if the QAModel is an OpenAI model +class OpenAIQAModel(QAModel, ABC): + pass + + +class OpenAICompletionQA(OpenAIQAModel): + def __init__( + self, + prompt_processor: NonChatPromptProcessor = JsonProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + api_key: str | None = None, + timeout: int | None = None, + include_metadata: bool = INCLUDE_METADATA, + ) -> None: + self.prompt_processor = prompt_processor + self.model_version = model_version + self.max_output_tokens = max_output_tokens + self.timeout = timeout + self.include_metadata = include_metadata + try: + self.api_key = api_key or get_openai_api_key() + except ConfigNotFoundError: + raise OpenAIKeyMissing() + + @staticmethod + def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]: + for event in response: + yield event["choices"][0]["text"] + + @log_function_time() + def answer_question( + self, query: str, context_docs: list[InferenceChunk] + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.Completion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( + api_key=_ensure_openai_api_key(self.api_key), + prompt=filled_prompt, + model=self.model_version, + max_tokens=self.max_output_tokens, + request_timeout=self.timeout, + ), + ) + model_output = cast(str, response["choices"][0]["text"]).strip() + logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")) + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + filled_prompt = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(filled_prompt) + + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.Completion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( + api_key=_ensure_openai_api_key(self.api_key), + prompt=filled_prompt, + model=self.model_version, + max_tokens=self.max_output_tokens, + request_timeout=self.timeout, + stream=True, + ), + ) + + tokens = self._generate_tokens_from_response(response) + + yield from process_model_tokens( + tokens=tokens, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) + + +class OpenAIChatCompletionQA(OpenAIQAModel): + def __init__( + self, + prompt_processor: ChatPromptProcessor = JsonChatProcessor(), + model_version: str = GEN_AI_MODEL_VERSION, + max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS, + timeout: int | None = None, + reflexion_try_count: int = 0, + api_key: str | None = None, + include_metadata: bool = INCLUDE_METADATA, + ) -> None: + self.prompt_processor = prompt_processor + self.model_version = model_version + self.max_output_tokens = max_output_tokens + self.reflexion_try_count = reflexion_try_count + self.timeout = timeout + self.include_metadata = include_metadata + self.api_key = api_key + + @staticmethod + def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]: + for event in response: + event_dict = cast(dict[str, Any], event["choices"][0]["delta"]) + if ( + "content" not in event_dict + ): # could be a role message or empty termination + continue + yield event_dict["content"] + + @log_function_time() + def answer_question( + self, + query: str, + context_docs: list[InferenceChunk], + ) -> tuple[DanswerAnswer, list[DanswerQuote]]: + messages = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(json.dumps(messages, indent=4)) + model_output = "" + for _ in range(self.reflexion_try_count + 1): + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.ChatCompletion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( + api_key=_ensure_openai_api_key(self.api_key), + messages=messages, + model=self.model_version, + max_tokens=self.max_output_tokens, + request_timeout=self.timeout, + ), + ) + model_output = cast( + str, response["choices"][0]["message"]["content"] + ).strip() + assistant_msg = {"content": model_output, "role": "assistant"} + messages.extend([assistant_msg, get_json_chat_reflexion_msg()]) + logger.info( + "OpenAI Token Usage: " + str(response["usage"]).replace("\n", "") + ) + + logger.debug(model_output) + + answer, quotes_dict = process_answer(model_output, context_docs) + return answer, quotes_dict + + def answer_question_stream( + self, query: str, context_docs: list[InferenceChunk] + ) -> Generator[dict[str, Any] | None, None, None]: + messages = self.prompt_processor.fill_prompt( + query, context_docs, self.include_metadata + ) + logger.debug(json.dumps(messages, indent=4)) + + openai_call = _handle_openai_exceptions_wrapper( + openai_call=openai.ChatCompletion.create, + query=query, + ) + response = openai_call( + **_build_openai_settings( + api_key=_ensure_openai_api_key(self.api_key), + messages=messages, + model=self.model_version, + max_tokens=self.max_output_tokens, + request_timeout=self.timeout, + stream=True, + ), + ) + + tokens = self._generate_tokens_from_response(response) + + yield from process_model_tokens( + tokens=tokens, + context_docs=context_docs, + is_json_prompt=self.prompt_processor.specifies_json_output, + ) diff --git a/backend/danswer/direct_qa/qa_prompts.py b/backend/danswer/direct_qa/qa_prompts.py index f2a74d8cb506..1e1ca9548a9e 100644 --- a/backend/danswer/direct_qa/qa_prompts.py +++ b/backend/danswer/direct_qa/qa_prompts.py @@ -1,3 +1,4 @@ +import abc import json from danswer.chunking.models import InferenceChunk @@ -16,10 +17,10 @@ QUOTE_PAT = "Quote:" BASE_PROMPT = ( f"Answer the query based on provided documents and quote relevant sections. " f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. " + f'Respond with "?" for the answer if the query cannot be answered based on the documents. ' f"The quotes must be EXACT substrings from the documents." ) - SAMPLE_QUESTION = "Where is the Eiffel Tower?" SAMPLE_JSON_RESPONSE = { @@ -31,7 +32,23 @@ SAMPLE_JSON_RESPONSE = { } -def add_metadata_section( +def _append_acknowledge_doc_messages( + current_messages: list[dict[str, str]], new_chunk_content: str +) -> list[dict[str, str]]: + updated_messages = current_messages.copy() + updated_messages.extend( + [ + { + "role": "user", + "content": new_chunk_content, + }, + {"role": "assistant", "content": "Acknowledged"}, + ] + ) + return updated_messages + + +def _add_metadata_section( prompt_current: str, chunk: InferenceChunk, prepend_tab: bool = False, @@ -67,192 +84,313 @@ def add_metadata_section( return prompt_current -def json_processor( - question: str, - chunks: list[InferenceChunk], - include_metadata: bool = False, - include_sep: bool = True, -) -> str: - prompt = ( - BASE_PROMPT + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n" - f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' - ) +class PromptProcessor(abc.ABC): + """Take the most relevant chunks and fills out a LLM prompt using the chunk contents + and optionally metadata about the chunk""" - for chunk in chunks: - prompt += f"\n\n{DOC_SEP_PAT}\n" - if include_metadata: - prompt = add_metadata_section( - prompt, chunk, prepend_tab=False, include_sep=include_sep - ) + @property + @abc.abstractmethod + def specifies_json_output(self) -> bool: + raise NotImplementedError - prompt += chunk.content - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - return prompt + @staticmethod + @abc.abstractmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> str | list[dict[str, str]]: + raise NotImplementedError -def json_chat_processor( - question: str, - chunks: list[InferenceChunk], - include_metadata: bool = False, - include_sep: bool = False, -) -> list[dict[str, str]]: - metadata_prompt_section = "with metadata and contents " if include_metadata else "" - intro_msg = ( - f"You are a Question Answering assistant that answers queries based on the provided most relevant documents.\n" - f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".' - ) - - complete_answer_not_found_response = ( - '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' - ) - task_msg = ( - "Now answer the next user query based on documents above and quote relevant sections.\n" - "Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n" - "All quotes MUST be EXACT substrings from provided documents.\n" - "Your responses should be informative and concise.\n" - "You MUST prioritize information from provided documents over internal knowledge.\n" - "If the query cannot be answered based on the documents, respond with " - f"{complete_answer_not_found_response}\n" - "If the query requires aggregating the number of documents, respond with " - '{"answer": "Aggregations not supported", "quotes": []}\n' - f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}" - ) - messages = [{"role": "system", "content": intro_msg}] - - for chunk in chunks: - full_context = "" - if include_metadata: - full_context = add_metadata_section( - full_context, chunk, prepend_tab=False, include_sep=include_sep - ) - full_context += chunk.content - messages.extend( - [ - { - "role": "user", - "content": full_context, - }, - {"role": "assistant", "content": "Acknowledged"}, - ] - ) - messages.append({"role": "system", "content": task_msg}) - - messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"}) - - return messages +class NonChatPromptProcessor(PromptProcessor): + @staticmethod + @abc.abstractmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> str: + raise NotImplementedError -# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may use again in future +class ChatPromptProcessor(PromptProcessor): + @staticmethod + @abc.abstractmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> list[dict[str, str]]: + raise NotImplementedError -# Chain of Thought approach works however has higher token cost (more expensive) and is slower. -# Should use this one if users ask questions that require logical reasoning. -def json_cot_variant_processor(question: str, documents: list[str]) -> str: - prompt = ( - f"Answer the query based on provided documents and quote relevant sections. " - f'Respond with a freeform reasoning section followed by "Final Answer:" with a ' - f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n" - f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n" - f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' - ) +class JsonProcessor(NonChatPromptProcessor): + @property + def specifies_json_output(self) -> bool: + return True - for document in documents: - prompt += f"\n{DOC_SEP_PAT}\n{document}" - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - prompt += "Reasoning:\n" - return prompt - - -# This one seems largely useless with a single example -# Model seems to take the one example of answering Yes and just does that too. -def json_reflexion_processor(question: str, documents: list[str]) -> str: - reflexion_str = "Does this fully answer the user query?" - prompt = ( - BASE_PROMPT - + f'After each generated json, ask "{reflexion_str}" and respond Yes or No. ' - f"If No, generate a better json response to the query.\n" - f"Sample question and response:\n" - f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n" - f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n" - f"{reflexion_str} Yes\n\n" - f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' - ) - - for document in documents: - prompt += f"\n---NEW CONTEXT DOCUMENT---\n{document}" - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - return prompt - - -# Initial design, works pretty well but not optimal -def freeform_processor(question: str, documents: list[str]) -> str: - prompt = ( - f"Answer the query based on the documents below and quote the documents segments containing the answer. " - f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. ' - f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. ' - f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". ' - f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n' - ) - - for document in documents: - prompt += f"\n{DOC_SEP_PAT}\n{document}" - - prompt += "\n\n---\n\n" - prompt += f"{QUESTION_PAT}\n{question}\n" - prompt += f"{ANSWER_PAT}\n" - return prompt - - -def freeform_chat_processor( - question: str, documents: list[str] -) -> list[dict[str, str]]: - sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each." - role_msg = ( - f"You are a Question Answering assistant that answers queries based on provided documents. " - f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and ' - f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. ' - f"Answer the question directly and concisely. " - f"Each quote should be a single continuous segment from a document. " - f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". ' - f"An example of quote sections may look like:\n{sample_quote}" - ) - - messages = [ - {"role": "system", "content": role_msg}, - ] - for document in documents: - messages.extend( - [ - { - "role": "user", - "content": f"Acknowledge the following document:\n{document}", - }, - {"role": "assistant", "content": "Acknowledged"}, - ] + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> str: + prompt = ( + BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n" + f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' ) - messages.append( - { - "role": "user", - "content": f"Please now answer the following query based on the previously provided " - f"documents and quote the relevant sections of the documents\n{question}", - }, - ) + for chunk in chunks: + prompt += f"\n\n{DOC_SEP_PAT}\n" + if include_metadata: + prompt = _add_metadata_section( + prompt, chunk, prepend_tab=False, include_sep=True + ) - return messages + prompt += chunk.content + + prompt += "\n\n---\n\n" + prompt += f"{QUESTION_PAT}\n{question}\n" + return prompt -# Not very useful, have not seen it improve an answer based on this -# Sometimes gpt-3.5-turbo will just answer something worse like: -# 'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source -# document. No revision is needed.' -def get_chat_reflexion_msg() -> dict[str, str]: +class JsonChatProcessor(ChatPromptProcessor): + @property + def specifies_json_output(self) -> bool: + return True + + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> list[dict[str, str]]: + metadata_prompt_section = ( + "with metadata and contents " if include_metadata else "" + ) + intro_msg = ( + f"You are a Question Answering assistant that answers queries " + f"based on the provided most relevant documents.\n" + f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".' + ) + + complete_answer_not_found_response = ( + '{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}' + ) + task_msg = ( + "Now answer the next user query based on documents above and quote relevant sections.\n" + "Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n" + "All quotes MUST be EXACT substrings from provided documents.\n" + "Your responses should be informative and concise.\n" + "You MUST prioritize information from provided documents over internal knowledge.\n" + "If the query cannot be answered based on the documents, respond with " + f"{complete_answer_not_found_response}\n" + "If the query requires aggregating the number of documents, respond with " + '{"answer": "Aggregations not supported", "quotes": []}\n' + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}" + ) + messages = [{"role": "system", "content": intro_msg}] + + for chunk in chunks: + full_context = "" + if include_metadata: + full_context = _add_metadata_section( + full_context, chunk, prepend_tab=False, include_sep=False + ) + full_context += chunk.content + messages = _append_acknowledge_doc_messages(messages, full_context) + messages.append({"role": "system", "content": task_msg}) + + messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"}) + + return messages + + +class WeakModelFreeformProcessor(NonChatPromptProcessor): + """Avoid using this one if the model is capable of using another prompt + Intended for models that can't follow complex instructions or have short context windows + This prompt only uses 1 reference document chunk + """ + + @property + def specifies_json_output(self) -> bool: + return False + + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> str: + first_chunk_content = chunks[0].content if chunks else "No Document Provided" + + prompt = ( + f"Reference Document:\n{first_chunk_content}\n{GENERAL_SEP_PAT}" + f"Answer the user query below based on the reference document above. " + f'Respond with an "{ANSWER_PAT}" section and ' + f'as many "{QUOTE_PAT}" sections as needed to support the answer.' + f"\n{GENERAL_SEP_PAT}" + f"{QUESTION_PAT} {question}\n" + f"{ANSWER_PAT}" + ) + + return prompt + + +class WeakChatModelFreeformProcessor(ChatPromptProcessor): + """Avoid using this one if the model is capable of using another prompt + Intended for models that can't follow complex instructions or have short context windows + This prompt only uses 1 reference document chunk + """ + + @property + def specifies_json_output(self) -> bool: + return False + + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> list[dict[str, str]]: + first_chunk_content = chunks[0].content if chunks else "No Document Provided" + intro_msg = ( + f"You are a question answering assistant. " + f'Respond to the query with an "{ANSWER_PAT}" section and ' + f'as many "{QUOTE_PAT}" sections as needed to support the answer. ' + f"Answer the user query based on the following document:\n\n{first_chunk_content}" + ) + + messages = [{"role": "system", "content": intro_msg}] + + user_query = f"{QUESTION_PAT} {question}" + messages.append({"role": "user", "content": user_query}) + + return messages + + +# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may revisit in future + + +class FreeformProcessor(NonChatPromptProcessor): + @property + def specifies_json_output(self) -> bool: + return False + + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> str: + prompt = ( + f"Answer the query based on the documents below and quote the documents segments containing the answer. " + f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. ' + f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. ' + f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". ' + f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n' + ) + + for chunk in chunks: + prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}" + + prompt += "\n\n---\n\n" + prompt += f"{QUESTION_PAT}\n{question}\n" + prompt += f"{ANSWER_PAT}\n" + return prompt + + +class FreeformChatProcessor(ChatPromptProcessor): + @property + def specifies_json_output(self) -> bool: + return False + + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> list[dict[str, str]]: + sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each." + role_msg = ( + f"You are a Question Answering assistant that answers queries based on provided documents. " + f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and ' + f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. ' + f"Answer the question directly and concisely. " + f"Each quote should be a single continuous segment from a document. " + f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". ' + f"An example of quote sections may look like:\n{sample_quote}" + ) + + messages = [ + {"role": "system", "content": role_msg}, + ] + for chunk in chunks: + messages = _append_acknowledge_doc_messages(messages, chunk.content) + + messages.append( + { + "role": "user", + "content": f"Please now answer the following query based on the previously provided " + f"documents and quote the relevant sections of the documents\n{question}", + }, + ) + + return messages + + +class JsonCOTProcessor(NonChatPromptProcessor): + """Chain of Thought allows model to explain out its reasoning to handle harder tests. + This prompt type works however has higher token cost (more expensive) and is slower. + Consider this one if users ask questions that require logical reasoning.""" + + @property + def specifies_json_output(self) -> bool: + return True + + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> str: + prompt = ( + f"Answer the query based on provided documents and quote relevant sections. " + f'Respond with a freeform reasoning section followed by "Final Answer:" with a ' + f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n" + f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n" + f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' + ) + + for chunk in chunks: + prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}" + + prompt += "\n\n---\n\n" + prompt += f"{QUESTION_PAT}\n{question}\n" + prompt += "Reasoning:\n" + return prompt + + +class JsonReflexionProcessor(NonChatPromptProcessor): + """Reflexion prompting to attempt to have model evaluate its own answer. + This one seems largely useless when only given a single example + Model seems to take the one example of answering "Yes" and just does that too.""" + + @property + def specifies_json_output(self) -> bool: + return True + + @staticmethod + def fill_prompt( + question: str, chunks: list[InferenceChunk], include_metadata: bool = False + ) -> str: + reflexion_str = "Does this fully answer the user query?" + prompt = ( + BASE_PROMPT + + f'After each generated json, ask "{reflexion_str}" and respond Yes or No. ' + f"If No, generate a better json response to the query.\n" + f"Sample question and response:\n" + f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n" + f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n" + f"{reflexion_str} Yes\n\n" + f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n' + ) + + for chunk in chunks: + prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}" + + prompt += "\n\n---\n\n" + prompt += f"{QUESTION_PAT}\n{question}\n" + return prompt + + +def get_json_chat_reflexion_msg() -> dict[str, str]: + """With the models tried (curent as of Jul 2023), this has not been very useful. + Have not seen any answers improved based on this. + For models like gpt-3.5-turbo, it will often answer something like: + 'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source + document. No revision is needed.'""" reflexion_content = ( "Is the assistant response a valid json that fully answer the user query? " "If the response needs to be fixed or if an improvement is possible, provide a revised json. " diff --git a/backend/danswer/direct_qa/qa_utils.py b/backend/danswer/direct_qa/qa_utils.py new file mode 100644 index 000000000000..4687fc6c7e79 --- /dev/null +++ b/backend/danswer/direct_qa/qa_utils.py @@ -0,0 +1,247 @@ +import json +import math +import re +from collections.abc import Generator +from typing import Any +from typing import Optional +from typing import Tuple + +import regex +from danswer.chunking.models import InferenceChunk +from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT +from danswer.configs.constants import BLURB +from danswer.configs.constants import DOCUMENT_ID +from danswer.configs.constants import SEMANTIC_IDENTIFIER +from danswer.configs.constants import SOURCE_LINK +from danswer.configs.constants import SOURCE_TYPE +from danswer.direct_qa.interfaces import DanswerAnswer +from danswer.direct_qa.interfaces import DanswerQuote +from danswer.direct_qa.qa_prompts import ANSWER_PAT +from danswer.direct_qa.qa_prompts import QUOTE_PAT +from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT +from danswer.utils.logger import setup_logger +from danswer.utils.text_processing import clean_model_quote +from danswer.utils.text_processing import shared_precompare_cleanup + +logger = setup_logger() + + +def structure_quotes_for_response( + quotes: list[DanswerQuote] | None, +) -> dict[str, dict[str, str | None]]: + if quotes is None: + return {} + + response_quotes = {} + for quote in quotes: + response_quotes[quote.quote] = { + DOCUMENT_ID: quote.document_id, + SOURCE_LINK: quote.link, + SOURCE_TYPE: quote.source_type, + SEMANTIC_IDENTIFIER: quote.semantic_identifier, + BLURB: quote.blurb, + } + return response_quotes + + +def extract_answer_quotes_freeform( + answer_raw: str, +) -> Tuple[Optional[str], Optional[list[str]]]: + """Splits the model output into an Answer and 0 or more Quote sections. + Splits by the Quote pattern, if not exist then assume it's all answer and no quotes + """ + # If no answer section, don't care about the quote + if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()): + return None, None + + # Sometimes model regenerates the Answer: pattern despite it being provided in the prompt + if answer_raw.lower().startswith(ANSWER_PAT.lower()): + answer_raw = answer_raw[len(ANSWER_PAT) :] + + # Accept quote sections starting with the lower case version + answer_raw = answer_raw.replace( + f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}" + ) # Just in case model unreliable + + sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw) + sections_clean = [ + str(section).strip() for section in sections if str(section).strip() + ] + if not sections_clean: + return None, None + + answer = str(sections_clean[0]) + if len(sections) == 1: + return answer, None + return answer, sections_clean[1:] + + +def extract_answer_quotes_json( + answer_dict: dict[str, str | list[str]] +) -> Tuple[Optional[str], Optional[list[str]]]: + answer_dict = {k.lower(): v for k, v in answer_dict.items()} + answer = str(answer_dict.get("answer")) + quotes = answer_dict.get("quotes") or answer_dict.get("quote") + if isinstance(quotes, str): + quotes = [quotes] + return answer, quotes + + +def separate_answer_quotes( + answer_raw: str, +) -> Tuple[Optional[str], Optional[list[str]]]: + try: + model_raw_json = json.loads(answer_raw) + return extract_answer_quotes_json(model_raw_json) + except ValueError: + return extract_answer_quotes_freeform(answer_raw) + + +def match_quotes_to_docs( + quotes: list[str], + chunks: list[InferenceChunk], + max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT, + fuzzy_search: bool = False, + prefix_only_length: int = 100, +) -> list[DanswerQuote]: + danswer_quotes = [] + for quote in quotes: + max_edits = math.ceil(float(len(quote)) * max_error_percent) + + for chunk in chunks: + if not chunk.source_links: + continue + + quote_clean = shared_precompare_cleanup( + clean_model_quote(quote, trim_length=prefix_only_length) + ) + chunk_clean = shared_precompare_cleanup(chunk.content) + + # Finding the offset of the quote in the plain text + if fuzzy_search: + re_search_str = ( + r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}" + ) + found = regex.search(re_search_str, chunk_clean) + if not found: + continue + offset = found.span()[0] + else: + if quote_clean not in chunk_clean: + continue + offset = chunk_clean.index(quote_clean) + + # Extracting the link from the offset + curr_link = None + for link_offset, link in chunk.source_links.items(): + # Should always find one because offset is at least 0 and there must be a 0 link_offset + if int(link_offset) <= offset: + curr_link = link + else: + danswer_quotes.append( + DanswerQuote( + quote=quote, + document_id=chunk.document_id, + link=curr_link, + source_type=chunk.source_type, + semantic_identifier=chunk.semantic_identifier, + blurb=chunk.blurb, + ) + ) + break + + # If the offset is larger than the start of the last quote, it must be the last one + danswer_quotes.append( + DanswerQuote( + quote=quote, + document_id=chunk.document_id, + link=curr_link, + source_type=chunk.source_type, + semantic_identifier=chunk.semantic_identifier, + blurb=chunk.blurb, + ) + ) + break + + return danswer_quotes + + +def process_answer( + answer_raw: str, chunks: list[InferenceChunk] +) -> tuple[DanswerAnswer, list[DanswerQuote]]: + answer, quote_strings = separate_answer_quotes(answer_raw) + if answer == UNCERTAINTY_PAT or not answer: + if answer == UNCERTAINTY_PAT: + logger.debug("Answer matched UNCERTAINTY_PAT") + else: + logger.debug("No answer extracted from raw output") + return DanswerAnswer(answer=None), [] + + logger.info(f"Answer: {answer}") + if not quote_strings: + logger.debug("No quotes extracted from raw output") + return DanswerAnswer(answer=answer), [] + logger.info(f"All quotes (including unmatched): {quote_strings}") + quotes = match_quotes_to_docs(quote_strings, chunks) + logger.info(f"Final quotes: {quotes}") + + return DanswerAnswer(answer=answer), quotes + + +def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool: + next_token = next_token.replace('\\"', "") + # If the previous character is an escape token, don't consider the first character of next_token + if answer_so_far and answer_so_far[-1] == "\\": + next_token = next_token[1:] + if '"' in next_token: + return True + return False + + +def extract_quotes_from_completed_token_stream( + model_output: str, context_chunks: list[InferenceChunk] +) -> list[DanswerQuote]: + logger.debug(model_output) + answer, quotes = process_answer(model_output, context_chunks) + if answer: + logger.info(answer) + elif model_output: + logger.warning("Answer extraction from model output failed.") + + return quotes + + +def process_model_tokens( + tokens: Generator[str, None, None], + context_docs: list[InferenceChunk], + is_json_prompt: bool = True, +) -> Generator[dict[str, Any], None, None]: + """Yields Answer tokens back out in a dict for streaming to frontend + When Answer section ends, yields dict with answer_finished key + Collects all the tokens at the end to form the complete model output""" + model_output: str = "" + found_answer_start = False if is_json_prompt else True + found_answer_end = False + for token in tokens: + model_previous = model_output + model_output += token + + trimmed_combine = model_output.replace(" ", "").replace("\n", "") + if not found_answer_start and '{"answer":"' in trimmed_combine: + # Note, if the token that completes the pattern has additional text, for example if the token is "? + # Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the + # event that the model outputs the UNCERTAINTY_PAT + found_answer_start = True + continue + + if found_answer_start and not found_answer_end: + if (is_json_prompt and stream_json_answer_end(model_previous, token)) or ( + not is_json_prompt and f"\n{QUOTE_PAT}" in model_output + ): + found_answer_end = True + yield {"answer_finished": True} + continue + yield {"answer_data": token} + + quotes = extract_quotes_from_completed_token_stream(model_output, context_docs) + yield structure_quotes_for_response(quotes) diff --git a/backend/danswer/listeners/slack_listener.py b/backend/danswer/listeners/slack_listener.py index e9a02c2ced25..360b2b91cae7 100644 --- a/backend/danswer/listeners/slack_listener.py +++ b/backend/danswer/listeners/slack_listener.py @@ -58,7 +58,7 @@ def _build_custom_semantic_identifier( def _process_quotes( - quotes: dict[str, dict[str, str | int | None]] | None + quotes: dict[str, dict[str, str | None]] | None ) -> tuple[str | None, list[str]]: if not quotes: return None, [] diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 2332582634ce..dc003a6b9f06 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -9,19 +9,21 @@ from danswer.auth.users import google_oauth_client from danswer.configs.app_configs import APP_HOST from danswer.configs.app_configs import APP_PORT from danswer.configs.app_configs import DISABLE_AUTH +from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import ENABLE_OAUTH from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET +from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION from danswer.configs.app_configs import SECRET from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION from danswer.configs.app_configs import WEB_DOMAIN +from danswer.configs.model_configs import GEN_AI_MODEL_VERSION +from danswer.configs.model_configs import INTERNAL_MODEL_VERSION from danswer.datastores.qdrant.indexing import list_qdrant_collections from danswer.datastores.typesense.store import check_typesense_collection_exist from danswer.datastores.typesense.store import create_typesense_collection from danswer.db.credentials import create_initial_public_credential -from danswer.direct_qa.key_validation import check_openai_api_key_is_valid -from danswer.direct_qa.llm import get_openai_api_key -from danswer.dynamic_configs.interface import ConfigNotFoundError +from danswer.direct_qa import get_default_backend_qa_model from danswer.server.event_loading import router as event_processing_router from danswer.server.health import router as health_router from danswer.server.manage import router as admin_router @@ -121,21 +123,29 @@ def get_application() -> FastAPI: warm_up_models, ) from danswer.datastores.qdrant.indexing import create_qdrant_collection - from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION + + if DISABLE_GENERATIVE_AI: + logger.info("Generative AI Q&A disabled") + else: + logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}") + logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}") auth_status = "off" if DISABLE_AUTH else "on" - logger.info(f"User auth is turned {auth_status}") + logger.info(f"User Authentication is turned {auth_status}") - if not ENABLE_OAUTH: - logger.debug("OAuth is turned off") - else: - if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET: - logger.warning("OAuth is turned on but incorrectly configured") + if not DISABLE_AUTH: + if not ENABLE_OAUTH: + logger.warning("OAuth is turned off") else: - logger.debug("OAuth is turned on") + if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET: + logger.warning("OAuth is turned on but incorrectly configured") + else: + logger.debug("OAuth is turned on") logger.info("Warming up local NLP models.") warm_up_models() + qa_model = get_default_backend_qa_model() + qa_model.warm_up_model() logger.info("Verifying query preprocessing (NLTK) data is downloaded") nltk.download("stopwords") diff --git a/backend/danswer/server/event_loading.py b/backend/danswer/server/event_loading.py index c27051237035..906de97df861 100644 --- a/backend/danswer/server/event_loading.py +++ b/backend/danswer/server/event_loading.py @@ -1,9 +1,5 @@ from typing import Any -from danswer.connectors.slack.connector import get_channel_info -from danswer.connectors.slack.connector import get_thread -from danswer.connectors.slack.connector import thread_to_doc -from danswer.utils.indexing_pipeline import build_indexing_pipeline from danswer.utils.logger import setup_logger from fastapi import APIRouter from pydantic import BaseModel diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 74edf241b1d8..9594095b6ede 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -38,8 +38,10 @@ from danswer.db.engine import get_session from danswer.db.engine import get_sqlalchemy_async_engine from danswer.db.index_attempt import create_index_attempt from danswer.db.models import User -from danswer.direct_qa.key_validation import check_openai_api_key_is_valid -from danswer.direct_qa.llm import get_openai_api_key +from danswer.direct_qa import check_model_api_key_is_valid +from danswer.direct_qa import get_default_backend_qa_model +from danswer.direct_qa.open_ai import get_openai_api_key +from danswer.direct_qa.open_ai import OpenAIQAModel from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.server.models import ApiKey @@ -102,7 +104,7 @@ def check_google_app_credentials_exist( ) -> dict[str, str]: try: return {"client_id": get_google_app_cred().web.client_id} - except ConfigNotFoundError as e: + except ConfigNotFoundError: raise HTTPException(status_code=404, detail="Google App Credentials not found") @@ -295,19 +297,13 @@ def validate_existing_openai_api_key( _: User = Depends(current_admin_user), ) -> None: # OpenAI key is only used for generative QA, so no need to validate this - # if it's turned off - if DISABLE_GENERATIVE_AI: + # if it's turned off or if a non-OpenAI model is being used + if DISABLE_GENERATIVE_AI or not isinstance( + get_default_backend_qa_model(), OpenAIQAModel + ): return - # always check if key exists - try: - openai_api_key = get_openai_api_key() - except ConfigNotFoundError: - raise HTTPException(status_code=404, detail="Key not found") - except ValueError as e: - raise HTTPException(status_code=400, detail=str(e)) - - # don't call OpenAI every single time, only validate every so often + # Only validate every so often check_key_time = "openai_api_key_last_check_time" kv_store = get_dynamic_config_store() curr_time = datetime.now() @@ -320,10 +316,17 @@ def validate_existing_openai_api_key( # First time checking the key, nothing unusual pass + try: + openai_api_key = get_openai_api_key() + except ConfigNotFoundError: + raise HTTPException(status_code=404, detail="Key not found") + except ValueError as e: + raise HTTPException(status_code=400, detail=str(e)) + get_dynamic_config_store().store(check_key_time, curr_time.timestamp()) try: - is_valid = check_openai_api_key_is_valid(openai_api_key) + is_valid = check_model_api_key_is_valid(openai_api_key) except ValueError: # this is the case where they aren't using an OpenAI-based model is_valid = True @@ -356,7 +359,7 @@ def store_openai_api_key( _: User = Depends(current_admin_user), ) -> None: try: - is_valid = check_openai_api_key_is_valid(request.api_key) + is_valid = check_model_api_key_is_valid(request.api_key) if not is_valid: raise HTTPException(400, "Invalid API key provided") get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key) diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index 2e0941f2258a..e18d27f8569e 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -102,8 +102,8 @@ class SearchResponse(BaseModel): class QAResponse(SearchResponse): - answer: str | None - quotes: dict[str, dict[str, str | int | None]] | None + answer: str | None # DanswerAnswer + quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote predicted_flow: QueryFlow predicted_search: SearchType error_msg: str | None = None diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index ebf770b55b93..65316f137b1c 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -1,10 +1,10 @@ +import json from collections.abc import Generator from danswer.auth.users import current_user from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS -from danswer.configs.app_configs import QA_TIMEOUT from danswer.datastores.qdrant.store import QdrantIndex from danswer.datastores.typesense.store import TypesenseIndex from danswer.db.models import User @@ -12,7 +12,6 @@ from danswer.direct_qa import get_default_backend_qa_model from danswer.direct_qa.answer_question import answer_question from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError -from danswer.direct_qa.llm import get_json_line from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import recommend_search_flow from danswer.search.keyword_search import retrieve_keyword_documents @@ -35,6 +34,10 @@ logger = setup_logger() router = APIRouter() +def get_json_line(json_dict: dict) -> str: + return json.dumps(json_dict) + "\n" + + @router.get("/search-intent") def get_search_type( question: QuestionRequest = Depends(), _: User = Depends(current_user) @@ -162,7 +165,7 @@ def stream_direct_qa( return try: - qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT) + qa_model = get_default_backend_qa_model() except (UnknownModelError, OpenAIKeyMissing) as e: logger.exception("Unable to get QA model") yield get_json_line({"error": str(e)}) diff --git a/backend/requirements/default.txt b/backend/requirements/default.txt index 34d19f7972db..4ffd03249e2e 100644 --- a/backend/requirements/default.txt +++ b/backend/requirements/default.txt @@ -10,6 +10,7 @@ filelock==3.12.0 google-api-python-client==2.86.0 google-auth-httplib2==0.1.0 google-auth-oauthlib==1.0.0 +gpt4all==1.0.5 httpcore==0.16.3 httpx==0.23.3 httpx-oauth==0.11.2 diff --git a/backend/tests/unit/qa_service/direct_qa/test_question_answer.py b/backend/tests/unit/qa_service/direct_qa/test_question_answer.py index 0249c82b4ed3..ecd5cd05a0b3 100644 --- a/backend/tests/unit/qa_service/direct_qa/test_question_answer.py +++ b/backend/tests/unit/qa_service/direct_qa/test_question_answer.py @@ -2,8 +2,8 @@ import textwrap import unittest from danswer.chunking.models import InferenceChunk -from danswer.direct_qa.llm import match_quotes_to_docs -from danswer.direct_qa.llm import separate_answer_quotes +from danswer.direct_qa.qa_utils import match_quotes_to_docs +from danswer.direct_qa.qa_utils import separate_answer_quotes class TestQAPostprocessing(unittest.TestCase): diff --git a/deployment/docker_compose/docker-compose.dev.yml b/deployment/docker_compose/docker-compose.dev.yml index 17771a4a05b2..b8c1738f1c18 100644 --- a/deployment/docker_compose/docker-compose.dev.yml +++ b/deployment/docker_compose/docker-compose.dev.yml @@ -17,6 +17,8 @@ services: ports: - "8080:8080" environment: + - INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion} + - GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo} - POSTGRES_HOST=relational_db - QDRANT_HOST=vector_db - TYPESENSE_HOST=search_engine diff --git a/deployment/docker_compose/env.prod.template b/deployment/docker_compose/env.prod.template index 034c7045cbfc..21b6cab43f7a 100644 --- a/deployment/docker_compose/env.prod.template +++ b/deployment/docker_compose/env.prod.template @@ -9,7 +9,7 @@ OPENAI_API_KEY= # Choose between "openai-chat-completion" and "openai-completion" INTERNAL_MODEL_VERSION=openai-chat-completion # Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility -OPENAI_MODEL_VERSION=gpt-4 +GEN_AI_MODEL_VERSION=gpt-4 # Could be something like danswer.companyname.com. Requires additional setup if not localhost WEB_DOMAIN=http://localhost:3000