mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-19 00:00:37 +02:00
236 lines
6.7 KiB
Python
236 lines
6.7 KiB
Python
import uuid
|
|
from datetime import datetime
|
|
from typing import Any
|
|
from typing import Literal
|
|
from typing import Optional
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from pydantic import BaseModel
|
|
from pydantic import Field
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_user
|
|
from onyx.configs.constants import MessageType
|
|
from onyx.db.chat import create_new_chat_message
|
|
from onyx.db.chat import get_chat_message
|
|
from onyx.db.chat import get_chat_messages_by_session
|
|
from onyx.db.chat import get_chat_session_by_id
|
|
from onyx.db.chat import get_or_create_root_message
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.models import User
|
|
from onyx.llm.utils import check_number_of_tokens
|
|
|
|
router = APIRouter(prefix="")
|
|
|
|
|
|
Role = Literal["user", "assistant"]
|
|
|
|
|
|
class MessageContent(BaseModel):
|
|
type: Literal["text"]
|
|
text: str
|
|
|
|
|
|
class Message(BaseModel):
|
|
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
|
|
object: Literal["thread.message"] = "thread.message"
|
|
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
|
|
thread_id: str
|
|
role: Role
|
|
content: list[MessageContent]
|
|
file_ids: list[str] = []
|
|
assistant_id: Optional[str] = None
|
|
run_id: Optional[str] = None
|
|
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
|
|
|
|
|
|
class CreateMessageRequest(BaseModel):
|
|
role: Role
|
|
content: str
|
|
file_ids: list[str] = []
|
|
metadata: Optional[dict] = None
|
|
|
|
|
|
class ListMessagesResponse(BaseModel):
|
|
object: Literal["list"] = "list"
|
|
data: list[Message]
|
|
first_id: str
|
|
last_id: str
|
|
has_more: bool
|
|
|
|
|
|
@router.post("/threads/{thread_id}/messages")
|
|
def create_message(
|
|
thread_id: str,
|
|
message: CreateMessageRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> Message:
|
|
user_id = user.id if user else None
|
|
|
|
try:
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=uuid.UUID(thread_id),
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail="Chat session not found")
|
|
|
|
chat_messages = get_chat_messages_by_session(
|
|
chat_session_id=chat_session.id,
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
)
|
|
latest_message = (
|
|
chat_messages[-1]
|
|
if chat_messages
|
|
else get_or_create_root_message(chat_session.id, db_session)
|
|
)
|
|
|
|
new_message = create_new_chat_message(
|
|
chat_session_id=chat_session.id,
|
|
parent_message=latest_message,
|
|
message=message.content,
|
|
prompt_id=chat_session.persona.prompts[0].id,
|
|
token_count=check_number_of_tokens(message.content),
|
|
message_type=(
|
|
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
|
|
),
|
|
db_session=db_session,
|
|
)
|
|
|
|
return Message(
|
|
id=str(new_message.id),
|
|
thread_id=thread_id,
|
|
role="user",
|
|
content=[MessageContent(type="text", text=message.content)],
|
|
file_ids=message.file_ids,
|
|
metadata=message.metadata,
|
|
)
|
|
|
|
|
|
@router.get("/threads/{thread_id}/messages")
|
|
def list_messages(
|
|
thread_id: str,
|
|
limit: int = 20,
|
|
order: Literal["asc", "desc"] = "desc",
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> ListMessagesResponse:
|
|
user_id = user.id if user else None
|
|
|
|
try:
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=uuid.UUID(thread_id),
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail="Chat session not found")
|
|
|
|
messages = get_chat_messages_by_session(
|
|
chat_session_id=chat_session.id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
|
|
# Apply filtering based on after and before
|
|
if after:
|
|
messages = [m for m in messages if str(m.id) >= after]
|
|
if before:
|
|
messages = [m for m in messages if str(m.id) <= before]
|
|
|
|
# Apply ordering
|
|
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
|
|
|
|
# Apply limit
|
|
messages = messages[:limit]
|
|
|
|
data = [
|
|
Message(
|
|
id=str(m.id),
|
|
thread_id=thread_id,
|
|
role="user" if m.message_type == "user" else "assistant",
|
|
content=[MessageContent(type="text", text=m.message)],
|
|
created_at=int(m.time_sent.timestamp()),
|
|
)
|
|
for m in messages
|
|
]
|
|
|
|
return ListMessagesResponse(
|
|
data=data,
|
|
first_id=str(data[0].id) if data else "",
|
|
last_id=str(data[-1].id) if data else "",
|
|
has_more=len(messages) == limit,
|
|
)
|
|
|
|
|
|
@router.get("/threads/{thread_id}/messages/{message_id}")
|
|
def retrieve_message(
|
|
thread_id: str,
|
|
message_id: int,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> Message:
|
|
user_id = user.id if user else None
|
|
|
|
try:
|
|
chat_message = get_chat_message(
|
|
chat_message_id=message_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail="Message not found")
|
|
|
|
return Message(
|
|
id=str(chat_message.id),
|
|
thread_id=thread_id,
|
|
role="user" if chat_message.message_type == "user" else "assistant",
|
|
content=[MessageContent(type="text", text=chat_message.message)],
|
|
created_at=int(chat_message.time_sent.timestamp()),
|
|
)
|
|
|
|
|
|
class ModifyMessageRequest(BaseModel):
|
|
metadata: dict
|
|
|
|
|
|
@router.post("/threads/{thread_id}/messages/{message_id}")
|
|
def modify_message(
|
|
thread_id: str,
|
|
message_id: int,
|
|
request: ModifyMessageRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> Message:
|
|
user_id = user.id if user else None
|
|
|
|
try:
|
|
chat_message = get_chat_message(
|
|
chat_message_id=message_id,
|
|
user_id=user_id,
|
|
db_session=db_session,
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail="Message not found")
|
|
|
|
# Update metadata
|
|
# TODO: Uncomment this once we have metadata in the chat message
|
|
# chat_message.metadata = request.metadata
|
|
# db_session.commit()
|
|
|
|
return Message(
|
|
id=str(chat_message.id),
|
|
thread_id=thread_id,
|
|
role="user" if chat_message.message_type == "user" else "assistant",
|
|
content=[MessageContent(type="text", text=chat_message.message)],
|
|
created_at=int(chat_message.time_sent.timestamp()),
|
|
metadata=request.metadata,
|
|
)
|