diff --git a/backend/alembic/versions/d929f0c1c6af_feedback_feature.py b/backend/alembic/versions/d929f0c1c6af_feedback_feature.py new file mode 100644 index 000000000..985880e40 --- /dev/null +++ b/backend/alembic/versions/d929f0c1c6af_feedback_feature.py @@ -0,0 +1,93 @@ +"""Feedback Feature + +Revision ID: d929f0c1c6af +Revises: 8aabb57f3b49 +Create Date: 2023-08-27 13:03:54.274987 + +""" +import fastapi_users_db_sqlalchemy +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = "d929f0c1c6af" +down_revision = "8aabb57f3b49" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.create_table( + "query_event", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("query", sa.String(), nullable=False), + sa.Column( + "selected_search_flow", + sa.Enum("KEYWORD", "SEMANTIC", name="searchtype"), + nullable=True, + ), + sa.Column("llm_answer", sa.String(), nullable=True), + sa.Column( + "feedback", + sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype"), + nullable=True, + ), + sa.Column( + "user_id", + fastapi_users_db_sqlalchemy.generics.GUID(), + nullable=True, + ), + sa.Column( + "time_created", + sa.DateTime(timezone=True), + server_default=sa.text("now()"), + nullable=False, + ), + sa.ForeignKeyConstraint( + ["user_id"], + ["user.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.create_table( + "document_retrieval_feedback", + sa.Column("id", sa.Integer(), nullable=False), + sa.Column("qa_event_id", sa.Integer(), nullable=False), + sa.Column("document_id", sa.String(), nullable=False), + sa.Column("document_rank", sa.Integer(), nullable=False), + sa.Column("clicked", sa.Boolean(), nullable=False), + sa.Column( + "feedback", + sa.Enum( + "ENDORSE", + "REJECT", + "HIDE", + "UNHIDE", + name="searchfeedbacktype", + ), + nullable=True, + ), + sa.ForeignKeyConstraint( + ["document_id"], + ["document.id"], + ), + sa.ForeignKeyConstraint( + ["qa_event_id"], + ["query_event.id"], + ), + sa.PrimaryKeyConstraint("id"), + ) + op.add_column("document", sa.Column("boost", sa.Integer(), nullable=False)) + op.add_column("document", sa.Column("hidden", sa.Boolean(), nullable=False)) + op.add_column("document", sa.Column("semantic_id", sa.String(), nullable=False)) + op.add_column("document", sa.Column("link", sa.String(), nullable=True)) + + +def downgrade() -> None: + op.drop_column("document", "link") + op.drop_column("document", "semantic_id") + op.drop_column("document", "hidden") + op.drop_column("document", "boost") + op.drop_table("document_retrieval_feedback") + op.drop_table("query_event") diff --git a/backend/danswer/background/connector_deletion.py b/backend/danswer/background/connector_deletion.py index f3b27c2c1..e5d0c6d2e 100644 --- a/backend/danswer/background/connector_deletion.py +++ b/backend/danswer/background/connector_deletion.py @@ -90,7 +90,7 @@ def _delete_connector_credential_pair( def _get_user( credential: Credential, ) -> str: - if credential.public_doc: + if credential.public_doc or not credential.user: return PUBLIC_DOC_PAT return str(credential.user.id) diff --git a/backend/danswer/background/update.py b/backend/danswer/background/update.py index e5a2a1515..ba310232a 100755 --- a/backend/danswer/background/update.py +++ b/backend/danswer/background/update.py @@ -23,6 +23,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai from danswer.db.credentials import backend_update_credential_json from danswer.db.engine import get_db_current_time from danswer.db.engine import get_sqlalchemy_engine +from danswer.db.feedback import create_document_metadata from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_index_attempt from danswer.db.index_attempt import get_inprogress_index_attempts @@ -246,6 +247,7 @@ def _run_indexing( logger.debug( f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}" ) + index_user_id = ( None if db_credential.public_doc else db_credential.user_id ) diff --git a/backend/danswer/configs/constants.py b/backend/danswer/configs/constants.py index 467e7e7a1..7ce83b49d 100644 --- a/backend/danswer/configs/constants.py +++ b/backend/danswer/configs/constants.py @@ -18,6 +18,7 @@ HTML_SEPARATOR = "\n" PUBLIC_DOC_PAT = "PUBLIC" QUOTE = "quote" BOOST = "boost" +DEFAULT_BOOST = 0 class DocumentSource(str, Enum): @@ -66,3 +67,15 @@ class ModelHostType(str, Enum): # https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183 COLAB_DEMO = "colab-demo" # TODO support for Azure, AWS, GCP GenAI model hosting + + +class QAFeedbackType(str, Enum): + LIKE = "like" # User likes the answer, used for metrics + DISLIKE = "dislike" # User dislikes the answer, used for metrics + + +class SearchFeedbackType(str, Enum): + ENDORSE = "endorse" # boost this document for all future queries + REJECT = "reject" # down-boost this document for all future queries + HIDE = "hide" # mark this document as untrusted, hide from LLM + UNHIDE = "unhide" diff --git a/backend/danswer/datastores/datastore_utils.py b/backend/danswer/datastores/datastore_utils.py index 9b8a4aab2..b1723f5c7 100644 --- a/backend/danswer/datastores/datastore_utils.py +++ b/backend/danswer/datastores/datastore_utils.py @@ -12,6 +12,15 @@ from danswer.connectors.models import IndexAttemptMetadata DEFAULT_BATCH_SIZE = 30 +BOOST_MULTIPLIER = 1.2 + + +def translate_boost_count_to_multiplier(boost: int) -> float: + if boost > 0: + return BOOST_MULTIPLIER**boost + elif boost < 0: + return 1 / (BOOST_MULTIPLIER**boost) + return 1 def get_uuid_from_chunk( diff --git a/backend/danswer/datastores/indexing_pipeline.py b/backend/danswer/datastores/indexing_pipeline.py index 5bcd9d3f1..7fe7a55d7 100644 --- a/backend/danswer/datastores/indexing_pipeline.py +++ b/backend/danswer/datastores/indexing_pipeline.py @@ -32,6 +32,7 @@ class IndexingPipelineProtocol(Protocol): def _upsert_insertion_records( insertion_records: set[DocumentInsertionRecord], index_attempt_metadata: IndexAttemptMetadata, + doc_m_data_lookup: dict[str, tuple[str, str]], ) -> None: with Session(get_sqlalchemy_engine()) as session: upsert_documents_complete( @@ -40,9 +41,11 @@ def _upsert_insertion_records( DocumentMetadata( connector_id=index_attempt_metadata.connector_id, credential_id=index_attempt_metadata.credential_id, - document_id=insertion_record.document_id, + document_id=i_r.document_id, + semantic_identifier=doc_m_data_lookup[i_r.document_id][0], + first_link=doc_m_data_lookup[i_r.document_id][1], ) - for insertion_record in insertion_records + for i_r in insertion_records ], ) @@ -62,6 +65,11 @@ def _get_net_new_documents( return net_new_documents +def _extract_minimal_document_metadata(doc: Document) -> tuple[str, str]: + first_link = next((section.link for section in doc.sections if section.link), "") + return doc.semantic_identifier, first_link + + def _indexing_pipeline( *, chunker: Chunker, @@ -73,6 +81,11 @@ def _indexing_pipeline( """Takes different pieces of the indexing pipeline and applies it to a batch of documents Note that the documents should already be batched at this point so that it does not inflate the memory requirements""" + + document_metadata_lookup = { + doc.id: _extract_minimal_document_metadata(doc) for doc in documents + } + chunks: list[DocAwareChunk] = list( chain(*[chunker.chunk(document=document) for document in documents]) ) @@ -92,6 +105,7 @@ def _indexing_pipeline( _upsert_insertion_records( insertion_records=insertion_records, index_attempt_metadata=index_attempt_metadata, + doc_m_data_lookup=document_metadata_lookup, ) except Exception as e: logger.error( diff --git a/backend/danswer/datastores/interfaces.py b/backend/danswer/datastores/interfaces.py index 28b9814b2..e5452b6de 100644 --- a/backend/danswer/datastores/interfaces.py +++ b/backend/danswer/datastores/interfaces.py @@ -22,6 +22,8 @@ class DocumentMetadata: connector_id: int credential_id: int document_id: str + semantic_identifier: str + first_link: str @dataclass @@ -32,7 +34,7 @@ class UpdateRequest: document_ids: list[str] # all other fields will be left alone allowed_users: list[str] | None = None - boost: int | None = None + boost: float | None = None class Verifiable(abc.ABC): diff --git a/backend/danswer/datastores/vespa/store.py b/backend/danswer/datastores/vespa/store.py index 13ca70e19..57e6d87ef 100644 --- a/backend/danswer/datastores/vespa/store.py +++ b/backend/danswer/datastores/vespa/store.py @@ -341,16 +341,20 @@ class VespaIndex(DocumentIndex): logger.error("Update request received but nothing to update") continue - update_dict: dict[str, dict[str, list[str] | int]] = {"fields": {}} + update_dict: dict[str, dict] = {"fields": {}} if update_request.boost: - update_dict["fields"][BOOST] = update_request.boost + update_dict["fields"][BOOST] = {"assign": update_request.boost} if update_request.allowed_users: - update_dict["fields"][ALLOWED_USERS] = update_request.allowed_users + update_dict["fields"][ALLOWED_USERS] = { + "assign": update_request.allowed_users + } for document_id in update_request.document_ids: for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(document_id): url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}" - res = requests.put(url, headers=json_header, json=update_dict) + res = requests.put( + url, headers=json_header, data=json.dumps(update_dict) + ) try: res.raise_for_status() diff --git a/backend/danswer/db/document.py b/backend/danswer/db/document.py index c35db0c86..eab0db814 100644 --- a/backend/danswer/db/document.py +++ b/backend/danswer/db/document.py @@ -7,8 +7,9 @@ from sqlalchemy import select from sqlalchemy.dialects.postgresql import insert from sqlalchemy.orm import Session +from danswer.configs.constants import DEFAULT_BOOST from danswer.datastores.interfaces import DocumentMetadata -from danswer.db.models import Document +from danswer.db.models import Document as DbDocument from danswer.db.models import DocumentByConnectorCredentialPair from danswer.db.utils import model_to_dict from danswer.utils.logger import setup_logger @@ -20,7 +21,7 @@ def get_documents_with_single_connector_credential_pair( db_session: Session, connector_id: int, credential_id: int, -) -> Sequence[Document]: +) -> Sequence[DbDocument]: initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where( and_( DocumentByConnectorCredentialPair.connector_id == connector_id, @@ -31,17 +32,17 @@ def get_documents_with_single_connector_credential_pair( # Filter it down to the documents with only a single connector/credential pair # Meaning if this connector/credential pair is removed, this doc should be gone trimmed_doc_ids_stmt = ( - select(Document.id) + select(DbDocument.id) .join( DocumentByConnectorCredentialPair, - DocumentByConnectorCredentialPair.id == Document.id, + DocumentByConnectorCredentialPair.id == DbDocument.id, ) - .where(Document.id.in_(initial_doc_ids_stmt)) - .group_by(Document.id) + .where(DbDocument.id.in_(initial_doc_ids_stmt)) + .group_by(DbDocument.id) .having(func.count(DocumentByConnectorCredentialPair.id) == 1) ) - stmt = select(Document).where(Document.id.in_(trimmed_doc_ids_stmt)) + stmt = select(DbDocument).where(DbDocument.id.in_(trimmed_doc_ids_stmt)) return db_session.scalars(stmt).all() @@ -60,13 +61,13 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple( # Filter it down to the documents with more than 1 connector/credential pair # Meaning if this connector/credential pair is removed, this doc is still accessible trimmed_doc_ids_stmt = ( - select(Document.id) + select(DbDocument.id) .join( DocumentByConnectorCredentialPair, - DocumentByConnectorCredentialPair.id == Document.id, + DocumentByConnectorCredentialPair.id == DbDocument.id, ) - .where(Document.id.in_(initial_doc_ids_stmt)) - .group_by(Document.id) + .where(DbDocument.id.in_(initial_doc_ids_stmt)) + .group_by(DbDocument.id) .having(func.count(DocumentByConnectorCredentialPair.id) > 1) ) @@ -81,13 +82,25 @@ def upsert_documents( db_session: Session, document_metadata_batch: list[DocumentMetadata] ) -> None: """NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause.""" - seen_document_ids: set[str] = set() + seen_documents: dict[str, DocumentMetadata] = {} for document_metadata in document_metadata_batch: - if document_metadata.document_id not in seen_document_ids: - seen_document_ids.add(document_metadata.document_id) + doc_id = document_metadata.document_id + if doc_id not in seen_documents: + seen_documents[doc_id] = document_metadata - insert_stmt = insert(Document).values( - [model_to_dict(Document(id=doc_id)) for doc_id in seen_document_ids] + insert_stmt = insert(DbDocument).values( + [ + model_to_dict( + DbDocument( + id=doc.document_id, + boost=DEFAULT_BOOST, + hidden=False, + semantic_id=doc.semantic_identifier, + link=doc.first_link, + ) + ) + for doc in seen_documents.values() + ] ) # for now, there are no columns to update. If more metadata is added, then this # needs to change to an `on_conflict_do_update` @@ -120,7 +133,8 @@ def upsert_document_by_connector_credential_pair( def upsert_documents_complete( - db_session: Session, document_metadata_batch: list[DocumentMetadata] + db_session: Session, + document_metadata_batch: list[DocumentMetadata], ) -> None: upsert_documents(db_session, document_metadata_batch) upsert_document_by_connector_credential_pair(db_session, document_metadata_batch) @@ -140,7 +154,7 @@ def delete_document_by_connector_credential_pair( def delete_documents(db_session: Session, document_ids: list[str]) -> None: - db_session.execute(delete(Document).where(Document.id.in_(document_ids))) + db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids))) def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None: diff --git a/backend/danswer/db/feedback.py b/backend/danswer/db/feedback.py new file mode 100644 index 000000000..adc00546e --- /dev/null +++ b/backend/danswer/db/feedback.py @@ -0,0 +1,156 @@ +from uuid import UUID + +from sqlalchemy import asc +from sqlalchemy import desc +from sqlalchemy import select +from sqlalchemy.orm import Session + +from danswer.configs.constants import QAFeedbackType +from danswer.configs.constants import SearchFeedbackType +from danswer.datastores.datastore_utils import translate_boost_count_to_multiplier +from danswer.datastores.document_index import get_default_document_index +from danswer.datastores.interfaces import UpdateRequest +from danswer.db.models import Document as DbDocument +from danswer.db.models import DocumentRetrievalFeedback +from danswer.db.models import QueryEvent +from danswer.search.models import SearchType + + +def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent: + stmt = select(QueryEvent).where(QueryEvent.id == query_id) + result = db_session.execute(stmt) + query_event = result.scalar_one_or_none() + + if not query_event: + raise ValueError("Invalid Query Event provided for updating") + + return query_event + + +def fetch_doc_m_by_id(doc_id: str, db_session: Session) -> DbDocument: + stmt = select(DbDocument).where(DbDocument.id == doc_id) + result = db_session.execute(stmt) + doc_m = result.scalar_one_or_none() + + if not doc_m: + raise ValueError("Invalid Document provided for updating") + + return doc_m + + +def fetch_docs_ranked_by_boost( + db_session: Session, ascending: bool = False, limit: int = 100 +) -> list[DbDocument]: + order_func = asc if ascending else desc + stmt = select(DbDocument).order_by(order_func(DbDocument.boost)).limit(limit) + result = db_session.execute(stmt) + doc_m_list = result.scalars().all() + + return list(doc_m_list) + + +def create_document_metadata( + doc_id: str, + semantic_id: str, + link: str | None, + db_session: Session, +) -> None: + try: + fetch_doc_m_by_id(doc_id, db_session) + return + except ValueError: + # Document already exists, don't reset its data + pass + + DbDocument( + id=doc_id, + semantic_id=semantic_id, + link=link, + ) + + +def create_query_event( + query: str, + selected_flow: SearchType | None, + llm_answer: str | None, + user_id: UUID | None, + db_session: Session, +) -> int: + query_event = QueryEvent( + query=query, + selected_search_flow=selected_flow, + llm_answer=llm_answer, + user_id=user_id, + ) + db_session.add(query_event) + db_session.commit() + + return query_event.id + + +def update_query_event_feedback( + feedback: QAFeedbackType, + query_id: int, + user_id: UUID | None, + db_session: Session, +) -> None: + query_event = fetch_query_event_by_id(query_id, db_session) + + if user_id != query_event.user_id: + raise ValueError("User trying to give feedback on a query run by another user.") + + query_event.feedback = feedback + + db_session.commit() + + +def create_doc_retrieval_feedback( + qa_event_id: int, + document_id: str, + document_rank: int, + user_id: UUID | None, + db_session: Session, + clicked: bool = False, + feedback: SearchFeedbackType | None = None, +) -> None: + if not clicked and feedback is None: + raise ValueError("No action taken, not valid feedback") + + query_event = fetch_query_event_by_id(qa_event_id, db_session) + + if user_id != query_event.user_id: + raise ValueError("User trying to give feedback on a query run by another user.") + + doc_m = fetch_doc_m_by_id(document_id, db_session) + + retrieval_feedback = DocumentRetrievalFeedback( + qa_event_id=qa_event_id, + document_id=document_id, + document_rank=document_rank, + clicked=clicked, + feedback=feedback, + ) + + if feedback is not None: + if feedback == SearchFeedbackType.ENDORSE: + doc_m.boost += 1 + elif feedback == SearchFeedbackType.REJECT: + doc_m.boost -= 1 + elif feedback == SearchFeedbackType.HIDE: + doc_m.hidden = True + elif feedback == SearchFeedbackType.UNHIDE: + doc_m.hidden = False + else: + raise ValueError("Unhandled document feedback type") + + if feedback in [SearchFeedbackType.ENDORSE, SearchFeedbackType.REJECT]: + document_index = get_default_document_index() + update = UpdateRequest( + document_ids=[document_id], + boost=translate_boost_count_to_multiplier(doc_m.boost), + ) + # Updates are generally batched for efficiency, this case only 1 doc/value is updated + document_index.update([update]) + + db_session.add(retrieval_feedback) + db_session.commit() diff --git a/backend/danswer/db/models.py b/backend/danswer/db/models.py index e8b42be56..7d845b5d7 100644 --- a/backend/danswer/db/models.py +++ b/backend/danswer/db/models.py @@ -23,8 +23,12 @@ from sqlalchemy.orm import mapped_column from sqlalchemy.orm import relationship from danswer.auth.schemas import UserRole +from danswer.configs.constants import DEFAULT_BOOST from danswer.configs.constants import DocumentSource +from danswer.configs.constants import QAFeedbackType +from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType +from danswer.search.models import SearchType class IndexingStatus(str, PyEnum): @@ -61,6 +65,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base): credentials: Mapped[List["Credential"]] = relationship( "Credential", back_populates="user", lazy="joined" ) + query_events: Mapped[List["QueryEvent"]] = relationship( + "QueryEvent", back_populates="user" + ) class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base): @@ -162,7 +169,7 @@ class Credential(Base): deletion_attempt: Mapped[Optional["DeletionAttempt"]] = relationship( "DeletionAttempt", back_populates="credential" ) - user: Mapped[User] = relationship("User", back_populates="credentials") + user: Mapped[User | None] = relationship("User", back_populates="credentials") class IndexAttempt(Base): @@ -258,17 +265,6 @@ class DeletionAttempt(Base): ) -class Document(Base): - """Represents a single documents from a source. This is used to store - document level metadata, but currently nothing is stored""" - - __tablename__ = "document" - - # this should correspond to the ID of the document (as is passed around - # in Danswer) - id: Mapped[str] = mapped_column(String, primary_key=True) - - class DocumentByConnectorCredentialPair(Base): """Represents an indexing of a document by a specific connector / credential pair""" @@ -289,3 +285,72 @@ class DocumentByConnectorCredentialPair(Base): credential: Mapped[Credential] = relationship( "Credential", back_populates="documents_by_credential" ) + + +class QueryEvent(Base): + __tablename__ = "query_event" + + id: Mapped[int] = mapped_column(primary_key=True) + query: Mapped[str] = mapped_column(String()) + # search_flow refers to user selection, None if user used auto + selected_search_flow: Mapped[SearchType | None] = mapped_column( + Enum(SearchType), nullable=True + ) + llm_answer: Mapped[str | None] = mapped_column(String(), default=None) + feedback: Mapped[QAFeedbackType | None] = mapped_column( + Enum(QAFeedbackType), nullable=True + ) + user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True) + time_created: Mapped[datetime.datetime] = mapped_column( + DateTime(timezone=True), + server_default=func.now(), + ) + + user: Mapped[User | None] = relationship("User", back_populates="query_events") + document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship( + "DocumentRetrievalFeedback", back_populates="qa_event" + ) + + +class DocumentRetrievalFeedback(Base): + __tablename__ = "document_retrieval_feedback" + + id: Mapped[int] = mapped_column(primary_key=True) + qa_event_id: Mapped[int] = mapped_column( + ForeignKey("query_event.id"), + ) + document_id: Mapped[str] = mapped_column( + ForeignKey("document.id"), + ) + # How high up this document is in the results, 1 for first + document_rank: Mapped[int] = mapped_column(Integer) + clicked: Mapped[bool] = mapped_column(Boolean, default=False) + feedback: Mapped[SearchFeedbackType | None] = mapped_column( + Enum(SearchFeedbackType), nullable=True + ) + + qa_event: Mapped[QueryEvent] = relationship( + "QueryEvent", back_populates="document_feedbacks" + ) + document: Mapped["Document"] = relationship( + "Document", back_populates="retrieval_feedbacks" + ) + + +class Document(Base): + __tablename__ = "document" + + # this should correspond to the ID of the document (as is passed around + # in Danswer) + id: Mapped[str] = mapped_column(String, primary_key=True) + # 0 for neutral, positive for mostly endorse, negative for mostly reject + boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST) + hidden: Mapped[bool] = mapped_column(Boolean, default=False) + semantic_id: Mapped[str] = mapped_column(String) + # First Section's link + link: Mapped[str | None] = mapped_column(String, nullable=True) + # TODO if more sensitive data is added here for display, make sure to add user/group permission + + retrieval_feedbacks: Mapped[List[DocumentRetrievalFeedback]] = relationship( + "DocumentRetrievalFeedback", back_populates="document" + ) diff --git a/backend/danswer/direct_qa/answer_question.py b/backend/danswer/direct_qa/answer_question.py index 7af0b6688..d810f323a 100644 --- a/backend/danswer/direct_qa/answer_question.py +++ b/backend/danswer/direct_qa/answer_question.py @@ -1,8 +1,11 @@ +from sqlalchemy.orm import Session + from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.configs.app_configs import QA_TIMEOUT from danswer.datastores.document_index import get_default_document_index +from danswer.db.feedback import create_query_event from danswer.db.models import User from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError @@ -22,19 +25,27 @@ logger = setup_logger() @log_function_time() -def answer_question( +def answer_qa_query( question: QuestionRequest, user: User | None, + db_session: Session, disable_generative_answer: bool = DISABLE_GENERATIVE_AI, answer_generation_timeout: int = QA_TIMEOUT, ) -> QAResponse: query = question.query - collection = question.collection filters = question.filters use_keyword = question.use_keyword offset_count = question.offset if question.offset is not None else 0 logger.info(f"Received QA query: {query}") + query_event_id = create_query_event( + query=query, + selected_flow=SearchType.KEYWORD, + llm_answer=None, + user_id=user.id if user is not None else None, + db_session=db_session, + ) + predicted_search, predicted_flow = query_intent(query) if use_keyword is None: use_keyword = predicted_search == SearchType.KEYWORD @@ -42,12 +53,12 @@ def answer_question( user_id = None if user is None else user.id if use_keyword: ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents( - query, user_id, filters, get_default_document_index(collection=collection) + query, user_id, filters, get_default_document_index() ) unranked_chunks: list[InferenceChunk] | None = [] else: ranked_chunks, unranked_chunks = retrieve_ranked_documents( - query, user_id, filters, get_default_document_index(collection=collection) + query, user_id, filters, get_default_document_index() ) if not ranked_chunks: return QAResponse( @@ -57,6 +68,7 @@ def answer_question( lower_ranked_docs=None, predicted_flow=predicted_flow, predicted_search=predicted_search, + query_event_id=query_event_id, ) if disable_generative_answer: @@ -70,6 +82,7 @@ def answer_question( # to run QA over more documents predicted_flow=QueryFlow.SEARCH, predicted_search=predicted_search, + query_event_id=query_event_id, ) try: @@ -83,6 +96,7 @@ def answer_question( predicted_flow=predicted_flow, predicted_search=predicted_search, error_msg=str(e), + query_event_id=query_event_id, ) chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS @@ -108,4 +122,5 @@ def answer_question( predicted_flow=predicted_flow, predicted_search=predicted_search, error_msg=error_msg, + query_event_id=query_event_id, ) diff --git a/backend/danswer/listeners/slack_listener.py b/backend/danswer/listeners/slack_listener.py index f13e710d1..f9ed5e21e 100644 --- a/backend/danswer/listeners/slack_listener.py +++ b/backend/danswer/listeners/slack_listener.py @@ -9,6 +9,7 @@ from slack_sdk import WebClient from slack_sdk.socket_mode import SocketModeClient from slack_sdk.socket_mode.request import SocketModeRequest from slack_sdk.socket_mode.response import SocketModeResponse +from sqlalchemy.orm import Session from danswer.configs.app_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT from danswer.configs.app_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS @@ -18,7 +19,8 @@ from danswer.configs.app_configs import DOCUMENT_INDEX_NAME from danswer.configs.constants import DocumentSource from danswer.connectors.slack.utils import make_slack_api_rate_limited from danswer.connectors.slack.utils import UserIdReplacer -from danswer.direct_qa.answer_question import answer_question +from danswer.db.engine import get_sqlalchemy_engine +from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.interfaces import DanswerQuote from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest @@ -228,17 +230,19 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non logger=cast(logging.Logger, logger), ) def _get_answer(question: QuestionRequest) -> QAResponse: - answer = answer_question( - question=question, - user=None, - answer_generation_timeout=DANSWER_BOT_ANSWER_GENERATION_TIMEOUT, - ) - if not answer.error_msg: - return answer - else: - raise RuntimeError(answer.error_msg) + engine = get_sqlalchemy_engine() + with Session(engine, expire_on_commit=False) as db_session: + answer = answer_qa_query( + question=question, + user=None, + db_session=db_session, + answer_generation_timeout=DANSWER_BOT_ANSWER_GENERATION_TIMEOUT, + ) + if not answer.error_msg: + return answer + else: + raise RuntimeError(answer.error_msg) - answer = None try: answer = _get_answer( QuestionRequest( diff --git a/backend/danswer/llm/azure.py b/backend/danswer/llm/azure.py index 49a91afac..cce164466 100644 --- a/backend/danswer/llm/azure.py +++ b/backend/danswer/llm/azure.py @@ -1,3 +1,4 @@ +import os from typing import Any from langchain.chat_models.azure_openai import AzureChatOpenAI @@ -22,6 +23,11 @@ class AzureGPT(LangChainChatLLM): *args: list[Any], **kwargs: dict[str, Any] ): + # set a dummy API key if not specified so that LangChain doesn't throw an + # exception when trying to initialize the LLM which would prevent the API + # server from starting up + if not api_key: + api_key = os.environ.get("OPENAI_API_KEY") or "dummy_api_key" self._llm = AzureChatOpenAI( model=model_version, openai_api_type="azure", diff --git a/backend/danswer/llm/openai.py b/backend/danswer/llm/openai.py index 4aa9274a0..891e52586 100644 --- a/backend/danswer/llm/openai.py +++ b/backend/danswer/llm/openai.py @@ -1,3 +1,4 @@ +import os from typing import Any from langchain.chat_models.openai import ChatOpenAI @@ -16,6 +17,11 @@ class OpenAIGPT(LangChainChatLLM): *args: list[Any], **kwargs: dict[str, Any] ): + # set a dummy API key if not specified so that LangChain doesn't throw an + # exception when trying to initialize the LLM which would prevent the API + # server from starting up + if not api_key: + api_key = os.environ.get("OPENAI_API_KEY") or "dummy_api_key" self._llm = ChatOpenAI( model=model_version, openai_api_key=api_key, diff --git a/backend/danswer/server/manage.py b/backend/danswer/server/manage.py index 799e5bdaa..a4d897083 100644 --- a/backend/danswer/server/manage.py +++ b/backend/danswer/server/manage.py @@ -49,6 +49,7 @@ from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed from danswer.db.deletion_attempt import create_deletion_attempt from danswer.db.deletion_attempt import get_deletion_attempts from danswer.db.engine import get_session +from danswer.db.feedback import fetch_docs_ranked_by_boost from danswer.db.index_attempt import create_index_attempt from danswer.db.index_attempt import get_latest_index_attempts from danswer.db.models import DeletionAttempt @@ -61,6 +62,7 @@ from danswer.dynamic_configs.interface import ConfigNotFoundError from danswer.server.models import ApiKey from danswer.server.models import AuthStatus from danswer.server.models import AuthUrl +from danswer.server.models import BoostDoc from danswer.server.models import ConnectorBase from danswer.server.models import ConnectorCredentialPairIdentifier from danswer.server.models import ConnectorIndexingStatus @@ -79,7 +81,6 @@ from danswer.server.models import StatusResponse from danswer.server.models import UserRoleResponse from danswer.utils.logger import setup_logger - router = APIRouter(prefix="/manage") logger = setup_logger() @@ -89,6 +90,28 @@ _GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id" """Admin only API endpoints""" +@router.get("/admin/doc-boosts") +def get_most_boosted_docs( + ascending: bool, + limit: int, + _: User | None = Depends(current_admin_user), + db_session: Session = Depends(get_session), +) -> list[BoostDoc]: + boost_docs = fetch_docs_ranked_by_boost( + ascending=ascending, limit=limit, db_session=db_session + ) + return [ + BoostDoc( + document_id=doc.id, + semantic_id=doc.semantic_id, + link=doc.link or "", + boost=doc.boost, + hidden=doc.hidden, + ) + for doc in boost_docs + ] + + @router.get("/admin/connector/google-drive/app-credential") def check_google_app_credentials_exist( _: User = Depends(current_admin_user), diff --git a/backend/danswer/server/models.py b/backend/danswer/server/models.py index f79ce9072..1052b7e23 100644 --- a/backend/danswer/server/models.py +++ b/backend/danswer/server/models.py @@ -11,6 +11,8 @@ from pydantic.generics import GenericModel from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX from danswer.configs.constants import DocumentSource +from danswer.configs.constants import QAFeedbackType +from danswer.configs.constants import SearchFeedbackType from danswer.connectors.models import InputType from danswer.datastores.interfaces import IndexFilter from danswer.db.models import Connector @@ -105,6 +107,14 @@ class UserRoleResponse(BaseModel): role: str +class BoostDoc(BaseModel): + document_id: str + semantic_id: str + link: str + boost: int + hidden: bool + + class SearchDoc(BaseModel): document_id: str semantic_identifier: str @@ -121,10 +131,24 @@ class QuestionRequest(BaseModel): offset: int | None +class QAFeedbackRequest(BaseModel): + query_id: int + feedback: QAFeedbackType + + +class SearchFeedbackRequest(BaseModel): + query_id: int + document_id: str + document_rank: int + click: bool + search_feedback: SearchFeedbackType + + class SearchResponse(BaseModel): # For semantic search, top docs are reranked, the remaining are as ordered from retrieval top_ranked_docs: list[SearchDoc] | None lower_ranked_docs: list[SearchDoc] | None + query_event_id: int class QAResponse(SearchResponse): diff --git a/backend/danswer/server/search_backend.py b/backend/danswer/server/search_backend.py index 1e4cc47f3..497095980 100644 --- a/backend/danswer/server/search_backend.py +++ b/backend/danswer/server/search_backend.py @@ -5,16 +5,22 @@ from dataclasses import asdict from fastapi import APIRouter from fastapi import Depends from fastapi.responses import StreamingResponse +from sqlalchemy.orm import Session from danswer.auth.users import current_user from danswer.chunking.models import InferenceChunk from danswer.configs.app_configs import DISABLE_GENERATIVE_AI from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS from danswer.datastores.document_index import get_default_document_index +from danswer.db.engine import get_session +from danswer.db.feedback import create_doc_retrieval_feedback +from danswer.db.feedback import create_query_event +from danswer.db.feedback import update_query_event_feedback from danswer.db.models import User -from danswer.direct_qa.answer_question import answer_question +from danswer.direct_qa.answer_question import answer_qa_query from danswer.direct_qa.exceptions import OpenAIKeyMissing from danswer.direct_qa.exceptions import UnknownModelError +from danswer.direct_qa.interfaces import DanswerAnswerPiece from danswer.direct_qa.llm_utils import get_default_qa_model from danswer.search.danswer_helper import query_intent from danswer.search.danswer_helper import recommend_search_flow @@ -24,8 +30,10 @@ from danswer.search.models import SearchType from danswer.search.semantic_search import chunks_to_search_docs from danswer.search.semantic_search import retrieve_ranked_documents from danswer.server.models import HelperResponse +from danswer.server.models import QAFeedbackRequest from danswer.server.models import QAResponse from danswer.server.models import QuestionRequest +from danswer.server.models import SearchFeedbackRequest from danswer.server.models import SearchResponse from danswer.utils.logger import setup_logger from danswer.utils.timing import log_generator_function_time @@ -50,62 +58,95 @@ def get_search_type( @router.post("/semantic-search") def semantic_search( - question: QuestionRequest, user: User = Depends(current_user) + question: QuestionRequest, + user: User = Depends(current_user), + db_session: Session = Depends(get_session), ) -> SearchResponse: query = question.query - collection = question.collection filters = question.filters logger.info(f"Received semantic search query: {query}") + query_event_id = create_query_event( + query=query, + selected_flow=SearchType.SEMANTIC, + llm_answer=None, + user_id=user.id, + db_session=db_session, + ) + user_id = None if user is None else user.id ranked_chunks, unranked_chunks = retrieve_ranked_documents( - query, user_id, filters, get_default_document_index(collection=collection) + query, user_id, filters, get_default_document_index() ) if not ranked_chunks: - return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None) + return SearchResponse( + top_ranked_docs=None, lower_ranked_docs=None, query_event_id=query_event_id + ) top_docs = chunks_to_search_docs(ranked_chunks) other_top_docs = chunks_to_search_docs(unranked_chunks) - return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=other_top_docs) + return SearchResponse( + top_ranked_docs=top_docs, + lower_ranked_docs=other_top_docs, + query_event_id=query_event_id, + ) @router.post("/keyword-search") def keyword_search( - question: QuestionRequest, user: User = Depends(current_user) + question: QuestionRequest, + user: User = Depends(current_user), + db_session: Session = Depends(get_session), ) -> SearchResponse: query = question.query - collection = question.collection filters = question.filters logger.info(f"Received keyword search query: {query}") + query_event_id = create_query_event( + query=query, + selected_flow=SearchType.KEYWORD, + llm_answer=None, + user_id=user.id, + db_session=db_session, + ) + user_id = None if user is None else user.id ranked_chunks = retrieve_keyword_documents( - query, user_id, filters, get_default_document_index(collection=collection) + query, user_id, filters, get_default_document_index() ) if not ranked_chunks: - return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None) + return SearchResponse( + top_ranked_docs=None, lower_ranked_docs=None, query_event_id=query_event_id + ) top_docs = chunks_to_search_docs(ranked_chunks) - return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=None) + return SearchResponse( + top_ranked_docs=top_docs, lower_ranked_docs=None, query_event_id=query_event_id + ) @router.post("/direct-qa") def direct_qa( - question: QuestionRequest, user: User = Depends(current_user) + question: QuestionRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), ) -> QAResponse: - return answer_question(question=question, user=user) + return answer_qa_query(question=question, user=user, db_session=db_session) @router.post("/stream-direct-qa") def stream_direct_qa( - question: QuestionRequest, user: User = Depends(current_user) + question: QuestionRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), ) -> StreamingResponse: send_packet_debug_msg = "Sending Packet: {}" top_documents_key = "top_documents" unranked_top_docs_key = "unranked_top_documents" predicted_flow_key = "predicted_flow" predicted_search_key = "predicted_search" + query_event_id_key = "query_event_id" logger.debug(f"Received QA query: {question.query}") logger.debug(f"Query filters: {question.filters}") @@ -116,8 +157,8 @@ def stream_direct_qa( def stream_qa_portions( disable_generative_answer: bool = DISABLE_GENERATIVE_AI, ) -> Generator[str, None, None]: + answer_so_far: str = "" query = question.query - collection = question.collection filters = question.filters use_keyword = question.use_keyword offset_count = question.offset if question.offset is not None else 0 @@ -132,7 +173,7 @@ def stream_direct_qa( query, user_id, filters, - get_default_document_index(collection=collection), + get_default_document_index(), ) unranked_chunks: list[InferenceChunk] | None = [] else: @@ -140,7 +181,7 @@ def stream_direct_qa( query, user_id, filters, - get_default_document_index(collection=collection), + get_default_document_index(), ) if not ranked_chunks: logger.debug("No Documents Found") @@ -194,6 +235,11 @@ def stream_direct_qa( ): if response_packet is None: continue + if ( + isinstance(response_packet, DanswerAnswerPiece) + and response_packet.answer_piece + ): + answer_so_far = answer_so_far + response_packet.answer_piece logger.debug(f"Sending packet: {response_packet}") yield get_json_line(asdict(response_packet)) except Exception as e: @@ -201,6 +247,49 @@ def stream_direct_qa( yield get_json_line({"error": str(e)}) logger.exception("Failed to run QA") + query_event_id = create_query_event( + query=query, + selected_flow=SearchType.KEYWORD + if question.use_keyword + else SearchType.SEMANTIC, + llm_answer=answer_so_far, + user_id=user_id, + db_session=db_session, + ) + + yield get_json_line({query_event_id_key: query_event_id}) + return return StreamingResponse(stream_qa_portions(), media_type="application/json") + + +@router.post("/query-feedback") +def process_query_feedback( + feedback: QAFeedbackRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + update_query_event_feedback( + feedback=feedback.feedback, + query_id=feedback.query_id, + user_id=user.id if user is not None else None, + db_session=db_session, + ) + + +@router.post("/doc-retrieval-feedback") +def process_doc_retrieval_feedback( + feedback: SearchFeedbackRequest, + user: User | None = Depends(current_user), + db_session: Session = Depends(get_session), +) -> None: + create_doc_retrieval_feedback( + qa_event_id=feedback.query_id, + document_id=feedback.document_id, + document_rank=feedback.document_rank, + clicked=feedback.click, + feedback=feedback.search_feedback, + user_id=user.id if user is not None else None, + db_session=db_session, + )