Fix Query History (#56)

This commit is contained in:
Yuhong Sun 2024-03-27 12:10:55 -07:00 committed by Chris Weaver
parent 7b16cb9562
commit d6c5c65b51
4 changed files with 57 additions and 138 deletions

View File

@ -1,105 +1,20 @@
import datetime import datetime
from collections.abc import Sequence
from typing import cast
from typing import Literal from typing import Literal
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import InstrumentedAttribute
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession from danswer.db.models import ChatSession
from danswer.db.models import User
SortByOptions = Literal["time_sent"] SortByOptions = Literal["time_sent"]
def build_query_history_query( def fetch_chat_sessions_by_time(
start: datetime.datetime, start: datetime.datetime, end: datetime.datetime, db_session: Session
end: datetime.datetime, ) -> list[ChatSession]:
sort_by_field: SortByOptions, chat_sessions = (
sort_by_direction: Literal["asc", "desc"], db_session.query(ChatSession)
offset: int, .filter(ChatSession.time_created >= start, ChatSession.time_created <= end)
limit: int | None, .all()
) -> Select[tuple[ChatMessage]]:
stmt = (
select(ChatMessage)
.where(
ChatMessage.time_sent >= start,
)
.where(
ChatMessage.time_sent <= end,
)
.where(
or_(
ChatMessage.message_type == MessageType.ASSISTANT,
ChatMessage.message_type == MessageType.USER,
),
)
) )
order_by_field = cast(InstrumentedAttribute, getattr(ChatMessage, sort_by_field)) return chat_sessions
if sort_by_direction == "asc":
stmt = stmt.order_by(order_by_field.asc())
else:
stmt = stmt.order_by(order_by_field.desc())
if offset:
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
return stmt
def fetch_query_history(
db_session: Session,
start: datetime.datetime,
end: datetime.datetime,
sort_by_field: SortByOptions = "time_sent",
sort_by_direction: Literal["asc", "desc"] = "desc",
offset: int = 0,
limit: int | None = 500,
) -> Sequence[ChatMessage]:
stmt = build_query_history_query(
start=start,
end=end,
sort_by_field=sort_by_field,
sort_by_direction=sort_by_direction,
offset=offset,
limit=limit,
)
return db_session.scalars(stmt).all()
def fetch_query_history_with_user_email(
db_session: Session,
start: datetime.datetime,
end: datetime.datetime,
sort_by_field: SortByOptions = "time_sent",
sort_by_direction: Literal["asc", "desc"] = "desc",
offset: int = 0,
limit: int | None = 500,
) -> Sequence[tuple[ChatMessage, str | None]]:
subquery = build_query_history_query(
start=start,
end=end,
sort_by_field=sort_by_field,
sort_by_direction=sort_by_direction,
offset=offset,
limit=limit,
).subquery()
subquery_alias = aliased(ChatMessage, subquery)
stmt_with_user_email = (
select(subquery_alias, User.email) # type: ignore
.join(ChatSession, subquery_alias.chat_session_id == ChatSession.id)
.join(User, ChatSession.user_id == User.id, isouter=True)
)
return db_session.execute(stmt_with_user_email).all() # type: ignore

View File

@ -1,6 +1,5 @@
import csv import csv
import io import io
from collections import defaultdict
from datetime import datetime from datetime import datetime
from datetime import timedelta from datetime import timedelta
from datetime import timezone from datetime import timezone
@ -14,14 +13,14 @@ from sqlalchemy.orm import Session
import danswer.db.models as db_models import danswer.db.models as db_models
from danswer.auth.users import current_admin_user from danswer.auth.users import current_admin_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.configs.constants import MessageType from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType from danswer.configs.constants import QAFeedbackType
from danswer.db.chat import get_chat_session_by_id from danswer.db.chat import get_chat_session_by_id
from danswer.db.engine import get_session from danswer.db.engine import get_session
from danswer.db.models import ChatMessage from danswer.db.models import ChatMessage
from ee.danswer.db.query_history import ( from danswer.db.models import ChatSession
fetch_query_history_with_user_email, from ee.danswer.db.query_history import fetch_chat_sessions_by_time
)
router = APIRouter() router = APIRouter()
@ -80,36 +79,16 @@ class ChatSessionSnapshot(BaseModel):
user_email: str | None user_email: str | None
name: str | None name: str | None
messages: list[MessageSnapshot] messages: list[MessageSnapshot]
persona_name: str
time_created: datetime time_created: datetime
@classmethod
def build(
cls,
messages: list[ChatMessage],
) -> "ChatSessionSnapshot":
if len(messages) == 0:
raise ValueError("No messages provided")
chat_session = messages[0].chat_session
return cls(
id=chat_session.id,
user_email=chat_session.user.email if chat_session.user else None,
name=chat_session.description,
messages=[
MessageSnapshot.build(message)
for message in sorted(messages, key=lambda m: m.time_sent)
if message.message_type != MessageType.SYSTEM
],
time_created=chat_session.time_created,
)
class QuestionAnswerPairSnapshot(BaseModel): class QuestionAnswerPairSnapshot(BaseModel):
user_message: str user_message: str
ai_response: str ai_response: str
retrieved_documents: list[AbridgedSearchDoc] retrieved_documents: list[AbridgedSearchDoc]
feedback: QAFeedbackType | None feedback: QAFeedbackType | None
persona_name: str
time_created: datetime time_created: datetime
@classmethod @classmethod
@ -132,6 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
ai_response=ai_message.message, ai_response=ai_message.message,
retrieved_documents=ai_message.documents, retrieved_documents=ai_message.documents,
feedback=ai_message.feedback, feedback=ai_message.feedback,
persona_name=chat_session_snapshot.persona_name,
time_created=user_message.time_created, time_created=user_message.time_created,
) )
for user_message, ai_message in message_pairs for user_message, ai_message in message_pairs
@ -148,34 +128,27 @@ class QuestionAnswerPairSnapshot(BaseModel):
] ]
), ),
"feedback": self.feedback.value if self.feedback else "", "feedback": self.feedback.value if self.feedback else "",
"persona_name": self.persona_name,
"time_created": str(self.time_created), "time_created": str(self.time_created),
} }
def fetch_and_process_chat_session_history( def fetch_and_process_chat_session_history(
db_session: Session, db_session: Session,
start: datetime | None, start: datetime,
end: datetime | None, end: datetime,
feedback_type: QAFeedbackType | None, feedback_type: QAFeedbackType | None,
limit: int | None = 500, limit: int | None = 500,
) -> list[ChatSessionSnapshot]: ) -> list[ChatSessionSnapshot]:
chat_messages_with_user_email = fetch_query_history_with_user_email( chat_sessions = fetch_chat_sessions_by_time(
db_session=db_session, start=start, end=end, db_session=db_session
start=start
or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
limit=limit,
) )
session_id_to_messages: dict[int, list[ChatMessage]] = defaultdict(list)
for message, _ in chat_messages_with_user_email:
session_id_to_messages[message.chat_session_id].append(message)
chat_session_snapshots = [ chat_session_snapshots = [
ChatSessionSnapshot.build(messages) snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for messages in session_id_to_messages.values() for chat_session in chat_sessions
] ]
if feedback_type: if feedback_type:
chat_session_snapshots = [ chat_session_snapshots = [
chat_session_snapshot chat_session_snapshot
@ -186,7 +159,32 @@ def fetch_and_process_chat_session_history(
) )
] ]
return chat_session_snapshots chat_session_snapshots.sort(key=lambda x: x.time_created, reverse=True)
return chat_session_snapshots[:limit]
def snapshot_from_chat_session(
chat_session: ChatSession,
db_session: Session,
) -> ChatSessionSnapshot:
last_message, messages = create_chat_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
return ChatSessionSnapshot(
id=chat_session.id,
user_email=chat_session.user.email if chat_session.user else None,
name=chat_session.description,
messages=[
MessageSnapshot.build(message)
for message in messages
if message.message_type != MessageType.SYSTEM
],
persona_name=chat_session.persona.name,
time_created=chat_session.time_created,
)
@router.get("/admin/chat-session-history") @router.get("/admin/chat-session-history")
@ -199,8 +197,11 @@ def get_chat_session_history(
) -> list[ChatSessionSnapshot]: ) -> list[ChatSessionSnapshot]:
return fetch_and_process_chat_session_history( return fetch_and_process_chat_session_history(
db_session=db_session, db_session=db_session,
start=start, start=start
end=end, or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
feedback_type=feedback_type, feedback_type=feedback_type,
) )
@ -223,7 +224,7 @@ def get_chat_session_admin(
400, f"Chat session with id '{chat_session_id}' does not exist." 400, f"Chat session with id '{chat_session_id}' does not exist."
) )
return ChatSessionSnapshot.build(messages=chat_session.messages) return snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
@router.get("/admin/query-history-csv") @router.get("/admin/query-history-csv")

View File

@ -31,5 +31,6 @@ export interface ChatSessionSnapshot {
user_email: string | null; user_email: string | null;
name: string | null; name: string | null;
messages: MessageSnapshot[]; messages: MessageSnapshot[];
persona_name: string | null;
time_created: string; time_created: string;
} }

View File

@ -61,6 +61,7 @@ function QueryHistoryTableRow({
<FeedbackBadge feedback={finalFeedback} /> <FeedbackBadge feedback={finalFeedback} />
</TableCell> </TableCell>
<TableCell>{chatSessionSnapshot.user_email || "-"}</TableCell> <TableCell>{chatSessionSnapshot.user_email || "-"}</TableCell>
<TableCell>{chatSessionSnapshot.persona_name || "Unknown"}</TableCell>
<TableCell> <TableCell>
{timestampToReadableDate(chatSessionSnapshot.time_created)} {timestampToReadableDate(chatSessionSnapshot.time_created)}
</TableCell> </TableCell>
@ -144,6 +145,7 @@ export function QueryHistoryTable() {
<TableHeaderCell>First AI Response</TableHeaderCell> <TableHeaderCell>First AI Response</TableHeaderCell>
<TableHeaderCell>Feedback</TableHeaderCell> <TableHeaderCell>Feedback</TableHeaderCell>
<TableHeaderCell>User</TableHeaderCell> <TableHeaderCell>User</TableHeaderCell>
<TableHeaderCell>Persona</TableHeaderCell>
<TableHeaderCell>Date</TableHeaderCell> <TableHeaderCell>Date</TableHeaderCell>
</TableRow> </TableRow>
</TableHead> </TableHead>