mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-10-05 17:53:54 +02:00
Support GPT4All in memory (#230)
This commit is contained in:
@@ -4,6 +4,7 @@ from danswer.connectors.factory import instantiate_connector
|
|||||||
from danswer.connectors.interfaces import LoadConnector
|
from danswer.connectors.interfaces import LoadConnector
|
||||||
from danswer.connectors.interfaces import PollConnector
|
from danswer.connectors.interfaces import PollConnector
|
||||||
from danswer.connectors.models import InputType
|
from danswer.connectors.models import InputType
|
||||||
|
from danswer.datastores.indexing_pipeline import build_indexing_pipeline
|
||||||
from danswer.db.connector import disable_connector
|
from danswer.db.connector import disable_connector
|
||||||
from danswer.db.connector import fetch_connectors
|
from danswer.db.connector import fetch_connectors
|
||||||
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
from danswer.db.connector_credential_pair import update_connector_credential_pair
|
||||||
@@ -21,7 +22,6 @@ from danswer.db.index_attempt import mark_attempt_succeeded
|
|||||||
from danswer.db.models import Connector
|
from danswer.db.models import Connector
|
||||||
from danswer.db.models import IndexAttempt
|
from danswer.db.models import IndexAttempt
|
||||||
from danswer.db.models import IndexingStatus
|
from danswer.db.models import IndexingStatus
|
||||||
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
|
@@ -31,10 +31,16 @@ CROSS_EMBED_CONTEXT_SIZE = 512
|
|||||||
BATCH_SIZE_ENCODE_CHUNKS = 8
|
BATCH_SIZE_ENCODE_CHUNKS = 8
|
||||||
|
|
||||||
# QA Model API Configs
|
# QA Model API Configs
|
||||||
# https://platform.openai.com/docs/models/model-endpoint-compatibility
|
# refer to https://platform.openai.com/docs/models/model-endpoint-compatibility for OpenAI models
|
||||||
|
# Valid list:
|
||||||
|
# - openai-completion
|
||||||
|
# - openai-chat-completion
|
||||||
|
# - gpt4all-completion
|
||||||
|
# - gpt4all-chat-completion
|
||||||
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
|
INTERNAL_MODEL_VERSION = os.environ.get("INTERNAL_MODEL", "openai-chat-completion")
|
||||||
OPENAI_MODEL_VERSION = os.environ.get("OPENAI_MODEL_VERSION", "gpt-3.5-turbo")
|
# For GPT4ALL, use "ggml-model-gpt4all-falcon-q4_0.bin" for the below for a tested model
|
||||||
OPENAI_MAX_OUTPUT_TOKENS = 512
|
GEN_AI_MODEL_VERSION = os.environ.get("GEN_AI_MODEL_VERSION", "gpt-3.5-turbo")
|
||||||
|
GEN_AI_MAX_OUTPUT_TOKENS = 512
|
||||||
|
|
||||||
# Danswer custom Deep Learning Models
|
# Danswer custom Deep Learning Models
|
||||||
INTENT_MODEL_VERSION = "danswer/intent-model"
|
INTENT_MODEL_VERSION = "danswer/intent-model"
|
||||||
|
@@ -1,18 +1,49 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
|
from danswer.direct_qa.gpt_4_all import GPT4AllChatCompletionQA
|
||||||
|
from danswer.direct_qa.gpt_4_all import GPT4AllCompletionQA
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.llm import OpenAIChatCompletionQA
|
from danswer.direct_qa.open_ai import OpenAIChatCompletionQA
|
||||||
from danswer.direct_qa.llm import OpenAICompletionQA
|
from danswer.direct_qa.open_ai import OpenAICompletionQA
|
||||||
|
from openai.error import AuthenticationError
|
||||||
|
from openai.error import Timeout
|
||||||
|
|
||||||
|
|
||||||
|
def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
||||||
|
if not model_api_key:
|
||||||
|
return False
|
||||||
|
|
||||||
|
qa_model = get_default_backend_qa_model(api_key=model_api_key, timeout=5)
|
||||||
|
|
||||||
|
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||||
|
for _ in range(2):
|
||||||
|
try:
|
||||||
|
qa_model.answer_question("Do not respond", [])
|
||||||
|
return True
|
||||||
|
except AuthenticationError:
|
||||||
|
return False
|
||||||
|
except Timeout:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
def get_default_backend_qa_model(
|
def get_default_backend_qa_model(
|
||||||
internal_model: str = INTERNAL_MODEL_VERSION, **kwargs: Any
|
internal_model: str = INTERNAL_MODEL_VERSION,
|
||||||
|
api_key: str | None = None,
|
||||||
|
timeout: int = QA_TIMEOUT,
|
||||||
|
**kwargs: Any
|
||||||
) -> QAModel:
|
) -> QAModel:
|
||||||
if internal_model == "openai-completion":
|
if internal_model == "openai-completion":
|
||||||
return OpenAICompletionQA(**kwargs)
|
return OpenAICompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||||
elif internal_model == "openai-chat-completion":
|
elif internal_model == "openai-chat-completion":
|
||||||
return OpenAIChatCompletionQA(**kwargs)
|
return OpenAIChatCompletionQA(timeout=timeout, api_key=api_key, **kwargs)
|
||||||
|
elif internal_model == "gpt4all-completion":
|
||||||
|
return GPT4AllCompletionQA(**kwargs)
|
||||||
|
elif internal_model == "gpt4all-chat-completion":
|
||||||
|
return GPT4AllChatCompletionQA(**kwargs)
|
||||||
else:
|
else:
|
||||||
raise UnknownModelError(internal_model)
|
raise UnknownModelError(internal_model)
|
||||||
|
@@ -1,13 +1,13 @@
|
|||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
|
||||||
from danswer.datastores.qdrant.store import QdrantIndex
|
from danswer.datastores.qdrant.store import QdrantIndex
|
||||||
from danswer.datastores.typesense.store import TypesenseIndex
|
from danswer.datastores.typesense.store import TypesenseIndex
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa import get_default_backend_qa_model
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
|
from danswer.direct_qa.qa_utils import structure_quotes_for_response
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
from danswer.search.models import QueryFlow
|
from danswer.search.models import QueryFlow
|
||||||
@@ -26,7 +26,6 @@ logger = setup_logger()
|
|||||||
def answer_question(
|
def answer_question(
|
||||||
question: QuestionRequest,
|
question: QuestionRequest,
|
||||||
user: User | None,
|
user: User | None,
|
||||||
qa_model_timeout: int = QA_TIMEOUT,
|
|
||||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||||
) -> QAResponse:
|
) -> QAResponse:
|
||||||
query = question.query
|
query = question.query
|
||||||
@@ -74,7 +73,7 @@ def answer_question(
|
|||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qa_model = get_default_backend_qa_model(timeout=qa_model_timeout)
|
qa_model = get_default_backend_qa_model()
|
||||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=None,
|
answer=None,
|
||||||
@@ -102,8 +101,8 @@ def answer_question(
|
|||||||
error_msg = f"Error occurred in call to LLM - {e}"
|
error_msg = f"Error occurred in call to LLM - {e}"
|
||||||
|
|
||||||
return QAResponse(
|
return QAResponse(
|
||||||
answer=answer,
|
answer=answer.answer if answer else None,
|
||||||
quotes=quotes,
|
quotes=structure_quotes_for_response(quotes),
|
||||||
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
|
||||||
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
|
||||||
predicted_flow=predicted_flow,
|
predicted_flow=predicted_flow,
|
||||||
|
@@ -1,5 +1,10 @@
|
|||||||
class OpenAIKeyMissing(Exception):
|
class OpenAIKeyMissing(Exception):
|
||||||
def __init__(self, msg: str = "Unable to find an OpenAI Key") -> None:
|
default_msg = (
|
||||||
|
"Unable to find existing OpenAI Key. "
|
||||||
|
'A new key can be added from "Keys" section of the Admin Panel'
|
||||||
|
)
|
||||||
|
|
||||||
|
def __init__(self, msg: str = default_msg) -> None:
|
||||||
super().__init__(msg)
|
super().__init__(msg)
|
||||||
|
|
||||||
|
|
||||||
|
173
backend/danswer/direct_qa/gpt_4_all.py
Normal file
173
backend/danswer/direct_qa/gpt_4_all.py
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
|
from danswer.direct_qa.interfaces import DanswerQuote
|
||||||
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
|
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import WeakChatModelFreeformProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import WeakModelFreeformProcessor
|
||||||
|
from danswer.direct_qa.qa_utils import process_answer
|
||||||
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
from gpt4all import GPT4All # type:ignore
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
GPT4ALL_MODEL: GPT4All | None = None
|
||||||
|
|
||||||
|
|
||||||
|
def get_gpt_4_all_model(
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
) -> GPT4All:
|
||||||
|
global GPT4ALL_MODEL
|
||||||
|
if GPT4ALL_MODEL is None:
|
||||||
|
GPT4ALL_MODEL = GPT4All(model_version)
|
||||||
|
return GPT4ALL_MODEL
|
||||||
|
|
||||||
|
|
||||||
|
def _build_gpt4all_settings(**kwargs: Any) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Utility to add in some common default values so they don't have to be set every time.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"temp": 0,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GPT4AllCompletionQA(QAModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_processor: NonChatPromptProcessor = WeakModelFreeformProcessor(),
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
include_metadata: bool = False, # gpt4all models can't handle this atm
|
||||||
|
) -> None:
|
||||||
|
self.prompt_processor = prompt_processor
|
||||||
|
self.model_version = model_version
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
|
||||||
|
def warm_up_model(self) -> None:
|
||||||
|
get_gpt_4_all_model(self.model_version)
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||||
|
|
||||||
|
model_output = gen_ai_model.generate(
|
||||||
|
**_build_gpt4all_settings(
|
||||||
|
prompt=filled_prompt, max_tokens=self.max_output_tokens
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||||
|
|
||||||
|
model_stream = gen_ai_model.generate(
|
||||||
|
**_build_gpt4all_settings(
|
||||||
|
prompt=filled_prompt, max_tokens=self.max_output_tokens, streaming=True
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=model_stream,
|
||||||
|
context_docs=context_docs,
|
||||||
|
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class GPT4AllChatCompletionQA(QAModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_processor: ChatPromptProcessor = WeakChatModelFreeformProcessor(),
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
include_metadata: bool = False, # gpt4all models can't handle this atm
|
||||||
|
) -> None:
|
||||||
|
self.prompt_processor = prompt_processor
|
||||||
|
self.model_version = model_version
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||||
|
|
||||||
|
with gen_ai_model.chat_session():
|
||||||
|
context_msgs = filled_prompt[:-1]
|
||||||
|
user_query = filled_prompt[-1].get("content")
|
||||||
|
for message in context_msgs:
|
||||||
|
gen_ai_model.current_chat_session.append(message)
|
||||||
|
|
||||||
|
model_output = gen_ai_model.generate(
|
||||||
|
**_build_gpt4all_settings(
|
||||||
|
prompt=user_query, max_tokens=self.max_output_tokens
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
gen_ai_model = get_gpt_4_all_model(self.model_version)
|
||||||
|
|
||||||
|
with gen_ai_model.chat_session():
|
||||||
|
context_msgs = filled_prompt[:-1]
|
||||||
|
user_query = filled_prompt[-1].get("content")
|
||||||
|
for message in context_msgs:
|
||||||
|
gen_ai_model.current_chat_session.append(message)
|
||||||
|
|
||||||
|
model_stream = gen_ai_model.generate(
|
||||||
|
**_build_gpt4all_settings(
|
||||||
|
prompt=user_query, max_tokens=self.max_output_tokens
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=model_stream,
|
||||||
|
context_docs=context_docs,
|
||||||
|
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||||
|
)
|
@@ -1,17 +1,39 @@
|
|||||||
import abc
|
import abc
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
from dataclasses import dataclass
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DanswerAnswer:
|
||||||
|
answer: str | None
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class DanswerQuote:
|
||||||
|
# This is during inference so everything is a string by this point
|
||||||
|
quote: str
|
||||||
|
document_id: str
|
||||||
|
link: str | None
|
||||||
|
source_type: str
|
||||||
|
semantic_identifier: str
|
||||||
|
blurb: str
|
||||||
|
|
||||||
|
|
||||||
class QAModel:
|
class QAModel:
|
||||||
|
def warm_up_model(self) -> None:
|
||||||
|
"""This is called during server start up to load the models into memory
|
||||||
|
pass if model is accessed via API"""
|
||||||
|
pass
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def answer_question(
|
def answer_question(
|
||||||
self,
|
self,
|
||||||
query: str,
|
query: str,
|
||||||
context_docs: list[InferenceChunk],
|
context_docs: list[InferenceChunk],
|
||||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
|
@@ -1,25 +0,0 @@
|
|||||||
from danswer.direct_qa import get_default_backend_qa_model
|
|
||||||
from danswer.direct_qa.llm import OpenAIQAModel
|
|
||||||
from openai.error import AuthenticationError
|
|
||||||
from openai.error import Timeout
|
|
||||||
|
|
||||||
|
|
||||||
def check_openai_api_key_is_valid(openai_api_key: str) -> bool:
|
|
||||||
if not openai_api_key:
|
|
||||||
return False
|
|
||||||
|
|
||||||
qa_model = get_default_backend_qa_model(api_key=openai_api_key, timeout=5)
|
|
||||||
if not isinstance(qa_model, OpenAIQAModel):
|
|
||||||
raise ValueError("Cannot check OpenAI API key validity for non-OpenAI QA model")
|
|
||||||
|
|
||||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
|
||||||
for _ in range(2):
|
|
||||||
try:
|
|
||||||
qa_model.answer_question("Do not respond", [])
|
|
||||||
return True
|
|
||||||
except AuthenticationError:
|
|
||||||
return False
|
|
||||||
except Timeout:
|
|
||||||
pass
|
|
||||||
|
|
||||||
return False
|
|
@@ -1,488 +0,0 @@
|
|||||||
import json
|
|
||||||
import math
|
|
||||||
import re
|
|
||||||
from collections.abc import Callable
|
|
||||||
from collections.abc import Generator
|
|
||||||
from functools import wraps
|
|
||||||
from typing import Any
|
|
||||||
from typing import cast
|
|
||||||
from typing import Dict
|
|
||||||
from typing import Literal
|
|
||||||
from typing import Optional
|
|
||||||
from typing import Tuple
|
|
||||||
from typing import TypeVar
|
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import openai
|
|
||||||
import regex
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
|
||||||
from danswer.configs.app_configs import INCLUDE_METADATA
|
|
||||||
from danswer.configs.app_configs import OPENAI_API_KEY
|
|
||||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
|
||||||
from danswer.configs.constants import BLURB
|
|
||||||
from danswer.configs.constants import DOCUMENT_ID
|
|
||||||
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
|
||||||
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
|
||||||
from danswer.configs.constants import SOURCE_LINK
|
|
||||||
from danswer.configs.constants import SOURCE_TYPE
|
|
||||||
from danswer.configs.model_configs import OPENAI_MAX_OUTPUT_TOKENS
|
|
||||||
from danswer.configs.model_configs import OPENAI_MODEL_VERSION
|
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
|
||||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import get_chat_reflexion_msg
|
|
||||||
from danswer.direct_qa.qa_prompts import json_chat_processor
|
|
||||||
from danswer.direct_qa.qa_prompts import json_processor
|
|
||||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
|
||||||
from danswer.utils.logger import setup_logger
|
|
||||||
from danswer.utils.text_processing import clean_model_quote
|
|
||||||
from danswer.utils.text_processing import shared_precompare_cleanup
|
|
||||||
from danswer.utils.timing import log_function_time
|
|
||||||
from openai.error import AuthenticationError
|
|
||||||
from openai.error import Timeout
|
|
||||||
|
|
||||||
|
|
||||||
logger = setup_logger()
|
|
||||||
|
|
||||||
|
|
||||||
def get_openai_api_key() -> str:
|
|
||||||
return OPENAI_API_KEY or cast(
|
|
||||||
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def get_json_line(json_dict: dict) -> str:
|
|
||||||
return json.dumps(json_dict) + "\n"
|
|
||||||
|
|
||||||
|
|
||||||
def extract_answer_quotes_freeform(
|
|
||||||
answer_raw: str,
|
|
||||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
|
||||||
null_answer_check = (
|
|
||||||
answer_raw.replace(ANSWER_PAT, "").replace(QUOTE_PAT, "").strip()
|
|
||||||
)
|
|
||||||
|
|
||||||
# If model just gives back the uncertainty pattern to signify answer isn't found or nothing at all
|
|
||||||
# if null_answer_check == UNCERTAINTY_PAT or not null_answer_check:
|
|
||||||
# return None, None
|
|
||||||
|
|
||||||
# If no answer section, don't care about the quote
|
|
||||||
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
|
||||||
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
|
||||||
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
|
||||||
|
|
||||||
# Accept quote sections starting with the lower case version
|
|
||||||
answer_raw = answer_raw.replace(
|
|
||||||
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
|
||||||
) # Just in case model unreliable
|
|
||||||
|
|
||||||
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
|
||||||
sections_clean = [
|
|
||||||
str(section).strip() for section in sections if str(section).strip()
|
|
||||||
]
|
|
||||||
if not sections_clean:
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
answer = str(sections_clean[0])
|
|
||||||
if len(sections) == 1:
|
|
||||||
return answer, None
|
|
||||||
return answer, sections_clean[1:]
|
|
||||||
|
|
||||||
|
|
||||||
def extract_answer_quotes_json(
|
|
||||||
answer_dict: dict[str, str | list[str]]
|
|
||||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
|
||||||
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
|
||||||
answer = str(answer_dict.get("answer"))
|
|
||||||
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
|
||||||
if isinstance(quotes, str):
|
|
||||||
quotes = [quotes]
|
|
||||||
return answer, quotes
|
|
||||||
|
|
||||||
|
|
||||||
def separate_answer_quotes(
|
|
||||||
answer_raw: str,
|
|
||||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
|
||||||
try:
|
|
||||||
model_raw_json = json.loads(answer_raw)
|
|
||||||
return extract_answer_quotes_json(model_raw_json)
|
|
||||||
except ValueError:
|
|
||||||
return extract_answer_quotes_freeform(answer_raw)
|
|
||||||
|
|
||||||
|
|
||||||
def match_quotes_to_docs(
|
|
||||||
quotes: list[str],
|
|
||||||
chunks: list[InferenceChunk],
|
|
||||||
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
|
||||||
fuzzy_search: bool = False,
|
|
||||||
prefix_only_length: int = 100,
|
|
||||||
) -> Dict[str, Dict[str, Union[str, int, None]]]:
|
|
||||||
quotes_dict: dict[str, dict[str, Union[str, int, None]]] = {}
|
|
||||||
for quote in quotes:
|
|
||||||
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
if not chunk.source_links:
|
|
||||||
continue
|
|
||||||
|
|
||||||
quote_clean = shared_precompare_cleanup(
|
|
||||||
clean_model_quote(quote, trim_length=prefix_only_length)
|
|
||||||
)
|
|
||||||
chunk_clean = shared_precompare_cleanup(chunk.content)
|
|
||||||
|
|
||||||
# Finding the offset of the quote in the plain text
|
|
||||||
if fuzzy_search:
|
|
||||||
re_search_str = (
|
|
||||||
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
|
||||||
)
|
|
||||||
found = regex.search(re_search_str, chunk_clean)
|
|
||||||
if not found:
|
|
||||||
continue
|
|
||||||
offset = found.span()[0]
|
|
||||||
else:
|
|
||||||
if quote_clean not in chunk_clean:
|
|
||||||
continue
|
|
||||||
offset = chunk_clean.index(quote_clean)
|
|
||||||
|
|
||||||
# Extracting the link from the offset
|
|
||||||
curr_link = None
|
|
||||||
for link_offset, link in chunk.source_links.items():
|
|
||||||
# Should always find one because offset is at least 0 and there must be a 0 link_offset
|
|
||||||
if int(link_offset) <= offset:
|
|
||||||
curr_link = link
|
|
||||||
else:
|
|
||||||
quotes_dict[quote] = {
|
|
||||||
DOCUMENT_ID: chunk.document_id,
|
|
||||||
SOURCE_LINK: curr_link,
|
|
||||||
SOURCE_TYPE: chunk.source_type,
|
|
||||||
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
|
||||||
BLURB: chunk.blurb,
|
|
||||||
}
|
|
||||||
break
|
|
||||||
quotes_dict[quote] = {
|
|
||||||
DOCUMENT_ID: chunk.document_id,
|
|
||||||
SOURCE_LINK: curr_link,
|
|
||||||
SOURCE_TYPE: chunk.source_type,
|
|
||||||
SEMANTIC_IDENTIFIER: chunk.semantic_identifier,
|
|
||||||
BLURB: chunk.blurb,
|
|
||||||
}
|
|
||||||
break
|
|
||||||
return quotes_dict
|
|
||||||
|
|
||||||
|
|
||||||
def process_answer(
|
|
||||||
answer_raw: str, chunks: list[InferenceChunk]
|
|
||||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
|
||||||
answer, quote_strings = separate_answer_quotes(answer_raw)
|
|
||||||
if answer == UNCERTAINTY_PAT or not answer:
|
|
||||||
if answer == UNCERTAINTY_PAT:
|
|
||||||
logger.debug("Answer matched UNCERTAINTY_PAT")
|
|
||||||
else:
|
|
||||||
logger.debug("No answer extracted from raw output")
|
|
||||||
return None, None
|
|
||||||
|
|
||||||
logger.info(f"Answer: {answer}")
|
|
||||||
if not quote_strings:
|
|
||||||
logger.debug("No quotes extracted from raw output")
|
|
||||||
return answer, None
|
|
||||||
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
|
||||||
quotes_dict = match_quotes_to_docs(quote_strings, chunks)
|
|
||||||
logger.info(f"Final quotes dict: {quotes_dict}")
|
|
||||||
|
|
||||||
return answer, quotes_dict
|
|
||||||
|
|
||||||
|
|
||||||
def stream_answer_end(answer_so_far: str, next_token: str) -> bool:
|
|
||||||
next_token = next_token.replace('\\"', "")
|
|
||||||
# If the previous character is an escape token, don't consider the first character of next_token
|
|
||||||
if answer_so_far and answer_so_far[-1] == "\\":
|
|
||||||
next_token = next_token[1:]
|
|
||||||
if '"' in next_token:
|
|
||||||
return True
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
F = TypeVar("F", bound=Callable)
|
|
||||||
|
|
||||||
Answer = str
|
|
||||||
Quotes = dict[str, dict[str, str | int | None]]
|
|
||||||
ModelType = Literal["ChatCompletion", "Completion"]
|
|
||||||
PromptProcessor = Callable[[str, list[str]], str]
|
|
||||||
|
|
||||||
|
|
||||||
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
|
||||||
"""
|
|
||||||
Utility to add in some common default values so they don't have to be set every time.
|
|
||||||
"""
|
|
||||||
return {
|
|
||||||
"temperature": 0,
|
|
||||||
"top_p": 1,
|
|
||||||
"frequency_penalty": 0,
|
|
||||||
"presence_penalty": 0,
|
|
||||||
**kwargs,
|
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
|
||||||
@wraps(openai_call)
|
|
||||||
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
|
||||||
try:
|
|
||||||
# if streamed, the call returns a generator
|
|
||||||
if kwargs.get("stream"):
|
|
||||||
|
|
||||||
def _generator() -> Generator[Any, None, None]:
|
|
||||||
yield from openai_call(*args, **kwargs)
|
|
||||||
|
|
||||||
return _generator()
|
|
||||||
return openai_call(*args, **kwargs)
|
|
||||||
except AuthenticationError:
|
|
||||||
logger.exception("Failed to authenticate with OpenAI API")
|
|
||||||
raise
|
|
||||||
except Timeout:
|
|
||||||
logger.exception("OpenAI API timed out for query: %s", query)
|
|
||||||
raise
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
|
||||||
raise
|
|
||||||
|
|
||||||
return cast(F, wrapped_call)
|
|
||||||
|
|
||||||
|
|
||||||
# used to check if the QAModel is an OpenAI model
|
|
||||||
class OpenAIQAModel(QAModel):
|
|
||||||
pass
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAICompletionQA(OpenAIQAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: Callable[
|
|
||||||
[str, list[InferenceChunk], bool], str
|
|
||||||
] = json_processor,
|
|
||||||
model_version: str = OPENAI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
|
||||||
api_key: str | None = None,
|
|
||||||
timeout: int | None = None,
|
|
||||||
include_metadata: bool = INCLUDE_METADATA,
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.model_version = model_version
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.timeout = timeout
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
try:
|
|
||||||
self.api_key = api_key or get_openai_api_key()
|
|
||||||
except ConfigNotFoundError:
|
|
||||||
raise OpenAIKeyMissing()
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
|
||||||
filled_prompt = self.prompt_processor(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.Completion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=self.api_key,
|
|
||||||
prompt=filled_prompt,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model_output = cast(str, response["choices"][0]["text"]).strip()
|
|
||||||
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
||||||
return answer, quotes_dict
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> Generator[dict[str, Any] | None, None, None]:
|
|
||||||
filled_prompt = self.prompt_processor(
|
|
||||||
query, context_docs, self.include_metadata
|
|
||||||
)
|
|
||||||
logger.debug(filled_prompt)
|
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.Completion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=self.api_key,
|
|
||||||
prompt=filled_prompt,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
stream=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model_output: str = ""
|
|
||||||
found_answer_start = False
|
|
||||||
found_answer_end = False
|
|
||||||
# iterate through the stream of events
|
|
||||||
for event in response:
|
|
||||||
event_text = cast(str, event["choices"][0]["text"])
|
|
||||||
model_previous = model_output
|
|
||||||
model_output += event_text
|
|
||||||
|
|
||||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
|
||||||
" ", ""
|
|
||||||
).replace("\n", ""):
|
|
||||||
found_answer_start = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
if found_answer_start and not found_answer_end:
|
|
||||||
if stream_answer_end(model_previous, event_text):
|
|
||||||
found_answer_end = True
|
|
||||||
yield {"answer_finished": True}
|
|
||||||
continue
|
|
||||||
yield {"answer_data": event_text}
|
|
||||||
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
||||||
if answer:
|
|
||||||
logger.info(answer)
|
|
||||||
else:
|
|
||||||
logger.warning(
|
|
||||||
"Answer extraction from model output failed, most likely no quotes provided"
|
|
||||||
)
|
|
||||||
|
|
||||||
if quotes_dict is None:
|
|
||||||
yield {}
|
|
||||||
else:
|
|
||||||
yield quotes_dict
|
|
||||||
|
|
||||||
|
|
||||||
class OpenAIChatCompletionQA(OpenAIQAModel):
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
prompt_processor: Callable[
|
|
||||||
[str, list[InferenceChunk], bool], list[dict[str, str]]
|
|
||||||
] = json_chat_processor,
|
|
||||||
model_version: str = OPENAI_MODEL_VERSION,
|
|
||||||
max_output_tokens: int = OPENAI_MAX_OUTPUT_TOKENS,
|
|
||||||
timeout: int | None = None,
|
|
||||||
reflexion_try_count: int = 0,
|
|
||||||
api_key: str | None = None,
|
|
||||||
include_metadata: bool = INCLUDE_METADATA,
|
|
||||||
) -> None:
|
|
||||||
self.prompt_processor = prompt_processor
|
|
||||||
self.model_version = model_version
|
|
||||||
self.max_output_tokens = max_output_tokens
|
|
||||||
self.reflexion_try_count = reflexion_try_count
|
|
||||||
self.timeout = timeout
|
|
||||||
self.include_metadata = include_metadata
|
|
||||||
try:
|
|
||||||
self.api_key = api_key or get_openai_api_key()
|
|
||||||
except ConfigNotFoundError:
|
|
||||||
raise OpenAIKeyMissing()
|
|
||||||
|
|
||||||
@log_function_time()
|
|
||||||
def answer_question(
|
|
||||||
self,
|
|
||||||
query: str,
|
|
||||||
context_docs: list[InferenceChunk],
|
|
||||||
) -> tuple[str | None, dict[str, dict[str, str | int | None]] | None]:
|
|
||||||
messages = self.prompt_processor(query, context_docs, self.include_metadata)
|
|
||||||
logger.debug(json.dumps(messages, indent=4))
|
|
||||||
model_output = ""
|
|
||||||
for _ in range(self.reflexion_try_count + 1):
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.ChatCompletion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=self.api_key,
|
|
||||||
messages=messages,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model_output = cast(
|
|
||||||
str, response["choices"][0]["message"]["content"]
|
|
||||||
).strip()
|
|
||||||
assistant_msg = {"content": model_output, "role": "assistant"}
|
|
||||||
messages.extend([assistant_msg, get_chat_reflexion_msg()])
|
|
||||||
logger.info(
|
|
||||||
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
answer, quotes_dict = process_answer(model_output, context_docs)
|
|
||||||
return answer, quotes_dict
|
|
||||||
|
|
||||||
def answer_question_stream(
|
|
||||||
self, query: str, context_docs: list[InferenceChunk]
|
|
||||||
) -> Generator[dict[str, Any] | None, None, None]:
|
|
||||||
messages = self.prompt_processor(query, context_docs, self.include_metadata)
|
|
||||||
logger.debug(json.dumps(messages, indent=4))
|
|
||||||
|
|
||||||
openai_call = _handle_openai_exceptions_wrapper(
|
|
||||||
openai_call=openai.ChatCompletion.create,
|
|
||||||
query=query,
|
|
||||||
)
|
|
||||||
response = openai_call(
|
|
||||||
**_build_openai_settings(
|
|
||||||
api_key=self.api_key,
|
|
||||||
messages=messages,
|
|
||||||
model=self.model_version,
|
|
||||||
max_tokens=self.max_output_tokens,
|
|
||||||
request_timeout=self.timeout,
|
|
||||||
stream=True,
|
|
||||||
),
|
|
||||||
)
|
|
||||||
model_output: str = ""
|
|
||||||
found_answer_start = False
|
|
||||||
found_answer_end = False
|
|
||||||
for event in response:
|
|
||||||
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
|
||||||
if (
|
|
||||||
"content" not in event_dict
|
|
||||||
): # could be a role message or empty termination
|
|
||||||
continue
|
|
||||||
event_text = event_dict["content"]
|
|
||||||
model_previous = model_output
|
|
||||||
model_output += event_text
|
|
||||||
logger.debug(f"GPT returned token: {event_text}")
|
|
||||||
|
|
||||||
if not found_answer_start and '{"answer":"' in model_output.replace(
|
|
||||||
" ", ""
|
|
||||||
).replace("\n", ""):
|
|
||||||
# Note, if the token that completes the pattern has additional text, for example if the token is "?
|
|
||||||
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
|
|
||||||
# event that the model outputs the UNCERTAINTY_PAT
|
|
||||||
found_answer_start = True
|
|
||||||
continue
|
|
||||||
|
|
||||||
if found_answer_start and not found_answer_end:
|
|
||||||
if stream_answer_end(model_previous, event_text):
|
|
||||||
found_answer_end = True
|
|
||||||
yield {"answer_finished": True}
|
|
||||||
continue
|
|
||||||
yield {"answer_data": event_text}
|
|
||||||
|
|
||||||
logger.debug(model_output)
|
|
||||||
|
|
||||||
_, quotes_dict = process_answer(model_output, context_docs)
|
|
||||||
|
|
||||||
yield {} if quotes_dict is None else quotes_dict
|
|
280
backend/danswer/direct_qa/open_ai.py
Normal file
280
backend/danswer/direct_qa/open_ai.py
Normal file
@@ -0,0 +1,280 @@
|
|||||||
|
import json
|
||||||
|
from abc import ABC
|
||||||
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Generator
|
||||||
|
from functools import wraps
|
||||||
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
|
from typing import TypeVar
|
||||||
|
|
||||||
|
import openai
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import INCLUDE_METADATA
|
||||||
|
from danswer.configs.app_configs import OPENAI_API_KEY
|
||||||
|
from danswer.configs.constants import OPENAI_API_KEY_STORAGE_KEY
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
|
from danswer.direct_qa.interfaces import DanswerQuote
|
||||||
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
|
from danswer.direct_qa.qa_prompts import ChatPromptProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import get_json_chat_reflexion_msg
|
||||||
|
from danswer.direct_qa.qa_prompts import JsonChatProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import JsonProcessor
|
||||||
|
from danswer.direct_qa.qa_prompts import NonChatPromptProcessor
|
||||||
|
from danswer.direct_qa.qa_utils import process_answer
|
||||||
|
from danswer.direct_qa.qa_utils import process_model_tokens
|
||||||
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.timing import log_function_time
|
||||||
|
from openai.error import AuthenticationError
|
||||||
|
from openai.error import Timeout
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
F = TypeVar("F", bound=Callable)
|
||||||
|
|
||||||
|
|
||||||
|
def get_openai_api_key() -> str:
|
||||||
|
return OPENAI_API_KEY or cast(
|
||||||
|
str, get_dynamic_config_store().load(OPENAI_API_KEY_STORAGE_KEY)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _ensure_openai_api_key(api_key: str | None) -> str:
|
||||||
|
try:
|
||||||
|
return api_key or get_openai_api_key()
|
||||||
|
except ConfigNotFoundError:
|
||||||
|
raise OpenAIKeyMissing()
|
||||||
|
|
||||||
|
|
||||||
|
def _build_openai_settings(**kwargs: Any) -> dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Utility to add in some common default values so they don't have to be set every time.
|
||||||
|
"""
|
||||||
|
return {
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 1,
|
||||||
|
"frequency_penalty": 0,
|
||||||
|
"presence_penalty": 0,
|
||||||
|
**kwargs,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _handle_openai_exceptions_wrapper(openai_call: F, query: str) -> F:
|
||||||
|
@wraps(openai_call)
|
||||||
|
def wrapped_call(*args: list[Any], **kwargs: dict[str, Any]) -> Any:
|
||||||
|
try:
|
||||||
|
# if streamed, the call returns a generator
|
||||||
|
if kwargs.get("stream"):
|
||||||
|
|
||||||
|
def _generator() -> Generator[Any, None, None]:
|
||||||
|
yield from openai_call(*args, **kwargs)
|
||||||
|
|
||||||
|
return _generator()
|
||||||
|
return openai_call(*args, **kwargs)
|
||||||
|
except AuthenticationError:
|
||||||
|
logger.exception("Failed to authenticate with OpenAI API")
|
||||||
|
raise
|
||||||
|
except Timeout:
|
||||||
|
logger.exception("OpenAI API timed out for query: %s", query)
|
||||||
|
raise
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Unexpected error with OpenAI API for query: %s", query)
|
||||||
|
raise
|
||||||
|
|
||||||
|
return cast(F, wrapped_call)
|
||||||
|
|
||||||
|
|
||||||
|
# used to check if the QAModel is an OpenAI model
|
||||||
|
class OpenAIQAModel(QAModel, ABC):
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAICompletionQA(OpenAIQAModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_processor: NonChatPromptProcessor = JsonProcessor(),
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
api_key: str | None = None,
|
||||||
|
timeout: int | None = None,
|
||||||
|
include_metadata: bool = INCLUDE_METADATA,
|
||||||
|
) -> None:
|
||||||
|
self.prompt_processor = prompt_processor
|
||||||
|
self.model_version = model_version
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
self.timeout = timeout
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
try:
|
||||||
|
self.api_key = api_key or get_openai_api_key()
|
||||||
|
except ConfigNotFoundError:
|
||||||
|
raise OpenAIKeyMissing()
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
||||||
|
for event in response:
|
||||||
|
yield event["choices"][0]["text"]
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
openai_call = _handle_openai_exceptions_wrapper(
|
||||||
|
openai_call=openai.Completion.create,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
response = openai_call(
|
||||||
|
**_build_openai_settings(
|
||||||
|
api_key=_ensure_openai_api_key(self.api_key),
|
||||||
|
prompt=filled_prompt,
|
||||||
|
model=self.model_version,
|
||||||
|
max_tokens=self.max_output_tokens,
|
||||||
|
request_timeout=self.timeout,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model_output = cast(str, response["choices"][0]["text"]).strip()
|
||||||
|
logger.info("OpenAI Token Usage: " + str(response["usage"]).replace("\n", ""))
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
|
filled_prompt = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(filled_prompt)
|
||||||
|
|
||||||
|
openai_call = _handle_openai_exceptions_wrapper(
|
||||||
|
openai_call=openai.Completion.create,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
response = openai_call(
|
||||||
|
**_build_openai_settings(
|
||||||
|
api_key=_ensure_openai_api_key(self.api_key),
|
||||||
|
prompt=filled_prompt,
|
||||||
|
model=self.model_version,
|
||||||
|
max_tokens=self.max_output_tokens,
|
||||||
|
request_timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = self._generate_tokens_from_response(response)
|
||||||
|
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=tokens,
|
||||||
|
context_docs=context_docs,
|
||||||
|
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIChatCompletionQA(OpenAIQAModel):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
prompt_processor: ChatPromptProcessor = JsonChatProcessor(),
|
||||||
|
model_version: str = GEN_AI_MODEL_VERSION,
|
||||||
|
max_output_tokens: int = GEN_AI_MAX_OUTPUT_TOKENS,
|
||||||
|
timeout: int | None = None,
|
||||||
|
reflexion_try_count: int = 0,
|
||||||
|
api_key: str | None = None,
|
||||||
|
include_metadata: bool = INCLUDE_METADATA,
|
||||||
|
) -> None:
|
||||||
|
self.prompt_processor = prompt_processor
|
||||||
|
self.model_version = model_version
|
||||||
|
self.max_output_tokens = max_output_tokens
|
||||||
|
self.reflexion_try_count = reflexion_try_count
|
||||||
|
self.timeout = timeout
|
||||||
|
self.include_metadata = include_metadata
|
||||||
|
self.api_key = api_key
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _generate_tokens_from_response(response: Any) -> Generator[str, None, None]:
|
||||||
|
for event in response:
|
||||||
|
event_dict = cast(dict[str, Any], event["choices"][0]["delta"])
|
||||||
|
if (
|
||||||
|
"content" not in event_dict
|
||||||
|
): # could be a role message or empty termination
|
||||||
|
continue
|
||||||
|
yield event_dict["content"]
|
||||||
|
|
||||||
|
@log_function_time()
|
||||||
|
def answer_question(
|
||||||
|
self,
|
||||||
|
query: str,
|
||||||
|
context_docs: list[InferenceChunk],
|
||||||
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
|
messages = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(json.dumps(messages, indent=4))
|
||||||
|
model_output = ""
|
||||||
|
for _ in range(self.reflexion_try_count + 1):
|
||||||
|
openai_call = _handle_openai_exceptions_wrapper(
|
||||||
|
openai_call=openai.ChatCompletion.create,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
response = openai_call(
|
||||||
|
**_build_openai_settings(
|
||||||
|
api_key=_ensure_openai_api_key(self.api_key),
|
||||||
|
messages=messages,
|
||||||
|
model=self.model_version,
|
||||||
|
max_tokens=self.max_output_tokens,
|
||||||
|
request_timeout=self.timeout,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
model_output = cast(
|
||||||
|
str, response["choices"][0]["message"]["content"]
|
||||||
|
).strip()
|
||||||
|
assistant_msg = {"content": model_output, "role": "assistant"}
|
||||||
|
messages.extend([assistant_msg, get_json_chat_reflexion_msg()])
|
||||||
|
logger.info(
|
||||||
|
"OpenAI Token Usage: " + str(response["usage"]).replace("\n", "")
|
||||||
|
)
|
||||||
|
|
||||||
|
logger.debug(model_output)
|
||||||
|
|
||||||
|
answer, quotes_dict = process_answer(model_output, context_docs)
|
||||||
|
return answer, quotes_dict
|
||||||
|
|
||||||
|
def answer_question_stream(
|
||||||
|
self, query: str, context_docs: list[InferenceChunk]
|
||||||
|
) -> Generator[dict[str, Any] | None, None, None]:
|
||||||
|
messages = self.prompt_processor.fill_prompt(
|
||||||
|
query, context_docs, self.include_metadata
|
||||||
|
)
|
||||||
|
logger.debug(json.dumps(messages, indent=4))
|
||||||
|
|
||||||
|
openai_call = _handle_openai_exceptions_wrapper(
|
||||||
|
openai_call=openai.ChatCompletion.create,
|
||||||
|
query=query,
|
||||||
|
)
|
||||||
|
response = openai_call(
|
||||||
|
**_build_openai_settings(
|
||||||
|
api_key=_ensure_openai_api_key(self.api_key),
|
||||||
|
messages=messages,
|
||||||
|
model=self.model_version,
|
||||||
|
max_tokens=self.max_output_tokens,
|
||||||
|
request_timeout=self.timeout,
|
||||||
|
stream=True,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
tokens = self._generate_tokens_from_response(response)
|
||||||
|
|
||||||
|
yield from process_model_tokens(
|
||||||
|
tokens=tokens,
|
||||||
|
context_docs=context_docs,
|
||||||
|
is_json_prompt=self.prompt_processor.specifies_json_output,
|
||||||
|
)
|
@@ -1,3 +1,4 @@
|
|||||||
|
import abc
|
||||||
import json
|
import json
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
@@ -16,10 +17,10 @@ QUOTE_PAT = "Quote:"
|
|||||||
BASE_PROMPT = (
|
BASE_PROMPT = (
|
||||||
f"Answer the query based on provided documents and quote relevant sections. "
|
f"Answer the query based on provided documents and quote relevant sections. "
|
||||||
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
f"Respond with a json containing a concise answer and up to three most relevant quotes from the documents. "
|
||||||
|
f'Respond with "?" for the answer if the query cannot be answered based on the documents. '
|
||||||
f"The quotes must be EXACT substrings from the documents."
|
f"The quotes must be EXACT substrings from the documents."
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
|
SAMPLE_QUESTION = "Where is the Eiffel Tower?"
|
||||||
|
|
||||||
SAMPLE_JSON_RESPONSE = {
|
SAMPLE_JSON_RESPONSE = {
|
||||||
@@ -31,7 +32,23 @@ SAMPLE_JSON_RESPONSE = {
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def add_metadata_section(
|
def _append_acknowledge_doc_messages(
|
||||||
|
current_messages: list[dict[str, str]], new_chunk_content: str
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
updated_messages = current_messages.copy()
|
||||||
|
updated_messages.extend(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": new_chunk_content,
|
||||||
|
},
|
||||||
|
{"role": "assistant", "content": "Acknowledged"},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
return updated_messages
|
||||||
|
|
||||||
|
|
||||||
|
def _add_metadata_section(
|
||||||
prompt_current: str,
|
prompt_current: str,
|
||||||
chunk: InferenceChunk,
|
chunk: InferenceChunk,
|
||||||
prepend_tab: bool = False,
|
prepend_tab: bool = False,
|
||||||
@@ -67,192 +84,313 @@ def add_metadata_section(
|
|||||||
return prompt_current
|
return prompt_current
|
||||||
|
|
||||||
|
|
||||||
def json_processor(
|
class PromptProcessor(abc.ABC):
|
||||||
question: str,
|
"""Take the most relevant chunks and fills out a LLM prompt using the chunk contents
|
||||||
chunks: list[InferenceChunk],
|
and optionally metadata about the chunk"""
|
||||||
include_metadata: bool = False,
|
|
||||||
include_sep: bool = True,
|
|
||||||
) -> str:
|
|
||||||
prompt = (
|
|
||||||
BASE_PROMPT + f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
|
||||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
for chunk in chunks:
|
@property
|
||||||
prompt += f"\n\n{DOC_SEP_PAT}\n"
|
@abc.abstractmethod
|
||||||
if include_metadata:
|
def specifies_json_output(self) -> bool:
|
||||||
prompt = add_metadata_section(
|
raise NotImplementedError
|
||||||
prompt, chunk, prepend_tab=False, include_sep=include_sep
|
|
||||||
)
|
|
||||||
|
|
||||||
prompt += chunk.content
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
prompt += "\n\n---\n\n"
|
def fill_prompt(
|
||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
return prompt
|
) -> str | list[dict[str, str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
def json_chat_processor(
|
class NonChatPromptProcessor(PromptProcessor):
|
||||||
question: str,
|
@staticmethod
|
||||||
chunks: list[InferenceChunk],
|
@abc.abstractmethod
|
||||||
include_metadata: bool = False,
|
def fill_prompt(
|
||||||
include_sep: bool = False,
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
) -> list[dict[str, str]]:
|
) -> str:
|
||||||
metadata_prompt_section = "with metadata and contents " if include_metadata else ""
|
raise NotImplementedError
|
||||||
intro_msg = (
|
|
||||||
f"You are a Question Answering assistant that answers queries based on the provided most relevant documents.\n"
|
|
||||||
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
|
|
||||||
)
|
|
||||||
|
|
||||||
complete_answer_not_found_response = (
|
|
||||||
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
|
||||||
)
|
|
||||||
task_msg = (
|
|
||||||
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
|
||||||
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
|
||||||
"All quotes MUST be EXACT substrings from provided documents.\n"
|
|
||||||
"Your responses should be informative and concise.\n"
|
|
||||||
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
|
||||||
"If the query cannot be answered based on the documents, respond with "
|
|
||||||
f"{complete_answer_not_found_response}\n"
|
|
||||||
"If the query requires aggregating the number of documents, respond with "
|
|
||||||
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
|
||||||
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
|
||||||
)
|
|
||||||
messages = [{"role": "system", "content": intro_msg}]
|
|
||||||
|
|
||||||
for chunk in chunks:
|
|
||||||
full_context = ""
|
|
||||||
if include_metadata:
|
|
||||||
full_context = add_metadata_section(
|
|
||||||
full_context, chunk, prepend_tab=False, include_sep=include_sep
|
|
||||||
)
|
|
||||||
full_context += chunk.content
|
|
||||||
messages.extend(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": full_context,
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": "Acknowledged"},
|
|
||||||
]
|
|
||||||
)
|
|
||||||
messages.append({"role": "system", "content": task_msg})
|
|
||||||
|
|
||||||
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
|
||||||
|
|
||||||
return messages
|
|
||||||
|
|
||||||
|
|
||||||
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may use again in future
|
class ChatPromptProcessor(PromptProcessor):
|
||||||
|
@staticmethod
|
||||||
|
@abc.abstractmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
# Chain of Thought approach works however has higher token cost (more expensive) and is slower.
|
class JsonProcessor(NonChatPromptProcessor):
|
||||||
# Should use this one if users ask questions that require logical reasoning.
|
@property
|
||||||
def json_cot_variant_processor(question: str, documents: list[str]) -> str:
|
def specifies_json_output(self) -> bool:
|
||||||
prompt = (
|
return True
|
||||||
f"Answer the query based on provided documents and quote relevant sections. "
|
|
||||||
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
|
|
||||||
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
|
|
||||||
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
|
||||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
for document in documents:
|
@staticmethod
|
||||||
prompt += f"\n{DOC_SEP_PAT}\n{document}"
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
prompt += "\n\n---\n\n"
|
) -> str:
|
||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
prompt = (
|
||||||
prompt += "Reasoning:\n"
|
BASE_PROMPT + f" Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||||
return prompt
|
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||||
|
|
||||||
|
|
||||||
# This one seems largely useless with a single example
|
|
||||||
# Model seems to take the one example of answering Yes and just does that too.
|
|
||||||
def json_reflexion_processor(question: str, documents: list[str]) -> str:
|
|
||||||
reflexion_str = "Does this fully answer the user query?"
|
|
||||||
prompt = (
|
|
||||||
BASE_PROMPT
|
|
||||||
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
|
|
||||||
f"If No, generate a better json response to the query.\n"
|
|
||||||
f"Sample question and response:\n"
|
|
||||||
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
|
|
||||||
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
|
|
||||||
f"{reflexion_str} Yes\n\n"
|
|
||||||
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
for document in documents:
|
|
||||||
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{document}"
|
|
||||||
|
|
||||||
prompt += "\n\n---\n\n"
|
|
||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
# Initial design, works pretty well but not optimal
|
|
||||||
def freeform_processor(question: str, documents: list[str]) -> str:
|
|
||||||
prompt = (
|
|
||||||
f"Answer the query based on the documents below and quote the documents segments containing the answer. "
|
|
||||||
f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. '
|
|
||||||
f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. '
|
|
||||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
|
||||||
f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n'
|
|
||||||
)
|
|
||||||
|
|
||||||
for document in documents:
|
|
||||||
prompt += f"\n{DOC_SEP_PAT}\n{document}"
|
|
||||||
|
|
||||||
prompt += "\n\n---\n\n"
|
|
||||||
prompt += f"{QUESTION_PAT}\n{question}\n"
|
|
||||||
prompt += f"{ANSWER_PAT}\n"
|
|
||||||
return prompt
|
|
||||||
|
|
||||||
|
|
||||||
def freeform_chat_processor(
|
|
||||||
question: str, documents: list[str]
|
|
||||||
) -> list[dict[str, str]]:
|
|
||||||
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
|
|
||||||
role_msg = (
|
|
||||||
f"You are a Question Answering assistant that answers queries based on provided documents. "
|
|
||||||
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
|
|
||||||
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
|
|
||||||
f"Answer the question directly and concisely. "
|
|
||||||
f"Each quote should be a single continuous segment from a document. "
|
|
||||||
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
|
||||||
f"An example of quote sections may look like:\n{sample_quote}"
|
|
||||||
)
|
|
||||||
|
|
||||||
messages = [
|
|
||||||
{"role": "system", "content": role_msg},
|
|
||||||
]
|
|
||||||
for document in documents:
|
|
||||||
messages.extend(
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"role": "user",
|
|
||||||
"content": f"Acknowledge the following document:\n{document}",
|
|
||||||
},
|
|
||||||
{"role": "assistant", "content": "Acknowledged"},
|
|
||||||
]
|
|
||||||
)
|
)
|
||||||
|
|
||||||
messages.append(
|
for chunk in chunks:
|
||||||
{
|
prompt += f"\n\n{DOC_SEP_PAT}\n"
|
||||||
"role": "user",
|
if include_metadata:
|
||||||
"content": f"Please now answer the following query based on the previously provided "
|
prompt = _add_metadata_section(
|
||||||
f"documents and quote the relevant sections of the documents\n{question}",
|
prompt, chunk, prepend_tab=False, include_sep=True
|
||||||
},
|
)
|
||||||
)
|
|
||||||
|
|
||||||
return messages
|
prompt += chunk.content
|
||||||
|
|
||||||
|
prompt += "\n\n---\n\n"
|
||||||
|
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
# Not very useful, have not seen it improve an answer based on this
|
class JsonChatProcessor(ChatPromptProcessor):
|
||||||
# Sometimes gpt-3.5-turbo will just answer something worse like:
|
@property
|
||||||
# 'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
|
def specifies_json_output(self) -> bool:
|
||||||
# document. No revision is needed.'
|
return True
|
||||||
def get_chat_reflexion_msg() -> dict[str, str]:
|
|
||||||
|
@staticmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
metadata_prompt_section = (
|
||||||
|
"with metadata and contents " if include_metadata else ""
|
||||||
|
)
|
||||||
|
intro_msg = (
|
||||||
|
f"You are a Question Answering assistant that answers queries "
|
||||||
|
f"based on the provided most relevant documents.\n"
|
||||||
|
f'Start by reading the following documents {metadata_prompt_section}and responding with "Acknowledged".'
|
||||||
|
)
|
||||||
|
|
||||||
|
complete_answer_not_found_response = (
|
||||||
|
'{"answer": "' + UNCERTAINTY_PAT + '", "quotes": []}'
|
||||||
|
)
|
||||||
|
task_msg = (
|
||||||
|
"Now answer the next user query based on documents above and quote relevant sections.\n"
|
||||||
|
"Respond with a JSON containing the answer and up to three most relevant quotes from the documents.\n"
|
||||||
|
"All quotes MUST be EXACT substrings from provided documents.\n"
|
||||||
|
"Your responses should be informative and concise.\n"
|
||||||
|
"You MUST prioritize information from provided documents over internal knowledge.\n"
|
||||||
|
"If the query cannot be answered based on the documents, respond with "
|
||||||
|
f"{complete_answer_not_found_response}\n"
|
||||||
|
"If the query requires aggregating the number of documents, respond with "
|
||||||
|
'{"answer": "Aggregations not supported", "quotes": []}\n'
|
||||||
|
f"Sample response:\n{json.dumps(SAMPLE_JSON_RESPONSE)}"
|
||||||
|
)
|
||||||
|
messages = [{"role": "system", "content": intro_msg}]
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
full_context = ""
|
||||||
|
if include_metadata:
|
||||||
|
full_context = _add_metadata_section(
|
||||||
|
full_context, chunk, prepend_tab=False, include_sep=False
|
||||||
|
)
|
||||||
|
full_context += chunk.content
|
||||||
|
messages = _append_acknowledge_doc_messages(messages, full_context)
|
||||||
|
messages.append({"role": "system", "content": task_msg})
|
||||||
|
|
||||||
|
messages.append({"role": "user", "content": f"{QUESTION_PAT}\n{question}\n"})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
class WeakModelFreeformProcessor(NonChatPromptProcessor):
|
||||||
|
"""Avoid using this one if the model is capable of using another prompt
|
||||||
|
Intended for models that can't follow complex instructions or have short context windows
|
||||||
|
This prompt only uses 1 reference document chunk
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def specifies_json_output(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> str:
|
||||||
|
first_chunk_content = chunks[0].content if chunks else "No Document Provided"
|
||||||
|
|
||||||
|
prompt = (
|
||||||
|
f"Reference Document:\n{first_chunk_content}\n{GENERAL_SEP_PAT}"
|
||||||
|
f"Answer the user query below based on the reference document above. "
|
||||||
|
f'Respond with an "{ANSWER_PAT}" section and '
|
||||||
|
f'as many "{QUOTE_PAT}" sections as needed to support the answer.'
|
||||||
|
f"\n{GENERAL_SEP_PAT}"
|
||||||
|
f"{QUESTION_PAT} {question}\n"
|
||||||
|
f"{ANSWER_PAT}"
|
||||||
|
)
|
||||||
|
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class WeakChatModelFreeformProcessor(ChatPromptProcessor):
|
||||||
|
"""Avoid using this one if the model is capable of using another prompt
|
||||||
|
Intended for models that can't follow complex instructions or have short context windows
|
||||||
|
This prompt only uses 1 reference document chunk
|
||||||
|
"""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def specifies_json_output(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
first_chunk_content = chunks[0].content if chunks else "No Document Provided"
|
||||||
|
intro_msg = (
|
||||||
|
f"You are a question answering assistant. "
|
||||||
|
f'Respond to the query with an "{ANSWER_PAT}" section and '
|
||||||
|
f'as many "{QUOTE_PAT}" sections as needed to support the answer. '
|
||||||
|
f"Answer the user query based on the following document:\n\n{first_chunk_content}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [{"role": "system", "content": intro_msg}]
|
||||||
|
|
||||||
|
user_query = f"{QUESTION_PAT} {question}"
|
||||||
|
messages.append({"role": "user", "content": user_query})
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
# EVERYTHING BELOW IS DEPRECATED, kept around as reference, may revisit in future
|
||||||
|
|
||||||
|
|
||||||
|
class FreeformProcessor(NonChatPromptProcessor):
|
||||||
|
@property
|
||||||
|
def specifies_json_output(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> str:
|
||||||
|
prompt = (
|
||||||
|
f"Answer the query based on the documents below and quote the documents segments containing the answer. "
|
||||||
|
f'Respond with one "{ANSWER_PAT}" section and as many "{QUOTE_PAT}" sections as is relevant. '
|
||||||
|
f'Start each quote with "{QUOTE_PAT}". Each quote should be a single continuous segment from a document. '
|
||||||
|
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||||
|
f'Each document is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
|
||||||
|
|
||||||
|
prompt += "\n\n---\n\n"
|
||||||
|
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||||
|
prompt += f"{ANSWER_PAT}\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class FreeformChatProcessor(ChatPromptProcessor):
|
||||||
|
@property
|
||||||
|
def specifies_json_output(self) -> bool:
|
||||||
|
return False
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> list[dict[str, str]]:
|
||||||
|
sample_quote = "Quote:\nThe hotdogs are freshly cooked.\n\nQuote:\nThey are very cheap at only a dollar each."
|
||||||
|
role_msg = (
|
||||||
|
f"You are a Question Answering assistant that answers queries based on provided documents. "
|
||||||
|
f'You will be asked to acknowledge a set of documents and then provide one "{ANSWER_PAT}" and '
|
||||||
|
f'as many "{QUOTE_PAT}" sections as is relevant to back up your answer. '
|
||||||
|
f"Answer the question directly and concisely. "
|
||||||
|
f"Each quote should be a single continuous segment from a document. "
|
||||||
|
f'If the query cannot be answered based on the documents, say "{UNCERTAINTY_PAT}". '
|
||||||
|
f"An example of quote sections may look like:\n{sample_quote}"
|
||||||
|
)
|
||||||
|
|
||||||
|
messages = [
|
||||||
|
{"role": "system", "content": role_msg},
|
||||||
|
]
|
||||||
|
for chunk in chunks:
|
||||||
|
messages = _append_acknowledge_doc_messages(messages, chunk.content)
|
||||||
|
|
||||||
|
messages.append(
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"Please now answer the following query based on the previously provided "
|
||||||
|
f"documents and quote the relevant sections of the documents\n{question}",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return messages
|
||||||
|
|
||||||
|
|
||||||
|
class JsonCOTProcessor(NonChatPromptProcessor):
|
||||||
|
"""Chain of Thought allows model to explain out its reasoning to handle harder tests.
|
||||||
|
This prompt type works however has higher token cost (more expensive) and is slower.
|
||||||
|
Consider this one if users ask questions that require logical reasoning."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def specifies_json_output(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> str:
|
||||||
|
prompt = (
|
||||||
|
f"Answer the query based on provided documents and quote relevant sections. "
|
||||||
|
f'Respond with a freeform reasoning section followed by "Final Answer:" with a '
|
||||||
|
f"json containing a concise answer to the query and up to three most relevant quotes from the documents.\n"
|
||||||
|
f"Sample answer json:\n{json.dumps(SAMPLE_JSON_RESPONSE)}\n\n"
|
||||||
|
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
prompt += f"\n{DOC_SEP_PAT}\n{chunk.content}"
|
||||||
|
|
||||||
|
prompt += "\n\n---\n\n"
|
||||||
|
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||||
|
prompt += "Reasoning:\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
class JsonReflexionProcessor(NonChatPromptProcessor):
|
||||||
|
"""Reflexion prompting to attempt to have model evaluate its own answer.
|
||||||
|
This one seems largely useless when only given a single example
|
||||||
|
Model seems to take the one example of answering "Yes" and just does that too."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def specifies_json_output(self) -> bool:
|
||||||
|
return True
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def fill_prompt(
|
||||||
|
question: str, chunks: list[InferenceChunk], include_metadata: bool = False
|
||||||
|
) -> str:
|
||||||
|
reflexion_str = "Does this fully answer the user query?"
|
||||||
|
prompt = (
|
||||||
|
BASE_PROMPT
|
||||||
|
+ f'After each generated json, ask "{reflexion_str}" and respond Yes or No. '
|
||||||
|
f"If No, generate a better json response to the query.\n"
|
||||||
|
f"Sample question and response:\n"
|
||||||
|
f"{QUESTION_PAT}\n{SAMPLE_QUESTION}\n"
|
||||||
|
f"{json.dumps(SAMPLE_JSON_RESPONSE)}\n"
|
||||||
|
f"{reflexion_str} Yes\n\n"
|
||||||
|
f'Each context document below is prefixed with "{DOC_SEP_PAT}".\n\n'
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
prompt += f"\n---NEW CONTEXT DOCUMENT---\n{chunk.content}"
|
||||||
|
|
||||||
|
prompt += "\n\n---\n\n"
|
||||||
|
prompt += f"{QUESTION_PAT}\n{question}\n"
|
||||||
|
return prompt
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_chat_reflexion_msg() -> dict[str, str]:
|
||||||
|
"""With the models tried (curent as of Jul 2023), this has not been very useful.
|
||||||
|
Have not seen any answers improved based on this.
|
||||||
|
For models like gpt-3.5-turbo, it will often answer something like:
|
||||||
|
'The response is a valid json that fully answers the user query with quotes exactly matching sections of the source
|
||||||
|
document. No revision is needed.'"""
|
||||||
reflexion_content = (
|
reflexion_content = (
|
||||||
"Is the assistant response a valid json that fully answer the user query? "
|
"Is the assistant response a valid json that fully answer the user query? "
|
||||||
"If the response needs to be fixed or if an improvement is possible, provide a revised json. "
|
"If the response needs to be fixed or if an improvement is possible, provide a revised json. "
|
||||||
|
247
backend/danswer/direct_qa/qa_utils.py
Normal file
247
backend/danswer/direct_qa/qa_utils.py
Normal file
@@ -0,0 +1,247 @@
|
|||||||
|
import json
|
||||||
|
import math
|
||||||
|
import re
|
||||||
|
from collections.abc import Generator
|
||||||
|
from typing import Any
|
||||||
|
from typing import Optional
|
||||||
|
from typing import Tuple
|
||||||
|
|
||||||
|
import regex
|
||||||
|
from danswer.chunking.models import InferenceChunk
|
||||||
|
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||||
|
from danswer.configs.constants import BLURB
|
||||||
|
from danswer.configs.constants import DOCUMENT_ID
|
||||||
|
from danswer.configs.constants import SEMANTIC_IDENTIFIER
|
||||||
|
from danswer.configs.constants import SOURCE_LINK
|
||||||
|
from danswer.configs.constants import SOURCE_TYPE
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
|
from danswer.direct_qa.interfaces import DanswerQuote
|
||||||
|
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||||
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
from danswer.utils.text_processing import clean_model_quote
|
||||||
|
from danswer.utils.text_processing import shared_precompare_cleanup
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
|
def structure_quotes_for_response(
|
||||||
|
quotes: list[DanswerQuote] | None,
|
||||||
|
) -> dict[str, dict[str, str | None]]:
|
||||||
|
if quotes is None:
|
||||||
|
return {}
|
||||||
|
|
||||||
|
response_quotes = {}
|
||||||
|
for quote in quotes:
|
||||||
|
response_quotes[quote.quote] = {
|
||||||
|
DOCUMENT_ID: quote.document_id,
|
||||||
|
SOURCE_LINK: quote.link,
|
||||||
|
SOURCE_TYPE: quote.source_type,
|
||||||
|
SEMANTIC_IDENTIFIER: quote.semantic_identifier,
|
||||||
|
BLURB: quote.blurb,
|
||||||
|
}
|
||||||
|
return response_quotes
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer_quotes_freeform(
|
||||||
|
answer_raw: str,
|
||||||
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||||
|
"""Splits the model output into an Answer and 0 or more Quote sections.
|
||||||
|
Splits by the Quote pattern, if not exist then assume it's all answer and no quotes
|
||||||
|
"""
|
||||||
|
# If no answer section, don't care about the quote
|
||||||
|
if answer_raw.lower().strip().startswith(QUOTE_PAT.lower()):
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
# Sometimes model regenerates the Answer: pattern despite it being provided in the prompt
|
||||||
|
if answer_raw.lower().startswith(ANSWER_PAT.lower()):
|
||||||
|
answer_raw = answer_raw[len(ANSWER_PAT) :]
|
||||||
|
|
||||||
|
# Accept quote sections starting with the lower case version
|
||||||
|
answer_raw = answer_raw.replace(
|
||||||
|
f"\n{QUOTE_PAT}".lower(), f"\n{QUOTE_PAT}"
|
||||||
|
) # Just in case model unreliable
|
||||||
|
|
||||||
|
sections = re.split(rf"(?<=\n){QUOTE_PAT}", answer_raw)
|
||||||
|
sections_clean = [
|
||||||
|
str(section).strip() for section in sections if str(section).strip()
|
||||||
|
]
|
||||||
|
if not sections_clean:
|
||||||
|
return None, None
|
||||||
|
|
||||||
|
answer = str(sections_clean[0])
|
||||||
|
if len(sections) == 1:
|
||||||
|
return answer, None
|
||||||
|
return answer, sections_clean[1:]
|
||||||
|
|
||||||
|
|
||||||
|
def extract_answer_quotes_json(
|
||||||
|
answer_dict: dict[str, str | list[str]]
|
||||||
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||||
|
answer_dict = {k.lower(): v for k, v in answer_dict.items()}
|
||||||
|
answer = str(answer_dict.get("answer"))
|
||||||
|
quotes = answer_dict.get("quotes") or answer_dict.get("quote")
|
||||||
|
if isinstance(quotes, str):
|
||||||
|
quotes = [quotes]
|
||||||
|
return answer, quotes
|
||||||
|
|
||||||
|
|
||||||
|
def separate_answer_quotes(
|
||||||
|
answer_raw: str,
|
||||||
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||||
|
try:
|
||||||
|
model_raw_json = json.loads(answer_raw)
|
||||||
|
return extract_answer_quotes_json(model_raw_json)
|
||||||
|
except ValueError:
|
||||||
|
return extract_answer_quotes_freeform(answer_raw)
|
||||||
|
|
||||||
|
|
||||||
|
def match_quotes_to_docs(
|
||||||
|
quotes: list[str],
|
||||||
|
chunks: list[InferenceChunk],
|
||||||
|
max_error_percent: float = QUOTE_ALLOWED_ERROR_PERCENT,
|
||||||
|
fuzzy_search: bool = False,
|
||||||
|
prefix_only_length: int = 100,
|
||||||
|
) -> list[DanswerQuote]:
|
||||||
|
danswer_quotes = []
|
||||||
|
for quote in quotes:
|
||||||
|
max_edits = math.ceil(float(len(quote)) * max_error_percent)
|
||||||
|
|
||||||
|
for chunk in chunks:
|
||||||
|
if not chunk.source_links:
|
||||||
|
continue
|
||||||
|
|
||||||
|
quote_clean = shared_precompare_cleanup(
|
||||||
|
clean_model_quote(quote, trim_length=prefix_only_length)
|
||||||
|
)
|
||||||
|
chunk_clean = shared_precompare_cleanup(chunk.content)
|
||||||
|
|
||||||
|
# Finding the offset of the quote in the plain text
|
||||||
|
if fuzzy_search:
|
||||||
|
re_search_str = (
|
||||||
|
r"(" + re.escape(quote_clean) + r"){e<=" + str(max_edits) + r"}"
|
||||||
|
)
|
||||||
|
found = regex.search(re_search_str, chunk_clean)
|
||||||
|
if not found:
|
||||||
|
continue
|
||||||
|
offset = found.span()[0]
|
||||||
|
else:
|
||||||
|
if quote_clean not in chunk_clean:
|
||||||
|
continue
|
||||||
|
offset = chunk_clean.index(quote_clean)
|
||||||
|
|
||||||
|
# Extracting the link from the offset
|
||||||
|
curr_link = None
|
||||||
|
for link_offset, link in chunk.source_links.items():
|
||||||
|
# Should always find one because offset is at least 0 and there must be a 0 link_offset
|
||||||
|
if int(link_offset) <= offset:
|
||||||
|
curr_link = link
|
||||||
|
else:
|
||||||
|
danswer_quotes.append(
|
||||||
|
DanswerQuote(
|
||||||
|
quote=quote,
|
||||||
|
document_id=chunk.document_id,
|
||||||
|
link=curr_link,
|
||||||
|
source_type=chunk.source_type,
|
||||||
|
semantic_identifier=chunk.semantic_identifier,
|
||||||
|
blurb=chunk.blurb,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
# If the offset is larger than the start of the last quote, it must be the last one
|
||||||
|
danswer_quotes.append(
|
||||||
|
DanswerQuote(
|
||||||
|
quote=quote,
|
||||||
|
document_id=chunk.document_id,
|
||||||
|
link=curr_link,
|
||||||
|
source_type=chunk.source_type,
|
||||||
|
semantic_identifier=chunk.semantic_identifier,
|
||||||
|
blurb=chunk.blurb,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
break
|
||||||
|
|
||||||
|
return danswer_quotes
|
||||||
|
|
||||||
|
|
||||||
|
def process_answer(
|
||||||
|
answer_raw: str, chunks: list[InferenceChunk]
|
||||||
|
) -> tuple[DanswerAnswer, list[DanswerQuote]]:
|
||||||
|
answer, quote_strings = separate_answer_quotes(answer_raw)
|
||||||
|
if answer == UNCERTAINTY_PAT or not answer:
|
||||||
|
if answer == UNCERTAINTY_PAT:
|
||||||
|
logger.debug("Answer matched UNCERTAINTY_PAT")
|
||||||
|
else:
|
||||||
|
logger.debug("No answer extracted from raw output")
|
||||||
|
return DanswerAnswer(answer=None), []
|
||||||
|
|
||||||
|
logger.info(f"Answer: {answer}")
|
||||||
|
if not quote_strings:
|
||||||
|
logger.debug("No quotes extracted from raw output")
|
||||||
|
return DanswerAnswer(answer=answer), []
|
||||||
|
logger.info(f"All quotes (including unmatched): {quote_strings}")
|
||||||
|
quotes = match_quotes_to_docs(quote_strings, chunks)
|
||||||
|
logger.info(f"Final quotes: {quotes}")
|
||||||
|
|
||||||
|
return DanswerAnswer(answer=answer), quotes
|
||||||
|
|
||||||
|
|
||||||
|
def stream_json_answer_end(answer_so_far: str, next_token: str) -> bool:
|
||||||
|
next_token = next_token.replace('\\"', "")
|
||||||
|
# If the previous character is an escape token, don't consider the first character of next_token
|
||||||
|
if answer_so_far and answer_so_far[-1] == "\\":
|
||||||
|
next_token = next_token[1:]
|
||||||
|
if '"' in next_token:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def extract_quotes_from_completed_token_stream(
|
||||||
|
model_output: str, context_chunks: list[InferenceChunk]
|
||||||
|
) -> list[DanswerQuote]:
|
||||||
|
logger.debug(model_output)
|
||||||
|
answer, quotes = process_answer(model_output, context_chunks)
|
||||||
|
if answer:
|
||||||
|
logger.info(answer)
|
||||||
|
elif model_output:
|
||||||
|
logger.warning("Answer extraction from model output failed.")
|
||||||
|
|
||||||
|
return quotes
|
||||||
|
|
||||||
|
|
||||||
|
def process_model_tokens(
|
||||||
|
tokens: Generator[str, None, None],
|
||||||
|
context_docs: list[InferenceChunk],
|
||||||
|
is_json_prompt: bool = True,
|
||||||
|
) -> Generator[dict[str, Any], None, None]:
|
||||||
|
"""Yields Answer tokens back out in a dict for streaming to frontend
|
||||||
|
When Answer section ends, yields dict with answer_finished key
|
||||||
|
Collects all the tokens at the end to form the complete model output"""
|
||||||
|
model_output: str = ""
|
||||||
|
found_answer_start = False if is_json_prompt else True
|
||||||
|
found_answer_end = False
|
||||||
|
for token in tokens:
|
||||||
|
model_previous = model_output
|
||||||
|
model_output += token
|
||||||
|
|
||||||
|
trimmed_combine = model_output.replace(" ", "").replace("\n", "")
|
||||||
|
if not found_answer_start and '{"answer":"' in trimmed_combine:
|
||||||
|
# Note, if the token that completes the pattern has additional text, for example if the token is "?
|
||||||
|
# Then the chars after " will not be streamed, but this is ok as it prevents streaming the ? in the
|
||||||
|
# event that the model outputs the UNCERTAINTY_PAT
|
||||||
|
found_answer_start = True
|
||||||
|
continue
|
||||||
|
|
||||||
|
if found_answer_start and not found_answer_end:
|
||||||
|
if (is_json_prompt and stream_json_answer_end(model_previous, token)) or (
|
||||||
|
not is_json_prompt and f"\n{QUOTE_PAT}" in model_output
|
||||||
|
):
|
||||||
|
found_answer_end = True
|
||||||
|
yield {"answer_finished": True}
|
||||||
|
continue
|
||||||
|
yield {"answer_data": token}
|
||||||
|
|
||||||
|
quotes = extract_quotes_from_completed_token_stream(model_output, context_docs)
|
||||||
|
yield structure_quotes_for_response(quotes)
|
@@ -58,7 +58,7 @@ def _build_custom_semantic_identifier(
|
|||||||
|
|
||||||
|
|
||||||
def _process_quotes(
|
def _process_quotes(
|
||||||
quotes: dict[str, dict[str, str | int | None]] | None
|
quotes: dict[str, dict[str, str | None]] | None
|
||||||
) -> tuple[str | None, list[str]]:
|
) -> tuple[str | None, list[str]]:
|
||||||
if not quotes:
|
if not quotes:
|
||||||
return None, []
|
return None, []
|
||||||
|
@@ -9,19 +9,21 @@ from danswer.auth.users import google_oauth_client
|
|||||||
from danswer.configs.app_configs import APP_HOST
|
from danswer.configs.app_configs import APP_HOST
|
||||||
from danswer.configs.app_configs import APP_PORT
|
from danswer.configs.app_configs import APP_PORT
|
||||||
from danswer.configs.app_configs import DISABLE_AUTH
|
from danswer.configs.app_configs import DISABLE_AUTH
|
||||||
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import ENABLE_OAUTH
|
from danswer.configs.app_configs import ENABLE_OAUTH
|
||||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
|
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_ID
|
||||||
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
|
from danswer.configs.app_configs import GOOGLE_OAUTH_CLIENT_SECRET
|
||||||
|
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
||||||
from danswer.configs.app_configs import SECRET
|
from danswer.configs.app_configs import SECRET
|
||||||
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
from danswer.configs.app_configs import TYPESENSE_DEFAULT_COLLECTION
|
||||||
from danswer.configs.app_configs import WEB_DOMAIN
|
from danswer.configs.app_configs import WEB_DOMAIN
|
||||||
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
|
from danswer.configs.model_configs import INTERNAL_MODEL_VERSION
|
||||||
from danswer.datastores.qdrant.indexing import list_qdrant_collections
|
from danswer.datastores.qdrant.indexing import list_qdrant_collections
|
||||||
from danswer.datastores.typesense.store import check_typesense_collection_exist
|
from danswer.datastores.typesense.store import check_typesense_collection_exist
|
||||||
from danswer.datastores.typesense.store import create_typesense_collection
|
from danswer.datastores.typesense.store import create_typesense_collection
|
||||||
from danswer.db.credentials import create_initial_public_credential
|
from danswer.db.credentials import create_initial_public_credential
|
||||||
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
from danswer.direct_qa.llm import get_openai_api_key
|
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
|
||||||
from danswer.server.event_loading import router as event_processing_router
|
from danswer.server.event_loading import router as event_processing_router
|
||||||
from danswer.server.health import router as health_router
|
from danswer.server.health import router as health_router
|
||||||
from danswer.server.manage import router as admin_router
|
from danswer.server.manage import router as admin_router
|
||||||
@@ -121,21 +123,29 @@ def get_application() -> FastAPI:
|
|||||||
warm_up_models,
|
warm_up_models,
|
||||||
)
|
)
|
||||||
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
from danswer.datastores.qdrant.indexing import create_qdrant_collection
|
||||||
from danswer.configs.app_configs import QDRANT_DEFAULT_COLLECTION
|
|
||||||
|
if DISABLE_GENERATIVE_AI:
|
||||||
|
logger.info("Generative AI Q&A disabled")
|
||||||
|
else:
|
||||||
|
logger.info(f"Using Internal Model: {INTERNAL_MODEL_VERSION}")
|
||||||
|
logger.info(f"Actual LLM model version: {GEN_AI_MODEL_VERSION}")
|
||||||
|
|
||||||
auth_status = "off" if DISABLE_AUTH else "on"
|
auth_status = "off" if DISABLE_AUTH else "on"
|
||||||
logger.info(f"User auth is turned {auth_status}")
|
logger.info(f"User Authentication is turned {auth_status}")
|
||||||
|
|
||||||
if not ENABLE_OAUTH:
|
if not DISABLE_AUTH:
|
||||||
logger.debug("OAuth is turned off")
|
if not ENABLE_OAUTH:
|
||||||
else:
|
logger.warning("OAuth is turned off")
|
||||||
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
|
|
||||||
logger.warning("OAuth is turned on but incorrectly configured")
|
|
||||||
else:
|
else:
|
||||||
logger.debug("OAuth is turned on")
|
if not GOOGLE_OAUTH_CLIENT_ID or not GOOGLE_OAUTH_CLIENT_SECRET:
|
||||||
|
logger.warning("OAuth is turned on but incorrectly configured")
|
||||||
|
else:
|
||||||
|
logger.debug("OAuth is turned on")
|
||||||
|
|
||||||
logger.info("Warming up local NLP models.")
|
logger.info("Warming up local NLP models.")
|
||||||
warm_up_models()
|
warm_up_models()
|
||||||
|
qa_model = get_default_backend_qa_model()
|
||||||
|
qa_model.warm_up_model()
|
||||||
|
|
||||||
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
logger.info("Verifying query preprocessing (NLTK) data is downloaded")
|
||||||
nltk.download("stopwords")
|
nltk.download("stopwords")
|
||||||
|
@@ -1,9 +1,5 @@
|
|||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
from danswer.connectors.slack.connector import get_channel_info
|
|
||||||
from danswer.connectors.slack.connector import get_thread
|
|
||||||
from danswer.connectors.slack.connector import thread_to_doc
|
|
||||||
from danswer.utils.indexing_pipeline import build_indexing_pipeline
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
@@ -38,8 +38,10 @@ from danswer.db.engine import get_session
|
|||||||
from danswer.db.engine import get_sqlalchemy_async_engine
|
from danswer.db.engine import get_sqlalchemy_async_engine
|
||||||
from danswer.db.index_attempt import create_index_attempt
|
from danswer.db.index_attempt import create_index_attempt
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.key_validation import check_openai_api_key_is_valid
|
from danswer.direct_qa import check_model_api_key_is_valid
|
||||||
from danswer.direct_qa.llm import get_openai_api_key
|
from danswer.direct_qa import get_default_backend_qa_model
|
||||||
|
from danswer.direct_qa.open_ai import get_openai_api_key
|
||||||
|
from danswer.direct_qa.open_ai import OpenAIQAModel
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
from danswer.server.models import ApiKey
|
from danswer.server.models import ApiKey
|
||||||
@@ -102,7 +104,7 @@ def check_google_app_credentials_exist(
|
|||||||
) -> dict[str, str]:
|
) -> dict[str, str]:
|
||||||
try:
|
try:
|
||||||
return {"client_id": get_google_app_cred().web.client_id}
|
return {"client_id": get_google_app_cred().web.client_id}
|
||||||
except ConfigNotFoundError as e:
|
except ConfigNotFoundError:
|
||||||
raise HTTPException(status_code=404, detail="Google App Credentials not found")
|
raise HTTPException(status_code=404, detail="Google App Credentials not found")
|
||||||
|
|
||||||
|
|
||||||
@@ -295,19 +297,13 @@ def validate_existing_openai_api_key(
|
|||||||
_: User = Depends(current_admin_user),
|
_: User = Depends(current_admin_user),
|
||||||
) -> None:
|
) -> None:
|
||||||
# OpenAI key is only used for generative QA, so no need to validate this
|
# OpenAI key is only used for generative QA, so no need to validate this
|
||||||
# if it's turned off
|
# if it's turned off or if a non-OpenAI model is being used
|
||||||
if DISABLE_GENERATIVE_AI:
|
if DISABLE_GENERATIVE_AI or not isinstance(
|
||||||
|
get_default_backend_qa_model(), OpenAIQAModel
|
||||||
|
):
|
||||||
return
|
return
|
||||||
|
|
||||||
# always check if key exists
|
# Only validate every so often
|
||||||
try:
|
|
||||||
openai_api_key = get_openai_api_key()
|
|
||||||
except ConfigNotFoundError:
|
|
||||||
raise HTTPException(status_code=404, detail="Key not found")
|
|
||||||
except ValueError as e:
|
|
||||||
raise HTTPException(status_code=400, detail=str(e))
|
|
||||||
|
|
||||||
# don't call OpenAI every single time, only validate every so often
|
|
||||||
check_key_time = "openai_api_key_last_check_time"
|
check_key_time = "openai_api_key_last_check_time"
|
||||||
kv_store = get_dynamic_config_store()
|
kv_store = get_dynamic_config_store()
|
||||||
curr_time = datetime.now()
|
curr_time = datetime.now()
|
||||||
@@ -320,10 +316,17 @@ def validate_existing_openai_api_key(
|
|||||||
# First time checking the key, nothing unusual
|
# First time checking the key, nothing unusual
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
try:
|
||||||
|
openai_api_key = get_openai_api_key()
|
||||||
|
except ConfigNotFoundError:
|
||||||
|
raise HTTPException(status_code=404, detail="Key not found")
|
||||||
|
except ValueError as e:
|
||||||
|
raise HTTPException(status_code=400, detail=str(e))
|
||||||
|
|
||||||
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
get_dynamic_config_store().store(check_key_time, curr_time.timestamp())
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_valid = check_openai_api_key_is_valid(openai_api_key)
|
is_valid = check_model_api_key_is_valid(openai_api_key)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# this is the case where they aren't using an OpenAI-based model
|
# this is the case where they aren't using an OpenAI-based model
|
||||||
is_valid = True
|
is_valid = True
|
||||||
@@ -356,7 +359,7 @@ def store_openai_api_key(
|
|||||||
_: User = Depends(current_admin_user),
|
_: User = Depends(current_admin_user),
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
is_valid = check_openai_api_key_is_valid(request.api_key)
|
is_valid = check_model_api_key_is_valid(request.api_key)
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
raise HTTPException(400, "Invalid API key provided")
|
raise HTTPException(400, "Invalid API key provided")
|
||||||
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)
|
get_dynamic_config_store().store(OPENAI_API_KEY_STORAGE_KEY, request.api_key)
|
||||||
|
@@ -102,8 +102,8 @@ class SearchResponse(BaseModel):
|
|||||||
|
|
||||||
|
|
||||||
class QAResponse(SearchResponse):
|
class QAResponse(SearchResponse):
|
||||||
answer: str | None
|
answer: str | None # DanswerAnswer
|
||||||
quotes: dict[str, dict[str, str | int | None]] | None
|
quotes: dict[str, dict[str, str | None]] | None # restructured DanswerQuote
|
||||||
predicted_flow: QueryFlow
|
predicted_flow: QueryFlow
|
||||||
predicted_search: SearchType
|
predicted_search: SearchType
|
||||||
error_msg: str | None = None
|
error_msg: str | None = None
|
||||||
|
@@ -1,10 +1,10 @@
|
|||||||
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
|
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
|
||||||
from danswer.datastores.qdrant.store import QdrantIndex
|
from danswer.datastores.qdrant.store import QdrantIndex
|
||||||
from danswer.datastores.typesense.store import TypesenseIndex
|
from danswer.datastores.typesense.store import TypesenseIndex
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
@@ -12,7 +12,6 @@ from danswer.direct_qa import get_default_backend_qa_model
|
|||||||
from danswer.direct_qa.answer_question import answer_question
|
from danswer.direct_qa.answer_question import answer_question
|
||||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||||
from danswer.direct_qa.exceptions import UnknownModelError
|
from danswer.direct_qa.exceptions import UnknownModelError
|
||||||
from danswer.direct_qa.llm import get_json_line
|
|
||||||
from danswer.search.danswer_helper import query_intent
|
from danswer.search.danswer_helper import query_intent
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
from danswer.search.keyword_search import retrieve_keyword_documents
|
from danswer.search.keyword_search import retrieve_keyword_documents
|
||||||
@@ -35,6 +34,10 @@ logger = setup_logger()
|
|||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
|
||||||
|
|
||||||
|
def get_json_line(json_dict: dict) -> str:
|
||||||
|
return json.dumps(json_dict) + "\n"
|
||||||
|
|
||||||
|
|
||||||
@router.get("/search-intent")
|
@router.get("/search-intent")
|
||||||
def get_search_type(
|
def get_search_type(
|
||||||
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
question: QuestionRequest = Depends(), _: User = Depends(current_user)
|
||||||
@@ -162,7 +165,7 @@ def stream_direct_qa(
|
|||||||
return
|
return
|
||||||
|
|
||||||
try:
|
try:
|
||||||
qa_model = get_default_backend_qa_model(timeout=QA_TIMEOUT)
|
qa_model = get_default_backend_qa_model()
|
||||||
except (UnknownModelError, OpenAIKeyMissing) as e:
|
except (UnknownModelError, OpenAIKeyMissing) as e:
|
||||||
logger.exception("Unable to get QA model")
|
logger.exception("Unable to get QA model")
|
||||||
yield get_json_line({"error": str(e)})
|
yield get_json_line({"error": str(e)})
|
||||||
|
@@ -10,6 +10,7 @@ filelock==3.12.0
|
|||||||
google-api-python-client==2.86.0
|
google-api-python-client==2.86.0
|
||||||
google-auth-httplib2==0.1.0
|
google-auth-httplib2==0.1.0
|
||||||
google-auth-oauthlib==1.0.0
|
google-auth-oauthlib==1.0.0
|
||||||
|
gpt4all==1.0.5
|
||||||
httpcore==0.16.3
|
httpcore==0.16.3
|
||||||
httpx==0.23.3
|
httpx==0.23.3
|
||||||
httpx-oauth==0.11.2
|
httpx-oauth==0.11.2
|
||||||
|
@@ -2,8 +2,8 @@ import textwrap
|
|||||||
import unittest
|
import unittest
|
||||||
|
|
||||||
from danswer.chunking.models import InferenceChunk
|
from danswer.chunking.models import InferenceChunk
|
||||||
from danswer.direct_qa.llm import match_quotes_to_docs
|
from danswer.direct_qa.qa_utils import match_quotes_to_docs
|
||||||
from danswer.direct_qa.llm import separate_answer_quotes
|
from danswer.direct_qa.qa_utils import separate_answer_quotes
|
||||||
|
|
||||||
|
|
||||||
class TestQAPostprocessing(unittest.TestCase):
|
class TestQAPostprocessing(unittest.TestCase):
|
||||||
|
@@ -17,6 +17,8 @@ services:
|
|||||||
ports:
|
ports:
|
||||||
- "8080:8080"
|
- "8080:8080"
|
||||||
environment:
|
environment:
|
||||||
|
- INTERNAL_MODEL_VERSION=${INTERNAL_MODEL_VERSION:-openai-chat-completion}
|
||||||
|
- GEN_AI_MODEL_VERSION=${GEN_AI_MODEL_VERSION:-gpt-3.5-turbo}
|
||||||
- POSTGRES_HOST=relational_db
|
- POSTGRES_HOST=relational_db
|
||||||
- QDRANT_HOST=vector_db
|
- QDRANT_HOST=vector_db
|
||||||
- TYPESENSE_HOST=search_engine
|
- TYPESENSE_HOST=search_engine
|
||||||
|
@@ -9,7 +9,7 @@ OPENAI_API_KEY=
|
|||||||
# Choose between "openai-chat-completion" and "openai-completion"
|
# Choose between "openai-chat-completion" and "openai-completion"
|
||||||
INTERNAL_MODEL_VERSION=openai-chat-completion
|
INTERNAL_MODEL_VERSION=openai-chat-completion
|
||||||
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility
|
# Use a valid model for the choice above, consult https://platform.openai.com/docs/models/model-endpoint-compatibility
|
||||||
OPENAI_MODEL_VERSION=gpt-4
|
GEN_AI_MODEL_VERSION=gpt-4
|
||||||
|
|
||||||
# Could be something like danswer.companyname.com. Requires additional setup if not localhost
|
# Could be something like danswer.companyname.com. Requires additional setup if not localhost
|
||||||
WEB_DOMAIN=http://localhost:3000
|
WEB_DOMAIN=http://localhost:3000
|
||||||
|
Reference in New Issue
Block a user