hagen-danswer 3e58cf2667
Added ability to use a tag to insert the current datetime in prompts (#3697)
* Added ability to use a tag to insert the current datetime in prompts

* made tagging logic more robust

* rename

* k

---------

Co-authored-by: Yuhong Sun <yuhongsun96@gmail.com>
2025-01-22 16:17:20 +00:00

438 lines
14 KiB
Python

import uuid
from uuid import UUID
from fastapi import APIRouter
from fastapi import Depends
from fastapi import HTTPException
from fastapi import Query
from fastapi import UploadFile
from pydantic import BaseModel
from sqlalchemy.exc import IntegrityError
from sqlalchemy.orm import Session
from onyx.auth.users import current_admin_user
from onyx.auth.users import current_chat_accesssible_user
from onyx.auth.users import current_curator_or_admin_user
from onyx.auth.users import current_limited_user
from onyx.auth.users import current_user
from onyx.configs.constants import FileOrigin
from onyx.configs.constants import MilestoneRecordType
from onyx.configs.constants import NotificationType
from onyx.db.engine import get_current_tenant_id
from onyx.db.engine import get_session
from onyx.db.models import StarterMessageModel as StarterMessage
from onyx.db.models import User
from onyx.db.notification import create_notification
from onyx.db.persona import create_assistant_label
from onyx.db.persona import create_update_persona
from onyx.db.persona import delete_persona_label
from onyx.db.persona import get_assistant_labels
from onyx.db.persona import get_persona_by_id
from onyx.db.persona import get_personas_for_user
from onyx.db.persona import mark_persona_as_deleted
from onyx.db.persona import mark_persona_as_not_deleted
from onyx.db.persona import update_all_personas_display_priority
from onyx.db.persona import update_persona_label
from onyx.db.persona import update_persona_public_status
from onyx.db.persona import update_persona_shared_users
from onyx.db.persona import update_persona_visibility
from onyx.db.prompts import build_prompt_name_from_persona_name
from onyx.db.prompts import upsert_prompt
from onyx.file_store.file_store import get_default_file_store
from onyx.file_store.models import ChatFileType
from onyx.secondary_llm_flows.starter_message_creation import (
generate_starter_messages,
)
from onyx.server.features.persona.models import GenerateStarterMessageRequest
from onyx.server.features.persona.models import ImageGenerationToolStatus
from onyx.server.features.persona.models import PersonaLabelCreate
from onyx.server.features.persona.models import PersonaLabelResponse
from onyx.server.features.persona.models import PersonaSharedNotificationData
from onyx.server.features.persona.models import PersonaSnapshot
from onyx.server.features.persona.models import PersonaUpsertRequest
from onyx.server.features.persona.models import PromptSnapshot
from onyx.server.models import DisplayPriorityRequest
from onyx.tools.utils import is_image_generation_available
from onyx.utils.logger import setup_logger
from onyx.utils.telemetry import create_milestone_and_report
logger = setup_logger()
admin_router = APIRouter(prefix="/admin/persona")
basic_router = APIRouter(prefix="/persona")
class IsVisibleRequest(BaseModel):
is_visible: bool
class IsPublicRequest(BaseModel):
is_public: bool
@admin_router.patch("/{persona_id}/visible")
def patch_persona_visibility(
persona_id: int,
is_visible_request: IsVisibleRequest,
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_visibility(
persona_id=persona_id,
is_visible=is_visible_request.is_visible,
db_session=db_session,
user=user,
)
@basic_router.patch("/{persona_id}/public")
def patch_user_presona_public_status(
persona_id: int,
is_public_request: IsPublicRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
try:
update_persona_public_status(
persona_id=persona_id,
is_public=is_public_request.is_public,
db_session=db_session,
user=user,
)
except ValueError as e:
logger.exception("Failed to update persona public status")
raise HTTPException(status_code=403, detail=str(e))
@admin_router.put("/display-priority")
def patch_persona_display_priority(
display_priority_request: DisplayPriorityRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_all_personas_display_priority(
display_priority_map=display_priority_request.display_priority_map,
db_session=db_session,
)
@admin_router.get("")
def list_personas_admin(
user: User | None = Depends(current_curator_or_admin_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
get_editable: bool = Query(False, description="If true, return editable personas"),
) -> list[PersonaSnapshot]:
return [
PersonaSnapshot.from_model(persona)
for persona in get_personas_for_user(
db_session=db_session,
user=user,
get_editable=get_editable,
include_deleted=include_deleted,
joinedload_all=True,
)
]
@admin_router.patch("/{persona_id}/undelete")
def undelete_persona(
persona_id: int,
user: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
mark_persona_as_not_deleted(
persona_id=persona_id,
user=user,
db_session=db_session,
)
# used for assistat profile pictures
@admin_router.post("/upload-image")
def upload_file(
file: UploadFile,
db_session: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> dict[str, str]:
file_store = get_default_file_store(db_session)
file_type = ChatFileType.IMAGE
file_id = str(uuid.uuid4())
file_store.save_file(
file_name=file_id,
content=file.file,
display_name=file.filename,
file_origin=FileOrigin.CHAT_UPLOAD,
file_type=file.content_type or file_type.value,
)
return {"file_id": file_id}
"""Endpoints for all"""
@basic_router.post("")
def create_persona(
persona_upsert_request: PersonaUpsertRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
tenant_id: str | None = Depends(get_current_tenant_id),
) -> PersonaSnapshot:
prompt_id = (
persona_upsert_request.prompt_ids[0]
if persona_upsert_request.prompt_ids
and len(persona_upsert_request.prompt_ids) > 0
else None
)
prompt = upsert_prompt(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
datetime_aware=persona_upsert_request.datetime_aware,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
prompt_snapshot = PromptSnapshot.from_model(prompt)
persona_upsert_request.prompt_ids = [prompt.id]
persona_snapshot = create_update_persona(
persona_id=None,
create_persona_request=persona_upsert_request,
user=user,
db_session=db_session,
)
persona_snapshot.prompts = [prompt_snapshot]
create_milestone_and_report(
user=user,
distinct_id=tenant_id or "N/A",
event_type=MilestoneRecordType.CREATED_ASSISTANT,
properties=None,
db_session=db_session,
)
return persona_snapshot
# NOTE: This endpoint cannot update persona configuration options that
# are core to the persona, such as its display priority and
# whether or not the assistant is a built-in / default assistant
@basic_router.patch("/{persona_id}")
def update_persona(
persona_id: int,
persona_upsert_request: PersonaUpsertRequest,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
prompt_id = (
persona_upsert_request.prompt_ids[0]
if persona_upsert_request.prompt_ids
and len(persona_upsert_request.prompt_ids) > 0
else None
)
prompt = upsert_prompt(
db_session=db_session,
user=user,
name=build_prompt_name_from_persona_name(persona_upsert_request.name),
datetime_aware=persona_upsert_request.datetime_aware,
system_prompt=persona_upsert_request.system_prompt,
task_prompt=persona_upsert_request.task_prompt,
include_citations=persona_upsert_request.include_citations,
prompt_id=prompt_id,
)
prompt_snapshot = PromptSnapshot.from_model(prompt)
persona_upsert_request.prompt_ids = [prompt.id]
persona_snapshot = create_update_persona(
persona_id=persona_id,
create_persona_request=persona_upsert_request,
user=user,
db_session=db_session,
)
persona_snapshot.prompts = [prompt_snapshot]
return persona_snapshot
class PersonaLabelPatchRequest(BaseModel):
label_name: str
@basic_router.get("/labels")
def get_labels(
db: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> list[PersonaLabelResponse]:
return [
PersonaLabelResponse.from_model(label)
for label in get_assistant_labels(db_session=db)
]
@basic_router.post("/labels")
def create_label(
label: PersonaLabelCreate,
db: Session = Depends(get_session),
_: User | None = Depends(current_user),
) -> PersonaLabelResponse:
"""Create a new assistant label"""
try:
label_model = create_assistant_label(name=label.name, db_session=db)
return PersonaLabelResponse.from_model(label_model)
except IntegrityError:
raise HTTPException(
status_code=400,
detail=f"Label with name '{label.name}' already exists. Please choose a different name.",
)
@admin_router.patch("/label/{label_id}")
def patch_persona_label(
label_id: int,
persona_label_patch_request: PersonaLabelPatchRequest,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_label(
label_id=label_id,
label_name=persona_label_patch_request.label_name,
db_session=db_session,
)
@admin_router.delete("/label/{label_id}")
def delete_label(
label_id: int,
_: User | None = Depends(current_admin_user),
db_session: Session = Depends(get_session),
) -> None:
delete_persona_label(label_id=label_id, db_session=db_session)
class PersonaShareRequest(BaseModel):
user_ids: list[UUID]
# We notify each user when a user is shared with them
@basic_router.patch("/{persona_id}/share")
def share_persona(
persona_id: int,
persona_share_request: PersonaShareRequest,
user: User = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
update_persona_shared_users(
persona_id=persona_id,
user_ids=persona_share_request.user_ids,
user=user,
db_session=db_session,
)
for user_id in persona_share_request.user_ids:
# Don't notify the user that they have access to their own persona
if user_id != user.id:
create_notification(
user_id=user_id,
notif_type=NotificationType.PERSONA_SHARED,
db_session=db_session,
additional_data=PersonaSharedNotificationData(
persona_id=persona_id,
).model_dump(),
)
@basic_router.delete("/{persona_id}")
def delete_persona(
persona_id: int,
user: User | None = Depends(current_user),
db_session: Session = Depends(get_session),
) -> None:
mark_persona_as_deleted(
persona_id=persona_id,
user=user,
db_session=db_session,
)
@basic_router.get("/image-generation-tool")
def get_image_generation_tool(
_: User
| None = Depends(current_user), # User param not used but kept for consistency
db_session: Session = Depends(get_session),
) -> ImageGenerationToolStatus: # Use bool instead of str for boolean values
is_available = is_image_generation_available(db_session=db_session)
return ImageGenerationToolStatus(is_available=is_available)
@basic_router.get("")
def list_personas(
user: User | None = Depends(current_chat_accesssible_user),
db_session: Session = Depends(get_session),
include_deleted: bool = False,
persona_ids: list[int] = Query(None),
) -> list[PersonaSnapshot]:
personas = get_personas_for_user(
user=user,
include_deleted=include_deleted,
db_session=db_session,
get_editable=False,
joinedload_all=True,
)
if persona_ids:
personas = [p for p in personas if p.id in persona_ids]
# Filter out personas with unavailable tools
personas = [
p
for p in personas
if not (
any(tool.in_code_tool_id == "ImageGenerationTool" for tool in p.tools)
and not is_image_generation_available(db_session=db_session)
)
]
return [PersonaSnapshot.from_model(p) for p in personas]
@basic_router.get("/{persona_id}")
def get_persona(
persona_id: int,
user: User | None = Depends(current_limited_user),
db_session: Session = Depends(get_session),
) -> PersonaSnapshot:
return PersonaSnapshot.from_model(
get_persona_by_id(
persona_id=persona_id,
user=user,
db_session=db_session,
is_for_edit=False,
)
)
@basic_router.post("/assistant-prompt-refresh")
def build_assistant_prompts(
generate_persona_prompt_request: GenerateStarterMessageRequest,
db_session: Session = Depends(get_session),
user: User | None = Depends(current_user),
) -> list[StarterMessage]:
try:
logger.info(
f"Generating {generate_persona_prompt_request.generation_count} starter messages"
f" for user: {user.id if user else 'Anonymous'}",
)
starter_messages = generate_starter_messages(
name=generate_persona_prompt_request.name,
description=generate_persona_prompt_request.description,
instructions=generate_persona_prompt_request.instructions,
document_set_ids=generate_persona_prompt_request.document_set_ids,
generation_count=generate_persona_prompt_request.generation_count,
db_session=db_session,
user=user,
)
return starter_messages
except Exception as e:
logger.exception("Failed to generate starter messages")
raise HTTPException(status_code=500, detail=str(e))