Add retrieved_document_ids to QueryEvent

This commit is contained in:
Weves 2023-10-20 11:12:55 -07:00 committed by Chris Weaver
parent 47ab273353
commit a7099a1917
5 changed files with 91 additions and 11 deletions

View File

@ -0,0 +1,31 @@
"""Added retrieved docs to query event
Revision ID: 9d97fecfab7f
Revises: ffc707a226b4
Create Date: 2023-10-20 12:22:31.930449
"""
from alembic import op
import sqlalchemy as sa
from sqlalchemy.dialects import postgresql
# revision identifiers, used by Alembic.
revision = "9d97fecfab7f"
down_revision = "ffc707a226b4"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.add_column(
"query_event",
sa.Column(
"retrieved_document_ids",
postgresql.ARRAY(sa.String()),
nullable=True,
),
)
def downgrade() -> None:
op.drop_column("query_event", "retrieved_document_ids")

View File

@ -94,16 +94,18 @@ def update_document_hidden(db_session: Session, document_id: str, hidden: bool)
def create_query_event(
db_session: Session,
query: str,
selected_flow: SearchType | None,
llm_answer: str | None,
user_id: UUID | None,
db_session: Session,
retrieved_document_ids: list[str] | None = None,
) -> int:
query_event = QueryEvent(
query=query,
selected_search_flow=selected_flow,
llm_answer=llm_answer,
retrieved_document_ids=retrieved_document_ids,
user_id=user_id,
)
db_session.add(query_event)
@ -113,10 +115,10 @@ def create_query_event(
def update_query_event_feedback(
db_session: Session,
feedback: QAFeedbackType,
query_id: int,
user_id: UUID | None,
db_session: Session,
) -> None:
query_event = fetch_query_event_by_id(query_id, db_session)
@ -124,7 +126,21 @@ def update_query_event_feedback(
raise ValueError("User trying to give feedback on a query run by another user.")
query_event.feedback = feedback
db_session.commit()
def update_query_event_retrieved_documents(
db_session: Session,
retrieved_document_ids: list[str],
query_id: int,
user_id: UUID | None,
) -> None:
query_event = fetch_query_event_by_id(query_id, db_session)
if user_id != query_event.user_id:
raise ValueError("User trying to update docs on a query run by another user.")
query_event.retrieved_document_ids = retrieved_document_ids
db_session.commit()

View File

@ -346,6 +346,12 @@ class QueryEvent(Base):
Enum(SearchType), nullable=True
)
llm_answer: Mapped[str | None] = mapped_column(Text, default=None)
# Document IDs of the top context documents retrieved for the query (if any)
# NOTE: not using a foreign key to enable easy deletion of documents without
# needing to adjust `QueryEvent` rows
retrieved_document_ids: Mapped[list[str] | None] = mapped_column(
postgresql.ARRAY(String), nullable=True
)
feedback: Mapped[QAFeedbackType | None] = mapped_column(
Enum(QAFeedbackType), nullable=True
)

View File

@ -9,6 +9,7 @@ from danswer.configs.app_configs import QA_TIMEOUT
from danswer.configs.constants import IGNORE_FOR_QA
from danswer.datastores.document_index import get_default_document_index
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
@ -55,7 +56,9 @@ def answer_qa_query(
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.KEYWORD,
selected_flow=SearchType.KEYWORD
if question.use_keyword
else SearchType.SEMANTIC,
llm_answer=None,
user_id=user.id if user is not None else None,
db_session=db_session,
@ -97,13 +100,22 @@ def answer_qa_query(
query_event_id=query_event_id,
)
top_docs = chunks_to_search_docs(ranked_chunks)
unranked_top_docs = chunks_to_search_docs(unranked_chunks)
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_docs],
query_id=query_event_id,
user_id=user_id,
)
if disable_generative_answer:
logger.debug("Skipping QA because generative AI is disabled")
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
# set flow as search so frontend doesn't ask the user if they want
# to run QA over more documents
predicted_flow=QueryFlow.SEARCH,
@ -119,8 +131,8 @@ def answer_qa_query(
return QAResponse(
answer=None,
quotes=None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=str(e),
@ -162,8 +174,8 @@ def answer_qa_query(
return QAResponse(
answer=d_answer.answer if d_answer else None,
quotes=quotes.quotes if quotes else None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
eval_res_valid=True if valid else False,
@ -174,8 +186,8 @@ def answer_qa_query(
return QAResponse(
answer=d_answer.answer if d_answer else None,
quotes=quotes.quotes if quotes else None,
top_ranked_docs=chunks_to_search_docs(ranked_chunks),
lower_ranked_docs=chunks_to_search_docs(unranked_chunks),
top_ranked_docs=top_docs,
lower_ranked_docs=unranked_top_docs,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=error_msg,

View File

@ -20,6 +20,7 @@ from danswer.db.engine import get_session
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_feedback
from danswer.db.feedback import update_query_event_retrieved_documents
from danswer.db.models import User
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.exceptions import OpenAIKeyMissing
@ -165,6 +166,12 @@ def semantic_search(
top_docs = chunks_to_search_docs(ranked_chunks)
other_top_docs = chunks_to_search_docs(unranked_chunks)
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_docs],
query_id=query_event_id,
user_id=user_id,
)
return SearchResponse(
top_ranked_docs=top_docs,
@ -203,6 +210,13 @@ def keyword_search(
)
top_docs = chunks_to_search_docs(ranked_chunks)
update_query_event_retrieved_documents(
db_session=db_session,
retrieved_document_ids=[doc.document_id for doc in top_docs],
query_id=query_event_id,
user_id=user_id,
)
return SearchResponse(
top_ranked_docs=top_docs, lower_ranked_docs=None, query_event_id=query_event_id
)
@ -349,6 +363,7 @@ def stream_direct_qa(
if question.use_keyword
else SearchType.SEMANTIC,
llm_answer=answer_so_far,
retrieved_document_ids=[doc.document_id for doc in top_docs],
user_id=user_id,
db_session=db_session,
)