mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-06 09:50:27 +02:00
* Combined Persona and Prompt API * quality * added tests * consolidated models and got rid of redundant fields * tenant appreciation day * reverted default
274 lines
7.7 KiB
Python
274 lines
7.7 KiB
Python
from typing import Any
|
|
from typing import Optional
|
|
from uuid import uuid4
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from fastapi import Query
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_user
|
|
from onyx.context.search.enums import RecencyBiasSetting
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.models import Persona
|
|
from onyx.db.models import User
|
|
from onyx.db.persona import get_persona_by_id
|
|
from onyx.db.persona import get_personas_for_user
|
|
from onyx.db.persona import mark_persona_as_deleted
|
|
from onyx.db.persona import upsert_persona
|
|
from onyx.db.prompts import upsert_prompt
|
|
from onyx.db.tools import get_tool_by_name
|
|
from onyx.utils.logger import setup_logger
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
router = APIRouter(prefix="/assistants")
|
|
|
|
|
|
# Base models
|
|
class AssistantObject(BaseModel):
|
|
id: int
|
|
object: str = "assistant"
|
|
created_at: int
|
|
name: Optional[str] = None
|
|
description: Optional[str] = None
|
|
model: str
|
|
instructions: Optional[str] = None
|
|
tools: list[dict[str, Any]]
|
|
file_ids: list[str]
|
|
metadata: Optional[dict[str, Any]] = None
|
|
|
|
|
|
class CreateAssistantRequest(BaseModel):
|
|
model: str
|
|
name: Optional[str] = None
|
|
description: Optional[str] = None
|
|
instructions: Optional[str] = None
|
|
tools: Optional[list[dict[str, Any]]] = None
|
|
file_ids: Optional[list[str]] = None
|
|
metadata: Optional[dict[str, Any]] = None
|
|
|
|
|
|
class ModifyAssistantRequest(BaseModel):
|
|
model: Optional[str] = None
|
|
name: Optional[str] = None
|
|
description: Optional[str] = None
|
|
instructions: Optional[str] = None
|
|
tools: Optional[list[dict[str, Any]]] = None
|
|
file_ids: Optional[list[str]] = None
|
|
metadata: Optional[dict[str, Any]] = None
|
|
|
|
|
|
class DeleteAssistantResponse(BaseModel):
|
|
id: int
|
|
object: str = "assistant.deleted"
|
|
deleted: bool
|
|
|
|
|
|
class ListAssistantsResponse(BaseModel):
|
|
object: str = "list"
|
|
data: list[AssistantObject]
|
|
first_id: Optional[int] = None
|
|
last_id: Optional[int] = None
|
|
has_more: bool
|
|
|
|
|
|
def persona_to_assistant(persona: Persona) -> AssistantObject:
|
|
return AssistantObject(
|
|
id=persona.id,
|
|
created_at=0,
|
|
name=persona.name,
|
|
description=persona.description,
|
|
model=persona.llm_model_version_override or "gpt-3.5-turbo",
|
|
instructions=persona.prompts[0].system_prompt if persona.prompts else None,
|
|
tools=[
|
|
{
|
|
"type": tool.display_name,
|
|
"function": {
|
|
"name": tool.name,
|
|
"description": tool.description,
|
|
"schema": tool.openapi_schema,
|
|
},
|
|
}
|
|
for tool in persona.tools
|
|
],
|
|
file_ids=[], # Assuming no file support for now
|
|
metadata={}, # Assuming no metadata for now
|
|
)
|
|
|
|
|
|
# API endpoints
|
|
@router.post("")
|
|
def create_assistant(
|
|
request: CreateAssistantRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> AssistantObject:
|
|
prompt = None
|
|
if request.instructions:
|
|
prompt = upsert_prompt(
|
|
user=user,
|
|
name=f"Prompt for {request.name or 'New Assistant'}",
|
|
description="Auto-generated prompt",
|
|
system_prompt=request.instructions,
|
|
task_prompt="",
|
|
include_citations=True,
|
|
datetime_aware=True,
|
|
personas=[],
|
|
db_session=db_session,
|
|
)
|
|
|
|
tool_ids = []
|
|
for tool in request.tools or []:
|
|
tool_type = tool.get("type")
|
|
if not tool_type:
|
|
continue
|
|
|
|
try:
|
|
tool_db = get_tool_by_name(tool_type, db_session)
|
|
tool_ids.append(tool_db.id)
|
|
except ValueError:
|
|
# Skip tools that don't exist in the database
|
|
logger.error(f"Tool {tool_type} not found in database")
|
|
raise HTTPException(
|
|
status_code=404, detail=f"Tool {tool_type} not found in database"
|
|
)
|
|
|
|
persona = upsert_persona(
|
|
user=user,
|
|
name=request.name or f"Assistant-{uuid4()}",
|
|
description=request.description or "",
|
|
num_chunks=25,
|
|
llm_relevance_filter=True,
|
|
llm_filter_extraction=True,
|
|
recency_bias=RecencyBiasSetting.AUTO,
|
|
llm_model_provider_override=None,
|
|
llm_model_version_override=request.model,
|
|
starter_messages=None,
|
|
is_public=False,
|
|
db_session=db_session,
|
|
prompt_ids=[prompt.id] if prompt else [0],
|
|
document_set_ids=[],
|
|
tool_ids=tool_ids,
|
|
icon_color=None,
|
|
icon_shape=None,
|
|
is_visible=True,
|
|
)
|
|
|
|
if prompt:
|
|
prompt.personas = [persona]
|
|
db_session.commit()
|
|
|
|
return persona_to_assistant(persona)
|
|
|
|
|
|
""
|
|
|
|
|
|
@router.get("/{assistant_id}")
|
|
def retrieve_assistant(
|
|
assistant_id: int,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> AssistantObject:
|
|
try:
|
|
persona = get_persona_by_id(
|
|
persona_id=assistant_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
is_for_edit=False,
|
|
)
|
|
except ValueError:
|
|
persona = None
|
|
|
|
if not persona:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
return persona_to_assistant(persona)
|
|
|
|
|
|
@router.post("/{assistant_id}")
|
|
def modify_assistant(
|
|
assistant_id: int,
|
|
request: ModifyAssistantRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> AssistantObject:
|
|
persona = get_persona_by_id(
|
|
persona_id=assistant_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
is_for_edit=True,
|
|
)
|
|
if not persona:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
|
|
update_data = request.model_dump(exclude_unset=True)
|
|
for key, value in update_data.items():
|
|
setattr(persona, key, value)
|
|
|
|
if "instructions" in update_data and persona.prompts:
|
|
persona.prompts[0].system_prompt = update_data["instructions"]
|
|
|
|
db_session.commit()
|
|
return persona_to_assistant(persona)
|
|
|
|
|
|
@router.delete("/{assistant_id}")
|
|
def delete_assistant(
|
|
assistant_id: int,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> DeleteAssistantResponse:
|
|
try:
|
|
mark_persona_as_deleted(
|
|
persona_id=int(assistant_id),
|
|
user=user,
|
|
db_session=db_session,
|
|
)
|
|
return DeleteAssistantResponse(id=assistant_id, deleted=True)
|
|
except ValueError:
|
|
raise HTTPException(status_code=404, detail="Assistant not found")
|
|
|
|
|
|
@router.get("")
|
|
def list_assistants(
|
|
limit: int = Query(20, le=100),
|
|
order: str = Query("desc", regex="^(asc|desc)$"),
|
|
after: Optional[int] = None,
|
|
before: Optional[int] = None,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> ListAssistantsResponse:
|
|
personas = list(
|
|
get_personas_for_user(
|
|
user=user,
|
|
db_session=db_session,
|
|
get_editable=False,
|
|
joinedload_all=True,
|
|
)
|
|
)
|
|
|
|
# Apply filtering based on after and before
|
|
if after:
|
|
personas = [p for p in personas if p.id > int(after)]
|
|
if before:
|
|
personas = [p for p in personas if p.id < int(before)]
|
|
|
|
# Apply ordering
|
|
personas.sort(key=lambda p: p.id, reverse=(order == "desc"))
|
|
|
|
# Apply limit
|
|
personas = personas[:limit]
|
|
|
|
assistants = [persona_to_assistant(p) for p in personas]
|
|
|
|
return ListAssistantsResponse(
|
|
data=assistants,
|
|
first_id=assistants[0].id if assistants else None,
|
|
last_id=assistants[-1].id if assistants else None,
|
|
has_more=len(personas) == limit,
|
|
)
|