mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-03-26 17:51:54 +01:00
Learn from feedback backend (#343)
--------- Co-authored-by: Weves <chrisweaver101@gmail.com>
This commit is contained in:
parent
c43a403b71
commit
b2a51283d1
93
backend/alembic/versions/d929f0c1c6af_feedback_feature.py
Normal file
93
backend/alembic/versions/d929f0c1c6af_feedback_feature.py
Normal file
@ -0,0 +1,93 @@
|
||||
"""Feedback Feature
|
||||
|
||||
Revision ID: d929f0c1c6af
|
||||
Revises: 8aabb57f3b49
|
||||
Create Date: 2023-08-27 13:03:54.274987
|
||||
|
||||
"""
|
||||
import fastapi_users_db_sqlalchemy
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = "d929f0c1c6af"
|
||||
down_revision = "8aabb57f3b49"
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
op.create_table(
|
||||
"query_event",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("query", sa.String(), nullable=False),
|
||||
sa.Column(
|
||||
"selected_search_flow",
|
||||
sa.Enum("KEYWORD", "SEMANTIC", name="searchtype"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column("llm_answer", sa.String(), nullable=True),
|
||||
sa.Column(
|
||||
"feedback",
|
||||
sa.Enum("LIKE", "DISLIKE", name="qafeedbacktype"),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"user_id",
|
||||
fastapi_users_db_sqlalchemy.generics.GUID(),
|
||||
nullable=True,
|
||||
),
|
||||
sa.Column(
|
||||
"time_created",
|
||||
sa.DateTime(timezone=True),
|
||||
server_default=sa.text("now()"),
|
||||
nullable=False,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["user_id"],
|
||||
["user.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.create_table(
|
||||
"document_retrieval_feedback",
|
||||
sa.Column("id", sa.Integer(), nullable=False),
|
||||
sa.Column("qa_event_id", sa.Integer(), nullable=False),
|
||||
sa.Column("document_id", sa.String(), nullable=False),
|
||||
sa.Column("document_rank", sa.Integer(), nullable=False),
|
||||
sa.Column("clicked", sa.Boolean(), nullable=False),
|
||||
sa.Column(
|
||||
"feedback",
|
||||
sa.Enum(
|
||||
"ENDORSE",
|
||||
"REJECT",
|
||||
"HIDE",
|
||||
"UNHIDE",
|
||||
name="searchfeedbacktype",
|
||||
),
|
||||
nullable=True,
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["document_id"],
|
||||
["document.id"],
|
||||
),
|
||||
sa.ForeignKeyConstraint(
|
||||
["qa_event_id"],
|
||||
["query_event.id"],
|
||||
),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
)
|
||||
op.add_column("document", sa.Column("boost", sa.Integer(), nullable=False))
|
||||
op.add_column("document", sa.Column("hidden", sa.Boolean(), nullable=False))
|
||||
op.add_column("document", sa.Column("semantic_id", sa.String(), nullable=False))
|
||||
op.add_column("document", sa.Column("link", sa.String(), nullable=True))
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("document", "link")
|
||||
op.drop_column("document", "semantic_id")
|
||||
op.drop_column("document", "hidden")
|
||||
op.drop_column("document", "boost")
|
||||
op.drop_table("document_retrieval_feedback")
|
||||
op.drop_table("query_event")
|
@ -90,7 +90,7 @@ def _delete_connector_credential_pair(
|
||||
def _get_user(
|
||||
credential: Credential,
|
||||
) -> str:
|
||||
if credential.public_doc:
|
||||
if credential.public_doc or not credential.user:
|
||||
return PUBLIC_DOC_PAT
|
||||
|
||||
return str(credential.user.id)
|
||||
|
@ -23,6 +23,7 @@ from danswer.db.connector_credential_pair import update_connector_credential_pai
|
||||
from danswer.db.credentials import backend_update_credential_json
|
||||
from danswer.db.engine import get_db_current_time
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.db.feedback import create_document_metadata
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_index_attempt
|
||||
from danswer.db.index_attempt import get_inprogress_index_attempts
|
||||
@ -246,6 +247,7 @@ def _run_indexing(
|
||||
logger.debug(
|
||||
f"Indexing batch of documents: {[doc.to_short_descriptor() for doc in doc_batch]}"
|
||||
)
|
||||
|
||||
index_user_id = (
|
||||
None if db_credential.public_doc else db_credential.user_id
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ HTML_SEPARATOR = "\n"
|
||||
PUBLIC_DOC_PAT = "PUBLIC"
|
||||
QUOTE = "quote"
|
||||
BOOST = "boost"
|
||||
DEFAULT_BOOST = 0
|
||||
|
||||
|
||||
class DocumentSource(str, Enum):
|
||||
@ -66,3 +67,15 @@ class ModelHostType(str, Enum):
|
||||
# https://medium.com/@yuhongsun96/host-a-llama-2-api-on-gpu-for-free-a5311463c183
|
||||
COLAB_DEMO = "colab-demo"
|
||||
# TODO support for Azure, AWS, GCP GenAI model hosting
|
||||
|
||||
|
||||
class QAFeedbackType(str, Enum):
|
||||
LIKE = "like" # User likes the answer, used for metrics
|
||||
DISLIKE = "dislike" # User dislikes the answer, used for metrics
|
||||
|
||||
|
||||
class SearchFeedbackType(str, Enum):
|
||||
ENDORSE = "endorse" # boost this document for all future queries
|
||||
REJECT = "reject" # down-boost this document for all future queries
|
||||
HIDE = "hide" # mark this document as untrusted, hide from LLM
|
||||
UNHIDE = "unhide"
|
||||
|
@ -12,6 +12,15 @@ from danswer.connectors.models import IndexAttemptMetadata
|
||||
|
||||
|
||||
DEFAULT_BATCH_SIZE = 30
|
||||
BOOST_MULTIPLIER = 1.2
|
||||
|
||||
|
||||
def translate_boost_count_to_multiplier(boost: int) -> float:
|
||||
if boost > 0:
|
||||
return BOOST_MULTIPLIER**boost
|
||||
elif boost < 0:
|
||||
return 1 / (BOOST_MULTIPLIER**boost)
|
||||
return 1
|
||||
|
||||
|
||||
def get_uuid_from_chunk(
|
||||
|
@ -32,6 +32,7 @@ class IndexingPipelineProtocol(Protocol):
|
||||
def _upsert_insertion_records(
|
||||
insertion_records: set[DocumentInsertionRecord],
|
||||
index_attempt_metadata: IndexAttemptMetadata,
|
||||
doc_m_data_lookup: dict[str, tuple[str, str]],
|
||||
) -> None:
|
||||
with Session(get_sqlalchemy_engine()) as session:
|
||||
upsert_documents_complete(
|
||||
@ -40,9 +41,11 @@ def _upsert_insertion_records(
|
||||
DocumentMetadata(
|
||||
connector_id=index_attempt_metadata.connector_id,
|
||||
credential_id=index_attempt_metadata.credential_id,
|
||||
document_id=insertion_record.document_id,
|
||||
document_id=i_r.document_id,
|
||||
semantic_identifier=doc_m_data_lookup[i_r.document_id][0],
|
||||
first_link=doc_m_data_lookup[i_r.document_id][1],
|
||||
)
|
||||
for insertion_record in insertion_records
|
||||
for i_r in insertion_records
|
||||
],
|
||||
)
|
||||
|
||||
@ -62,6 +65,11 @@ def _get_net_new_documents(
|
||||
return net_new_documents
|
||||
|
||||
|
||||
def _extract_minimal_document_metadata(doc: Document) -> tuple[str, str]:
|
||||
first_link = next((section.link for section in doc.sections if section.link), "")
|
||||
return doc.semantic_identifier, first_link
|
||||
|
||||
|
||||
def _indexing_pipeline(
|
||||
*,
|
||||
chunker: Chunker,
|
||||
@ -73,6 +81,11 @@ def _indexing_pipeline(
|
||||
"""Takes different pieces of the indexing pipeline and applies it to a batch of documents
|
||||
Note that the documents should already be batched at this point so that it does not inflate the
|
||||
memory requirements"""
|
||||
|
||||
document_metadata_lookup = {
|
||||
doc.id: _extract_minimal_document_metadata(doc) for doc in documents
|
||||
}
|
||||
|
||||
chunks: list[DocAwareChunk] = list(
|
||||
chain(*[chunker.chunk(document=document) for document in documents])
|
||||
)
|
||||
@ -92,6 +105,7 @@ def _indexing_pipeline(
|
||||
_upsert_insertion_records(
|
||||
insertion_records=insertion_records,
|
||||
index_attempt_metadata=index_attempt_metadata,
|
||||
doc_m_data_lookup=document_metadata_lookup,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(
|
||||
|
@ -22,6 +22,8 @@ class DocumentMetadata:
|
||||
connector_id: int
|
||||
credential_id: int
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
first_link: str
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -32,7 +34,7 @@ class UpdateRequest:
|
||||
document_ids: list[str]
|
||||
# all other fields will be left alone
|
||||
allowed_users: list[str] | None = None
|
||||
boost: int | None = None
|
||||
boost: float | None = None
|
||||
|
||||
|
||||
class Verifiable(abc.ABC):
|
||||
|
@ -341,16 +341,20 @@ class VespaIndex(DocumentIndex):
|
||||
logger.error("Update request received but nothing to update")
|
||||
continue
|
||||
|
||||
update_dict: dict[str, dict[str, list[str] | int]] = {"fields": {}}
|
||||
update_dict: dict[str, dict] = {"fields": {}}
|
||||
if update_request.boost:
|
||||
update_dict["fields"][BOOST] = update_request.boost
|
||||
update_dict["fields"][BOOST] = {"assign": update_request.boost}
|
||||
if update_request.allowed_users:
|
||||
update_dict["fields"][ALLOWED_USERS] = update_request.allowed_users
|
||||
update_dict["fields"][ALLOWED_USERS] = {
|
||||
"assign": update_request.allowed_users
|
||||
}
|
||||
|
||||
for document_id in update_request.document_ids:
|
||||
for doc_chunk_id in _get_vespa_chunk_ids_by_document_id(document_id):
|
||||
url = f"{DOCUMENT_ID_ENDPOINT}/{doc_chunk_id}"
|
||||
res = requests.put(url, headers=json_header, json=update_dict)
|
||||
res = requests.put(
|
||||
url, headers=json_header, data=json.dumps(update_dict)
|
||||
)
|
||||
|
||||
try:
|
||||
res.raise_for_status()
|
||||
|
@ -7,8 +7,9 @@ from sqlalchemy import select
|
||||
from sqlalchemy.dialects.postgresql import insert
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.datastores.interfaces import DocumentMetadata
|
||||
from danswer.db.models import Document
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentByConnectorCredentialPair
|
||||
from danswer.db.utils import model_to_dict
|
||||
from danswer.utils.logger import setup_logger
|
||||
@ -20,7 +21,7 @@ def get_documents_with_single_connector_credential_pair(
|
||||
db_session: Session,
|
||||
connector_id: int,
|
||||
credential_id: int,
|
||||
) -> Sequence[Document]:
|
||||
) -> Sequence[DbDocument]:
|
||||
initial_doc_ids_stmt = select(DocumentByConnectorCredentialPair.id).where(
|
||||
and_(
|
||||
DocumentByConnectorCredentialPair.connector_id == connector_id,
|
||||
@ -31,17 +32,17 @@ def get_documents_with_single_connector_credential_pair(
|
||||
# Filter it down to the documents with only a single connector/credential pair
|
||||
# Meaning if this connector/credential pair is removed, this doc should be gone
|
||||
trimmed_doc_ids_stmt = (
|
||||
select(Document.id)
|
||||
select(DbDocument.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DocumentByConnectorCredentialPair.id == Document.id,
|
||||
DocumentByConnectorCredentialPair.id == DbDocument.id,
|
||||
)
|
||||
.where(Document.id.in_(initial_doc_ids_stmt))
|
||||
.group_by(Document.id)
|
||||
.where(DbDocument.id.in_(initial_doc_ids_stmt))
|
||||
.group_by(DbDocument.id)
|
||||
.having(func.count(DocumentByConnectorCredentialPair.id) == 1)
|
||||
)
|
||||
|
||||
stmt = select(Document).where(Document.id.in_(trimmed_doc_ids_stmt))
|
||||
stmt = select(DbDocument).where(DbDocument.id.in_(trimmed_doc_ids_stmt))
|
||||
return db_session.scalars(stmt).all()
|
||||
|
||||
|
||||
@ -60,13 +61,13 @@ def get_document_by_connector_credential_pairs_indexed_by_multiple(
|
||||
# Filter it down to the documents with more than 1 connector/credential pair
|
||||
# Meaning if this connector/credential pair is removed, this doc is still accessible
|
||||
trimmed_doc_ids_stmt = (
|
||||
select(Document.id)
|
||||
select(DbDocument.id)
|
||||
.join(
|
||||
DocumentByConnectorCredentialPair,
|
||||
DocumentByConnectorCredentialPair.id == Document.id,
|
||||
DocumentByConnectorCredentialPair.id == DbDocument.id,
|
||||
)
|
||||
.where(Document.id.in_(initial_doc_ids_stmt))
|
||||
.group_by(Document.id)
|
||||
.where(DbDocument.id.in_(initial_doc_ids_stmt))
|
||||
.group_by(DbDocument.id)
|
||||
.having(func.count(DocumentByConnectorCredentialPair.id) > 1)
|
||||
)
|
||||
|
||||
@ -81,13 +82,25 @@ def upsert_documents(
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
) -> None:
|
||||
"""NOTE: this function is Postgres specific. Not all DBs support the ON CONFLICT clause."""
|
||||
seen_document_ids: set[str] = set()
|
||||
seen_documents: dict[str, DocumentMetadata] = {}
|
||||
for document_metadata in document_metadata_batch:
|
||||
if document_metadata.document_id not in seen_document_ids:
|
||||
seen_document_ids.add(document_metadata.document_id)
|
||||
doc_id = document_metadata.document_id
|
||||
if doc_id not in seen_documents:
|
||||
seen_documents[doc_id] = document_metadata
|
||||
|
||||
insert_stmt = insert(Document).values(
|
||||
[model_to_dict(Document(id=doc_id)) for doc_id in seen_document_ids]
|
||||
insert_stmt = insert(DbDocument).values(
|
||||
[
|
||||
model_to_dict(
|
||||
DbDocument(
|
||||
id=doc.document_id,
|
||||
boost=DEFAULT_BOOST,
|
||||
hidden=False,
|
||||
semantic_id=doc.semantic_identifier,
|
||||
link=doc.first_link,
|
||||
)
|
||||
)
|
||||
for doc in seen_documents.values()
|
||||
]
|
||||
)
|
||||
# for now, there are no columns to update. If more metadata is added, then this
|
||||
# needs to change to an `on_conflict_do_update`
|
||||
@ -120,7 +133,8 @@ def upsert_document_by_connector_credential_pair(
|
||||
|
||||
|
||||
def upsert_documents_complete(
|
||||
db_session: Session, document_metadata_batch: list[DocumentMetadata]
|
||||
db_session: Session,
|
||||
document_metadata_batch: list[DocumentMetadata],
|
||||
) -> None:
|
||||
upsert_documents(db_session, document_metadata_batch)
|
||||
upsert_document_by_connector_credential_pair(db_session, document_metadata_batch)
|
||||
@ -140,7 +154,7 @@ def delete_document_by_connector_credential_pair(
|
||||
|
||||
|
||||
def delete_documents(db_session: Session, document_ids: list[str]) -> None:
|
||||
db_session.execute(delete(Document).where(Document.id.in_(document_ids)))
|
||||
db_session.execute(delete(DbDocument).where(DbDocument.id.in_(document_ids)))
|
||||
|
||||
|
||||
def delete_documents_complete(db_session: Session, document_ids: list[str]) -> None:
|
||||
|
156
backend/danswer/db/feedback.py
Normal file
156
backend/danswer/db/feedback.py
Normal file
@ -0,0 +1,156 @@
|
||||
from uuid import UUID
|
||||
|
||||
from sqlalchemy import asc
|
||||
from sqlalchemy import desc
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.datastores.datastore_utils import translate_boost_count_to_multiplier
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.datastores.interfaces import UpdateRequest
|
||||
from danswer.db.models import Document as DbDocument
|
||||
from danswer.db.models import DocumentRetrievalFeedback
|
||||
from danswer.db.models import QueryEvent
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
def fetch_query_event_by_id(query_id: int, db_session: Session) -> QueryEvent:
|
||||
stmt = select(QueryEvent).where(QueryEvent.id == query_id)
|
||||
result = db_session.execute(stmt)
|
||||
query_event = result.scalar_one_or_none()
|
||||
|
||||
if not query_event:
|
||||
raise ValueError("Invalid Query Event provided for updating")
|
||||
|
||||
return query_event
|
||||
|
||||
|
||||
def fetch_doc_m_by_id(doc_id: str, db_session: Session) -> DbDocument:
|
||||
stmt = select(DbDocument).where(DbDocument.id == doc_id)
|
||||
result = db_session.execute(stmt)
|
||||
doc_m = result.scalar_one_or_none()
|
||||
|
||||
if not doc_m:
|
||||
raise ValueError("Invalid Document provided for updating")
|
||||
|
||||
return doc_m
|
||||
|
||||
|
||||
def fetch_docs_ranked_by_boost(
|
||||
db_session: Session, ascending: bool = False, limit: int = 100
|
||||
) -> list[DbDocument]:
|
||||
order_func = asc if ascending else desc
|
||||
stmt = select(DbDocument).order_by(order_func(DbDocument.boost)).limit(limit)
|
||||
result = db_session.execute(stmt)
|
||||
doc_m_list = result.scalars().all()
|
||||
|
||||
return list(doc_m_list)
|
||||
|
||||
|
||||
def create_document_metadata(
|
||||
doc_id: str,
|
||||
semantic_id: str,
|
||||
link: str | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
try:
|
||||
fetch_doc_m_by_id(doc_id, db_session)
|
||||
return
|
||||
except ValueError:
|
||||
# Document already exists, don't reset its data
|
||||
pass
|
||||
|
||||
DbDocument(
|
||||
id=doc_id,
|
||||
semantic_id=semantic_id,
|
||||
link=link,
|
||||
)
|
||||
|
||||
|
||||
def create_query_event(
|
||||
query: str,
|
||||
selected_flow: SearchType | None,
|
||||
llm_answer: str | None,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> int:
|
||||
query_event = QueryEvent(
|
||||
query=query,
|
||||
selected_search_flow=selected_flow,
|
||||
llm_answer=llm_answer,
|
||||
user_id=user_id,
|
||||
)
|
||||
db_session.add(query_event)
|
||||
db_session.commit()
|
||||
|
||||
return query_event.id
|
||||
|
||||
|
||||
def update_query_event_feedback(
|
||||
feedback: QAFeedbackType,
|
||||
query_id: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
) -> None:
|
||||
query_event = fetch_query_event_by_id(query_id, db_session)
|
||||
|
||||
if user_id != query_event.user_id:
|
||||
raise ValueError("User trying to give feedback on a query run by another user.")
|
||||
|
||||
query_event.feedback = feedback
|
||||
|
||||
db_session.commit()
|
||||
|
||||
|
||||
def create_doc_retrieval_feedback(
|
||||
qa_event_id: int,
|
||||
document_id: str,
|
||||
document_rank: int,
|
||||
user_id: UUID | None,
|
||||
db_session: Session,
|
||||
clicked: bool = False,
|
||||
feedback: SearchFeedbackType | None = None,
|
||||
) -> None:
|
||||
if not clicked and feedback is None:
|
||||
raise ValueError("No action taken, not valid feedback")
|
||||
|
||||
query_event = fetch_query_event_by_id(qa_event_id, db_session)
|
||||
|
||||
if user_id != query_event.user_id:
|
||||
raise ValueError("User trying to give feedback on a query run by another user.")
|
||||
|
||||
doc_m = fetch_doc_m_by_id(document_id, db_session)
|
||||
|
||||
retrieval_feedback = DocumentRetrievalFeedback(
|
||||
qa_event_id=qa_event_id,
|
||||
document_id=document_id,
|
||||
document_rank=document_rank,
|
||||
clicked=clicked,
|
||||
feedback=feedback,
|
||||
)
|
||||
|
||||
if feedback is not None:
|
||||
if feedback == SearchFeedbackType.ENDORSE:
|
||||
doc_m.boost += 1
|
||||
elif feedback == SearchFeedbackType.REJECT:
|
||||
doc_m.boost -= 1
|
||||
elif feedback == SearchFeedbackType.HIDE:
|
||||
doc_m.hidden = True
|
||||
elif feedback == SearchFeedbackType.UNHIDE:
|
||||
doc_m.hidden = False
|
||||
else:
|
||||
raise ValueError("Unhandled document feedback type")
|
||||
|
||||
if feedback in [SearchFeedbackType.ENDORSE, SearchFeedbackType.REJECT]:
|
||||
document_index = get_default_document_index()
|
||||
update = UpdateRequest(
|
||||
document_ids=[document_id],
|
||||
boost=translate_boost_count_to_multiplier(doc_m.boost),
|
||||
)
|
||||
# Updates are generally batched for efficiency, this case only 1 doc/value is updated
|
||||
document_index.update([update])
|
||||
|
||||
db_session.add(retrieval_feedback)
|
||||
db_session.commit()
|
@ -23,8 +23,12 @@ from sqlalchemy.orm import mapped_column
|
||||
from sqlalchemy.orm import relationship
|
||||
|
||||
from danswer.auth.schemas import UserRole
|
||||
from danswer.configs.constants import DEFAULT_BOOST
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.search.models import SearchType
|
||||
|
||||
|
||||
class IndexingStatus(str, PyEnum):
|
||||
@ -61,6 +65,9 @@ class User(SQLAlchemyBaseUserTableUUID, Base):
|
||||
credentials: Mapped[List["Credential"]] = relationship(
|
||||
"Credential", back_populates="user", lazy="joined"
|
||||
)
|
||||
query_events: Mapped[List["QueryEvent"]] = relationship(
|
||||
"QueryEvent", back_populates="user"
|
||||
)
|
||||
|
||||
|
||||
class AccessToken(SQLAlchemyBaseAccessTokenTableUUID, Base):
|
||||
@ -162,7 +169,7 @@ class Credential(Base):
|
||||
deletion_attempt: Mapped[Optional["DeletionAttempt"]] = relationship(
|
||||
"DeletionAttempt", back_populates="credential"
|
||||
)
|
||||
user: Mapped[User] = relationship("User", back_populates="credentials")
|
||||
user: Mapped[User | None] = relationship("User", back_populates="credentials")
|
||||
|
||||
|
||||
class IndexAttempt(Base):
|
||||
@ -258,17 +265,6 @@ class DeletionAttempt(Base):
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
"""Represents a single documents from a source. This is used to store
|
||||
document level metadata, but currently nothing is stored"""
|
||||
|
||||
__tablename__ = "document"
|
||||
|
||||
# this should correspond to the ID of the document (as is passed around
|
||||
# in Danswer)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
|
||||
|
||||
class DocumentByConnectorCredentialPair(Base):
|
||||
"""Represents an indexing of a document by a specific connector / credential
|
||||
pair"""
|
||||
@ -289,3 +285,72 @@ class DocumentByConnectorCredentialPair(Base):
|
||||
credential: Mapped[Credential] = relationship(
|
||||
"Credential", back_populates="documents_by_credential"
|
||||
)
|
||||
|
||||
|
||||
class QueryEvent(Base):
|
||||
__tablename__ = "query_event"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
query: Mapped[str] = mapped_column(String())
|
||||
# search_flow refers to user selection, None if user used auto
|
||||
selected_search_flow: Mapped[SearchType | None] = mapped_column(
|
||||
Enum(SearchType), nullable=True
|
||||
)
|
||||
llm_answer: Mapped[str | None] = mapped_column(String(), default=None)
|
||||
feedback: Mapped[QAFeedbackType | None] = mapped_column(
|
||||
Enum(QAFeedbackType), nullable=True
|
||||
)
|
||||
user_id: Mapped[UUID | None] = mapped_column(ForeignKey("user.id"), nullable=True)
|
||||
time_created: Mapped[datetime.datetime] = mapped_column(
|
||||
DateTime(timezone=True),
|
||||
server_default=func.now(),
|
||||
)
|
||||
|
||||
user: Mapped[User | None] = relationship("User", back_populates="query_events")
|
||||
document_feedbacks: Mapped[List["DocumentRetrievalFeedback"]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="qa_event"
|
||||
)
|
||||
|
||||
|
||||
class DocumentRetrievalFeedback(Base):
|
||||
__tablename__ = "document_retrieval_feedback"
|
||||
|
||||
id: Mapped[int] = mapped_column(primary_key=True)
|
||||
qa_event_id: Mapped[int] = mapped_column(
|
||||
ForeignKey("query_event.id"),
|
||||
)
|
||||
document_id: Mapped[str] = mapped_column(
|
||||
ForeignKey("document.id"),
|
||||
)
|
||||
# How high up this document is in the results, 1 for first
|
||||
document_rank: Mapped[int] = mapped_column(Integer)
|
||||
clicked: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
feedback: Mapped[SearchFeedbackType | None] = mapped_column(
|
||||
Enum(SearchFeedbackType), nullable=True
|
||||
)
|
||||
|
||||
qa_event: Mapped[QueryEvent] = relationship(
|
||||
"QueryEvent", back_populates="document_feedbacks"
|
||||
)
|
||||
document: Mapped["Document"] = relationship(
|
||||
"Document", back_populates="retrieval_feedbacks"
|
||||
)
|
||||
|
||||
|
||||
class Document(Base):
|
||||
__tablename__ = "document"
|
||||
|
||||
# this should correspond to the ID of the document (as is passed around
|
||||
# in Danswer)
|
||||
id: Mapped[str] = mapped_column(String, primary_key=True)
|
||||
# 0 for neutral, positive for mostly endorse, negative for mostly reject
|
||||
boost: Mapped[int] = mapped_column(Integer, default=DEFAULT_BOOST)
|
||||
hidden: Mapped[bool] = mapped_column(Boolean, default=False)
|
||||
semantic_id: Mapped[str] = mapped_column(String)
|
||||
# First Section's link
|
||||
link: Mapped[str | None] = mapped_column(String, nullable=True)
|
||||
# TODO if more sensitive data is added here for display, make sure to add user/group permission
|
||||
|
||||
retrieval_feedbacks: Mapped[List[DocumentRetrievalFeedback]] = relationship(
|
||||
"DocumentRetrievalFeedback", back_populates="document"
|
||||
)
|
||||
|
@ -1,8 +1,11 @@
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
from danswer.configs.app_configs import QA_TIMEOUT
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
@ -22,19 +25,27 @@ logger = setup_logger()
|
||||
|
||||
|
||||
@log_function_time()
|
||||
def answer_question(
|
||||
def answer_qa_query(
|
||||
question: QuestionRequest,
|
||||
user: User | None,
|
||||
db_session: Session,
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
answer_generation_timeout: int = QA_TIMEOUT,
|
||||
) -> QAResponse:
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
use_keyword = question.use_keyword
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
logger.info(f"Received QA query: {query}")
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.KEYWORD,
|
||||
llm_answer=None,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
predicted_search, predicted_flow = query_intent(query)
|
||||
if use_keyword is None:
|
||||
use_keyword = predicted_search == SearchType.KEYWORD
|
||||
@ -42,12 +53,12 @@ def answer_question(
|
||||
user_id = None if user is None else user.id
|
||||
if use_keyword:
|
||||
ranked_chunks: list[InferenceChunk] | None = retrieve_keyword_documents(
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
query, user_id, filters, get_default_document_index()
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
else:
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
query, user_id, filters, get_default_document_index()
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return QAResponse(
|
||||
@ -57,6 +68,7 @@ def answer_question(
|
||||
lower_ranked_docs=None,
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
)
|
||||
|
||||
if disable_generative_answer:
|
||||
@ -70,6 +82,7 @@ def answer_question(
|
||||
# to run QA over more documents
|
||||
predicted_flow=QueryFlow.SEARCH,
|
||||
predicted_search=predicted_search,
|
||||
query_event_id=query_event_id,
|
||||
)
|
||||
|
||||
try:
|
||||
@ -83,6 +96,7 @@ def answer_question(
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
error_msg=str(e),
|
||||
query_event_id=query_event_id,
|
||||
)
|
||||
|
||||
chunk_offset = offset_count * NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
@ -108,4 +122,5 @@ def answer_question(
|
||||
predicted_flow=predicted_flow,
|
||||
predicted_search=predicted_search,
|
||||
error_msg=error_msg,
|
||||
query_event_id=query_event_id,
|
||||
)
|
||||
|
@ -9,6 +9,7 @@ from slack_sdk import WebClient
|
||||
from slack_sdk.socket_mode import SocketModeClient
|
||||
from slack_sdk.socket_mode.request import SocketModeRequest
|
||||
from slack_sdk.socket_mode.response import SocketModeResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.configs.app_configs import DANSWER_BOT_ANSWER_GENERATION_TIMEOUT
|
||||
from danswer.configs.app_configs import DANSWER_BOT_DISPLAY_ERROR_MSGS
|
||||
@ -18,7 +19,8 @@ from danswer.configs.app_configs import DOCUMENT_INDEX_NAME
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.connectors.slack.utils import make_slack_api_rate_limited
|
||||
from danswer.connectors.slack.utils import UserIdReplacer
|
||||
from danswer.direct_qa.answer_question import answer_question
|
||||
from danswer.db.engine import get_sqlalchemy_engine
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.interfaces import DanswerQuote
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
@ -228,17 +230,19 @@ def process_slack_event(client: SocketModeClient, req: SocketModeRequest) -> Non
|
||||
logger=cast(logging.Logger, logger),
|
||||
)
|
||||
def _get_answer(question: QuestionRequest) -> QAResponse:
|
||||
answer = answer_question(
|
||||
question=question,
|
||||
user=None,
|
||||
answer_generation_timeout=DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
engine = get_sqlalchemy_engine()
|
||||
with Session(engine, expire_on_commit=False) as db_session:
|
||||
answer = answer_qa_query(
|
||||
question=question,
|
||||
user=None,
|
||||
db_session=db_session,
|
||||
answer_generation_timeout=DANSWER_BOT_ANSWER_GENERATION_TIMEOUT,
|
||||
)
|
||||
if not answer.error_msg:
|
||||
return answer
|
||||
else:
|
||||
raise RuntimeError(answer.error_msg)
|
||||
|
||||
answer = None
|
||||
try:
|
||||
answer = _get_answer(
|
||||
QuestionRequest(
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from langchain.chat_models.azure_openai import AzureChatOpenAI
|
||||
@ -22,6 +23,11 @@ class AzureGPT(LangChainChatLLM):
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
# set a dummy API key if not specified so that LangChain doesn't throw an
|
||||
# exception when trying to initialize the LLM which would prevent the API
|
||||
# server from starting up
|
||||
if not api_key:
|
||||
api_key = os.environ.get("OPENAI_API_KEY") or "dummy_api_key"
|
||||
self._llm = AzureChatOpenAI(
|
||||
model=model_version,
|
||||
openai_api_type="azure",
|
||||
|
@ -1,3 +1,4 @@
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from langchain.chat_models.openai import ChatOpenAI
|
||||
@ -16,6 +17,11 @@ class OpenAIGPT(LangChainChatLLM):
|
||||
*args: list[Any],
|
||||
**kwargs: dict[str, Any]
|
||||
):
|
||||
# set a dummy API key if not specified so that LangChain doesn't throw an
|
||||
# exception when trying to initialize the LLM which would prevent the API
|
||||
# server from starting up
|
||||
if not api_key:
|
||||
api_key = os.environ.get("OPENAI_API_KEY") or "dummy_api_key"
|
||||
self._llm = ChatOpenAI(
|
||||
model=model_version,
|
||||
openai_api_key=api_key,
|
||||
|
@ -49,6 +49,7 @@ from danswer.db.deletion_attempt import check_deletion_attempt_is_allowed
|
||||
from danswer.db.deletion_attempt import create_deletion_attempt
|
||||
from danswer.db.deletion_attempt import get_deletion_attempts
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.feedback import fetch_docs_ranked_by_boost
|
||||
from danswer.db.index_attempt import create_index_attempt
|
||||
from danswer.db.index_attempt import get_latest_index_attempts
|
||||
from danswer.db.models import DeletionAttempt
|
||||
@ -61,6 +62,7 @@ from danswer.dynamic_configs.interface import ConfigNotFoundError
|
||||
from danswer.server.models import ApiKey
|
||||
from danswer.server.models import AuthStatus
|
||||
from danswer.server.models import AuthUrl
|
||||
from danswer.server.models import BoostDoc
|
||||
from danswer.server.models import ConnectorBase
|
||||
from danswer.server.models import ConnectorCredentialPairIdentifier
|
||||
from danswer.server.models import ConnectorIndexingStatus
|
||||
@ -79,7 +81,6 @@ from danswer.server.models import StatusResponse
|
||||
from danswer.server.models import UserRoleResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
|
||||
|
||||
router = APIRouter(prefix="/manage")
|
||||
logger = setup_logger()
|
||||
|
||||
@ -89,6 +90,28 @@ _GOOGLE_DRIVE_CREDENTIAL_ID_COOKIE_NAME = "google_drive_credential_id"
|
||||
"""Admin only API endpoints"""
|
||||
|
||||
|
||||
@router.get("/admin/doc-boosts")
|
||||
def get_most_boosted_docs(
|
||||
ascending: bool,
|
||||
limit: int,
|
||||
_: User | None = Depends(current_admin_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> list[BoostDoc]:
|
||||
boost_docs = fetch_docs_ranked_by_boost(
|
||||
ascending=ascending, limit=limit, db_session=db_session
|
||||
)
|
||||
return [
|
||||
BoostDoc(
|
||||
document_id=doc.id,
|
||||
semantic_id=doc.semantic_id,
|
||||
link=doc.link or "",
|
||||
boost=doc.boost,
|
||||
hidden=doc.hidden,
|
||||
)
|
||||
for doc in boost_docs
|
||||
]
|
||||
|
||||
|
||||
@router.get("/admin/connector/google-drive/app-credential")
|
||||
def check_google_app_credentials_exist(
|
||||
_: User = Depends(current_admin_user),
|
||||
|
@ -11,6 +11,8 @@ from pydantic.generics import GenericModel
|
||||
|
||||
from danswer.configs.app_configs import MASK_CREDENTIAL_PREFIX
|
||||
from danswer.configs.constants import DocumentSource
|
||||
from danswer.configs.constants import QAFeedbackType
|
||||
from danswer.configs.constants import SearchFeedbackType
|
||||
from danswer.connectors.models import InputType
|
||||
from danswer.datastores.interfaces import IndexFilter
|
||||
from danswer.db.models import Connector
|
||||
@ -105,6 +107,14 @@ class UserRoleResponse(BaseModel):
|
||||
role: str
|
||||
|
||||
|
||||
class BoostDoc(BaseModel):
|
||||
document_id: str
|
||||
semantic_id: str
|
||||
link: str
|
||||
boost: int
|
||||
hidden: bool
|
||||
|
||||
|
||||
class SearchDoc(BaseModel):
|
||||
document_id: str
|
||||
semantic_identifier: str
|
||||
@ -121,10 +131,24 @@ class QuestionRequest(BaseModel):
|
||||
offset: int | None
|
||||
|
||||
|
||||
class QAFeedbackRequest(BaseModel):
|
||||
query_id: int
|
||||
feedback: QAFeedbackType
|
||||
|
||||
|
||||
class SearchFeedbackRequest(BaseModel):
|
||||
query_id: int
|
||||
document_id: str
|
||||
document_rank: int
|
||||
click: bool
|
||||
search_feedback: SearchFeedbackType
|
||||
|
||||
|
||||
class SearchResponse(BaseModel):
|
||||
# For semantic search, top docs are reranked, the remaining are as ordered from retrieval
|
||||
top_ranked_docs: list[SearchDoc] | None
|
||||
lower_ranked_docs: list[SearchDoc] | None
|
||||
query_event_id: int
|
||||
|
||||
|
||||
class QAResponse(SearchResponse):
|
||||
|
@ -5,16 +5,22 @@ from dataclasses import asdict
|
||||
from fastapi import APIRouter
|
||||
from fastapi import Depends
|
||||
from fastapi.responses import StreamingResponse
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from danswer.auth.users import current_user
|
||||
from danswer.chunking.models import InferenceChunk
|
||||
from danswer.configs.app_configs import DISABLE_GENERATIVE_AI
|
||||
from danswer.configs.app_configs import NUM_GENERATIVE_AI_INPUT_DOCS
|
||||
from danswer.datastores.document_index import get_default_document_index
|
||||
from danswer.db.engine import get_session
|
||||
from danswer.db.feedback import create_doc_retrieval_feedback
|
||||
from danswer.db.feedback import create_query_event
|
||||
from danswer.db.feedback import update_query_event_feedback
|
||||
from danswer.db.models import User
|
||||
from danswer.direct_qa.answer_question import answer_question
|
||||
from danswer.direct_qa.answer_question import answer_qa_query
|
||||
from danswer.direct_qa.exceptions import OpenAIKeyMissing
|
||||
from danswer.direct_qa.exceptions import UnknownModelError
|
||||
from danswer.direct_qa.interfaces import DanswerAnswerPiece
|
||||
from danswer.direct_qa.llm_utils import get_default_qa_model
|
||||
from danswer.search.danswer_helper import query_intent
|
||||
from danswer.search.danswer_helper import recommend_search_flow
|
||||
@ -24,8 +30,10 @@ from danswer.search.models import SearchType
|
||||
from danswer.search.semantic_search import chunks_to_search_docs
|
||||
from danswer.search.semantic_search import retrieve_ranked_documents
|
||||
from danswer.server.models import HelperResponse
|
||||
from danswer.server.models import QAFeedbackRequest
|
||||
from danswer.server.models import QAResponse
|
||||
from danswer.server.models import QuestionRequest
|
||||
from danswer.server.models import SearchFeedbackRequest
|
||||
from danswer.server.models import SearchResponse
|
||||
from danswer.utils.logger import setup_logger
|
||||
from danswer.utils.timing import log_generator_function_time
|
||||
@ -50,62 +58,95 @@ def get_search_type(
|
||||
|
||||
@router.post("/semantic-search")
|
||||
def semantic_search(
|
||||
question: QuestionRequest, user: User = Depends(current_user)
|
||||
question: QuestionRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
logger.info(f"Received semantic search query: {query}")
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.SEMANTIC,
|
||||
llm_answer=None,
|
||||
user_id=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
ranked_chunks, unranked_chunks = retrieve_ranked_documents(
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
query, user_id, filters, get_default_document_index()
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||
return SearchResponse(
|
||||
top_ranked_docs=None, lower_ranked_docs=None, query_event_id=query_event_id
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
other_top_docs = chunks_to_search_docs(unranked_chunks)
|
||||
|
||||
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=other_top_docs)
|
||||
return SearchResponse(
|
||||
top_ranked_docs=top_docs,
|
||||
lower_ranked_docs=other_top_docs,
|
||||
query_event_id=query_event_id,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/keyword-search")
|
||||
def keyword_search(
|
||||
question: QuestionRequest, user: User = Depends(current_user)
|
||||
question: QuestionRequest,
|
||||
user: User = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> SearchResponse:
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
logger.info(f"Received keyword search query: {query}")
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.KEYWORD,
|
||||
llm_answer=None,
|
||||
user_id=user.id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
user_id = None if user is None else user.id
|
||||
ranked_chunks = retrieve_keyword_documents(
|
||||
query, user_id, filters, get_default_document_index(collection=collection)
|
||||
query, user_id, filters, get_default_document_index()
|
||||
)
|
||||
if not ranked_chunks:
|
||||
return SearchResponse(top_ranked_docs=None, lower_ranked_docs=None)
|
||||
return SearchResponse(
|
||||
top_ranked_docs=None, lower_ranked_docs=None, query_event_id=query_event_id
|
||||
)
|
||||
|
||||
top_docs = chunks_to_search_docs(ranked_chunks)
|
||||
return SearchResponse(top_ranked_docs=top_docs, lower_ranked_docs=None)
|
||||
return SearchResponse(
|
||||
top_ranked_docs=top_docs, lower_ranked_docs=None, query_event_id=query_event_id
|
||||
)
|
||||
|
||||
|
||||
@router.post("/direct-qa")
|
||||
def direct_qa(
|
||||
question: QuestionRequest, user: User = Depends(current_user)
|
||||
question: QuestionRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> QAResponse:
|
||||
return answer_question(question=question, user=user)
|
||||
return answer_qa_query(question=question, user=user, db_session=db_session)
|
||||
|
||||
|
||||
@router.post("/stream-direct-qa")
|
||||
def stream_direct_qa(
|
||||
question: QuestionRequest, user: User = Depends(current_user)
|
||||
question: QuestionRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> StreamingResponse:
|
||||
send_packet_debug_msg = "Sending Packet: {}"
|
||||
top_documents_key = "top_documents"
|
||||
unranked_top_docs_key = "unranked_top_documents"
|
||||
predicted_flow_key = "predicted_flow"
|
||||
predicted_search_key = "predicted_search"
|
||||
query_event_id_key = "query_event_id"
|
||||
|
||||
logger.debug(f"Received QA query: {question.query}")
|
||||
logger.debug(f"Query filters: {question.filters}")
|
||||
@ -116,8 +157,8 @@ def stream_direct_qa(
|
||||
def stream_qa_portions(
|
||||
disable_generative_answer: bool = DISABLE_GENERATIVE_AI,
|
||||
) -> Generator[str, None, None]:
|
||||
answer_so_far: str = ""
|
||||
query = question.query
|
||||
collection = question.collection
|
||||
filters = question.filters
|
||||
use_keyword = question.use_keyword
|
||||
offset_count = question.offset if question.offset is not None else 0
|
||||
@ -132,7 +173,7 @@ def stream_direct_qa(
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
get_default_document_index(collection=collection),
|
||||
get_default_document_index(),
|
||||
)
|
||||
unranked_chunks: list[InferenceChunk] | None = []
|
||||
else:
|
||||
@ -140,7 +181,7 @@ def stream_direct_qa(
|
||||
query,
|
||||
user_id,
|
||||
filters,
|
||||
get_default_document_index(collection=collection),
|
||||
get_default_document_index(),
|
||||
)
|
||||
if not ranked_chunks:
|
||||
logger.debug("No Documents Found")
|
||||
@ -194,6 +235,11 @@ def stream_direct_qa(
|
||||
):
|
||||
if response_packet is None:
|
||||
continue
|
||||
if (
|
||||
isinstance(response_packet, DanswerAnswerPiece)
|
||||
and response_packet.answer_piece
|
||||
):
|
||||
answer_so_far = answer_so_far + response_packet.answer_piece
|
||||
logger.debug(f"Sending packet: {response_packet}")
|
||||
yield get_json_line(asdict(response_packet))
|
||||
except Exception as e:
|
||||
@ -201,6 +247,49 @@ def stream_direct_qa(
|
||||
yield get_json_line({"error": str(e)})
|
||||
logger.exception("Failed to run QA")
|
||||
|
||||
query_event_id = create_query_event(
|
||||
query=query,
|
||||
selected_flow=SearchType.KEYWORD
|
||||
if question.use_keyword
|
||||
else SearchType.SEMANTIC,
|
||||
llm_answer=answer_so_far,
|
||||
user_id=user_id,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
yield get_json_line({query_event_id_key: query_event_id})
|
||||
|
||||
return
|
||||
|
||||
return StreamingResponse(stream_qa_portions(), media_type="application/json")
|
||||
|
||||
|
||||
@router.post("/query-feedback")
|
||||
def process_query_feedback(
|
||||
feedback: QAFeedbackRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
update_query_event_feedback(
|
||||
feedback=feedback.feedback,
|
||||
query_id=feedback.query_id,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/doc-retrieval-feedback")
|
||||
def process_doc_retrieval_feedback(
|
||||
feedback: SearchFeedbackRequest,
|
||||
user: User | None = Depends(current_user),
|
||||
db_session: Session = Depends(get_session),
|
||||
) -> None:
|
||||
create_doc_retrieval_feedback(
|
||||
qa_event_id=feedback.query_id,
|
||||
document_id=feedback.document_id,
|
||||
document_rank=feedback.document_rank,
|
||||
clicked=feedback.click,
|
||||
feedback=feedback.search_feedback,
|
||||
user_id=user.id if user is not None else None,
|
||||
db_session=db_session,
|
||||
)
|
||||
|
Loading…
x
Reference in New Issue
Block a user