mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-09-20 13:05:49 +02:00
More Cleanup and Deduplication (#675)
This commit is contained in:
@@ -1,4 +1,5 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from sqlalchemy.orm import Session
|
from sqlalchemy.orm import Session
|
||||||
|
|
||||||
@@ -6,7 +7,10 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
|||||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
from danswer.configs.constants import IGNORE_FOR_QA
|
||||||
|
from danswer.db.feedback import create_query_event
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
|
from danswer.direct_qa.interfaces import StreamingError
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.direct_qa.models import LLMMetricsContainer
|
from danswer.direct_qa.models import LLMMetricsContainer
|
||||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
from danswer.direct_qa.qa_utils import get_usable_chunks
|
||||||
@@ -21,8 +25,11 @@ from danswer.secondary_llm_flows.answer_validation import get_answer_validity
|
|||||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
|
from danswer.server.models import RerankedRetrievalDocs
|
||||||
|
from danswer.server.utils import get_json_line
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_function_time
|
from danswer.utils.timing import log_function_time
|
||||||
|
from danswer.utils.timing import log_generator_function_time
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@@ -169,3 +176,114 @@ def answer_qa_query(
|
|||||||
favor_recent=favor_recent,
|
favor_recent=favor_recent,
|
||||||
error_msg=error_msg,
|
error_msg=error_msg,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@log_generator_function_time()
|
||||||
|
def answer_qa_query_stream(
|
||||||
|
question: QuestionRequest,
|
||||||
|
user: User | None,
|
||||||
|
db_session: Session,
|
||||||
|
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||||
|
) -> Iterator[str]:
|
||||||
|
logger.debug(
|
||||||
|
f"Received QA query ({question.search_type.value} search): {question.query}"
|
||||||
|
)
|
||||||
|
logger.debug(f"Query filters: {question.filters}")
|
||||||
|
|
||||||
|
answer_so_far: str = ""
|
||||||
|
query = question.query
|
||||||
|
offset_count = question.offset if question.offset is not None else 0
|
||||||
|
|
||||||
|
time_cutoff, favor_recent = extract_question_time_filters(question)
|
||||||
|
question.filters.time_cutoff = time_cutoff
|
||||||
|
question.favor_recent = favor_recent
|
||||||
|
|
||||||
|
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
|
||||||
|
question=question,
|
||||||
|
user=user,
|
||||||
|
db_session=db_session,
|
||||||
|
document_index=get_default_document_index(),
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO retire this
|
||||||
|
predicted_search, predicted_flow = query_intent(query)
|
||||||
|
|
||||||
|
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||||
|
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||||
|
|
||||||
|
initial_response = RerankedRetrievalDocs(
|
||||||
|
top_documents=top_docs,
|
||||||
|
unranked_top_documents=unranked_top_docs,
|
||||||
|
# if generative AI is disabled, set flow as search so frontend
|
||||||
|
# doesn't ask the user if they want to run QA over more documents
|
||||||
|
predicted_flow=QueryFlow.SEARCH
|
||||||
|
if disable_generative_answer
|
||||||
|
else predicted_flow,
|
||||||
|
predicted_search=predicted_search,
|
||||||
|
time_cutoff=time_cutoff,
|
||||||
|
favor_recent=favor_recent,
|
||||||
|
).dict()
|
||||||
|
|
||||||
|
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
|
||||||
|
yield get_json_line(initial_response)
|
||||||
|
|
||||||
|
if not ranked_chunks:
|
||||||
|
logger.debug("No Documents Found")
|
||||||
|
return
|
||||||
|
|
||||||
|
if disable_generative_answer:
|
||||||
|
logger.debug("Skipping QA because generative AI is disabled")
|
||||||
|
return
|
||||||
|
|
||||||
|
try:
|
||||||
|
qa_model = get_default_qa_model()
|
||||||
|
except Exception as e:
|
||||||
|
logger.exception("Unable to get QA model")
|
||||||
|
error = StreamingError(error=str(e))
|
||||||
|
yield get_json_line(error.dict())
|
||||||
|
return
|
||||||
|
|
||||||
|
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
||||||
|
# types which can't be parsed). These chunks are useful to show in the
|
||||||
|
# search results, but not for QA.
|
||||||
|
filtered_ranked_chunks = [
|
||||||
|
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
||||||
|
]
|
||||||
|
|
||||||
|
# get all chunks that fit into the token limit
|
||||||
|
usable_chunks = get_usable_chunks(
|
||||||
|
chunks=filtered_ranked_chunks,
|
||||||
|
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
||||||
|
offset=offset_count,
|
||||||
|
)
|
||||||
|
logger.debug(
|
||||||
|
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
for response_packet in qa_model.answer_question_stream(query, usable_chunks):
|
||||||
|
if response_packet is None:
|
||||||
|
continue
|
||||||
|
if (
|
||||||
|
isinstance(response_packet, DanswerAnswerPiece)
|
||||||
|
and response_packet.answer_piece
|
||||||
|
):
|
||||||
|
answer_so_far = answer_so_far + response_packet.answer_piece
|
||||||
|
logger.debug(f"Sending packet: {response_packet}")
|
||||||
|
yield get_json_line(response_packet.dict())
|
||||||
|
except Exception:
|
||||||
|
# exception is logged in the answer_question method, no need to re-log
|
||||||
|
logger.exception("Failed to run QA")
|
||||||
|
error = StreamingError(error="The LLM failed to produce a useable response")
|
||||||
|
yield get_json_line(error.dict())
|
||||||
|
|
||||||
|
query_event_id = create_query_event(
|
||||||
|
query=query,
|
||||||
|
search_type=question.search_type,
|
||||||
|
llm_answer=answer_so_far,
|
||||||
|
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
||||||
|
user_id=None if user is None else user.id,
|
||||||
|
db_session=db_session,
|
||||||
|
)
|
||||||
|
|
||||||
|
yield get_json_line({"query_event_id": query_event_id})
|
||||||
|
@@ -1,5 +1,3 @@
|
|||||||
from openai.error import AuthenticationError
|
|
||||||
|
|
||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.direct_qa.interfaces import QAModel
|
from danswer.direct_qa.interfaces import QAModel
|
||||||
from danswer.direct_qa.qa_block import QABlock
|
from danswer.direct_qa.qa_block import QABlock
|
||||||
@@ -12,25 +10,6 @@ from danswer.utils.logger import setup_logger
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def check_model_api_key_is_valid(model_api_key: str) -> bool:
|
|
||||||
if not model_api_key:
|
|
||||||
return False
|
|
||||||
|
|
||||||
llm = get_default_llm(api_key=model_api_key, timeout=10)
|
|
||||||
|
|
||||||
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
|
||||||
for _ in range(2):
|
|
||||||
try:
|
|
||||||
llm.invoke("Do not respond")
|
|
||||||
return True
|
|
||||||
except AuthenticationError:
|
|
||||||
return False
|
|
||||||
except Exception as e:
|
|
||||||
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
|
||||||
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
# TODO introduce the prompt choice parameter
|
# TODO introduce the prompt choice parameter
|
||||||
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
|
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
|
||||||
return (
|
return (
|
||||||
|
@@ -4,7 +4,6 @@ import re
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from json.decoder import JSONDecodeError
|
from json.decoder import JSONDecodeError
|
||||||
from typing import cast
|
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
from typing import Tuple
|
from typing import Tuple
|
||||||
|
|
||||||
@@ -12,8 +11,6 @@ import regex
|
|||||||
|
|
||||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
||||||
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
|
||||||
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
|
||||||
from danswer.configs.model_configs import GEN_AI_API_KEY
|
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswer
|
from danswer.direct_qa.interfaces import DanswerAnswer
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||||
from danswer.direct_qa.interfaces import DanswerQuote
|
from danswer.direct_qa.interfaces import DanswerQuote
|
||||||
@@ -21,8 +18,6 @@ from danswer.direct_qa.interfaces import DanswerQuotes
|
|||||||
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
from danswer.direct_qa.qa_prompts import ANSWER_PAT
|
||||||
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
from danswer.direct_qa.qa_prompts import QUOTE_PAT
|
||||||
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
|
||||||
from danswer.indexing.models import InferenceChunk
|
from danswer.indexing.models import InferenceChunk
|
||||||
from danswer.llm.utils import check_number_of_tokens
|
from danswer.llm.utils import check_number_of_tokens
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
@@ -33,17 +28,6 @@ from danswer.utils.text_processing import shared_precompare_cleanup
|
|||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
|
|
||||||
def get_gen_ai_api_key() -> str | None:
|
|
||||||
# first check if the key has been provided by the UI
|
|
||||||
try:
|
|
||||||
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
|
||||||
except ConfigNotFoundError:
|
|
||||||
pass
|
|
||||||
|
|
||||||
# if not provided by the UI, fallback to the env variable
|
|
||||||
return GEN_AI_API_KEY
|
|
||||||
|
|
||||||
|
|
||||||
def extract_answer_quotes_freeform(
|
def extract_answer_quotes_freeform(
|
||||||
answer_raw: str,
|
answer_raw: str,
|
||||||
) -> Tuple[Optional[str], Optional[list[str]]]:
|
) -> Tuple[Optional[str], Optional[list[str]]]:
|
||||||
|
@@ -1,22 +1,77 @@
|
|||||||
|
import abc
|
||||||
|
from collections.abc import Iterator
|
||||||
|
|
||||||
import litellm # type:ignore
|
import litellm # type:ignore
|
||||||
from langchain.chat_models import ChatLiteLLM
|
from langchain.chat_models import ChatLiteLLM
|
||||||
|
from langchain.chat_models.base import BaseChatModel
|
||||||
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
|
||||||
|
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
||||||
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
|
||||||
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
from danswer.configs.model_configs import GEN_AI_API_VERSION
|
||||||
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
|
||||||
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
|
||||||
from danswer.llm.interfaces import LangChainChatLLM
|
from danswer.llm.interfaces import LLM
|
||||||
|
from danswer.llm.utils import message_generator_to_string_generator
|
||||||
from danswer.llm.utils import should_be_verbose
|
from danswer.llm.utils import should_be_verbose
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
# If a user configures a different model and it doesn't support all the same
|
# If a user configures a different model and it doesn't support all the same
|
||||||
# parameters like frequency and presence, just ignore them
|
# parameters like frequency and presence, just ignore them
|
||||||
litellm.drop_params = True
|
litellm.drop_params = True
|
||||||
litellm.telemetry = False
|
litellm.telemetry = False
|
||||||
|
|
||||||
|
|
||||||
|
class LangChainChatLLM(LLM, abc.ABC):
|
||||||
|
@property
|
||||||
|
@abc.abstractmethod
|
||||||
|
def llm(self) -> BaseChatModel:
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
def _log_model_config(self) -> None:
|
||||||
|
logger.debug(
|
||||||
|
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
||||||
|
)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _log_prompt(prompt: LanguageModelInput) -> None:
|
||||||
|
if isinstance(prompt, list):
|
||||||
|
for ind, msg in enumerate(prompt):
|
||||||
|
logger.debug(f"Message {ind}:\n{msg.content}")
|
||||||
|
if isinstance(prompt, str):
|
||||||
|
logger.debug(f"Prompt:\n{prompt}")
|
||||||
|
|
||||||
|
def invoke(self, prompt: LanguageModelInput) -> str:
|
||||||
|
self._log_model_config()
|
||||||
|
if LOG_ALL_MODEL_INTERACTIONS:
|
||||||
|
self._log_prompt(prompt)
|
||||||
|
|
||||||
|
model_raw = self.llm.invoke(prompt).content
|
||||||
|
if LOG_ALL_MODEL_INTERACTIONS:
|
||||||
|
logger.debug(f"Raw Model Output:\n{model_raw}")
|
||||||
|
|
||||||
|
return model_raw
|
||||||
|
|
||||||
|
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||||
|
self._log_model_config()
|
||||||
|
if LOG_ALL_MODEL_INTERACTIONS:
|
||||||
|
self._log_prompt(prompt)
|
||||||
|
|
||||||
|
output_tokens = []
|
||||||
|
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
|
||||||
|
output_tokens.append(token)
|
||||||
|
yield token
|
||||||
|
|
||||||
|
full_output = "".join(output_tokens)
|
||||||
|
if LOG_ALL_MODEL_INTERACTIONS:
|
||||||
|
logger.debug(f"Raw Model Output:\n{full_output}")
|
||||||
|
|
||||||
|
|
||||||
def _get_model_str(
|
def _get_model_str(
|
||||||
model_provider: str | None,
|
model_provider: str | None,
|
||||||
model_version: str | None,
|
model_version: str | None,
|
@@ -1,10 +1,10 @@
|
|||||||
from danswer.configs.app_configs import QA_TIMEOUT
|
from danswer.configs.app_configs import QA_TIMEOUT
|
||||||
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
from danswer.llm.chat_llm import DefaultMultiLLM
|
||||||
from danswer.llm.custom_llm import CustomModelServer
|
from danswer.llm.custom_llm import CustomModelServer
|
||||||
from danswer.llm.gpt_4_all import DanswerGPT4All
|
from danswer.llm.gpt_4_all import DanswerGPT4All
|
||||||
from danswer.llm.interfaces import LLM
|
from danswer.llm.interfaces import LLM
|
||||||
from danswer.llm.multi_llm import DefaultMultiLLM
|
from danswer.llm.utils import get_gen_ai_api_key
|
||||||
|
|
||||||
|
|
||||||
def get_default_llm(
|
def get_default_llm(
|
||||||
|
@@ -1,11 +1,8 @@
|
|||||||
import abc
|
import abc
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
|
|
||||||
from langchain.chat_models.base import BaseChatModel
|
|
||||||
from langchain.schema.language_model import LanguageModelInput
|
from langchain.schema.language_model import LanguageModelInput
|
||||||
|
|
||||||
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
|
|
||||||
from danswer.llm.utils import message_generator_to_string_generator
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
|
||||||
@@ -28,48 +25,3 @@ class LLM(abc.ABC):
|
|||||||
@abc.abstractmethod
|
@abc.abstractmethod
|
||||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
class LangChainChatLLM(LLM, abc.ABC):
|
|
||||||
@property
|
|
||||||
@abc.abstractmethod
|
|
||||||
def llm(self) -> BaseChatModel:
|
|
||||||
raise NotImplementedError
|
|
||||||
|
|
||||||
def _log_model_config(self) -> None:
|
|
||||||
logger.debug(
|
|
||||||
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
|
|
||||||
)
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def _log_prompt(prompt: LanguageModelInput) -> None:
|
|
||||||
if isinstance(prompt, list):
|
|
||||||
for ind, msg in enumerate(prompt):
|
|
||||||
logger.debug(f"Message {ind}:\n{msg.content}")
|
|
||||||
if isinstance(prompt, str):
|
|
||||||
logger.debug(f"Prompt:\n{prompt}")
|
|
||||||
|
|
||||||
def invoke(self, prompt: LanguageModelInput) -> str:
|
|
||||||
self._log_model_config()
|
|
||||||
if LOG_ALL_MODEL_INTERACTIONS:
|
|
||||||
self._log_prompt(prompt)
|
|
||||||
|
|
||||||
model_raw = self.llm.invoke(prompt).content
|
|
||||||
if LOG_ALL_MODEL_INTERACTIONS:
|
|
||||||
logger.debug(f"Raw Model Output:\n{model_raw}")
|
|
||||||
|
|
||||||
return model_raw
|
|
||||||
|
|
||||||
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
|
|
||||||
self._log_model_config()
|
|
||||||
if LOG_ALL_MODEL_INTERACTIONS:
|
|
||||||
self._log_prompt(prompt)
|
|
||||||
|
|
||||||
output_tokens = []
|
|
||||||
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
|
|
||||||
output_tokens.append(token)
|
|
||||||
yield token
|
|
||||||
|
|
||||||
full_output = "".join(output_tokens)
|
|
||||||
if LOG_ALL_MODEL_INTERACTIONS:
|
|
||||||
logger.debug(f"Raw Model Output:\n{full_output}")
|
|
||||||
|
@@ -1,6 +1,7 @@
|
|||||||
from collections.abc import Callable
|
from collections.abc import Callable
|
||||||
from collections.abc import Iterator
|
from collections.abc import Iterator
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
from typing import cast
|
||||||
|
|
||||||
import tiktoken
|
import tiktoken
|
||||||
from langchain.prompts.base import StringPromptValue
|
from langchain.prompts.base import StringPromptValue
|
||||||
@@ -14,9 +15,16 @@ from langchain.schema.messages import HumanMessage
|
|||||||
from langchain.schema.messages import SystemMessage
|
from langchain.schema.messages import SystemMessage
|
||||||
|
|
||||||
from danswer.configs.app_configs import LOG_LEVEL
|
from danswer.configs.app_configs import LOG_LEVEL
|
||||||
|
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
|
||||||
from danswer.configs.constants import MessageType
|
from danswer.configs.constants import MessageType
|
||||||
|
from danswer.configs.model_configs import GEN_AI_API_KEY
|
||||||
from danswer.db.models import ChatMessage
|
from danswer.db.models import ChatMessage
|
||||||
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
|
from danswer.llm.interfaces import LLM
|
||||||
|
from danswer.utils.logger import setup_logger
|
||||||
|
|
||||||
|
logger = setup_logger()
|
||||||
|
|
||||||
_LLM_TOKENIZER: Callable[[str], Any] | None = None
|
_LLM_TOKENIZER: Callable[[str], Any] | None = None
|
||||||
|
|
||||||
@@ -107,7 +115,7 @@ def should_be_verbose() -> bool:
|
|||||||
def check_number_of_tokens(
|
def check_number_of_tokens(
|
||||||
text: str, encode_fn: Callable[[str], list] | None = None
|
text: str, encode_fn: Callable[[str], list] | None = None
|
||||||
) -> int:
|
) -> int:
|
||||||
"""Get's the number of tokens in the provided text, using the provided encoding
|
"""Gets the number of tokens in the provided text, using the provided encoding
|
||||||
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
|
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
|
||||||
and GPT-4.
|
and GPT-4.
|
||||||
"""
|
"""
|
||||||
@@ -116,3 +124,26 @@ def check_number_of_tokens(
|
|||||||
encode_fn = tiktoken.get_encoding("cl100k_base").encode
|
encode_fn = tiktoken.get_encoding("cl100k_base").encode
|
||||||
|
|
||||||
return len(encode_fn(text))
|
return len(encode_fn(text))
|
||||||
|
|
||||||
|
|
||||||
|
def get_gen_ai_api_key() -> str | None:
|
||||||
|
# first check if the key has been provided by the UI
|
||||||
|
try:
|
||||||
|
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
|
||||||
|
except ConfigNotFoundError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
# if not provided by the UI, fallback to the env variable
|
||||||
|
return GEN_AI_API_KEY
|
||||||
|
|
||||||
|
|
||||||
|
def test_llm(llm: LLM) -> bool:
|
||||||
|
# try for up to 2 timeouts (e.g. 10 seconds in total)
|
||||||
|
for _ in range(2):
|
||||||
|
try:
|
||||||
|
llm.invoke("Do not respond")
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
logger.warning(f"GenAI API key failed for the following reason: {e}")
|
||||||
|
|
||||||
|
return False
|
||||||
|
@@ -23,12 +23,13 @@ from danswer.db.feedback import fetch_docs_ranked_by_boost
|
|||||||
from danswer.db.feedback import update_document_boost
|
from danswer.db.feedback import update_document_boost
|
||||||
from danswer.db.feedback import update_document_hidden
|
from danswer.db.feedback import update_document_hidden
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
|
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||||
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
|
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.dynamic_configs import get_dynamic_config_store
|
from danswer.dynamic_configs import get_dynamic_config_store
|
||||||
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||||
|
from danswer.llm.factory import get_default_llm
|
||||||
|
from danswer.llm.utils import get_gen_ai_api_key
|
||||||
|
from danswer.llm.utils import test_llm
|
||||||
from danswer.server.models import ApiKey
|
from danswer.server.models import ApiKey
|
||||||
from danswer.server.models import BoostDoc
|
from danswer.server.models import BoostDoc
|
||||||
from danswer.server.models import BoostUpdateRequest
|
from danswer.server.models import BoostUpdateRequest
|
||||||
@@ -132,7 +133,8 @@ def validate_existing_genai_api_key(
|
|||||||
raise HTTPException(status_code=404, detail="Key not found")
|
raise HTTPException(status_code=404, detail="Key not found")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
is_valid = check_model_api_key_is_valid(genai_api_key)
|
llm = get_default_llm(api_key=genai_api_key, timeout=10)
|
||||||
|
is_valid = test_llm(llm)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
# this is the case where they aren't using an OpenAI-based model
|
# this is the case where they aren't using an OpenAI-based model
|
||||||
is_valid = True
|
is_valid = True
|
||||||
@@ -168,9 +170,15 @@ def store_genai_api_key(
|
|||||||
_: User = Depends(current_admin_user),
|
_: User = Depends(current_admin_user),
|
||||||
) -> None:
|
) -> None:
|
||||||
try:
|
try:
|
||||||
is_valid = check_model_api_key_is_valid(request.api_key)
|
if not request.api_key:
|
||||||
|
raise HTTPException(400, "No API key provided")
|
||||||
|
|
||||||
|
llm = get_default_llm(api_key=request.api_key, timeout=10)
|
||||||
|
is_valid = test_llm(llm)
|
||||||
|
|
||||||
if not is_valid:
|
if not is_valid:
|
||||||
raise HTTPException(400, "Invalid API key provided")
|
raise HTTPException(400, "Invalid API key provided")
|
||||||
|
|
||||||
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
|
||||||
except RuntimeError as e:
|
except RuntimeError as e:
|
||||||
raise HTTPException(400, str(e))
|
raise HTTPException(400, str(e))
|
||||||
|
@@ -1,5 +1,3 @@
|
|||||||
from collections.abc import Generator
|
|
||||||
|
|
||||||
from fastapi import APIRouter
|
from fastapi import APIRouter
|
||||||
from fastapi import Depends
|
from fastapi import Depends
|
||||||
from fastapi import HTTPException
|
from fastapi import HTTPException
|
||||||
@@ -9,30 +7,19 @@ from sqlalchemy.orm import Session
|
|||||||
|
|
||||||
from danswer.auth.users import current_admin_user
|
from danswer.auth.users import current_admin_user
|
||||||
from danswer.auth.users import current_user
|
from danswer.auth.users import current_user
|
||||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
|
||||||
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
|
|
||||||
from danswer.configs.constants import IGNORE_FOR_QA
|
|
||||||
from danswer.db.engine import get_session
|
from danswer.db.engine import get_session
|
||||||
from danswer.db.feedback import create_doc_retrieval_feedback
|
from danswer.db.feedback import create_doc_retrieval_feedback
|
||||||
from danswer.db.feedback import create_query_event
|
|
||||||
from danswer.db.feedback import update_query_event_feedback
|
from danswer.db.feedback import update_query_event_feedback
|
||||||
from danswer.db.models import User
|
from danswer.db.models import User
|
||||||
from danswer.direct_qa.answer_question import answer_qa_query
|
from danswer.direct_qa.answer_question import answer_qa_query
|
||||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
from danswer.direct_qa.answer_question import answer_qa_query_stream
|
||||||
from danswer.direct_qa.interfaces import StreamingError
|
|
||||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
|
||||||
from danswer.direct_qa.qa_utils import get_usable_chunks
|
|
||||||
from danswer.document_index.factory import get_default_document_index
|
from danswer.document_index.factory import get_default_document_index
|
||||||
from danswer.document_index.vespa.index import VespaIndex
|
from danswer.document_index.vespa.index import VespaIndex
|
||||||
from danswer.search.access_filters import build_access_filters_for_user
|
from danswer.search.access_filters import build_access_filters_for_user
|
||||||
from danswer.search.danswer_helper import query_intent
|
|
||||||
from danswer.search.danswer_helper import recommend_search_flow
|
from danswer.search.danswer_helper import recommend_search_flow
|
||||||
from danswer.search.models import IndexFilters
|
from danswer.search.models import IndexFilters
|
||||||
from danswer.search.models import QueryFlow
|
|
||||||
from danswer.search.models import SearchQuery
|
|
||||||
from danswer.search.search_runner import chunks_to_search_docs
|
from danswer.search.search_runner import chunks_to_search_docs
|
||||||
from danswer.search.search_runner import danswer_search
|
from danswer.search.search_runner import danswer_search
|
||||||
from danswer.search.search_runner import search_chunks
|
|
||||||
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
|
||||||
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
from danswer.secondary_llm_flows.query_validation import get_query_answerability
|
||||||
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
|
||||||
@@ -41,13 +28,10 @@ from danswer.server.models import QAFeedbackRequest
|
|||||||
from danswer.server.models import QAResponse
|
from danswer.server.models import QAResponse
|
||||||
from danswer.server.models import QueryValidationResponse
|
from danswer.server.models import QueryValidationResponse
|
||||||
from danswer.server.models import QuestionRequest
|
from danswer.server.models import QuestionRequest
|
||||||
from danswer.server.models import RerankedRetrievalDocs
|
|
||||||
from danswer.server.models import SearchDoc
|
from danswer.server.models import SearchDoc
|
||||||
from danswer.server.models import SearchFeedbackRequest
|
from danswer.server.models import SearchFeedbackRequest
|
||||||
from danswer.server.models import SearchResponse
|
from danswer.server.models import SearchResponse
|
||||||
from danswer.server.utils import get_json_line
|
|
||||||
from danswer.utils.logger import setup_logger
|
from danswer.utils.logger import setup_logger
|
||||||
from danswer.utils.timing import log_generator_function_time
|
|
||||||
|
|
||||||
logger = setup_logger()
|
logger = setup_logger()
|
||||||
|
|
||||||
@@ -190,144 +174,10 @@ def stream_direct_qa(
|
|||||||
user: User | None = Depends(current_user),
|
user: User | None = Depends(current_user),
|
||||||
db_session: Session = Depends(get_session),
|
db_session: Session = Depends(get_session),
|
||||||
) -> StreamingResponse:
|
) -> StreamingResponse:
|
||||||
send_packet_debug_msg = "Sending Packet: {}"
|
packets = answer_qa_query_stream(
|
||||||
top_documents_key = "top_documents"
|
question=question, user=user, db_session=db_session
|
||||||
unranked_top_docs_key = "unranked_top_documents"
|
|
||||||
predicted_flow_key = "predicted_flow"
|
|
||||||
predicted_search_key = "predicted_search"
|
|
||||||
query_event_id_key = "query_event_id"
|
|
||||||
|
|
||||||
logger.debug(
|
|
||||||
f"Received QA query ({question.search_type.value} search): {question.query}"
|
|
||||||
)
|
)
|
||||||
logger.debug(f"Query filters: {question.filters}")
|
return StreamingResponse(packets, media_type="application/json")
|
||||||
|
|
||||||
@log_generator_function_time()
|
|
||||||
def stream_qa_portions(
|
|
||||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
|
||||||
) -> Generator[str, None, None]:
|
|
||||||
answer_so_far: str = ""
|
|
||||||
query = question.query
|
|
||||||
offset_count = question.offset if question.offset is not None else 0
|
|
||||||
|
|
||||||
time_cutoff, favor_recent = extract_question_time_filters(question)
|
|
||||||
question.filters.time_cutoff = time_cutoff # not used but just in case
|
|
||||||
filters = question.filters
|
|
||||||
|
|
||||||
user_acl_filters = build_access_filters_for_user(user, db_session)
|
|
||||||
final_filters = IndexFilters(
|
|
||||||
source_type=filters.source_type,
|
|
||||||
document_set=filters.document_set,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
access_control_list=user_acl_filters,
|
|
||||||
)
|
|
||||||
|
|
||||||
search_query = SearchQuery(
|
|
||||||
query=query,
|
|
||||||
search_type=question.search_type,
|
|
||||||
filters=final_filters,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
)
|
|
||||||
|
|
||||||
ranked_chunks, unranked_chunks = search_chunks(
|
|
||||||
query=search_query, document_index=get_default_document_index()
|
|
||||||
)
|
|
||||||
|
|
||||||
# TODO retire this
|
|
||||||
predicted_search, predicted_flow = query_intent(query)
|
|
||||||
|
|
||||||
if not ranked_chunks:
|
|
||||||
logger.debug("No Documents Found")
|
|
||||||
empty_docs_result = {
|
|
||||||
top_documents_key: None,
|
|
||||||
unranked_top_docs_key: None,
|
|
||||||
predicted_flow_key: predicted_flow,
|
|
||||||
predicted_search_key: predicted_search,
|
|
||||||
}
|
|
||||||
logger.debug(send_packet_debug_msg.format(empty_docs_result))
|
|
||||||
yield get_json_line(empty_docs_result)
|
|
||||||
return
|
|
||||||
|
|
||||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
|
||||||
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
|
|
||||||
initial_response = RerankedRetrievalDocs(
|
|
||||||
top_documents=top_docs,
|
|
||||||
unranked_top_documents=unranked_top_docs,
|
|
||||||
# if generative AI is disabled, set flow as search so frontend
|
|
||||||
# doesn't ask the user if they want to run QA over more documents
|
|
||||||
predicted_flow=QueryFlow.SEARCH
|
|
||||||
if disable_generative_answer
|
|
||||||
else predicted_flow,
|
|
||||||
predicted_search=predicted_search,
|
|
||||||
time_cutoff=time_cutoff,
|
|
||||||
favor_recent=favor_recent,
|
|
||||||
).dict()
|
|
||||||
|
|
||||||
logger.debug(send_packet_debug_msg.format(initial_response))
|
|
||||||
yield get_json_line(initial_response)
|
|
||||||
|
|
||||||
if disable_generative_answer:
|
|
||||||
logger.debug("Skipping QA because generative AI is disabled")
|
|
||||||
return
|
|
||||||
|
|
||||||
try:
|
|
||||||
qa_model = get_default_qa_model()
|
|
||||||
except Exception as e:
|
|
||||||
logger.exception("Unable to get QA model")
|
|
||||||
error = StreamingError(error=str(e))
|
|
||||||
yield get_json_line(error.dict())
|
|
||||||
return
|
|
||||||
|
|
||||||
# remove chunks marked as not applicable for QA (e.g. Google Drive file
|
|
||||||
# types which can't be parsed). These chunks are useful to show in the
|
|
||||||
# search results, but not for QA.
|
|
||||||
filtered_ranked_chunks = [
|
|
||||||
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
|
|
||||||
]
|
|
||||||
|
|
||||||
# get all chunks that fit into the token limit
|
|
||||||
usable_chunks = get_usable_chunks(
|
|
||||||
chunks=filtered_ranked_chunks,
|
|
||||||
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
|
|
||||||
offset=offset_count,
|
|
||||||
)
|
|
||||||
logger.debug(
|
|
||||||
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
|
|
||||||
)
|
|
||||||
|
|
||||||
try:
|
|
||||||
for response_packet in qa_model.answer_question_stream(
|
|
||||||
query, usable_chunks
|
|
||||||
):
|
|
||||||
if response_packet is None:
|
|
||||||
continue
|
|
||||||
if (
|
|
||||||
isinstance(response_packet, DanswerAnswerPiece)
|
|
||||||
and response_packet.answer_piece
|
|
||||||
):
|
|
||||||
answer_so_far = answer_so_far + response_packet.answer_piece
|
|
||||||
logger.debug(f"Sending packet: {response_packet}")
|
|
||||||
yield get_json_line(response_packet.dict())
|
|
||||||
except Exception:
|
|
||||||
# exception is logged in the answer_question method, no need to re-log
|
|
||||||
logger.exception("Failed to run QA")
|
|
||||||
yield get_json_line(
|
|
||||||
{"error": "The LLM failed to produce a useable response"}
|
|
||||||
)
|
|
||||||
|
|
||||||
query_event_id = create_query_event(
|
|
||||||
query=query,
|
|
||||||
search_type=question.search_type,
|
|
||||||
llm_answer=answer_so_far,
|
|
||||||
retrieved_document_ids=[doc.document_id for doc in top_docs],
|
|
||||||
user_id=None if user is None else user.id,
|
|
||||||
db_session=db_session,
|
|
||||||
)
|
|
||||||
|
|
||||||
yield get_json_line({query_event_id_key: query_event_id})
|
|
||||||
return
|
|
||||||
|
|
||||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
|
||||||
|
|
||||||
|
|
||||||
@router.post("/query-feedback")
|
@router.post("/query-feedback")
|
||||||
|
Reference in New Issue
Block a user