More Cleanup and Deduplication (#675)

This commit is contained in:
Yuhong Sun 2023-11-01 16:03:48 -07:00 committed by GitHub
parent 9cd0c197e7
commit 73b653d324
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 224 additions and 247 deletions

View File

@ -1,4 +1,5 @@
from collections.abc import Callable
from collections.abc import Iterator
from sqlalchemy.orm import Session
@ -6,7 +7,10 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.db.feedback import create_query_event
from danswer.db.models import User
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.models import LLMMetricsContainer
from danswer.direct_qa.qa_utils import get_usable_chunks
@ -21,8 +25,11 @@ from danswer.secondary_llm_flows.answer_validation import get_answer_validity
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_function_time
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
@ -169,3 +176,114 @@ def answer_qa_query(
favor_recent=favor_recent,
error_msg=error_msg,
)
@log_generator_function_time()
def answer_qa_query_stream(
question: QuestionRequest,
user: User | None,
db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Iterator[str]:
logger.debug(
f"Received QA query ({question.search_type.value} search): {question.query}"
)
logger.debug(f"Query filters: {question.filters}")
answer_so_far: str = ""
query = question.query
offset_count = question.offset if question.offset is not None else 0
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff
question.favor_recent = favor_recent
ranked_chunks, unranked_chunks, query_event_id = danswer_search(
question=question,
user=user,
db_session=db_session,
document_index=get_default_document_index(),
)
# TODO retire this
predicted_search, predicted_flow = query_intent(query)
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
initial_response = RerankedRetrievalDocs(
top_documents=top_docs,
unranked_top_documents=unranked_top_docs,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search=predicted_search,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
).dict()
logger.debug(f"Sending Initial Retrival Results: {initial_response}")
yield get_json_line(initial_response)
if not ranked_chunks:
logger.debug("No Documents Found")
return
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_qa_model()
except Exception as e:
logger.exception("Unable to get QA model")
error = StreamingError(error=str(e))
yield get_json_line(error.dict())
return
# remove chunks marked as not applicable for QA (e.g. Google Drive file
# types which can't be parsed). These chunks are useful to show in the
# search results, but not for QA.
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
try:
for response_packet in qa_model.answer_question_stream(query, usable_chunks):
if response_packet is None:
continue
if (
isinstance(response_packet, DanswerAnswerPiece)
and response_packet.answer_piece
):
answer_so_far = answer_so_far + response_packet.answer_piece
logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(response_packet.dict())
except Exception:
# exception is logged in the answer_question method, no need to re-log
logger.exception("Failed to run QA")
error = StreamingError(error="The LLM failed to produce a useable response")
yield get_json_line(error.dict())
query_event_id = create_query_event(
query=query,
search_type=question.search_type,
llm_answer=answer_so_far,
retrieved_document_ids=[doc.document_id for doc in top_docs],
user_id=None if user is None else user.id,
db_session=db_session,
)
yield get_json_line({"query_event_id": query_event_id})

View File

@ -1,5 +1,3 @@
from openai.error import AuthenticationError
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.direct_qa.interfaces import QAModel
from danswer.direct_qa.qa_block import QABlock
@ -12,25 +10,6 @@ from danswer.utils.logger import setup_logger
logger = setup_logger()
def check_model_api_key_is_valid(model_api_key: str) -> bool:
if not model_api_key:
return False
llm = get_default_llm(api_key=model_api_key, timeout=10)
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
llm.invoke("Do not respond")
return True
except AuthenticationError:
return False
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False
# TODO introduce the prompt choice parameter
def get_default_qa_handler(real_time_flow: bool = True) -> QAHandler:
return (

View File

@ -4,7 +4,6 @@ import re
from collections.abc import Generator
from collections.abc import Iterator
from json.decoder import JSONDecodeError
from typing import cast
from typing import Optional
from typing import Tuple
@ -12,8 +11,6 @@ import regex
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.app_configs import QUOTE_ALLOWED_ERROR_PERCENT
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.direct_qa.interfaces import DanswerAnswer
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import DanswerQuote
@ -21,8 +18,6 @@ from danswer.direct_qa.interfaces import DanswerQuotes
from danswer.direct_qa.qa_prompts import ANSWER_PAT
from danswer.direct_qa.qa_prompts import QUOTE_PAT
from danswer.direct_qa.qa_prompts import UNCERTAINTY_PAT
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.indexing.models import InferenceChunk
from danswer.llm.utils import check_number_of_tokens
from danswer.utils.logger import setup_logger
@ -33,17 +28,6 @@ from danswer.utils.text_processing import shared_precompare_cleanup
logger = setup_logger()
def get_gen_ai_api_key() -> str | None:
# first check if the key has been provided by the UI
try:
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
except ConfigNotFoundError:
pass
# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY
def extract_answer_quotes_freeform(
answer_raw: str,
) -> Tuple[Optional[str], Optional[list[str]]]:

View File

@ -1,22 +1,77 @@
import abc
from collections.abc import Iterator
import litellm # type:ignore
from langchain.chat_models import ChatLiteLLM
from langchain.chat_models.base import BaseChatModel
from langchain.schema.language_model import LanguageModelInput
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.configs.model_configs import GEN_AI_API_ENDPOINT
from danswer.configs.model_configs import GEN_AI_API_VERSION
from danswer.configs.model_configs import GEN_AI_MAX_OUTPUT_TOKENS
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.configs.model_configs import GEN_AI_MODEL_VERSION
from danswer.configs.model_configs import GEN_AI_TEMPERATURE
from danswer.llm.interfaces import LangChainChatLLM
from danswer.llm.interfaces import LLM
from danswer.llm.utils import message_generator_to_string_generator
from danswer.llm.utils import should_be_verbose
from danswer.utils.logger import setup_logger
logger = setup_logger()
# If a user configures a different model and it doesn't support all the same
# parameters like frequency and presence, just ignore them
litellm.drop_params = True
litellm.telemetry = False
class LangChainChatLLM(LLM, abc.ABC):
@property
@abc.abstractmethod
def llm(self) -> BaseChatModel:
raise NotImplementedError
def _log_model_config(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
@staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
model_raw = self.llm.invoke(prompt).content
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{model_raw}")
return model_raw
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
output_tokens = []
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
output_tokens.append(token)
yield token
full_output = "".join(output_tokens)
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{full_output}")
def _get_model_str(
model_provider: str | None,
model_version: str | None,

View File

@ -1,10 +1,10 @@
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.model_configs import GEN_AI_MODEL_PROVIDER
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.llm.chat_llm import DefaultMultiLLM
from danswer.llm.custom_llm import CustomModelServer
from danswer.llm.gpt_4_all import DanswerGPT4All
from danswer.llm.interfaces import LLM
from danswer.llm.multi_llm import DefaultMultiLLM
from danswer.llm.utils import get_gen_ai_api_key
def get_default_llm(

View File

@ -1,11 +1,8 @@
import abc
from collections.abc import Iterator
from langchain.chat_models.base import BaseChatModel
from langchain.schema.language_model import LanguageModelInput
from danswer.configs.app_configs import LOG_ALL_MODEL_INTERACTIONS
from danswer.llm.utils import message_generator_to_string_generator
from danswer.utils.logger import setup_logger
@ -28,48 +25,3 @@ class LLM(abc.ABC):
@abc.abstractmethod
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
raise NotImplementedError
class LangChainChatLLM(LLM, abc.ABC):
@property
@abc.abstractmethod
def llm(self) -> BaseChatModel:
raise NotImplementedError
def _log_model_config(self) -> None:
logger.debug(
f"Model Class: {self.llm.__class__.__name__}, Model Config: {self.llm.__dict__}"
)
@staticmethod
def _log_prompt(prompt: LanguageModelInput) -> None:
if isinstance(prompt, list):
for ind, msg in enumerate(prompt):
logger.debug(f"Message {ind}:\n{msg.content}")
if isinstance(prompt, str):
logger.debug(f"Prompt:\n{prompt}")
def invoke(self, prompt: LanguageModelInput) -> str:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
model_raw = self.llm.invoke(prompt).content
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{model_raw}")
return model_raw
def stream(self, prompt: LanguageModelInput) -> Iterator[str]:
self._log_model_config()
if LOG_ALL_MODEL_INTERACTIONS:
self._log_prompt(prompt)
output_tokens = []
for token in message_generator_to_string_generator(self.llm.stream(prompt)):
output_tokens.append(token)
yield token
full_output = "".join(output_tokens)
if LOG_ALL_MODEL_INTERACTIONS:
logger.debug(f"Raw Model Output:\n{full_output}")

View File

@ -1,6 +1,7 @@
from collections.abc import Callable
from collections.abc import Iterator
from typing import Any
from typing import cast
import tiktoken
from langchain.prompts.base import StringPromptValue
@ -14,9 +15,16 @@ from langchain.schema.messages import HumanMessage
from langchain.schema.messages import SystemMessage
from danswer.configs.app_configs import LOG_LEVEL
from danswer.configs.constants import GEN_AI_API_KEY_STORAGE_KEY
from danswer.configs.constants import MessageType
from danswer.configs.model_configs import GEN_AI_API_KEY
from danswer.db.models import ChatMessage
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.interfaces import LLM
from danswer.utils.logger import setup_logger
logger = setup_logger()
_LLM_TOKENIZER: Callable[[str], Any] | None = None
@ -107,7 +115,7 @@ def should_be_verbose() -> bool:
def check_number_of_tokens(
text: str, encode_fn: Callable[[str], list] | None = None
) -> int:
"""Get's the number of tokens in the provided text, using the provided encoding
"""Gets the number of tokens in the provided text, using the provided encoding
function. If none is provided, default to the tiktoken encoder used by GPT-3.5
and GPT-4.
"""
@ -116,3 +124,26 @@ def check_number_of_tokens(
encode_fn = tiktoken.get_encoding("cl100k_base").encode
return len(encode_fn(text))
def get_gen_ai_api_key() -> str | None:
# first check if the key has been provided by the UI
try:
return cast(str, get_dynamic_config_store().load(GEN_AI_API_KEY_STORAGE_KEY))
except ConfigNotFoundError:
pass
# if not provided by the UI, fallback to the env variable
return GEN_AI_API_KEY
def test_llm(llm: LLM) -> bool:
# try for up to 2 timeouts (e.g. 10 seconds in total)
for _ in range(2):
try:
llm.invoke("Do not respond")
return True
except Exception as e:
logger.warning(f"GenAI API key failed for the following reason: {e}")
return False

View File

@ -23,12 +23,13 @@ from danswer.db.feedback import fetch_docs_ranked_by_boost
from danswer.db.feedback import update_document_boost
from danswer.db.feedback import update_document_hidden
from danswer.db.models import User
from danswer.direct_qa.llm_utils import check_model_api_key_is_valid
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_gen_ai_api_key
from danswer.document_index.factory import get_default_document_index
from danswer.dynamic_configs import get_dynamic_config_store
from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.llm.factory import get_default_llm
from danswer.llm.utils import get_gen_ai_api_key
from danswer.llm.utils import test_llm
from danswer.server.models import ApiKey
from danswer.server.models import BoostDoc
from danswer.server.models import BoostUpdateRequest
@ -132,7 +133,8 @@ def validate_existing_genai_api_key(
raise HTTPException(status_code=404, detail="Key not found")
try:
is_valid = check_model_api_key_is_valid(genai_api_key)
llm = get_default_llm(api_key=genai_api_key, timeout=10)
is_valid = test_llm(llm)
except ValueError:
# this is the case where they aren't using an OpenAI-based model
is_valid = True
@ -168,9 +170,15 @@ def store_genai_api_key(
_: User = Depends(current_admin_user),
) -> None:
try:
is_valid = check_model_api_key_is_valid(request.api_key)
if not request.api_key:
raise HTTPException(400, "No API key provided")
llm = get_default_llm(api_key=request.api_key, timeout=10)
is_valid = test_llm(llm)
if not is_valid:
raise HTTPException(400, "Invalid API key provided")
get_dynamic_config_store().store(GEN_AI_API_KEY_STORAGE_KEY, request.api_key)
except RuntimeError as e:
raise HTTPException(400, str(e))

View File

@ -1,5 +1,3 @@
from collections.abc import Generator
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
@ -9,30 +7,19 @@ from sqlalchemy.orm import Session
from danswer.auth.users import current_admin_user
from danswer.auth.users import current_user
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.db.engine import get_session
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_feedback
from danswer.db.models import User
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.interfaces import StreamingError
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.direct_qa.qa_utils import get_usable_chunks
from danswer.direct_qa.answer_question import answer_qa_query_stream
from danswer.document_index.factory import get_default_document_index
from danswer.document_index.vespa.index import VespaIndex
from danswer.search.access_filters import build_access_filters_for_user
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
from danswer.search.models import IndexFilters
from danswer.search.models import QueryFlow
from danswer.search.models import SearchQuery
from danswer.search.search_runner import chunks_to_search_docs
from danswer.search.search_runner import danswer_search
from danswer.search.search_runner import search_chunks
from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters
from danswer.secondary_llm_flows.query_validation import get_query_answerability
from danswer.secondary_llm_flows.query_validation import stream_query_answerability
@ -41,13 +28,10 @@ from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QueryValidationResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import RerankedRetrievalDocs
from danswer.server.models import SearchDoc
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.server.utils import get_json_line
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
logger = setup_logger()
@ -190,144 +174,10 @@ def stream_direct_qa(
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
send_packet_debug_msg = "Sending Packet: {}"
top_documents_key = "top_documents"
unranked_top_docs_key = "unranked_top_documents"
predicted_flow_key = "predicted_flow"
predicted_search_key = "predicted_search"
query_event_id_key = "query_event_id"
logger.debug(
f"Received QA query ({question.search_type.value} search): {question.query}"
packets = answer_qa_query_stream(
question=question, user=user, db_session=db_session
)
logger.debug(f"Query filters: {question.filters}")
@log_generator_function_time()
def stream_qa_portions(
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Generator[str, None, None]:
answer_so_far: str = ""
query = question.query
offset_count = question.offset if question.offset is not None else 0
time_cutoff, favor_recent = extract_question_time_filters(question)
question.filters.time_cutoff = time_cutoff # not used but just in case
filters = question.filters
user_acl_filters = build_access_filters_for_user(user, db_session)
final_filters = IndexFilters(
source_type=filters.source_type,
document_set=filters.document_set,
time_cutoff=time_cutoff,
access_control_list=user_acl_filters,
)
search_query = SearchQuery(
query=query,
search_type=question.search_type,
filters=final_filters,
favor_recent=favor_recent,
)
ranked_chunks, unranked_chunks = search_chunks(
query=search_query, document_index=get_default_document_index()
)
# TODO retire this
predicted_search, predicted_flow = query_intent(query)
if not ranked_chunks:
logger.debug("No Documents Found")
empty_docs_result = {
top_documents_key: None,
unranked_top_docs_key: None,
predicted_flow_key: predicted_flow,
predicted_search_key: predicted_search,
}
logger.debug(send_packet_debug_msg.format(empty_docs_result))
yield get_json_line(empty_docs_result)
return
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
initial_response = RerankedRetrievalDocs(
top_documents=top_docs,
unranked_top_documents=unranked_top_docs,
# if generative AI is disabled, set flow as search so frontend
# doesn't ask the user if they want to run QA over more documents
predicted_flow=QueryFlow.SEARCH
if disable_generative_answer
else predicted_flow,
predicted_search=predicted_search,
time_cutoff=time_cutoff,
favor_recent=favor_recent,
).dict()
logger.debug(send_packet_debug_msg.format(initial_response))
yield get_json_line(initial_response)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return
try:
qa_model = get_default_qa_model()
except Exception as e:
logger.exception("Unable to get QA model")
error = StreamingError(error=str(e))
yield get_json_line(error.dict())
return
# remove chunks marked as not applicable for QA (e.g. Google Drive file
# types which can't be parsed). These chunks are useful to show in the
# search results, but not for QA.
filtered_ranked_chunks = [
chunk for chunk in ranked_chunks if not chunk.metadata.get(IGNORE_FOR_QA)
]
# get all chunks that fit into the token limit
usable_chunks = get_usable_chunks(
chunks=filtered_ranked_chunks,
token_limit=NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL,
offset=offset_count,
)
logger.debug(
f"Chunks fed to LLM: {[chunk.semantic_identifier for chunk in usable_chunks]}"
)
try:
for response_packet in qa_model.answer_question_stream(
query, usable_chunks
):
if response_packet is None:
continue
if (
isinstance(response_packet, DanswerAnswerPiece)
and response_packet.answer_piece
):
answer_so_far = answer_so_far + response_packet.answer_piece
logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(response_packet.dict())
except Exception:
# exception is logged in the answer_question method, no need to re-log
logger.exception("Failed to run QA")
yield get_json_line(
{"error": "The LLM failed to produce a useable response"}
)
query_event_id = create_query_event(
query=query,
search_type=question.search_type,
llm_answer=answer_so_far,
retrieved_document_ids=[doc.document_id for doc in top_docs],
user_id=None if user is None else user.id,
db_session=db_session,
)
yield get_json_line({query_event_id_key: query_event_id})
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")
return StreamingResponse(packets, media_type="application/json")
@router.post("/query-feedback")