From a7099a19174fe3ae44392241f81363334deffe4f Mon Sep 17 00:00:00 2001 From: Weves Date: Fri, 20 Oct 2023 11:12:55 -0700 Subject: [PATCH] Add retrieved_document_ids to QueryEvent --- ...b7f_added_retrieved_docs_to_query_event.py | 31 +++++++++++++++++++ backend/danswer/db/feedback.py | 20 ++++++++++-- backend/danswer/db/models.py | 6 ++++ backend/danswer/direct_qa/answer_question.py | 30 ++++++++++++------ backend/danswer/server/search_backend.py | 15 +++++++++ 5 files changed, 91 insertions(+), 11 deletions(-) create mode 100644 backend/alembic/versions/9d97fecfab7f_added_retrieved_docs_to_query_event.py diff --git a/backend/alembic/versions/9d97fecfab7f_added_retrieved_docs_to_query_event.py b/backend/alembic/versions/9d97fecfab7f_added_retrieved_docs_to_query_event.py new file mode 100644 index 000000000..088b50731 --- /dev/null +++ b/backend/alembic/versions/9d97fecfab7f_added_retrieved_docs_to_query_event.py @@ -0,0 +1,31 @@ +"""Added retrieved docs to query event + +Revision ID: 9d97fecfab7f +Revises: ffc707a226b4 +Create Date: 2023-10-20 12:22:31.930449 + +""" +from alembic import op +import sqlalchemy as sa +from sqlalchemy.dialects import postgresql + +# revision identifiers, used by Alembic. +revision = "9d97fecfab7f" +down_revision = "ffc707a226b4" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "query_event", + sa.Column( + "retrieved_document_ids", + postgresql.ARRAY(sa.String()), + nullable=True, + ), + ) + + +def downgrade() -> None: + op.drop_column("query_event", "retrieved_document_ids") diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py index be0e844a8..58db9a711 100644 --- a/backend/danswer/db/feedback.py +++ b/backend/danswer/db/feedback.py @@ -94,16 +94,18 @@ def update_document_hidden(db_session: Session, document_id: str, hidden: bool) def create_query_event( + db_session: Session, query: str, selected_flow: SearchType | None, llm_answer: str | None, user_id: UUID | None, - db_session: Session, + retrieved_document_ids: list[str] | None = None, ) -> int: query_event = QueryEvent( query=query, selected_search_flow=selected_flow, llm_answer=llm_answer, + retrieved_document_ids=retrieved_document_ids, user_id=user_id, ) db_session.add(query_event) @@ -113,10 +115,10 @@ def create_query_event( def update_query_event_feedback( + db_session: Session, feedback: QAFeedbackType, query_id: int, user_id: UUID | None, - db_session: Session, ) -> None: query_event = fetch_query_event_by_id(query_id, db_session) @@ -124,7 +126,21 @@ def update_query_event_feedback( raise ValueError("User trying to give feedback on a query run by another user.") query_event.feedback = feedback + db_session.commit() + +def update_query_event_retrieved_documents( + db_session: Session, + retrieved_document_ids: list[str], + query_id: int, + user_id: UUID | None, +) -> None: + query_event = fetch_query_event_by_id(query_id, db_session) + + if user_id != query_event.user_id: + raise ValueError("User trying to update docs on a query run by another user.") + + query_event.retrieved_document_ids = retrieved_document_ids db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index 8ad4d17c5..7b9643189 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -346,6 +346,12 @@ class QueryEvent(Base): Enum(SearchType), nullable=True ) llm_answer: Mapped[str | None] = mapped_column(Text, default=None) + # Document IDs of the top context documents retrieved for the query (if any) + # NOTE: not using a foreign key to enable easy deletion of documents without + # needing to adjust `QueryEvent` rows + retrieved_document_ids: Mapped[list[str] | None] = mapped_column( + postgresql.ARRAY(String), nullable=True + ) feedback: Mapped[QAFeedbackType | None] = mapped_column( Enum(QAFeedbackType), nullable=True ) diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index f8793df50..58ddc2bd6 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -9,6 +9,7 @@ from danswer.configs.app_configs import QA_TIMEOUT from danswer.configs.constants import IGNORE_FOR_QA from danswer.datastores.document_index import get_default_document_index from danswer.db.feedback import create_query_event +from danswer.db.feedback import update_query_event_retrieved_documents from danswer.db.models import User from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError @@ -55,7 +56,9 @@ def answer_qa_query( query_event_id = create_query_event( query=query, - selected_flow=SearchType.KEYWORD, + selected_flow=SearchType.KEYWORD + if question.use_keyword + else SearchType.SEMANTIC, llm_answer=None, user_id=user.id if user is not None else None, db_session=db_session, @@ -97,13 +100,22 @@ def answer_qa_query( query_event_id=query_event_id, ) + top_docs = chunks_to_search_docs(ranked_chunks) + unranked_top_docs = chunks_to_search_docs(unranked_chunks) + update_query_event_retrieved_documents( + db_session=db_session, + retrieved_document_ids=[doc.document_id for doc in top_docs], + query_id=query_event_id, + user_id=user_id, + ) + if disable_generative_answer: logger.debug("Skipping QA because generative AI is disabled") return QAResponse( answer=None, quotes=None, - top_ranked_docs=chunks_to_search_docs(ranked_chunks), - lower_ranked_docs=chunks_to_search_docs(unranked_chunks), + top_ranked_docs=top_docs, + lower_ranked_docs=unranked_top_docs, # set flow as search so frontend doesn't ask the user if they want # to run QA over more documents predicted_flow=QueryFlow.SEARCH, @@ -119,8 +131,8 @@ def answer_qa_query( return QAResponse( answer=None, quotes=None, - top_ranked_docs=chunks_to_search_docs(ranked_chunks), - lower_ranked_docs=chunks_to_search_docs(unranked_chunks), + top_ranked_docs=top_docs, + lower_ranked_docs=unranked_top_docs, predicted_flow=predicted_flow, predicted_search=predicted_search, error_msg=str(e), @@ -162,8 +174,8 @@ def answer_qa_query( return QAResponse( answer=d_answer.answer if d_answer else None, quotes=quotes.quotes if quotes else None, - top_ranked_docs=chunks_to_search_docs(ranked_chunks), - lower_ranked_docs=chunks_to_search_docs(unranked_chunks), + top_ranked_docs=top_docs, + lower_ranked_docs=unranked_top_docs, predicted_flow=predicted_flow, predicted_search=predicted_search, eval_res_valid=True if valid else False, @@ -174,8 +186,8 @@ def answer_qa_query( return QAResponse( answer=d_answer.answer if d_answer else None, quotes=quotes.quotes if quotes else None, - top_ranked_docs=chunks_to_search_docs(ranked_chunks), - lower_ranked_docs=chunks_to_search_docs(unranked_chunks), + top_ranked_docs=top_docs, + lower_ranked_docs=unranked_top_docs, predicted_flow=predicted_flow, predicted_search=predicted_search, error_msg=error_msg, diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index ce22f985d..1343213ca 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -20,6 +20,7 @@ from danswer.db.engine import get_session from danswer.db.feedback import create_doc_retrieval_feedback from danswer.db.feedback import create_query_event from danswer.db.feedback import update_query_event_feedback +from danswer.db.feedback import update_query_event_retrieved_documents from danswer.db.models import User from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.exceptions import OpenAIKeyMissing @@ -165,6 +166,12 @@ def semantic_search( top_docs = chunks_to_search_docs(ranked_chunks) other_top_docs = chunks_to_search_docs(unranked_chunks) + update_query_event_retrieved_documents( + db_session=db_session, + retrieved_document_ids=[doc.document_id for doc in top_docs], + query_id=query_event_id, + user_id=user_id, + ) return SearchResponse( top_ranked_docs=top_docs, @@ -203,6 +210,13 @@ def keyword_search( ) top_docs = chunks_to_search_docs(ranked_chunks) + update_query_event_retrieved_documents( + db_session=db_session, + retrieved_document_ids=[doc.document_id for doc in top_docs], + query_id=query_event_id, + user_id=user_id, + ) + return SearchResponse( top_ranked_docs=top_docs, lower_ranked_docs=None, query_event_id=query_event_id ) @@ -349,6 +363,7 @@ def stream_direct_qa( if question.use_keyword else SearchType.SEMANTIC, llm_answer=answer_so_far, + retrieved_document_ids=[doc.document_id for doc in top_docs], user_id=user_id, db_session=db_session, )