2024-12-13 09:56:10 -08:00

236 lines
6.7 KiB
Python

import uuid
from datetime import datetime
from typing import Any
from typing import Literal
from typing import Optional
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from pydantic import BaseModel
from pydantic import Field
from sqlalchemy.orm import Session
from onyx.auth.users import current_user
from onyx.configs.constants import MessageType
from onyx.db.chat import create_new_chat_message
from onyx.db.chat import get_chat_message
from onyx.db.chat import get_chat_messages_by_session
from onyx.db.chat import get_chat_session_by_id
from onyx.db.chat import get_or_create_root_message
from onyx.db.engine import get_session
from onyx.db.models import User
from onyx.llm.utils import check_number_of_tokens
router = APIRouter(prefix="")
Role = Literal["user", "assistant"]
class MessageContent(BaseModel):
type: Literal["text"]
text: str
class Message(BaseModel):
id: str = Field(default_factory=lambda: f"msg_{uuid.uuid4()}")
object: Literal["thread.message"] = "thread.message"
created_at: int = Field(default_factory=lambda: int(datetime.now().timestamp()))
thread_id: str
role: Role
content: list[MessageContent]
file_ids: list[str] = []
assistant_id: Optional[str] = None
run_id: Optional[str] = None
metadata: Optional[dict[str, Any]] = None # Change this line to use dict[str, Any]
class CreateMessageRequest(BaseModel):
role: Role
content: str
file_ids: list[str] = []
metadata: Optional[dict] = None
class ListMessagesResponse(BaseModel):
object: Literal["list"] = "list"
data: list[Message]
first_id: str
last_id: str
has_more: bool
@router.post("/threads/{thread_id}/messages")
def create_message(
thread_id: str,
message: CreateMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
chat_messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user.id if user else None,
db_session=db_session,
)
latest_message = (
chat_messages[-1]
if chat_messages
else get_or_create_root_message(chat_session.id, db_session)
)
new_message = create_new_chat_message(
chat_session_id=chat_session.id,
parent_message=latest_message,
message=message.content,
prompt_id=chat_session.persona.prompts[0].id,
token_count=check_number_of_tokens(message.content),
message_type=(
MessageType.USER if message.role == "user" else MessageType.ASSISTANT
),
db_session=db_session,
)
return Message(
id=str(new_message.id),
thread_id=thread_id,
role="user",
content=[MessageContent(type="text", text=message.content)],
file_ids=message.file_ids,
metadata=message.metadata,
)
@router.get("/threads/{thread_id}/messages")
def list_messages(
thread_id: str,
limit: int = 20,
order: Literal["asc", "desc"] = "desc",
after: Optional[str] = None,
before: Optional[str] = None,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> ListMessagesResponse:
user_id = user.id if user else None
try:
chat_session = get_chat_session_by_id(
chat_session_id=uuid.UUID(thread_id),
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Chat session not found")
messages = get_chat_messages_by_session(
chat_session_id=chat_session.id,
user_id=user_id,
db_session=db_session,
)
# Apply filtering based on after and before
if after:
messages = [m for m in messages if str(m.id) >= after]
if before:
messages = [m for m in messages if str(m.id) <= before]
# Apply ordering
messages = sorted(messages, key=lambda m: m.id, reverse=(order == "desc"))
# Apply limit
messages = messages[:limit]
data = [
Message(
id=str(m.id),
thread_id=thread_id,
role="user" if m.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=m.message)],
created_at=int(m.time_sent.timestamp()),
)
for m in messages
]
return ListMessagesResponse(
data=data,
first_id=str(data[0].id) if data else "",
last_id=str(data[-1].id) if data else "",
has_more=len(messages) == limit,
)
@router.get("/threads/{thread_id}/messages/{message_id}")
def retrieve_message(
thread_id: str,
message_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
)
class ModifyMessageRequest(BaseModel):
metadata: dict
@router.post("/threads/{thread_id}/messages/{message_id}")
def modify_message(
thread_id: str,
message_id: int,
request: ModifyMessageRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> Message:
user_id = user.id if user else None
try:
chat_message = get_chat_message(
chat_message_id=message_id,
user_id=user_id,
db_session=db_session,
)
except ValueError:
raise HTTPException(status_code=404, detail="Message not found")
# Update metadata
# TODO: Uncomment this once we have metadata in the chat message
# chat_message.metadata = request.metadata
# db_session.commit()
return Message(
id=str(chat_message.id),
thread_id=thread_id,
role="user" if chat_message.message_type == "user" else "assistant",
content=[MessageContent(type="text", text=chat_message.message)],
created_at=int(chat_message.time_sent.timestamp()),
metadata=request.metadata,
)