Learn from feedback backend (#343)

---------

Co-authored-by: Weves <chrisweaver101@gmail.com>
This commit is contained in:
Yuhong Sun 2023-08-28 13:29:29 -07:00 committed by GitHub
parent c43a403b71
commit b2a51283d1
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 610 additions and 71 deletions

View File

@ -0,0 +1,93 @@
"""Feedback Feature
Revision ID: d929f0c1c6af
Revises: 8aabb57f3b49
Create Date: 2023-08-27 13:03:54.274987
"""
import fastapi_users_db_sqlalchemy
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = "d929f0c1c6af"
down_revision = "8aabb57f3b49"
branch_labels = None
depends_on = None
def upgrade() -> None:
op.create_table(
"query_event",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("query", sa.String(), nullable=False),
sa.Column(
"selected_search_flow",
sa.Enum("KEYWORD", "SEMANTIC", name="searchtype"),
nullable=True,
),
sa.Column("llm_answer", sa.String(), nullable=True),
sa.Column(
"feedback",
sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype"),
nullable=True,
),
sa.Column(
"user_id",
fastapi_users_db_sqlalchemy.generics.GUID(),
nullable=True,
),
sa.Column(
"time_created",
sa.DateTime(timezone=True),
server_default=sa.text("now()"),
nullable=False,
),
sa.ForeignKeyConstraint(
["user_id"],
["user.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.create_table(
"document_retrieval_feedback",
sa.Column("id", sa.Integer(), nullable=False),
sa.Column("qa_event_id", sa.Integer(), nullable=False),
sa.Column("document_id", sa.String(), nullable=False),
sa.Column("document_rank", sa.Integer(), nullable=False),
sa.Column("clicked", sa.Boolean(), nullable=False),
sa.Column(
"feedback",
sa.Enum(
"ENDORSE",
"REJECT",
"HIDE",
"UNHIDE",
name="searchfeedbacktype",
),
nullable=True,
),
sa.ForeignKeyConstraint(
["document_id"],
["document.id"],
),
sa.ForeignKeyConstraint(
["qa_event_id"],
["query_event.id"],
),
sa.PrimaryKeyConstraint("id"),
)
op.add_column("document", sa.Column("boost", sa.Integer(), nullable=False))
op.add_column("document", sa.Column("hidden", sa.Boolean(), nullable=False))
op.add_column("document", sa.Column("semantic_id", sa.String(), nullable=False))
op.add_column("document", sa.Column("link", sa.String(), nullable=True))
def downgrade() -> None:
op.drop_column("document", "link")
op.drop_column("document", "semantic_id")
op.drop_column("document", "hidden")
op.drop_column("document", "boost")
op.drop_table("document_retrieval_feedback")
op.drop_table("query_event")

View File

@ -90,7 +90,7 @@ def _delete_connector_credential_pair(
def _get_user(
credential: Credential,
) -> str:
if credential.public_doc:
if credential.public_doc or not credential.user:
return PUBLIC_DOC_PAT
return str(credential.user.id)

View File

@ -23,6 +23,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai
from danswer.db.credentials import backend_update_credential_json
from danswer.db.engine import get_db_current_time
from danswer.db.engine import get_sqlalchemy_engine
from danswer.db.feedback import create_document_metadata
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_index_attempt
from danswer.db.index_attempt import get_inprogress_index_attempts
@ -246,6 +247,7 @@ def _run_indexing(
logger.debug(
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
)
index_user_id = (
None if db_credential.public_doc else db_credential.user_id
)

View File

@ -18,6 +18,7 @@ HTML_SEPARATOR = "\n"
PUBLIC_DOC_PAT = "PUBLIC"
QUOTE = "quote"
BOOST = "boost"
DEFAULT_BOOST = 0
class DocumentSource(str, Enum):
@ -66,3 +67,15 @@ class ModelHostType(str, Enum):
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
COLAB_DEMO = "colab-demo"
# TODO support for Azure, AWS, GCP GenAI model hosting
class QAFeedbackType(str, Enum):
LIKE = "like" # User likes the answer, used for metrics
DISLIKE = "dislike" # User dislikes the answer, used for metrics
class SearchFeedbackType(str, Enum):
ENDORSE = "endorse" # boost this document for all future queries
REJECT = "reject" # down-boost this document for all future queries
HIDE = "hide" # mark this document as untrusted, hide from LLM
UNHIDE = "unhide"

View File

@ -12,6 +12,15 @@ from danswer.connectors.models import IndexAttemptMetadata
DEFAULT_BATCH_SIZE = 30
BOOST_MULTIPLIER = 1.2
def translate_boost_count_to_multiplier(boost: int) -> float:
if boost > 0:
return BOOST_MULTIPLIER**boost
elif boost < 0:
return 1 / (BOOST_MULTIPLIER**boost)
return 1
def get_uuid_from_chunk(

View File

@ -32,6 +32,7 @@ class IndexingPipelineProtocol(Protocol):
def _upsert_insertion_records(
insertion_records: set[DocumentInsertionRecord],
index_attempt_metadata: IndexAttemptMetadata,
doc_m_data_lookup: dict[str, tuple[str, str]],
) -> None:
with Session(get_sqlalchemy_engine()) as session:
upsert_documents_complete(
@ -40,9 +41,11 @@ def _upsert_insertion_records(
DocumentMetadata(
connector_id=index_attempt_metadata.connector_id,
credential_id=index_attempt_metadata.credential_id,
document_id=insertion_record.document_id,
document_id=i_r.document_id,
semantic_identifier=doc_m_data_lookup[i_r.document_id][0],
first_link=doc_m_data_lookup[i_r.document_id][1],
)
for insertion_record in insertion_records
for i_r in insertion_records
],
)
@ -62,6 +65,11 @@ def _get_net_new_documents(
return net_new_documents
def _extract_minimal_document_metadata(doc: Document) -> tuple[str, str]:
first_link = next((section.link for section in doc.sections if section.link), "")
return doc.semantic_identifier, first_link
def _indexing_pipeline(
*,
chunker: Chunker,
@ -73,6 +81,11 @@ def _indexing_pipeline(
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
Note that the documents should already be batched at this point so that it does not inflate the
memory requirements"""
document_metadata_lookup = {
doc.id: _extract_minimal_document_metadata(doc) for doc in documents
}
chunks: list[DocAwareChunk] = list(
chain(*[chunker.chunk(document=document) for document in documents])
)
@ -92,6 +105,7 @@ def _indexing_pipeline(
_upsert_insertion_records(
insertion_records=insertion_records,
index_attempt_metadata=index_attempt_metadata,
doc_m_data_lookup=document_metadata_lookup,
)
except Exception as e:
logger.error(

View File

@ -22,6 +22,8 @@ class DocumentMetadata:
connector_id: int
credential_id: int
document_id: str
semantic_identifier: str
first_link: str
@dataclass
@ -32,7 +34,7 @@ class UpdateRequest:
document_ids: list[str]
# all other fields will be left alone
allowed_users: list[str] | None = None
boost: int | None = None
boost: float | None = None
class Verifiable(abc.ABC):

View File

@ -341,16 +341,20 @@ class VespaIndex(DocumentIndex):
logger.error("Update request received but nothing to update")
continue
update_dict: dict[str, dict[str, list[str] | int]] = {"fields": {}}
update_dict: dict[str, dict] = {"fields": {}}
if update_request.boost:
update_dict["fields"][BOOST] = update_request.boost
update_dict["fields"][BOOST] = {"assign": update_request.boost}
if update_request.allowed_users:
update_dict["fields"][ALLOWED_USERS] = update_request.allowed_users
update_dict["fields"][ALLOWED_USERS] = {
"assign": update_request.allowed_users
}
for document_id in update_request.document_ids:
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(document_id):
url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}"
res = requests.put(url, headers=json_header, json=update_dict)
res = requests.put(
url, headers=json_header, data=json.dumps(update_dict)
)
try:
res.raise_for_status()

View File

@ -7,8 +7,9 @@ from sqlalchemy import select
from sqlalchemy.dialects.postgresql import insert
from sqlalchemy.orm import Session
from danswer.configs.constants import DEFAULT_BOOST
from danswer.datastores.interfaces import DocumentMetadata
from danswer.db.models import Document
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentByConnectorCredentialPair
from danswer.db.utils import model_to_dict
from danswer.utils.logger import setup_logger
@ -20,7 +21,7 @@ def get_documents_with_single_connector_credential_pair(
db_session: Session,
connector_id: int,
credential_id: int,
) -> Sequence[Document]:
) -> Sequence[DbDocument]:
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
and_(
DocumentByConnectorCredentialPair.connector_id == connector_id,
@ -31,17 +32,17 @@ def get_documents_with_single_connector_credential_pair(
# Filter it down to the documents with only a single connector/credential pair
# Meaning if this connector/credential pair is removed, this doc should be gone
trimmed_doc_ids_stmt = (
select(Document.id)
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,
DocumentByConnectorCredentialPair.id == DbDocument.id,
)
.where(Document.id.in_(initial_doc_ids_stmt))
.group_by(Document.id)
.where(DbDocument.id.in_(initial_doc_ids_stmt))
.group_by(DbDocument.id)
.having(func.count(DocumentByConnectorCredentialPair.id) == 1)
)
stmt = select(Document).where(Document.id.in_(trimmed_doc_ids_stmt))
stmt = select(DbDocument).where(DbDocument.id.in_(trimmed_doc_ids_stmt))
return db_session.scalars(stmt).all()
@ -60,13 +61,13 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple(
# Filter it down to the documents with more than 1 connector/credential pair
# Meaning if this connector/credential pair is removed, this doc is still accessible
trimmed_doc_ids_stmt = (
select(Document.id)
select(DbDocument.id)
.join(
DocumentByConnectorCredentialPair,
DocumentByConnectorCredentialPair.id == Document.id,
DocumentByConnectorCredentialPair.id == DbDocument.id,
)
.where(Document.id.in_(initial_doc_ids_stmt))
.group_by(Document.id)
.where(DbDocument.id.in_(initial_doc_ids_stmt))
.group_by(DbDocument.id)
.having(func.count(DocumentByConnectorCredentialPair.id) > 1)
)
@ -81,13 +82,25 @@ def upsert_documents(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
) -> None:
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
seen_document_ids: set[str] = set()
seen_documents: dict[str, DocumentMetadata] = {}
for document_metadata in document_metadata_batch:
if document_metadata.document_id not in seen_document_ids:
seen_document_ids.add(document_metadata.document_id)
doc_id = document_metadata.document_id
if doc_id not in seen_documents:
seen_documents[doc_id] = document_metadata
insert_stmt = insert(Document).values(
[model_to_dict(Document(id=doc_id)) for doc_id in seen_document_ids]
insert_stmt = insert(DbDocument).values(
[
model_to_dict(
DbDocument(
id=doc.document_id,
boost=DEFAULT_BOOST,
hidden=False,
semantic_id=doc.semantic_identifier,
link=doc.first_link,
)
)
for doc in seen_documents.values()
]
)
# for now, there are no columns to update. If more metadata is added, then this
# needs to change to an `on_conflict_do_update`
@ -120,7 +133,8 @@ def upsert_document_by_connector_credential_pair(
def upsert_documents_complete(
db_session: Session, document_metadata_batch: list[DocumentMetadata]
db_session: Session,
document_metadata_batch: list[DocumentMetadata],
) -> None:
upsert_documents(db_session, document_metadata_batch)
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
@ -140,7 +154,7 @@ def delete_document_by_connector_credential_pair(
def delete_documents(db_session: Session, document_ids: list[str]) -> None:
db_session.execute(delete(Document).where(Document.id.in_(document_ids)))
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None:

View File

@ -0,0 +1,156 @@
from uuid import UUID
from sqlalchemy import asc
from sqlalchemy import desc
from sqlalchemy import select
from sqlalchemy.orm import Session
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.datastores.datastore_utils import translate_boost_count_to_multiplier
from danswer.datastores.document_index import get_default_document_index
from danswer.datastores.interfaces import UpdateRequest
from danswer.db.models import Document as DbDocument
from danswer.db.models import DocumentRetrievalFeedback
from danswer.db.models import QueryEvent
from danswer.search.models import SearchType
def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent:
stmt = select(QueryEvent).where(QueryEvent.id == query_id)
result = db_session.execute(stmt)
query_event = result.scalar_one_or_none()
if not query_event:
raise ValueError("Invalid Query Event provided for updating")
return query_event
def fetch_doc_m_by_id(doc_id: str, db_session: Session) -> DbDocument:
stmt = select(DbDocument).where(DbDocument.id == doc_id)
result = db_session.execute(stmt)
doc_m = result.scalar_one_or_none()
if not doc_m:
raise ValueError("Invalid Document provided for updating")
return doc_m
def fetch_docs_ranked_by_boost(
db_session: Session, ascending: bool = False, limit: int = 100
) -> list[DbDocument]:
order_func = asc if ascending else desc
stmt = select(DbDocument).order_by(order_func(DbDocument.boost)).limit(limit)
result = db_session.execute(stmt)
doc_m_list = result.scalars().all()
return list(doc_m_list)
def create_document_metadata(
doc_id: str,
semantic_id: str,
link: str | None,
db_session: Session,
) -> None:
try:
fetch_doc_m_by_id(doc_id, db_session)
return
except ValueError:
# Document already exists, don't reset its data
pass
DbDocument(
id=doc_id,
semantic_id=semantic_id,
link=link,
)
def create_query_event(
query: str,
selected_flow: SearchType | None,
llm_answer: str | None,
user_id: UUID | None,
db_session: Session,
) -> int:
query_event = QueryEvent(
query=query,
selected_search_flow=selected_flow,
llm_answer=llm_answer,
user_id=user_id,
)
db_session.add(query_event)
db_session.commit()
return query_event.id
def update_query_event_feedback(
feedback: QAFeedbackType,
query_id: int,
user_id: UUID | None,
db_session: Session,
) -> None:
query_event = fetch_query_event_by_id(query_id, db_session)
if user_id != query_event.user_id:
raise ValueError("User trying to give feedback on a query run by another user.")
query_event.feedback = feedback
db_session.commit()
def create_doc_retrieval_feedback(
qa_event_id: int,
document_id: str,
document_rank: int,
user_id: UUID | None,
db_session: Session,
clicked: bool = False,
feedback: SearchFeedbackType | None = None,
) -> None:
if not clicked and feedback is None:
raise ValueError("No action taken, not valid feedback")
query_event = fetch_query_event_by_id(qa_event_id, db_session)
if user_id != query_event.user_id:
raise ValueError("User trying to give feedback on a query run by another user.")
doc_m = fetch_doc_m_by_id(document_id, db_session)
retrieval_feedback = DocumentRetrievalFeedback(
qa_event_id=qa_event_id,
document_id=document_id,
document_rank=document_rank,
clicked=clicked,
feedback=feedback,
)
if feedback is not None:
if feedback == SearchFeedbackType.ENDORSE:
doc_m.boost += 1
elif feedback == SearchFeedbackType.REJECT:
doc_m.boost -= 1
elif feedback == SearchFeedbackType.HIDE:
doc_m.hidden = True
elif feedback == SearchFeedbackType.UNHIDE:
doc_m.hidden = False
else:
raise ValueError("Unhandled document feedback type")
if feedback in [SearchFeedbackType.ENDORSE, SearchFeedbackType.REJECT]:
document_index = get_default_document_index()
update = UpdateRequest(
document_ids=[document_id],
boost=translate_boost_count_to_multiplier(doc_m.boost),
)
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
document_index.update([update])
db_session.add(retrieval_feedback)
db_session.commit()

View File

@ -23,8 +23,12 @@ from sqlalchemy.orm import mapped_column
from sqlalchemy.orm import relationship
from danswer.auth.schemas import UserRole
from danswer.configs.constants import DEFAULT_BOOST
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import InputType
from danswer.search.models import SearchType
class IndexingStatus(str, PyEnum):
@ -61,6 +65,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
credentials: Mapped[List["Credential"]] = relationship(
"Credential", back_populates="user", lazy="joined"
)
query_events: Mapped[List["QueryEvent"]] = relationship(
"QueryEvent", back_populates="user"
)
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
@ -162,7 +169,7 @@ class Credential(Base):
deletion_attempt: Mapped[Optional["DeletionAttempt"]] = relationship(
"DeletionAttempt", back_populates="credential"
)
user: Mapped[User] = relationship("User", back_populates="credentials")
user: Mapped[User | None] = relationship("User", back_populates="credentials")
class IndexAttempt(Base):
@ -258,17 +265,6 @@ class DeletionAttempt(Base):
)
class Document(Base):
"""Represents a single documents from a source. This is used to store
document level metadata, but currently nothing is stored"""
__tablename__ = "document"
# this should correspond to the ID of the document (as is passed around
# in Danswer)
id: Mapped[str] = mapped_column(String, primary_key=True)
class DocumentByConnectorCredentialPair(Base):
"""Represents an indexing of a document by a specific connector / credential
pair"""
@ -289,3 +285,72 @@ class DocumentByConnectorCredentialPair(Base):
credential: Mapped[Credential] = relationship(
"Credential", back_populates="documents_by_credential"
)
class QueryEvent(Base):
__tablename__ = "query_event"
id: Mapped[int] = mapped_column(primary_key=True)
query: Mapped[str] = mapped_column(String())
# search_flow refers to user selection, None if user used auto
selected_search_flow: Mapped[SearchType | None] = mapped_column(
Enum(SearchType), nullable=True
)
llm_answer: Mapped[str | None] = mapped_column(String(), default=None)
feedback: Mapped[QAFeedbackType | None] = mapped_column(
Enum(QAFeedbackType), nullable=True
)
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
time_created: Mapped[datetime.datetime] = mapped_column(
DateTime(timezone=True),
server_default=func.now(),
)
user: Mapped[User | None] = relationship("User", back_populates="query_events")
document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
"DocumentRetrievalFeedback", back_populates="qa_event"
)
class DocumentRetrievalFeedback(Base):
__tablename__ = "document_retrieval_feedback"
id: Mapped[int] = mapped_column(primary_key=True)
qa_event_id: Mapped[int] = mapped_column(
ForeignKey("query_event.id"),
)
document_id: Mapped[str] = mapped_column(
ForeignKey("document.id"),
)
# How high up this document is in the results, 1 for first
document_rank: Mapped[int] = mapped_column(Integer)
clicked: Mapped[bool] = mapped_column(Boolean, default=False)
feedback: Mapped[SearchFeedbackType | None] = mapped_column(
Enum(SearchFeedbackType), nullable=True
)
qa_event: Mapped[QueryEvent] = relationship(
"QueryEvent", back_populates="document_feedbacks"
)
document: Mapped["Document"] = relationship(
"Document", back_populates="retrieval_feedbacks"
)
class Document(Base):
__tablename__ = "document"
# this should correspond to the ID of the document (as is passed around
# in Danswer)
id: Mapped[str] = mapped_column(String, primary_key=True)
# 0 for neutral, positive for mostly endorse, negative for mostly reject
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
semantic_id: Mapped[str] = mapped_column(String)
# First Section's link
link: Mapped[str | None] = mapped_column(String, nullable=True)
# TODO if more sensitive data is added here for display, make sure to add user/group permission
retrieval_feedbacks: Mapped[List[DocumentRetrievalFeedback]] = relationship(
"DocumentRetrievalFeedback", back_populates="document"
)

View File

@ -1,8 +1,11 @@
from sqlalchemy.orm import Session
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.configs.app_configs import QA_TIMEOUT
from danswer.datastores.document_index import get_default_document_index
from danswer.db.feedback import create_query_event
from danswer.db.models import User
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
@ -22,19 +25,27 @@ logger = setup_logger()
@log_function_time()
def answer_question(
def answer_qa_query(
question: QuestionRequest,
user: User | None,
db_session: Session,
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
answer_generation_timeout: int = QA_TIMEOUT,
) -> QAResponse:
query = question.query
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
logger.info(f"Received QA query: {query}")
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.KEYWORD,
llm_answer=None,
user_id=user.id if user is not None else None,
db_session=db_session,
)
predicted_search, predicted_flow = query_intent(query)
if use_keyword is None:
use_keyword = predicted_search == SearchType.KEYWORD
@ -42,12 +53,12 @@ def answer_question(
user_id = None if user is None else user.id
if use_keyword:
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
query, user_id, filters, get_default_document_index(collection=collection)
query, user_id, filters, get_default_document_index()
)
unranked_chunks: list[InferenceChunk] | None = []
else:
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, get_default_document_index(collection=collection)
query, user_id, filters, get_default_document_index()
)
if not ranked_chunks:
return QAResponse(
@ -57,6 +68,7 @@ def answer_question(
lower_ranked_docs=None,
predicted_flow=predicted_flow,
predicted_search=predicted_search,
query_event_id=query_event_id,
)
if disable_generative_answer:
@ -70,6 +82,7 @@ def answer_question(
# to run QA over more documents
predicted_flow=QueryFlow.SEARCH,
predicted_search=predicted_search,
query_event_id=query_event_id,
)
try:
@ -83,6 +96,7 @@ def answer_question(
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=str(e),
query_event_id=query_event_id,
)
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
@ -108,4 +122,5 @@ def answer_question(
predicted_flow=predicted_flow,
predicted_search=predicted_search,
error_msg=error_msg,
query_event_id=query_event_id,
)

View File

@ -9,6 +9,7 @@ from slack_sdk import WebClient
from slack_sdk.socket_mode import SocketModeClient
from slack_sdk.socket_mode.request import SocketModeRequest
from slack_sdk.socket_mode.response import SocketModeResponse
from sqlalchemy.orm import Session
from danswer.configs.app_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
from danswer.configs.app_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
@ -18,7 +19,8 @@ from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
from danswer.configs.constants import DocumentSource
from danswer.connectors.slack.utils import make_slack_api_rate_limited
from danswer.connectors.slack.utils import UserIdReplacer
from danswer.direct_qa.answer_question import answer_question
from danswer.db.engine import get_sqlalchemy_engine
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.interfaces import DanswerQuote
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
@ -228,17 +230,19 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
logger=cast(logging.Logger, logger),
)
def _get_answer(question: QuestionRequest) -> QAResponse:
answer = answer_question(
question=question,
user=None,
answer_generation_timeout=DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
engine = get_sqlalchemy_engine()
with Session(engine, expire_on_commit=False) as db_session:
answer = answer_qa_query(
question=question,
user=None,
db_session=db_session,
answer_generation_timeout=DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
)
if not answer.error_msg:
return answer
else:
raise RuntimeError(answer.error_msg)
answer = None
try:
answer = _get_answer(
QuestionRequest(

View File

@ -1,3 +1,4 @@
import os
from typing import Any
from langchain.chat_models.azure_openai import AzureChatOpenAI
@ -22,6 +23,11 @@ class AzureGPT(LangChainChatLLM):
*args: list[Any],
**kwargs: dict[str, Any]
):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = os.environ.get("OPENAI_API_KEY") or "dummy_api_key"
self._llm = AzureChatOpenAI(
model=model_version,
openai_api_type="azure",

View File

@ -1,3 +1,4 @@
import os
from typing import Any
from langchain.chat_models.openai import ChatOpenAI
@ -16,6 +17,11 @@ class OpenAIGPT(LangChainChatLLM):
*args: list[Any],
**kwargs: dict[str, Any]
):
# set a dummy API key if not specified so that LangChain doesn't throw an
# exception when trying to initialize the LLM which would prevent the API
# server from starting up
if not api_key:
api_key = os.environ.get("OPENAI_API_KEY") or "dummy_api_key"
self._llm = ChatOpenAI(
model=model_version,
openai_api_key=api_key,

View File

@ -49,6 +49,7 @@ from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
from danswer.db.deletion_attempt import create_deletion_attempt
from danswer.db.deletion_attempt import get_deletion_attempts
from danswer.db.engine import get_session
from danswer.db.feedback import fetch_docs_ranked_by_boost
from danswer.db.index_attempt import create_index_attempt
from danswer.db.index_attempt import get_latest_index_attempts
from danswer.db.models import DeletionAttempt
@ -61,6 +62,7 @@ from danswer.dynamic_configs.interface import ConfigNotFoundError
from danswer.server.models import ApiKey
from danswer.server.models import AuthStatus
from danswer.server.models import AuthUrl
from danswer.server.models import BoostDoc
from danswer.server.models import ConnectorBase
from danswer.server.models import ConnectorCredentialPairIdentifier
from danswer.server.models import ConnectorIndexingStatus
@ -79,7 +81,6 @@ from danswer.server.models import StatusResponse
from danswer.server.models import UserRoleResponse
from danswer.utils.logger import setup_logger
router = APIRouter(prefix="/manage")
logger = setup_logger()
@ -89,6 +90,28 @@ _GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id"
"""Admin only API endpoints"""
@router.get("/admin/doc-boosts")
def get_most_boosted_docs(
ascending: bool,
limit: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> list[BoostDoc]:
boost_docs = fetch_docs_ranked_by_boost(
ascending=ascending, limit=limit, db_session=db_session
)
return [
BoostDoc(
document_id=doc.id,
semantic_id=doc.semantic_id,
link=doc.link or "",
boost=doc.boost,
hidden=doc.hidden,
)
for doc in boost_docs
]
@router.get("/admin/connector/google-drive/app-credential")
def check_google_app_credentials_exist(
_: User = Depends(current_admin_user),

View File

@ -11,6 +11,8 @@ from pydantic.generics import GenericModel
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
from danswer.configs.constants import DocumentSource
from danswer.configs.constants import QAFeedbackType
from danswer.configs.constants import SearchFeedbackType
from danswer.connectors.models import InputType
from danswer.datastores.interfaces import IndexFilter
from danswer.db.models import Connector
@ -105,6 +107,14 @@ class UserRoleResponse(BaseModel):
role: str
class BoostDoc(BaseModel):
document_id: str
semantic_id: str
link: str
boost: int
hidden: bool
class SearchDoc(BaseModel):
document_id: str
semantic_identifier: str
@ -121,10 +131,24 @@ class QuestionRequest(BaseModel):
offset: int | None
class QAFeedbackRequest(BaseModel):
query_id: int
feedback: QAFeedbackType
class SearchFeedbackRequest(BaseModel):
query_id: int
document_id: str
document_rank: int
click: bool
search_feedback: SearchFeedbackType
class SearchResponse(BaseModel):
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
top_ranked_docs: list[SearchDoc] | None
lower_ranked_docs: list[SearchDoc] | None
query_event_id: int
class QAResponse(SearchResponse):

View File

@ -5,16 +5,22 @@ from dataclasses import asdict
from fastapi import APIRouter
from fastapi import Depends
from fastapi.responses import StreamingResponse
from sqlalchemy.orm import Session
from danswer.auth.users import current_user
from danswer.chunking.models import InferenceChunk
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
from danswer.datastores.document_index import get_default_document_index
from danswer.db.engine import get_session
from danswer.db.feedback import create_doc_retrieval_feedback
from danswer.db.feedback import create_query_event
from danswer.db.feedback import update_query_event_feedback
from danswer.db.models import User
from danswer.direct_qa.answer_question import answer_question
from danswer.direct_qa.answer_question import answer_qa_query
from danswer.direct_qa.exceptions import OpenAIKeyMissing
from danswer.direct_qa.exceptions import UnknownModelError
from danswer.direct_qa.interfaces import DanswerAnswerPiece
from danswer.direct_qa.llm_utils import get_default_qa_model
from danswer.search.danswer_helper import query_intent
from danswer.search.danswer_helper import recommend_search_flow
@ -24,8 +30,10 @@ from danswer.search.models import SearchType
from danswer.search.semantic_search import chunks_to_search_docs
from danswer.search.semantic_search import retrieve_ranked_documents
from danswer.server.models import HelperResponse
from danswer.server.models import QAFeedbackRequest
from danswer.server.models import QAResponse
from danswer.server.models import QuestionRequest
from danswer.server.models import SearchFeedbackRequest
from danswer.server.models import SearchResponse
from danswer.utils.logger import setup_logger
from danswer.utils.timing import log_generator_function_time
@ -50,62 +58,95 @@ def get_search_type(
@router.post("/semantic-search")
def semantic_search(
question: QuestionRequest, user: User = Depends(current_user)
question: QuestionRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchResponse:
query = question.query
collection = question.collection
filters = question.filters
logger.info(f"Received semantic search query: {query}")
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.SEMANTIC,
llm_answer=None,
user_id=user.id,
db_session=db_session,
)
user_id = None if user is None else user.id
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
query, user_id, filters, get_default_document_index(collection=collection)
query, user_id, filters, get_default_document_index()
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
return SearchResponse(
top_ranked_docs=None, lower_ranked_docs=None, query_event_id=query_event_id
)
top_docs = chunks_to_search_docs(ranked_chunks)
other_top_docs = chunks_to_search_docs(unranked_chunks)
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=other_top_docs)
return SearchResponse(
top_ranked_docs=top_docs,
lower_ranked_docs=other_top_docs,
query_event_id=query_event_id,
)
@router.post("/keyword-search")
def keyword_search(
question: QuestionRequest, user: User = Depends(current_user)
question: QuestionRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> SearchResponse:
query = question.query
collection = question.collection
filters = question.filters
logger.info(f"Received keyword search query: {query}")
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.KEYWORD,
llm_answer=None,
user_id=user.id,
db_session=db_session,
)
user_id = None if user is None else user.id
ranked_chunks = retrieve_keyword_documents(
query, user_id, filters, get_default_document_index(collection=collection)
query, user_id, filters, get_default_document_index()
)
if not ranked_chunks:
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
return SearchResponse(
top_ranked_docs=None, lower_ranked_docs=None, query_event_id=query_event_id
)
top_docs = chunks_to_search_docs(ranked_chunks)
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=None)
return SearchResponse(
top_ranked_docs=top_docs, lower_ranked_docs=None, query_event_id=query_event_id
)
@router.post("/direct-qa")
def direct_qa(
question: QuestionRequest, user: User = Depends(current_user)
question: QuestionRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> QAResponse:
return answer_question(question=question, user=user)
return answer_qa_query(question=question, user=user, db_session=db_session)
@router.post("/stream-direct-qa")
def stream_direct_qa(
question: QuestionRequest, user: User = Depends(current_user)
question: QuestionRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> StreamingResponse:
send_packet_debug_msg = "Sending Packet: {}"
top_documents_key = "top_documents"
unranked_top_docs_key = "unranked_top_documents"
predicted_flow_key = "predicted_flow"
predicted_search_key = "predicted_search"
query_event_id_key = "query_event_id"
logger.debug(f"Received QA query: {question.query}")
logger.debug(f"Query filters: {question.filters}")
@ -116,8 +157,8 @@ def stream_direct_qa(
def stream_qa_portions(
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
) -> Generator[str, None, None]:
answer_so_far: str = ""
query = question.query
collection = question.collection
filters = question.filters
use_keyword = question.use_keyword
offset_count = question.offset if question.offset is not None else 0
@ -132,7 +173,7 @@ def stream_direct_qa(
query,
user_id,
filters,
get_default_document_index(collection=collection),
get_default_document_index(),
)
unranked_chunks: list[InferenceChunk] | None = []
else:
@ -140,7 +181,7 @@ def stream_direct_qa(
query,
user_id,
filters,
get_default_document_index(collection=collection),
get_default_document_index(),
)
if not ranked_chunks:
logger.debug("No Documents Found")
@ -194,6 +235,11 @@ def stream_direct_qa(
):
if response_packet is None:
continue
if (
isinstance(response_packet, DanswerAnswerPiece)
and response_packet.answer_piece
):
answer_so_far = answer_so_far + response_packet.answer_piece
logger.debug(f"Sending packet: {response_packet}")
yield get_json_line(asdict(response_packet))
except Exception as e:
@ -201,6 +247,49 @@ def stream_direct_qa(
yield get_json_line({"error": str(e)})
logger.exception("Failed to run QA")
query_event_id = create_query_event(
query=query,
selected_flow=SearchType.KEYWORD
if question.use_keyword
else SearchType.SEMANTIC,
llm_answer=answer_so_far,
user_id=user_id,
db_session=db_session,
)
yield get_json_line({query_event_id_key: query_event_id})
return
return StreamingResponse(stream_qa_portions(), media_type="application/json")
@router.post("/query-feedback")
def process_query_feedback(
feedback: QAFeedbackRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_query_event_feedback(
feedback=feedback.feedback,
query_id=feedback.query_id,
user_id=user.id if user is not None else None,
db_session=db_session,
)
@router.post("/doc-retrieval-feedback")
def process_doc_retrieval_feedback(
feedback: SearchFeedbackRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
create_doc_retrieval_feedback(
qa_event_id=feedback.query_id,
document_id=feedback.document_id,
document_rank=feedback.document_rank,
clicked=feedback.click,
feedback=feedback.search_feedback,
user_id=user.id if user is not None else None,
db_session=db_session,
)