mirror of
https://github.com/danswer-ai/danswer.git
synced 2025-05-03 16:30:21 +02:00
* Added ability to use a tag to insert the current datetime in prompts * made tagging logic more robust * rename * k --------- Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
438 lines
14 KiB
Python
438 lines
14 KiB
Python
import uuid
|
|
from uuid import UUID
|
|
|
|
from fastapi import APIRouter
|
|
from fastapi import Depends
|
|
from fastapi import HTTPException
|
|
from fastapi import Query
|
|
from fastapi import UploadFile
|
|
from pydantic import BaseModel
|
|
from sqlalchemy.exc import IntegrityError
|
|
from sqlalchemy.orm import Session
|
|
|
|
from onyx.auth.users import current_admin_user
|
|
from onyx.auth.users import current_chat_accesssible_user
|
|
from onyx.auth.users import current_curator_or_admin_user
|
|
from onyx.auth.users import current_limited_user
|
|
from onyx.auth.users import current_user
|
|
from onyx.configs.constants import FileOrigin
|
|
from onyx.configs.constants import MilestoneRecordType
|
|
from onyx.configs.constants import NotificationType
|
|
from onyx.db.engine import get_current_tenant_id
|
|
from onyx.db.engine import get_session
|
|
from onyx.db.models import StarterMessageModel as StarterMessage
|
|
from onyx.db.models import User
|
|
from onyx.db.notification import create_notification
|
|
from onyx.db.persona import create_assistant_label
|
|
from onyx.db.persona import create_update_persona
|
|
from onyx.db.persona import delete_persona_label
|
|
from onyx.db.persona import get_assistant_labels
|
|
from onyx.db.persona import get_persona_by_id
|
|
from onyx.db.persona import get_personas_for_user
|
|
from onyx.db.persona import mark_persona_as_deleted
|
|
from onyx.db.persona import mark_persona_as_not_deleted
|
|
from onyx.db.persona import update_all_personas_display_priority
|
|
from onyx.db.persona import update_persona_label
|
|
from onyx.db.persona import update_persona_public_status
|
|
from onyx.db.persona import update_persona_shared_users
|
|
from onyx.db.persona import update_persona_visibility
|
|
from onyx.db.prompts import build_prompt_name_from_persona_name
|
|
from onyx.db.prompts import upsert_prompt
|
|
from onyx.file_store.file_store import get_default_file_store
|
|
from onyx.file_store.models import ChatFileType
|
|
from onyx.secondary_llm_flows.starter_message_creation import (
|
|
generate_starter_messages,
|
|
)
|
|
from onyx.server.features.persona.models import GenerateStarterMessageRequest
|
|
from onyx.server.features.persona.models import ImageGenerationToolStatus
|
|
from onyx.server.features.persona.models import PersonaLabelCreate
|
|
from onyx.server.features.persona.models import PersonaLabelResponse
|
|
from onyx.server.features.persona.models import PersonaSharedNotificationData
|
|
from onyx.server.features.persona.models import PersonaSnapshot
|
|
from onyx.server.features.persona.models import PersonaUpsertRequest
|
|
from onyx.server.features.persona.models import PromptSnapshot
|
|
from onyx.server.models import DisplayPriorityRequest
|
|
from onyx.tools.utils import is_image_generation_available
|
|
from onyx.utils.logger import setup_logger
|
|
from onyx.utils.telemetry import create_milestone_and_report
|
|
|
|
|
|
logger = setup_logger()
|
|
|
|
|
|
admin_router = APIRouter(prefix="/admin/persona")
|
|
basic_router = APIRouter(prefix="/persona")
|
|
|
|
|
|
class IsVisibleRequest(BaseModel):
|
|
is_visible: bool
|
|
|
|
|
|
class IsPublicRequest(BaseModel):
|
|
is_public: bool
|
|
|
|
|
|
@admin_router.patch("/{persona_id}/visible")
|
|
def patch_persona_visibility(
|
|
persona_id: int,
|
|
is_visible_request: IsVisibleRequest,
|
|
user: User | None = Depends(current_curator_or_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
update_persona_visibility(
|
|
persona_id=persona_id,
|
|
is_visible=is_visible_request.is_visible,
|
|
db_session=db_session,
|
|
user=user,
|
|
)
|
|
|
|
|
|
@basic_router.patch("/{persona_id}/public")
|
|
def patch_user_presona_public_status(
|
|
persona_id: int,
|
|
is_public_request: IsPublicRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
try:
|
|
update_persona_public_status(
|
|
persona_id=persona_id,
|
|
is_public=is_public_request.is_public,
|
|
db_session=db_session,
|
|
user=user,
|
|
)
|
|
except ValueError as e:
|
|
logger.exception("Failed to update persona public status")
|
|
raise HTTPException(status_code=403, detail=str(e))
|
|
|
|
|
|
@admin_router.put("/display-priority")
|
|
def patch_persona_display_priority(
|
|
display_priority_request: DisplayPriorityRequest,
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
update_all_personas_display_priority(
|
|
display_priority_map=display_priority_request.display_priority_map,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
@admin_router.get("")
|
|
def list_personas_admin(
|
|
user: User | None = Depends(current_curator_or_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
include_deleted: bool = False,
|
|
get_editable: bool = Query(False, description="If true, return editable personas"),
|
|
) -> list[PersonaSnapshot]:
|
|
return [
|
|
PersonaSnapshot.from_model(persona)
|
|
for persona in get_personas_for_user(
|
|
db_session=db_session,
|
|
user=user,
|
|
get_editable=get_editable,
|
|
include_deleted=include_deleted,
|
|
joinedload_all=True,
|
|
)
|
|
]
|
|
|
|
|
|
@admin_router.patch("/{persona_id}/undelete")
|
|
def undelete_persona(
|
|
persona_id: int,
|
|
user: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
mark_persona_as_not_deleted(
|
|
persona_id=persona_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
# used for assistat profile pictures
|
|
@admin_router.post("/upload-image")
|
|
def upload_file(
|
|
file: UploadFile,
|
|
db_session: Session = Depends(get_session),
|
|
_: User | None = Depends(current_user),
|
|
) -> dict[str, str]:
|
|
file_store = get_default_file_store(db_session)
|
|
file_type = ChatFileType.IMAGE
|
|
file_id = str(uuid.uuid4())
|
|
file_store.save_file(
|
|
file_name=file_id,
|
|
content=file.file,
|
|
display_name=file.filename,
|
|
file_origin=FileOrigin.CHAT_UPLOAD,
|
|
file_type=file.content_type or file_type.value,
|
|
)
|
|
return {"file_id": file_id}
|
|
|
|
|
|
"""Endpoints for all"""
|
|
|
|
|
|
@basic_router.post("")
|
|
def create_persona(
|
|
persona_upsert_request: PersonaUpsertRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
tenant_id: str | None = Depends(get_current_tenant_id),
|
|
) -> PersonaSnapshot:
|
|
prompt_id = (
|
|
persona_upsert_request.prompt_ids[0]
|
|
if persona_upsert_request.prompt_ids
|
|
and len(persona_upsert_request.prompt_ids) > 0
|
|
else None
|
|
)
|
|
prompt = upsert_prompt(
|
|
db_session=db_session,
|
|
user=user,
|
|
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
|
system_prompt=persona_upsert_request.system_prompt,
|
|
task_prompt=persona_upsert_request.task_prompt,
|
|
datetime_aware=persona_upsert_request.datetime_aware,
|
|
include_citations=persona_upsert_request.include_citations,
|
|
prompt_id=prompt_id,
|
|
)
|
|
prompt_snapshot = PromptSnapshot.from_model(prompt)
|
|
persona_upsert_request.prompt_ids = [prompt.id]
|
|
persona_snapshot = create_update_persona(
|
|
persona_id=None,
|
|
create_persona_request=persona_upsert_request,
|
|
user=user,
|
|
db_session=db_session,
|
|
)
|
|
persona_snapshot.prompts = [prompt_snapshot]
|
|
create_milestone_and_report(
|
|
user=user,
|
|
distinct_id=tenant_id or "N/A",
|
|
event_type=MilestoneRecordType.CREATED_ASSISTANT,
|
|
properties=None,
|
|
db_session=db_session,
|
|
)
|
|
|
|
return persona_snapshot
|
|
|
|
|
|
# NOTE: This endpoint cannot update persona configuration options that
|
|
# are core to the persona, such as its display priority and
|
|
# whether or not the assistant is a built-in / default assistant
|
|
@basic_router.patch("/{persona_id}")
|
|
def update_persona(
|
|
persona_id: int,
|
|
persona_upsert_request: PersonaUpsertRequest,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> PersonaSnapshot:
|
|
prompt_id = (
|
|
persona_upsert_request.prompt_ids[0]
|
|
if persona_upsert_request.prompt_ids
|
|
and len(persona_upsert_request.prompt_ids) > 0
|
|
else None
|
|
)
|
|
prompt = upsert_prompt(
|
|
db_session=db_session,
|
|
user=user,
|
|
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
|
|
datetime_aware=persona_upsert_request.datetime_aware,
|
|
system_prompt=persona_upsert_request.system_prompt,
|
|
task_prompt=persona_upsert_request.task_prompt,
|
|
include_citations=persona_upsert_request.include_citations,
|
|
prompt_id=prompt_id,
|
|
)
|
|
prompt_snapshot = PromptSnapshot.from_model(prompt)
|
|
persona_upsert_request.prompt_ids = [prompt.id]
|
|
persona_snapshot = create_update_persona(
|
|
persona_id=persona_id,
|
|
create_persona_request=persona_upsert_request,
|
|
user=user,
|
|
db_session=db_session,
|
|
)
|
|
persona_snapshot.prompts = [prompt_snapshot]
|
|
return persona_snapshot
|
|
|
|
|
|
class PersonaLabelPatchRequest(BaseModel):
|
|
label_name: str
|
|
|
|
|
|
@basic_router.get("/labels")
|
|
def get_labels(
|
|
db: Session = Depends(get_session),
|
|
_: User | None = Depends(current_user),
|
|
) -> list[PersonaLabelResponse]:
|
|
return [
|
|
PersonaLabelResponse.from_model(label)
|
|
for label in get_assistant_labels(db_session=db)
|
|
]
|
|
|
|
|
|
@basic_router.post("/labels")
|
|
def create_label(
|
|
label: PersonaLabelCreate,
|
|
db: Session = Depends(get_session),
|
|
_: User | None = Depends(current_user),
|
|
) -> PersonaLabelResponse:
|
|
"""Create a new assistant label"""
|
|
try:
|
|
label_model = create_assistant_label(name=label.name, db_session=db)
|
|
return PersonaLabelResponse.from_model(label_model)
|
|
except IntegrityError:
|
|
raise HTTPException(
|
|
status_code=400,
|
|
detail=f"Label with name '{label.name}' already exists. Please choose a different name.",
|
|
)
|
|
|
|
|
|
@admin_router.patch("/label/{label_id}")
|
|
def patch_persona_label(
|
|
label_id: int,
|
|
persona_label_patch_request: PersonaLabelPatchRequest,
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
update_persona_label(
|
|
label_id=label_id,
|
|
label_name=persona_label_patch_request.label_name,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
@admin_router.delete("/label/{label_id}")
|
|
def delete_label(
|
|
label_id: int,
|
|
_: User | None = Depends(current_admin_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
delete_persona_label(label_id=label_id, db_session=db_session)
|
|
|
|
|
|
class PersonaShareRequest(BaseModel):
|
|
user_ids: list[UUID]
|
|
|
|
|
|
# We notify each user when a user is shared with them
|
|
@basic_router.patch("/{persona_id}/share")
|
|
def share_persona(
|
|
persona_id: int,
|
|
persona_share_request: PersonaShareRequest,
|
|
user: User = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
update_persona_shared_users(
|
|
persona_id=persona_id,
|
|
user_ids=persona_share_request.user_ids,
|
|
user=user,
|
|
db_session=db_session,
|
|
)
|
|
|
|
for user_id in persona_share_request.user_ids:
|
|
# Don't notify the user that they have access to their own persona
|
|
if user_id != user.id:
|
|
create_notification(
|
|
user_id=user_id,
|
|
notif_type=NotificationType.PERSONA_SHARED,
|
|
db_session=db_session,
|
|
additional_data=PersonaSharedNotificationData(
|
|
persona_id=persona_id,
|
|
).model_dump(),
|
|
)
|
|
|
|
|
|
@basic_router.delete("/{persona_id}")
|
|
def delete_persona(
|
|
persona_id: int,
|
|
user: User | None = Depends(current_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> None:
|
|
mark_persona_as_deleted(
|
|
persona_id=persona_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
)
|
|
|
|
|
|
@basic_router.get("/image-generation-tool")
|
|
def get_image_generation_tool(
|
|
_: User
|
|
| None = Depends(current_user), # User param not used but kept for consistency
|
|
db_session: Session = Depends(get_session),
|
|
) -> ImageGenerationToolStatus: # Use bool instead of str for boolean values
|
|
is_available = is_image_generation_available(db_session=db_session)
|
|
return ImageGenerationToolStatus(is_available=is_available)
|
|
|
|
|
|
@basic_router.get("")
|
|
def list_personas(
|
|
user: User | None = Depends(current_chat_accesssible_user),
|
|
db_session: Session = Depends(get_session),
|
|
include_deleted: bool = False,
|
|
persona_ids: list[int] = Query(None),
|
|
) -> list[PersonaSnapshot]:
|
|
personas = get_personas_for_user(
|
|
user=user,
|
|
include_deleted=include_deleted,
|
|
db_session=db_session,
|
|
get_editable=False,
|
|
joinedload_all=True,
|
|
)
|
|
|
|
if persona_ids:
|
|
personas = [p for p in personas if p.id in persona_ids]
|
|
|
|
# Filter out personas with unavailable tools
|
|
personas = [
|
|
p
|
|
for p in personas
|
|
if not (
|
|
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
|
|
and not is_image_generation_available(db_session=db_session)
|
|
)
|
|
]
|
|
|
|
return [PersonaSnapshot.from_model(p) for p in personas]
|
|
|
|
|
|
@basic_router.get("/{persona_id}")
|
|
def get_persona(
|
|
persona_id: int,
|
|
user: User | None = Depends(current_limited_user),
|
|
db_session: Session = Depends(get_session),
|
|
) -> PersonaSnapshot:
|
|
return PersonaSnapshot.from_model(
|
|
get_persona_by_id(
|
|
persona_id=persona_id,
|
|
user=user,
|
|
db_session=db_session,
|
|
is_for_edit=False,
|
|
)
|
|
)
|
|
|
|
|
|
@basic_router.post("/assistant-prompt-refresh")
|
|
def build_assistant_prompts(
|
|
generate_persona_prompt_request: GenerateStarterMessageRequest,
|
|
db_session: Session = Depends(get_session),
|
|
user: User | None = Depends(current_user),
|
|
) -> list[StarterMessage]:
|
|
try:
|
|
logger.info(
|
|
f"Generating {generate_persona_prompt_request.generation_count} starter messages"
|
|
f" for user: {user.id if user else 'Anonymous'}",
|
|
)
|
|
starter_messages = generate_starter_messages(
|
|
name=generate_persona_prompt_request.name,
|
|
description=generate_persona_prompt_request.description,
|
|
instructions=generate_persona_prompt_request.instructions,
|
|
document_set_ids=generate_persona_prompt_request.document_set_ids,
|
|
generation_count=generate_persona_prompt_request.generation_count,
|
|
db_session=db_session,
|
|
user=user,
|
|
)
|
|
return starter_messages
|
|
except Exception as e:
|
|
logger.exception("Failed to generate starter messages")
|
|
raise HTTPException(status_code=500, detail=str(e))
|