mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-06 09:50:27 +02:00
346 lines
11 KiB
Python
346 lines
11 KiB
Python
from typing import Literal
|
|
from typing import Optional
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi import BackgroundTasks
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_user
|
|
from onyx.chat.process_message import stream_chat_message_objects
|
|
from onyx.configs.constants import MessageType
|
|
from onyx.context.search.models import RetrievalDetails
|
|
from onyx.db.chat import create_new_chat_message
|
|
from onyx.db.chat import get_chat_message
|
|
from onyx.db.chat import get_chat_messages_by_session
|
|
from onyx.db.chat import get_chat_session_by_id
|
|
from onyx.db.chat import get_or_create_root_message
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.models import ChatMessage
|
|
from onyx.db.models import User
|
|
from onyx.server.query_and_chat.models import ChatMessageDetail
|
|
from onyx.server.query_and_chat.models import CreateChatMessageRequest
|
|
from onyx.tools.tool_implementations.search.search_tool import SearchTool
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
router = APIRouter()
|
|
|
|
|
|
class RunRequest(BaseModel):
|
|
assistant_id: int
|
|
model: Optional[str] = None
|
|
instructions: Optional[str] = None
|
|
additional_instructions: Optional[str] = None
|
|
tools: Optional[list[dict]] = None
|
|
metadata: Optional[dict] = None
|
|
|
|
|
|
RunStatus = Literal[
|
|
"queued",
|
|
"in_progress",
|
|
"requires_action",
|
|
"cancelling",
|
|
"cancelled",
|
|
"failed",
|
|
"completed",
|
|
"expired",
|
|
]
|
|
|
|
|
|
class RunResponse(BaseModel):
|
|
id: str
|
|
object: Literal["thread.run"]
|
|
created_at: int
|
|
assistant_id: int
|
|
thread_id: UUID
|
|
status: RunStatus
|
|
started_at: Optional[int] = None
|
|
expires_at: Optional[int] = None
|
|
cancelled_at: Optional[int] = None
|
|
failed_at: Optional[int] = None
|
|
completed_at: Optional[int] = None
|
|
last_error: Optional[dict] = None
|
|
model: str
|
|
instructions: str
|
|
tools: list[dict]
|
|
file_ids: list[str]
|
|
metadata: Optional[dict] = None
|
|
|
|
|
|
def process_run_in_background(
|
|
message_id: int,
|
|
parent_message_id: int,
|
|
chat_session_id: UUID,
|
|
assistant_id: int,
|
|
instructions: str,
|
|
tools: list[dict],
|
|
user: User | None,
|
|
db_session: Session,
|
|
) -> None:
|
|
# Get the latest message in the chat session
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=chat_session_id,
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
)
|
|
|
|
search_tool_retrieval_details = RetrievalDetails()
|
|
for tool in tools:
|
|
if tool["type"] == SearchTool.__name__ and (
|
|
retrieval_details := tool.get("retrieval_details")
|
|
):
|
|
search_tool_retrieval_details = RetrievalDetails.model_validate(
|
|
retrieval_details
|
|
)
|
|
break
|
|
|
|
new_msg_req = CreateChatMessageRequest(
|
|
chat_session_id=chat_session_id,
|
|
parent_message_id=int(parent_message_id) if parent_message_id else None,
|
|
message=instructions,
|
|
file_descriptors=[],
|
|
prompt_id=chat_session.persona.prompts[0].id,
|
|
search_doc_ids=None,
|
|
retrieval_options=search_tool_retrieval_details, # Adjust as needed
|
|
rerank_settings=None,
|
|
query_override=None,
|
|
regenerate=None,
|
|
llm_override=None,
|
|
prompt_override=None,
|
|
alternate_assistant_id=assistant_id,
|
|
use_existing_user_message=True,
|
|
existing_assistant_message_id=message_id,
|
|
)
|
|
|
|
run_message = get_chat_message(message_id, user.id if user else None, db_session)
|
|
try:
|
|
for packet in stream_chat_message_objects(
|
|
new_msg_req=new_msg_req,
|
|
user=user,
|
|
db_session=db_session,
|
|
):
|
|
if isinstance(packet, ChatMessageDetail):
|
|
# Update the run status and message content
|
|
run_message = get_chat_message(
|
|
message_id, user.id if user else None, db_session
|
|
)
|
|
if run_message:
|
|
# this handles cancelling
|
|
if run_message.error:
|
|
return
|
|
|
|
run_message.message = packet.message
|
|
run_message.message_type = MessageType.ASSISTANT
|
|
db_session.commit()
|
|
except Exception as e:
|
|
logger.exception("Error processing run in background")
|
|
run_message.error = str(e)
|
|
db_session.commit()
|
|
return
|
|
|
|
db_session.refresh(run_message)
|
|
if run_message.token_count == 0:
|
|
run_message.error = "No tokens generated"
|
|
db_session.commit()
|
|
|
|
|
|
@router.post("/threads/{thread_id}/runs")
|
|
def create_run(
|
|
thread_id: UUID,
|
|
run_request: RunRequest,
|
|
background_tasks: BackgroundTasks,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> RunResponse:
|
|
try:
|
|
chat_session = get_chat_session_by_id(
|
|
chat_session_id=thread_id,
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail="Thread not found")
|
|
|
|
chat_messages = get_chat_messages_by_session(
|
|
chat_session_id=chat_session.id,
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
)
|
|
latest_message = (
|
|
chat_messages[-1]
|
|
if chat_messages
|
|
else get_or_create_root_message(chat_session.id, db_session)
|
|
)
|
|
|
|
# Create a new "run" (chat message) in the session
|
|
new_message = create_new_chat_message(
|
|
chat_session_id=chat_session.id,
|
|
parent_message=latest_message,
|
|
message="",
|
|
prompt_id=chat_session.persona.prompts[0].id,
|
|
token_count=0,
|
|
message_type=MessageType.ASSISTANT,
|
|
db_session=db_session,
|
|
commit=False,
|
|
)
|
|
db_session.flush()
|
|
latest_message.latest_child_message = new_message.id
|
|
db_session.commit()
|
|
|
|
# Schedule the background task
|
|
background_tasks.add_task(
|
|
process_run_in_background,
|
|
new_message.id,
|
|
latest_message.id,
|
|
chat_session.id,
|
|
run_request.assistant_id,
|
|
run_request.instructions or "",
|
|
run_request.tools or [],
|
|
user,
|
|
db_session,
|
|
)
|
|
|
|
return RunResponse(
|
|
id=str(new_message.id),
|
|
object="thread.run",
|
|
created_at=int(new_message.time_sent.timestamp()),
|
|
assistant_id=run_request.assistant_id,
|
|
thread_id=chat_session.id,
|
|
status="queued",
|
|
model=run_request.model or "default_model",
|
|
instructions=run_request.instructions or "",
|
|
tools=run_request.tools or [],
|
|
file_ids=[],
|
|
metadata=run_request.metadata,
|
|
)
|
|
|
|
|
|
@router.get("/threads/{thread_id}/runs/{run_id}")
|
|
def retrieve_run(
|
|
thread_id: UUID,
|
|
run_id: str,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> RunResponse:
|
|
# Retrieve the chat message (which represents a "run" in DAnswer)
|
|
chat_message = get_chat_message(
|
|
chat_message_id=int(run_id), # Convert string run_id to int
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
)
|
|
if not chat_message:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
|
|
chat_session = chat_message.chat_session
|
|
|
|
# Map DAnswer status to OpenAI status
|
|
run_status: RunStatus = "queued"
|
|
if chat_message.message:
|
|
run_status = "in_progress"
|
|
if chat_message.token_count != 0:
|
|
run_status = "completed"
|
|
if chat_message.error:
|
|
run_status = "cancelled"
|
|
|
|
return RunResponse(
|
|
id=run_id,
|
|
object="thread.run",
|
|
created_at=int(chat_message.time_sent.timestamp()),
|
|
assistant_id=chat_session.persona_id or 0,
|
|
thread_id=chat_session.id,
|
|
status=run_status,
|
|
started_at=int(chat_message.time_sent.timestamp()),
|
|
completed_at=(
|
|
int(chat_message.time_sent.timestamp()) if chat_message.message else None
|
|
),
|
|
model=chat_session.current_alternate_model or "default_model",
|
|
instructions="", # DAnswer doesn't store per-message instructions
|
|
tools=[], # DAnswer doesn't have a direct equivalent for tools
|
|
file_ids=(
|
|
[file["id"] for file in chat_message.files] if chat_message.files else []
|
|
),
|
|
metadata=None, # DAnswer doesn't store metadata for individual messages
|
|
)
|
|
|
|
|
|
@router.post("/threads/{thread_id}/runs/{run_id}/cancel")
|
|
def cancel_run(
|
|
thread_id: UUID,
|
|
run_id: str,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> RunResponse:
|
|
# In DAnswer, we don't have a direct equivalent to cancelling a run
|
|
# We'll simulate it by marking the message as "cancelled"
|
|
chat_message = (
|
|
db_session.query(ChatMessage).filter(ChatMessage.id == run_id).first()
|
|
)
|
|
if not chat_message:
|
|
raise HTTPException(status_code=404, detail="Run not found")
|
|
|
|
chat_message.error = "Cancelled"
|
|
db_session.commit()
|
|
|
|
return retrieve_run(thread_id, run_id, user, db_session)
|
|
|
|
|
|
@router.get("/threads/{thread_id}/runs")
|
|
def list_runs(
|
|
thread_id: UUID,
|
|
limit: int = 20,
|
|
order: Literal["asc", "desc"] = "desc",
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> list[RunResponse]:
|
|
# In DAnswer, we'll treat each message in a chat session as a "run"
|
|
chat_messages = get_chat_messages_by_session(
|
|
chat_session_id=thread_id,
|
|
user_id=user.id if user else None,
|
|
db_session=db_session,
|
|
)
|
|
|
|
# Apply pagination
|
|
if after:
|
|
chat_messages = [msg for msg in chat_messages if str(msg.id) > after]
|
|
if before:
|
|
chat_messages = [msg for msg in chat_messages if str(msg.id) < before]
|
|
|
|
# Apply ordering
|
|
chat_messages = sorted(
|
|
chat_messages, key=lambda msg: msg.time_sent, reverse=(order == "desc")
|
|
)
|
|
|
|
# Apply limit
|
|
chat_messages = chat_messages[:limit]
|
|
|
|
return [
|
|
retrieve_run(thread_id, str(msg.id), user, db_session) for msg in chat_messages
|
|
]
|
|
|
|
|
|
@router.get("/threads/{thread_id}/runs/{run_id}/steps")
|
|
def list_run_steps(
|
|
run_id: str,
|
|
limit: int = 20,
|
|
order: Literal["asc", "desc"] = "desc",
|
|
after: Optional[str] = None,
|
|
before: Optional[str] = None,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> list[dict]: # You may want to create a specific model for run steps
|
|
# DAnswer doesn't have an equivalent to run steps
|
|
# We'll return an empty list to maintain API compatibility
|
|
return []
|
|
|
|
|
|
# Additional helper functions can be added here if needed
|