From 61c9343a7e3a4f77ddf38472d901c9a5cb8cbff9 Mon Sep 17 00:00:00 2001 From: Yuhong Sun Date: Tue, 31 Oct 2023 23:25:26 -0700 Subject: [PATCH] Clean Up Duplicate Code (#670) --- backend/danswer/background/celery/celery.py | 2 +- backend/danswer/chat/chat_llm.py | 4 +- .../slack/handlers/handle_feedback.py | 2 + backend/danswer/db/feedback.py | 16 +++-- backend/danswer/direct_qa/answer_question.py | 51 +++------------- backend/danswer/direct_qa/llm_utils.py | 2 +- backend/danswer/document_index/__init__.py | 7 --- backend/danswer/document_index/factory.py | 7 +++ backend/danswer/indexing/indexing_pipeline.py | 2 +- backend/danswer/llm/{build.py => factory.py} | 0 backend/danswer/main.py | 2 +- backend/danswer/search/search_runner.py | 59 +++++++++++++++++++ .../secondary_llm_flows/answer_validation.py | 2 +- .../secondary_llm_flows/chat_helpers.py | 2 +- .../secondary_llm_flows/extract_filters.py | 2 +- .../secondary_llm_flows/query_validation.py | 2 +- backend/danswer/server/manage.py | 3 + backend/danswer/server/search_backend.py | 43 +++----------- backend/danswer/utils/acl.py | 2 +- 19 files changed, 109 insertions(+), 101 deletions(-) create mode 100644 backend/danswer/document_index/factory.py rename backend/danswer/llm/{build.py => factory.py} (100%) diff --git a/backend/danswer/background/celery/celery.py b/backend/danswer/background/celery/celery.py index 07d09afdc..0ada09b53 100644 --- a/backend/danswer/background/celery/celery.py +++ b/backend/danswer/background/celery/celery.py @@ -28,7 +28,7 @@ from danswer.db.engine import SYNC_DB_API from danswer.db.models import DocumentSet from danswer.db.tasks import check_live_task_not_timed_out from danswer.db.tasks import get_latest_task -from danswer.document_index import get_default_document_index +from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import UpdateRequest from danswer.utils.batching import batch_generator diff --git a/backend/danswer/chat/chat_llm.py b/backend/danswer/chat/chat_llm.py index 6b0c3e8c2..0ebceb712 100644 --- a/backend/danswer/chat/chat_llm.py +++ b/backend/danswer/chat/chat_llm.py @@ -31,9 +31,9 @@ from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import DanswerChatModelOut from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.qa_utils import get_usable_chunks -from danswer.document_index import get_default_document_index +from danswer.document_index.factory import get_default_document_index from danswer.indexing.models import InferenceChunk -from danswer.llm.build import get_default_llm +from danswer.llm.factory import get_default_llm from danswer.llm.interfaces import LLM from danswer.llm.utils import get_default_llm_tokenizer from danswer.llm.utils import translate_danswer_msg_to_langchain diff --git a/backend/danswer/danswerbot/slack/handlers/handle_feedback.py b/backend/danswer/danswerbot/slack/handlers/handle_feedback.py index 160e75f33..49bc03cd1 100644 --- a/backend/danswer/danswerbot/slack/handlers/handle_feedback.py +++ b/backend/danswer/danswerbot/slack/handlers/handle_feedback.py @@ -9,6 +9,7 @@ from danswer.danswerbot.slack.utils import decompose_block_id from danswer.db.engine import get_sqlalchemy_engine from danswer.db.feedback import create_doc_retrieval_feedback from danswer.db.feedback import update_query_event_feedback +from danswer.document_index.factory import get_default_document_index def handle_slack_feedback( @@ -45,6 +46,7 @@ def handle_slack_feedback( document_id=doc_id, document_rank=doc_rank, user_id=None, + document_index=get_default_document_index(), db_session=db_session, clicked=False, # Not tracking this for Slack feedback=SearchFeedbackType.ENDORSE diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index 18bf9540c..620ef95a3 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -15,7 +15,7 @@ from danswer.db.models import ChatMessageFeedback from danswer.db.models import Document as DbDocument from danswer.db.models import DocumentRetrievalFeedback from danswer.db.models import QueryEvent -from danswer.document_index import get_default_document_index +from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import UpdateRequest from danswer.search.models import SearchType @@ -57,7 +57,9 @@ def fetch_docs_ranked_by_boost( return list(doc_list) -def update_document_boost(db_session: Session, document_id: str, boost: int) -> None: +def update_document_boost( + db_session: Session, document_id: str, boost: int, document_index: DocumentIndex +) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) result = db_session.execute(stmt).scalar_one_or_none() if result is None: @@ -70,12 +72,14 @@ def update_document_boost(db_session: Session, document_id: str, boost: int) -> boost=boost, ) - get_default_document_index().update([update]) + document_index.update([update]) db_session.commit() -def update_document_hidden(db_session: Session, document_id: str, hidden: bool) -> None: +def update_document_hidden( + db_session: Session, document_id: str, hidden: bool, document_index: DocumentIndex +) -> None: stmt = select(DbDocument).where(DbDocument.id == document_id) result = db_session.execute(stmt).scalar_one_or_none() if result is None: @@ -88,7 +92,7 @@ def update_document_hidden(db_session: Session, document_id: str, hidden: bool) hidden=hidden, ) - get_default_document_index().update([update]) + document_index.update([update]) db_session.commit() @@ -149,6 +153,7 @@ def create_doc_retrieval_feedback( document_id: str, document_rank: int, user_id: UUID | None, + document_index: DocumentIndex, db_session: Session, clicked: bool = False, feedback: SearchFeedbackType | None = None, @@ -185,7 +190,6 @@ def create_doc_retrieval_feedback( raise ValueError("Unhandled document feedback type") if feedback in [SearchFeedbackType.ENDORSE, SearchFeedbackType.REJECT]: - document_index = get_default_document_index() update = UpdateRequest( document_ids=[document_id], boost=doc_m.boost, diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 8642e17e9..9cdfc2709 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -6,22 +6,17 @@ from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_DOCUMENT_TOKENS_FED_TO_GENERATIVE_MODEL from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import IGNORE_FOR_QA -from danswer.db.feedback import create_query_event -from danswer.db.feedback import update_query_event_retrieved_documents from danswer.db.models import User from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.models import LLMMetricsContainer from danswer.direct_qa.qa_utils import get_usable_chunks -from danswer.document_index import get_default_document_index -from danswer.search.access_filters import build_access_filters_for_user +from danswer.document_index.factory import get_default_document_index from danswer.search.danswer_helper import query_intent -from danswer.search.models import IndexFilters from danswer.search.models import QueryFlow from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer -from danswer.search.models import SearchQuery from danswer.search.search_runner import chunks_to_search_docs -from danswer.search.search_runner import search_chunks +from danswer.search.search_runner import danswer_search from danswer.secondary_llm_flows.answer_validation import get_answer_validity from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters from danswer.server.models import QAResponse @@ -52,41 +47,20 @@ def answer_qa_query( time_cutoff, favor_recent = extract_question_time_filters(question) question.filters.time_cutoff = time_cutoff - filters = question.filters + question.favor_recent = favor_recent - query_event_id = create_query_event( - query=query, - search_type=question.search_type, - llm_answer=None, - user_id=user.id if user is not None else None, + ranked_chunks, unranked_chunks, query_event_id = danswer_search( + question=question, + user=user, db_session=db_session, - ) - - user_id = None if user is None else user.id - user_acl_filters = build_access_filters_for_user(user, db_session) - final_filters = IndexFilters( - source_type=filters.source_type, - document_set=filters.document_set, - time_cutoff=time_cutoff, - access_control_list=user_acl_filters, - ) - search_query = SearchQuery( - query=query, - search_type=question.search_type, - filters=final_filters, - favor_recent=True if question.favor_recent is None else question.favor_recent, - ) - - # TODO retire this - predicted_search, predicted_flow = query_intent(query) - - ranked_chunks, unranked_chunks = search_chunks( - query=search_query, document_index=get_default_document_index(), retrieval_metrics_callback=retrieval_metrics_callback, rerank_metrics_callback=rerank_metrics_callback, ) + # TODO retire this + predicted_search, predicted_flow = query_intent(query) + if not ranked_chunks: return QAResponse( answer=None, @@ -103,13 +77,6 @@ def answer_qa_query( top_docs = chunks_to_search_docs(ranked_chunks) unranked_top_docs = chunks_to_search_docs(unranked_chunks) - update_query_event_retrieved_documents( - db_session=db_session, - retrieved_document_ids=[doc.document_id for doc in top_docs], - query_id=query_event_id, - user_id=user_id, - ) - if disable_generative_answer: logger.debug("Skipping QA because generative AI is disabled") return QAResponse( diff --git a/backend/danswer/direct_qa/llm_utils.py b/backend/danswer/direct_qa/llm_utils.py index a43b05a21..b4c00566c 100644 --- a/backend/danswer/direct_qa/llm_utils.py +++ b/backend/danswer/direct_qa/llm_utils.py @@ -6,7 +6,7 @@ from danswer.direct_qa.qa_block import QABlock from danswer.direct_qa.qa_block import QAHandler from danswer.direct_qa.qa_block import SingleMessageQAHandler from danswer.direct_qa.qa_block import SingleMessageScratchpadHandler -from danswer.llm.build import get_default_llm +from danswer.llm.factory import get_default_llm from danswer.utils.logger import setup_logger logger = setup_logger() diff --git a/backend/danswer/document_index/__init__.py b/backend/danswer/document_index/__init__.py index 40a44cf6a..e69de29bb 100644 --- a/backend/danswer/document_index/__init__.py +++ b/backend/danswer/document_index/__init__.py @@ -1,7 +0,0 @@ -from danswer.document_index.interfaces import DocumentIndex -from danswer.document_index.vespa.index import VespaIndex - - -def get_default_document_index() -> DocumentIndex: - # Currently only supporting Vespa - return VespaIndex() diff --git a/backend/danswer/document_index/factory.py b/backend/danswer/document_index/factory.py new file mode 100644 index 000000000..40a44cf6a --- /dev/null +++ b/backend/danswer/document_index/factory.py @@ -0,0 +1,7 @@ +from danswer.document_index.interfaces import DocumentIndex +from danswer.document_index.vespa.index import VespaIndex + + +def get_default_document_index() -> DocumentIndex: + # Currently only supporting Vespa + return VespaIndex() diff --git a/backend/danswer/indexing/indexing_pipeline.py b/backend/danswer/indexing/indexing_pipeline.py index 1d41c681c..c1b52895a 100644 --- a/backend/danswer/indexing/indexing_pipeline.py +++ b/backend/danswer/indexing/indexing_pipeline.py @@ -11,7 +11,7 @@ from danswer.db.document import prepare_to_modify_documents from danswer.db.document import upsert_documents_complete from danswer.db.document_set import fetch_document_sets_for_documents from danswer.db.engine import get_sqlalchemy_engine -from danswer.document_index import get_default_document_index +from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import DocumentIndex from danswer.document_index.interfaces import DocumentMetadata from danswer.indexing.chunker import Chunker diff --git a/backend/danswer/llm/build.py b/backend/danswer/llm/factory.py similarity index 100% rename from backend/danswer/llm/build.py rename to backend/danswer/llm/factory.py diff --git a/backend/danswer/main.py b/backend/danswer/main.py index 82f0ac26e..ef98a360e 100644 --- a/backend/danswer/main.py +++ b/backend/danswer/main.py @@ -31,7 +31,7 @@ from danswer.configs.model_configs import GEN_AI_MODEL_VERSION from danswer.configs.model_configs import SKIP_RERANKING from danswer.db.credentials import create_initial_public_credential from danswer.direct_qa.llm_utils import get_default_qa_model -from danswer.document_index import get_default_document_index +from danswer.document_index.factory import get_default_document_index from danswer.server.cc_pair.api import router as cc_pair_router from danswer.server.chat_backend import router as chat_router from danswer.server.connector import router as connector_router diff --git a/backend/danswer/search/search_runner.py b/backend/danswer/search/search_runner.py index 4038d58d3..ba7f1eae3 100644 --- a/backend/danswer/search/search_runner.py +++ b/backend/danswer/search/search_runner.py @@ -5,6 +5,7 @@ from nltk.corpus import stopwords # type:ignore from nltk.stem import WordNetLemmatizer # type:ignore from nltk.tokenize import word_tokenize # type:ignore from sentence_transformers import SentenceTransformer # type: ignore +from sqlalchemy.orm import Session from danswer.configs.model_configs import ASYM_QUERY_PREFIX from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MAX @@ -12,12 +13,17 @@ from danswer.configs.model_configs import CROSS_ENCODER_RANGE_MIN from danswer.configs.model_configs import NORMALIZE_EMBEDDINGS from danswer.configs.model_configs import SIM_SCORE_RANGE_HIGH from danswer.configs.model_configs import SIM_SCORE_RANGE_LOW +from danswer.db.feedback import create_query_event +from danswer.db.feedback import update_query_event_retrieved_documents +from danswer.db.models import User from danswer.document_index.document_index_utils import ( translate_boost_count_to_multiplier, ) from danswer.document_index.interfaces import DocumentIndex from danswer.indexing.models import InferenceChunk +from danswer.search.access_filters import build_access_filters_for_user from danswer.search.models import ChunkMetric +from danswer.search.models import IndexFilters from danswer.search.models import MAX_METRICS_CONTENT from danswer.search.models import RerankMetricsContainer from danswer.search.models import RetrievalMetricsContainer @@ -25,10 +31,12 @@ from danswer.search.models import SearchQuery from danswer.search.models import SearchType from danswer.search.search_nlp_models import get_default_embedding_model from danswer.search.search_nlp_models import get_default_reranking_model_ensemble +from danswer.server.models import QuestionRequest from danswer.server.models import SearchDoc from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time + logger = setup_logger() @@ -322,3 +330,54 @@ def search_chunks( _log_top_chunk_links(query.search_type.value, ranked_chunks) return ranked_chunks, top_chunks[query.num_rerank :] + + +def danswer_search( + question: QuestionRequest, + user: User | None, + db_session: Session, + document_index: DocumentIndex, + retrieval_metrics_callback: Callable[[RetrievalMetricsContainer], None] + | None = None, + rerank_metrics_callback: Callable[[RerankMetricsContainer], None] | None = None, +) -> tuple[list[InferenceChunk] | None, list[InferenceChunk] | None, int]: + query_event_id = create_query_event( + query=question.query, + search_type=question.search_type, + llm_answer=None, + user_id=user.id if user is not None else None, + db_session=db_session, + ) + + user_acl_filters = build_access_filters_for_user(user, db_session) + final_filters = IndexFilters( + source_type=question.filters.source_type, + document_set=question.filters.document_set, + time_cutoff=question.filters.time_cutoff, + access_control_list=user_acl_filters, + ) + + search_query = SearchQuery( + query=question.query, + search_type=question.search_type, + filters=final_filters, + favor_recent=True if question.favor_recent is None else question.favor_recent, + ) + + ranked_chunks, unranked_chunks = search_chunks( + query=search_query, + document_index=document_index, + retrieval_metrics_callback=retrieval_metrics_callback, + rerank_metrics_callback=rerank_metrics_callback, + ) + + retrieved_ids = [doc.document_id for doc in ranked_chunks] if ranked_chunks else [] + + update_query_event_retrieved_documents( + db_session=db_session, + retrieved_document_ids=retrieved_ids, + query_id=query_event_id, + user_id=None if user is None else user.id, + ) + + return ranked_chunks, unranked_chunks, query_event_id diff --git a/backend/danswer/secondary_llm_flows/answer_validation.py b/backend/danswer/secondary_llm_flows/answer_validation.py index 866471e4b..189dece6f 100644 --- a/backend/danswer/secondary_llm_flows/answer_validation.py +++ b/backend/danswer/secondary_llm_flows/answer_validation.py @@ -2,7 +2,7 @@ from danswer.configs.constants import ANSWER_PAT from danswer.configs.constants import CODE_BLOCK_PAT from danswer.configs.constants import QUESTION_PAT from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt -from danswer.llm.build import get_default_llm +from danswer.llm.factory import get_default_llm from danswer.utils.logger import setup_logger from danswer.utils.timing import log_function_time diff --git a/backend/danswer/secondary_llm_flows/chat_helpers.py b/backend/danswer/secondary_llm_flows/chat_helpers.py index 7925a3fa6..2a60f94f9 100644 --- a/backend/danswer/secondary_llm_flows/chat_helpers.py +++ b/backend/danswer/secondary_llm_flows/chat_helpers.py @@ -1,4 +1,4 @@ -from danswer.llm.build import get_default_llm +from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt diff --git a/backend/danswer/secondary_llm_flows/extract_filters.py b/backend/danswer/secondary_llm_flows/extract_filters.py index 8fcc32663..df4efe7ae 100644 --- a/backend/danswer/secondary_llm_flows/extract_filters.py +++ b/backend/danswer/secondary_llm_flows/extract_filters.py @@ -6,7 +6,7 @@ from datetime import timezone from dateutil.parser import parse from danswer.configs.app_configs import DISABLE_TIME_FILTER_EXTRACTION -from danswer.llm.build import get_default_llm +from danswer.llm.factory import get_default_llm from danswer.llm.utils import dict_based_prompt_to_langchain_prompt from danswer.server.models import QuestionRequest from danswer.utils.logger import setup_logger diff --git a/backend/danswer/secondary_llm_flows/query_validation.py b/backend/danswer/secondary_llm_flows/query_validation.py index 14a9623b0..22e8ba15d 100644 --- a/backend/danswer/secondary_llm_flows/query_validation.py +++ b/backend/danswer/secondary_llm_flows/query_validation.py @@ -6,7 +6,7 @@ from danswer.configs.constants import GENERAL_SEP_PAT from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.qa_block import dict_based_prompt_to_langchain_prompt -from danswer.llm.build import get_default_llm +from danswer.llm.factory import get_default_llm from danswer.server.models import QueryValidationResponse from danswer.server.utils import get_json_line from danswer.utils.logger import setup_logger diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index c496da1c6..547454533 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -26,6 +26,7 @@ from danswer.db.models import User from danswer.direct_qa.llm_utils import check_model_api_key_is_valid from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.qa_utils import get_gen_ai_api_key +from danswer.document_index.factory import get_default_document_index from danswer.dynamic_configs import get_dynamic_config_store from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.server.models import ApiKey @@ -79,6 +80,7 @@ def document_boost_update( db_session=db_session, document_id=boost_update.document_id, boost=boost_update.boost, + document_index=get_default_document_index(), ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) @@ -95,6 +97,7 @@ def document_hidden_update( db_session=db_session, document_id=hidden_update.document_id, hidden=hidden_update.hidden, + document_index=get_default_document_index(), ) except ValueError as e: raise HTTPException(status_code=400, detail=str(e)) diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 4682a5ba3..ea5cea7d7 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -16,14 +16,13 @@ from danswer.db.engine import get_session from danswer.db.feedback import create_doc_retrieval_feedback from danswer.db.feedback import create_query_event from danswer.db.feedback import update_query_event_feedback -from danswer.db.feedback import update_query_event_retrieved_documents from danswer.db.models import User from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.interfaces import StreamingError from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.direct_qa.qa_utils import get_usable_chunks -from danswer.document_index import get_default_document_index +from danswer.document_index.factory import get_default_document_index from danswer.document_index.vespa.index import VespaIndex from danswer.search.access_filters import build_access_filters_for_user from danswer.search.danswer_helper import query_intent @@ -32,6 +31,7 @@ from danswer.search.models import IndexFilters from danswer.search.models import QueryFlow from danswer.search.models import SearchQuery from danswer.search.search_runner import chunks_to_search_docs +from danswer.search.search_runner import danswer_search from danswer.search.search_runner import search_chunks from danswer.secondary_llm_flows.extract_filters import extract_question_time_filters from danswer.secondary_llm_flows.query_validation import get_query_answerability @@ -143,34 +143,13 @@ def handle_search_request( time_cutoff, favor_recent = extract_question_time_filters(question) question.filters.time_cutoff = time_cutoff - filters = question.filters - user_id = None if user is None else user.id + question.favor_recent = favor_recent - query_event_id = create_query_event( - query=query, - search_type=question.search_type, - llm_answer=None, - user_id=user_id, + ranked_chunks, unranked_chunks, query_event_id = danswer_search( + question=question, + user=user, db_session=db_session, - ) - - user_acl_filters = build_access_filters_for_user(user, db_session) - final_filters = IndexFilters( - source_type=filters.source_type, - document_set=filters.document_set, - time_cutoff=filters.time_cutoff, - access_control_list=user_acl_filters, - ) - - search_query = SearchQuery( - query=query, - search_type=question.search_type, - filters=final_filters, - favor_recent=favor_recent, - ) - - ranked_chunks, unranked_chunks = search_chunks( - query=search_query, document_index=get_default_document_index() + document_index=get_default_document_index(), ) if not ranked_chunks: @@ -185,13 +164,6 @@ def handle_search_request( top_docs = chunks_to_search_docs(ranked_chunks) lower_top_docs = chunks_to_search_docs(unranked_chunks) - update_query_event_retrieved_documents( - db_session=db_session, - retrieved_document_ids=[doc.document_id for doc in top_docs], - query_id=query_event_id, - user_id=user_id, - ) - return SearchResponse( top_ranked_docs=top_docs, lower_ranked_docs=lower_top_docs or None, @@ -385,5 +357,6 @@ def process_doc_retrieval_feedback( clicked=feedback.click, feedback=feedback.search_feedback, user_id=user.id if user is not None else None, + document_index=get_default_document_index(), db_session=db_session, ) diff --git a/backend/danswer/utils/acl.py b/backend/danswer/utils/acl.py index aa6576ceb..3cc31ccf2 100644 --- a/backend/danswer/utils/acl.py +++ b/backend/danswer/utils/acl.py @@ -7,7 +7,7 @@ from danswer.access.models import DocumentAccess from danswer.db.document import get_acccess_info_for_documents from danswer.db.engine import get_sqlalchemy_engine from danswer.db.models import Document -from danswer.document_index import get_default_document_index +from danswer.document_index.factory import get_default_document_index from danswer.document_index.interfaces import UpdateRequest from danswer.document_index.vespa.index import VespaIndex from danswer.dynamic_configs import get_dynamic_config_store