Fix Query History (#56)

This commit is contained in:
Yuhong Sun 2024-03-27 12:10:55 -07:00 committed by Chris Weaver
parent 7b16cb9562
commit d6c5c65b51
4 changed files with 57 additions and 138 deletions

View File

@ -1,105 +1,20 @@
import datetime
from collections.abc import Sequence
from typing import cast
from typing import Literal
from sqlalchemy import or_
from sqlalchemy import Select
from sqlalchemy import select
from sqlalchemy.orm import aliased
from sqlalchemy.orm import Session
from sqlalchemy.orm.attributes import InstrumentedAttribute
from danswer.configs.constants import MessageType
from danswer.db.models import ChatMessage
from danswer.db.models import ChatSession
from danswer.db.models import User
SortByOptions = Literal["time_sent"]
def build_query_history_query(
start: datetime.datetime,
end: datetime.datetime,
sort_by_field: SortByOptions,
sort_by_direction: Literal["asc", "desc"],
offset: int,
limit: int | None,
) -> Select[tuple[ChatMessage]]:
stmt = (
select(ChatMessage)
.where(
ChatMessage.time_sent >= start,
)
.where(
ChatMessage.time_sent <= end,
)
.where(
or_(
ChatMessage.message_type == MessageType.ASSISTANT,
ChatMessage.message_type == MessageType.USER,
),
)
def fetch_chat_sessions_by_time(
start: datetime.datetime, end: datetime.datetime, db_session: Session
) -> list[ChatSession]:
chat_sessions = (
db_session.query(ChatSession)
.filter(ChatSession.time_created >= start, ChatSession.time_created <= end)
.all()
)
order_by_field = cast(InstrumentedAttribute, getattr(ChatMessage, sort_by_field))
if sort_by_direction == "asc":
stmt = stmt.order_by(order_by_field.asc())
else:
stmt = stmt.order_by(order_by_field.desc())
if offset:
stmt = stmt.offset(offset)
if limit:
stmt = stmt.limit(limit)
return stmt
def fetch_query_history(
db_session: Session,
start: datetime.datetime,
end: datetime.datetime,
sort_by_field: SortByOptions = "time_sent",
sort_by_direction: Literal["asc", "desc"] = "desc",
offset: int = 0,
limit: int | None = 500,
) -> Sequence[ChatMessage]:
stmt = build_query_history_query(
start=start,
end=end,
sort_by_field=sort_by_field,
sort_by_direction=sort_by_direction,
offset=offset,
limit=limit,
)
return db_session.scalars(stmt).all()
def fetch_query_history_with_user_email(
db_session: Session,
start: datetime.datetime,
end: datetime.datetime,
sort_by_field: SortByOptions = "time_sent",
sort_by_direction: Literal["asc", "desc"] = "desc",
offset: int = 0,
limit: int | None = 500,
) -> Sequence[tuple[ChatMessage, str | None]]:
subquery = build_query_history_query(
start=start,
end=end,
sort_by_field=sort_by_field,
sort_by_direction=sort_by_direction,
offset=offset,
limit=limit,
).subquery()
subquery_alias = aliased(ChatMessage, subquery)
stmt_with_user_email = (
select(subquery_alias, User.email) # type: ignore
.join(ChatSession, subquery_alias.chat_session_id == ChatSession.id)
.join(User, ChatSession.user_id == User.id, isouter=True)
)
return db_session.execute(stmt_with_user_email).all() # type: ignore
return chat_sessions

View File

@ -1,6 +1,5 @@
import csv
import io
from collections import defaultdict
from datetime import datetime
from datetime import timedelta
from datetime import timezone
@ -14,14 +13,14 @@ from sqlalchemy.orm import Session
import danswer.db.models as db_models
from danswer.auth.users import current_admin_user
from danswer.chat.chat_utils import create_chat_chain
from danswer.configs.constants import MessageType
from danswer.configs.constants import QAFeedbackType
from danswer.db.chat import get_chat_session_by_id
from danswer.db.engine import get_session
from danswer.db.models import ChatMessage
from ee.danswer.db.query_history import (
fetch_query_history_with_user_email,
)
from danswer.db.models import ChatSession
from ee.danswer.db.query_history import fetch_chat_sessions_by_time
router = APIRouter()
@ -80,36 +79,16 @@ class ChatSessionSnapshot(BaseModel):
user_email: str | None
name: str | None
messages: list[MessageSnapshot]
persona_name: str
time_created: datetime
@classmethod
def build(
cls,
messages: list[ChatMessage],
) -> "ChatSessionSnapshot":
if len(messages) == 0:
raise ValueError("No messages provided")
chat_session = messages[0].chat_session
return cls(
id=chat_session.id,
user_email=chat_session.user.email if chat_session.user else None,
name=chat_session.description,
messages=[
MessageSnapshot.build(message)
for message in sorted(messages, key=lambda m: m.time_sent)
if message.message_type != MessageType.SYSTEM
],
time_created=chat_session.time_created,
)
class QuestionAnswerPairSnapshot(BaseModel):
user_message: str
ai_response: str
retrieved_documents: list[AbridgedSearchDoc]
feedback: QAFeedbackType | None
persona_name: str
time_created: datetime
@classmethod
@ -132,6 +111,7 @@ class QuestionAnswerPairSnapshot(BaseModel):
ai_response=ai_message.message,
retrieved_documents=ai_message.documents,
feedback=ai_message.feedback,
persona_name=chat_session_snapshot.persona_name,
time_created=user_message.time_created,
)
for user_message, ai_message in message_pairs
@ -148,34 +128,27 @@ class QuestionAnswerPairSnapshot(BaseModel):
]
),
"feedback": self.feedback.value if self.feedback else "",
"persona_name": self.persona_name,
"time_created": str(self.time_created),
}
def fetch_and_process_chat_session_history(
db_session: Session,
start: datetime | None,
end: datetime | None,
start: datetime,
end: datetime,
feedback_type: QAFeedbackType | None,
limit: int | None = 500,
) -> list[ChatSessionSnapshot]:
chat_messages_with_user_email = fetch_query_history_with_user_email(
db_session=db_session,
start=start
or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
limit=limit,
chat_sessions = fetch_chat_sessions_by_time(
start=start, end=end, db_session=db_session
)
session_id_to_messages: dict[int, list[ChatMessage]] = defaultdict(list)
for message, _ in chat_messages_with_user_email:
session_id_to_messages[message.chat_session_id].append(message)
chat_session_snapshots = [
ChatSessionSnapshot.build(messages)
for messages in session_id_to_messages.values()
snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
for chat_session in chat_sessions
]
if feedback_type:
chat_session_snapshots = [
chat_session_snapshot
@ -186,7 +159,32 @@ def fetch_and_process_chat_session_history(
)
]
return chat_session_snapshots
chat_session_snapshots.sort(key=lambda x: x.time_created, reverse=True)
return chat_session_snapshots[:limit]
def snapshot_from_chat_session(
chat_session: ChatSession,
db_session: Session,
) -> ChatSessionSnapshot:
last_message, messages = create_chat_chain(
chat_session_id=chat_session.id, db_session=db_session
)
messages.append(last_message)
return ChatSessionSnapshot(
id=chat_session.id,
user_email=chat_session.user.email if chat_session.user else None,
name=chat_session.description,
messages=[
MessageSnapshot.build(message)
for message in messages
if message.message_type != MessageType.SYSTEM
],
persona_name=chat_session.persona.name,
time_created=chat_session.time_created,
)
@router.get("/admin/chat-session-history")
@ -199,8 +197,11 @@ def get_chat_session_history(
) -> list[ChatSessionSnapshot]:
return fetch_and_process_chat_session_history(
db_session=db_session,
start=start,
end=end,
start=start
or (
datetime.now(tz=timezone.utc) - timedelta(days=30)
), # default is 30d lookback
end=end or datetime.now(tz=timezone.utc),
feedback_type=feedback_type,
)
@ -223,7 +224,7 @@ def get_chat_session_admin(
400, f"Chat session with id '{chat_session_id}' does not exist."
)
return ChatSessionSnapshot.build(messages=chat_session.messages)
return snapshot_from_chat_session(chat_session=chat_session, db_session=db_session)
@router.get("/admin/query-history-csv")

View File

@ -31,5 +31,6 @@ export interface ChatSessionSnapshot {
user_email: string | null;
name: string | null;
messages: MessageSnapshot[];
persona_name: string | null;
time_created: string;
}

View File

@ -61,6 +61,7 @@ function QueryHistoryTableRow({
<FeedbackBadge feedback={finalFeedback} />
</TableCell>
<TableCell>{chatSessionSnapshot.user_email || "-"}</TableCell>
<TableCell>{chatSessionSnapshot.persona_name || "Unknown"}</TableCell>
<TableCell>
{timestampToReadableDate(chatSessionSnapshot.time_created)}
</TableCell>
@ -144,6 +145,7 @@ export function QueryHistoryTable() {
<TableHeaderCell>First AI Response</TableHeaderCell>
<TableHeaderCell>Feedback</TableHeaderCell>
<TableHeaderCell>User</TableHeaderCell>
<TableHeaderCell>Persona</TableHeaderCell>
<TableHeaderCell>Date</TableHeaderCell>
</TableRow>
</TableHead>